In [None]:
import os
import time
import numpy
import matplotlib.pyplot as plt


import torch
from torch import nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data.dataloader import DataLoader

import torchvision
from torchvision import transforms
from torchvision import utils as vutils
import torchvision.models as models

import argparse
import random
from tqdm import tqdm

from models import weights_init, Discriminator, Generator
from operation import copy_G_params, load_params, get_dir
from operation import ImageFolder, InfiniteSamplerWrapper
from diffaug import DiffAugment
policy = 'color,translation'
import lpips
percept = lpips.PerceptualLoss(model='net-lin', net='vgg', use_gpu=True)

from helper_function import get_covariance, frechet_distance
from torchvision.models import inception_v3


#torch.backends.cudnn.benchmark = True

def crop_image_by_part(image, part):
    hw = image.shape[2]//2
    if part==0:
        return image[:,:,:hw,:hw]
    if part==1:
        return image[:,:,:hw,hw:]
    if part==2:
        return image[:,:,hw:,:hw]
    if part==3:
        return image[:,:,hw:,hw:]

def train_d(net, data, label="real"):
    """Train function of discriminator"""
    if label=="real":
        part = random.randint(0, 3)
        pred, [rec_all, rec_small, rec_part] = net(data, label, part=part)
        err = F.relu(  torch.rand_like(pred) * 0.2 + 0.8 -  pred).mean() + \
            percept( rec_all, F.interpolate(data, rec_all.shape[2]) ).sum() +\
            percept( rec_small, F.interpolate(data, rec_small.shape[2]) ).sum() +\
            percept( rec_part, F.interpolate(crop_image_by_part(data, part), rec_part.shape[2]) ).sum()
        err.backward()
        return pred.mean().item(), rec_all, rec_small, rec_part
    else:
        pred = net(data, label)
        err = F.relu( torch.rand_like(pred) * 0.2 + 0.8 + pred).mean()
        err.backward()
        return pred.mean().item()

Setting up Perceptual loss...
Loading model from: /home/dnn4/pythonCodeArea/Ashish/Final_Dissertation/FastGAn/FastGAN-pytorch/lpips/weights/v0.1/vgg.pth
...[net-lin [vgg]] initialized
...Done


In [None]:
import cv2
import numpy as np
from PIL import Image

class CLAHETransform(object):
    def __init__(self, clip_limit=2.0, tile_grid_size=(8, 8)):
        self.clip_limit = clip_limit
        self.tile_grid_size = tile_grid_size

    def __call__(self, image):
        lab_img = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2LAB)
        l, a, b = cv2.split(lab_img)
        clahe = cv2.createCLAHE(clipLimit=self.clip_limit, tileGridSize=self.tile_grid_size)
        clahe_img = clahe.apply(l)
        l = clahe.apply(l)
        img_lab = cv2.merge((l, a, b))
        img_np = cv2.cvtColor(img_lab, cv2.COLOR_LAB2RGB)
        img = Image.fromarray(img_np)
        return img

