In [None]:
%load_ext autoreload
%autoreload 2

import os
import numpy as np
import matplotlib.pyplot as plt
from scipy.io import savemat

import torch

import sys

setpaths_dir = "../setpaths"
sys.path.append(setpaths_dir)
from setpaths import setpaths
libpath, datpath, resultpath, basepath = setpaths(setpaths_dir)

sys.path.append(basepath)
from lib.DOTDataset_class import DOTDataset
from lib.utils import getDatasetMat, train_model

In [None]:
rand_seed = 0
np.random.seed(rand_seed)
torch_seed = torch.manual_seed(rand_seed)

GPUID = 7
os.environ['CUDA_VISIBLE_DEVICES'] = str(GPUID)

#### Training and reconstruction parameters

In [None]:
nTest = 500
LR = 2e-4
batch_sz = 4500
scale_initial_val = LR
lam1 = None # if None use learnable L1 coefficient, else use lam1
untied = True

showEvery = 50
measNormalization = 7.8e0

#### Set up model and prepare for training

In [None]:
# Load in data
loadFname = "allTrainingDat_30-Sep-2021_EML"
loadFpath = os.path.join(datpath, loadFname + '.mat')
datVarName='diff_L'
full_dataset, trainInds, testInds = getDatasetMat(matpath=loadFpath, nTest=nTest, measNormalization=measNormalization, datVarName=datVarName)

# Send to GPU if available
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

savepath = os.path.join(resultpath, 'performance_analysis')

if not (os.path.isdir(savepath)):
    os.makedirs(savepath)

if untied:
    untied_str = 'T'
else:
    untied_str = 'F'
    
minthresh_final = 0.0

In [None]:
cfgs = [

    # Test varying number of layers
    {"nEpochs": 2000,
     "nLayers": 1,
     "lossFunc": "MAE",
     "actfunc": "shrink",}, 
    
    {"nEpochs": 2000,
     "nLayers": 3,
     "lossFunc": "MAE",
     "actfunc": "shrink",}, 
    
    {"nEpochs": 2000,
     "nLayers": 5,
     "lossFunc": "MAE",
     "actfunc": "shrink",}, 
    
    {"nEpochs": 2000,
     "nLayers": 7,
     "lossFunc": "MAE",
     "actfunc": "shrink",}, 
    
    {"nEpochs": 2000,
     "nLayers": 10,
     "lossFunc": "MAE",
     "actfunc": "shrink",}, 
    
    # Test loss function used to optimize 
    {"nEpochs": 2000,
     "nLayers": 3,
     "lossFunc": "SSIM",
     "actfunc": "shrink",}, 
    
    {"nEpochs": 2000,
     "nLayers": 3,
     "lossFunc": "MSE",
     "actfunc": "shrink",}, 
    
    # Test varying number of iterations (2000 itr already tested)
    {"nEpochs": 250,
     "nLayers": 3,
     "lossFunc": "MAE",
     "actfunc": "shrink",}, 
    
    {"nEpochs": 500,
     "nLayers": 3,
     "lossFunc": "MAE",
     "actfunc": "shrink",}, 
]

In [None]:
for j in range(len(cfgs)):
    d = cfgs[j]
    train_dict = {"nLayers": d["nLayers"],
                  "scale_mag": scale_initial_val,
                  "lam1": lam1,
                  "actfunc": d["actfunc"],
                  "LR": LR,
                  "nEpochs": d["nEpochs"],
                  "batch_sz": batch_sz,
                  "showEvery": showEvery,
                  "untied": untied,
                  "lossFunc": d["lossFunc"],
                  "measNormalization": measNormalization}
    
    # Train Model -----------------------------------------------------------
    model, epoch_arr, train_losses, test_losses, misc_out = train_model(full_dataset, train_dict, device)

    model_savename = "model_%s_NL=%d_nEpoch=%d_lossFunc=%s_untied=%s_actfunc=%s" % (loadFname, d["nLayers"], d["nEpochs"], d["lossFunc"], untied_str, d["actfunc"])

    fullsavepath_model = os.path.join(savepath, model_savename + '.pt')
    model_dict = {
                "train_dict": train_dict,
                "model": model,
                "epoch_arr": epoch_arr, 
                "train_losses": train_losses,
                "test_losses": test_losses
    }
    torch.save(model_dict, fullsavepath_model)
    print("Saved model to: %s" % fullsavepath_model)
    
    
    # Test results -----------------------------------------------------------
    _, _, _, imY, imX = full_dataset.getDims() # szA, NBINS, nSrcDet, imY, imX = allDat.getDims()
    Y_test_torch, X_test_torch = full_dataset.getFullTestSet()
    nTest = Y_test_torch.shape[1]
    
    with torch.no_grad():
        X_test_recon_torch = model(Y_test_torch.to(device))

    X_test_recon_im = X_test_recon_torch.detach().cpu().numpy()
    X_test_recon_im = X_test_recon_im.reshape((imY, imX, nTest))

    X_test = X_test_torch.detach().cpu().numpy()
    X_test = X_test.reshape((imY, imX, nTest))

    fullsavepath_results = os.path.join(savepath, "result_%s.mat" % model_savename)
    matdict = {
        "truthIms": X_test,
        "reconIms": X_test_recon_im,
        "epoch_arr": epoch_arr, 
        "train_losses": train_losses,
        "test_losses": test_losses,
        "trainInds": trainInds,
        "testInds": testInds,
        "runtime_arr": misc_out["runtime_arr"],
    }
    savemat(fullsavepath_results, matdict)
    print("Completed %d/%d. Saved result to file: %s" % (j+1, len(cfgs), fullsavepath_results))