In [None]:
import torch
import torchvision
from torch.utils.data import Dataset
import torchvision.transforms as transforms
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from PIL import Image
import numpy as np
import matplotlib.pyplot as plt
import sys
import os

In [None]:
def weights_init_normal(m):
    classname = m.__class__.__name__
    if classname.find("Conv") != -1:
        torch.nn.init.normal_(m.weight.data, 0.0, 0.02)
    elif classname.find("BatchNorm2d") != -1:
        torch.nn.init.normal_(m.weight.data, 1.0, 0.02)
        torch.nn.init.constant_(m.bias.data, 0.0)

In [None]:
#discriminator architecture

class discriminator(nn.Module):
    def __init__(self,input_channels=3):
        super(discriminator,self).__init__()


        
        self.block_1=nn.Sequential(
                nn.Conv2d(input_channels*2,64,kernel_size=(4,4),stride=(2,2),padding=1),
                nn.LeakyReLU(0.2)
                )
        self.block_2=nn.Sequential(
                nn.Conv2d(64,128,kernel_size=(4,4),stride=(2,2),padding=1),
                nn.BatchNorm2d(128),
                nn.LeakyReLU(0.2)
                )
        self.block_3=nn.Sequential(
                nn.Conv2d(128,256,kernel_size=(4,4),stride=(2,2),padding=1),
                nn.BatchNorm2d(256),
                nn.LeakyReLU(0.2)
                )
        self.block_4=nn.Sequential(
                nn.Conv2d(256,512,kernel_size=(4,4),stride=(2,2),padding=1),
                nn.BatchNorm2d(512),
                nn.LeakyReLU(0.2)
                )
        self.block_5=nn.ZeroPad2d((1, 0, 1, 0))
        self.block_6=nn.Sequential(
                nn.Conv2d(512,1,kernel_size=(4,4),stride=(1,1),padding=1),
                nn.Sigmoid()
                )
    def forward(self,x,y):
        input=torch.cat([x,y],dim=1)
        out=self.block_1(input)
        out=self.block_2(out)
        out=self.block_3(out)
        out=self.block_4(out)
        out=self.block_5(out)
        out=self.block_6(out)
      
        
        return out
            
    
def test():
    x=torch.randn((1,3,286,286))
    y=torch.randn((1,3,286,286))
    preds=discriminator(input_channels=3)
    output=preds(x,y)
    print(output.shape)
    


In [None]:
if __name__ == "__main__":
    test()

In [None]:
#generator architecture

class encoder(nn.Module):
    def __init__(self,input_channels,output_channels,normalize=True,dropout=0.0):
        super(encoder,self).__init__()
        
        layers=[nn.Conv2d(input_channels,output_channels,4,2,1,bias=False)]
        if normalize:
            layers.append(nn.InstanceNorm2d(output_channels))
        layers.append(nn.LeakyReLU(0.2))
        if dropout:
            layers.append(nn.Dropout(dropout))
        self.model=nn.Sequential(*layers)
        
    def forward(self,x):
        return self.model(x)
    
class decoder(nn.Module):
    def __init__(self,input_channels,output_channels,dropout=0.0):
        super(decoder,self).__init__()
        layers=[
            nn.ConvTranspose2d(input_channels,output_channels,4,2,1,bias=False),
            nn.InstanceNorm2d(output_channels),
            nn.ReLU(inplace=True)
        ]
        if dropout:
            layers.append(nn.Dropout(dropout))
            
        self.model=nn.Sequential(*layers)
    
    def forward(self,x,skip_input):
        x=self.model(x)
        x=torch.cat((x,skip_input),1) # torch.cat(tensors, dim=0, *, out=None) → Tensor(here 1 represents concatenation along columns)
        
        return x
    
    
