In [None]:
# -*- coding: utf-8 -*-
"""Implements SRGAN models: https://arxiv.org/abs/1609.04802
TODO:
"""

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable

def swish(x):
    return x * F.sigmoid(x)

class FeatureExtractor(nn.Module):
    def __init__(self, cnn, feature_layer=11):
        super(FeatureExtractor, self).__init__()
        self.features = nn.Sequential(*list(cnn.features.children())[:(feature_layer+1)])

    def forward(self, x):
        return self.features(x)


class residualBlock(nn.Module):
    def __init__(self, in_channels=64, k=3, n=64, s=1):
        super(residualBlock, self).__init__()

        self.conv1 = nn.Conv2d(in_channels, n, k, stride=s, padding=1)
        self.bn1 = nn.BatchNorm2d(n)
        self.conv2 = nn.Conv2d(n, n, k, stride=s, padding=1)
        self.bn2 = nn.BatchNorm2d(n)

    def forward(self, x):
        y = swish(self.bn1(self.conv1(x)))
        return self.bn2(self.conv2(y)) + x

class upsampleBlock(nn.Module):
    # Implements resize-convolution
    def __init__(self, in_channels, out_channels):
        super(upsampleBlock, self).__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, 3, stride=1, padding=1)
        self.shuffler = nn.PixelShuffle(2)

    def forward(self, x):
        return swish(self.shuffler(self.conv(x)))

class Generator(nn.Module):
    def __init__(self, n_residual_blocks, upsample_factor):
        super(Generator, self).__init__()
        self.n_residual_blocks = n_residual_blocks
        self.upsample_factor = upsample_factor

        self.conv1 = nn.Conv2d(3, 64, 9, stride=1, padding=4)

        for i in range(self.n_residual_blocks):
            self.add_module('residual_block' + str(i+1), residualBlock())

        self.conv2 = nn.Conv2d(64, 64, 3, stride=1, padding=1)
        self.bn2 = nn.BatchNorm2d(64)

        for i in range(self.upsample_factor/2):
            self.add_module('upsample' + str(i+1), upsampleBlock(64, 256))

        self.conv3 = nn.Conv2d(64, 3, 9, stride=1, padding=4)

    def forward(self, x):
        x = swish(self.conv1(x))

        y = x.clone()
        for i in range(self.n_residual_blocks):
            y = self.__getattr__('residual_block' + str(i+1))(y)

        x = self.bn2(self.conv2(y)) + x

        for i in range(self.upsample_factor/2):
            x = self.__getattr__('upsample' + str(i+1))(x)

        return self.conv3(x)

class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.conv1 = nn.Conv2d(3, 64, 3, stride=1, padding=1)

        self.conv2 = nn.Conv2d(64, 64, 3, stride=2, padding=1)
        self.bn2 = nn.BatchNorm2d(64)
        self.conv3 = nn.Conv2d(64, 128, 3, stride=1, padding=1)
        self.bn3 = nn.BatchNorm2d(128)
        self.conv4 = nn.Conv2d(128, 128, 3, stride=2, padding=1)
        self.bn4 = nn.BatchNorm2d(128)
        self.conv5 = nn.Conv2d(128, 256, 3, stride=1, padding=1)
        self.bn5 = nn.BatchNorm2d(256)
        self.conv6 = nn.Conv2d(256, 256, 3, stride=2, padding=1)
        self.bn6 = nn.BatchNorm2d(256)
        self.conv7 = nn.Conv2d(256, 512, 3, stride=1, padding=1)
        self.bn7 = nn.BatchNorm2d(512)
        self.conv8 = nn.Conv2d(512, 512, 3, stride=2, padding=1)
        self.conv8_bn = nn.BatchNorm2d(512)
        self.conv9 = nn.Conv2d(512, 512, 3, stride=4, padding=0)
        self.conv9_bn = nn.BatchNorm2d(512)

        self.fc1 = nn.Linear(2048, 1024)
        self.fc2 = nn.Linear(1024, 1)

        # Replaced original paper FC layers with FCN
        # self.conv9 = nn.Conv2d(512, 1, 1, stride=1, padding=1)

    def forward(self, x):
        x = swish(self.conv1(x))

        x = swish(self.bn2(self.conv2(x)))
        x = swish(self.bn3(self.conv3(x)))
        x = swish(self.bn4(self.conv4(x)))
        x = swish(self.bn5(self.conv5(x)))
        x = swish(self.bn6(self.conv6(x)))
        x = swish(self.bn7(self.conv7(x)))
        x = swish(self.conv8_bn(self.conv8(x)))
        x = swish(self.conv9_bn(self.conv9(x)))

        print(x.size())
        x = x.view(x.size(0), -1)
        print(x.size())
        print(self.fc1(x).size())
        x = F.elu(self.fc1(x))
        print('uwuw')
        return F.sigmoid(self.fc2(x))

        # x = self.conv9(x)
        # return F.sigmoid(F.avg_pool2d(x, x.size()[2:])).view(x.size()[0], -1)

