In [None]:
from torch.utils.data import Sampler, DataLoader
import matplotlib.pyplot as plt
import random
import numpy as np
import torch
import torchvision
import torch.nn as nn
import torch.nn.functional as F
from torchvision.io import read_image
from torchvision import transforms
import torchvision.transforms.functional as tf

import cv2
from torch.utils.data import Dataset, DataLoader
import torchtext 
import os
import matplotlib.pyplot as plt
import time

np.random.seed(0)
torch.manual_seed(0)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
# print(device)

In [None]:
torchtext.utils.download_from_url("http://efrosgans.eecs.berkeley.edu/pix2pix/datasets/facades.tar.gz")
torchtext.utils.extract_archive("./.data/facades.tar.gz","data")

In [None]:
!ls -al
!pwd

In [None]:
class CustomDataset(Dataset):
  def __init__(self, img_dir, data="train", transform=True):
    self.transform = transform
    self.img_dir= os.path.join(img_dir, data)

  def __len__(self):
    return len(os.listdir(self.img_dir)) 
  
  def __getitem__(self, idx):
    img_path = os.path.join(self.img_dir, f"{idx+1}.jpg")
    img = read_image(img_path,torchvision.io.image.ImageReadMode.RGB)
    width = img.shape[2]//2
    input = img[:,:,width:]
    real = img[:,:,:width]
    if self.transform:
        input, real = self.transform_imgs(input,real,(286,286))
    input = (input / 127.5) - 1
    real = (real / 127.5) - 1
    return input.type('torch.FloatTensor'), real.type('torch.FloatTensor')
  
  def transform_imgs(self,input, real, resize_dim):
    org_dim = (input.shape[1], input.shape[2]) 
    resize = transforms.Resize(size=resize_dim)
    input,real = resize(input), resize(real)
    i, j, h, w = transforms.RandomCrop.get_params(input, output_size=org_dim)
    input, real = tf.crop(input, i, j, h, w), tf.crop(real, i, j, h, w)
    if np.random.rand() > 0.5:
        input,real = tf.hflip(input),tf.hflip(real)
    return input, real

In [None]:
train_data = CustomDataset(img_dir="data/facades", data="val", transform=True)
img, real = train_data[1]

In [None]:
plt.imshow(img.permute(1, 2, 0))
plt.show()

plt.imshow(real.permute(1, 2, 0))
plt.show()

In [None]:
class DownSampleBlock(nn.Module):
  def __init__(self, in_ch, out_ch, use_batchnorm=False, stride=2, padding=1):
    super().__init__()
    self.conv1 = nn.Conv2d(in_ch, out_ch, 4,stride=stride, padding=padding, bias=False)
    nn.init.normal_(self.conv1.weight, mean=0.0, std=0.02)
    self.bn = nn.BatchNorm2d(out_ch) if use_batchnorm else None
    self.relu  = nn.LeakyReLU()
  
  def forward(self, x):
    x = self.conv1(x)
    if self.bn:
      x = self.bn(x)
    x = self.relu(x)
    return x

class UpSampleBlock(nn.Module):
  def __init__(self, in_ch, out_ch, use_dropout=True):
    super().__init__()
    self.conv1 = nn.ConvTranspose2d(in_ch, out_ch, 4,stride=2, padding=1,bias=False)
    nn.init.normal_(self.conv1.weight, mean=0.0, std=0.02)
    self.bn = nn.BatchNorm2d(out_ch)
    self.relu  = nn.ReLU()
    self.dropout = nn.Dropout(0.5) if use_dropout else None
  
  def forward(self, x):
    x = self.conv1(x)
    x = self.bn(x)
    if self.dropout:
      x = self.dropout(x)
    x = self.relu(x)
    return x

In [None]:
print(img.shape)
ds = DownSampleBlock(3,3)
x = ds(img.float())
print(x.shape)
ds = UpSampleBlock(3,3)
x = ds(x.unsqueeze(0))
print(x.shape)

In [None]:
class Encoder(nn.Module):
  def __init__(self, chs, batch_norm):
    super().__init__()
    self.enc_blocks = nn.ModuleList([DownSampleBlock(chs[i], chs[i+1],batch_norm[i]) for i in range(len(chs)-1)])
  
  def forward(self, x):
    ftrs = []
    for block in self.enc_blocks:
        x = block(x)
        ftrs.append(x)
        # print(x.shape)
    return x,ftrs