In [None]:
def train(args):
    data_root = args.path
    total_iterations = args.iter
    checkpoint = args.ckpt
    batch_size = args.batch_size
    im_size = args.im_size
    ndf = 64
    ngf = 64
    nz = 256
    nlr = 0.0002
    nbeta1 = 0.5
    use_cuda = True
    multi_gpu = True
    # n_epochs = args.n_epochs
    dataloader_workers = args.workers
    current_iteration = args.start_iter
    save_interval = args.save_interval
    saved_model_folder, saved_image_folder = get_dir(args)


    device = torch.device("cpu")
    if use_cuda:
        device = torch.device("cuda:0")

    transform_list = [
            transforms.Resize((int(im_size),int(im_size))),
            CLAHETransform(),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
        ]
    trans = transforms.Compose(transform_list)

    if 'lmdb' in data_root:
        from operation import MultiResolutionDataset
        dataset = MultiResolutionDataset(data_root, trans, 1024)
    else:
        dataset = ImageFolder(root=data_root, transform=trans)


    dataloader = iter(DataLoader(dataset, batch_size=batch_size, shuffle=False,
                      sampler=InfiniteSamplerWrapper(dataset), num_workers=dataloader_workers, pin_memory=True))
    '''
    loader = MultiEpochsDataLoader(dataset, batch_size=batch_size,
                               shuffle=True, num_workers=dataloader_workers,
                               pin_memory=True)
    dataloader = CudaDataLoader(loader, 'cuda')
    '''


    #from model_s import Generator, Discriminator
    netG = Generator(ngf=ngf, nz=nz, im_size=im_size)
    netG.apply(weights_init)

    netD = Discriminator(ndf=ndf, im_size=im_size)
    netD.apply(weights_init)

    netG.to(device)
    netD.to(device)

    avg_param_G = copy_G_params(netG)

    fixed_noise = torch.FloatTensor(8, nz).normal_(0, 1).to(device)

    optimizerG = optim.Adam(netG.parameters(), lr=nlr, betas=(nbeta1, 0.999))
    optimizerD = optim.Adam(netD.parameters(), lr=nlr, betas=(nbeta1, 0.999))

    if checkpoint != 'None':
        ckpt = torch.load(checkpoint)
        netG.load_state_dict({k.replace('module.', ''): v for k, v in ckpt['g'].items()})
        netD.load_state_dict({k.replace('module.', ''): v for k, v in ckpt['d'].items()})
        avg_param_G = ckpt['g_ema']
        optimizerG.load_state_dict(ckpt['opt_g'])
        optimizerD.load_state_dict(ckpt['opt_d'])
        current_iteration = int(checkpoint.split('_')[-1].split('.')[0])
        # current_iteration = 50000
        print("checkpoint loaded successfully")
        # del ckpt

    if multi_gpu:
        netG = nn.DataParallel(netG.to(device))
        netD = nn.DataParallel(netD.to(device))




    for iteration in tqdm(range(current_iteration, total_iterations+1)):
        real_image = next(dataloader)
        real_image = real_image.to(device)
        current_batch_size = real_image.size(0)
        noise = torch.Tensor(current_batch_size, nz).normal_(0, 1).to(device)

        fake_images = netG(noise)

        real_image = DiffAugment(real_image, policy=policy)
        fake_images = [DiffAugment(fake, policy=policy) for fake in fake_images]

        ## 2. train Discriminator
        netD.zero_grad()

        err_dr, rec_img_all, rec_img_small, rec_img_part = train_d(netD, real_image, label="real")
        train_d(netD, [fi.detach() for fi in fake_images], label="fake")
        optimizerD.step()

        ## 3. train Generator
        netG.zero_grad()
        pred_g = netD(fake_images, "fake")
        err_g = -pred_g.mean()

        err_g.backward()
        optimizerG.step()

        for p, avg_p in zip(netG.parameters(), avg_param_G):
            avg_p.mul_(0.999).add_(0.001 * p.data)

        if iteration % 100 == 0:
            print("GAN: loss d: %.5f    loss g: %.5f"%(err_dr, -err_g.item()))

        if iteration % (save_interval) == 0:
            backup_para = copy_G_params(netG)
            load_params(netG, avg_param_G)
            with torch.no_grad():
                vutils.save_image(netG(fixed_noise)[0].add(1).mul(0.5), saved_image_folder+'/%d.jpg'%iteration, nrow=4)
                vutils.save_image( torch.cat([
                        F.interpolate(real_image, 128),
                        rec_img_all, rec_img_small,
                        rec_img_part]).add(1).mul(0.5), saved_image_folder+'/rec_%d.jpg'%iteration )
            load_params(netG, backup_para)

        if iteration % (save_interval) == 0 or iteration == total_iterations:
            backup_para = copy_G_params(netG)
            load_params(netG, avg_param_G)
            torch.save({'g':netG.state_dict(),'d':netD.state_dict()}, saved_model_folder+'/%d.pth'%iteration)
            load_params(netG, backup_para)
            torch.save({'g':netG.state_dict(),
                        'd':netD.state_dict(),
                        'g_ema': avg_param_G,
                        'opt_g': optimizerG.state_dict(),
                        'opt_d': optimizerD.state_dict()}, saved_model_folder+'/all_%d.pth'%iteration)



In [None]:
# if __name__ == "__main__":
#     parser = argparse.ArgumentParser(description='region gan')

