In [None]:
%load_ext autoreload
%autoreload 2

# from IPython.core.interactiveshell import InteractiveShell
# InteractiveShell.ast_node_interactivity = "all"

import os
import numpy as np
import matplotlib.pyplot as plt

import torch
from torchvision import datasets
from torchvision.transforms import Compose, ToTensor, Resize, InterpolationMode

import math
import random
import sys
import h5py
from scipy.io import savemat
import configparser

setpaths_dir = "../setpaths"
sys.path.append(setpaths_dir)
from setpaths import setpaths
libpath, datpath, resultpath, basepath = setpaths(setpaths_dir)

sys.path.append(basepath)
from lib.DOTDataset_class import DOTDataset
from lib.utils import train_model, showIms

In [None]:
rand_seed = 0
np.random.seed(rand_seed)
torch_seed = torch.manual_seed(rand_seed)

GPUID = 5
os.environ['CUDA_VISIBLE_DEVICES'] = str(GPUID)

In [None]:
# Set parameters
# configname = "unet_vgg_train_fashion_test_fashion.ini"
# configname = "train_fashion_test_fashion.ini"
# configname = "train_fashion_test_mnist.ini"
# configname = "train_mnist_test_mnist.ini"
# configname = "nLayers_experiment_train_test_mnist.ini"
configname = 'circphantom_test.ini'

fullconfigpath = os.path.join("settings", configname)
config = configparser.ConfigParser()
_ = config.read(fullconfigpath)

# Forward model parameters
jac_dir = config["Settings"]["jac_dir"]
abs_mua = float(config["Settings"]["abs_mua"])
bin_sz = int(config["Settings"]["bin_sz"])
int_time = float(config["Settings"]["int_time"]) # Seconds
pile_up = float(config["Settings"]["pile_up"]) # Pile-up point in cnts/sec

train_set_select = config["Settings"]["train_set_select"] # 'f' for fashion mnist
test_set_select = config["Settings"]["test_set_select"] # 'm' for mnist

# Training parameters
nTrain = int(config["Settings"]["nTrain"])
nTest = int(config["Settings"]["nTest"])
batch_sz = int(config["Settings"]["batch_sz"])
# nLayers = int(config["Settings"]["nLayers"])
nLayers = [int(i) for i in config["Settings"]["nLayers"].split(',')]
LR = float(config["Settings"]["LR"])
scale_initial_val = LR
nEpochs = int(config["Settings"]["nEpochs"])
lam1 = config.get("Settings", "lam1", fallback=None) # if None use learnable L1 coefficient, else use lam1
untied = config["Settings"].getboolean("untied")
showEvery = int(config["Settings"]["showEvery"])
lossFunc = config["Settings"]["lossFunc"]
vgg_weight = float(config["Settings"]["vgg_weight"])
unet_nfilts = int(config["Settings"]["unet_nfilts"])
if len(config["Settings"]["displayIndices"]) > 0:
    displayIndices = [int(i) for i in config["Settings"]["displayIndices"].split(',')]
else:
    displayIndices = []
measNormalization = float(config["Settings"]["measNormalization"])

RUN_DEBUG = True
debug_vis_inds = displayIndices

In [None]:
# Load in Jacobian

J_full_fname = "%s/%s/J_multisrc_interp.mat" % (datpath, jac_dir)

with h5py.File(J_full_fname, 'r') as f:
    J = f["J_final"][:]
    Jheaders = f["Jheaders"]
    VOX_W = int(f["Jheaders"]["VOX_W"][0][0])
    VOX_L = int(f["Jheaders"]["VOX_W"][0][0])
    NBINS = int(f["Jheaders"]["NBINS"][0][0])
    SRC_W = int(f["Jheaders"]["SRC_W"][0][0])
    SRC_L = int(f["Jheaders"]["SRC_L"][0][0])
    SENS_W = int(f["Jheaders"]["SENS_W"][0][0])
    SENS_L = int(f["Jheaders"]["SENS_L"][0][0])
    bkg_sig = f["bkg_final"][:]

Jheaders_py = {
    "VOX_W": VOX_W,
    "VOX_L": VOX_L, 
    "NBINS": NBINS,
    "SRC_W": SRC_W,
    "SRC_L": SRC_L,
    "SENS_W": SENS_W,
    "SENS_L": SENS_L,
}
    
