In [None]:
import os
import numpy as np
from PIL import Image
import torchvision.transforms as T
from torchvision.utils import save_image

import torch
from torch import nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torch import optim

In [None]:
class ConvBlock(nn.Module):
  def __init__(self, in_channels, out_channels, kernel_size, stride, padding, use_activation=True):
    super().__init__()
    self.conv = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride, padding=padding, padding_mode='reflect')
    self.inst_norm = nn.InstanceNorm2d(num_features=out_channels)
    self.use_activation = use_activation

  def forward(self, x):
    x = self.inst_norm(self.conv(x))
    if self.use_activation:
      x = F.relu(x)
    return x

In [None]:
class ResidualBlock(nn.Module):
  def __init__(self, in_channels, out_channels, kernel_size, stride, padding):
    super().__init__()
    self.conv1 = ConvBlock(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride, padding=padding, use_activation=True)
    self.conv2 = ConvBlock(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride, padding=padding, use_activation=False)

  def forward(self, x):
    identity = x
    return F.relu(torch.add(identity, self.conv2(self.conv1(x))))

In [None]:
class FractionalConvBlock(nn.Module):
  def __init__(self, in_channels, out_channels, kernel_size, stride, padding, output_padding):
    super().__init__()
    self.frac_seq = nn.Sequential(
        nn.ConvTranspose2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride, padding=padding, output_padding=output_padding),
        nn.InstanceNorm2d(num_features=out_channels),
        nn.ReLU()
    )

  def forward(self, x):
    return self.frac_seq(x)

In [None]:
class Generator(nn.Module):
  def __init__(self, in_channels, num_of_residual_blocks):
    super().__init__()
    self.conv1 = ConvBlock(in_channels=in_channels, out_channels=64, kernel_size=7, stride=1, padding=3)
    self.conv2 = ConvBlock(in_channels=64, out_channels=128, kernel_size=3, stride=2, padding=1)
    self.conv3 = ConvBlock(in_channels=128, out_channels=256, kernel_size=3, stride=2, padding=1)

    self.residual_blocks = []
    self.num_of_residual_blocks = num_of_residual_blocks
    for i in range(self.num_of_residual_blocks):
      self.residual_blocks.append(ResidualBlock(in_channels=256, out_channels=256, kernel_size=3, stride=1, padding=1))

    self.residual_layers = nn.Sequential(
        *self.residual_blocks
    )

    self.frac_conv1 = FractionalConvBlock(in_channels=256, out_channels=128, kernel_size=3, stride=2, padding=1, output_padding=1)
    self.frac_conv2 = FractionalConvBlock(in_channels=128, out_channels=64, kernel_size=3, stride=2, padding=1, output_padding=1)
    self.conv4 = ConvBlock(in_channels=64, out_channels=in_channels, kernel_size=7, stride=1, padding=3, use_activation=False)

  def forward(self, x):
    x = self.conv1(x)
    x = self.conv2(x)
    x = self.conv3(x)
    x = self.residual_layers(x)
    x = self.frac_conv1(x)
    x = self.frac_conv2(x)
    x = torch.tanh(self.conv4(x))
    return x

In [None]:
class DiscConvBlock(nn.Module):
  def __init__(self, in_channels, out_channels, kernel_size, stride, padding, use_instancenorm=True):
    super().__init__()
    self.conv = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride, padding=padding, padding_mode='reflect')
    self.inst_norm = nn.InstanceNorm2d(num_features=out_channels)
    self.use_instancenorm = use_instancenorm

  def forward(self, x):
    x = self.conv(x)
    if self.use_instancenorm:
      x = self.inst_norm(x)
    return F.leaky_relu(input=x, negative_slope=0.2)

In [None]:
class Discriminator(nn.Module):
  def __init__(self, in_channels):
    super().__init__()
    self.disc_seq = nn.Sequential(
        DiscConvBlock(in_channels=in_channels, out_channels=64, kernel_size=4, stride=2, padding=1, use_instancenorm=False),
        DiscConvBlock(in_channels=64, out_channels=128, kernel_size=4, stride=2, padding=1),
        DiscConvBlock(in_channels=128, out_channels=256, kernel_size=4, stride=2, padding=1),
        DiscConvBlock(in_channels=256, out_channels=512, kernel_size=4, stride=1, padding=1),
        nn.Conv2d(in_channels=512, out_channels=1, kernel_size=4, stride=1, padding=1, padding_mode='reflect'),
        nn.Sigmoid()
    )

  def forward(self, x):
    return self.disc_seq(x)

