In [2]:
# Utility imports
# General imports
import torch
import torch.nn as nn
import os
# import tensorflow as tf
from scipy import stats
# Used for distributions libraries.
from scipy import stats
# Utility imports
from utils.losses import *
from utils.plotting import *
from utils.training import *

import pickle
import sys
from lightning.pytorch.accelerators import find_usable_cuda_devices


In [3]:
eta = 1.6e3
lr_vals = [1.0]
F0 = -0.3
nu_vals = [0]

file_name = 'scan_bce_ECD_results'
loss_funcs = ['linear', 'square', 'exponl']

reps = 100
optimizer = 'ECD'
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)

cuda


In [None]:
# Data parameters
N = 10**6
X = np.load('data/zenodo/fold/8/X_trn.npy')[:N]
y = np.load('data/zenodo/fold/8/y_trn.npy')[:N].astype('float32')
data, m, s = split_data(X, y)

class train_val_loader(pl.LightningDataModule):
    def __init__(self, data, N, workers):
        super().__init__()
        self.N = N
        self.data = data
        self.workers = workers
    def prepare_data(self):
        X_train, X_test, y_train, y_test = self.data
        X_train = X_train.astype(np.float32)
        X_test = X_test.astype(np.float32)
        y_train = y_train.astype(np.float32)
        y_test = y_test.astype(np.float32)

        self.X_train = torch.from_numpy(X_train)
        self.X_test = torch.from_numpy(X_test)
        self.y_train = torch.from_numpy(y_train)
        self.y_test = torch.from_numpy(y_test)

        self.train_data = TensorDataset(self.X_train, self.y_train)
        self.test_data = TensorDataset(self.X_test, self.y_test)
    def train_dataloader(self):
        return DataLoader(self.train_data, batch_size=int(0.1*self.N), shuffle=True, num_workers=self.workers, persistent_workers=True)
    def val_dataloader(self):
        return DataLoader(self.test_data, batch_size=int(0.1*self.N), shuffle=False, num_workers=self.workers, persistent_workers=True)

train_val_data = train_val_loader(data, N, 20)

max_epochs = 500
min_epochs = 15
patience = 10

X_mae = np.load('data/zenodo/fold/8/X_tst.npy')
X_mae = torch.from_numpy(X_mae)
lr_tst = np.load('data/zenodo/fold/8/lr_tst.npy')
lr_tst = torch.from_numpy(lr_tst)

def mae(model_lr):
    abs_dif = abs(model_lr(X_mae) - lr_tst)
    return abs_dif[abs_dif < 100].mean()
def stdae(model_lr):
    abs_dif = abs(model_lr(X_mae) - lr_tst)
    return abs_dif[abs_dif < 100].std()

filestr = 'models/zenodo/mlc/'
dict_list = []
for loss_func in loss_funcs:
    for learning_rate in lr_vals:
        for nu in nu_vals:
            for i in range(reps):
                if loss_func == 'odds': loss_fn = bce; output = 'sigmoid'; lr_fn = odds_lr
                elif loss_func == 'probit': loss_fn = probit_bce; output = 'linear'; lr_fn = probit_lr
                elif loss_func == 'arctan': loss_fn = arctan_bce; output = 'linear'; lr_fn = arctan_lr
        
                params = {'loss_fun':loss_fn, 'd': 6, 'output': output, 'optimizer': optimizer, 'learning_rate': learning_rate, 'eta': eta, 'F0': F0, 'nu': nu}
                model_path = filestr + loss_func + '/ecd/'
    
                checkpoint_callback = ModelCheckpoint(
                    dirpath = model_path,
                    filename = 'model_{}'.format(i),
                    monitor = 'val_loss',
                    mode = 'min',
                    save_weights_only = True
                )
    
                trainer = pl.Trainer(accelerator='cuda', devices=1, max_epochs=max_epochs, callbacks=[EarlyStopping(monitor="val_loss", mode="min", patience = patience), checkpoint_callback], min_epochs=min_epochs, enable_progress_bar=False)
    
                model = create_model_original(**params)
                model.to(device)
                trainer.fit(model, train_val_data)
    
                try: os.mkdir(model_path)
                except OSError as error: print(error)
    
                train_losses = model.train_hist
                val_losses = model.val_hist
    
                checkpoint = torch.load(checkpoint_callback.best_model_path)
                model.load_state_dict(checkpoint['state_dict'])
                model.eval()
                lr = lr_fn(model, m , s)
                mae_1 = mae(lr).detach().numpy()
    
                scan_res = dict(mae = mae_1, optimizer = optimizer, learning_rate = learning_rate, eta = eta, F0 = F0, nu = nu, classifier = loss_func, train_loss = train_losses, val_loss = val_losses, path = checkpoint_callback.best_model_path, patience=patience)
                dict_list.append(scan_res)
                print(f'eta: {eta} lr: {learning_rate} mae: {mae_1} classifier: ', loss_func, f'path: {model_path}')
                del model

In [None]:
with open('models/zenodo/bce/' + file_name + '.pkl', 'wb') as fout:
    pickle.dump(results, fout)
    fout.close()