
# All the necessary imports

In [1]:
import torch
import numpy as np
import torch.nn as nn
import torchvision
from torch.utils.data  import DataLoader,Dataset
import cv2
import matplotlib.pyplot as plt
import itertools 

# Resnet blocks 
### Used 6 in my experiment because I worked with 128x128 image size (suggested in the official paper to use 6 for 128x128 and 9 for 256x256 and up)

In [2]:
class resnet6block(nn.Module):
    def __init__(self):
        super(resnet6block,self).__init__()
        self.conv_block = self.make_block()
        
    def make_block(self):
        self.conv_block = []
        self.conv_block+=[nn.ReflectionPad2d(1)]
        self.conv_block+=[nn.Conv2d(256,256,kernel_size=3,padding=0,bias=True)]
        self.conv_block+=[nn.ReflectionPad2d(1)]
        self.conv_block+=[nn.Conv2d(256,256,kernel_size=3,padding=0,bias=True)]
        
        return nn.Sequential(*self.conv_block)
    def forward(self,x):
        return x + self.conv_block(x)

# The Generator class of our model
### NOTE: You can add 9 resnet blocks by changing the 6 to 9 in the for loop in the __init__() method

In [3]:
class Generator(nn.Module):
    def __init__(self):
        super(Generator,self).__init__()
        model = [nn.ReflectionPad2d(3),
                 nn.Conv2d(3,64,kernel_size=7,stride=1,padding=0,bias=True),
                nn.InstanceNorm2d(64),
                nn.ReLU(True)]
        #downsampling
        model+=[
               nn.Conv2d(64,128,kernel_size=3,stride=2,padding=1,bias=True),
                nn.InstanceNorm2d(128),
                nn.ReLU(True)]
        model +=[nn.Conv2d(128,256,kernel_size=3,stride=2,padding=1,bias=True),
                nn.InstanceNorm2d(256),
                nn.ReLU(True)]
        #6 resnet layers
        for i in range(6):
            model+= [resnet6block()]
        #upsampling layer
        model += [nn.Upsample(scale_factor=2),
                  nn.ReflectionPad2d(1),
                 nn.Conv2d(256,128,kernel_size=3,stride=1,padding=0,bias=True),
                 nn.InstanceNorm2d(128),
                 nn.ReLU(True)]
        model += [nn.Upsample(scale_factor=2),
                 nn.ReflectionPad2d(1),
                 nn.Conv2d(128,64,kernel_size=3,stride=1,padding=0,bias=True),
                 nn.InstanceNorm2d(64),
                 nn.ReLU(True)]
        model+=[nn.ReflectionPad2d(3),
               nn.Conv2d(64,3,kernel_size=7,stride=1,padding=0),
               nn.Tanh()]
        self.model = nn.Sequential(*model)
        
        #return self.model
    def forward(self,x):
        return self.model(x)

# The Discriminator class

### Exactly as suggested in the official paper

In [4]:
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator,self).__init__()
        model=[nn.Conv2d(3,64,kernel_size=4,stride=2,padding=1,bias=True), 
              nn.LeakyReLU(0.02,True)]
        
        model +=[nn.Conv2d(64,128,kernel_size=4,stride=2,padding=1,bias=True),
                nn.InstanceNorm2d(128),
                nn.LeakyReLU(0.02,True)]
        
        model +=[ nn.Conv2d(128,256,kernel_size=4,stride=2,padding=1,bias=True),
                nn.InstanceNorm2d(128),
                nn.LeakyReLU(0.02,True)]
       
        model +=[ nn.Conv2d(256,512,kernel_size=4,stride=2,padding=1,bias=True),
                nn.InstanceNorm2d(512),
                nn.LeakyReLU(0.02,True)]
        
        model +=[nn.ZeroPad2d((1,0,1,0))]
        
        model +=[nn.Conv2d(512,1,kernel_size=4,stride=1,padding=1,bias=True)]
        
        self.model = nn.Sequential(*model)
        
        
    
    def  forward(self,x):
        return self.model(x)

## Weights initialising 

