In [1]:
import torch
from torch import nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
import os

from PIL import Image
from torchvision.transforms import Resize, Compose, ToTensor, Normalize
import numpy as np
import skimage
import matplotlib.pyplot as plt

import time
import pickle

from datetime import datetime
from pathlib import Path

import math 

#### Import classes

In [2]:
%run "custom_datasets.ipynb"
%run "Model Classes/cnn_model.ipynb"
%run "Model Classes/pigan_model.ipynb"

%matplotlib qt

Imported classes.
Imported CNN model.
Imported PI-Gan model.


The support for Qt4  was deprecated in Matplotlib 3.3 and will be removed two minor releases later.
  from matplotlib.backends.qt_compat import QtGui


In [3]:
def set_device():
    DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    return DEVICE 

DEVICE = set_device()
# DEVICE = torch.device('cpu')

print('----------------------------------')
print('Using device for training:', DEVICE)
print('----------------------------------')

----------------------------------
Using device for training: cuda
----------------------------------


#### Dataloader

In [4]:
image_size = "small"

data = PrepareData3D(["Aorta Volunteers", "Aorta BaV", "Aorta Resvcue", "Aorta CoA"], 
                     image_size=image_size, norm_min_max=[0,1])

train_ds = SirenDataset(data.train, DEVICE) 
train_dataloader = DataLoader(train_ds, batch_size=1, num_workers=0, shuffle=True)
print(train_ds.__len__())
print(next(iter(train_dataloader))[1])

val_ds = SirenDataset(data.val, DEVICE) 
val_dataloader = DataLoader(val_ds, batch_size=1, num_workers=0, shuffle=True)

print(val_ds.__len__())


54
('RESV_016.npy',)
18


#### Load models

In [5]:
z_dim = 128

flattened_size = [16384 if image_size=="full" else 4096][0]

cnn = CNN((1, 16), 
          (16, 32), 
          (32, 64), 
          (64, 128), 
          (flattened_size, z_dim)).cuda()

In [6]:
# %run "Model Classes/cnn_model.ipynb"


# pcmra = next(iter(train_dataloader))[3]
# print("pcmra:", pcmra.shape)

# out = cnn(pcmra)
# print("out:", out.shape)

In [7]:
siren = SirenGenerator(dim=z_dim, dim_hidden=256).cuda()

#### Optimizers & Loss

In [8]:
wd = 0

siren_optim = torch.optim.Adam(params=siren.parameters(), weight_decay=wd)
cnn_optim = torch.optim.Adam(params=cnn.parameters(), weight_decay=wd)

# def l2_loss(out, ground_truth): 
#     return ((out - ground_truth)**2).mean()

# criterion = l2_loss

criterion = nn.BCELoss()

In [9]:
cnn_optim.param_groups[0]['lr'] = 5e-5
siren_optim.param_groups[0]['lr'] = 5e-5
print(siren_optim)

Adam (
Parameter Group 0
    amsgrad: False
    betas: (0.9, 0.999)
    eps: 1e-08
    lr: 5e-05
    weight_decay: 0
)


#####  Random coords subsample

In [10]:
def choose_random_coords(coords, pcmra_array, mask_array, n=1000): 
    mx = coords.shape[1]
    rand_idx = random.sample(range(mx), n)

    coords = coords[:, rand_idx, :]
    pcmra_array = pcmra_array[:, rand_idx, :]
    mask_array = mask_array[:, rand_idx, :]

    return coords, pcmra_array, mask_array

#### Image generation and model saving functions 

In [11]:
def get_complete_image(pcmra, coords, val_n = 10000): 
    
    cnn.eval()
    siren.eval()
    
    image = torch.Tensor([]).cuda()
    
    cnn_out = cnn(pcmra)
    
    n_slices = math.ceil(coords.shape[1] / val_n)    
    for i in range(n_slices):
        s_e = (i * val_n, (i+1) * val_n)
        coords_in = coords[:, s_e[0] : s_e[1], :]

        siren_out = siren(cnn_out, coords_in)
        image = torch.cat((image, siren_out.detach()), 1)
    
    cnn.train()
    cnn.train()
    
    return image 


def save_model(best_loss, losses, dataset="train"):
    mean, std = round(np.mean(losses), 6), round(np.std(losses), 6)

    print(f"{dataset} \t mean loss: {mean} \t std: {std}")

    if mean < best_loss: 
        best_loss = mean
        print(f"New best {dataset} loss, saving model.")

        torch.save(cnn.state_dict(), f"Models/{folder}/cnn_{dataset}.pt")
        torch.save(cnn_optim.state_dict(), f"Models/{folder}/cnn_optim_{dataset}.pt")
        
        torch.save(siren.state_dict(), f"Models/{folder}/siren_{dataset}.pt")
        torch.save(siren_optim.state_dict(), f"Models/{folder}/siren_optim_{dataset}.pt")

    return best_loss     

#### Load model

In [12]:
# folder = "Models/PI-Gan 02-04-2021 16:20:46 mask_complete dataset_n 30000/"

# best_loss = "train"

# cnn.load_state_dict(torch.load(f"{folder}/cnn_{best_loss}.pt"))
# cnn_optim.load_state_dict(torch.load(f"{folder}/cnn_optim_{best_loss}.pt"))

# siren.load_state_dict(torch.load(f"{folder}/siren_{best_loss}.pt"))
# siren_optim.load_state_dict(torch.load(f"{folder}/siren_optim_{best_loss}.pt"))


#### Train model
for pcmra array with linear output, 0.000500 is good.


for mask with sigmoid output and BCE, 0.02 is good. 