In [None]:
chs = [3,64,128,256,512,512,512,512,512]
bn = [True]*len(chs)
bn[0]=False
enc = Encoder(chs,bn)
y = torch.cat([img.float().unsqueeze(0),img.float().unsqueeze(0)])
print(y.shape)
x,ftrs = enc(y)

In [None]:
class Decoder(nn.Module):
  def __init__(self, chs, dropout):
    super().__init__()
    self.dec_blocks = nn.ModuleList([UpSampleBlock(2*chs[i], chs[i+1],dropout[i]) for i in range(len(chs)-1)])
    self.dec_blocks[0]=UpSampleBlock(chs[0],chs[1])
  
  def forward(self, x, encoder_features):
    for block, ftr in zip(self.dec_blocks,encoder_features):
        x = block(x)
        x = torch.cat([x, ftr], dim=1)
    return x

In [None]:
enc_chs = [3,64,128,256,512,512,512,512,512]
enc_bn = [True]*len(chs)
enc_bn[0]=False

y = torch.cat([img.float().unsqueeze(0),img.float().unsqueeze(0)])
dec_chs = enc_chs[::-1]
dec_dropout = [False]*len(chs)
dec_dropout[0:3]=[True]*3

In [None]:
class Generator(nn.Module):
  def __init__(self, enc_chs=enc_chs, enc_bn = enc_bn, dec_chs=dec_chs, dec_dropout=dec_dropout,lambdaa=100):
    super().__init__()
    self.encoder     = Encoder(enc_chs,enc_bn)
    self.decoder     = Decoder(dec_chs,dec_dropout)
    self.head        = nn.ConvTranspose2d(128,3, 4,stride=2, padding=1,bias=False)
    nn.init.normal_(self.head.weight, mean=0.0, std=0.02)
    self.head_act   = nn.Tanh()
    self.lambdaa = lambdaa 

  def forward(self, x):
    x, ftrs = self.encoder(x)
    ftrs.reverse()
    out      = self.decoder(x, ftrs[1:])
    out      = self.head_act(self.head(out))
    return out
  
  def get_loss(self,disc_out, gen_out, target):
    loss_fn = torch.nn.BCEWithLogitsLoss()
    gan_loss = loss_fn(disc_out,torch.ones_like(disc_out))
    l1_loss = nn.L1Loss()(gen_out,target)
    loss = gan_loss + self.lambdaa*l1_loss
    return loss, gan_loss,l1_loss

In [None]:
class Discriminator(nn.Module):
  def __init__(self):
    super().__init__()
    self.down1 = DownSampleBlock(6,64,False)
    self.down2 = DownSampleBlock(64,128,True)
    self.down3 = DownSampleBlock(128,256,True)
    self.down4 = DownSampleBlock(256,512,True,stride=1, padding=1)
    self.conv = nn.Conv2d(512, 1, 4,stride=1, padding=1, bias=False)
    nn.init.normal_(self.conv.weight, mean=0.0, std=0.02)
    # self.out = nn.Sigmoid()

  def forward(self, input, target):
    x = torch.cat([input,target],dim=1)
    x =self.down1(x)
    x= self.down2(x)
    x = self.down3(x)
    x = self.down4(x)
    x = self.conv(x)
    # x = self.out(x)
    return x
  
  def get_loss(self, out, target):
    loss_fn = torch.nn.BCEWithLogitsLoss()
    loss = loss_fn(target,torch.ones_like(target))
    loss+= loss_fn(out,torch.zeros_like(out))
    return loss

In [None]:
gen = Generator()
y1 = torch.cat([img.float().unsqueeze(0),img.float().unsqueeze(0)])
y2 = torch.cat([real.float().unsqueeze(0),real.float().unsqueeze(0)])

pred= gen(y1)
print(pred.shape)

dis = Discriminator()
pred_dis = dis(y1, y2)
print(pred_dis.shape)

In [None]:
gen_out = gen(y1)
dis_real = dis(y1,y2)
dis_gen = dis(y1,gen_out.detach())
dis_loss = dis.get_loss(dis_real,dis_gen)
print(dis_loss.item())
gen_loss, gen_gan_loss, gen_l1_loss = gen.get_loss(dis_gen,gen_out,y2)
print(gen_loss.item(),gen_gan_loss.item(),gen_l1_loss.item())

In [None]:
plt.imshow(y[0].permute(1, 2, 0)*0.5+0.5)
plt.show()

plt.imshow(pred[0].detach().permute(1, 2, 0)*0.5+0.5)
plt.show()

