In [1]:
# %run convert_ipynb_to_py_files.ipynb

In [2]:
import torch
import torch.nn.functional as F

from torch import nn
from torch.utils.data import DataLoader, Dataset
from torchvision.transforms import Resize, Compose, ToTensor, Normalize

import argparse
import os
import math 
import skimage
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image

import time
import pickle

from datetime import datetime
from pathlib import Path

# from data_classes.py_files.custom_datasets import *
# from data_classes.py_files.data_classes import *
from data_classes.py_files.new_dataset import *
# 
from model_classes.py_files.cnn_model import *
from model_classes.py_files.pigan_model import *

from functions import *

%matplotlib qt

Imported CNN and Mapping functions.
Imported PI-Gan model.
Loaded all helper functions.


The support for Qt4  was deprecated in Matplotlib 3.3 and will be removed two minor releases later.
  from matplotlib.backends.qt_compat import QtGui


#### Import classes

In [3]:
DEVICE = set_device()

print('----------------------------------')
print('Using device for training:', DEVICE)
print('----------------------------------')

----------------------------------
Using device for training: cuda
----------------------------------
----------------------------------
Using device for training: cuda
----------------------------------


# Set the run you want to evaluate 

In [4]:
sorted(os.listdir(path='saved_runs'))

['old',
 'pi-gan 03-05-2021 12:25:53 ',
 'pi-gan 03-05-2021 13:13:07 ',
 'pi-gan 03-05-2021 14:00:17 ',
 'pi-gan 03-05-2021 14:47:32 ',
 'pi-gan 03-05-2021 15:34:40 ',
 'pi-gan 03-05-2021 16:12:17 ',
 'pi-gan 03-05-2021 16:50:03 ',
 'pi-gan 03-05-2021 17:37:11 ',
 'pi-gan 03-05-2021 18:23:52 ',
 'pi-gan 03-05-2021 19:10:42 ',
 'pi-gan 03-05-2021 19:57:58 ',
 'pi-gan 03-05-2021 20:36:38 ',
 'pi-gan 03-05-2021 21:23:56 ',
 'pi-gan 03-05-2021 22:11:17 ',
 'pi-gan 03-05-2021 22:58:42 ',
 'pi-gan 03-05-2021 23:46:00 ',
 'pi-gan 04-05-2021 00:23:45 ',
 'pi-gan 04-05-2021 01:01:32 ',
 'pi-gan 04-05-2021 01:48:08 ',
 'pi-gan 04-05-2021 02:34:43 ',
 'pi-gan 04-05-2021 03:21:39 ',
 'pi-gan 04-05-2021 04:08:37 ',
 'pi-gan 04-05-2021 04:46:19 ',
 'pi-gan 04-05-2021 05:24:02 ',
 'pi-gan 04-05-2021 06:10:53 ',
 'pi-gan 04-05-2021 06:57:42 ',
 'pi-gan 04-05-2021 07:44:33 ',
 'pi-gan 04-05-2021 08:31:23 ',
 'pi-gan 04-05-2021 09:09:03 ',
 'pi-gan 04-05-2021 09:46:50 ',
 'pi-gan 04-05-2021 10:34:13 ',


In [12]:
# run = sorted(os.listdir(path='saved_runs'))[-1]
run = "pi-gan 06-05-2021 11:23:14 very trained pcmra"

ARGS = load_args(run, print_changed=False)

print(run)

for var, val in vars(ARGS).items():
    print(f"{var.ljust(15)} \t {val}")

pi-gan 06-05-2021 11:23:14 very trained pcmra
device          	 GPU
print_models    	 False
name            	 
pretrained      	 None
pretrained_best 	 train
reconstruction  	 pcmra
share_mapping   	 False
pcmra_lambda    	 1
mask_lambda     	 1
dataset         	 small
rotate          	 False
translate       	 False
flip            	 False
norm_min_max    	 [0, 1]
seed            	 34
epochs          	 500
batch_size      	 2
eval_every      	 5
shuffle         	 True
n_coords_sample 	 -1
cnn_setup       	 1
mapping_setup   	 2
dim_hidden      	 256
siren_hidden_layers 	 3
first_omega_0   	 30.0
hidden_omega_0  	 30.0
pcmra_first_omega_0 	 30
pcmra_hidden_omega_0 	 30.0
cnn_lr          	 0.0002
cnn_wd          	 0
mapping_lr      	 0.0002
pcmra_mapping_lr 	 0.0002
siren_lr        	 0.0001
siren_wd        	 0
pcmra_siren_lr  	 0.0002
pcmra_siren_wd  	 0
scheduler_on    	 pcmra


# Evaluation

In [13]:
##### data preparation #####
ARGS.rotate, ARGS.translate, ARGS.flip = False, False, False
ARGS.batch_size = 1
train_dl, val_dl, test_dl = initialize_dataloaders(ARGS)

----------------------------------
Using device for training: cuda
----------------------------------
Train subjects: 84
[['16-01-22_Jarik_kt-pca_done2', 'Aorta Volunteers', 'scaled'], ['16-01-27 Valentine kt-pca_done2', 'Aorta Volunteers', 'scaled'], ['16-01-27_Claudia_kt-pca_done2', 'Aorta Volunteers', 'scaled'], ['16-02-03_Feiko_kt-pca_done2', 'Aorta Volunteers', 'scaled'], ['16-02-10_Luuk_kt-pca_done2', 'Aorta Volunteers', 'scaled'], ['16-02-10_Michelle_kt-pca_done', 'Aorta Volunteers', 'scaled'], ['16-02-12_Marjolein_kt-pca_done', 'Aorta Volunteers', 'scaled'], ['16-03-18_Ruud_kt-pca_done2', 'Aorta Volunteers', 'scaled'], ['16-04-13_Pim_kt-pca_done2', 'Aorta Volunteers', 'scaled'], ['16-05-25_Emile_kt-pca_done', 'Aorta Volunteers', 'scaled']]
Val subjects: 28
Test subjects: 28


#### Load models

In [14]:
models, optims, schedulers = load_models_and_optims(ARGS)

----------------------------------
Using device for training: cuda
----------------------------------
CNN
MAPPING
SIREN
PCMRA_MAPPING
PCMRA_SIREN


In [15]:
best_loss = "train"

for model in models.keys():
    models[model].load_state_dict(torch.load(f"saved_runs/{run}/{model}_{best_loss}.pt"))

In [16]:
##### loss function #####
criterion = nn.BCELoss()    

In [17]:
def scroll_through_output(dataloader, shape=(64, 64, 24), first=100):
    pcmras = masks = mask_outs = torch.Tensor([])
    
    if ARGS.reconstruction != "mask": 
        pcmra_outs = torch.Tensor([])

    titles = []

    for idx, subj, proj, pcmra, coords, pcmra_array, mask_array in dataloader: 
#         print(idx, subj)
#         print(pcmra.shape)

        mask_out = get_complete_image(models, pcmra, coords, ARGS)
        if ARGS.reconstruction != "mask": 
            pcmra_out = get_complete_image(models, pcmra, coords, ARGS, output="pcmra")

        loss = criterion(mask_out, mask_array) 

        pcmras = torch.cat((pcmras, pcmra_array.cpu().view(shape).detach()), 2)
        masks = torch.cat((masks, mask_array.cpu().view(shape).detach()), 2)
        mask_outs = torch.cat((mask_outs, mask_out.cpu().view(shape).detach()), 2)
        if ARGS.reconstruction != "mask": 
            pcmra_outs = torch.cat((pcmra_outs, pcmra_out.cpu().view(shape).detach()), 2)

        titles += [f"{idx.item()} {subj[0]} {proj[0]}, loss:, {round(loss.item(), 4)}" for i in range(shape[2])]
        
        if idx >= first: 
            break
    
    if ARGS.reconstruction != "mask": 
        window = Show_images(titles, (pcmras.numpy(), "pcmras"), 
                             (masks.numpy(), "masks"), 
                             (mask_outs.numpy(), "mask output"), 
                             (pcmra_outs.numpy(), "pcmra output"))
    else:
        window = Show_images(titles, (pcmras.numpy(), "pcmras"), 
                             (masks.numpy(), "masks"), 
                             (mask_outs.numpy(), "mask output"))

    return window

In [19]:
# window = scroll_through_output(train_dl)
# window = scroll_through_output(val_dl)
window = scroll_through_output(test_dl)