In [3]:
import math
from torch import nn
import torch.nn.functional as F
import random

from torchvision import transforms
import matplotlib.pyplot as plt

class Visualizer:
    def __init__(self, show_step=10, image_size=30):
        self.transform = transforms.Compose([transforms.Normalize(mean = [-2.118, -2.036, -1.804], # Equivalent to un-normalizing ImageNet (for correct visualization)
                                                                    std = [4.367, 4.464, 4.444]),
                                            transforms.ToPILImage(),
                                            transforms.Scale(image_size)])

        self.show_step = show_step
        self.step = 0

        self.figure, (self.lr_plot, self.hr_plot, self.fake_plot) = plt.subplots(1,3)
        self.figure.show()

        self.lr_image_ph = None
        self.hr_image_ph = None
        self.fake_hr_image_ph = None

    def show(self, inputsG, inputsD_real, inputsD_fake):

        self.step += 1
        if self.step == self.show_step:
            self.step = 0

            i = random.randint(0, inputsG.size(0) -1)

            lr_image = self.transform(inputsG[i])
            hr_image = self.transform(inputsD_real[i])
            fake_hr_image = self.transform(inputsD_fake[i])

            if self.lr_image_ph is None:
                self.lr_image_ph = self.lr_plot.imshow(lr_image)
                self.hr_image_ph = self.hr_plot.imshow(hr_image)
                self.fake_hr_image_ph = self.fake_plot.imshow(fake_hr_image)
            else:
                self.lr_image_ph.set_data(lr_image)
                self.hr_image_ph.set_data(hr_image)
                self.fake_hr_image_ph.set_data(fake_hr_image)

            self.figure.canvas.draw()

In [13]:
class ResidualBlock(nn.Module):
    def __init__(self,channels):
        super(ResidualBlock,self).__init__()
        self.conv1=nn.Conv2d(channels,channels,kernel_size=3,padding=1)
        self.bn1=nn.BatchNorm2d(channels)
        self.prelu=nn.PReLU()
        self.conv2=nn.Conv2d(channels,channels,kernel_size=3,padding=1)
        self.bn2=nn.BatchNorm2d(channels)
    
    def forward(self,x):
        residual=self.conv1(x)
        residual=self.bn1(residual)
        residual=self.prelu(residual)
        residual=self.conv2(residual)
        residual=self.bn2(residual)
        return residual+x
        
class Upsample(nn.Module):
    def __init__(self,in_channels,up_scale):
        super(Upsample,self).__init__()
        self.conv=nn.Conv2d(in_channels,in_channels*up_scale**2,kernel_size=3,padding=1)
        self.pix_shuff=nn.PixelShuffle(up_scale)
        self.prelu=nn.PReLU()
    
    def forward(self,x):
        x=self.conv(x)
        x=self.pix_shuff(x)
        x=self.prelu(x)
        return x

        
class Generator(nn.Module):
    
    def __init__(self):
        super(Generator,self).__init__()
        upsample_block_num=1
        self.block1=nn.Sequential(
        nn.Conv2d(3,64,kernel_size=9,padding=4),
        nn.PReLU(),
        )
        self.block2 = ResidualBlock(64)
        self.block3 = ResidualBlock(64)
        self.block4 = ResidualBlock(64)
        self.block5 = ResidualBlock(64)
        self.block6 = ResidualBlock(64)
        self.block7 = nn.Sequential(
            nn.Conv2d(64, 64, kernel_size=3, padding=1),
            nn.PReLU())
        
        block8=[Upsample(64,2) for _ in range (upsample_block_num)]
        block8.append(nn.Conv2d(64,3,kernel_size=9,padding=4))
        self.block8=nn.Sequential(*block8)

    def forward(self,x):
        block1=self.block1(x)
        block2=self.block2(block1)
        block3=self.block3(block2)
        block4=self.block4(block3)
        block5=self.block5(block4)
        block6=self.block6(block5)
        block7=self.block7(block6)
        block8=self.block8(block1+block7)
    
        return (F.tanh(block8)+1)/2

class Discriminator(nn.Module):
    
    def __init__(self):
        super(Discriminator,self).__init__()
        self.net=nn.Sequential(
        nn.Conv2d(3,64,kernel_size=3,padding=1),
        nn.LeakyReLU(0.2),
        nn.Conv2d(64,64,kernel_size=3,stride=2,padding=1),
        nn.BatchNorm2d(64),
        nn.LeakyReLU(0.2),
        nn.Conv2d(64,128,kernel_size=3,padding=1),
        nn.BatchNorm2d(128),
        nn.LeakyReLU(0.2),
        nn.Conv2d(128,128,kernel_size=3,stride=2,padding=1),
        nn.BatchNorm2d(128),
        nn.LeakyReLU(0.2),
            
        nn.Conv2d(128, 256, kernel_size=3, padding=1),
        nn.BatchNorm2d(256),
        nn.LeakyReLU(0.2),
        
        nn.Conv2d(256, 256, kernel_size=3, stride=2, padding=1),
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.2),

            nn.Conv2d(256, 512, kernel_size=3, padding=1),
            nn.BatchNorm2d(512),
            nn.LeakyReLU(0.2),

            nn.Conv2d(512, 512, kernel_size=3, stride=2, padding=1),
            nn.BatchNorm2d(512),
            nn.LeakyReLU(0.2),

            nn.AdaptiveAvgPool2d(1),
            nn.Conv2d(512, 1024, kernel_size=1),
            nn.LeakyReLU(0.2),
            nn.Conv2d(1024, 1, kernel_size=1)
            
        )
        
        def forward(self,x):
            batch_size=x.size(0)
            return F.sigmoid(self.net(x).view(batch_size))
            
