In [None]:
from torch.utils.data import Dataset, random_split, WeightedRandomSampler
from datetime import datetime
import matplotlib.pyplot as plt
import os
import sys

root_dir = os.path.abspath(os.path.join(os.getcwd(), "../.."))
sys.path.append(root_dir)

from dataset_pretrain import Data_Loader
from model.lossfunction import LayerSegLoss_1, LayerSegLoss_2
from model.generator import Generator
from model.reconstructor import Reconstructor
from model.discriminator import Discriminator

import torch
import torch.nn as nn
from torchvision import transforms
from utils.monitor import Monitor
import numpy as np

ROOT_PATH = os.path.abspath(os.path.join(os.getcwd(), os.pardir))

def train(net_G, net_D, net_R, device, data_path, lr_G=0.00001, lr_D=0.0001, lr_R=0.0001, epochs=40, batch_size=1, image_size=256, save_path='', generator_backbone='', discriminator_backbone='', cross_vaild=''):
    time = datetime.now()
    date_time = time.strftime('%Y%m%d')

    # Data Loading
    transform = transforms.Compose([transforms.Resize((image_size, image_size)),
                                    transforms.ToTensor(),
                                    transforms.Normalize(0, 1)
                                    ])

    train_dataset = Data_Loader(data_path + '/LS_{}_nonoverlap_train'.format(cross_vaild), transform)
    valid_dataset = Data_Loader(data_path + '/LS_{}_nonoverlap_test'.format(cross_vaild), transform)

    # sampler setting
    train_loader = torch.utils.data.DataLoader(dataset=train_dataset,
                                                   batch_size=batch_size,
                                                   shuffle=True,
                                                   drop_last=True
                                                   )
    valid_loader = torch.utils.data.DataLoader(dataset=valid_dataset,
                                                   batch_size=batch_size,
                                                   shuffle=True,
                                                   drop_last=True
                                                   )

    # Monitor
    my_monitor = Monitor(epochs=epochs, device=device,
                         train_loss_name_list=['Generator_Loss', 'Discriminator_Loss'],
                         val_loss_name_list=['Generator_Loss_val'],
                         lr_name_list=['lr_G', 'lr_D'],
                         train_dataset=train_loader,
                         val_dataset=valid_loader
                         )

    # Optimizer & Loss Function
    optimizer_G = torch.optim.Adam(net_G.parameters(), lr=lr_G, betas=(0.5, 0.999))
    optimizer_D = torch.optim.Adam(net_D.parameters(), lr=lr_D, betas=(0.5, 0.999))
    optimizer_R = torch.optim.Adam(net_R.parameters(), lr=lr_R, betas=(0.5, 0.999))

    scheduler_G = torch.optim.lr_scheduler.StepLR(optimizer=optimizer_G, step_size=30, gamma=0.5)
    scheduler_D = torch.optim.lr_scheduler.StepLR(optimizer=optimizer_D, step_size=30, gamma=0.5)
    scheduler_R = torch.optim.lr_scheduler.StepLR(optimizer=optimizer_R, step_size=60, gamma=0.5)

    criterion_LS = LayerSegLoss_1()
    criterion_D = nn.BCELoss()

    best_loss_G = float('inf')
    best_loss_D = float('inf')

    # Start train in epoch
    for epoch in range(epochs):
        my_monitor.train_start(optimizer_list=[optimizer_G, optimizer_D])
        '''
            Start train
        '''
        for image, masks, image_cropping, masks_cropping, real_layer_images, _ in train_loader:
            net_G.train()
            net_D.train()
            net_R.train()
            # Data Loading
            image = image.to(device=device, dtype=torch.float32)
            masks = masks.to(device=device, dtype=torch.float32)
            image_cropping = image_cropping.to(device=device, dtype=torch.float32)
            masks_cropping = masks_cropping.to(device=device, dtype=torch.float32)
            real_layer_images = real_layer_images.to(device=device, dtype=torch.float32)

            # Train the discriminator
            optimizer_D.zero_grad()

            zeros_layer = torch.zeros_like(masks_cropping)

            label_real = torch.cat((masks_cropping, masks_cropping), dim=1)
            label_real_2 = torch.cat((masks, masks), dim=1)
            label_fake = torch.cat((masks, zeros_layer), dim=1)

            fake_image = net_G(image, masks)

            dis_mask = torch.clone(masks)
            dis_mask = torch.cat((dis_mask, dis_mask), dim=1)

            loss_D_real = criterion_D(net_D(image_cropping) * dis_mask, label_real)
            loss_D_real_2 = criterion_D(net_D(real_layer_images) * dis_mask, label_real_2)
            loss_D_fake = criterion_D(net_D(fake_image) * dis_mask, label_fake)

            loss_D = (loss_D_real + loss_D_fake + loss_D_real_2) / 3

            loss_D.backward()
            optimizer_D.step()

            # Train the Generator and Reconstructor
            optimizer_G.zero_grad()
            optimizer_R.zero_grad()

            pre_layer = net_G(image, masks)

            pre_image, x = net_R(image, pre_layer, masks)
            print(x[0])
            pre_masks = net_D(pre_layer) * dis_mask

            loss_GR = criterion_LS(org_image=image, org_masks=masks, real_layer=real_layer_images,
                                         pre_layer=pre_layer, recon_image=pre_image, pre_masks=pre_masks)

            loss_GR.backward()
            optimizer_G.step()
            optimizer_R.step()

            my_monitor.set_loss(loss_list=[loss_GR, loss_D])

        '''
            Start valid
        '''
        my_monitor.val_start()
        net_G.eval()
        net_D.eval()
        net_R.eval()
        with torch.no_grad():
            for image, masks, image_cropping, masks_cropping, real_layer_images, _ in valid_loader:
                # Data Loading
                image = image.to(device=device, dtype=torch.float32)
                masks = masks.to(device=device, dtype=torch.float32)
                image_cropping = image_cropping.to(device=device, dtype=torch.float32)
                masks_cropping = masks_cropping.to(device=device, dtype=torch.float32)
                real_layer_images = real_layer_images.to(device=device, dtype=torch.float32)

                dis_mask = torch.clone(masks)
                dis_mask = torch.cat((dis_mask, dis_mask), dim=1)

                org_mask_all = torch.sum(masks, 1).unsqueeze(1)
                org_mask_all[org_mask_all != 0] = 1

                pre_layer = net_G(image, masks)

                pre_image, x = net_R(image, pre_layer, masks)
                print(x[0])
                pre_masks = net_D(pre_layer) * dis_mask

                loss_GR_val = criterion_LS(org_image=image, org_masks=masks, real_layer=real_layer_images,
                                                 pre_layer=pre_layer, recon_image=pre_image, pre_masks=pre_masks)

                my_monitor.set_loss(loss_list=[loss_GR_val])
                my_monitor.set_output_image(number=4, image_list=[image, pre_layer, pre_image, pre_masks, org_mask_all, x])

            my_monitor.show_val_result()
        
        my_monitor.epoch_summary()

        # Save parameters
        if my_monitor.get_recent_best_loss(loss_name='Generator_Loss') < best_loss_G:
            best_loss_G = my_monitor.get_recent_best_loss(loss_name='Generator_Loss')
            if isinstance(net_G, torch.nn.DataParallel):
                torch.save(net_G.module.state_dict(),
                           '{}/best_Generator_Pre_{}_{}_{}.pth'.format(save_path, generator_backbone, discriminator_backbone, cross_vaild))
                torch.save(net_R.module.state_dict(),
                           '{}/best_Reconstructor_Pre_{}_{}_{}.pth'.format(save_path, generator_backbone, discriminator_backbone, cross_vaild))

        if my_monitor.get_recent_best_loss(loss_name='Discriminator_Loss') < best_loss_D:
            best_loss_D = my_monitor.get_recent_best_loss(loss_name='Discriminator_Loss')
            if isinstance(net_D, torch.nn.DataParallel):
                torch.save(net_D.module.state_dict(),
                           '{}/best_Discriminator_Pre_{}_{}_{}.pth'.format(save_path, generator_backbone, discriminator_backbone, cross_vaild))
       
        # scheduler_G.step()
        # scheduler_D.step()
        scheduler_R.step()

if __name__ == "__main__":
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    layer = 2
    image_size = 256
    generator_backbone = 'transunet'
    discriminator_backbone = 'unet'

    net_generator = Generator(n_layers=layer, backbone=generator_backbone)
    net_discriminator = Discriminator(n_layers=layer, backbone=discriminator_backbone)
    net_reconstructor = Reconstructor()

    net_generator = nn.DataParallel(net_generator)
    net_discriminator = nn.DataParallel(net_discriminator)
    net_reconstructor = nn.DataParallel(net_reconstructor)

    net_generator.to(device=device)
    net_discriminator.to(device=device)
    net_reconstructor.to(device=device)

    data_path = ROOT_PATH + '/Data'
    save_path = ROOT_PATH + '/parameters/pre_train'
    cross_vaild = 'K1'

    train(net_G=net_generator, net_D=net_discriminator, net_R=net_reconstructor, 
          device=device, data_path=data_path,
          epochs=250, batch_size=12, lr_G=0.00001, lr_D=0.000001, lr_R=0.00001, image_size=image_size,
          generator_backbone=generator_backbone, 
          discriminator_backbone=discriminator_backbone,
          cross_vaild=cross_vaild, save_path=save_path)