J = torch.tensor(np.transpose(J, np.flip(range(len(J.shape)))), dtype=torch.float) # Invert dimensions in importing from matlab
bkg_sig = torch.tensor(np.transpose(bkg_sig, np.flip(range(len(bkg_sig.shape)))), dtype=torch.float)

nvox = np.prod([VOX_W, VOX_L])
J = torch.reshape(J, (NBINS, -1, nvox))
bkg_sig = torch.reshape(bkg_sig, (NBINS, -1))

nsrcdet = J.shape[1]

In [None]:
# Load in datasets (MNIST and Fashion MNIST)

mnist_full_set = datasets.MNIST(
    root=datpath,
    train=True,
    download=True,
    transform=ToTensor(),
)

fashion_full_set = datasets.FashionMNIST(
    root=datpath,
    train=False,
    download=True,
    transform=ToTensor()
)

print("Done loading datasets!")

len_mnist = len(mnist_full_set)
len_fashion = len(fashion_full_set)

fashion_inds = torch.randperm(len_fashion)
mnist_inds = torch.randperm(len_mnist)

fashion_set_shuffled = fashion_full_set.data[fashion_inds,:,:]
mnist_set_shuffled = mnist_full_set.data[mnist_inds,:,:]

# Process images so they are properly sized and max intensity equal to mu_a
f_resize = Resize((VOX_W, VOX_L))

if train_set_select.lower() == 'f':
    train_dat = f_resize(fashion_set_shuffled[:nTrain,:,:].double())
elif train_set_select.lower() == 'm':
    train_dat = f_resize(mnist_set_shuffled[:nTrain,:,:].double())
    
if test_set_select.lower() == 'f':
    test_dat = f_resize(fashion_set_shuffled[nTrain:(nTrain+nTest)].double())
elif test_set_select.lower() == 'm':
    test_dat = f_resize(mnist_set_shuffled[nTrain:(nTrain+nTest)].double())

final_recon_vis = np.transpose(np.concatenate([train_dat[debug_vis_inds,:,:], test_dat[debug_vis_inds,:,:]]), (1,2,0))

if len(displayIndices) > 0:
    showIms(final_recon_vis)

test_dat *= (abs_mua / torch.amax(test_dat, dim=(1,2)))[:,None,None]
train_dat *= (abs_mua / torch.amax(train_dat, dim=(1,2)))[:,None,None]

In [None]:
# Generate noisy measurements

nbins_final = math.ceil(NBINS / bin_sz)

J_binned = torch.zeros((nbins_final, nsrcdet, nvox))
bkg_binned = torch.zeros((nbins_final, nsrcdet))
for t in range(nbins_final):
    t_start = t*bin_sz
    t_end = min((t+1)*bin_sz, NBINS)
    J_binned[t,:,:] = torch.sum(J[t_start:t_end,:,:], dim=0)
    bkg_binned[t,:] = torch.sum(bkg_sig[t_start:t_end,:], dim=0)

# Reshape Jacobian to be able to multiply with images
J_mat = torch.reshape(J_binned, (nbins_final, nsrcdet, nvox)).double()

# Reshape images
test_mu = torch.transpose(torch.reshape(test_dat, (-1, nvox)), 0, 1)
train_mu = torch.transpose(torch.reshape(train_dat, (-1, nvox)), 0, 1)

# Generate clean measurements: J*mu
m_test_clean = J_mat @ test_mu
m_train_clean = J_mat @ train_mu

# Calculate normalization factor: integrate bkg in time domain, find the max
normfac = (int_time * pile_up) / torch.amax(torch.sum(bkg_binned,0))

# replicate bkg by number of samples, 
bkg_clean_test = bkg_binned[:,:,None].repeat((1,1,nTest))
bkg_clean_train = bkg_binned[:,:,None].repeat((1,1,nTrain))

# Measurement is bkg - J*mu
abs_clean_test = torch.clip(bkg_clean_test - m_test_clean, min=0, max=None)
abs_clean_train = torch.clip(bkg_clean_train - m_train_clean, min=0, max=None)

