In [None]:
import torch 
import torch.nn as nn
import torch.nn.functional as f 
import numpy as np
import cv2
import matplotlib.pyplot as plt
import torchvision
import itertools
import glob
from torch.utils.data import DataLoader,Dataset
from torch.autograd import Variable
import datetime
import time
from torchvision.utils import save_image
import os
from PIL import Image

# Resnet blocks as per the paper

In [None]:
class resnet(torch.nn.Module):
    def __init__(self):
        super(resnet,self).__init__()
        self.resblock =  self.make_block()
    def make_block(self):
        self.conv_block = []
        self.conv_block+=[nn.ReflectionPad2d(1)]
        self.conv_block+=[nn.Conv2d(256,256,kernel_size=3,padding=0,bias=True)]
        self.conv_block+=[nn.ReflectionPad2d(1)]
        self.conv_block+=[nn.Conv2d(256,256,kernel_size=3,padding=0,bias=True)]
        
        return nn.Sequential(*self.conv_block)
    def forward(self,x):
        return x+self.resblock(x)

# Generator model 

In [None]:
class Generator(nn.Module):
    def __init__(self):
        super(Generator,self).__init__()
        
        #down sampling
        self.conv = [nn.ReflectionPad2d(3),nn.Conv2d(3,64,kernel_size = 7,stride=1, padding = 0, bias = True),
                      nn.InstanceNorm2d(64),
                     nn.ReLU(True)
                    ]
        self.conv+= [nn.Conv2d(64,128,kernel_size=3,stride=2,padding=1,bias=True),
                    nn.InstanceNorm2d(128),
                    nn.ReLU(True)]
        self.conv+=[nn.Conv2d(128,256,kernel_size=3,stride=2,padding=1,bias=True),
                   nn.InstanceNorm2d(256),
                   nn.ReLU(True)]
        
        #resnet layers
        for i in range(6):
            self.conv+=[resnet()]
        
        self.conv = nn.Sequential(*self.conv)
        
        #content mask upsampling
        self.content_deconv = [nn.Upsample(scale_factor=2),
                              nn.ReflectionPad2d(1),
                              nn.Conv2d(256,128,kernel_size=3,stride=1,padding=0,bias=True),
                              nn.InstanceNorm2d(128),
                              nn.ReLU(True)]
        self.content_deconv += [nn.Upsample(scale_factor=2),
                              nn.ReflectionPad2d(1),
                              nn.Conv2d(128,64,kernel_size=3,stride=1,padding=0,bias=True),
                              nn.InstanceNorm2d(64),
                              nn.ReLU(True)]
        self.content_deconv+= [nn.ReflectionPad2d(3),
                              nn.Conv2d(64,30,kernel_size=7,stride=1,padding=0,bias=True),
                              ]
        self.content_deconv = nn.Sequential(*self.content_deconv)
        #attention mask upsampling
        
        self.attention_deconv = [nn.Upsample(scale_factor=2),
                              nn.ReflectionPad2d(1),
                              nn.Conv2d(256,128,kernel_size=3,stride=1,padding=0,bias=True),
                              nn.InstanceNorm2d(128),
                              nn.ReLU(True)]
        self.attention_deconv+= [nn.Upsample(scale_factor=2),
                              nn.ReflectionPad2d(1),
                              nn.Conv2d(128,64,kernel_size=3,stride=1,padding=0,bias=True),
                              nn.InstanceNorm2d(64),
                              nn.ReLU(True)]
        self.attention_deconv+= [nn.ReflectionPad2d(3),
                              nn.Conv2d(64,10,kernel_size=7,stride=1,padding=0,bias=True),
                              ]
        self.attention_deconv = nn.Sequential(*self.attention_deconv)
        self.tanh=  nn.Tanh()
        
    def forward(self,x):
       
        x = self.conv(x)
        content = self.content_deconv(x)
        attention = self.attention_deconv(x)
        
        image = self.tanh(content)
        
        image1 = image[:, 0:3, :, :]
        image2 = image[:, 3:6, :, :]
        image3 = image[:, 6:9, :, :]
        image4 = image[:, 9:12, :, :]
        image5 = image[:, 12:15, :, :]
        image6 = image[:, 15:18, :, :]
        image7 = image[:, 18:21, :, :]
        image8 = image[:, 21:24, :, :]
        image9 = image[:, 24:27, :, :]
        image10 = image[:, 27:30, :, :]
        
        softmax = nn.Softmax(dim=1)
        attention = softmax(attention)
        
        attention1_ = attention[:, 0:1, :, :]
        attention2_ = attention[:, 1:2, :, :]
        attention3_ = attention[:, 2:3, :, :]
        attention4_ = attention[:, 3:4, :, :]
        attention5_ = attention[:, 4:5, :, :]
        attention6_ = attention[:, 5:6, :, :]
        attention7_ = attention[:, 6:7, :, :]
        attention8_ = attention[:, 7:8, :, :]
        attention9_ = attention[:, 8:9, :, :]
        attention10_ = attention[:, 9:10, :, :]

        attention1_ = attention1_.repeat(1, 3, 1, 1)
        attention2_ = attention2_.repeat(1, 3, 1, 1)
        attention3_ = attention3_.repeat(1, 3, 1, 1)
        attention4_ = attention4_.repeat(1, 3, 1, 1)
        attention5_ = attention5_.repeat(1, 3, 1, 1)
        attention6_ = attention6_.repeat(1, 3, 1, 1)
        attention7_ = attention7_.repeat(1, 3, 1, 1)
        attention8_ = attention8_.repeat(1, 3, 1, 1)
        attention9_ = attention9_.repeat(1, 3, 1, 1)
        attention10_ = attention10_.repeat(1, 3, 1, 1)
        
        op1 = image1*attention1_
        op2 = image2*attention2_
        op3 = image3*attention3_
        op4 = image4*attention4_
        op5 = image5*attention5_
        op6 = image6*attention6_
        op7 = image7*attention7_
        op8 = image8*attention8_
        op9 = image9*attention9_
        op10 = image10*attention10_
        
        op = op1+op2+op3+op4+op5+op6+op7+op8+op9+op10
        
        return op


