In [1]:
# %run convert_ipynb_to_py_files.ipynb

In [1]:
import torch
import torch.nn.functional as F

from torch import nn
from torch.utils.data import DataLoader, Dataset
from torchvision.transforms import Resize, Compose, ToTensor, Normalize

import argparse
import os
import math 
import skimage
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image

import time
import pickle

from datetime import datetime
from pathlib import Path

from py_files.new_dataset import *

from py_files.cnn_model import *
from py_files.pigan_model import *

from py_files.seq_pi_gan_functions import *

Imported CNN and Mapping functions.
Imported PI-Gan model.
Loaded all helper functions.


#### Import classes

In [2]:
DEVICE = set_device()

print('----------------------------------')
print('Using device for training:', DEVICE)
print('----------------------------------')

----------------------------------
Using device for training: cuda
----------------------------------
----------------------------------
Using device for training: cuda
----------------------------------


# Set the run you want to evaluate 

In [3]:
sorted(os.listdir(path='saved_runs'))

['pi-gan 11-05-2021 10:36:55 ']

In [4]:
run = sorted(os.listdir(path='saved_runs'))[-1]
# run = "pi-gan 06-05-2021 11:23:14 very trained pcmra"

ARGS = load_args(run, print_changed=False)

print(run)

for var, val in vars(ARGS).items():
    print(f"{var.ljust(15)} \t {val}")

pi-gan 11-05-2021 10:36:55 
device          	 GPU
print_models    	 True
name            	 
pretrained      	 None
pretrained_best 	 train
dataset         	 small
rotate          	 False
translate       	 False
flip            	 False
norm_min_max    	 [0, 1]
seed            	 34
epochs          	 100
batch_size      	 24
eval_every      	 5
shuffle         	 True
n_coords_sample 	 500
cnn_setup       	 1
mapping_setup   	 2
dim_hidden      	 256
siren_hidden_layers 	 3
first_omega_0   	 30.0
hidden_omega_0  	 30.0
pcmra_first_omega_0 	 30.0
pcmra_hidden_omega_0 	 30.0
cnn_lr          	 0.0001
cnn_wd          	 0
mapping_lr      	 0.0001
pcmra_mapping_lr 	 0.0001
siren_lr        	 0.0001
siren_wd        	 0
pcmra_siren_lr  	 0.0001
pcmra_siren_wd  	 0


# Evaluation

In [5]:
##### data preparation #####
ARGS.rotate, ARGS.translate, ARGS.flip = False, False, False
ARGS.batch_size = 1
train_dl, val_dl, test_dl = initialize_dataloaders(ARGS)

----------------------------------
Using device for training: cuda
----------------------------------
Train subjects: 84
Val subjects: 28
Test subjects: 28


#### Load models

In [6]:
models, optims, schedulers = load_models_and_optims(ARGS)

----------------------------------
Using device for training: cuda
----------------------------------
CNN
CNN1(
  (model): Sequential(
    (0): ConvLayer(
      (model): Sequential(
        (0): Conv3d(1, 16, kernel_size=(5, 5, 5), stride=(1, 1, 1), padding=(2, 2, 2))
        (1): ReLU()
        (2): LayerNorm((24, 64, 64), eps=1e-05, elementwise_affine=True)
      )
    )
    (1): ConvLayer(
      (model): Sequential(
        (0): Conv3d(16, 16, kernel_size=(5, 5, 5), stride=(2, 2, 2), padding=(2, 2, 2))
        (1): ReLU()
        (2): LayerNorm((12, 32, 32), eps=1e-05, elementwise_affine=True)
      )
    )
    (2): ConvLayer(
      (model): Sequential(
        (0): Conv3d(16, 32, kernel_size=(5, 5, 5), stride=(1, 1, 1), padding=(2, 2, 2))
        (1): ReLU()
        (2): LayerNorm((12, 32, 32), eps=1e-05, elementwise_affine=True)
      )
    )
    (3): ConvLayer(
      (model): Sequential(
        (0): Conv3d(32, 32, kernel_size=(5, 5, 5), stride=(2, 2, 2), padding=(2, 2, 2))
     

In [15]:
best_loss = "train"

for model in models.keys():
    models[model].load_state_dict(torch.load(f"saved_runs/{run}/{model}_{best_loss}.pt"))

In [16]:
##### loss function #####
criterion = nn.BCELoss()    

In [17]:
def scroll_through_output(dataloader, shape=(64, 64, 24), first=100):
    pcmras = masks = mask_outs = torch.Tensor([])
    
    if ARGS.reconstruction != "mask": 
        pcmra_outs = torch.Tensor([])

    titles = []

    for idx, subj, proj, pcmra, coords, pcmra_array, mask_array in dataloader: 
#         print(idx, subj)
#         print(pcmra.shape)

        mask_out = get_complete_image(models, pcmra, coords, ARGS)
        if ARGS.reconstruction != "mask": 
            pcmra_out = get_complete_image(models, pcmra, coords, ARGS, output="pcmra")

        loss = criterion(mask_out, mask_array) 

        pcmras = torch.cat((pcmras, pcmra_array.cpu().view(shape).detach()), 2)
        masks = torch.cat((masks, mask_array.cpu().view(shape).detach()), 2)
        mask_outs = torch.cat((mask_outs, mask_out.cpu().view(shape).detach()), 2)
        if ARGS.reconstruction != "mask": 
            pcmra_outs = torch.cat((pcmra_outs, pcmra_out.cpu().view(shape).detach()), 2)

        titles += [f"{idx.item()} {subj[0]} {proj[0]}, loss:, {round(loss.item(), 4)}" for i in range(shape[2])]
        
        if idx >= first: 
            break
    
    if ARGS.reconstruction != "mask": 
        window = Show_images(titles, (pcmras.numpy(), "pcmras"), 
                             (masks.numpy(), "masks"), 
                             (mask_outs.numpy(), "mask output"), 
                             (pcmra_outs.numpy(), "pcmra output"))
    else:
        window = Show_images(titles, (pcmras.numpy(), "pcmras"), 
                             (masks.numpy(), "masks"), 
                             (mask_outs.numpy(), "mask output"))

    return window

In [19]:
# window = scroll_through_output(train_dl)
# window = scroll_through_output(val_dl)
window = scroll_through_output(test_dl)