In [22]:
import torch
from torch import nn
from tqdm.auto import tqdm
from torchvision import transforms
from torchvision.utils import make_grid
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
from torch.utils.data import Dataset
import numpy as np
import os

class Block(nn.Module):
  def __init__(self, in_chan ,out_chan, stride):
    super().__init__()
    self.conv = nn.Sequential(nn.Conv2d(in_chan, out_chan, kernel_size = 4, stride = stride, padding = 1, bias = True, padding_mode = "reflect"),
                              nn.InstanceNorm2d(out_chan),
                              nn.LeakyReLU(0.2))

  def forward(self, x):
    return self.conv(x)

In [23]:
class Discr(nn.Module):
  def __init__(self, in_chan = 3 , features = [64,128, 256, 512]):
    super().__init__()
    self.initial = nn.Sequential(
        nn.Conv2d(in_chan, features[0], kernel_size=4, stride =2, padding = 1, padding_mode="reflect"),
        nn.LeakyReLU(0.2)
    )
    layers = []
    in_chan = features[0]
    for feature in features[1:]:
      layers.append(Block(in_chan, feature, stride = 1 if feature == features[-1] else 2))
      in_chan = feature
  
    layers.append(nn.Conv2d(in_chan, 1, 4, stride = 1, padding = 1, padding_mode = "reflect"))
    self.model = nn.Sequential(*layers)

  def forward(self, x):
    return torch.sigmoid(self.model(self.initial(x)))
  
def test():
  x = torch.randn((5,3,256,256))
  model = Discr(3)
  return model(x).shape

In [24]:
class Convblock(nn.Module):
  def __init__(self, in_chan, out_chan, down = True, use_act = True, **kwargs):
    super().__init__()
    self.conv = nn.Sequential(
        nn.Conv2d(in_chan, out_chan, padding_mode = "reflect", **kwargs)
        if down 
        else nn.ConvTranspose2d(in_chan, out_chan , **kwargs),
        nn.InstanceNorm2d(out_chan),
        nn.ReLU(inplace=True) if use_act else nn.Identity()
    )
  
  def forward(self, x):
    return self.conv(x)

In [25]:
class Resblock(nn.Module):
  def __init__(self, channels):
    super().__init__()
    self.block = nn.Sequential(
        Convblock(channels, channels, kernel_size = 3, padding = 1),
        Convblock(channels, channels, use_act = False, kernel_size = 3, padding = 1)
    )
  
  def forward(self, x):
    return x + self.block(x)

In [26]:
class Generator(nn.Module):
  def __init__(self, img_chan, num_features = 64, num_res = 9):
    super().__init__()
    self.initial = nn.Sequential(
        nn.Conv2d(img_chan, num_features, kernel_size=7, stride = 1, padding = 3, padding_mode="reflect"),
        nn.ReLU(inplace=True)
    )
  
    self.down_blocks = nn.ModuleList([
                                      Convblock(num_features, num_features*2, kernel_size=3, stride = 2,padding = 1),
                                      Convblock(num_features*2, num_features*4, kernel_size=3, stride = 2,padding = 1)
    ])

    self.residual_block = nn.Sequential(
        *[Resblock(num_features*4) for _ in range(num_res)]
    )

    self.up_blocks = nn.ModuleList(
        [Convblock(num_features*4, num_features*2, down = False, kernel_size = 3, stride = 2, padding = 1,output_padding = 1),
         Convblock(num_features*2, num_features*1, down = False, kernel_size = 3, stride = 2, padding = 1,output_padding = 1)]
    )

    self.last = nn.Conv2d(num_features, img_chan, kernel_size=7, stride = 1, padding = 3, padding_mode = "reflect")

  def forward(self, x):
    x = self.initial(x)
    for layer in self.down_blocks:
      x  = layer(x)
    x =self.residual_block(x)
    for layer in self.up_blocks:
      x = layer(x)
    
    return torch.tanh(self.last(x))

def test():
  x = torch.randn((2, 3, 256, 256))
  gen = Generator(3, 9)
  print(gen(x).shape)

In [27]:
!pip install --upgrade albumentations

Requirement already up-to-date: albumentations in /usr/local/lib/python3.7/dist-packages (0.5.2)


In [28]:
import torch
import albumentations as A
from albumentations.pytorch import ToTensorV2

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
TRAIN_DIR = "data/train"
VAL_DIR = "data/val"
BATCH_SIZE = 1
LEARNING_RATE = 1e-5
LAMBDA_IDENTITY = 0.0
LAMBDA_CYCLE = 10
NUM_WORKERS = 4
NUM_EPOCHS = 10
LOAD_MODEL = True
SAVE_MODEL = True
CHECKPOINT_GEN_H = "genh.pth.tar"
CHECKPOINT_GEN_Z = "genz.pth.tar"
CHECKPOINT_CRITIC_H = "critich.pth.tar"
CHECKPOINT_CRITIC_Z = "criticz.pth.tar"

transforms = A.Compose(
    [
        A.Resize(width=256, height=256),
        A.HorizontalFlip(p=0.5),
        A.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], max_pixel_value=255),
        ToTensorV2(),
     ],
    additional_targets={"image0": "image"},
)

### Dataset

In [29]:
from PIL import Image