print(torch.max(pred[0]),torch.min(pred[0]))

In [None]:
train_data = CustomDataset(img_dir="data/facades", data="train", transform=True)
test_data = CustomDataset(img_dir="data/facades", data="test", transform=True)

train_dataloader = DataLoader(train_data, batch_size=16, shuffle=True)
test_dataloader = DataLoader(test_data, batch_size=16, shuffle=True)

gen = Generator().to(device)
dis = Discriminator().to(device)

In [None]:
# train_inp, train_tar = next(iter(train_dataloader))
# test_inp, test_tar = next(iter(test_dataloader))

# l = list(enumerate(train_dataloader))
# len(l)

In [None]:
epochs = 300
# start = time.time()

gen_optim = torch.optim.Adam(gen.parameters(), lr=2e-4,betas=(0.5,0.999))
dis_optim = torch.optim.Adam(dis.parameters(), lr=2e-4,betas=(0.5,0.999))

In [None]:
# from torch.utils.tensorboard import SummaryWriter
# writer = SummaryWriter('runs/pix2pix')

In [None]:
def generate_images(model, test_input, tar,idx):
  test_input=test_input.float()
  prediction = model(test_input.to(device)).detach().cpu()
  plt.figure(figsize=(15, 15))
  
  display_list = [test_input[idx], tar[idx], prediction[idx]]
  title = ['Input Image', 'Ground Truth', 'Predicted Image']

  for i in range(3):
    plt.subplot(1, 3, i+1)
    plt.title(title[i])
    # Getting the pixel values in the [0, 1] range to plot.
    plt.imshow((display_list[i] * 0.5 + 0.5).permute(1, 2, 0))
    plt.axis('off')
  plt.show()

In [None]:
generate_images(gen, y1, y2,0)

In [None]:
torch.autograd.set_detect_anomaly(True)
for epoch in range(epochs):
    running_loss = np.array([0,0,0,0],dtype=float)
    for batch_idx, (inp,tar) in enumerate(train_dataloader):
        # dis_optim.zero_grad()

        inp,tar=inp.to(device),tar.to(device)
        # dis.zero_grad()
        # gen_out = gen(inp)
        # dis_real = dis(inp,tar)
        # dis_gen = dis(inp,gen_out.detach())
        # dis_loss = dis.get_loss(dis_real,dis_gen)
        # dis_loss.backward()
        # dis_optim.step()

        # gen.zero_grad()
        # dis_gen = dis(inp,gen_out)
        # gen_loss, gen_gan_loss, gen_l1_loss = gen.get_loss(dis_gen,gen_out,tar)
        # gen_loss.backward()
        # gen_optim.step()

        b_size = inp.shape[0]
        real_class = torch.ones(b_size,1,30,30).to(device)
        fake_class = torch.zeros(b_size,1,30,30).to(device)

        #Train D
        dis.zero_grad()
        real_patch = dis(inp,tar)
        real_gan_loss= nn.BCEWithLogitsLoss()(real_patch,real_class)

        fake=gen(inp)

        fake_patch = dis(inp,fake.detach())
        fake_gan_loss=nn.BCEWithLogitsLoss()(fake_patch,fake_class)

        D_loss = real_gan_loss + fake_gan_loss
        D_loss.backward()
        dis_optim.step()

        #Train G
        gen.zero_grad()
        fake_patch = dis(inp,fake)
        fake_gan_loss=nn.BCEWithLogitsLoss()(fake_patch,real_class)

        L1_loss = nn.L1Loss()(fake,tar)
        G_loss = fake_gan_loss + 100*L1_loss
        G_loss.backward()

        gen_optim.step()

        x= np.array([G_loss.item(),fake_gan_loss.item(),L1_loss.item(),D_loss.item()],dtype=float)
        running_loss += x
        if (batch_idx+1)%4==0:
            print("Step:",batch_idx+1,"Gen loss:",round(G_loss.item(),2),"Gen GAN loss:",round(fake_gan_loss.item(),2),"Gen L1:",round(L1_loss.item(),2),"Dis loss:",round(D_loss.item(),2))

    n=(batch_idx+1)
    running_loss=np.around(running_loss/n,decimals=2)
    print("Epoch:",epoch+1,"Gen loss:",running_loss[0],"Gen GAN loss:",running_loss[1],"Gen L1:",running_loss[2],"Dis loss:",running_loss[3])
    if (epoch+1)%5==0: 
        generate_images(gen,y1,y2,0)

