In [1]:
import numpy as np
import matplotlib.pyplot as plt
import os
from sklearn.metrics import r2_score, mean_squared_error, mean_absolute_error

import torch
import torch.nn as nn
from torch.nn.functional import relu

In [2]:
# Support functions

def train_test_split(X, y, test_size = 0.2, return_test_indices=False):
    ''' Split input data X and labels y into training and testing arrays '''
    n = len(y)
    n_list = list(range(n))
    test_choice = np.random.choice(n_list, replace=False, size=int(test_size*n))
    train_choice = np.array([i for i in n_list if i not in test_choice])
    X_train = X[train_choice,:]
    X_test = X[test_choice,:]
    y_train = y[train_choice]
    y_test = y[test_choice]
    if return_test_indices:
        return X_train, y_train, X_test, y_test, test_choice
    return X_train, y_train, X_test, y_test


def pred_vs_true(y_true, y_pred, plotTitle=None, plotColor='tab:blue', saveLoc=None):
    ''' Plot an array of predicted values vs. label values
        Return the correlation score (R^2), mean average error (MAE), and root mean squared error (RMSE)
    '''
    r2 = r2_score(y_true, y_pred)
    mae = mean_absolute_error(y_true, y_pred)
    rmse = mean_squared_error(y_true, y_pred, squared=False)
    if plotTitle:
        plt.figure(figsize=(6,6))
        plt.title(f'{plotTitle}\nR^2: {round(r2,3)},   MAE: {round(mae,2)},   RMSE: {round(rmse,2)}', size=14)
        plt.xlabel('Test Labels', size=20)
        plt.ylabel('Predictions', size=20)
        plt.scatter(y_true, y_pred, c='k')
        ymin, ymax = min(y_true), max(y_true)
        plt.plot([ymin, ymax], [ymin, ymax], c=plotColor)
        plt.grid()
        if saveLoc: plt.savefig(f'{saveLoc}.png', facecolor='w', bbox_inches='tight')
        plt.show()
        plt.close()
    return r2, mae, rmse
    
'''
def get_plateau_index(y, deriv_step=10, round_decimals=0):
    dy = np.array([round((b-a)/deriv_step, 0) for a,b in zip(y[:-deriv_step], y[deriv_step:])])
    zeros = np.nonzero(dy == 0)[0]
    return zeros[0] if len(zeros)>0 else None
'''

def plot_losses(batched_losses, valid_loss, plotTitle=None, plotColor='tab:blue', saveLoc=None):   
    epochs = batched_losses.shape[0]
    batch_min = np.min(batched_losses, axis=-1)
    batch_max = np.max(batched_losses, axis=-1)
    batch_mean = np.mean(batched_losses, axis=-1)

    plt.figure()
    plt.title(plotTitle, size=14)
    plt.xlabel('Epoch', size=20)
    plt.ylabel('Loss', size=20)
    plt.fill_between(list(range(epochs)), batch_min, batch_max, alpha=0.3, color=plotColor,
                     label='Training Loss (Batch Min-Max)')
    plt.plot(batch_mean, lw=2, c=plotColor, label='Training Loss (Batch Average)')
    plt.plot(valid_loss, lw=3, c='k', label='Validation Loss')
    #valid_plateau = get_plateau_index(valid_loss, deriv_step=10, round_decimals=0)
    #if valid_plateau != None:
    #    plt.axvline(x=valid_plateau, ls='--', c='k', label=f'Epoch {valid_plateau}')
    plt.xlim(0,epochs)
    plt.legend()
    plt.grid()
    if saveLoc: plt.savefig(f'{saveLoc}.png', facecolor='w', bbox_inches='tight')
    plt.show()
    plt.close()

'''
def r2_histogram(r2_vals, plotTitle=None, histColor='tab:blue', saveLoc=None):
    plt.figure()
    plt.title(plotTitle)
    min_cutoff = 0.3
    bins = np.arange(min_cutoff - 0.15, 1.01, 0.05)
    counts, _, _ = plt.hist(np.where(r2_vals < min_cutoff, min_cutoff-0.06, r2_vals),
                            bins, color=histColor, rwidth=0.8)
    plt.axvline(x=min_cutoff, c='k', ls='--', lw=2)
    plt.xlabel('R^2 value', size=20)
    plt.xlim(min_cutoff, 1)
    plt.xticks(bins, ['', '  < 0.3', '', '0.3', '', '0.4', '', '0.5', '',
                      '0.6', '', '0.7', '', '0.8', '', '0.9', '', '1.0'])
    plt.yticks(list(range(0, int(max(counts))+3, 2)))
    plt.grid(axis='y')
    if saveLoc: plt.savefig(f'{saveLoc}.png', facecolor='w', bbox_inches='tight')
    plt.show()
    plt.close()


def histogram(vals, plotTitle=None, xLabel='<Some_Metric?>', histColor='tab:blue', nBins=10, saveLoc=None):
    plt.figure()
    plt.title(plotTitle)
    counts, _, _ = plt.hist(vals, bins=nBins, rwidth=0.8, color=histColor)
                            # fill=True, edgecolor=histColor, linewidth=2, facecolor='w', hatch='/')
    plt.xlabel(xLabel, size=20)
    plt.yticks(list(range(0, int(max(counts))+3, 2)))
    plt.grid(axis='y')
    if saveLoc: plt.savefig(f'{saveLoc}.png', facecolor='w', bbox_inches='tight')
    plt.show()
    plt.close()
'''   
    
