In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms
from torchvision.utils import save_image
import os
from PIL import Image
from torchvision.io import read_image
import pandas as pd
import matplotlib.pyplot as plt
import torchvision
import random

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)

In [None]:
!unzip images.zip

In [None]:
!rm -rf images.zip

In [None]:
img = read_image('./images/1.jpg')
print(img[:,:,:512].shape)

In [None]:
#U-Net, Generator

class Genearator(nn.Module):

    def __init__(self):
        super().__init__()
        def initial(in_channel, out_channel, kernel, stride):
          layers = [
              nn.Conv2d(in_channel, out_channel, kernel_size=kernel, stride=stride, padding=1),
              nn.LeakyReLU(0.2)
          ]
          return nn.Sequential(*layers)

        def encode(in_channel, out_channel, kernel, stride, padding):
          layers = [
              nn.Conv2d(in_channel, out_channel, kernel_size=kernel, stride=stride, padding=padding),
              nn.BatchNorm2d(out_channel),
              nn.LeakyReLU(0.2)
          ]
          return nn.Sequential(*layers)

        def decode(in_channel, out_channel, kernel, stride):
          layers = [
              nn.ConvTranspose2d(in_channel, out_channel, kernel_size=kernel, stride=stride, padding=1),
              nn.BatchNorm2d(out_channel),
              nn.ReLU(),
              nn.Dropout(0.5)
          ]
          return nn.Sequential(*layers)

        def final(in_channel, out_channel):
          layers = [
                nn.ConvTranspose2d(in_channel, out_channel, kernel_size=1, stride=1),
                nn.Tanh()
          ]
          return  nn.Sequential(*layers)

        self.encode1 = initial(3, 64, 4, 2) #256
        self.encode2 = encode(64, 128, 4, 2, 1) #128
        self.encode3 = encode(128, 256, 4, 2, 1) #64
        self.encode4 = encode(256, 256, 4, 2, 1) #32
        self.encode5 = encode(256, 256, 4, 2, 1) #16

        self.bottleneck = encode(256, 256, 1, 1, 0) #16
        self.decode1 = decode(256+256, 256, 4, 2) #32
        self.decode2 = decode(256+256, 256, 4, 2) #64
        self.decode3 = decode(256+256, 128, 4, 2) #128
        self.decode4 = decode(128+128, 64, 4, 2) #256
        self.decode5 = decode(64+64, 3, 4, 2) #512

        self.final = final(3+3, 3) #512

    def forward(self, x):

        source_image = x.view(-1, 3, 512, 512)
        encode1 = self.encode1(source_image)
        encode2 = self.encode2(encode1)
        encode3 = self.encode3(encode2)
        encode4 = self.encode4(encode3)
        encode5 = self.encode5(encode4)

        bottleneck = self.bottleneck(encode5) #(1, 512, 16, 16)
        #print("BN", bottleneck.shape)

        decode1 = self.decode1(torch.cat((bottleneck, encode5), dim=1))
        decode2 = self.decode2(torch.cat((decode1, encode4), dim=1))
        decode3 = self.decode3(torch.cat((decode2, encode3), dim=1))
        decode4 = self.decode4(torch.cat((decode3, encode2), dim=1))
        decode5 = self.decode5(torch.cat((decode4, encode1), dim=1))

        final_layer = self.final(torch.cat((decode5, source_image), dim=1))

        return final_layer

In [None]:
# PatchGAN discriminator

class Discriminator(nn.Module):

    def __init__(self):
        super().__init__()
        self.dis_seq = nn.Sequential(

            nn.Conv2d(6, 64, 4, 2, padding=1, padding_mode="reflect"), #256
            nn.LeakyReLU(),

            nn.Conv2d(64, 128, 4, 2, padding=1, padding_mode="reflect"), # 128
            nn.BatchNorm2d(128),
            nn.LeakyReLU(),

            nn.Conv2d(128, 256, 4, 2, padding=1, padding_mode="reflect"), # 64
            nn.BatchNorm2d(256),
            nn.LeakyReLU(),

            nn.Conv2d(256, 1, 4, 1, padding=1, padding_mode="reflect"), # 63
        )

    def forward(self, x):

        x = x.view(-1, 6, 512, 512)
        x = self.dis_seq(x)

        return x

In [None]:
discriminator = Discriminator()
discriminator = discriminator.to(device)
generator = Genearator()
generator = generator.to(device)

In [None]:
generator.load_state_dict(torch.load('generator.pt', weights_only=True, map_location='cpu'))
discriminator.load_state_dict(torch.load('discriminator.pt', weights_only=True, map_location='cpu'))

In [None]:
#DataLoader