# Discriminator Model

In [None]:
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator,self).__init__()
        model=[nn.Conv2d(3,64,kernel_size=4,stride=2,padding=1,bias=True),
              nn.LeakyReLU(0.02,True)]
        
        model +=[nn.Conv2d(64,128,kernel_size=4,stride=2,padding=1,bias=True),
                nn.InstanceNorm2d(128),
                nn.LeakyReLU(0.02,True)]
        
        model +=[ nn.Conv2d(128,256,kernel_size=4,stride=2,padding=1,bias=True),
                nn.InstanceNorm2d(128),
                nn.LeakyReLU(0.02,True)]
       
        model +=[ nn.Conv2d(256,512,kernel_size=4,stride=2,padding=1,bias=True),
                nn.InstanceNorm2d(512),
                nn.LeakyReLU(0.02,True)]
        
        model +=[nn.ZeroPad2d((1,0,1,0))]
        
        model +=[nn.Conv2d(512,1,kernel_size=4,stride=1,padding=1,bias=True)]
        
        self.model = nn.Sequential(*model)
        
    def  forward(self,x):
        return self.model(x)

# Weights initialisation

In [None]:
def weights_init(m):
    classname = m.__class__.__name__
    if classname.find('Conv')!=-1:
        torch.nn.init.normal_(m.weight.data,0.0,0.02)
    elif classname.find('BatchNorm2d')!=-1:
        torch.nn.init.normal_(m.weight.data,1.0,0.02)
        torch.nn.init.constant_(m.bias.data,0.0)
    

# Buffer to store previously generated images

In [None]:
class replaybuffer():
    def __init__(self,max_size=50):
        self.max_size= max_size
        self.data =[]
    
    def push_and_pop(self,data):
        to_return =[]
        for element in data.data:
            element= torch.unsqueeze(element,0)
            if len(self.data)<self.max_size:
                self.data.append(element)
                to_return.append(element)
            else:
                if np.random.uniform(0,1)>0.5:
                    i = np.random.randint(0,self.max_size-1)
                    to_return.append(self.data[i].clone())
                    self.data[i] = element
                else:
                    to_return.append(element)
        return torch.autograd.Variable(torch.cat(to_return))
        

# Class to reducing learning rate

In [None]:
class LambdaLR:
    def __init__(self, n_epochs, offset, decay_start_epoch):
        
        self.n_epochs = n_epochs
        self.offset = offset
        self.decay_start_epoch = decay_start_epoch

    def step(self, epoch):
        return 1.0 - max(0, epoch + self.offset - self.decay_start_epoch) / (self.n_epochs - self.decay_start_epoch)

In [None]:
generator_A = Generator()
generator_B = Generator()
discriminator_A = Discriminator()
discriminator_B = Discriminator()
generator_A.apply(weights_init)
generator_B.apply(weights_init)
discriminator_A.apply(weights_init)
discriminator_B.apply(weights_init)
# ignore this if not running on GPU
generator_A.cuda()
generator_B.cuda()
discriminator_A.cuda()
discriminator_B.cuda()

In [None]:
optimizerG = torch.optim.Adam(itertools.chain(generator_A.parameters(),generator_B.parameters()),lr = 0.0002, betas=(0.5,0.999),weight_decay=1e-5)
optimizerDA = torch.optim.Adam(discriminator_A.parameters(),lr = 0.0002, betas=(0.5,0.999))
optimizerDB = torch.optim.Adam(discriminator_B.parameters(),lr=0.0002,betas =(0.5,0.999))

# Loss functions

In [None]:
GANloss = nn.MSELoss()
cycle = nn.L1Loss()
identity = nn.L1Loss()

In [None]:
lr_g = torch.optim.lr_scheduler.LambdaLR(optimizerG,lr_lambda=LambdaLR(200,0,100).step)
lr_da = torch.optim.lr_scheduler.LambdaLR(optimizerDA,lr_lambda=LambdaLR(200,0,100).step)
lr_db = torch.optim.lr_scheduler.LambdaLR(optimizerDB,lr_lambda=LambdaLR(200,0,100).step)