#     parser.add_argument('--path', type=str, default='../lmdbs/art_landscape_1k', help='path of resource dataset, should be a folder that has one or many sub image folders inside')
#     parser.add_argument('--output_path', type=str, default='./', help='Output path for the train results')
#     parser.add_argument('--cuda', type=int, default=0, help='index of gpu to use')
#     parser.add_argument('--name', type=str, default='test1', help='experiment name')
#     parser.add_argument('--iter', type=int, default=50000, help='number of iterations')
#     parser.add_argument('--start_iter', type=int, default=0, help='the iteration to start training')
#     parser.add_argument('--batch_size', type=int, default=8, help='mini batch number of images')
#     parser.add_argument('--im_size', type=int, default=1024, help='image resolution')
#     parser.add_argument('--ckpt', type=str, default='None', help='checkpoint weight path if have one')
#     parser.add_argument('--workers', type=int, default=2, help='number of workers for dataloader')
#     parser.add_argument('--save_interval', type=int, default=100, help='number of iterations to save model')

class argument:
    def __init__(self, path='./', output_path='./', cuda=0, name='result', iter=50000,
                 start_iter=0, batch_size=8, im_size=512, ckpt='None', workers=2, save_interval=50):
        self.path = path
        self.output_path = output_path
        self.cuda = cuda
        self.name = name
        self.iter = iter
        self.start_iter = start_iter
        self.batch_size = batch_size
        self.im_size = im_size
        self.ckpt = ckpt
        self.workers = workers
        self.save_interval = save_interval

In [None]:
arg = argument()
arg.path = "../../DATA/Classification_Dataset/original/Train/CR Scale RIS"
arg.ckpt = "train_results/test1/models/all_50000.pth"
arg.iter = 100000
arg.name = "HR Sliver NMI"
train(arg)

# calculation of FID

In [None]:
arg = argument()
arg.iter = 50000
arg.name = "HR Scab Patch"
arg.path = "../../DATA/04_Mar_24/HR Scab Patch"

