In [None]:
%run /home/ptenkaate/scratch/Master-Thesis/convert_ipynb_to_py_files.ipynb

In [None]:
import torch
import torch.nn.functional as F

from torchinfo import summary

from torch import nn
from torch.utils.data import DataLoader, Dataset
from torchvision.transforms import Resize, Compose, ToTensor, Normalize

import argparse
import os
import math 
import skimage
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image

import time
import pickle

from datetime import datetime
from pathlib import Path

from py_files.new_dataset import *

from py_files.cnn_model import *
from py_files.pigan_model import *

from py_files.seq_pi_gan_functions import *

In [None]:
input_size = (32, 1, 24, 64, 64)
encoder = Encoder()
summary(encoder, input_size=input_size, depth=2)

print(encoder)

In [None]:
class Decoder(nn.Module):

    def __init__(self):
        
        super(Decoder, self).__init__()
        
        self.model = nn.Sequential(
            
#             nn.ConvTranspose3d(512, 128, (3, 4, 4), stride=2, padding=0, output_padding=0),            
#             nn.ReLU(),
#             nn.ConvTranspose3d(128, 64, 3, stride=2, padding=1, output_padding=1),
#             nn.ReLU(),
#             nn.ConvTranspose3d(64, 32, 3, stride=2, padding=1, output_padding=1),
#             nn.ReLU(),
#             nn.ConvTranspose3d(32, 16, 3, stride=2, padding=1, output_padding=1),
#             nn.ReLU(),
#             nn.ConvTranspose3d(16, 1, 3, stride=(1, 2, 2), padding=1, output_padding=(0, 1, 1)),
            
            
            nn.ConvTranspose3d(512, 128, (3, 4, 4), stride=2, padding=0, output_padding=0),            
            nn.ReLU(),
            nn.ConvTranspose3d(128, 64, 3, stride=2, padding=1, output_padding=1),
            nn.ReLU(),
            nn.ConvTranspose3d(64, 32, 3, stride=2, padding=1, output_padding=1),
            nn.ReLU(),
            nn.Upsample(scale_factor=2, mode='trilinear'),
            nn.Conv3d(32, 32, 3, stride=1, padding=1),
            nn.ReLU(),
            nn.Conv3d(32, 16, 3, stride=1, padding=1),
            nn.ReLU(),
            nn.ConvTranspose3d(16, 1, 3, stride=(1, 2, 2), padding=1, output_padding=(0, 1, 1)),
            
        )

    def forward(self, x):
        out = self.model(x)
            
        return out
    
output_size = encoder(torch.randn(input_size).cuda()).shape
print(tuple(output_size))
decoder = Decoder()
summary(decoder, input_size=output_size, depth=2)


In [None]:
encoder = Encoder_1().cuda()
decoder = Decoder().cuda()

e_optim = torch.optim.Adam(lr=1e-4, params=encoder.parameters())
d_optim = torch.optim.Adam(lr=1e-4, params=decoder.parameters())

In [None]:
saved_run  = "pi-gan 18-05-2021 20:47:06 trained quite well"

encoder.load_state_dict(torch.load(f"saved_runs/{saved_run}/encoder_train.pt"))
decoder.load_state_dict(torch.load(f"saved_runs/{saved_run}/decoder_train.pt"))

In [None]:
# for g in e_optim.param_groups:
#     g['lr'] = 1e-5

# for g in d_optim.param_groups:
#     g['lr'] = 1e-5

ARGS = init_ARGS()
print(vars(ARGS))

##### path to wich the model should be saved #####
path = get_folder(ARGS)

##### save ARGS #####
with open(f"{path}/ARGS.txt", "w") as f:
    print(vars(ARGS), file=f)

##### data preparation #####
train_dl, val_dl, test_dl = initialize_dataloaders(ARGS)
print(next(iter(test_dl))[1])


In [None]:
ARGS.epochs = 10000

##### loss function #####
criterion = nn.MSELoss()

losses = np.empty((0, 5))

for ep in range(ARGS.epochs):
#     print("Epoch", ep)

    ep_loss = []

    t = time.time() 
    
    for batch in train_dl:
                    
        batch = transform_batch(batch, ARGS)            
        _, _, _, pcmra, coords, pcmra_array, mask_array = get_siren_batch(batch)
        
        
        out = decoder(encoder(pcmra))
#         out = decoder(encoder(pcmra).detach())

        loss = criterion(out, pcmra)
        loss.backward()

        ep_loss.append(loss.item())

        e_optim.step()
        d_optim.step()  

        e_optim.zero_grad(); 
        d_optim.zero_grad()  
        
        
    if (ep + 1) % 20 == 0: 

        with torch.no_grad():
            encoder.eval()
            decoder.eval()

            print(f"Epoch {ep} took {round(time.time() - t, 2)} seconds.")

            train_loss = []

            b = 0
            for batch in train_dl:

                batch = transform_batch(batch, ARGS)

                _, _, _, pcmra, coords, pcmra_array, mask_array = get_siren_batch(batch)

                out = encoder(pcmra)
                out = decoder(out)

                loss = criterion(out, pcmra)

                train_loss.append(loss.item())

                b += 1

                if b == 10: 
                    break

            print("train loss", np.array(train_loss).mean())

            val_loss = []

            for batch in val_dl:

                _, _, _, pcmra, coords, pcmra_array, mask_array = get_siren_batch(batch)

                out = encoder(pcmra)
                out = decoder(out)

                loss = criterion(out, pcmra)

                val_loss.append(loss.item())

            print("val loss", np.array(val_loss).mean())

            encoder.train()
            decoder.train()
            
            losses = np.append(losses, [[ep, np.array(train_loss).mean(), np.array(train_loss).std(), 
                                         np.array(val_loss).mean(), np.array(val_loss).std()]], axis=0)
            
            
            np.save(f"{path}/losses.npy", losses)
            
           
        models = {"encoder": encoder, "decoder": decoder}    
        optims = {"encoder": e_optim, "decoder": d_optim}    

        save_loss(path, losses, models, optims, name="loss")

torch.cuda.empty_cache()

In [None]:
def scroll_through_output(dataloader, shape=(64, 64, 24), transform=True):
    pcmras = outs = torch.Tensor([])
    
    titles = []

    ep_loss = []

    for batch in dataloader: 

        if transform:
            batch = transform_batch(batch, ARGS)

        _, _, _, pcmra, coords, pcmra_array, mask_array = get_siren_batch(batch)
        
        out = encoder(pcmra)
        out = decoder(out)

        loss = criterion(pcmra, out) 
        ep_loss.append(loss.item())


        pcmras = torch.cat((pcmras, pcmra.contiguous().view(-1, 64, 64).cpu().detach().permute(1, 2, 0)), 2)
        outs = torch.cat((outs, out.contiguous().view(-1, 64, 64).cpu().detach().permute(1, 2, 0)), 2)
        
    
    print(np.array(ep_loss).mean())
    window = Show_images("Comparison", (pcmras.numpy(), "pcmras"), 
                                 (outs.numpy(), "output"))

    return window

In [None]:
# %matplotlib qt

scroll_through_output(test_dl, shape=(64, 64, 24), transform=True)

In [None]:
# torch.save(encoder.state_dict(), f"{path}/cnn_train.pt")
# torch.save(e_optim.state_dict(), f"{path}/cnn_optim_train.pt")

# torch.save(decoder.state_dict(), f"{path}/decoder_train.pt")
# torch.save(e_optim.state_dict(), f"{path}/decoder_optim_train.pt")


In [None]:
torch.cuda.empty_cache()