In [None]:
import matplotlib
# import matplotlib.pyplot as plt
matplotlib.use('Agg')

import argparse
import os

import torch
import torch.optim as optim
import torch.nn as nn
from torch.autograd import Variable

import torchvision
import torchvision.datasets as datasets
import torchvision.transforms as transforms

from tensorboard_logger import configure, log_value

# from newmodel import Generator, Discriminator, FeatureExtractor
# from utilsnew import Visualizer2

parser = argparse.ArgumentParser()
# parser.add_argument('--dataset', type=str, default='cifar100', help='cifar10 | cifar100 | folder')
parser.add_argument('--dataroot', type=str, default='./data', help='path to dataset')
parser.add_argument('--workers', type=int, default=2, help='number of data loading workers')
parser.add_argument('--batchSize', type=int, default=16, help='input batch size')
parser.add_argument('--imageSize', type=int, default=100, help='the low resolution image size')
parser.add_argument('--upSampling', type=int, default=2, help='low to high resolution scaling factor')
parser.add_argument('--nEpochs', type=int, default=100, help='number of epochs to train for')
parser.add_argument('--gEpochs', type=int, default=5, help='number of epochs to pre-train the generator for')
parser.add_argument('--lrG', type=float, default=0.00001, help='learning rate for generator')
parser.add_argument('--lrD', type=float, default=0.0000001, help='learning rate for discriminator')
parser.add_argument('--cuda', action='store_true', help='enables cuda')
parser.add_argument('--nGPU', type=int, default=1, help='number of GPUs to use')
parser.add_argument('--netG', type=str, default='', help="path to netG (to continue training)")
parser.add_argument('--netD', type=str, default='', help="path to netD (to continue training)")
parser.add_argument('--out', type=str, default='checkpoints', help='folder to output model checkpoints')

opt = parser.parse_args()
print(opt)

try:
    os.makedirs(opt.out)
except OSError:
    pass

if torch.cuda.is_available() and not opt.cuda:
    print("WARNING: You have a CUDA device, so you should probably run with --cuda")

transform = transforms.Compose([transforms.Resize((opt.imageSize*opt.upSampling,opt.imageSize*opt.upSampling)), 
                                transforms.ToTensor()]) #opt.upSampling

normalize = transforms.Normalize(mean = [0.485, 0.456, 0.406],
                                std = [0.229, 0.224, 0.225])

scale = transforms.Compose([transforms.ToPILImage(),
                            transforms.Resize(opt.imageSize),
                            transforms.ToTensor(),
                            transforms.Normalize(mean = [0.485, 0.456, 0.406],
                                                std = [0.229, 0.224, 0.225])
                            ])


backtrans= transforms.Compose([transforms.Normalize(mean = [-2.118, -2.036, -1.804], #Equivalent to un-normalizing ImageNet (for correct visualization)
                            std = [4.367, 4.464, 4.444]),
                            transforms.ToPILImage(),
                            transforms.Resize(opt.imageSize)])

dataset = datasets.ImageFolder(root= os.path.join(opt.dataroot, 'fake') ,
                                transform=transform)
datasetreal = datasets.ImageFolder(root= os.path.join(opt.dataroot, 'real') ,
                                transform=transform)