In [None]:
fake_A_buffer = replaybuffer()
fake_B_buffer = replaybuffer()

# Loading the dataset

In [None]:
class ImageDataset(Dataset):
    def __init__(self, root, transforms_=None, unaligned=False, mode='train'):
        self.transform = transforms_
        self.unaligned = unaligned

        self.files_A = sorted(glob.glob(os.path.join(root, '%s/a' % mode) + '/*.*'))
        self.files_B = sorted(glob.glob(os.path.join(root, '%s/b' % mode) + '/*.*'))

    def __getitem__(self, index):
        item_A = self.transform(Image.open(self.files_A[index % len(self.files_A)]))

        if self.unaligned:
            item_B = self.transform(Image.open(self.files_B[np.random.randint(0, len(self.files_B) - 1)]))
        else:
            item_B = self.transform(Image.open(self.files_B[index % len(self.files_B)]))

        return {'A': item_A, 'B': item_B}

    def __len__(self):
        return max(len(self.files_A), len(self.files_B))

In [None]:
tm=torchvision.transforms.Compose(
[torchvision.transforms.Resize(128),
 torchvision.transforms.ToTensor(),
 torchvision.transforms.Normalize((0.5,0.5,0.5),(0.5,0.5,0.5))
        
    ])

In [None]:
dataloader = DataLoader(ImageDataset('/home/prateek/Desktop/data1/', transforms_=tm, unaligned=True), 
                        batch_size=1, shuffle=True, num_workers=8)

In [None]:
#batch size, channels, row(depens on image size), columns(depens on image size)
target_real = torch.tensor(np.ones((1,1,8,8)),dtype = torch.float32).cuda() 
target_fake = torch.tensor(np.zeros((1,1,8,8)),dtype= torch.float32).cuda()

# Training loop

### AttentionGAN converges really fast as compared to CycleGAN thus, the number of epochs I used was 60 which was also recommended in the paper

In [None]:
#training loop
for epoch in range(0,60):
    for i,batch in enumerate(dataloader):
        generator_A.train()
        generator_B.train()
        optimizerG.zero_grad()
       # optimizerD.requires_grad = False
        setA = batch['A'].cuda()
        setB = batch['B'].cuda()
        
        fakeA = generator_A(setB)
        fakeB = generator_B(setA)
        
        rec_A = generator_A(fakeB)
        rec_B = generator_B(fakeA)
        
        pred_fake_A = discriminator_A(fakeA)
        gla = GANloss(pred_fake_A,target_real)
        
        pred_fake_B = discriminator_B(fakeB)
        glb = GANloss(pred_fake_B,target_real)
        
        cycle_a = cycle(rec_A,setA)
        cycle_b = cycle(rec_B,setB)
        
        identity_A = generator_A(setA)
        identity_B = generator_B(setB)
        
        identity_loss =(identity(identity_A,setA)+identity(identity_B,setB))/2
        
        generator_loss = (gla+glb)/2+((cycle_a+cycle_b)/2)*10.0+identity_loss*5.0
        
        generator_loss.backward()
        optimizerG.step()
        
        #########################################
        optimizerDA.requires_grad = True
        optimizerDA.zero_grad()
        
        pred_real_A = discriminator_A(setA)
        dar_loss= GANloss(pred_real_A,target_real)
        
        pred_fake_A = discriminator_A(fake_A_buffer.push_and_pop(fakeA))
        daf_loss = GANloss(pred_fake_A,target_fake)
        
        da_loss = (dar_loss+daf_loss)/2
        da_loss.backward()
        optimizerDA.step()
        
        ###########################################
        optimizerDB.requires_grad=True
        optimizerDB.zero_grad()
        pred_real_B = discriminator_B(setB.detach())
        dbr_loss = GANloss(pred_real_B,target_real)
        
        pred_fake_b = discriminator_B(fake_B_buffer.push_and_pop(fakeB))
        dbf_loss = GANloss(pred_fake_b,target_fake)
        
        db_loss = (dbf_loss+dbr_loss)/2
        db_loss.backward()
        optimizerDB.step()
        
        ###########################################
        
        if(i%100==0 or i==0):
            fakeim = torch.cat([setB,fakeA,setA,fakeB],2)
            fakeim = fakeim.squeeze(0)
            fakeim = fakeim*.5+.5
            fakeim = np.transpose(fakeim.cpu().detach())
            plt.imshow(fakeim)
            plt.show()

            print(f'epoch: {epoch} G:{generator_loss}, D:{db_loss+da_loss}, Cycle:{((cycle_a+cycle_b)/2)*10}')
    lr_da.step()
    lr_db.step()
    lr_g.step()
    
    torch.save(generator_A.state_dict(),'/home/prateek/Desktop/GA.pt')
    torch.save(generator_B.state_dict(),'/home/prateek/Desktop/GB.pt')
    torch.save(discriminator_A.state_dict(),'/home/prateek/Desktop/DA.pt')
    torch.save(discriminator_B.state_dict(),'/home/prateek/Desktop/DB.pt')
        
        
        