In [None]:
%run /home/ptenkaate/scratch/Master-Thesis/convert_ipynb_to_py_files.ipynb

#### Colab run

In [None]:
# from google.colab import drive
# drive.mount('/content/drive/')

In [None]:
# cd drive/MyDrive/master_thesis/pi-gan_sequential

In [None]:
# !pip install kornia
# !pip install pydicom
# !pip install torchinfo
# !pip install einops

In [2]:
import warnings

import torch
import torch.nn.functional as F

from torch import nn
from torch.utils.data import DataLoader, Dataset
from torchvision.transforms import Resize, Compose, ToTensor, Normalize

import argparse
import os
import math 
import skimage
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image

import time
import pickle

from datetime import datetime
from pathlib import Path

from py_files.new_dataset import *

from py_files.cnn_model import *
from py_files.pigan_model import *

from py_files.seq_pi_gan_functions import *

Imported CNN and Mapping functions.
Imported PI-Gan model.
Loaded all helper functions.


#### Train the model

In [3]:
def train():  
    
    warnings.filterwarnings("ignore")
    
    ##### path to wich the model should be saved #####
    path = get_folder(ARGS)
    
    print(path)
    ##### save ARGS #####
    with open(f"{path}/ARGS.txt", "w") as f:
        print(vars(ARGS), file=f)
        
    ##### data preparation #####
    train_dl, val_dl, test_dl = initialize_dataloaders(ARGS)
    print("train batch:", next(iter(train_dl))[1][:5])
    print("eval batch:", next(iter(val_dl))[1][:5])
    print("test batch:", next(iter(test_dl))[1][:5])
            
    ##### initialize models and optimizers #####
    models, optims, schedulers = load_models_and_optims(ARGS)
    
    
    ##### load pretrained model #####
    if ARGS.pretrained: 
        print(f"Loading pretrained model from '{ARGS.pretrained}'.")
        load_pretrained_models(ARGS.pretrained, ARGS.pretrained_best_dataset, ARGS.pretrained_best_loss,
                    models, optims, pretrained_models = ARGS.pretrained_models)
    
        if ARGS.pretrained_lr_reset:
            orig_lr = {"cnn": ARGS.pretrained_lr_reset, "mapping": ARGS.pretrained_lr_reset, 
                       "siren": ARGS.pretrained_lr_reset, "pcmra_mapping": ARGS.pretrained_lr_reset, 
                       "pcmra_siren": ARGS.pretrained_lr_reset}
            for name, optim in optims.items():
                for param_group in optim.param_groups: 
                    if param_group["lr"] != orig_lr[name]: 
                        param_group["lr"] = ARGS.pretrained_lr_reset
                print(f"{name} lr: {optim.param_groups[0]['lr']}")

    ##### loss function #####
    criterions = [nn.BCELoss(), nn.MSELoss()]
        
    ##### epoch, train loss mean, train loss std, val loss mean, val loss std #####
    mask_losses, pcmra_losses, dice_losses = np.empty((0, 5)), np.empty((0, 5)), np.empty((0, 5))
    
    for ep in range(ARGS.pcmra_epochs):
    
        t = time.time() 

        for model in models.values():
            model.train()

        loss, _ = train_model(train_dl, models, optims, schedulers, criterions[1], ARGS, output="pcmra")
        
        
        if ep % ARGS.eval_every == 0: 

            print(f"Epoch {ep} took {round(time.time() - t, 2)} seconds.")
            
            t_pcmra_mean, t_pcmra_std, _, _ = \
                val_model(train_dl, models, criterions[1], ARGS, output="pcmra")
            
            v_pcmra_mean, v_pcmra_std, _, _ = \
                val_model(val_dl, models, criterions[1], ARGS, output="pcmra")

            pcmra_losses = np.append(pcmra_losses, [[ep ,t_pcmra_mean, t_pcmra_std, 
                                         v_pcmra_mean, v_pcmra_std]], axis=0)
            
            save_loss(path, pcmra_losses, models, optims, name="pcmra_loss", 
                      save_models=True)
        
    
    for ep in range(ARGS.mask_epochs):
    
        t = time.time() 

        for model in models.values():
            model.train()

        loss, _ = train_model(train_dl, models, optims, schedulers, criterions[0], ARGS, output="mask")
        
        
        if ep % ARGS.eval_every == 0: 

            print(f"Epoch {ep} took {round(time.time() - t, 2)} seconds.")
            
            t_mask_mean, t_mask_std, t_dice_mean, t_dice_std = \
                val_model(train_dl, models, criterions[0], ARGS, output="mask")
            
            v_mask_mean, v_mask_std, v_dice_mean, v_dice_std = \
                val_model(val_dl, models, criterions[0], ARGS, output="mask")

            mask_losses = np.append(mask_losses, [[ep ,t_mask_mean, t_mask_std, 
                                         v_mask_mean, v_mask_std]], axis=0)
            
            dice_losses = np.append(dice_losses, [[ep ,t_dice_mean, t_dice_std, 
                                         v_dice_mean, v_dice_std]], axis=0)
            
            save_loss(path, mask_losses, models, optims, name="mask_loss", 
                      save_models=True)
            
            save_loss(path, dice_losses, models, optims, name="dice_loss", 
                      save_models=False)