In [5]:
def weights_init(m):
    classname = m.__class__.__name__
    if classname.find('Conv')!=-1:
        torch.nn.init.normal_(m.weight.data,0.0,0.02)
    elif classname.find('BatchNorm2d')!=-1:
        torch.nn.init.normal_(m.weight.data,1.0,0.02)
        torch.nn.init.constant_(m.bias.data,0.0)
    

# The buffer store 50 previously generated images
### As suggested in the paper

In [6]:
class replaybuffer():
    def __init__(self,max_size=50):
        self.max_size= max_size
        self.data =[]
    
    def push_and_pop(self,data):
        to_return =[]
        for element in data.data:
            element= torch.unsqueeze(element,0)
            if len(self.data)<self.max_size:
                self.data.append(element)
                to_return.append(element)
            else:
                if np.random.uniform(0,1)>0.5:
                    i = np.random.randint(0,self.max_size-1)
                    to_return.append(self.data[i].clone())
                    self.data[i] = element
                else:
                    to_return.append(element)
        return torch.autograd.Variable(torch.cat(to_return))
        

## Gradually decrease the Learning rates of the model

In [7]:
class LambdaLR:
    def __init__(self, n_epochs, offset, decay_start_epoch):
        assert (n_epochs - decay_start_epoch) > 0, "Decay must start before the training session ends!"
        self.n_epochs = n_epochs
        self.offset = offset
        self.decay_start_epoch = decay_start_epoch

    def step(self, epoch):
        return 1.0 - max(0, epoch + self.offset - self.decay_start_epoch) / (self.n_epochs - self.decay_start_epoch)

In [8]:
generator_A = Generator()
generator_B = Generator()
discriminator_A = Discriminator()
discriminator_B = Discriminator()
generator_A.apply(weights_init)
generator_B.apply(weights_init)
discriminator_A.apply(weights_init)
discriminator_B.apply(weights_init)
generator_A.cuda()
generator_B.cuda()
discriminator_A.cuda()
discriminator_B.cuda()