class generator(nn.Module):
    def __init__(self,input_channels=3,output_channels=3):
        super(generator,self).__init__()
        # encoder model: C64-C128-C256-C512-C512-C512-C512-C512
        self.down1=self.model(input_channels,64,normalize=False)
        self.down2=self.model(64,128)
        self.down3=self.model(128,256)
        self.down4=self.model(256,512)
        self.down5=self.model(512,512)
        self.down6=self.model(512,512)
        self.down7=self.model(512,512)
        
        self.bottle_neck=self.model(512,512,normalize=False) #bottleneck layer
        
        # decoder model: CD512-CD1024-CD1024-C1024-C1024-C512-C256-C128
        self.up1=self.model(512,512,dropout=0.5)
        self.up2=self.model(512*2,512,dropout=0.5)
        self.up3=self.model(512*2,512,dropout=0.5)
        self.up4=self.model(512*2,512,dropout=0.5)
        self.up5=self.model(512*2,256)
        self.up6=self.model(256*2,128)
        self.up7=self.model(128*2,64)
        
        self.final_up=nn.Sequential(
        nn.ConvTranspose2d(64*2,out_channels,4,2,1),
        nn.Tanh(),
        )
        
        def forward(self,x):
            d1=self.down1(x)
            d2=self.down2(d1)
            d3=self.down3(d2) 
            d4=self.down4(d3)
            d5=self.down5(d4)
            d6=self.down6(d5)
            d7=self.down7(d6)
            
            bottle_neck=self.bottle_neck(d7)
            
            u1=self.up1(bottle_neck,d7)
            u2=self.up2(u1,d6)
            u3=self.up3(u2,d5)
            u4=self.up4(u3,d4)
            u5=self.up5(u4,d3)
            u6=self.up6(u5,d2)
            u7=self.up7(u6,d1)
            
            return self.final_up(u7)
                

In [None]:
class Arial2Map(Dataset):
    def __init__(self, root, transform=None):
        self.root = root
        self.file = root + '/maps/maps/train/'
        self.transform = transform
    
    def __len__(self):
        return len(os.listdir(self.file))
    
    def __getitem__(self, idx):
        img = Image.open(self.file + str(idx+1) + '.jpg')
        w, h = img.size
        img_A = img.crop((0, 0, w/2, h))
        img_B = img.crop((w/2, 0, w, h))
        
        if self.transform:
            img_A = self.transform(img_A)
            img_B = self.transform(img_B)
        return {'arial':img_A, 'map':img_B}

In [None]:
transform = transforms.Compose([
                    transforms.Resize((256,256)),
                    transforms.ToTensor()
            ])

In [None]:
train_data = Arial2Map('../input/pix2pix-dataset', transform=transform)


In [None]:
len(train_data)

In [None]:
plt.imshow(train_data[71]['arial'].permute(1,2,0))

In [None]:
dataloader = torch.utils.data.DataLoader(dataset=train_data,
                                        shuffle=True,
                                        batch_size=2)

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
epochs = 200
device

In [None]:
D_loss=[]
G_loss=[]

In [None]:
G=generator().to(device)
D=discriminator().to(device)
criterion_gan=torch.nn.MSELoss()
criterion_pixelwise=torch.nn.L1Loss()
lambda_pixel=100
optimizer_G=torch.optim.Adam(G.parameters(),lr=0.0002,betas=(0.5,0.999))
optimizer_D=torch.optim.Adam(D.parameters(),lr=0.0002,betas=(0.5,0.999))

In [None]:
for epoch in range(epochs):
    for i,imgs in enumerate(dataloader):
        real_arials=imgs['arial'].to(device)
        real_maps=imgs['map'].to(device)
        
        valid=torch.ones(real_arials.shape[0],1,16,16).to(device)
        fake=torch.zeros(real_arials.shape[0],1,16,16).to(device)
        
        
        # train generator
        output_gen=G(real_arials)
        pred_fake=D(output_gen,real_arials)
        gan_loss=criterion_gan(pred_fake,valid)
        
        #pixelwise loss
        L1_loss=criterion_pixelwise(output_gen,real_maps)
        
        Gen_loss=gan_loss+lambda_pixel*L1_loss
        
        G.zero_grad()
        Gen_loss.backward()
        
        optimizer_G.step()
        
        
        #train discriminator
        
        pred_real=D(real_maps,real_arials)
        loss_real=criterion_gan(pred_real,valid)
        
        pred_fake=D(output_gen.detach(),real_arials)
        loss_fake=criterion_gan(pred_fake,fake)
        
        Disc_loss=(loss_real+loss_fake)/2
        
        D.zero_grad()
        Disc_loss.backward()
        
        optimizer_D.step()
        
        D_loss.append(Disc_loss.item())
        G_loss.append(Gen_loss.item())
        
        if(i+1)%500==0:
            sys.stdout.write(
                "\r[Epoch %d/%d] [Batch %d/%d] [D loss: %f] [G loss: %f adv: %f]\n"
                % (
                    epoch,
                    epochs,
                    i,
                    len(dataloader),
                    Disc_loss.item(),
                    Gen_loss.item(),
                    gan_loss.item(),
                )
            )
        

In [None]:
out = G(train_data[89]['arial'].unsqueeze(0).to(device))

In [None]:
plt.imshow(out.detach().cpu().squeeze(0).permute(1,2,0))

In [None]:
plt.imshow(train_data[89]['arial'].permute(1,2,0))

In [None]:
plt.imshow(train_data[89]['map'].permute(1,2,0))