## Run as .ipynb

In [None]:
# for cnn_setup, mapping_setup in [(-1, -1)]:

#     ARGS = init_ARGS()
    
#     ARGS.cnn_setup = cnn_setup
#     ARGS.mapping_setup = mapping_setup
    
#     ARGS.batch_size = 1
    
#     ARGS.pcmra_epochs = 0
    
#     ARGS.patience = 100 
    
#     ARGS.first_omega_0 = 30
    
#     print(vars(ARGS))

#     train()  

#     torch.cuda.empty_cache()    


In [5]:
ARGS = init_ARGS()

ARGS.flip = False 
ARGS.crop = False 
ARGS.rotate = False 

train_dl, val_dl, test_dl = initialize_dataloaders(ARGS)

----------------------------------
Using device for training: cuda
----------------------------------
Train subjects: 86
Val subjects: 29
Test subjects: 29


In [6]:
import numbers
from collections import defaultdict

# Initialize the blurring layer
sigma = 1.0
size = math.ceil(3*sigma)
blur_layer = GaussianSmoothing(1, [size,size,size], sigma, dim=3).cuda()

In [7]:
batch = next(iter(train_dl))

batch = transform_batch(batch, ARGS)        

In [95]:
t = time.time()


def get_surface_and_norm(batch): 
    masks_blurred = blur_layer(batch[-2])     

    grad, grad_magn = gradient3d(masks_blurred, normalize=True, s=2)

    surface = (grad_magn >= 0.5*torch.max(grad_magn)).type(torch.float32)
    norm = grad * surface
    
    return surface, norm

def reshape_arrays(*arrays): 
    return [array.view(array.shape[0], array.shape[1], -1).permute(0, 2, 1) for array in arrays]
    

def get_siren_batch(batch, n=5000): 
    
    idx, subj, proj, pcmras, masks, loss_covers = batch
    
    subjects = []
    
    # initialize a coords matrix
    coords = get_coords(*pcmras.shape[2:]).to(pcmras.device)
    
    # reshape all matrixes 
    pcmra_array, mask_array, loss_cover_array = reshape_arrays(pcmras, masks, loss_covers)
    
    # select n coords and their corresponding values
    for pcmra, mask, loss_cover in zip(pcmra_array, mask_array, loss_cover_array):
        
        # select n random coords that have a non zero loss_cover
        idx = (loss_cover != 0).nonzero()[:, 0].cpu().numpy()
        idx = np.random.choice(idx, n)
        
        subject = [coords[idx, :].unsqueeze(0), pcmra[idx, :].unsqueeze(0), mask[idx, :].unsqueeze(0)]
    
        subjects.append(subject)

get_siren_batch(batch, n=5000)

print(time.time() - t)

0.03341102600097656


