In [None]:
%load_ext autoreload
%autoreload 2

import os
import numpy as np
import matplotlib.pyplot as plt
from scipy.io import savemat
import configparser

import torch
import torch.nn as nn
from torch.utils.data import random_split
from torch.utils.data import DataLoader

import sys
import time

setpaths_dir = "../setpaths"
sys.path.append(setpaths_dir)
from setpaths import setpaths
libpath, datpath, resultpath, basepath = setpaths(setpaths_dir)

sys.path.append(basepath)
from lib.utils import getDatasetMat, train_model, showIms
from lib.FC_Conv import FC_Conv

In [None]:
rand_seed = 0
np.random.seed(rand_seed)
torch_seed = torch.manual_seed(rand_seed)

GPUID = 2
os.environ['CUDA_VISIBLE_DEVICES'] = str(GPUID)

#### Load training and reconstruction parameters from configuration

In [None]:
# configname = "automap_settings.ini"
# configname = "FC_settings.ini"
configname = "automap_confocal_settings.ini"
# configname = "FC_confocal_settings.ini"

fullconfigpath = os.path.join("settings", configname)
config = configparser.ConfigParser()
_ = config.read(fullconfigpath)

# Set files to read from
nTest = int(config["Settings"]["nTest"])
batch_sz = int(config["Settings"]["batch_sz"])
nEpochs = int(config["Settings"]["nEpochs"])
LR = float(config["Settings"]["LR"])
lossFunc = config["Settings"]["lossFunc"]
measNormalization = float(config["Settings"]["measNormalization"])
loadFname = config["Settings"]["loadFname"]
datVarName = config["Settings"]["datVarName"]
displayIndices = [int(i) for i in config["Settings"]["displayIndices"].split(',')]
showEvery = int(config["Settings"]["showEvery"])
unet_nfilts = int(config["Settings"]["unet_nfilts"])
trainAutomap = config["Settings"].getboolean("trainAutomap", fallback=True)

# If set conv_sizes to empty array, train FC network. Else, set to automap configuration
FC_sizes = [1681] # width of hidden layer
if trainAutomap:
    conv_sizes = [(2, 32), (2, 32), 3] # size of padding, kernel size == 2*pad_size + 1
else:
    conv_sizes = []

#### Set up model and prepare for training

In [None]:
savepath = os.path.join(resultpath, 'exp')
    
if trainAutomap:
    model_savename = "model_automap"
else:
    model_savename = "model_FC"

if 'conf' in loadFname:
    model_savename += '_conf'

intermed_path = os.path.join(savepath, 'intermed_' + model_savename)

In [None]:
# Load in data
loadFpath = os.path.join(datpath, loadFname + '.mat')
full_dataset, trainInds, testInds = getDatasetMat(matpath=loadFpath, nTest=nTest, measNormalization=measNormalization, datVarName=datVarName)
trainMeas, trainTruth = full_dataset.getFullTrainSet()
testMeas, testTruth = full_dataset.getFullTestSet()
szA, NBINS, nSrcDet, imY, imX = full_dataset.getDims()

dev = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

model = FC_Conv(szA, FC_sizes, conv_sizes, dev)

In [None]:
optimizer = torch.optim.Adam(model.parameters(), lr=LR)
train_dataloader = DataLoader(full_dataset, batch_size=batch_sz, shuffle=True)


train_dict = {"batch_sz": batch_sz,
              "nEpochs": nEpochs,
              "showEvery": showEvery,
              "lossFunc": lossFunc,
              "untied": None,
              "nLayers": None,
              "scale_mag": None,
              "lam1": None,
              "LR": LR,
              "vgg_weight": 0,
              "unet_nfilts": unet_nfilts}

model, epoch_arr, train_losses, test_losses, misc_out = train_model(dataset_in=full_dataset, 
                                                        train_d=train_dict, dev=dev, visInds=displayIndices,
                                                        model=model, optim=optimizer)

In [None]:
# Test results

Y_test_torch, truthIms_torch = full_dataset.getFullTestSet()
_, _, _, imY, imX = full_dataset.getDims()
nIms = Y_test_torch.shape[1]

model = model.to('cpu')
Y_test_torch = Y_test_torch.cpu()
if unet_nfilts > 0:
    unet = misc_out['unet'].to('cpu')

starttime = time.perf_counter()
reconIms = model(Y_test_torch)
if unet_nfilts > 0:
    reconIms = unet(reconIms)
finishtime = time.perf_counter()
reconTime = finishtime - starttime

reconIms_np = np.reshape(reconIms.detach().numpy(), (imY, imX, nIms))
truthIms_np = truthIms_torch.cpu().detach().numpy()

In [None]:
if not (os.path.isdir(savepath)):
    os.makedirs(savepath)

fullsavepath_model = os.path.join(savepath, model_savename + '.pt')

pydict = {
    "full_dataset": full_dataset,
    "train_dict": train_dict,
    "model": model,
    "epoch_arr": epoch_arr, 
    "train_losses": train_losses,
    "test_losses": test_losses,
    "trainInds": trainInds,
    "testInds": testInds,
}
for k in misc_out:
    pydict[k] = misc_out[k]

matdict = {
    "epoch_arr": epoch_arr, 
    "train_losses": train_losses,
    "test_losses": test_losses,
    "trainInds": trainInds,
    "testInds": testInds,
    "reconIms_np": reconIms_np,
    "truthIms_np": truthIms_np,
    "runtime_arr": misc_out["runtime_arr"],
}

torch.save(pydict, fullsavepath_model)

fullsavepath_mat = os.path.join(savepath, model_savename + '.mat')
savemat(fullsavepath_mat, matdict)

print("Saved model to: %s" % fullsavepath_model)