In [1]:
import torch
import torch.nn.functional as F

from torch import nn
from torch.utils.data import DataLoader, Dataset
from torchvision.transforms import Resize, Compose, ToTensor, Normalize

import argparse
import os
import math 
import skimage
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image

import time
import pickle

from datetime import datetime
from pathlib import Path

# from data_classes.py_files.data_classes import *
from data_classes.py_files.custom_datasets import *

from model_classes.py_files.cnn_model import *
from model_classes.py_files.pigan_model import *

Imported Project and Show_images classes.
Imported data preparation and custom Dataset classes.
Imported CNN model.
Imported PI-Gan model.


#### Import classes

In [2]:
def set_device():
    DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    return DEVICE 

DEVICE = set_device()

print('----------------------------------')
print('Using device for training:', DEVICE)
print('----------------------------------')

----------------------------------
Using device for training: cuda
----------------------------------


#####  Random coords subsample

In [3]:
def choose_random_coords(*arrays, n=1000): 
    
    mx = arrays[0].shape[1]
    rand_idx = random.sample(range(mx), n)

    arrays = [array[:, rand_idx, :] for array in arrays]
    
    return arrays

#### Image generation and model saving functions 

In [4]:
def get_complete_image(cnn, siren, pcmra, coords, val_n = 10000): 
    
    cnn.eval(); siren.eval() #evaluation mode    
    image = torch.Tensor([]).cuda() # initialize results tensor
    cnn_out = cnn(pcmra) # get representation
    
    n_slices = math.ceil(coords.shape[1] / val_n) # number of batches
    for i in range(n_slices):
        coords_in = coords[:, (i*val_n) : ((i+1)*val_n), :]
        siren_out = siren(cnn_out, coords_in)
        image = torch.cat((image, siren_out.detach()), 1)
    
    cnn.train(); cnn.train()
    
    return image 


def save_info(path, losses, cnn, siren, cnn_optim, siren_optim): 
    
    np.save(f"{path}/losses.npy", losses)
    
    eps = losses[:, 0]
    train_losses = losses[:, 1]
    val_losses = losses[:, 3]

    if train_losses[-1] == train_losses.min(): 
        print(f"New best train loss: {round(train_losses[-1], 5)}, saving model.")

        torch.save(cnn.state_dict(), f"{path}/cnn_train.pt")
        torch.save(cnn_optim.state_dict(), f"{path}/cnn_optim_train.pt")
        
        torch.save(siren.state_dict(), f"{path}/siren_train.pt")
        torch.save(siren_optim.state_dict(), f"{path}/siren_optim_train.pt")
    
    if train_losses[-1] == train_losses.min(): 
        print(f"New best val loss: {round(val_losses[-1], 5)}, saving model.")

        torch.save(cnn.state_dict(), f"{path}/cnn_val.pt")
        torch.save(cnn_optim.state_dict(), f"{path}/cnn_optim_val.pt")
        
        torch.save(siren.state_dict(), f"{path}/siren_val.pt")
        torch.save(siren_optim.state_dict(), f"{path}/siren_optim_val.pt")


    fig, ax = plt.subplots()
    
    fig.patch.set_facecolor('white')
    ax.plot(eps[1:], train_losses[1:], label='Train loss')
    ax.plot(eps[1:], val_losses[1:], label='Eval loss')

    plt.xlabel('Epochs')
    plt.ylabel('BCELoss')
    legend = ax.legend(loc='upper right')
    
    plt.savefig(f"{path}/loss_plot.png")

## Train model

In [5]:
def get_folder(ARGS): 
    now = datetime.now()
    dt = now.strftime("%d-%m-%Y %H:%M:%S")
    path = f"saved_runs/pi-gan {dt} {ARGS.name}"
    
    Path(f"{path}").mkdir(parents=True, exist_ok=True)   

    return path
    

def initialize_dataloaders(projects, ARGS):
    assert(ARGS.dataset in ["full", "small"])

    data = PrepareData3D(projects, image_size=ARGS.dataset, norm_min_max=ARGS.norm_min_max)

    train_ds = SirenDataset(data.train, DEVICE) 
    train_dl = DataLoader(train_ds, batch_size=1, num_workers=0, shuffle=ARGS.shuffle)
    print("Train subjects:", train_ds.__len__())

    val_ds = SirenDataset(data.val, DEVICE) 
    val_dl = DataLoader(val_ds, batch_size=1, num_workers=0, shuffle=False)
    print("Validation subjects:", val_ds.__len__())
    
    return train_dl, val_dl
    

