In [None]:
from torch.utils.data import Dataset, random_split, WeightedRandomSampler
from datetime import datetime
import matplotlib.pyplot as plt
import os
import sys

root_dir = os.path.abspath(os.path.join(os.getcwd(), "../.."))
sys.path.append(root_dir)

from torch.utils.data import Subset

from dataset_downstream import Data_Loader
from model.lossfunction import LayerSegLoss_1, LayerSegLoss_2
from model.generator import Generator
from model.reconstructor import Reconstructor
from model.discriminator import Discriminator
from model.registrater import Registrater

import torch
import torch.nn as nn
from torchvision import transforms
from utils.monitor import Monitor
import numpy as np
import torch.nn.functional as F

ROOT_PATH = os.path.abspath(os.path.join(os.getcwd(), os.pardir))


def data_loading_to_device(*args, device='cpu'):
    args = tuple(arg.to(device=device, dtype=torch.float32) for arg in args)
    return args


def train(net_E, device, data_path, lr_E=0.0001, epochs=40, batch_size=1, image_size=256, save_path='', cross_vaild=''):
    time = datetime.now()
    date_time = time.strftime('%Y%m%d')

    # Data Loading
    transform = transforms.Compose([transforms.Resize((image_size, image_size)),
                                    transforms.ToTensor(),
                                    transforms.Normalize(0, 1)
                                    ])

    train_dataset = Data_Loader(data_path + '/DR_{}_train'.format(cross_vaild), transform)
    valid_dataset = Data_Loader(data_path + '/DR_{}_test'.format(cross_vaild), transform)
    # indices = list(range(20))
    # valid_dataset = Subset(valid_dataset, indices)

    # sampler setting
    train_loader = torch.utils.data.DataLoader(dataset=train_dataset,
                                               batch_size=batch_size,
                                               shuffle=True,
                                               drop_last=True
                                               )
    valid_loader = torch.utils.data.DataLoader(dataset=valid_dataset,
                                               batch_size=batch_size,
                                               shuffle=True,
                                               drop_last=True
                                               )

    # Monitor
    my_monitor = Monitor(epochs=epochs, device=device,
                         train_loss_name_list=['Registration_Loss'],
                         val_loss_name_list=['Registration_Loss_val'],
                         lr_name_list=['lr_E'],
                         train_dataset=train_loader,
                         val_dataset=valid_loader
                         )

    # Optimizer & Loss Function
    optimizer_E = torch.optim.Adam(net_E.parameters(), lr=lr_E, betas=(0.5, 0.999))

    scheduler_E = torch.optim.lr_scheduler.StepLR(optimizer=optimizer_E, step_size=50, gamma=0.5)

    criterion_E = nn.MSELoss()
    best_loss_E = float('inf')

    # Start train in epoch
    for epoch in range(epochs):
        my_monitor.train_start(optimizer_list=[optimizer_E])
        '''
            Start train
        '''
        for moving_image_layer, fixed_image_layer, moving_layer_masks, fixed_layer_masks in train_loader:
            net_E.train()
            # Data Loading
            moving_image_layer, fixed_image_layer, moving_layer_masks, fixed_layer_masks = data_loading_to_device(
                moving_image_layer, fixed_image_layer, moving_layer_masks, fixed_layer_masks, device=device)
            # 0 - lower, 1 - upper

            # Train the Generator
            optimizer_E.zero_grad()

            moving_reg, fixed_reg, moving_mask_reg, fixed_mask_reg, parmeter = net_E(moving_image_layer,
                                                                                     fixed_image_layer,
                                                                                     moving_layer_masks,
                                                                                     fixed_layer_masks)

            print('Output_upper: S({:.5f}) R({:.5f}) T({:.5f}, {:.5f})'.format(parmeter.cpu().detach().numpy()[0][0][0],
                                                                               parmeter.cpu().detach().numpy()[0][0][1],
                                                                               parmeter.cpu().detach().numpy()[0][0][2],
                                                                               parmeter.cpu().detach().numpy()[0][0][
                                                                                   3]),
                  'Output_lower: S({:.5f}) R({:.5f}) T({:.5f}, {:.5f})'.format(parmeter.cpu().detach().numpy()[0][1][0],
                                                                               parmeter.cpu().detach().numpy()[0][1][1],
                                                                               parmeter.cpu().detach().numpy()[0][1][2],
                                                                               parmeter.cpu().detach().numpy()[0][1][
                                                                                   3]))

            # loss mask
            range_without_value_moving_upper = moving_mask_reg[:, 1]
            range_without_value_fixed_upper = fixed_mask_reg[:, 1]

            range_without_value_upper = range_without_value_moving_upper + range_without_value_fixed_upper
            range_without_value_upper[range_without_value_upper > 0] = 1

            intersectio_upper = torch.sum(
                torch.logical_and(range_without_value_moving_upper, range_without_value_fixed_upper))
            union_upper = torch.sum(torch.logical_or(range_without_value_moving_upper, range_without_value_fixed_upper))

            range_without_value_moving_lower = moving_mask_reg[:, 0]
            range_without_value_fixed_lower = fixed_mask_reg[:, 0]

            range_without_value_lower = range_without_value_moving_lower + range_without_value_fixed_lower
            range_without_value_lower[range_without_value_lower > 0] = 1

            intersectio_lower = torch.sum(
                torch.logical_and(range_without_value_moving_lower, range_without_value_fixed_lower))
            union_lower = torch.sum(torch.logical_or(range_without_value_moving_lower, range_without_value_fixed_lower))

            loss_upper = criterion_E(moving_reg[:, 1] * range_without_value_upper,
                                     fixed_reg[:, 1] * range_without_value_upper) / (intersectio_upper / union_upper)
            loss_lower = criterion_E(moving_reg[:, 0] * range_without_value_lower,
                                     fixed_reg[:, 0] * range_without_value_lower) / (intersectio_lower / union_lower)

            loss_E = loss_lower + loss_upper

            loss_E.backward(retain_graph=True)
            optimizer_E.step()

            my_monitor.set_loss(loss_list=[loss_E])

        '''
            Start valid
        '''
        my_monitor.val_start()
        net_E.eval()
        with torch.no_grad():
            for moving_image_layer, fixed_image_layer, moving_layer_masks, fixed_layer_masks in valid_loader:
                # Data Loading
                moving_image_layer, fixed_image_layer, moving_layer_masks, fixed_layer_masks = data_loading_to_device(
                    moving_image_layer, fixed_image_layer, moving_layer_masks, fixed_layer_masks, device=device)

                moving_reg, fixed_reg, moving_mask_reg, fixed_mask_reg, parmeter = net_E(moving_image_layer,
                                                                                         fixed_image_layer,
                                                                                         moving_layer_masks,
                                                                                         fixed_layer_masks)

                # loss mask
                range_without_value_moving_upper = moving_mask_reg[:, 1]
                range_without_value_fixed_upper = fixed_mask_reg[:, 1]

                range_without_value_upper = range_without_value_moving_upper + range_without_value_fixed_upper
                range_without_value_upper[range_without_value_upper > 0] = 1

                intersectio_upper = torch.sum(
                    torch.logical_and(range_without_value_moving_upper, range_without_value_fixed_upper))
                union_upper = torch.sum(
                    torch.logical_or(range_without_value_moving_upper, range_without_value_fixed_upper))

                range_without_value_moving_lower = moving_mask_reg[:, 0]
                range_without_value_fixed_lower = fixed_mask_reg[:, 0]

                range_without_value_lower = range_without_value_moving_lower + range_without_value_fixed_lower
                range_without_value_lower[range_without_value_lower > 0] = 1

                intersectio_lower = torch.sum(
                    torch.logical_and(range_without_value_moving_lower, range_without_value_fixed_lower))
                union_lower = torch.sum(
                    torch.logical_or(range_without_value_moving_lower, range_without_value_fixed_lower))

                loss_upper = criterion_E(moving_reg[:, 1] * range_without_value_upper,
                                         fixed_reg[:, 1] * range_without_value_upper) / (
                                         intersectio_upper / union_upper)
                loss_lower = criterion_E(moving_reg[:, 0] * range_without_value_lower,
                                         fixed_reg[:, 0] * range_without_value_lower) / (
                                         intersectio_lower / union_lower)

                loss_E_val = loss_lower + loss_upper

                loss_image_upper = F.mse_loss(moving_reg[:, 1] * range_without_value_upper,
                                              fixed_reg[:, 1] * range_without_value_upper, reduction='none')
                loss_image_lower = F.mse_loss(moving_reg[:, 0] * range_without_value_lower,
                                              fixed_reg[:, 0] * range_without_value_lower, reduction='none')

                loss_image = torch.cat((loss_image_lower.unsqueeze(1), loss_image_upper.unsqueeze(1)), dim=1)

                my_monitor.set_loss(loss_list=[loss_E_val])
                my_monitor.set_output_image(number=3,
                                            image_list=[moving_image_layer, moving_image_layer, moving_image_layer,
                                                        fixed_image_layer, fixed_image_layer, fixed_image_layer,
                                                        moving_reg, loss_image, parmeter])

            my_monitor.show_val_result_downstream()

        my_monitor.epoch_summary()

        # Save parameters
        if my_monitor.get_recent_best_loss(loss_name='Registration_Loss') < best_loss_E:
            best_loss_E = my_monitor.get_recent_best_loss(loss_name='Registration_Loss')
            if isinstance(net_E, torch.nn.DataParallel):
                torch.save(net_E.module.state_dict(),
                           '{}/best_Registrater_Pre_{}.pth'.format(save_path, cross_vaild))

        # if epoch <= 40:
        # scheduler_G.step()
        # scheduler_D.step()
        # scheduler_E.step()
        # if epoch >= 100 and epoch <= 130:
        #     scheduler_G.step()
        #     scheduler_D.step()


if __name__ == "__main__":
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    layer = 2
    image_size = 256
    generator_backbone = 'transunet'
    discriminator_backbone = 'unet'

    net_registrater = Registrater(in_channels=4)
    net_registrater = nn.DataParallel(net_registrater)
    net_registrater.to(device=device)

    data_path = ROOT_PATH + '/Data'
    save_path = ROOT_PATH + '/parameters/downstream'

    cross_vaild = 'K1'
    
    train(net_E=net_registrater, 
          device=device, data_path=data_path, epochs=150, batch_size=50, lr_E=0.000001,
          image_size=image_size, cross_vaild=cross_vaild, save_path=save_path)