class Feature(nn.Module):
    
    def __init__(self,model,feature_layer=11):
        super(Feature,self).__init__()
        self.features=nn.Sequential(*list(model.features.children())[:feature_layer+1]) #for features
    
    def forward(self,x):
        return self.features(x)
    




    
        
            
            
        
    
    
    

In [None]:
import torchvision
import torchvision.datasets as datasets
import torchvision.transforms as transforms
from torch.autograd import Variable
gen=Generator()
import torch.optim as optim
import torch
import torchvision.transforms.functional as TF
transform = transforms.Compose([
transforms.ToTensor()])
disc=Discriminator()

visualizer = Visualizer(image_size=64)
normalize = transforms.Normalize(mean = [0.485, 0.456, 0.406],
                                std = [0.229, 0.224, 0.225])

scale = transforms.Compose([transforms.ToPILImage(),
                            transforms.Scale(32),
                            transforms.ToTensor(),
                            transforms.Normalize(mean = [0.485, 0.456, 0.406],
                                                std = [0.229, 0.224, 0.225])
])
feature_extractor=Feature(torchvision.models.vgg19(pretrained=True))

content_criterion=nn.MSELoss()

adv_criterion=nn.BCELoss()

dataset = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
dataloader = torch.utils.data.DataLoader(dataset, batch_size=15,
shuffle=True)
print(len(dataloader))

bs=15
ones_const = Variable(torch.ones(bs, 1))

gen.cuda()
disc.cuda()
feature_extractor.cuda()
content_criterion.cuda()
adv_criterion.cuda()
ones_const.cuda()

optim_g=optim.Adam(gen.parameters(),lr=0.01)
optim_d=optim.Adam(disc.parameters(),lr=0.01)


low_res=torch.FloatTensor(bs,3,32,32)
low_res=Variable(low_res.cuda())

#pretraining the gen

for epoch in range(2):
    mean_gen_loss=0.0
    
    for i,data in enumerate(dataloader,0):
        
        high_res_real,labels=data
        
        #downsample into low_res
        for j in range (bs):
            low_res[j]=scale(high_res_real[j])
            high_res_real[j]=normalize(high_res_real[j])
        
        high_res_real=high_res_real.cuda()
        
        high_res_fake=gen(low_res)
        
        gen.zero_grad()
        high_res_fake=F.upsample(high_res_fake,size=(32,32),mode='bilinear')
        gen_content_loss=content_criterion(high_res_fake,high_res_real)
        mean_gen_loss+=gen_content_loss.data[0]
        gen_content_loss.backward()
        
        optim_g.step()
        
        visualizer.show(low_res.cpu().data, high_res_real.cpu().data, high_res_fake.cpu().data)
        
#SRGAN training 

for epoch in range(opt.nEpochs):
    mean_generator_content_loss = 0.0
    mean_generator_adversarial_loss = 0.0
    mean_generator_total_loss = 0.0
    mean_discriminator_loss = 0.0

    for i, data in enumerate(dataloader):
        # Generate data
        high_res_real, _ = data

        # Downsample images to low resolution
        for j in range(opt.batchSize):
            low_res[j] = scale(high_res_real[j])
            high_res_real[j] = normalize(high_res_real[j])

        # Generate real and fake inputs
        
        high_res_real = Variable(high_res_real.cuda())
        high_res_fake = gen(Variable(low_res).cuda())
        
        target_real = Variable(torch.rand(opt.batchSize,1)*0.5 + 0.7).cuda()
        target_fake = Variable(torch.rand(opt.batchSize,1)*0.3).cuda()

        ######### Train discriminator #########
        disc.zero_grad()

        discriminator_loss = adv_criterion(disc(high_res_real), target_real) + \
                             adv_criterion(disc(Variable(high_res_fake.data)), target_fake)
        mean_discriminator_loss += discriminator_loss.data[0]
        
        discriminator_loss.backward()
        optim_discriminator.step()

        ######### Train generator #########
        generator.zero_grad()

        real_features = Variable(feature_extractor(high_res_real).data)
        fake_features = feature_extractor(high_res_fake)
        high_res_fake=F.upsample(high_res_fake,size=(32,32),mode='bilinear')
        
        generator_content_loss = content_criterion(high_res_fake, high_res_real) + 0.006*content_criterion(fake_features, real_features)
        mean_generator_content_loss += gen_content_loss.data[0]
        generator_adversarial_loss = adv_criterion(discriminator(high_res_fake), ones_const)
        mean_generator_adversarial_loss += generator_adversarial_loss.data[0]

        generator_total_loss = generator_content_loss + 1e-3*generator_adversarial_loss
        mean_generator_total_loss += generator_total_loss.data[0]
        
        generator_total_loss.backward()
        optim_g.step()   
        
        
        ######### Status and display #########
        sys.stdout.write('\r[%d/%d][%d/%d] Discriminator_Loss: %.4f Generator_Loss (Content/Advers/Total): %.4f/%.4f/%.4f' % (epoch, opt.nEpochs, i, len(dataloader),
        discriminator_loss.data[0], generator_content_loss.data[0], generator_adversarial_loss.data[0], generator_total_loss.data[0]))
        visualizer.show(low_res.cpu().data, high_res_real.cpu().data, high_res_fake.cpu().data)

    

  "please use transforms.Resize instead.")
  "matplotlib is currently using a non-GUI backend, "


Files already downloaded and verified
3334


  "See the documentation of nn.Upsample for details.".format(mode))