In [6]:
def train_epoch(cnn, siren, dataloader, cnn_optim, siren_optim, criterion, batch_count, ARGS):
    losses = []
    
    for _, _, _, pcmra, coords, _, mask_array in dataloader:
        siren_in, siren_labels = choose_random_coords(coords, mask_array)

        cnn_out = cnn(pcmra)
        siren_out = siren(cnn_out, siren_in)

        loss = criterion(siren_out, siren_labels) 
        losses.append(loss.item())
        loss = loss / dataloader.__len__()
        loss.backward()

        batch_count += 1
        if batch_count % ARGS.acc_steps == 0: 
            siren_optim.step()
            cnn_optim.step()   

            siren_optim.zero_grad()
            cnn_optim.zero_grad()
    
    mean, std = round(np.mean(losses), 6), round(np.std(losses), 6)
    
    return mean, std, batch_count


def val_epoch(cnn, siren, dataloader, cnn_optim, siren_optim, criterion):
    losses = []

    for idx, subj, proj, pcmra, coords, pcmra_array, mask_array in dataloader:    
        siren_out = get_complete_image(cnn, siren, pcmra, coords)
        loss = criterion(siren_out, mask_array)            

        losses.append(loss.item())
    
    mean, std = round(np.mean(losses), 6), round(np.std(losses), 6)
    
    return mean, std

In [7]:
def train():  
    
    ##### path to wich the model should be saved #####
    path = get_folder(ARGS)
    
    ##### save ARGS #####
    with open(f'{path}/ARGS.txt', 'w') as f:
        print(vars(ARGS), file=f)
        
    ##### data preparation #####
    train_dl, val_dl = initialize_dataloaders(["Aorta Volunteers", "Aorta BaV",
                                               "Aorta Resvcue", "Aorta CoA"], ARGS)

    
    ##### initialize models #####
    cnn = CNN((1, 16), (16, 32), (32, 64), (64, 128),
              (ARGS.flattened_size, ARGS.z_dim)).cuda()
    
    siren = SirenGenerator(dim=ARGS.z_dim, dim_hidden=256).cuda()
    
    ##### initialize optimizers #####
    cnn_optim = torch.optim.Adam(lr=ARGS.cnn_lr, params=cnn.parameters(), 
                                 weight_decay=ARGS.cnn_wd)
    
    siren_optim = torch.optim.Adam(lr=ARGS.siren_lr, params=siren.parameters(), 
                                   weight_decay=ARGS.siren_wd)
    
    ##### loss function #####
    criterion = nn.BCELoss()
    
    
    ##### epoch, train loss mean, train loss std, #####
    ##### val loss mean, val loss std #####
    losses = np.empty((0, 5))

    batch_count = 0     
    
    for ep in range(ARGS.epochs):
    
        t = time.time() 

        cnn.train(); siren.train()

        t_loss_mean, t_loss_std, batch_count = train_epoch(cnn, siren, train_dl, 
                                                           cnn_optim, siren_optim, 
                                                           criterion, batch_count, ARGS)
        
        if ep % ARGS.eval_every == 0: 

            print(f"Epoch {ep} took {round(time.time() - t)} seconds.")

            v_loss_mean, v_loss_std = val_epoch(cnn, siren, val_dl, 
                                                cnn_optim, siren_optim, criterion)
            
            losses = np.append(losses, [[ep ,t_loss_mean, t_loss_std, 
                                         v_loss_mean, v_loss_std]], axis=0)
            
            save_info(path, losses, cnn, siren, cnn_optim, siren_optim)

In [None]:
class init_ARGS(object): 
    def __init__(self): 
        self.name = ""
        self.dataset = "small"
        self.epochs = 500
        self.acc_steps = 10
        self.shuffle = True
        self.norm_min_max = [0, 1]
        self.shuffle = True
        self.flattened_size = 4096
        self.z_dim = 128
        self.shuffle = True
        self.cnn_lr = 1e-4
        self.siren_lr = 1e-4
        self.cnn_wd = 0
        self.siren_wd = 0
        self.eval_every = 2

        
        print("WARNING: ARGS class initialized.")
        
ARGS = init_ARGS()
        
train()  

Train subjects: 54
Validation subjects: 18
Epoch 0 took 4 seconds.
New best train loss: 0.67144, saving model.
New best val loss: 0.63832, saving model.
Epoch 2 took 4 seconds.
New best train loss: 0.27788, saving model.
New best val loss: 0.14501, saving model.
Epoch 4 took 4 seconds.
New best train loss: 0.13694, saving model.
New best val loss: 0.12213, saving model.
Epoch 6 took 4 seconds.
New best train loss: 0.13041, saving model.
New best val loss: 0.12257, saving model.
Epoch 8 took 4 seconds.
New best train loss: 0.12776, saving model.
New best val loss: 0.12128, saving model.
Epoch 10 took 4 seconds.
New best train loss: 0.12454, saving model.
New best val loss: 0.11824, saving model.