In [None]:
class CycleGanDataset(Dataset):
  def __init__(self, summer_root_dir, winter_root_dir):
    self.summer_root_dir = summer_root_dir
    self.winter_root_dir = winter_root_dir
    self.summer_images = os.listdir(summer_root_dir)
    self.winter_images = os.listdir(winter_root_dir)

    self.length_dataset = max(len(self.summer_images), len(self.winter_images))
    self.length_summer = len(self.summer_images)
    self.length_winter = len(self.winter_images)

    self.image_transform = T.ToTensor()

  def __len__(self):
    return self.length_dataset

  def __getitem__(self, index):
    summer_image_path = os.path.join(self.summer_root_dir, self.summer_images[index % self.length_summer])
    winter_image_path = os.path.join(self.winter_root_dir, self.winter_images[index % self.length_winter])

    summer_image = self.image_transform(Image.open(summer_image_path))
    winter_image = self.image_transform(Image.open(winter_image_path))

    return summer_image, winter_image

In [None]:
train_dataset = CycleGanDataset("/content/drive/MyDrive/SummerToWinterCycleGANDataset/trainA", "/content/drive/MyDrive/SummerToWinterCycleGANDataset/trainB")

In [None]:
train_dataloader = DataLoader(train_dataset, shuffle=True)

In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"

In [None]:
disc_summer = Discriminator(in_channels=3).to(device)
disc_winter = Discriminator(in_channels=3).to(device)

gen_summer = Generator(in_channels=3, num_of_residual_blocks=9).to(device)
gen_winter = Generator(in_channels=3, num_of_residual_blocks=9).to(device)

In [None]:
optimizer_disc = optim.Adam(
    list(disc_summer.parameters()) + list(disc_winter.parameters()),
    lr=2e-4,
    betas=(0.5, 0.999)
)

In [None]:
optimizer_gen = optim.Adam(
    list(gen_summer.parameters()) + list(gen_winter.parameters()),
    lr=2e-4,
    betas=(0.5, 0.999)
)

In [None]:
l1_loss = nn.L1Loss()
mse = nn.MSELoss()

In [None]:
count = 1
for epoch in range(1, 201):
  for idx, (summer_image, winter_image) in enumerate(train_dataloader):
    summer_image = summer_image.to(device)
    winter_image = winter_image.to(device)

    fake_summer = gen_summer(winter_image)
    disc_summer_real = disc_summer(summer_image)
    disc_summer_fake = disc_summer(fake_summer.detach())
    disc_summer_real_loss = mse(disc_summer_real, torch.ones_like(disc_summer_real))
    disc_summer_fake_loss = mse(disc_summer_fake, torch.zeros_like(disc_summer_fake))
    disc_summer_loss = disc_summer_real_loss + disc_summer_fake_loss

    fake_winter = gen_winter(summer_image)
    disc_winter_real = disc_winter(winter_image)
    disc_winter_fake = disc_winter(fake_winter.detach())
    disc_winter_real_loss = mse(disc_winter_real, torch.ones_like(disc_winter_real))
    disc_winter_fake_loss = mse(disc_winter_fake, torch.zeros_like(disc_winter_fake))
    disc_winter_loss = disc_winter_real_loss + disc_winter_fake_loss

    disc_loss = (disc_summer_loss + disc_winter_loss)/2

    optimizer_disc.zero_grad()
    disc_loss.backward()
    optimizer_disc.step()


    disc_summer_fake = disc_summer(fake_summer)
    disc_winter_fake = disc_winter(fake_winter)
    loss_gen_summer = mse(disc_summer_fake, torch.ones_like(disc_summer_fake))
    loss_gen_winter = mse(disc_winter_fake, torch.ones_like(disc_winter_fake))


    cycle_summer = gen_summer(fake_winter)
    cycle_winter = gen_winter(fake_summer)
    cycle_summer_loss = l1_loss(summer_image, cycle_summer)
    cycle_winter_loss = l1_loss(winter_image, cycle_winter)


    gen_loss = (loss_gen_summer + loss_gen_winter + cycle_summer_loss*10 + cycle_winter_loss*10)

    optimizer_gen.zero_grad()
    gen_loss.backward()
    optimizer_gen.step()

    if idx % 200 == 0:
      save_image(fake_summer, f"/content/drive/MyDrive/teststow/summer_{count}_{idx}.jpeg")
      save_image(fake_winter, f"/content/drive/MyDrive/teststow/winter_{count}_{idx}.jpeg")
      count += 1

    if epoch % 10 == 0:
      torch.save(gen_summer.state_dict(), f"/content/drive/MyDrive/stow_checkpoints/summer_gen_epoch{epoch-1}.pth")
      torch.save(gen_winter.state_dict(), f"/content/drive/MyDrive/stow_checkpoints/winter_gen_epoch{epoch-1}.pth")
      torch.save(disc_summer.state_dict(), f"/content/drive/MyDrive/stow_checkpoints/summer_disc_epoch{epoch-1}.pth")
      torch.save(disc_winter.state_dict(), f"/content/drive/MyDrive/stow_checkpoints/winter_disc_epoch{epoch-1}.pth")