In [None]:
import numpy as np

import torch
import torchvision
import torch.nn as nn
import torch.optim as optim
from model import I2INet, CCNet
from City_dataloader import dataloader, vdataloader
from torch.optim import lr_scheduler
from torch.autograd import Variable
from pytorch_ssim import MSSSIM
from torchvision import datasets, models, transforms
from model_resnet import Discriminator
import torchvision.transforms.functional as TF

%matplotlib inline
import os
import matplotlib
import matplotlib.pyplot as plt

from skimage.measure import compare_ssim as ssim
from skimage.measure import compare_psnr as psnr

from vgg_features import vgg_features
from PIL import Image
from tqdm import tqdm

In [None]:
def mse(x, y):
    return np.linalg.norm(x - y)

def kl(p, q):
    """Kullback-Leibler divergence D(P || Q) for discrete distributions
    Parameters
    ----------
    p, q : array-like, dtype=float, shape=n
    Discrete probability distributions.
    """
    p = np.asarray(p, dtype=np.float)
    q = np.asarray(q, dtype=np.float)

    return np.sum(np.where(p != 0, p * np.log(p / q), 0))

In [None]:
seed = 8
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)  # if you are using multi-GPU.
np.random.seed(seed)  # Numpy module.
np.random.seed(seed)  # Python random module.
torch.manual_seed(seed)
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True

def worker_init_fn(worker_id):                                                          
    np.random.seed(8)

In [None]:
batch_size = 8
lr_rate = 1e-4 #Change

In [None]:
transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])])
dataset = dataloader('/home/shyam.nandan/DeepLabv3.pytorch-master/data',transform )
loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=4 ,pin_memory=False,worker_init_fn=worker_init_fn)
vdataset = vdataloader('/home/shyam.nandan/DeepLabv3.pytorch-master/data',transform )
vloader = torch.utils.data.DataLoader(vdataset, batch_size=1, shuffle=True, num_workers=4 ,pin_memory=False,worker_init_fn=worker_init_fn)

In [None]:
model = I2INet()
model = model.cuda()
model = nn.DataParallel(model, list(range(2)))
print(model)
netD = Discriminator().cuda()
netD = nn.DataParallel(netD, list(range(2)))
print(netD)
loss_fn = nn.L1Loss()
msssim_loss = MSSSIM()
optimizer = optim.Adam(model.parameters(),lr=lr_rate, betas=(0.5, 0.999))
optimizerD = optim.Adam(filter(lambda p: p.requires_grad, netD.parameters()),lr=lr_rate, betas=(0.5, 0.999))
exp_lr_scheduler = lr_scheduler.StepLR(optimizer, step_size=15, gamma=0.1)
vggf = vgg_features()

In [None]:
model.train()
track_loss = []
short_avg = 0
best_psnr = -100
for Ep in range(25):
    torch.save(model.state_dict(), 'EpochG_10_vgg.pth')
    torch.save(netD.state_dict(), 'EpochD_10_vgg.pth')
    exp_lr_scheduler.step()
    for i, data in enumerate(tqdm(loader)):
                Img = data['Img']
                Img = Variable(Img).cuda()
                tImg = data['tImg']
                tImg = Variable(tImg).cuda()
                output = model(tImg)
                #####
                optimizerD.zero_grad()
                real_out = netD(Img)
                fake_out = netD(output)
                lossD = torch.mean((fake_out - 0)**2) + torch.mean((real_out - 1)**2)
                lossD.backward()
                optimizerD.step()

                #####
                output = model(tImg)
                fake_out = netD(output)
                loss = 100*loss_fn(output, Img) + torch.mean((fake_out - 1)**2) + 5*loss_fn(vggf.forward(output), vggf.forward(Img))
                #loss = loss_fn(output, tImg)  + 20*(1 - msssim_loss(output, tImg))
                short_avg += loss.data.cpu().numpy()
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
                if i%10 == 0:
                    track_loss.append(short_avg/10)
                    short_avg = 0  
plt.plot(track_loss)

In [None]:
plt.plot(track_loss[1:])

In [None]:
torch.save(model.state_dict(), 'EpochG_10_vgg.pth')
torch.save(netD.state_dict(), 'EpochD_10_vgg.pth')

In [None]:
model = I2INet()
model = model.cuda()
model = nn.DataParallel(model, list(range(2)))
model.load_state_dict(torch.load('EpochG_10_vgg.pth'), strict = True)
model.eval()

In [None]:
avg_loss = 0
avg_mse = 0
avg_ssim = 0
avg_psnr = 0
avg_kl = 0
for i, data in enumerate(tqdm(vloader)):
        target = data['Img']
        output = model(data['tImg'].cpu())
        output = torch.nn.functional.upsample(output,scale_factor=2, mode='bilinear', align_corners=True)
        torchvision.utils.save_image(output, './'+data['Path'][0], nrow=1, padding=0, normalize=True, range=(-1,1))
        
        #t = target.data.cpu().numpy().reshape((512, 1024, 3))
        #p = output.data.cpu().numpy().reshape((512, 1024, 3))

        #avg_psnr += psnr(t, p, data_range = t.max() - t.min())
        #avg_mse  += mse(t, p)
        #avg_ssim += ssim(t, p, data_range = t.max() - t.min(), multichannel=True)
#print('Avg_psne-mse-ssim-avgloss',avg_psnr/i,avg_mse/i, avg_ssim/i)   

In [None]:
#print('Avg_psne-mse-ssim-avgloss',avg_psnr/i,avg_mse/i, avg_ssim/i) 
torchvision.utils.save_image(output, 'test1.png', nrow=1, padding=0, normalize=True, range=(-1,1))
torchvision.utils.save_image(target, 'test2.png', nrow=1, padding=0, normalize=True, range=(-1,1))
torchvision.utils.save_image(data['tImg'].cuda(), 'test3.png', nrow=1, padding=0, normalize=True, range=(-1,1))