In [None]:
%run /home/ptenkaate/scratch/Master-Thesis/convert_ipynb_to_py_files.ipynb

In [1]:
import torch
import torch.nn.functional as F

from torch import nn
from torch.utils.data import DataLoader, Dataset
from torchvision.transforms import Resize, Compose, ToTensor, Normalize

import argparse
import os
import math 
import skimage
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image

import time
import pickle

from datetime import datetime
from pathlib import Path

from py_files.new_dataset import *

from py_files.cnn_model import *
from py_files.pigan_model import *

from py_files.seq_pi_gan_functions import *

%matplotlib qt

Imported CNN and Mapping functions.
Imported PI-Gan model.
Loaded all helper functions.


The support for Qt4  was deprecated in Matplotlib 3.3 and will be removed two minor releases later.
  from matplotlib.backends.qt_compat import QtGui


#### Import classes

In [2]:
DEVICE = set_device()

print('----------------------------------')
print('Using device for training:', DEVICE)
print('----------------------------------')

----------------------------------
Using device for training: cuda
----------------------------------
----------------------------------
Using device for training: cuda
----------------------------------


# Set the run you want to evaluate 

In [3]:
sorted(os.listdir(path='saved_runs'))

['pi-gan 21-05-2021 18:39:52 cnn setup 5 mapping setup 8',
 'pi-gan 21-05-2021 23:33:42 cnn setup 5 mapping setup 7',
 'pi-gan 22-05-2021 04:28:59 cnn setup 6 mapping setup 7',
 'pi-gan 22-05-2021 09:21:36 cnn setup 7 mapping setup 7',
 'pi-gan 22-05-2021 15:04:31 cnn setup 8 mapping setup 7',
 'pi-gan 22-05-2021 18:48:27 cnn setup 9 mapping setup 7',
 'pi-gan 22-05-2021 22:29:28 cnn setup 10 mapping setup 7',
 'pi-gan 23-05-2021 03:47:35 cnn setup 13 mapping setup 7',
 'pi-gan 23-05-2021 08:39:43 cnn setup 11 mapping setup 6',
 'pi-gan 23-05-2021 10:29:06 cnn setup 12 mapping setup 6',
 'pi-gan 23-05-2021 14:22:41 cnn setup 12 mapping setup 6 continuing',
 'pi-gan 23-05-2021 16:26:08 cnn setup 14 mapping setup 7',
 'pi-gan 23-05-2021 20:33:49 cnn setup 14 mapping setup 7 continuing',
 'pi-gan 23-05-2021 22:18:32 cnn setup 15 mapping setup 7',
 'pi-gan 24-05-2021 01:52:31 cnn setup 6 mapping setup 7 with weight decay',
 'pi-gan 24-05-2021 16:53:49 cnn setup 14 mapping setup 7 weight de

In [5]:
# run = sorted(os.listdir(path='saved_runs'))[-1]
run = "pi-gan 26-05-2021 15:49:32 no translation"

ARGS = load_args(run, print_changed=False)

print(run)

for var, val in vars(ARGS).items():
    print(f"{var.ljust(15)} \t {val}")

pi-gan 26-05-2021 15:49:32 no translation
device          	 GPU
print_models    	 False
name            	 no translation
pretrained      	 None
pretrained_best_dataset 	 train
pretrained_best_loss 	 mask
pretrained_models 	 None
pretrained_lr_reset 	 None
dataset         	 small
rotate          	 False
translate       	 False
flip            	 False
crop            	 False
stretch         	 False
stretch_factor  	 1.2
norm_min_max    	 [0, 1]
seed            	 34
pcmra_epochs    	 2500
mask_epochs     	 2000
batch_size      	 24
eval_every      	 50
shuffle         	 True
n_coords_sample 	 5000
cnn_setup       	 14
pcmra_train_cnn 	 True
mask_train_cnn  	 False
mapping_setup   	 7
dim_hidden      	 256
siren_hidden_layers 	 3
first_omega_0   	 30.0
hidden_omega_0  	 30.0
pcmra_first_omega_0 	 30.0
pcmra_hidden_omega_0 	 30.0
cnn_lr          	 0.0001
cnn_wd          	 0
mapping_lr      	 0.0001
pcmra_mapping_lr 	 0.0001
siren_lr        	 0.0001
siren_wd        	 0
pcmra_siren_lr  	 0.00

# Evaluation

In [7]:
ARGS.batch_size = 3

train_dl, val_dl, test_dl = initialize_dataloaders(ARGS)

----------------------------------
Using device for training: cuda
----------------------------------
Train subjects: 84
Val subjects: 28
Test subjects: 28


#### Load models

In [8]:
ARGS.print_models=False
models, optims, schedulers = load_models_and_optims(ARGS)

----------------------------------
Using device for training: cuda
----------------------------------
CNN
MAPPING
SIREN
PCMRA_MAPPING
PCMRA_SIREN


In [9]:
ARGS.pretrained_best_dataset = "train"
ARGS.pretrained_best_loss = "mask"

load_pretrained_models(run, ARGS.pretrained_best_dataset, ARGS.pretrained_best_loss,
                       models, optims, pretrained_models=None)

Loading params from cnn
Loading params from mapping
Loading params from siren
Loading params from pcmra_mapping
Loading params from pcmra_siren


In [10]:
##### loss function #####
mask_criterion = nn.BCELoss()    
pcmra_criterion = nn.MSELoss()    

In [11]:
def scroll_through_output(dataloader, shape=[24, 64, 64], transform=False):
    with torch.no_grad():
        pcmras = masks = loss_covers = mask_outs = pcmra_outs = torch.Tensor([])

        titles = []

        mask_losses, pcmra_losses = [], []

        for batch in dataloader: 

            if transform:
                batch = transform_batch(batch, ARGS)


            _, _, _, pcmra, coords, pcmra_array, mask_array, loss_cover_array = get_siren_batch(batch)


            mask_out = get_complete_image(models, pcmra, coords, ARGS)
            pcmra_out = get_complete_image(models, pcmra, coords, ARGS, output="pcmra")

            mask_loss = mask_criterion(mask_out, mask_array) 
            pcmra_loss = pcmra_criterion(pcmra_out, pcmra_array) 

            mask_losses.append(mask_loss.item())
            pcmra_losses.append(pcmra_loss.item())

            for p, m, po, mo, lc in zip(pcmra_array, mask_array, pcmra_out, mask_out, loss_cover_array):
                pcmras = torch.cat((pcmras, p.cpu().view(shape).detach().permute(1, 2, 0)), 2)
                masks = torch.cat((masks, m.cpu().view(shape).detach().permute(1, 2, 0)), 2)
                mask_outs = torch.cat((mask_outs, mo.cpu().view(shape).detach().permute(1, 2, 0)), 2)
                pcmra_outs = torch.cat((pcmra_outs, po.cpu().view(shape).detach().permute(1, 2, 0)), 2)
                loss_covers = torch.cat((loss_covers, lc.cpu().view(shape).detach().permute(1, 2, 0)), 2)


        
        print("MASK: ", np.array(mask_losses).mean())
        print("PCMRA:", np.array(pcmra_losses).mean())

        titles = "Results"

        window = Show_images(titles, (pcmras.numpy(), "pcmras"), 
                                     (masks.numpy(), "masks"),
#                                      (masks.numpy().round(), "masks rounded"),
#                                      (loss_covers.numpy(), "loss_covers"),
#                             )
                                     (mask_outs.numpy(), "mask output"), 
                                     (pcmra_outs.numpy(), "pcmra output"))

        return window

In [12]:
ARGS.rotate = True
ARGS.translate = True
ARGS.flip = True
ARGS.crop = True
ARGS.stretch = True

window = scroll_through_output(train_dl, transform=False)
# window = scroll_through_output(val_dl)
# window = scroll_through_output(test_dl)

MASK:  0.005189981727328684
PCMRA: 0.00047619071119697765


In [None]:
# criterion = nn.MSELoss()

# for idx, subj, proj, pcmra, coords, pcmra_array, mask_array in val_dl:
#     if idx == 1:
#         break

# print(pcmra.shape)
# out = get_complete_image(models, pcmra, coords, ARGS, output="pcmra")
# out = out.view(1, 1, 64, 64, 24).permute(0, 1, 4, 2, 3)

# plt.imshow(pcmra[0, 0, 8, :, :].cpu())
# plt.show()

# plt.imshow(out[0, 0, 8, :, :].cpu().detach())
# plt.show()

# print(criterion(pcmra, out).item())