In [None]:
# %run /home/ptenkaate/scratch/Master-Thesis/convert_ipynb_to_py_files.ipynb

In [None]:
import torch
import torch.nn.functional as F

from torch import nn
from torch.utils.data import DataLoader, Dataset
from torchvision.transforms import Resize, Compose, ToTensor, Normalize

import argparse
import os
import math 
import skimage
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image

import time
import pickle

from datetime import datetime
from pathlib import Path

from py_files.new_dataset import *

from py_files.cnn_model import *
from py_files.pigan_model import *

from py_files.seq_pi_gan_functions import *

In [None]:
def double_conv(in_channels, out_channels):
    return nn.Sequential(
        nn.Conv3d(in_channels, out_channels, 3, padding=1),
        nn.ReLU(inplace=True),
        nn.Conv3d(out_channels, out_channels, 3, padding=1),
        nn.ReLU(inplace=True)
    )   


class UNet(nn.Module):

    def __init__(self):
        super().__init__()
                
        self.dconv_down1 = double_conv(1, 64)
        self.dconv_down2 = double_conv(64, 128)
        self.dconv_down3 = double_conv(128, 256)
        self.dconv_down4 = double_conv(256, 512)        

        self.maxpool = nn.MaxPool3d(2)
        self.upsample = nn.Upsample(scale_factor=2, mode='trilinear', align_corners=True)        
        
        self.dconv_up3 = double_conv(512, 256)
        self.dconv_up2 = double_conv(256, 128)
        self.dconv_up1 = double_conv(128, 64)
        
        self.conv_last = nn.Conv3d(64, 1, 1)
        
        
    def forward(self, x):
        conv1 = self.dconv_down1(x)
        x = self.maxpool(conv1)

        conv2 = self.dconv_down2(x)
        x = self.maxpool(conv2)
        
        conv3 = self.dconv_down3(x)
        x = self.maxpool(conv3)   
        
        x = self.dconv_down4(x)
        
        x = self.upsample(x)        
        
        x = self.dconv_up3(x)
        x = self.upsample(x)        

        x = self.dconv_up2(x)
        x = self.upsample(x)        
        
        x = self.dconv_up1(x)
        
        out = self.conv_last(x)
        
        return out

In [None]:
# ARGS = init_ARGS()

# ARGS.rotate, ARGS.translate, ARGS.flip = False, True, True
# ARGS.epochs = 50

# ARGS.batch_size = 24


# ##### path to wich the model should be saved #####
# path = get_folder(ARGS)

# ##### save ARGS #####
# with open(f"{path}/ARGS.txt", "w") as f:
#     print(vars(ARGS), file=f)

# ##### data preparation #####
# train_dl, val_dl, test_dl = initialize_dataloaders(ARGS) 

In [None]:
unet = UNet().cuda()

optim = torch.optim.Adam(lr=1e-4, params=unet.parameters())


##### loss function #####
criterion = nn.MSELoss()

losses = np.empty((0, 5))

In [None]:
for ep in range(ARGS.epochs):

    t = time.time() 

    ep_loss = []

    for _, _, _, pcmra, coords, pcmra_array, mask_array in train_dl:

        out = unet(pcmra)

        loss = criterion(out, pcmra)
        loss.backward()

        ep_loss.append(loss.item())

        optim.step()

        optim.zero_grad()

    print(np.array(ep_loss).mean())
    
    if ep % 5 == 0: 
        
        print(f"Epoch {ep} took {round(time.time() - t, 2)} seconds.")

        ep_loss = []
        
        for _, _, _, pcmra, coords, pcmra_array, mask_array in val_dl:

            out = unet(pcmra)

            loss = criterion(out, pcmra)

            ep_loss.append(loss.item())

        print("val loss", np.array(ep_loss).mean())

In [None]:
idx, subj, proj, pcmra, coords, pcmra_array, mask_array = next(iter(val_dl))
            
out = encoder(pcmra)
out = decoder(out)

# print(subj)

# print(out.shape)
# print(out.shape)

print(criterion(pcmra, out).item())

i = 6
slce = 8
in_i = pcmra[i, 0, slce, :, :]
out_i = out[i, 0, slce, :, :]

plt.imshow(in_i.cpu())
plt.show()

plt.imshow(out_i.cpu().detach())
plt.show()



In [None]:
from torchinfo import summary

summary(unet, input_size=(32, 1, 24, 64, 64))