def compare_metrics(r2_vals, mae_vals, rmse_vals, plotTitle=None, plotColor='tab:blue', saveLoc=None):
    plt.figure()
    plt.title(plotTitle, size=14)
    plt.xlabel('R^2', size=20)
    plt.ylabel('Error', size=20)
    plt.plot(r2_vals, mae_vals, 'D', c=plotColor, label='MAE', markerfacecolor='white')
    plt.plot(r2_vals, rmse_vals, 'o', c=plotColor, label='RMSE')
    plt.grid()
    plt.legend(bbox_to_anchor=(1,1), loc='upper left', shadow=True, fontsize=16)
    if saveLoc: plt.savefig(f'{saveLoc}.png', facecolor='w', bbox_inches='tight')
    plt.show()
    plt.close()
    

In [3]:
# Feed-Forward Neural Network

class FFNN(nn.Module):
    
    def __init__(self, hidden_layers, input_size):
        super(FFNN, self).__init__()
        self.layers = nn.ModuleList()
        self.layers.append(nn.Linear(input_size, hidden_layers[0]))
        for i,j in zip(hidden_layers[:-1], hidden_layers[1:]):
            self.layers.append(nn.Linear(i,j))
        self.layers.append(nn.Linear(j, 1))
        
    def forward(self, X):
        y = X
        for layer in self.layers[:-1]:
            y = relu(layer(y))
        y = self.layers[-1](y)
        return y

            
def get_FFNN(hidden_layers=[8,8], epochs=100, batchSize=32, optimizer_name='Adam', learnRate=0.001,
             epoch_cutoff=False # if True, stop training once validation loss converges,
                                 # i.e. (valid_loss[i] - valid_loss[i-step])/step < epsilon
                                 # defaults: step=10, epsilon=0.5; or pass in tuple argument (step, epsilon)                 
            ):
    ''' Inputs: hyperparameters for a FFNN model
        Output: a function that takes input data and labels and trains a new FFNN model 
    '''
    
    def train_FFNN(X_train, y_train, X_valid, y_valid):
        ''' Inputs: training, validation, and testing data and labels (as np arrays)
            Output: model, batched_loss (2D array), valid_loss (1D array)
        '''
        
        # Preprocess training data
        if len(y_train.shape)==1: y_train = np.array([y_train]).T
        if len(y_valid.shape)==1: y_valid = np.array([y_valid]).T
        trainset = torch.utils.data.TensorDataset(torch.from_numpy(X_train), torch.from_numpy(y_train))
        trainloader = torch.utils.data.DataLoader(trainset, batch_size=batchSize, shuffle=True)
        
        # Init FFNN
        model = FFNN(hidden_layers, input_size=X_train.shape[1])
        loss_func = nn.MSELoss()
        optimizer = eval(f'torch.optim.{optimizer_name}')(model.parameters(), lr=learnRate)
        
        if not isinstance(epoch_cutoff, bool):
            step, ep = epoch_cutoff
        elif epoch_cutoff:
            step, ep = 10, 0.5
        else:
            step, ep = 0, 0 # No cutoff, arbitrary values
        
        # Fitting
        batched_losses = []
        valid_loss = []
        for epoch in range(epochs):
            
            if epoch_cutoff and epoch > 3*step: # Prevent immediate "convergence" within 30 epochs
                if (valid_loss[-step-1] - valid_loss[-1])/step < ep:
                    # print(f'Training converged in {epoch} epochs')
                    break
                    
            model.train()
            epoch_losses = []
            for (X_batch, y_batch) in trainloader:
                y_batch_pred = model(X_batch.float())
                loss = loss_func(y_batch_pred, y_batch.float())
                epoch_losses.append(loss.item())
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
            batched_losses.append(epoch_losses)
            
            model.eval()
            y_pred = model(torch.from_numpy(X_valid).float())
            epoch_valid_loss = loss_func(y_pred, torch.from_numpy(y_valid))
            valid_loss.append(epoch_valid_loss.item())

        return model, np.array(batched_losses), np.array(valid_loss)
    
    return train_FFNN
        
print('PyTorch implemented:\tFeed-Forward Neural Network (FFNN)')

PyTorch implemented:	Feed-Forward Neural Network (FFNN)


In [4]:
"""
def predict(model, X_test):
    y_pred = model(torch.from_numpy(X_test).float())
    return y_pred.detach().numpy()
"""