class HorseZebraDataset(Dataset):
  def __init__(self, dir_zebra, dir_horse, transform = None):
    self.root_zebra = dir_zebra
    self.dir_horse = dir_horse
    self.transform = transform

    self.horse_image= os.listdir(dir_horse)
    self.zebra_image = os.listdir(dir_zebra)

    self.len_ds = max(len(self.zebra_image), len(self.horse_image))
    self.zebra_len = len(self.zebra_image)
    self.horse_len = len(self.horse_image)

  def __len__(self):
    return self.len_ds
  
  def __getitem__(self, ind):
    zebra_img = self.zebra_image[ind % self.zebra_len]
    horse_img = self.horse_image[ind % self.horse_len]

    zebra_path = os.path.join(self.root_zebra, zebra_img)
    horse_path = os.path.join(self.dir_horse, horse_img)
    zebra_img = np.array(Image.open(zebra_path).convert("RGB"))
    horse_img = np.array(Image.open(horse_path).convert("RGB"))

    if self.transform:
      aug = self.transform(image = zebra_img, image0 = horse_img)
      zebra_img = aug['image']
      horse_img = aug['image0']

    return zebra_img, horse_img

###Train

In [39]:
from tqdm import tqdm
from tqdm import tqdm_notebook
from torchvision.utils import save_image
import torch.optim as optim

In [44]:
device = torch.device("cuda")

def train_fn(disc_H, disc_Z, gen_Z, gen_H, loader, opt_disc, opt_gen , l1, mse, d_scaler, g_scaler):
  loop = tqdm_notebook(loader, leave=  True)

  for idx, (zebra, horse) in enumerate(loop):
    H_reals = 0
    H_fakes = 0
    zebra = zebra.to(device)
    horse = horse.to(device)

    with torch.cuda.amp.autocast():
      fake_horse = gen_H(zebra)
      D_H_real = disc_H(horse)
      D_H_fake = disc_H(fake_horse.detach())
      H_reals = D_H_real.mean().item()
      H_fakes = D_H_fake.mean().item()

      D_H_real_loss = mse(D_H_real, torch.ones_like(D_H_real))
      D_H_fake_loss = mse(D_H_fake, torch.zeros_like(D_H_fake))
      D_H_loss = D_H_fake_loss + D_H_real_loss

      fake_zebra = gen_Z(horse)
      D_Z_real = disc_Z(zebra)
      D_Z_fake = disc_Z(fake_zebra.detach())
      D_Z_real_loss = mse(D_Z_real, torch.ones_like(D_Z_real))
      D_Z_fake_loss = mse(D_Z_fake, torch.zeros_like(D_Z_fake))
      D_Z_loss = D_Z_fake_loss + D_Z_real_loss

      D_loss = (D_H_loss + D_Z_loss)/ 2

    opt_disc.zero_grad()
    d_scaler.scale(D_loss).backward()
    d_scaler.step(opt_disc)
    d_scaler.update()

    with torch.cuda.amp.autocast():
      D_H_fake = disc_H(fake_horse)
      D_Z_fake = disc_Z(fake_zebra)
      loss_G_H = mse(D_H_fake, torch.ones_like(D_H_fake))
      loss_G_Z = mse(D_Z_fake, torch.ones_like(D_Z_fake))

      cycle_zebra = gen_Z(fake_horse)
      cycle_horse = gen_H(fake_zebra)
      cycle_zebra_loss = l1(zebra, cycle_zebra)
      cycle_horse_loss = l1(horse, cycle_horse)

      id_zebra = gen_Z(zebra)
      id_horse = gen_H(horse)
      id_zebra_loss = l1(zebra, id_zebra)
      id_horse_loss = l1(horse, id_horse)

      G_loss = (loss_G_H + loss_G_Z + cycle_zebra_loss * LAMBDA_CYCLE + cycle_horse_loss * LAMBDA_CYCLE + id_horse_loss * LAMBDA_IDENTITY + id_zebra_loss*LAMBDA_IDENTITY)

    opt_gen.zero_grad()
    g_scaler.scale(G_loss).backward()
    g_scaler.step(opt_gen)
    g_scaler.update()

    if idx % 100 == 0:
      save_image(fake_horse * 0.5 + 0.5, f"/content/output/horse_{idx}.png")
      save_image(fake_zebra * 0.5 + 0.5, f"/content/output/zebra_{idx}.png")
    loop.set_postfix(H_real=H_reals/(idx+1), H_fake=H_fakes/(idx+1))
      
def main():
  disc_H = Discr(3).to(device) #horse
  disc_Z = Discr(3).to(device) #zebra
  gen_Z = Generator(img_chan=3).to(device) #horse -> zebra
  gen_H = Generator(img_chan=3).to(device) #zebra -> horse
  opt_disc = optim.Adam(list(disc_H.parameters()) + list(disc_Z.parameters()),
                        lr = LEARNING_RATE,betas  =(0.5, 0.999))
  opt_gen = optim.Adam(list(gen_H.parameters()) + list(gen_Z.parameters()), lr = LEARNING_RATE, betas  =(0.5, 0.999))

  l1 = nn.L1Loss()
  mse = nn.MSELoss()

  dataset = HorseZebraDataset("/content/zebra", "/content/horse", transform= transforms)
  loader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle = True, pin_memory=True)

  g_scaler = torch.cuda.amp.GradScaler()
  d_scaler = torch.cuda.amp.GradScaler()

  for epoch in range(NUM_EPOCHS):

    train_fn(disc_H, disc_Z, gen_Z, gen_H, loader, opt_disc, opt_gen , l1, mse, d_scaler, g_scaler)

In [45]:
main()

Please use `tqdm.notebook.tqdm` instead of `tqdm.tqdm_notebook`
  after removing the cwd from sys.path.


HBox(children=(FloatProgress(value=0.0, max=1334.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=1334.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=1334.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=1334.0), HTML(value='')))

KeyboardInterrupt: ignored