In [118]:
t = time.time()

    
def get_siren_batch_sdf(batch, n=5000, sdf_split=0.5): 
    
    idx, subj, proj, pcmras, masks, loss_covers = batch
    subjects = []

    surface_n, random_n = int(n*sdf_split), n - int(n*sdf_split)

    # get the surface and norm of the mask
    surfaces, norms = get_surface_and_norm(batch)



    # initialize a coords matrix
    coords = get_coords(*pcmras.shape[2:]).to(pcmras.device)

    # reshape all matrixes 
    surface_array, norm_array = reshape_arrays(surfaces, norms)
    pcmra_array, mask_array, loss_cover_array = reshape_arrays(pcmras, masks, loss_covers)
    coords_array = coords.unsqueeze(0).repeat(pcmras.shape[0], 1, 1)

    # select n coords and their corresponding values
    for pcmra, mask, loss_cover, surface, norm in \
        zip(pcmra_array, mask_array, loss_cover_array, surface_array, norm_array):


        # select n * sfd_split points that lie on the surface
        surface_idx = (surface != 0).nonzero()[:, 0].flatten().cpu().numpy()
        surface_idx = np.random.choice(surface_idx, surface_n)

        # select n random coords that have a non zero loss_cover
        random_idx = (loss_cover != 0).nonzero()[:, 0].cpu().numpy()
        random_idx = np.random.choice(random_idx, random_n)

        idx = np.concatenate((surface_idx, random_idx))
        print(idx)

        subject = [coords[idx, :].unsqueeze(0), pcmra[idx, :].unsqueeze(0), mask[idx, :].unsqueeze(0),
                 surface[idx, :].unsqueeze(0), norm[idx, :].unsqueeze(0)]


        subjects.append(subject)

    coords_array = torch.cat([subj[0] for subj in subjects], 0)
    pcmra_array = torch.cat([subj[1] for subj in subjects], 0)
    mask_array = torch.cat([subj[2] for subj in subjects], 0)
    surface_array = torch.cat([subj[3] for subj in subjects], 0)
    norm_array = torch.cat([subj[4] for subj in subjects], 0)
    


print(time.time() - t) 

0.0012083053588867188


tensor([0., 0., 0.], device='cuda:0')


In [117]:
# print(masks.shape)

# t = time.time()
# with torch.no_grad():
#     masks_blurred = blur_layer(masks)     

# grad, grad_magn = gradient3d(masks_blurred, normalize=True, s=2)

# surface_indicator = (grad_magn >= 0.5*torch.max(grad_magn)).type(torch.float32)
# si = surface_indicator

# surface_norm = grad * surface_indicator

# print(time.time() - t)

# surface_array = surface_indicator.view(24, -1, 1)
# sa = surface_array
# print(sa.shape)

# print(time.time() - t)

# idxs = []
# for s in sa:
#     s = s.squeeze()
#     surface_idx = (s != 0).nonzero().flatten().cpu().numpy()
#     idxs.append(np.random.choice(surface_idx, 2500))
    
# #     print(surface_idx.shape)

# print(time.time() - t)

# print(np.array(idxs))


In [4]:
class GaussianSmoothing(nn.Module):
    """
    Apply gaussian smoothing on a
    1d, 2d or 3d tensor. Filtering is performed seperately for each channel
    in the input using a depthwise convolution.
    Arguments:
        channels (int, sequence): Number of channels of the input tensors. Output will
            have this number of channels as well.
        kernel_size (int, sequence): Size of the gaussian kernel.
        sigma (float, sequence): Standard deviation of the gaussian kernel.
        dim (int, optional): The number of dimensions of the data.
            Default value is 2 (spatial).
    """
    def __init__(self, channels, kernel_size, sigma, dim=2):
        super(GaussianSmoothing, self).__init__()
        if isinstance(kernel_size, numbers.Number):
            kernel_size = [kernel_size] * dim
            self.kernel_radius = kernel_size // 2
        else:
            self.kernel_radius = kernel_size[0] // 2

        if isinstance(sigma, numbers.Number):
            sigma = [sigma] * dim


        # The gaussian kernel is the product of the
        # gaussian function of each dimension.
        kernel = 1
        meshgrids = torch.meshgrid(
            [
                torch.arange(size, dtype=torch.float32)
                for size in kernel_size
            ]
        )
        for size, std, mgrid in zip(kernel_size, sigma, meshgrids):
            mean = (size - 1) / 2
            kernel *= 1 / (std * math.sqrt(2 * math.pi)) * \
                      torch.exp(-((mgrid - mean) / (2 * std)) ** 2)

        # Make sure sum of values in gaussian kernel equals 1.
        kernel = kernel / torch.sum(kernel)

        # Reshape to depthwise convolutional weight
        kernel = kernel.view(1, 1, *kernel.size())
        kernel = kernel.repeat(channels, *[1] * (kernel.dim() - 1))

        self.register_buffer('weight', kernel)
        self.groups = channels

        if dim == 1:
            self.conv = F.conv1d
        elif dim == 2:
            self.conv = F.conv2d
        elif dim == 3:
            self.conv = F.conv3d
        else:
            raise RuntimeError(
                'Only 1, 2 and 3 dimensions are supported. Received {}.'.format(dim)
            )

    def forward(self, input):
        """
        Apply gaussian filter to input.
        Arguments:
            input (torch.Tensor): Input to apply gaussian filter on.
        Returns:
            filtered (torch.Tensor): Filtered output.
        """
        return self.conv(input, weight=self.weight, groups=self.groups, padding=self.kernel_radius)
    
    
