In [None]:
import torch
from torch.utils.data import Dataset
from torchvision import transforms
import matplotlib.pyplot as plt
import os
import sys

root_dir = os.path.abspath(os.path.join(os.getcwd(), "../.."))
sys.path.append(root_dir)

from dataset_fake import Data_Loader
from model.generator import Generator
from model.reconstructor import Reconstructor
from model.discriminator import Discriminator

import cv2
from PIL import Image
import torch.nn as nn

ROOT_PATH = os.path.abspath(os.path.join(os.getcwd(), os.pardir))

cross_vaild = 'K1'

data_path = ROOT_PATH + '/Data/LS_{}_nonoverlap_test'.format(cross_vaild)
transform = transforms.Compose([transforms.Resize((256, 256)),
                                transforms.ToTensor(),
                                ])

image_dataset = Data_Loader(data_path, transform)
data = torch.utils.data.DataLoader(dataset=image_dataset,
                                   batch_size=1,
                                   drop_last=True
                                   )


device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

layer = 2
image_size = 256
generator_backbone = 'transunet'
discriminator_backbone = 'unet'

net_generator = Generator(n_layers=layer, backbone=generator_backbone)
net_discriminator = Discriminator(n_layers=layer, backbone=discriminator_backbone)
net_reconstructor = Reconstructor(if_CP=False)

net_G = nn.DataParallel(net_generator)
net_D = nn.DataParallel(net_discriminator)
net_R = nn.DataParallel(net_reconstructor)

net_G.to(device=device)
net_D.to(device=device)
net_R.to(device=device)

def predict(net_G, net_D, net_R, image, masks):
    parameter_path = ROOT_PATH + '/parameters'
    
    pretrain_path = '{}/best_Generator_Main_transunet_unet_K1.pth'.format(parameter_path)
    state_dict = torch.load(pretrain_path)
    net_G.module.load_state_dict(state_dict)

    pretrain_path = '{}/best_Reconstructor_Main_transunet_unet_K1.pth'.format(parameter_path)
    state_dict = torch.load(pretrain_path)
    net_R.module.load_state_dict(state_dict)

    pretrain_path = '{}/best_Discriminator_Main_transunet_unet_K1.pth'.format(parameter_path)
    state_dict = torch.load(pretrain_path)
    net_D.module.load_state_dict(state_dict)

    dis_mask = torch.clone(masks)
    dis_mask = torch.cat((dis_mask, dis_mask), dim=1)

    pre_layer = net_G(image, masks)

    pre_image, x = net_R(image, pre_layer, masks)
    print(x[0])
    pre_masks = net_D(pre_layer) * dis_mask

    cv2.imwrite(ROOT_PATH + '/results/fake_overlap/recon_image_{}.jpg'.format(image_name),
                pre_image.cpu().detach().numpy()[0][0] * 255)
    cv2.imwrite(ROOT_PATH + '/results/fake_overlap/upper_layer_{}.jpg'.format(image_name),
                pre_layer.cpu().detach().numpy()[0][1] * 255)
    cv2.imwrite(ROOT_PATH + '/results/fake_overlap/lower_layer_{}.jpg'.format(image_name),
                pre_layer.cpu().detach().numpy()[0][0] * 255)
    cv2.imwrite(ROOT_PATH + '/results/fake_overlap/upper_mask_{}.bmp'.format(image_name),
                pre_masks.cpu().detach().numpy()[0][1] * 255)
    cv2.imwrite(ROOT_PATH + '/results/fake_overlap/lower_mask_{}.bmp'.format(image_name),
                pre_masks.cpu().detach().numpy()[0][0] * 255)



parameter_path = ROOT_PATH + '/parameters'
for batch_idx, batch in enumerate(data):
    image, layer_masks, image_cropping, masks_cropping, real_layer_images, name = batch
    print('{} / {}'.format(batch_idx + 1, len(data)))
    '''
    model output
    '''
    net_R.eval()
    net_D.eval()
    net_G.eval()

    image = image.to(device=device, dtype=torch.float32)
    masks = layer_masks.to(device=device, dtype=torch.float32)
    image_cropping = image_cropping.to(device=device, dtype=torch.float32)
    masks_cropping = masks_cropping.to(device=device, dtype=torch.float32)
    real_layer_images = real_layer_images.to(device=device, dtype=torch.float32)

    image_name = name[0]
    print(image_name)

    # print('Pre Done')
    # predict(net_G, net_D, net_R, image, masks, mode='Pre')
    print('Main Done')
    predict(net_G, net_D, net_R, image, masks, mode='Main')