In [13]:
# torch.cuda.empty_cache()

In [None]:
epochs = 1000
print_every = 5

aggregate_gradient = 10
batches = 0

# n = 393216
n = 30000

output_type = "mask"
dataset = "complete"

now = datetime.now()
dt = now.strftime("%d-%m-%Y %H:%M:%S")

folder = f"PI-Gan {dt} {output_type}_{dataset} dataset_n {n}"

Path(f"Models/{folder}").mkdir(parents=True, exist_ok=True)   
print(f"Creating path \\Models\\{folder}")
    

best_train_loss, best_val_loss = 100000, 100000

for ep in range(epochs):
    
    t = time.time() 
    
    cnn.train()
    siren.train()

    losses = []
        
    for idx, subj, proj, pcmra, coords, pcmra_array, mask_array in train_dataloader:
        siren_in, _, siren_labels = choose_random_coords(coords, pcmra_array, mask_array, n=n)

        cnn_out = cnn(pcmra)
        siren_out = siren(cnn_out, siren_in)
        
        loss = criterion(siren_out, siren_labels) 
        losses.append(loss.item())
        
        loss = loss / train_ds.__len__()
        loss.backward()
        
        batches += 1

        if batches % aggregate_gradient == 0: 
            siren_optim.step()
            cnn_optim.step()   
            
            siren_optim.zero_grad()
            cnn_optim.zero_grad()
    

    if ep % print_every == 0: 
        
        print(f"Epoch {ep} took {round(time.time() - t)} seconds.")
        
        best_train_loss = save_model(best_train_loss, losses, dataset="train")
        
        val_losses = []
        
        for idx, subj, proj, pcmra, coords, pcmra_array, mask_array in val_dataloader:    
            siren_out = get_complete_image(pcmra, coords)
            loss = criterion(siren_out, mask_array)            
        
            val_losses.append(loss.item())
            
        best_val_loss = save_model(best_val_loss, val_losses, dataset="val")
                
        print()
        

Creating path \Models\PI-Gan 02-04-2021 17:56:41 mask_complete dataset_n 30000
Epoch 0 took 7 seconds.
train 	 mean loss: 0.587906 	 std: 0.099548
New best train loss, saving model.
val 	 mean loss: 0.573797 	 std: 0.018322
New best val loss, saving model.

Epoch 5 took 5 seconds.
train 	 mean loss: 0.135945 	 std: 0.023211
New best train loss, saving model.
val 	 mean loss: 0.155971 	 std: 0.086799
New best val loss, saving model.

Epoch 10 took 5 seconds.
train 	 mean loss: 0.126712 	 std: 0.029006
New best train loss, saving model.
val 	 mean loss: 0.144014 	 std: 0.080867
New best val loss, saving model.



#### Show results

In [None]:
# idx, subj, proj, pcmra, coords, pcmra_array, mask_array = next(iter(val_dataloader))
# # pcmra, coords = pcmra.unsqueeze(0), coords.unsqueeze(0)
# # pcmra_array, mask_array =  pcmra_array.unsqueeze(0), mask_array.unsqueeze(0)

# siren_out = get_complete_image(pcmra, coords)
# loss = criterion(siren_out, mask_array)            

# print(f"{subj}, loss: {loss}")

# def arrays_to_numpy(*arrays): 
#     print(arrays)
    
    
# slic = 8

# # shape = (128, 128, 24)
# shape = (64, 64, 24)

# fig, axes = plt.subplots(1, 3, figsize=(12,12))
# axes[0].imshow(pcmra_array.cpu().view(shape).detach().numpy()[:, :, slic])
# axes[1].imshow(mask_array.cpu().view(shape).detach().numpy()[:, :, slic])
# # axes[2].imshow(siren_out.cpu().view(shape).detach().numpy()[:, :, slic])
# axes[2].imshow(siren_out.cpu().view(shape).detach().numpy().round()[:, :, slic])

# plt.show()

In [None]:
# for idx, subj, proj, pcmra, coords, pcmra_array, mask_array in val_dataloader: 
    
    
#     siren_out = get_complete_image(pcmra, coords)
#     loss = criterion(siren_out, mask_array)            

#     print(subj, loss.item()) 

#     slic = 12

#     fig, axes = plt.subplots(1, 3, figsize=(12,12))
#     axes[0].imshow(pcmra_array.cpu().view(128, 128, 24).detach().numpy()[:, :, slic])
#     axes[1].imshow(mask_array.cpu().view(128, 128, 24).detach().numpy()[:, :, slic])
#     axes[2].imshow(siren_out.cpu().view(128, 128, 24).detach().numpy().round()[:, :, slic])

#     plt.show()

In [None]:
def scroll_through_output(shape=(64, 64, 24)):
    pcmras = masks = outs = torch.Tensor([])
    titles = []

    shape = (64, 64, 24)
    for idx, subj, proj, pcmra, coords, pcmra_array, mask_array in val_dataloader: 

        siren_out = get_complete_image(pcmra, coords)
        loss = criterion(siren_out, mask_array) 

        pcmras = torch.cat((pcmras, pcmra_array.cpu().view(shape).detach()), 2)
        masks = torch.cat((masks, mask_array.cpu().view(shape).detach()), 2)
        outs = torch.cat((outs, siren_out.cpu().view(shape).detach()), 2)

        titles += [subj[0] + " " + proj[0] for i in range(shape[2])]

    return Show_images(titles, (pcmras.numpy(), "pcmras"), (masks.numpy(), "masks"), (outs.numpy(), "outs"))
    


In [None]:
window = scroll_through_output()