def gradient3d(data, normalize=False, s=2):
    data_padded = torch.nn.functional.pad(data, (1,1,1,1,1,1,0,0,0,0))

    grad_x = (data_padded[:,:, s:,   1:-1, 1:-1] - data_padded[:,:, 0:-s, 1:-1, 1:-1]) / s
    grad_y = (data_padded[:,:, 1:-1, s:,   1:-1] - data_padded[:,:, 1:-1, 0:-s, 1:-1]) / s
    grad_z = (data_padded[:,:, 1:-1, 1:-1, s:  ] - data_padded[:,:, 1:-1, 1:-1 ,0:-s]) / s
    
    grad = torch.cat([grad_x,grad_y,grad_z], dim=1)
    grad_magn = torch.sqrt(grad_x**2 + grad_y**2 + grad_z**2)
    
    if normalize:
        eps=1e-8
        grad = grad/(grad_magn + eps)
    
    return grad, grad_magn

## Run as .py

In [None]:
def str2bool(v):
    if isinstance(v, bool):
        return v
    if v.lower() in ('yes', 'true', 't', 'y', '1'):
        return True
    elif v.lower() in ('no', 'false', 'f', 'n', '0'):
        return False
    else:
        raise argparse.ArgumentTypeError('Boolean value expected.')

