In [None]:
import torch
import numpy as np
import matplotlib.pyplot as plt

import os
import sys

import torch.nn as nn
from torch.utils.data import DataLoader
from scipy.io import savemat

import configparser


setpaths_dir = "../setpaths"
sys.path.append(setpaths_dir)
from setpaths import setpaths
libpath, datpath, resultpath, basepath = setpaths(setpaths_dir)

sys.path.append(basepath)
from lib.utils import getDatasetMat, train_model, showIms
from lib.unet_flexible import UNet_wrapper

In [None]:
class Two_Step_Model(nn.Module):
    def __init__(self, model, unet, dev):
        super().__init__()
        self.model = model.to(dev)
        self.unet = unet.to(dev)
    
    def forward(self, x):
        return self.unet(self.model(x))

In [None]:
rand_seed = 0
np.random.seed(rand_seed)
torch_seed = torch.manual_seed(rand_seed)

GPUID = 6
os.environ['CUDA_VISIBLE_DEVICES'] = str(GPUID)

In [None]:
configname = "tof_EML_dot_train_settings.ini"
# configname = "exp_vgg_unet.ini"

fullconfigpath = os.path.join("../fig12_recon_exp/settings", configname)
config = configparser.ConfigParser()
_ = config.read(fullconfigpath)

# Set files to read from
nTest = int(config["Settings"]["nTest"])
LR = float(config["Settings"]["LR"])
lossFunc = config["Settings"]["lossFunc"]
measNormalization = float(config["Settings"]["measNormalization"])
loadFname = config["Settings"]["loadFname"]
datVarName = config["Settings"]["datVarName"]
displayIndices = [int(i) for i in config["Settings"]["displayIndices"].split(',')]

nEpochs = 500
batch_sz = 900
unet_nfilts = 16
vgg_weight = 0.01
showEvery = 100

In [None]:
dev = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Load in pre-trained model without U-Net
path2model = '/home/yz142/unrolled_DOT/unrolled_DOT_results/exp'
model_fname = 'model_allTrainingDat_30-Sep-2021_EML_NL=1_nEpoch=2000_lossFunc=MAE_untied=T_vgg=F_unet_nfilts=0_act=shrink'
model_d = torch.load('%s/%s.pt' % (path2model, model_fname))
full_dataset = model_d['full_dataset']

_,_,_,imSz,_ = full_dataset.getDims()

# Display reconstructed images with pre-trained model
unet = UNet_wrapper(imSz=imSz, nfilts=unet_nfilts, input_channels=1, bn_input=True)
model = Two_Step_Model(model=model_d['model'], unet=unet, dev=dev)

In [None]:
# Visualize current reconstructions

testMeas, testTruth = full_dataset.getFullTestSet()

Y_test_torch = testMeas.to(dev)
X_pred_test = model_d['model'](Y_test_torch)
predIms_test = np.reshape(X_pred_test.cpu().detach().numpy()[:,displayIndices], (imSz, imSz, -1))
showIms(predIms_test)

In [None]:
optimizer = torch.optim.Adam(model.unet.get_params(), lr=LR)
train_dataloader = DataLoader(full_dataset, batch_size=batch_sz, shuffle=True)

train_dict = {"batch_sz": batch_sz,
              "nEpochs": nEpochs,
              "showEvery": showEvery,
              "lossFunc": lossFunc,
              "untied": None,
              "nLayers": None,
              "scale_mag": None,
              "lam1": None,
              "LR": 0.,
              "vgg_weight": vgg_weight,
              "unet_nfilts": 0}

model, epoch_arr, train_losses, test_losses, misc_out = train_model(dataset_in=full_dataset, 
                                                        train_d=train_dict, dev=dev, visInds=displayIndices,
                                                        model=model, optim=optimizer)

In [None]:
# Test results
import time

Y_test_torch, truthIms_torch = full_dataset.getFullTestSet()
_, _, _, imY, imX = full_dataset.getDims()
nIms = Y_test_torch.shape[1]

model = model.to('cpu')
Y_test_torch = Y_test_torch.cpu()

starttime = time.perf_counter()
reconIms = model(Y_test_torch)
finishtime = time.perf_counter()
reconTime = finishtime - starttime

reconIms_np = np.reshape(reconIms.detach().numpy(), (imY, imX, nIms))
truthIms_np = truthIms_torch.cpu().detach().numpy()

In [None]:
savepath = os.path.join(resultpath, 'exp')

if not (os.path.isdir(savepath)):
    os.makedirs(savepath)
    
model_savename = "model_2-step_pretrained=%s" % (model_fname)

fullsavepath_model = os.path.join(savepath, model_savename + '.pt')

pydict = {
    "full_dataset": full_dataset,
    "train_dict": train_dict,
    "model": model,
    "epoch_arr": epoch_arr, 
    "train_losses": train_losses,
    "test_losses": test_losses,
    "trainInds": model_d['trainInds'],
    "testInds": model_d['testInds'],
}
for k in misc_out:
    pydict[k] = misc_out[k]

matdict = {
    "epoch_arr": epoch_arr, 
    "train_losses": train_losses,
    "test_losses": test_losses,
    "trainInds": model_d['trainInds'],
    "testInds": model_d['testInds'],
    "reconIms_np": reconIms_np,
    "truthIms_np": truthIms_np,
}

torch.save(pydict, fullsavepath_model)

fullsavepath_mat = os.path.join(savepath, model_savename + '.mat')
savemat(fullsavepath_mat, matdict)

print("Saved model to: %s" % fullsavepath_model)