In [None]:
%load_ext autoreload
%autoreload 2

import os
import sys

import numpy as np
import matplotlib.pyplot as plt
import torch

from scipy.io import loadmat, savemat

setpaths_dir = "../setpaths"
sys.path.append(setpaths_dir)
from setpaths import setpaths
libpath, datpath, resultpath, basepath = setpaths(setpaths_dir)

sys.path.append(basepath)
from lib.DOTDataset_class import DOTDataset
from lib.utils import train_model, showIms
from lib.vgg_loss import Vgg16

In [None]:
dev = 'cuda'
run_sim = False

if run_sim:
    path2model = '/home/yz142/unrolled_DOT/unrolled_DOT_results/sim'
    model_fname = 'model_5_1_22_unrolled_jac_train=f_test=f_NL=3_nEpoch=200_lossFunc=MSE_untied=T_vgg=T_unet_nfilts=32'
    model_d = torch.load('%s/%s.pt' % (path2model, model_fname))
    dat_d = loadmat('%s/%s.mat' % (path2model, model_fname))
    meas_test_torch = torch.tensor(dat_d['meas_test_np']).to(dev)
    truth_test_torch = torch.tensor(dat_d['truth_test_np']).to(dev)
    imSz = truth_test_torch.shape[0]
else:
    path2model = '/home/yz142/unrolled_DOT/unrolled_DOT_results/exp'
    model_fname = 'model_allTrainingDat_30-Sep-2021_EML_NL=1_nEpoch=400_lossFunc=MAE_untied=T_vgg=T_unet_nfilts=16_act=shrink'
    model_d = torch.load('%s/%s.pt' % (path2model, model_fname))
    meas_test_torch = model_d['full_dataset'].getFullTestSet()[0].to(dev)
    truth_test_torch = model_d['full_dataset'].getFullTestSet()[1].to(dev)
    _,_,_,imSz,_ = model_d['full_dataset'].getDims()

In [None]:
model = model_d['model'].to(dev)
unet_trained = model_d['unet'].to(dev)
recon_test_torch = unet_trained(model(meas_test_torch))

vgg_net = Vgg16(requires_grad=False).to(dev)

truth_ims = torch.reshape(truth_test_torch, (imSz, imSz, -1))
recon_ims = torch.reshape(recon_test_torch, (imSz, imSz, -1))
truth_reshape = torch.permute(truth_ims, (2, 0, 1))[:,None,:,:].repeat(1,3,1,1).float()
recon_reshape = torch.permute(recon_ims, (2, 0, 1))[:,None,:,:].repeat(1,3,1,1).float()

vgg_truth = vgg_net(truth_reshape)
vgg_recon = vgg_net(recon_reshape)

In [None]:
savepath = os.path.join(resultpath, 'vgg')

model_savename = "model_vgg_pretrained=%s" % (model_fname)


if not (os.path.isdir(savepath)):
    os.makedirs(savepath)
    
matdict = {
    "vgg_recon2_2": vgg_recon.relu2_2.cpu().detach().numpy(),
    "vgg_recon4_3": vgg_recon.relu4_3.cpu().detach().numpy(),
    "vgg_truth2_2": vgg_truth.relu2_2.cpu().detach().numpy(),
    "vgg_truth4_3": vgg_truth.relu4_3.cpu().detach().numpy(),
    "truth_ims": truth_ims.cpu().detach().numpy(),
    "recon_ims": recon_ims.cpu().detach().numpy(),
}


fullsavepath_mat = os.path.join(savepath, model_savename + '.mat')
savemat(fullsavepath_mat, matdict)

print("Saved model to: %s" % fullsavepath_mat)

In [None]:
plot_ind = 40

nplot_cols = 8
plot_inds2_2 = np.arange(0, 64)
plot_inds4_3 = np.arange(0, 200)

# Plot ground truth
print("Truth")
truth_plot = truth_reshape[plot_ind,0,:,:].cpu().detach().numpy()
plt.imshow(truth_plot)
plt.gca().axis('off')
plt.show()

# Plot relu2_2
print("ReLU2_2")
nplot_rows = int(np.ceil(len(plot_inds2_2)/float(nplot_cols)))
plt.figure(figsize=(nplot_cols*3, nplot_rows*3))
for i in range(len(plot_inds2_2)):
    relu2_2_i = vgg_truth.relu2_2[plot_ind,plot_inds2_2[i],:,:].cpu().detach().numpy()
    plt.subplot(nplot_rows, nplot_cols, i+1)
    plt.imshow(relu2_2_i)
    plt.gca().axis('off')
plt.show()

# Plot relu4_3
print("ReLU4_3")
nplot_rows = int(np.ceil(len(plot_inds4_3)/float(nplot_cols)))
plt.figure(figsize=(nplot_cols*3, nplot_rows*3))
for i in range(len(plot_inds4_3)):
    relu4_3_i = vgg_truth.relu4_3[plot_ind,plot_inds4_3[i],:,:].cpu().detach().numpy()
    plt.subplot(nplot_rows, nplot_cols, i+1)
    plt.imshow(relu4_3_i)
    plt.gca().axis('off')
plt.show()

In [None]:
# Plot recon
print("Recon")
recon_plot = recon_reshape[plot_ind,0,:,:].cpu().detach().numpy()
plt.imshow(recon_plot)
plt.gca().axis('off')
plt.show()

# Plot relu2_2
print("ReLU2_2")
nplot_rows = int(np.ceil(len(plot_inds2_2)/float(nplot_cols)))
plt.figure(figsize=(nplot_cols*3, nplot_rows*3))
for i in range(len(plot_inds2_2)):
    relu2_2_i = vgg_recon.relu2_2[plot_ind,plot_inds2_2[i],:,:].cpu().detach().numpy()
    plt.subplot(nplot_rows, nplot_cols, i+1)
    plt.imshow(relu2_2_i)
    plt.gca().axis('off')
plt.show()

# Plot relu4_3
print("ReLU4_3")
nplot_rows = int(np.ceil(len(plot_inds4_3)/float(nplot_cols)))
plt.figure(figsize=(nplot_cols*3, nplot_rows*3))
for i in range(len(plot_inds4_3)):
    relu4_3_i = vgg_recon.relu4_3[plot_ind,plot_inds4_3[i],:,:].cpu().detach().numpy()
    plt.subplot(nplot_rows, nplot_cols, i+1)
    plt.imshow(relu4_3_i)
    plt.gca().axis('off')
plt.show()