<a href="https://colab.research.google.com/github/puru9860/Cycle-GAN/blob/main/CycleGan_pytorch.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install git+https://github.com/albumentations-team/albumentations.git
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from PIL import Image
import os
import numpy as np
import albumentations as A
from albumentations.pytorch import ToTensorV2
from torch.utils.data import Dataset
import random
import copy
import sys
from tqdm import tqdm
from torchvision.utils import save_image

Collecting git+https://github.com/albumentations-team/albumentations.git
  Cloning https://github.com/albumentations-team/albumentations.git to /tmp/pip-req-build-yy3q7cwi
  Running command git clone -q https://github.com/albumentations-team/albumentations.git /tmp/pip-req-build-yy3q7cwi
Building wheels for collected packages: albumentations
  Building wheel for albumentations (setup.py) ... [?25l[?25hdone
  Created wheel for albumentations: filename=albumentations-1.0.0-cp37-none-any.whl size=98151 sha256=cc89e907d24e59bb5d586bc205a8e7f4c424f71df7d9071e62819cd7883a96aa
  Stored in directory: /tmp/pip-ephem-wheel-cache-i2f6ma8l/wheels/e2/85/3e/2a40fac5cc1f43ced656603bb2fca1327b30ec7de1b1b66517
Successfully built albumentations
Installing collected packages: albumentations
  Found existing installation: albumentations 0.1.12
    Uninstalling albumentations-0.1.12:
      Successfully uninstalled albumentations-0.1.12
Successfully installed albumentations-1.0.0


In [None]:
# !pip install -U git+https://github.com/albu/albumentations > /dev/null 

In [None]:
DEVICE  = "cuda" if torch.cuda.is_available() else "cpu"
TRAIN_DIR= "/content/drive/MyDrive/unmask_2/train/"
VAL_DIR = "/content/drive/MyDrive/unmask_2/val/"
BATCH_SIZE = 1
LEARNING_RATE = 1e-5
LAMBDA_IDENTITY = 0.0
LAMBDA_CYCLE = 10
NUM_WORKERS = 2
NUM_EPOCHS = 1000
LOAD_MODEL = True
SAVE_MODEL = False
CHECKPOINT_GEN_M = "/content/drive/MyDrive/unmask_2/ck/genM.pth.tar"
CHECKPOINT_GEN_W = "/content/drive/MyDrive/unmask_2/ck/genW.pth.tar"
CHECKPOINT_CRITIC_M = "/content/drive/MyDrive/unmask_2/ck/criticM.pth.tar"
CHECKPOINT_CRITIC_W = "/content/drive/MyDrive/unmask_2/ck/criticW.pth.tar"


In [None]:
# Discriminator Block

class Block(nn.Module):
  def __init__(self,in_channels,out_channels,stride):
    super().__init__()
    self.conv = nn.Sequential(
        nn.Conv2d(in_channels,out_channels,4,stride,1,bias=True,padding_mode="reflect"),
        nn.InstanceNorm2d(out_channels),
        nn.LeakyReLU(0.2)
    )
  def forward(self,x):
    return self.conv(x)


class Discriminator(nn.Module):
  def __init__(self,in_channels=3,features=[64,128,256,512]):
    super().__init__()
    self.initial = nn.Sequential(
        nn.Conv2d(
            in_channels,
            features[0],
            kernel_size=4,
            stride=2,
            padding=1,
            padding_mode="reflect"            
        ),
        nn.LeakyReLU(0.2),
    )

    layers = []
    in_channels = features[0]
    for feature in features[1:]:
      layers.append(Block(in_channels,feature,stride=1 if feature == features[-1] else 2))
      in_channels = feature
    layers.append(nn.Conv2d(in_channels,1,kernel_size=4,stride=1,padding=1,padding_mode="reflect"))
    self.model = nn.Sequential(*layers)


  def forward(self,x):
    x = self.initial(x)
    return torch.sigmoid(self.model(x))

In [None]:
# Generator model
class ConvBlock(nn.Module):
  def __init__(self,in_channels,out_channels, down=True,use_act=True, **kwargs):
    super().__init__()
    self.conv = nn.Sequential(
        nn.Conv2d(in_channels,out_channels,padding_mode="reflect", **kwargs)
        if down
        else nn.ConvTranspose2d(in_channels,out_channels,**kwargs),
        nn.InstanceNorm2d(out_channels),
        nn.ReLU(inplace=True) if use_act else nn.Identity()
    )
  def forward(self,x):
    return self.conv(x)
  
class ResidualBlock(nn.Module):
  def __init__(self,channels):
    super().__init__()
    self.block = nn.Sequential(
        ConvBlock(channels,channels,kernel_size=3,padding =1),
        ConvBlock(channels,channels,use_act=False,kernel_size=3,padding=1),
    )
  def forward(self,x):
    return x + self.block(x)

class Generator(nn.Module):
  def __init__(self,img_channels,num_features=64,num_residuals=9):
    super().__init__()
    self.initial = nn.Sequential(
      nn.Conv2d(img_channels,num_features,kernel_size=7,stride = 1,padding=3,padding_mode="reflect"),
      nn.ReLU(inplace=True),
    )
    self.down_blocks = nn.ModuleList(
        [
          ConvBlock(num_features,num_features*2,kernel_size = 3,stride =2,padding=1),
          ConvBlock(num_features*2,num_features*4,kernel_size = 3,stride =2,padding=1),

        ]
    )

    self.residual_blocks = nn.Sequential(
        *[ResidualBlock(num_features*4) for _ in range(num_residuals)]
    )

    self.up_blocks = nn.ModuleList(
        [
         ConvBlock(num_features*4,num_features*2,down=False,kernel_size=3,stride=2,padding=1,output_padding=1),
         ConvBlock(num_features*2,num_features,down=False,kernel_size=3,stride=2,padding=1,output_padding=1)

        ]
    )
    self.last = nn.Conv2d(num_features,img_channels,kernel_size=7,stride=1,padding=3,padding_mode="reflect")

  def forward(self,x):
    x = self.initial(x)
    for layer in self.down_blocks:
      x = layer(x)

    x = self.residual_blocks(x)

    for layer in self.up_blocks:
      x = layer(x)
    return torch.tanh(self.last(x))

In [None]:
transforms = A.Compose(
    [
     A.Resize(width=256,height=256),
     A.HorizontalFlip(p=0.5),
     A.Normalize(mean=[0.5,0.5,0.5],std=[0.5,0.5,0.5],max_pixel_value = 255),
     ToTensorV2(),
    ],
    additional_targets={"image0":"image"},
)

In [None]:
class MaskedDataset(Dataset):
  def __init__(self,root_masked,root_without,transform=None):
    self.root_masked = root_masked
    self.root_without  = root_without
    self.transform = transform

    self.masked_images = os.listdir(root_masked)
    self.without_images = os.listdir(root_without)
    self.masked_len = len(self.masked_images)
    self.without_len = len(self.without_images)
    self.length_dataset = max(self.masked_len,self.without_len)

  def __len__(self):
    return self.length_dataset
  
  def __getitem__(self,index):
    masked_img = self.masked_images[index % self.masked_len]
    without_img = self.without_images[index % self.without_len]

    masked_path = os.path.join(self.root_masked,masked_img)
    without_path = os.path.join(self.root_without,without_img)

    masked_img = np.array(Image.open(masked_path).convert("RGB"))
    without_img = np.array(Image.open(without_path).convert("RGB"))

    if self.transform:
      augmentations = self.transform(image=masked_img,image0=without_img)
      masked_img = augmentations["image"]
      without_img = augmentations["image0"]
    
    return masked_img,without_img

In [None]:
# save and load checkpoints

def save_checkpoint(model,optimizer,filename="my_checkpoint.pth.tar"):
  print("=>Saving checkpoint")
  checkpoint = {
      "state_dict":model.state_dict(),
      "optimizer": optimizer.state_dict(),
  }
  torch.save(checkpoint,filename)


def load_checkpoint(checkpoint_file,model,optimizer,lr):
  print("=>Loading Checkpoint")
  checkpoint = torch.load(checkpoint_file,map_location=DEVICE)
  model.load_state_dict(checkpoint["state_dict"])
  optimizer.load_state_dict(checkpoint["optimizer"])

  for param_group in optimizer.param_groups:
    param_group["lr"] = lr

In [None]:
def train_fn(disc_M,disc_W,gen_M,gen_W,loader,opt_disc,opt_gen,L1,mse,d_scaler,g_scaler):
  W_reals = 0
  W_fakes = 0
  loop = tqdm(loader,leave=True)

  for idx , (masked,without) in enumerate(loop):
    masked = masked.to(DEVICE)
    without = without.to(DEVICE)

    with torch.cuda.amp.autocast():
      fake_without = gen_W(masked)
      D_W_real = disc_W(without)
      D_W_fake = disc_W(fake_without.detach())
      W_reals += D_W_real.mean().item()
      W_fakes += D_W_fake.mean().item()
      D_W_real_loss = mse(D_W_real,torch.ones_like(D_W_real))
      D_W_fake_loss = mse(D_W_fake,torch.ones_like(D_W_fake))
      D_W_loss = D_W_real_loss + D_W_fake_loss

      fake_masked = gen_M(without)
      D_M_real = disc_M(masked)
      D_M_fake = disc_M(fake_masked.detach())
      D_M_real_loss = mse(D_M_real, torch.ones_like(D_M_real))
      D_M_fake_loss = mse(D_M_fake, torch.zeros_like(D_M_fake))
      D_M_loss = D_M_real_loss + D_M_fake_loss

      D_loss = (D_W_loss + D_M_loss)/2
    
    opt_disc.zero_grad()
    d_scaler.scale(D_loss).backward()
    d_scaler.step(opt_disc)
    d_scaler.update()

    # Train Generators H and Z
    with torch.cuda.amp.autocast():
        # adversarial loss for both generators
        D_W_fake = disc_W(fake_without)
        D_M_fake = disc_M(fake_masked)
        loss_G_W = mse(D_W_fake, torch.ones_like(D_W_fake))
        loss_G_M = mse(D_M_fake, torch.ones_like(D_M_fake))

        # cycle loss
        cycle_masked = gen_M(fake_without)
        cycle_without = gen_W(fake_masked)
        cycle_masked_loss = L1(masked, cycle_masked)
        cycle_without_loss = L1(without, cycle_without)

        # identity loss (remove these for efficiency if you set lambda_identity=0)
        identity_masked = gen_M(masked)
        identity_without = gen_W(without)
        identity_masked_loss = L1(masked, identity_masked)
        identity_without_loss = L1(without, identity_without)

        # add all togethor
        G_loss = (
            loss_G_M
            + loss_G_W
            + cycle_masked_loss * LAMBDA_CYCLE
            + cycle_without_loss * LAMBDA_CYCLE
            + identity_without_loss * LAMBDA_IDENTITY
            + identity_masked_loss * LAMBDA_IDENTITY
        )

    opt_gen.zero_grad()
    g_scaler.scale(G_loss).backward()
    g_scaler.step(opt_gen)
    g_scaler.update()

    if idx % 200 == 0:
        save_image(fake_without*0.5+0.5, f"/content/drive/MyDrive/unmask_2/ck/without_{idx}.png")
        save_image(fake_masked*0.5+0.5, f"/content/drive/MyDrive/unmask_2/ck/masked_{idx}.png")

    loop.set_postfix(W_real=W_reals/(idx+1), W_fake=W_fakes/(idx+1))



In [None]:
disc_M = Discriminator(in_channels=3).to(DEVICE)
disc_W = Discriminator(in_channels=3).to(DEVICE)
gen_M = Generator(img_channels=3,num_residuals=9).to(DEVICE)
gen_W = Generator(img_channels=3,num_residuals=9).to(DEVICE)

opt_disc = optim.Adam(
    list(disc_M.parameters()) + list(disc_W.parameters()),
    lr = LEARNING_RATE,
    betas = (0.5,0.999)
)
opt_gen = optim.Adam(
    list(gen_M.parameters()) + list(gen_W.parameters()),
    lr = LEARNING_RATE,
    betas = (0.5,0.999)
)

L1 = nn.L1Loss()
mse = nn.MSELoss()

if LOAD_MODEL:
  load_checkpoint(
      CHECKPOINT_GEN_M,gen_M,opt_gen,LEARNING_RATE,
  )
  load_checkpoint(
      CHECKPOINT_GEN_W,gen_W,opt_gen,LEARNING_RATE,
  )
  load_checkpoint(
      CHECKPOINT_CRITIC_M,disc_M,opt_disc,LEARNING_RATE,
  )
  load_checkpoint(
      CHECKPOINT_CRITIC_W,disc_W,opt_disc,LEARNING_RATE,
  )

dataset = MaskedDataset(root_masked=TRAIN_DIR+"masked",root_without=TRAIN_DIR+"without",transform=transforms)

val_dataset =  MaskedDataset(root_masked=VAL_DIR+"masked",root_without=VAL_DIR+"without",transform=transforms)

val_loader = DataLoader(
    val_dataset,
    batch_size=1,
    shuffle=False,
    pin_memory=True,
)

loader = DataLoader(
    dataset,
    batch_size=BATCH_SIZE,
    shuffle = True,
    num_workers = NUM_WORKERS,
    pin_memory = True
)
g_scaler = torch.cuda.amp.GradScaler()
d_scaler = torch.cuda.amp.GradScaler()

for epoch in range(NUM_EPOCHS):
  print(f"\nepoch: {epoch}")
  train_fn(disc_M,disc_W,gen_M,gen_W,loader,opt_disc,opt_gen,L1,mse,d_scaler,g_scaler)

  if SAVE_MODEL and epoch % 10 ==0  :
    save_checkpoint(gen_M,opt_gen,filename=CHECKPOINT_GEN_M)
    save_checkpoint(gen_W,opt_gen,filename=CHECKPOINT_GEN_W)
    save_checkpoint(disc_M,opt_disc,filename=CHECKPOINT_CRITIC_M)
    save_checkpoint(disc_W,opt_disc,filename=CHECKPOINT_CRITIC_W)


In [None]:
masked_img = np.array(Image.open('/content/h1.jpg').convert("RGB"))