#### Dataloader

In [None]:
image_size = "small"

data = PrepareData3D(["Aorta Volunteers", "Aorta BaV", "Aorta Resvcue", "Aorta CoA"], 
                     image_size=image_size, norm_min_max=[0,1])

train_ds = SirenDataset(data.train, DEVICE) 
train_dataloader = DataLoader(train_ds, batch_size=1, num_workers=0, shuffle=True)
print(train_ds.__len__())
print(next(iter(train_dataloader))[1])

val_ds = SirenDataset(data.val, DEVICE) 
val_dataloader = DataLoader(val_ds, batch_size=1, num_workers=0, shuffle=True)

print(val_ds.__len__())


#### Load models

In [None]:
z_dim = 128

flattened_size = [16384 if image_size=="full" else 4096][0]

cnn = CNN((1, 16), 
          (16, 32), 
          (32, 64), 
          (64, 128), 
          (flattened_size, z_dim)).cuda()

In [None]:
# %run "Model Classes/cnn_model.ipynb"


# pcmra = next(iter(train_dataloader))[3]
# print("pcmra:", pcmra.shape)

# out = cnn(pcmra)
# print("out:", out.shape)

In [None]:
siren = SirenGenerator(dim=z_dim, dim_hidden=256).cuda()

#### Optimizers & Loss

In [None]:
wd = 0

siren_optim = torch.optim.Adam(params=siren.parameters(), weight_decay=wd)
cnn_optim = torch.optim.Adam(params=cnn.parameters(), weight_decay=wd)

# def l2_loss(out, ground_truth): 
#     return ((out - ground_truth)**2).mean()

# criterion = l2_loss

criterion = nn.BCELoss()

In [None]:
cnn_optim.param_groups[0]['lr'] = 5e-5
siren_optim.param_groups[0]['lr'] = 5e-5
print(siren_optim)

#### Load model

In [None]:
# folder = "Models/PI-Gan 02-04-2021 16:20:46 mask_complete dataset_n 30000/"

# best_loss = "train"

# cnn.load_state_dict(torch.load(f"{folder}/cnn_{best_loss}.pt"))
# cnn_optim.load_state_dict(torch.load(f"{folder}/cnn_optim_{best_loss}.pt"))

# siren.load_state_dict(torch.load(f"{folder}/siren_{best_loss}.pt"))
# siren_optim.load_state_dict(torch.load(f"{folder}/siren_optim_{best_loss}.pt"))


#### Train model
for pcmra array with linear output, 0.000500 is good.


for mask with sigmoid output and BCE, 0.02 is good. 

In [None]:
# torch.cuda.empty_cache()

In [None]:
epochs = 1000
print_every = 5

aggregate_gradient = 10
batches = 0

# n = 393216
n = 30000

output_type = "mask"
dataset = "complete"


folder = f"PI-Gan {dt} {output_type}_{dataset} dataset_n {n}"

Path(f"Models/{folder}").mkdir(parents=True, exist_ok=True)   
print(f"Creating path \\Models\\{folder}")
    


for ep in range(epochs):
    
    t = time.time() 
    
    cnn.train()
    siren.train()

    losses = []
        
    for idx, subj, proj, pcmra, coords, pcmra_array, mask_array in train_dataloader:
        siren_in, _, siren_labels = choose_random_coords(coords, pcmra_array, mask_array, n=n)

        cnn_out = cnn(pcmra)
        siren_out = siren(cnn_out, siren_in)
        
        loss = criterion(siren_out, siren_labels) 
        losses.append(loss.item())
        
        loss = loss / train_ds.__len__()
        loss.backward()
        
        batches += 1

        if batches % aggregate_gradient == 0: 
            siren_optim.step()
            cnn_optim.step()   
            
            siren_optim.zero_grad()
            cnn_optim.zero_grad()
    

    if ep % print_every == 0: 
        
        print(f"Epoch {ep} took {round(time.time() - t)} seconds.")
        
        best_train_loss = save_model(best_train_loss, losses, dataset="train")
        
        val_losses = []
        
        for idx, subj, proj, pcmra, coords, pcmra_array, mask_array in val_dataloader:    
            siren_out = get_complete_image(pcmra, coords)
            loss = criterion(siren_out, mask_array)            
        
            val_losses.append(loss.item())
            
        best_val_loss = save_model(best_val_loss, val_losses, dataset="val")
                
        print()
        

#### Show results

In [None]:
# idx, subj, proj, pcmra, coords, pcmra_array, mask_array = next(iter(val_dataloader))
# # pcmra, coords = pcmra.unsqueeze(0), coords.unsqueeze(0)
# # pcmra_array, mask_array =  pcmra_array.unsqueeze(0), mask_array.unsqueeze(0)