# Scale the max background to be int_time*pile_up
bkg_clean_test_norm = bkg_clean_test * normfac
bkg_clean_train_norm = bkg_clean_train * normfac
abs_clean_test_norm = abs_clean_test * normfac
abs_clean_train_norm = abs_clean_train * normfac

# Apply poisson noise then take the difference
bkg_noisy_test = torch.poisson(bkg_clean_test_norm)
bkg_noisy_train = torch.poisson(bkg_clean_train_norm)
abs_noisy_test = torch.poisson(abs_clean_test_norm)
abs_noisy_train = torch.poisson(abs_clean_train_norm)

m_noisy_test_raw = bkg_noisy_test - abs_noisy_test
m_noisy_train_raw = bkg_noisy_train - abs_noisy_train
m_noisy_test = m_noisy_test_raw / normfac
m_noisy_train = m_noisy_train_raw / normfac

# Prepare Jacobian matrix
J_mat_np = torch.reshape(J_mat, (nbins_final*nsrcdet, nvox)).cpu().detach().numpy()

In [None]:
# Prepare model for training

test_noisy_dat = m_noisy_test.cpu().detach().numpy()
train_noisy_dat = m_noisy_train.cpu().detach().numpy()
test_truth = torch.permute(test_dat, (1,2,0)).cpu().detach().numpy()
train_truth = torch.permute(train_dat, (1,2,0)).cpu().detach().numpy()

test_noisy_dat *= measNormalization / np.amax(test_noisy_dat)
train_noisy_dat *= measNormalization / np.amax(train_noisy_dat)
full_dataset = DOTDataset(trainMeas=train_noisy_dat, trainTruth=train_truth, testMeas=test_noisy_dat, testTruth=test_truth)

