In [None]:
import os
import sys
import time
import json
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline
%load_ext autoreload

import torch
import torch.nn as nn
import sagemaker
from sagemaker.pytorch import PyTorch
from sagemaker.tuner import ContinuousParameter, HyperparameterTuner
from sagemaker.amazon.amazon_estimator import get_image_uri

module_path = os.path.abspath(os.path.join('../py-conjugated/'))
if module_path not in sys.path:
    sys.path.append(module_path)
import morphology_networks as net
import model_training as train
import model_testing as test
import physically_informed_loss_functions as pilf
import network_utils as nuts

torch.manual_seed(28)

In [None]:
data_bucket = 'sagemaker-us-east-2-362637960691'
train_data_path = 'py-conjugated/raw_data/OPV/train_set/'
test_data_path = 'py-conjugated/raw_data/OPV/test_set/'
# model_states_path = 's3://{}/py_conjugated/model_states/OPV/OPV_encoder_1/'.format(data_bucket)

In [None]:
%autoreload

train_dataset = nuts.OPV_ImDataset(data_bucket, train_data_path)
test_dataset = nuts.OPV_ImDataset(data_bucket, test_data_path)

train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size = 26)
test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size = 10)

In [None]:
def fit(model, criterion, lr, epochs = 30):
    
    optimizer = torch.optim.Adam(model.parameters(), lr = lr)
    
    train_epoch_pce_losses = []
    train_epoch_voc_losses = []
    train_epoch_jsc_losses = []
    train_epoch_ff_losses = []
    train_loss = []
    
    epoch_pce_losses = []
    epoch_voc_losses = []
    epoch_jsc_losses = []
    epoch_ff_losses = []
    epoch_loss = []
    
    epoch_pce_accs = []
    epoch_voc_accs = []
    epoch_jsc_accs = []
    epoch_ff_accs = []
    epoch_accs = []

    epoch_pce_r2s = []
    epoch_voc_r2s = []
    epoch_jsc_r2s = []
    epoch_ff_r2s = []
    epoch_r2s = []
    
    for epoch in range(epochs):
        train_losses = train.train_OPV_m2py_model(model = model,
                                                training_data_set = train_dataloader,
                                               criterion = criterion,
                                               optimizer = optimizer)

        train_epoch_pce_losses.append(train_losses[0])
        train_epoch_voc_losses.append(train_losses[1])
        train_epoch_jsc_losses.append(train_losses[2])
        train_epoch_ff_losses.append(train_losses[3])
        tot_trn_loss = sum(train_losses)
        train_loss.append(tot_trn_loss)

        test_losses, test_accs, test_r2s = test.eval_OPV_m2py_model(model = model,
                                                                   test_data_set = test_dataloader,
                                                                   criterion = criterion)

        epoch_pce_losses.append(test_losses[0])
        epoch_voc_losses.append(test_losses[1])
        epoch_jsc_losses.append(test_losses[2])
        epoch_ff_losses.append(test_losses[3])
        tot_tst_loss = sum(test_losses)
        epoch_loss.append(tot_tst_loss)
        
        epoch_pce_accs.append(test_accs[0])
        epoch_voc_accs.append(test_accs[1])
        epoch_jsc_accs.append(test_accs[2])
        epoch_ff_accs.append(test_accs[3])
        tot_tst_acc = sum(test_accs)
        epoch_accs.append(tot_tst_acc)
        
        epoch_pce_r2s.append(test_r2s[0])
        epoch_voc_r2s.append(test_r2s[1])
        epoch_jsc_r2s.append(test_r2s[2])
        epoch_ff_r2s.append(test_r2s[3])
        tot_tst_r2 = sum(test_r2s)
        epoch_r2s.append(tot_tst_r2)
        
        print('Finished epoch ', epoch)
        
    best_loss_indx = epoch_loss.index(min(epoch_loss))
    best_acc_indx = epoch_accs.index(min(epoch_accs))
    best_r2_indx = epoch_r2s.index(max(epoch_r2s))
    
    fit_results = {
        'lr': lr,
        'best_loss_epoch': best_loss_indx,
        'best_acc_epoch': best_acc_indx,
        'best_r2_epoch': best_r2_indx,
        'pce_loss': epoch_pce_losses,
        'voc_loss': epoch_voc_losses,
        'jsc_loss': epoch_jsc_losses,
        'ff_loss': epoch_ff_losses,
        'test_losses': epoch_loss,        
        'pce_acc': epoch_pce_accs,
        'voc_acc': epoch_voc_accs,
        'jsc_acc': epoch_jsc_accs,
        'ff_acc': epoch_ff_accs,
        'test_accs': epoch_accs,
        'pce_r2': epoch_pce_r2s,
        'voc_r2': epoch_voc_r2s,
        'jsc_r2': epoch_jsc_r2s,
        'ff_r2': epoch_ff_r2s,
        'test_r2s': epoch_r2s,
        'train_pce_loss': train_epoch_pce_losses,
        'train_voc_loss': train_epoch_voc_losses,
        'train_jsc_loss': train_epoch_jsc_losses,
        'train_ff_loss': train_epoch_ff_losses
    }

    return fit_results

In [None]:
model = net.OPV_m2py_NN(8)
criterion = nn.MSELoss()

lrs = np.linspace(1e-5, 1e-1, 50)

In [None]:
%autoreload

lr_opt = {}

for i, lr in enumerate(lrs):
    print(f'  optimization loop {i}')
    print('-----------------------------')
    
    lr_opt[i] = fit(model, criterion, lr, epochs = 20)

In [None]:
lr_opt

In [None]:
with open('./20200722_OPVNN4_hpo_results-r1.json', 'w') as fp:
    json.dump(lr_opt, fp)