# siren_out = get_complete_image(pcmra, coords)
# loss = criterion(siren_out, mask_array)            

# print(f"{subj}, loss: {loss}")

# def arrays_to_numpy(*arrays): 
#     print(arrays)
    
    
# slic = 8

# # shape = (128, 128, 24)
# shape = (64, 64, 24)

# fig, axes = plt.subplots(1, 3, figsize=(12,12))
# axes[0].imshow(pcmra_array.cpu().view(shape).detach().numpy()[:, :, slic])
# axes[1].imshow(mask_array.cpu().view(shape).detach().numpy()[:, :, slic])
# # axes[2].imshow(siren_out.cpu().view(shape).detach().numpy()[:, :, slic])
# axes[2].imshow(siren_out.cpu().view(shape).detach().numpy().round()[:, :, slic])

# plt.show()

In [None]:
# for idx, subj, proj, pcmra, coords, pcmra_array, mask_array in val_dataloader: 
    
    
#     siren_out = get_complete_image(pcmra, coords)
#     loss = criterion(siren_out, mask_array)            

#     print(subj, loss.item()) 

#     slic = 12

#     fig, axes = plt.subplots(1, 3, figsize=(12,12))
#     axes[0].imshow(pcmra_array.cpu().view(128, 128, 24).detach().numpy()[:, :, slic])
#     axes[1].imshow(mask_array.cpu().view(128, 128, 24).detach().numpy()[:, :, slic])
#     axes[2].imshow(siren_out.cpu().view(128, 128, 24).detach().numpy().round()[:, :, slic])

#     plt.show()

In [None]:
def scroll_through_output(shape=(64, 64, 24)):
    pcmras = masks = outs = torch.Tensor([])
    titles = []

    for idx, subj, proj, pcmra, coords, pcmra_array, mask_array in val_dataloader: 

        siren_out = get_complete_image(pcmra, coords)
        loss = criterion(siren_out, mask_array) 

        pcmras = torch.cat((pcmras, pcmra_array.cpu().view(shape).detach()), 2)
        masks = torch.cat((masks, mask_array.cpu().view(shape).detach()), 2)
        outs = torch.cat((outs, siren_out.cpu().view(shape).detach()), 2)

        titles += [subj[0] + " " + proj[0] for i in range(shape[2])]

    return Show_images(titles, (pcmras.numpy(), "pcmras"), (masks.numpy(), "masks"), (outs.numpy(), "outs"))
    


In [None]:
window = scroll_through_output()

## Arguments

In [None]:
if __name__ == "__main__":
    PARSER = argparse.ArgumentParser()

    # Arguments for training
    PARSER.add_argument('--name', type=str, default="", 
                        help='Name of the folder where the output should be saved.')
    
    PARSER.add_argument('--dataset', type=str, default="small", 
                        help='The dataset which we train on.')
    
    PARSER.add_argument('--epochs', type=int, default=500, 
                        help='Number of epochs.')
    
    PARSER.add_argument('--acc_steps', type=int, default=10, 
                        help='Number of subjects that the gradient is \
                        accumulated over before taking an optim step.')
    
    PARSER.add_argument('--shuffle', type=bool, default=True, 
                        help='Shuffle the train dataloader?')
    
    PARSER.add_argument('--norm_min_max', type=list, default=[0, 1], 
                        help='List with min and max for normalizing input.')
    
    PARSER.add_argument('--flattened_size', type=int, default=16384, 
                        help='Size of cnn conv output.')
    
    PARSER.add_argument('--z_dim', type=int, default=128, 
                        help='Size of the latent pcmra representation.')
    
    PARSER.add_argument('--cnn_lr', type=float, default=0, 
                        help='Learning rate of cnn optim.')

    PARSER.add_argument('--siren_lr', type=float, default=0, 
                        help='Learning rate of siren optim.')

    PARSER.add_argument('--cnn_wd', type=float, default=0, 
                        help='Weight decay of cnn optim.')

    PARSER.add_argument('--siren_wd', type=float, default=0, 
                        help='Weight decay of siren optim.')
    
    PARSER.add_argument('--eval_every', type=int, default=10, 
                        help='Set the # epochs after which evaluation should be done.')
    
    ARGS = PARSER.parse_args()
    
    train()

In [None]:
class ARGS(object): 
    def __init__(self): 
        self.dataset = "small"
        self.epochs = 500
        self.acc_steps = 10
        self.shuffle = True
        self.norm_min_max = [0, 1]
        self.shuffle = True
        self.flattened_size = 16384
        self.z_dim = 128
        self.shuffle = True
        self.cnn_lr = 0
        self.siren_lr = 0
        self.cnn_wd = 0
        self.siren_wd = 0
        self.eval_every = 10
        
train()

    