In [None]:
def FD(args):
    data_root = args.path
    total_iterations = args.iter
    checkpoint = args.ckpt
    batch_size = args.batch_size
    im_size = args.im_size
    ndf = 64
    ngf = 64
    nz = 256
    nlr = 0.0002
    nbeta1 = 0.5
    use_cuda = True
    multi_gpu = True
    dataloader_workers = args.workers
    current_iteration = args.start_iter
    save_interval = args.save_interval
    saved_model_folder, saved_image_folder = get_dir(args)


    device = torch.device("cpu")
    if use_cuda:
        device = torch.device("cuda:0")

    transform_list = [
            transforms.Resize((int(im_size),int(im_size))),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
        ]
    trans = transforms.Compose(transform_list)

    if 'lmdb' in data_root:
        from operation import MultiResolutionDataset
        dataset = MultiResolutionDataset(data_root, trans, 1024)
    else:
        dataset = ImageFolder(root=data_root, transform=trans)


    print("size of dataset",len(dataset))
    dataloader = torch.utils.data.DataLoader(dataset,batch_size = batch_size,shuffle = True)
    # dataloader = iter(DataLoader(dataset, batch_size=batch_size, shuffle=False,
    #                   sampler=InfiniteSamplerWrapper(dataset), num_workers=dataloader_workers, pin_memory=True))



    inception_model = models.inception_v3(weights=models.Inception_V3_Weights.IMAGENET1K_V1)
    inception_model.fc = torch.nn.Identity()
    inception_model.to(device)

    real_features_list = []
    n_samples = len(dataset) # The total number of samples

    inception_model = inception_model.eval()
    with torch.no_grad(): # You don't need to calculate gradients here, so you do this to save memory
        try:
            for real_example in tqdm(dataloader, total=n_samples // batch_size): # Go by batch
                real_samples = real_example
                real_features = inception_model(real_samples.to(device)).detach().to('cpu') # Move features to CPU
                real_features_list.append(real_features)
        except:
            print("Error in real loop")


    real_features_all = torch.cat(real_features_list)
    mu_real = real_features_all.mean(0)
    sigma_real = get_covariance(real_features_all)

    FID = []
    epochs = [i for i in range(77500,100001,500)]
    net_ig = Generator( ngf=64, nz= nz, nc=3, im_size=args.im_size)
    net_ig.to(device)
    for epoch in epochs:
        fake_features_list = []
#################################################################################################

        ckpt = f"train_results/{args.name}/models/{epoch}.pth"
        checkpoint = torch.load(ckpt, map_location=lambda a,b: a)
        # Remove prefix `module`.
        checkpoint['g'] = {k.replace('module.', ''): v for k, v in checkpoint['g'].items()}
        net_ig.load_state_dict(checkpoint['g'])
        #load_params(net_ig, checkpoint['g_ema'])
        print('load checkpoint success, epoch %d'%epoch)

        net_ig.eval()
        with torch.no_grad():
            try:
                for real_example in tqdm(dataloader, total=n_samples // batch_size): # Go by batch
                    fake_examples = torch.Tensor(len(real_example), nz).normal_(0, 1).to(device)
                    fake_samples = net_ig(fake_examples)
                    fake_features = inception_model(fake_samples[0]).detach().to('cpu')
                    fake_features_list.append(fake_features)
            except:
                print("Error in fake loop")
        fake_features_all = torch.cat(fake_features_list)
        mu_fake = fake_features_all.mean(0)
        sigma_fake = get_covariance(fake_features_all)
        fid = frechet_distance(mu_real, mu_fake, sigma_real, sigma_fake).item()
        print(fid)
        FID.append(fid)
        # with torch.no_grad():
        #     print(frechet_distance(mu_real, mu_fake, sigma_real, sigma_fake).item())
    return FID,epochs


In [None]:
y,z = FD(arg)

size of dataset 56


100%|█████████████████████████████████████████████| 7/7 [00:00<00:00, 20.76it/s]


load checkpoint success, epoch 77500


100%|█████████████████████████████████████████████| 7/7 [00:00<00:00, 10.85it/s]


54.294979095458984
load checkpoint success, epoch 78000


100%|█████████████████████████████████████████████| 7/7 [00:00<00:00, 11.28it/s]


61.90937805175781
load checkpoint success, epoch 78500


100%|█████████████████████████████████████████████| 7/7 [00:00<00:00, 10.00it/s]


55.44194412231445
load checkpoint success, epoch 79000


100%|█████████████████████████████████████████████| 7/7 [00:00<00:00, 10.30it/s]


57.55289840698242
load checkpoint success, epoch 79500


100%|█████████████████████████████████████████████| 7/7 [00:00<00:00, 11.55it/s]


53.049556732177734
load checkpoint success, epoch 80000


100%|█████████████████████████████████████████████| 7/7 [00:00<00:00, 10.10it/s]


50.32158279418945
load checkpoint success, epoch 80500


100%|█████████████████████████████████████████████| 7/7 [00:00<00:00, 11.51it/s]


56.59404754638672
load checkpoint success, epoch 81000


100%|█████████████████████████████████████████████| 7/7 [00:00<00:00, 10.26it/s]


56.961158752441406
load checkpoint success, epoch 81500


100%|█████████████████████████████████████████████| 7/7 [00:00<00:00, 10.22it/s]


53.41719436645508
load checkpoint success, epoch 82000


100%|█████████████████████████████████████████████| 7/7 [00:00<00:00, 10.14it/s]


58.539493560791016
load checkpoint success, epoch 82500


100%|█████████████████████████████████████████████| 7/7 [00:00<00:00, 10.06it/s]


57.01541519165039
load checkpoint success, epoch 83000


100%|█████████████████████████████████████████████| 7/7 [00:00<00:00, 10.41it/s]


53.60749816894531
load checkpoint success, epoch 83500


100%|█████████████████████████████████████████████| 7/7 [00:00<00:00, 11.11it/s]


54.236480712890625
load checkpoint success, epoch 84000


100%|█████████████████████████████████████████████| 7/7 [00:00<00:00, 10.37it/s]


52.32200622558594
load checkpoint success, epoch 84500


100%|█████████████████████████████████████████████| 7/7 [00:00<00:00, 10.54it/s]


53.532615661621094
load checkpoint success, epoch 85000


100%|█████████████████████████████████████████████| 7/7 [00:00<00:00, 10.71it/s]


51.14384841918945
load checkpoint success, epoch 85500


100%|█████████████████████████████████████████████| 7/7 [00:00<00:00, 10.81it/s]


54.35237503051758
load checkpoint success, epoch 86000


100%|█████████████████████████████████████████████| 7/7 [00:00<00:00, 10.40it/s]


54.5759162902832
load checkpoint success, epoch 86500


100%|█████████████████████████████████████████████| 7/7 [00:00<00:00, 11.96it/s]


48.90548324584961
load checkpoint success, epoch 87000


100%|█████████████████████████████████████████████| 7/7 [00:00<00:00, 10.40it/s]


49.619327545166016
load checkpoint success, epoch 87500


100%|█████████████████████████████████████████████| 7/7 [00:00<00:00, 10.85it/s]


55.35722732543945
load checkpoint success, epoch 88000


100%|█████████████████████████████████████████████| 7/7 [00:00<00:00, 10.34it/s]


54.20051574707031
load checkpoint success, epoch 88500


100%|█████████████████████████████████████████████| 7/7 [00:00<00:00, 10.62it/s]


55.45621109008789
load checkpoint success, epoch 89000


100%|█████████████████████████████████████████████| 7/7 [00:00<00:00, 10.04it/s]


57.36295700073242
load checkpoint success, epoch 89500


100%|█████████████████████████████████████████████| 7/7 [00:00<00:00, 10.60it/s]


59.099971771240234
load checkpoint success, epoch 90000


100%|█████████████████████████████████████████████| 7/7 [00:00<00:00, 10.37it/s]


54.865234375
load checkpoint success, epoch 90500


100%|█████████████████████████████████████████████| 7/7 [00:00<00:00, 10.26it/s]


56.0011100769043
load checkpoint success, epoch 91000


100%|█████████████████████████████████████████████| 7/7 [00:00<00:00, 11.28it/s]


54.8021240234375
load checkpoint success, epoch 91500


100%|█████████████████████████████████████████████| 7/7 [00:00<00:00, 10.88it/s]


51.938575744628906
load checkpoint success, epoch 92000


100%|█████████████████████████████████████████████| 7/7 [00:00<00:00, 10.62it/s]


54.81195068359375
load checkpoint success, epoch 92500


100%|█████████████████████████████████████████████| 7/7 [00:00<00:00, 10.48it/s]


51.88690185546875
load checkpoint success, epoch 93000


100%|█████████████████████████████████████████████| 7/7 [00:00<00:00, 10.03it/s]


55.61791229248047
load checkpoint success, epoch 93500


100%|█████████████████████████████████████████████| 7/7 [00:00<00:00, 11.90it/s]


53.71633529663086
load checkpoint success, epoch 94000


100%|█████████████████████████████████████████████| 7/7 [00:00<00:00, 11.77it/s]


56.821163177490234
load checkpoint success, epoch 94500


100%|█████████████████████████████████████████████| 7/7 [00:00<00:00, 10.06it/s]


63.63819122314453
load checkpoint success, epoch 95000


100%|█████████████████████████████████████████████| 7/7 [00:00<00:00, 10.83it/s]


57.48065948486328
load checkpoint success, epoch 95500


100%|█████████████████████████████████████████████| 7/7 [00:00<00:00, 10.82it/s]


53.95498275756836
load checkpoint success, epoch 96000


100%|█████████████████████████████████████████████| 7/7 [00:00<00:00, 10.68it/s]


54.77057647705078
load checkpoint success, epoch 96500


100%|█████████████████████████████████████████████| 7/7 [00:00<00:00, 10.26it/s]


53.1870231628418
load checkpoint success, epoch 97000


100%|█████████████████████████████████████████████| 7/7 [00:00<00:00, 11.69it/s]


57.89139938354492
load checkpoint success, epoch 97500


100%|█████████████████████████████████████████████| 7/7 [00:00<00:00, 10.95it/s]


53.665618896484375
load checkpoint success, epoch 98000


100%|█████████████████████████████████████████████| 7/7 [00:00<00:00, 10.21it/s]


57.38545227050781
load checkpoint success, epoch 98500


100%|█████████████████████████████████████████████| 7/7 [00:00<00:00, 10.21it/s]


58.847618103027344
load checkpoint success, epoch 99000


100%|█████████████████████████████████████████████| 7/7 [00:00<00:00, 10.35it/s]


55.35979080200195
load checkpoint success, epoch 99500


100%|█████████████████████████████████████████████| 7/7 [00:00<00:00, 10.99it/s]


60.85419845581055
load checkpoint success, epoch 100000


100%|█████████████████████████████████████████████| 7/7 [00:00<00:00, 10.38it/s]


55.826568603515625