class ImageDataset(Dataset):

    def __init__(self, annotations_file, img_dir, transform=None, target_transform=None):
        self.img_labels = pd.read_csv(annotations_file)
        self.img_dir = img_dir
        self.transform = transform
        self.target_transform = target_transform

    def __len__(self):
        return len(self.img_labels)

    def __getitem__(self, idx):
        img_path = os.path.join(self.img_dir, self.img_labels.iloc[idx, 0])
        image = (read_image(img_path)/255).float()
        label = self.img_labels.iloc[idx, 1]
        if self.transform:
            image = self.transform(image)
        if self.target_transform:
            label = self.target_transform(label)

        return image, label

transform = transforms.Compose([
  transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

images_dataset = ImageDataset(img_dir = 'images',
                             annotations_file = 'labels.csv',
                             transform = transform)
images_dl = DataLoader(images_dataset, batch_size=batch_size)

total_images = len(images_dataset)+1

In [None]:
epochs = 10000
lr = 0.0002
batch_size = 16
l1_var = 100

In [None]:
loss_bce = nn.BCEWithLogitsLoss()
loss_l1 = nn.L1Loss()
optimizer_G = optim.Adam(generator.parameters(), lr=lr, betas=(0.5, 0.999))
optimizer_D = optim.Adam(discriminator.parameters(), lr=lr, betas=(0.5, 0.999))

In [None]:
save_and_show_result_val = 20
def show_result():
    rand_str = str(random.randint(1, total_images))

    test_src_image = (read_image('./images/'+rand_str+'.jpg')[:,:,:512]/255).float()
    test_src_image = transform(test_src_image).to(device)

    test_target_image = (read_image('./images/'+rand_str+'.jpg')[:,:,512:]/255).float()
    test_target_image = transform(test_target_image).to(device)

    generator_eval = generator.eval()
    gen_img = generator_eval(test_src_image)

    fig, axs = plt.subplots(1, 3, figsize=(15, 5))
    axs[0].imshow(test_src_image.cpu().permute(1, 2, 0))
    axs[0].set_title("Source("+rand_str+".jpg)")

    axs[1].imshow(gen_img.detach().cpu().squeeze().permute(1, 2, 0))
    axs[1].set_title("Generated")

    axs[2].imshow(test_target_image.cpu().permute(1, 2, 0))
    axs[2].set_title("Target")

    plt.show()

def show_result_single(img):
    src_img = (read_image(img)/255).float()
    src_img = transform(src_img).to(device)

    generator_eval = generator.eval()
    gen_img = generator_eval(src_img)

    plt.imshow(gen_img.detach().cpu().squeeze().permute(1, 2, 0))
    plt.show()

In [None]:
# Training

for epoch in range(epochs):

  for img, label in images_dl:

    img = img.to(device)

    src_img = img[:, :, :, :512]
    tgt_img = img[:, :, :, 512:]

    gen_img = generator(src_img)

    # Generator

    optimizer_G.zero_grad()

    srcGenCat_gen = torch.cat((src_img, gen_img), dim=1)
    dis_call_gen = discriminator(srcGenCat_gen).detach().squeeze()
    G_adv_loss = loss_bce(dis_call_gen, torch.ones_like(dis_call_gen).to(device))
    G_l1_loss = loss_l1(gen_img, tgt_img)
    loss_G = G_adv_loss + l1_var * G_l1_loss
    loss_G.backward()
    optimizer_G.step()


    # Discriminator

    optimizer_D.zero_grad()

    srcTgtCat = torch.cat((src_img, tgt_img), dim=1)
    srcGenCat = torch.cat((src_img, gen_img.detach()), dim=1)
    dis_call = discriminator(srcTgtCat).squeeze()
    real_loss = loss_bce(dis_call, torch.ones_like(dis_call).to(device))
    dis_call = discriminator(srcGenCat).squeeze()
    fake_loss = loss_bce(dis_call, torch.zeros_like(dis_call).to(device))
    loss_D = (fake_loss+real_loss)/2
    loss_D.backward()
    optimizer_D.step()


  if ((epoch+1)%1) == 0:

    gen_loss = loss_G.item()
    dis_loss = loss_D.item()
    #print(f"Epoch{epoch+1} Discriminator Loss: {dis_loss}, Generator Loss: {gen_loss}", end='\r')
    print(f"Epoch{epoch+1} Discriminator Loss: {dis_loss}, Generator Loss: {gen_loss}")

  if ((epoch+1)%save_and_show_result_val) == 0:

      torch.save(generator.state_dict(), 'generator.pt')
      torch.save(discriminator.state_dict(), 'discriminator.pt')

      show_result()


In [None]:
show_result()

In [None]:
show_result_single('test.jpg')

# Deploy

In [None]:
!pip install onnx

In [None]:
img = torch.randn(3, 512, 512).to(device)
out = generator(img)
torch.onnx.export(generator.eval(), out, "generator.onnx", input_names=['input'])