In [None]:
torch.save(gen.state_dict(), 'PIX2PIX_GEN_500.ckpt')

In [None]:
torch.save(dis.state_dict(), 'PIX2PIX_DIS_500.ckpt')

In [None]:
test_data = CustomDataset(img_dir="data/facades", data="test", transform=False)
test_dataloader = DataLoader(test_data, batch_size=16, shuffle=True)
test_inp, test_tar = next(iter(test_dataloader))

for i in range(len(test_inp)):
    generate_images(gen,test_inp,test_tar,i)

In [None]:
epochs = 200
for epoch in range(epochs):
    running_loss = np.array([0,0,0,0],dtype=float)
    for batch_idx, (inp,tar) in enumerate(train_dataloader):
        # dis_optim.zero_grad()

        inp,tar=inp.to(device),tar.to(device)
        # dis.zero_grad()
        # gen_out = gen(inp)
        # dis_real = dis(inp,tar)
        # dis_gen = dis(inp,gen_out.detach())
        # dis_loss = dis.get_loss(dis_real,dis_gen)
        # dis_loss.backward()
        # dis_optim.step()

        # gen.zero_grad()
        # dis_gen = dis(inp,gen_out)
        # gen_loss, gen_gan_loss, gen_l1_loss = gen.get_loss(dis_gen,gen_out,tar)
        # gen_loss.backward()
        # gen_optim.step()

        b_size = inp.shape[0]
        real_class = torch.ones(b_size,1,30,30).to(device)
        fake_class = torch.zeros(b_size,1,30,30).to(device)

        #Train D
        dis.zero_grad()
        real_patch = dis(inp,tar)
        real_gan_loss= nn.BCEWithLogitsLoss()(real_patch,real_class)

        fake=gen(inp)

        fake_patch = dis(inp,fake.detach())
        fake_gan_loss=nn.BCEWithLogitsLoss()(fake_patch,fake_class)

        D_loss = real_gan_loss + fake_gan_loss
        D_loss.backward()
        dis_optim.step()

        #Train G
        gen.zero_grad()
        fake_patch = dis(inp,fake)
        fake_gan_loss=nn.BCEWithLogitsLoss()(fake_patch,real_class)

        L1_loss = nn.L1Loss()(fake,tar)
        G_loss = fake_gan_loss + 100*L1_loss
        G_loss.backward()

        gen_optim.step()

        x= np.array([G_loss.item(),fake_gan_loss.item(),L1_loss.item(),D_loss.item()],dtype=float)
        running_loss += x
        if (batch_idx+1)%4==0:
            print("Step:",batch_idx+1,"Gen loss:",round(G_loss.item(),2),"Gen GAN loss:",round(fake_gan_loss.item(),2),"Gen L1:",round(L1_loss.item(),2),"Dis loss:",round(D_loss.item(),2))

    n=(batch_idx+1)
    running_loss=np.around(running_loss/n,decimals=2)
    print("Epoch:",epoch+301,"Gen loss:",running_loss[0],"Gen GAN loss:",running_loss[1],"Gen L1:",running_loss[2],"Dis loss:",running_loss[3])
    if (epoch+1)%5==0: 
        generate_images(gen,y1,y2,0)

In [None]:
test_data = CustomDataset(img_dir="data/facades", data="test", transform=False)
test_dataloader = DataLoader(test_data, batch_size=16, shuffle=True)
test_inp, test_tar = next(iter(test_dataloader))

for i in range(len(test_inp)):
    generate_images(gen,test_inp,test_tar,i)

In [None]:
!ls
!pwd

In [None]:
# This mounts your Google Drive to the Colab VM.
from google.colab import drive
drive.mount('/content/drive')

# TODO: Enter the foldername in your Drive where you have saved the unzipped
# assignment folder, e.g. 'cs231n/assignments/assignment3/'
FOLDERNAME = "cs231n/"
assert FOLDERNAME is not None, "[!] Enter the foldername."

# Now that we've mounted your Drive, this ensures that
# the Python interpreter of the Colab VM can load
# python files from within it.
import sys
sys.path.append('/content/drive/My Drive/{}'.format(FOLDERNAME))

# This downloads the COCO dataset to your Drive
# if it doesn't already exist.
%cd /content/drive/My\ Drive/$FOLDERNAME

In [None]:
torch.save(gen.state_dict(), 'PIX2PIX_GEN_500.ckpt')

In [None]:
torch.save(dis.state_dict(), 'PIX2PIX_DIS_500.ckpt')