# assert dataset

dataloader = torch.utils.data.DataLoader(dataset, batch_size=opt.batchSize,
                                         shuffle=True, num_workers=int(opt.workers))


dataloaderreal = torch.utils.data.DataLoader(datasetreal, batch_size=opt.batchSize,
                                         shuffle=True, num_workers=int(opt.workers))

netG = Generator(16, opt.upSampling) #6
#netG.apply(weights_init)
if opt.netG != '':
    netG.load_state_dict(torch.load(opt.netG))
print netG

netD = Discriminator()
#netD.apply(weights_init)
if opt.netD != '':
    netD.load_state_dict(torch.load(opt.netD))
print netD

# For the content loss
feature_extractor = FeatureExtractor(torchvision.models.vgg19(pretrained=True))
print feature_extractor
content_criterion = nn.MSELoss()
adversarial_criterion = nn.BCELoss()

target_real = Variable(torch.ones(opt.batchSize,1))
target_fake = Variable(torch.zeros(opt.batchSize,1))

# if gpu is to be used
if opt.cuda:
    netG.cuda()
    netD.cuda()
    feature_extractor.cuda()
    content_criterion.cuda()
    adversarial_criterion.cuda()
    target_real = target_real.cuda()
    target_fake = target_fake.cuda()

optimG = optim.Adam(netG.parameters(), lr=opt.lrG)
optimD = optim.SGD(netD.parameters(), lr=opt.lrD, momentum=0.9, nesterov=True)

configure('logs/' + 'genimage-' + str(opt.out) + str(opt.batchSize) + '-' + str(opt.lrG) + '-' + str(opt.lrD), flush_secs=5)
visualizer = Visualizer2()
dire ='resultimages/' +str(opt.out) +'/'
if not os.path.exists(dire):
    os.makedirs(dire)

inputsG = torch.FloatTensor(opt.batchSize, 3, opt.imageSize, opt.imageSize)

inputsGreal = torch.FloatTensor(opt.batchSize, 3, opt.imageSize*opt.upSampling, opt.imageSize*opt.upSampling)

# Pre-train generator
print 'Generator pre-training'
for epoch in range(opt.gEpochs):
    for i, data in enumerate(dataloader):
        # Generate data
        inputs, _ = data

        # print(int(inputs.size()[0]) )
        if not(int(inputs.size()[0]) == opt.batchSize):
            continue
        # print(inputs.size())
        # Downsample images to low resolution
        for j in range(opt.batchSize):
            inputsG[j] = scale(inputs[j])
            inputs[j] = normalize(inputs[j])

        # Generate real and fake inputs
        if opt.cuda:
            inputsD_real = Variable(inputs.cuda())
            inputsD_fake = netG(Variable(inputsG).cuda())
        else:
            inputsD_real = Variable(inputs)
            inputsD_fake = netG(Variable(inputsG))


        # imgplot = plt.imshow(backtrans(inputs[0]))
        # plt.show()

        ######### Train generator #########
        netG.zero_grad()
        # print(inputsD_fake.size())
        # print(inputsD_real.size())

        lossG_content = content_criterion(inputsD_fake, inputsD_real)
        lossG_content.backward()

        # Update generator weights
        optimG.step()

        if i%50==0:
            # Status and display
            print('[%d/%d][%d/%d] Loss_G: %.4f'
                  % (epoch, opt.gEpochs, i, len(dataloader), lossG_content.data[0],))
        # visualizer.show(inputsG, inputsD_real.cpu().data, inputsD_fake.cpu().data)

    log_value('G_pixel_loss', lossG_content.data[0], epoch)
    torch.save(netG.state_dict(), '%s/netG_pretrain_%d.pth' % (opt.out, epoch))