if __name__ == "__main__":

    PARSER = argparse.ArgumentParser()

    
    # Arguments for training
    PARSER.add_argument('--device', type=str, default="GPU", 
                        help='Device that should be used.')

    PARSER.add_argument('--print_models', type=str2bool, nargs="?", const=True, default=False, 
                        help='Print the models after initialization or not.')

    PARSER.add_argument('--name', type=str, default="", 
                        help='Name of the folder where the output should be saved.')
    
    

    # pretrained params 
    
    PARSER.add_argument('--pretrained', type=str, default=None, 
                        help='Folder name of pretrained model that should be loaded.')
    
    PARSER.add_argument('--pretrained_best_dataset', type=str, default="train", 
                        help='Pretrained model with lowest [train, val] loss.')
    
    PARSER.add_argument('--pretrained_best_loss', type=str, default="mask", 
                        help='Pretrained model with lowest [train, val] loss.')
    
    PARSER.add_argument('--pretrained_models', type=str, default=None, 
                        help='Choose which pretrained models to load. None = all models')
    
    PARSER.add_argument('--pretrained_lr_reset', type=str, default=None, 
                        help='Reset the lr to a value.')
    
    
    
    # data
    PARSER.add_argument('--dataset', type=str, default="new", 
                        help='The dataset which we train on.')
    
    PARSER.add_argument('--seed', type=int, default=34, 
                        help='Seed for initializig dataloader')
    
    PARSER.add_argument('--rotate', type=str2bool, nargs="?", const=True, default=True, 
                        help='Rotations of the same image')
    
    PARSER.add_argument('--translate', type=str2bool, nargs="?", const=True, default=True, 
                        help='Translations of the same image')
    
    PARSER.add_argument('--translate_max_pixels', type=int, default=20, 
                        help='Translation max in height and width.')
    
    PARSER.add_argument('--flip', type=str2bool, nargs="?", const=True, default=True, 
                        help='Flips the train image')
    
    PARSER.add_argument('--crop', type=str2bool, nargs="?", const=True, default=True, 
                        help='Crops the train image')

    PARSER.add_argument('--stretch', type=str2bool, nargs="?", const=True, default=True, 
                        help='Stretches the train image')

    PARSER.add_argument('--stretch_factor', type=float, default=1.2, 
                        help='Stretch maximum of the train image')

    PARSER.add_argument('--norm_min_max', type=list, default=[0, 1], 
                        help='List with min and max for normalizing input.')
    
    
    
    # train variables
    PARSER.add_argument('--pcmra_epochs', type=int, default=5000, 
                        help='Number of epochs for pcmra training.')

    PARSER.add_argument('--mask_epochs', type=int, default=5000, 
                        help='Number of epochs for mask training.')
    
    PARSER.add_argument('--batch_size', type=int, default=24, 
                        help='Number of epochs.')
        
    PARSER.add_argument('--eval_every', type=int, default=50, 
                        help='Set the # epochs after which evaluation should be done.')
    
    PARSER.add_argument('--shuffle', type=str2bool, nargs="?", const=True, default=True, 
                        help='Shuffle the train dataloader?')
    
    PARSER.add_argument('--n_coords_sample', type=int, default=5000, 
                        help='Number of coordinates that should be sampled for each subject.')
    
    PARSER.add_argument('--min_lr', type=float, default=1e-5, 
                        help='Minimum lr, input for lr scheduler.')
    
    
    
    # CNN
    PARSER.add_argument('--cnn_setup', type=int, default=-1, 
                        help='Setup of the CNN.')
    
    PARSER.add_argument('--pcmra_train_cnn', type=str2bool, nargs="?", const=True, default=True, 
                        help='Whether to also train the cnn during pcmra reconstruction.')

    PARSER.add_argument('--mask_train_cnn', type=str2bool, nargs="?", const=True, default=True, 
                        help='Whether to also train the cnn during mask segmentation.')


    
    # Mapping
    PARSER.add_argument('--mapping_setup', type=int, default=-1, 
                        help='Setup of the Mapping network.')

    
    
    # SIREN
    PARSER.add_argument('--dim_hidden', type=int, default=256, 
                        help='Dimension of hidden SIREN layers.')
    
    PARSER.add_argument('--siren_hidden_layers', type=int, default=3, 
                        help='Number of hidden SIREN layers.')
    
    
    PARSER.add_argument('--first_omega_0', type=float, default=30., 
                        help='Omega_0 of first layer.')
    
    PARSER.add_argument('--hidden_omega_0', type=float, default=30., 
                        help='Omega_0 of hidden layer.')
    
    
    PARSER.add_argument('--pcmra_first_omega_0', type=float, default=30., 
                        help='Omega_0 of first layer of PCMRA siren.')
    
    PARSER.add_argument('--pcmra_hidden_omega_0', type=float, default=30., 
                        help='Omega_0 of hidden layer of PCMRA siren.')
    
    
    
    # optimizers
    PARSER.add_argument('--cnn_lr', type=float, default=1e-4, 
                        help='Learning rate of cnn optim.')

    PARSER.add_argument('--cnn_wd', type=float, default=0, 
                        help='Weight decay of cnn optim.')

    
    PARSER.add_argument('--mapping_lr', type=float, default=1e-4, 
                        help='Learning rate of siren optim.')
    
    PARSER.add_argument('--pcmra_mapping_lr', type=float, default=1e-4, 
                        help='Learning rate of siren optim.')
    

    PARSER.add_argument('--siren_lr', type=float, default=1e-4, 
                        help='Learning rate of siren optim.')

    PARSER.add_argument('--siren_wd', type=float, default=0, 
                        help='Weight decay of siren optim.')
    
    
    PARSER.add_argument('--pcmra_siren_lr', type=float, default=1e-4, 
                        help='Learning rate of PCMRA siren optim.')    
    
    PARSER.add_argument('--pcmra_siren_wd', type=float, default=0, 
                        help='Weight decay of PCMRA siren optim.')
    
    PARSER.add_argument('--patience', type=int, default=200, 
                        help='Patience of the LR scheduler.')
    
    
    ARGS = PARSER.parse_args()
    
    train()