In [1]:
import numpy as np
import torch
import torch.nn as nn
from torch.nn.functional import relu

In [2]:
def test_r2(y_true, y_pred, plotTitle=None, saveLoc=None):
    ''' Plot an array of predicted values over label values
        and return the R^2 correlation
    '''
    r2 = r2_score(y_true, y_pred)
    if plotTitle:
        plt.figure(figsize=(6,6))
        plt.title(f'{plotTitle}\nR^2: {round(r2,4)}')
        plt.xlabel('Labels', size=20)
        plt.ylabel('Predictions', size=20)
        plt.scatter(y_true, y_pred)
        ymin = min(y_true)
        ymax = max(y_true)
        plt.plot([ymin, ymax], [ymin, ymax], c='r')
        plt.grid()
        if saveLoc: plt.savefig(f'{saveLoc}--r2.png', facecolor='w', bbox_inches='tight')
        plt.show()
        plt.close()
    return r2

def plot_batched_losses(batched_losses, plotTitle=None, saveLoc=None):   
    epochs = batched_losses.shape[0]
    batch_min = np.min(batched_losses, axis=-1)
    batch_max = np.max(batched_losses, axis=-1)
    batch_mean = np.mean(batched_losses, axis=-1)

    plt.figure()
    plt.title(plotTitle)
    plt.xlabel('epoch', size=20)
    plt.ylabel('loss', size=20)
    plt.fill_between(list(range(epochs)), batch_min, batch_max, alpha=0.3, color='tab:red', label='Batch min-max')
    plt.plot(batch_mean, lw=2, label='Batch average')
    plt.xlim(0,epochs)
    plt.legend()
    plt.grid()
    if saveLoc: plt.savefig(f'{saveLoc}--loss.png', facecolor='w', bbox_inches='tight')
    plt.show()
    plt.close()
        

In [3]:
# Feed-Forward Neural Network

class FFNN(nn.Module):
    
    def __init__(self, hidden_layers, input_size):
        super(FFNN, self).__init__()
        self.layers = nn.ModuleList()
        self.layers.append(nn.Linear(input_size, hidden_layers[0]))
        for i,j in zip(hidden_layers[:-1], hidden_layers[1:]):
            self.layers.append(nn.Linear(i,j))
        self.layers.append(nn.Linear(j, 1))
        
    def forward(self, X):
        y = X
        for layer in self.layers[:-1]:
            y = relu(layer(y))
        y = self.layers[-1](y)
        return y

            
def get_FFNN(hidden_layers=[8,8], epochs=100, batchSize=32, optimizer_name='Adam', learnRate=0.001):
    ''' Inputs: hyperparameters for a FFNN model
        Output: a function that takes input data and labels and trains a new FFNN model 
    '''
    def train_FFNN(X_train, y_train, X_test, y_test):
        ''' Inputs: training and testing data and labels (as np arrays)
            Output: model, y_pred (as np.array), loss_list (list of floats)
        '''
        
        # Preprocess training data
        if len(y_train.shape)==1: y_train = np.array([y_train]).T
        if len(y_test.shape)==1: y_test = np.array([y_test]).T
        trainset = torch.utils.data.TensorDataset(torch.from_numpy(X_train), torch.from_numpy(y_train))
        trainloader = torch.utils.data.DataLoader(trainset, batch_size=batchSize, shuffle=True)
        
        # Init FFNN
        model = FFNN(hidden_layers, input_size=X_train.shape[1])
        loss_func = nn.MSELoss()
        optimizer = eval(f'torch.optim.{optimizer_name}')(model.parameters(), lr=learnRate)
        
        # Fitting
        batched_losses = []
        for epoch in range(epochs):
            epoch_losses = []
            for (X_batch, y_batch) in trainloader:
                y_batch_pred = model(X_batch.float())
                loss = loss_func(y_batch_pred, y_batch.float())
                epoch_losses.append(loss.item())
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
            batched_losses.append(epoch_losses)
        
        # Test prediction
        y_pred = model(torch.from_numpy(X_test).float())
        return model, y_pred.detach().numpy(), np.array(batched_losses)
    return train_FFNN
        
print('PyTorch implemented:\tFeed-Forward Neural Network (FFNN)')

PyTorch implemented:	Feed-Forward Neural Network (FFNN)


In [4]:
def run_nTrials(train_model, X, y, saveLoc=None, histColor='k', nTrials=50):
    
    # Run trials
    results = []
    r2_vals = [] # redundant for code clarity
    print(f'Starting {nTrials} trials...')
    for trial in range(nTrials):
        X_train, y_train, X_test, y_test = train_test_split(X, y)
        model, y_pred, batched_losses = train_model(X_train, y_train, X_test, y_test)
        r2 = test_r2(y_test, y_pred)
        r2_vals.append(r2)
        res_dict = {'r2':r2, 'y_test':y_test, 'y_pred':y_pred}
        if saveLoc:
            res_dict.update({'model':model, 'batched_losses':batched_losses})
        results.append(res_dict)
        print(f'[{trial}]', round(r2,3), end='\t' if (trial+1)%5 else '\n')
    r2_vals = np.array(r2_vals)
    
    # R^2 stats
    best = max(results, key = lambda res: res['r2'])
    r2_report = f'\nHighest R^2: {round(best["r2"], 3)}'
    r2_report += f'\nAverage R^2: {round(np.mean(r2_vals), 3)} +/- {round(np.std(r2_vals), 4)}'
    print(r2_report)
    
    # Plots: (1) pred vs. labels; (2) loss vs. epoch; (3) R^2 histogram
    plotTitle = saveLoc.split('/')[-1] if saveLoc else 'Best Model'
    test_r2(best['y_test'], best['y_pred'], plotTitle, saveLoc)
    plot_batched_losses(best['batched_losses'], plotTitle, saveLoc)
    
    plt.figure()
    plt.title(plotTitle)
    min_cutoff = 0.3
    bins = np.arange(min_cutoff - 0.15, 1.01, 0.05)
    counts, _, _ = plt.hist(np.where(r2_vals < min_cutoff, min_cutoff-0.06, r2_vals),
                            bins, color=histColor, rwidth=0.8)
    plt.axvline(x=min_cutoff, c='k', ls='--', lw=2)
    plt.xlabel('R^2')
    plt.xlim(min_cutoff, 1)
    plt.xticks(bins, ['', '  < 0.3', '', '0.3', '', '0.4', '', '0.5', '',
                      '0.6', '', '0.7', '', '0.8', '', '0.9', '', '1.0'])
    plt.yticks(list(range(0, int(max(counts))+3, 2)))
    plt.grid(axis='y')
    if saveLoc: plt.savefig(f'{saveLoc}--hist.png', facecolor='w', bbox_inches='tight')
    plt.show()
    plt.close()
    
    # Save model and txt summary
    if saveLoc:
        torch.save(best['model'], saveLoc+'.pth')
        with open(saveLoc+'.txt', 'w') as txtf:
            txtf.write(plotTitle + '\n' + r2_report + '\n\n')
            txtf.write(f'X.shape:{X.shape} / y_train.shape:{y_train.shape}\n\n' + str(list(r2_vals)))
    
    return results