Discriminator(
  (model): Sequential(
    (0): Conv2d(3, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (1): LeakyReLU(negative_slope=0.02, inplace=True)
    (2): Conv2d(64, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (3): InstanceNorm2d(128, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
    (4): LeakyReLU(negative_slope=0.02, inplace=True)
    (5): Conv2d(128, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (6): InstanceNorm2d(128, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
    (7): LeakyReLU(negative_slope=0.02, inplace=True)
    (8): Conv2d(256, 512, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (9): InstanceNorm2d(512, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
    (10): LeakyReLU(negative_slope=0.02, inplace=True)
    (11): ZeroPad2d(padding=(1, 0, 1, 0), value=0.0)
    (12): Conv2d(512, 1, kernel_size=(4, 4), stride=(1, 1), padding=(1, 1))
  )
)

# Adam optimizer with lr = 0.0002 , beta1 = 0.5 and beta2 = 0.999
## As suggested in the official paper

In [11]:
optimizerG = torch.optim.Adam(itertools.chain(generator_A.parameters(),generator_B.parameters()),lr = 0.0002, betas=(0.5,0.999),weight_decay=1e-5)
optimizerDA = torch.optim.Adam(discriminator_A.parameters(),lr = 0.0002, betas=(0.5,0.999))
optimizerDB = torch.optim.Adam(discriminator_B.parameters(),lr=0.0002,betas =(0.5,0.999))

# The losses as suggested by the paper

In [12]:
GANloss = nn.MSELoss()
cycle = nn.L1Loss()
identity = nn.L1Loss()

In [13]:
lr_g = torch.optim.lr_scheduler.LambdaLR(optimizerG,lr_lambda=LambdaLR(200,0,100).step)
lr_da = torch.optim.lr_scheduler.LambdaLR(optimizerDA,lr_lambda=LambdaLR(200,0,100).step)
lr_db = torch.optim.lr_scheduler.LambdaLR(optimizerDB,lr_lambda=LambdaLR(200,0,100).step)

In [14]:
fake_A_buffer = replaybuffer()
fake_B_buffer = replaybuffer()

## Loading the data

In [15]:
imageA = torchvision.datasets.ImageFolder('/path/to/data/',transform=torchvision.transforms.Compose(
[torchvision.transforms.Resize(128),
 torchvision.transforms.ToTensor(),
 torchvision.transforms.Normalize((0.5,0.5,0.5),(0.5,0.5,0.5))
        
    ]))

In [16]:
imageB = torchvision.datasets.ImageFolder('/path/to/data/',transform=torchvision.transforms.Compose(
[torchvision.transforms.Resize(128),
 torchvision.transforms.ToTensor(),
 torchvision.transforms.Normalize((0.5,0.5,0.5),(0.5,0.5,0.5))
        
    ]))

In [17]:
dataA = DataLoader(imageA,shuffle=True,batch_size = 1, num_workers=8 )
dataB = DataLoader(imageB,shuffle=True,batch_size = 1 , num_workers=8)

# Targets for the patch gan discriminator used in the paper

In [18]:
target_real = torch.tensor(np.ones((1,1,8,8)),dtype = torch.float32).cuda()
target_fake = torch.tensor(np.zeros((1,1,8,8)),dtype= torch.float32).cuda()

# Training loop

In [None]:
for i in range(200):
    for imgA,imgB in zip(enumerate(dataA),enumerate(dataB)):
        generator_A.train()
        generator_B.train()
        optimizerG.zero_grad()
       
        setA = imgA[1][0].cuda()
        setB = imgB[1][0].cuda()
        
        fakeA = generator_A(setB)
        fakeB = generator_B(setA)
        
        rec_A = generator_A(fakeB)
        rec_B = generator_B(fakeA)
        
        pred_fake_A = discriminator_A(fakeA)
        gla = GANloss(pred_fake_A,target_real)
        
        pred_fake_B = discriminator_B(fakeB)
        glb = GANloss(pred_fake_B,target_real)
        
        cycle_a = cycle(rec_A,setA)
        cycle_b = cycle(rec_B,setB)
        
        identity_A = generator_A(setA)
        identity_B = generator_B(setB)
        
        identity_loss =(identity(identity_A,setA)+identity(identity_B,setB))/2
        
        generator_loss = (gla+glb)/2+((cycle_a+cycle_b)/2)*10.0+identity_loss*5.0
        
        generator_loss.backward()
        optimizerG.step()
        
        #########################################
        optimizerDA.requires_grad = True
        optimizerDA.zero_grad()
        
        pred_real_A = discriminator_A(setA)
        dar_loss= GANloss(pred_real_A,target_real)
        
        pred_fake_A = discriminator_A(fake_A_buffer.push_and_pop(fakeA))
        daf_loss = GANloss(pred_fake_A,target_fake)
        
        da_loss = (dar_loss+daf_loss)/2
        da_loss.backward()
        optimizerDA.step()
        
        ###########################################
        optimizerDB.requires_grad=True
        optimizerDB.zero_grad()
        pred_real_B = discriminator_B(setB.detach())
        dbr_loss = GANloss(pred_real_B,target_real)
        
        pred_fake_b = discriminator_B(fake_B_buffer.push_and_pop(fakeB))
        dbf_loss = GANloss(pred_fake_b,target_fake)
        
        db_loss = (dbf_loss+dbr_loss)/2
        db_loss.backward()
        optimizerDB.step()
        
        ###########################################
        
        if(imgA[0]%10==0 or imgA[0]==0):
            fakeim = torch.cat([setB,fakeA,setA,fakeB],2)
            fakeim = fakeim.squeeze(0)
            fakeim = fakeim*.5+.5
            fakeim = np.transpose(fakeim.cpu().detach())
            plt.imshow(fakeim)
            plt.show()

            print(f'epoch: {i} G:{generator_loss}, D:{db_loss+da_loss}, Cycle:{((cycle_a+cycle_b)/2)*10}')
    lr_da.step()
    lr_db.step()
    lr_g.step()
    
    torch.save(generator_A.state_dict(),'Path')
    torch.save(generator_B.state_dict(),'Path')
    torch.save(discriminator_A.state_dict(),'Path')
    torch.save(discriminator_B.state_dict(),'Path')
        
        
        

# Check out the official repository https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix


# If you'd like a read of the official paper https://arxiv.org/abs/1703.10593 