def run_nTrials(train_model, X_train, y_train, X_test=None, y_test=None,
                plotTitle='<model>', plotColor='tab:blue', saveFolder=None,
                savePlots = False, # if True, save in same folder; or input subfolder name as string
                best_metric_tuple = (max, 'r2'),
                verb=True,
                nTrials=50):
    
    
    # Determine test set
    if isinstance(X_test, type(None)) and isinstance(y_test, type(None)):
        print('No test set specified. Randomly sample 20% of data as a non-universal test set in every trial.')
        X, y = X_train, y_train
        print(f'X.shape:{X.shape}\n')
        non_universal_test = True
    else:
        non_universal_test = False
        print(f'X_train.shape:{X_train.shape} / y_test.shape:{y_test.shape}\n')
    
    
    # Run trials
    opt, metric = best_metric_tuple
    results = []
    epoch_vals, r2_vals, mae_vals, rmse_vals = [], [], [], [] # redundant for code clarity
    if verb: print(f'Starting {nTrials} trials:\t [#] epochs, {metric}')
        
    for trial in range(nTrials):
        
        if non_universal_test:
            X_train, y_train, X_test, y_test = train_test_split(X, y)
            X_train, y_train, X_valid, y_valid = train_test_split(X_train, y_train)
            
        X_tr, y_tr, X_valid, y_valid = train_test_split(X_train, y_train)
        model, batched_losses, valid_loss = train_model(X_tr, y_tr, X_valid, y_valid)
        y_pred = model(torch.from_numpy(X_test).float()).detach().numpy() # predict(model, X_test)
        r2, mae, rmse = pred_vs_true(y_test, y_pred)
        r2_vals.append(r2)
        mae_vals.append(mae)
        rmse_vals.append(rmse)
        epochs = len(valid_loss)
        epoch_vals.append(epochs)
        res_dict = {'r2':r2, 'mae':mae, 'rmse':rmse, 'epochs':epochs, 'y_test':y_test, 'y_pred':y_pred}
        #if saveFolder:
        res_dict.update({'model':model, 'batched_losses':batched_losses, 'valid_loss':valid_loss})
        results.append(res_dict)
        if verb:
            print(f'[{trial}] {epochs}, {round(res_dict[metric],2)}'.ljust(25),
                  end='' if (trial+1)%5 else '\n')
        else:
            print('.', end='')
        
        
    # Stats    
    r2_report = f'\nHighest R^2: {round(max(r2_vals), 3)}'
    r2_report += f'\nAverage R^2: {round(np.mean(r2_vals), 3)} +/- {round(np.std(r2_vals), 4)}'
    if verb or metric=='r2': print(r2_report)
    
    mae_report = f'\nHighest MAE: {round(max(mae_vals), 3)}'
    mae_report += f'\nAverage MAE: {round(np.mean(mae_vals), 3)} +/- {round(np.std(mae_vals), 4)}'
    if verb or metric=='mae': print(mae_report)
    
    rmse_report = f'\nLowest RMSE: {round(min(rmse_vals), 2)}'
    rmse_report += f'\nAverage RMSE: {round(np.mean(rmse_vals), 2)} +/- {round(np.std(rmse_vals), 3)}'
    if verb or metric=='rmse': print(rmse_report)
        
    epoch_report = f'\nMin Epochs: {min(epoch_vals)}, Max Epochs: {max(epoch_vals)}'
    epoch_report += f'\nAverage Epochs: {round(np.mean(epoch_vals),1)} +/- {round(np.std(epoch_vals), 1)}'
    if verb: print(epoch_report)
    
    
    # Save best model and txt summary
    best = opt(results, key = lambda res: res[metric])
    savePlot, saveTrain, saveMetrics = None, None, None
    
    if saveFolder:
        torch.save(best['model'], saveFolder+'/'+plotTitle+'.pth')
        with open(f'{saveFolder}/{plotTitle}.txt', 'w') as txtf:
            txtf.write(plotTitle + '\n\n')
            txtf.write(f'X_train.shape:{X_train.shape} / y_test.shape:{y_test.shape}\n')
            txtf.write(f'{r2_report}\n{r2_vals}\n')
            txtf.write(f'{mae_report}\n{mae_vals}\n')
            txtf.write(f'{rmse_report}\n{rmse_vals}\n')
            txtf.write(f'{epoch_report}\n{epoch_vals}\n')

        if savePlots:
            if isinstance(savePlots, str):
                if 'plots' not in os.listdir(saveFolder):
                    os.mkdir(f'{saveFolder}/{savePlots}')
                savePlot = saveFolder+'/'+savePlots+'/'+plotTitle
            else:
                savePlot = saveFolder+'/'+plotTitle
            saveTrain = savePlot + '.train'
            saveMetrics = savePlot +'.metrics'
    
    
    # Display and Save Plots
    pred_vs_true(best['y_test'], best['y_pred'], plotTitle+f',  epochs: {best["epochs"]}', plotColor, saveLoc=savePlot)
    plot_losses(best['batched_losses'], best['valid_loss'], plotTitle, plotColor, saveLoc=saveTrain)
    compare_metrics(r2_vals, mae_vals, rmse_vals, plotTitle, plotColor, saveLoc=saveMetrics)
    
    return results, r2_vals, mae_vals, rmse_vals, #epoch_vals