print 'Adversarial training'
lenreal = len(dataloaderreal)
count=0
visualcount=0
realdata = iter(dataloaderreal)
for epoch in range(opt.nEpochs):
    mean_generator_content_loss = 0.0
    mean_generator_adversarial_loss = 0.0
    mean_generator_total_loss = 0.0
    mean_discriminator_loss = 0.0
    for i, data in enumerate(dataloader):
        ######### Train discriminator #########
        netD.zero_grad()

        # With real data
   
        if count==lenreal-1:
            del realdata
            del inputsreal
            realdata = iter(dataloaderreal)
        count= (count+1)%lenreal
        inputsreal, _ = realdata.next()
     
        while not(int(inputsreal.size()[0]) == opt.batchSize):
            inputsreal, _ = realdata.next()
        for k in range(opt.batchSize):
            inputsGreal[k] = normalize(inputsreal[k])

        if opt.cuda:
            inputsDreal = Variable(inputsGreal.cuda())
        else:
            inputsDreal = Variable(inputsGreal)

        outputsre = netD(inputsDreal)
        Dreal = outputsre.data.mean()
        lossDreal = adversarial_criterion(outputsre, target_real)
        # Update discriminator weights

        # Generate data
        inputs, _ = data

        if not(int(inputs.size()[0]) == opt.batchSize):
            continue

        # Downsample images to low resolution
        for j in range(opt.batchSize):
            inputsG[j] = scale(inputs[j])
            # print(inputs[j].size())
            inputs[j] = normalize(inputs[j])

        # Generate real and fake inputs
        if opt.cuda:
            inputsD_real = Variable(inputs.cuda())
            inputsD_fake = netG(Variable(inputsG).cuda())
        else:
            inputsD_real = Variable(inputs)
            inputsD_fake = netG(Variable(inputsG))
          

        # With fake data

        outputs = netD(inputsD_real)
        D_real = outputs.data.mean()

        # lossD_real = adversarial_criterion(outputs, target_fake)
        # lossD_real.backward()

        outputsnew = netD(inputsD_fake.detach()) # Don't need to compute gradients wrt weights of netG (for efficiency)
        D_fake = outputsnew.data.mean()

        lossD = adversarial_criterion(outputsnew, target_fake) + 10*(adversarial_criterion(outputs, target_fake) + lossDreal)
        mean_discriminator_loss+=lossD.data[0]
        lossD.backward()

        # Update discriminator weights
        optimD.step()

        ######### Train generator #########
        netG.zero_grad()

        real_features = Variable(feature_extractor(inputsD_real).data)
        fake_features = feature_extractor(inputsD_fake)

        lossG_content = content_criterion(fake_features, real_features)
        lossG_adversarial = adversarial_criterion(netD(inputsD_fake), target_fake)
        mean_generator_content_loss += lossG_content.data[0]

        lossG_total = 0.1*lossG_content + lossG_adversarial 
        mean_generator_adversarial_loss += lossG_adversarial.data[0]
        
        mean_generator_total_loss += lossG_total.data[0]
        lossG_total.backward()

        # Update generator weights
        optimG.step()
        if i%50==0:
            # Status and display
            print('[%d/%d][%d/%d] Loss_Dreal: %.4f D(x): %.4f '% (epoch, opt.nEpochs, i, len(dataloader), lossDreal.data[0],Dreal))
            print('[%d/%d][%d/%d] Loss_D: %.4f Loss_G (Content/Advers): %.4f/%.4f D(x): %.4f D(G(z)): %.4f'
                  % (epoch, opt.nEpochs, i, len(dataloader),
                     lossD.data[0], lossG_content.data[0], lossG_adversarial.data[0], D_real, D_fake,))
        if i%200==0:
            visualcount = visualizer.show(inputsG, inputsD_fake.cpu().data,visualcount,str(opt.out))
    log_value('D_real_loss', lossDreal.data[0], epoch)
    log_value('G_content_loss', mean_generator_content_loss/len(dataloader), epoch)
    log_value('G_advers_loss', mean_generator_adversarial_loss/len(dataloader), epoch)
    log_value('generator_total_loss', mean_generator_total_loss/len(dataloader), epoch)
    log_value('D_advers_loss', mean_discriminator_loss/len(dataloader), epoch)

    # Do checkpointing
    torch.save(netG.state_dict(), '%s/netG_epoch_%d.pth' % (opt.out, epoch))
    torch.save(netD.state_dict(), '%s/netD_epoch_%d.pth' % (opt.out, epoch))