In [None]:
if RUN_DEBUG:
    with torch.no_grad():
       
        trainMeas, trainTruth = full_dataset.getFullTrainSet()
        testMeas, testTruth = full_dataset.getFullTestSet()
        
        # Test that training images correctly corresponds to mnist, fashion mnist image
        if train_set_select.lower() == 'm':
            full_set_train = mnist_full_set.data[mnist_inds,:,:]
        elif train_set_select.lower() == 'f':
            full_set_train = fashion_full_set.data[fashion_inds,:,:]
        for train_i in range(nTrain):
            allDat_im = np.reshape(trainTruth[:,train_i].cpu().detach().numpy(), (VOX_W, VOX_L))
            train_im = np.reshape(f_resize(full_set_train[train_i].double()[None,:,:]).cpu().detach().numpy(), (VOX_W, VOX_L)) 
            allDat_im /= np.amax(allDat_im)
            train_im /= np.amax(train_im)
            assert (np.allclose(allDat_im, train_im)), "Training images do not match!"
        print("All Training Images Match!")
    
        if test_set_select.lower() == 'm':
            full_set_test = mnist_full_set.data[mnist_inds,:,:]
        elif test_set_select.lower() == 'f':
            full_set_test = fashion_full_set.data[fashion_inds,:,:]
        for test_i in range(nTest):
            allDat_im = np.reshape(testTruth[:,test_i].cpu().detach().numpy(), (VOX_W, VOX_L))
            test_im = np.reshape(f_resize(full_set_test[nTrain+test_i].double()[None,:,:]).cpu().detach().numpy(), (VOX_W, VOX_L))
            allDat_im /= np.amax(allDat_im)
            test_im /= np.amax(test_im)
            assert (np.allclose(allDat_im, test_im)), "Testing images do not match!"
        print("All Testing Images Match!")
        print("\n")
    
        # Ensure no overlap
        train_ims_check = trainTruth
        test_ims_check = testTruth
        for p in range(nTest):
            im_p_test = torch.reshape(test_ims_check[:,p], (VOX_W, VOX_L))
            for q in range(nTrain):
                im_q_train = torch.reshape(train_ims_check[:,q], (VOX_W, VOX_L))
                assert (not torch.allclose(im_p_test, im_q_train)), "Found Overlap"
        print("Found no overlaps!")
        print("\n")
    
        # Visualize collocated measurements
        if SRC_W*SRC_L*SENS_W*SENS_L == nsrcdet:
            
            bkg_clean_test_norm = torch.reshape(bkg_clean_test_norm, (nbins_final, SRC_W, SRC_L, SENS_W, SENS_L, nTest))
            abs_clean_test_norm = torch.reshape(abs_clean_test_norm, (nbins_final, SRC_W, SRC_L, SENS_W, SENS_L, nTest))
            bkg_noisy_test = torch.reshape(bkg_noisy_test, (nbins_final, SRC_W, SRC_L, SENS_W, SENS_L, nTest))
            abs_noisy_test = torch.reshape(abs_noisy_test, (nbins_final, SRC_W, SRC_L, SENS_W, SENS_L, nTest))
            m_noisy_test_raw = torch.reshape(m_noisy_test_raw, (nbins_final, SRC_W, SRC_L, SENS_W, SENS_L, nTest))
            m_noisy_test_norm = torch.reshape(m_noisy_test, (nbins_final, SRC_W, SRC_L, SENS_W, SENS_L, nTest))
            
            bkg_clean_train_norm = torch.reshape(bkg_clean_train_norm, (nbins_final, SRC_W, SRC_L, SENS_W, SENS_L, nTrain))
            abs_clean_train_norm = torch.reshape(abs_clean_train_norm, (nbins_final, SRC_W, SRC_L, SENS_W, SENS_L, nTrain))
            bkg_noisy_train = torch.reshape(bkg_noisy_train, (nbins_final, SRC_W, SRC_L, SENS_W, SENS_L, nTrain))
            abs_noisy_train = torch.reshape(abs_noisy_train, (nbins_final, SRC_W, SRC_L, SENS_W, SENS_L, nTrain))
            m_noisy_train_raw = torch.reshape(m_noisy_train_raw, (nbins_final, SRC_W, SRC_L, SENS_W, SENS_L, nTrain))
            m_noisy_train_norm = torch.reshape(m_noisy_train, (nbins_final, SRC_W, SRC_L, SENS_W, SENS_L, nTrain))
            
            print("Visualize collocated measurements")
            for j in range(2): # Training and test
                for i in range(len(debug_vis_inds)):
                    if j == 0:
                        bkg_clean = bkg_clean_test_norm[:,:,:,:,:,debug_vis_inds[i]].cpu().detach().numpy()
                        abs_clean = abs_clean_test_norm[:,:,:,:,:,debug_vis_inds[i]].cpu().detach().numpy()
                        bkg_noisy = bkg_noisy_test[:,:,:,:,:,debug_vis_inds[i]].cpu().detach().numpy()
                        abs_noisy = abs_noisy_test[:,:,:,:,:,debug_vis_inds[i]].cpu().detach().numpy()
                        m_noisy_raw = m_noisy_test_raw[:,:,:,:,:,debug_vis_inds[i]].cpu().detach().numpy()
                        m_noisy_norm = m_noisy_test_norm[:,:,:,:,:,debug_vis_inds[i]].cpu().detach().numpy()
                        truth_debug = test_truth[:,:,debug_vis_inds[i]]
                        title_str = "test"
                    else:
                        bkg_clean = bkg_clean_train_norm[:,:,:,:,:,debug_vis_inds[i]].cpu().detach().numpy()
                        abs_clean = abs_clean_train_norm[:,:,:,:,:,debug_vis_inds[i]].cpu().detach().numpy()
                        bkg_noisy = bkg_noisy_train[:,:,:,:,:,debug_vis_inds[i]].cpu().detach().numpy()
                        abs_noisy = abs_noisy_train[:,:,:,:,:,debug_vis_inds[i]].cpu().detach().numpy()
                        m_noisy_raw = m_noisy_train_raw[:,:,:,:,:,debug_vis_inds[i]].cpu().detach().numpy()
                        m_noisy_norm = m_noisy_train_norm[:,:,:,:,:,debug_vis_inds[i]].cpu().detach().numpy()
                        truth_debug = train_truth[:,:,debug_vis_inds[i]]
                        title_str = "train"

                    print("%s measurements" % title_str)
                    # Generate colocated measurements
                    bkg_clean_coloc = np.zeros((SRC_W, SRC_L))
                    abs_clean_coloc = np.zeros((SRC_W, SRC_L))
                    bkg_noisy_coloc = np.zeros((SRC_W, SRC_L))
                    abs_noisy_coloc = np.zeros((SRC_W, SRC_L))
                    m_noisy_raw_coloc = np.zeros((SRC_W, SRC_L))
                    m_noisy_norm_coloc = np.zeros((SRC_W, SRC_L))

                    for sc in range(SRC_W):
                        for sr in range(SRC_L):
                            bkg_clean_coloc[sc, sr] = np.sum(bkg_clean[:,sc,sr,sc,sr])
                            abs_clean_coloc[sc, sr] = np.sum(abs_clean[:,sc,sr,sc,sr])
                            bkg_noisy_coloc[sc, sr] = np.sum(bkg_noisy[:,sc,sr,sc,sr])
                            abs_noisy_coloc[sc, sr] = np.sum(abs_noisy[:,sc,sr,sc,sr])
                            m_noisy_raw_coloc[sc, sr] = np.sum(m_noisy_raw[:,sc,sr,sc,sr])
                            m_noisy_norm_coloc[sc, sr] = np.sum(m_noisy_norm[:,sc,sr,sc,sr])

                    title_str_full = "%s im %d" % (title_str, debug_vis_inds[i])
                    print("Tot # of background photons: %.2e" % (int_time * pile_up))
                    plt.figure(figsize=(30,3))
                    plt.subplot(1,7,1)
                    plt.imshow(truth_debug)
                    plt.title(title_str_full)
                    plt.axis("off")
                    _ = plt.colorbar()
                    plt.subplot(1,7,2)
                    plt.imshow(bkg_clean_coloc)
                    plt.title("Bkg Clean")
                    plt.axis("off")
                    _ = plt.colorbar()
                    plt.subplot(1,7,3)
                    plt.imshow(abs_clean_coloc)
                    plt.title("Abs Clean")
                    plt.axis("off")
                    _ = plt.colorbar()
                    plt.subplot(1,7,4)
                    plt.imshow(bkg_noisy_coloc)
                    plt.title("Bkg Noisy")
                    plt.axis("off")
                    _ = plt.colorbar()
                    plt.subplot(1,7,5)
                    plt.imshow(abs_noisy_coloc)
                    plt.title("Abs Noisy")
                    plt.axis("off")
                    _ = plt.colorbar()
                    plt.subplot(1,7,6)
                    plt.imshow(m_noisy_raw_coloc)
                    plt.title("Meas (photon counts)")
                    plt.axis("off")
                    _ = plt.colorbar()
                    plt.subplot(1,7,7)
                    plt.imshow(m_noisy_norm_coloc)
                    plt.title("Meas (normalized)")
                    plt.axis("off")
                    _ = plt.colorbar()
                    plt.show()
        print("\n")
                
                
        # Visualize training and test sets
        print("Visualize training and test sets")
        figure = plt.figure(figsize=(20, 8))
        cols, rows = 6, 3
        ncols_per_class = 3
        train_rand_inds = random.sample(range(nTrain), rows*ncols_per_class)
        test_rand_inds = random.sample(range(nTest), rows*ncols_per_class)
        train_k = 0
        test_k = 0
        for i in range(1, cols * rows + 1):
            if (((i-1) % cols) // ncols_per_class) > 0:
                img = testTruth[:,test_rand_inds[test_k]]
                im_title = "Test"
                test_k += 1
            else:
                img = trainTruth[:,train_rand_inds[train_k]]
                im_title = "Train"
                train_k += 1
            img = torch.reshape(img, (VOX_W, VOX_L))
            figure.add_subplot(rows, cols, i)
            plt.axis("off")
            plt.title(im_title)
            plt.imshow(img.squeeze())
            plt.colorbar()
        plt.show()

In [None]:
# -------------------------------------------------------------------------
# Train model

# Send to GPU if available
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

for L_i in range(len(nLayers)):

    NL = nLayers[L_i]
    
    print("Running simulation %d/%d with %d layers" % (L_i+1, len(nLayers), NL))
    
    # Set learning parameters
    train_dict = {"nLayers": NL,
                  "scale_mag": scale_initial_val,
                  "lam1": lam1,
                  "LR": LR,
                  "batch_sz": batch_sz,
                  "nEpochs": nEpochs,
                  "showEvery": showEvery,
                  "untied": untied,
                  "lossFunc": lossFunc,
                  "vgg_weight": vgg_weight,
                  "unet_nfilts": unet_nfilts,}

    # Perform training sequence
    model, epoch_arr, train_losses, test_losses, misc_out = train_model(full_dataset, train_dict, device, A=None, visInds=displayIndices)

    # -------------------------------------------------------------------------
    # Test Trained model
    min_thresh = 0.0

    testMeas, testTruth = full_dataset.getFullTestSet()
    trainMeas, trainTruth = full_dataset.getFullTrainSet()

    cpu_dev = 'cpu'
    model.to(cpu_dev)

    with torch.no_grad():
        if unet_nfilts > 0:
            unet = misc_out["unet"]
            unet.send2dev(cpu_dev)
            X_test_torch = unet(model(testMeas.to(cpu_dev)))

            import math
            nBatches = math.ceil(nTrain / batch_sz)
            X_train_torch = torch.zeros(VOX_W*VOX_L, nTrain)
            for b in range(nBatches):
                b_start = b*batch_sz
                b_end = (b+1)*batch_sz
                X_train_torch[:,b_start:b_end] = unet(model(trainMeas[:,b_start:b_end].to(cpu_dev)))
        else:
            X_test_torch = model(testMeas.to(cpu_dev))
            X_train_torch = model(trainMeas.to(cpu_dev))

        meas_test_np = testMeas.cpu().detach().numpy()
        recon_test_np = np.reshape(X_test_torch.cpu().detach().numpy(), (VOX_W, VOX_L, -1))
        truth_test_np = np.reshape(testTruth.cpu().detach().numpy(), (VOX_W, VOX_L, -1))

        meas_train_np = trainMeas.cpu().detach().numpy()
        recon_train_np = np.reshape(X_train_torch.cpu().detach().numpy(), (VOX_W, VOX_L, -1))
        truth_train_np = np.reshape(trainTruth.cpu().detach().numpy(), (VOX_W, VOX_L, -1))

    print("Final reconstructions with clipping")
    final_recon_vis = recon_test_np[:,:,displayIndices]
    showIms(final_recon_vis)

    # -------------------------------------------------------------------------
    # Save results

    savepath = os.path.join(resultpath, 'sim')

    if not (os.path.isdir(savepath)):
        os.makedirs(savepath)

    if untied:
        untied_str = 'T'
    else:
        untied_str = 'F'
    if vgg_weight > 0:
        vgg_str = 'T'
    else:
        vgg_str = 'F'
    model_savename = "model_%s_train=%s_test=%s_NL=%d_nEpoch=%d_lossFunc=%s_untied=%s_vgg=%s_unet_nfilts=%d" % (jac_dir, train_set_select, test_set_select, NL, nEpochs, lossFunc, untied_str, vgg_str, unet_nfilts)

    fullsavepath_model = os.path.join(savepath, model_savename + '.pt')
    fullsavepath_mat = os.path.join(savepath, model_savename + '.mat')

    pydict = {
        "train_dict": train_dict,
        "model": model,
        "epoch_arr": epoch_arr, 
        "train_losses": train_losses,
        "test_losses": test_losses,
        "full_dataset": full_dataset,
    }

    for k in misc_out:
        pydict[k] = misc_out[k]

    matdict = {
        "meas_test_np": meas_test_np,
        "recon_test_np": recon_test_np,
        "truth_test_np": truth_test_np,
        "meas_train_np": meas_train_np,
        "recon_train_np": recon_train_np,
        "truth_train_np": truth_train_np,
        "epoch_arr": epoch_arr, 
        "train_losses": train_losses,
        "truthIms": truth_test_np,
        "diff_meas": meas_test_np[None,...],
        "test_losses": test_losses,
        "J_mat_np": J_mat_np,
        "Jheaders": Jheaders_py,
    }

    torch.save(pydict, fullsavepath_model)
    savemat(fullsavepath_mat, matdict)

    print("Saved model to: %s" % fullsavepath_model)