# Importing and Utils

In [15]:
import os
import random
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset
import albumentations as A
from albumentations.pytorch import ToTensorV2
import cv2

from torch.optim import Adam


random.seed(42)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [2]:
data_root = "data"

train_folder = os.path.join(data_root, "train")
test_folder = os.path.join(data_root, "test")

train_file_names = os.listdir(os.path.join(train_folder, "image"))
random.shuffle(train_file_names)

# Dataset

In [3]:
IMG_SIZE = (384, 512)

train_transforms = A.Compose([
    A.Resize(height=IMG_SIZE[0], width=IMG_SIZE[1]),
    A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
    ToTensorV2()
])

In [4]:
BATCH_SIZE = 2

class TrainDataset(Dataset):
  def __init__(self, image_list, image_folder, transforms):
    self.image_list = image_list
    self.image_folder = image_folder
    self.transforms = transforms

  def __len__(self): return len(self.image_list)

  def __getitem__(self, i):
    image_name = self.image_list[i]

    agnostic_path = os.path.join(self.image_folder, "agnostic", image_name)
    cloth_path = os.path.join(self.image_folder, "cloth", image_name)
    output_img_path = os.path.join(self.image_folder, "image", image_name)

    agnostic_image = cv2.cvtColor(cv2.imread(agnostic_path), cv2.COLOR_BGR2RGB)
    cloth_image = cv2.cvtColor(cv2.imread(cloth_path), cv2.COLOR_BGR2RGB)
    output_img_image = cv2.cvtColor(cv2.imread(output_img_path), cv2.COLOR_BGR2RGB)

    agnostic_image = self.transforms(image=agnostic_image)["image"]
    cloth_image = self.transforms(image=cloth_image)["image"]
    output_img_image = self.transforms(image=output_img_image)["image"]

    return agnostic_image, cloth_image, output_img_image

train_dataset = TrainDataset(train_file_names, train_folder, train_transforms)
train_dataloader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)

# Model

In [5]:
class DoubleConv(nn.Module):
  def __init__(self, in_channels, out_channels, kernel_size=3, padding=1):
    super(DoubleConv, self).__init__()

    self.in_channels = in_channels
    self.out_channels = out_channels
    self.kernel_size = kernel_size
    self.padding = padding

    self.double_conv = nn.Sequential(
        nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, padding=padding),
        nn.ReLU(inplace=True),
        nn.Conv2d(out_channels, out_channels, kernel_size=kernel_size, padding=padding),
        nn.ReLU(inplace=True)
    )

  def forward(self, X):
    X = self.double_conv(X)
    
    return X

In [6]:
class DownSample(nn.Module):
  def __init__(self, in_channels, out_channels, kernel_size=3, padding=1):
    super(DownSample, self).__init__()

    self.in_channels = in_channels
    self.out_channels = out_channels
    self.kernel_size = kernel_size
    self.padding = padding

    self.double_conv = DoubleConv(in_channels, out_channels, kernel_size=kernel_size, padding=padding)
    self.pool = nn.MaxPool2d(kernel_size=2, stride=2)

  def forward(self, X):
    X = self.double_conv(X)
    X_pooled = self.pool(X)

    return X, X_pooled

In [7]:
class UpSample(nn.Module):
  def __init__(self, in_channels, out_channels, kernel_size=3, padding=1):
    super(UpSample, self).__init__()

    self.in_channels = in_channels
    self.out_channels = out_channels
    self.kernel_size = kernel_size
    self.padding = padding

    self.up = nn.ConvTranspose2d(in_channels, out_channels, kernel_size=2, stride=2)
    self.double_conv = DoubleConv(in_channels, out_channels, kernel_size=kernel_size, padding=padding)

  def forward(self, X, X_skip):
    X = self.up(X)

    X_cat = torch.cat((X, X_skip), dim=1)
    X_cat = self.double_conv(X_cat)

    return X_cat

In [8]:
class UNet(nn.Module):
  def __init__(self, filter_sizes):
    super(UNet, self).__init__()

    self.filter_sizes = filter_sizes
    self.down_sample_blocks = []

    current_in_channels = 6
    for filter_size in filter_sizes:
      self.down_sample_blocks.append(DownSample(current_in_channels, filter_size, kernel_size=3, padding=1).to(device))
      current_in_channels = filter_size

    self.bottleneck = DoubleConv(filter_sizes[-1], filter_sizes[-1]*2, kernel_size=3, padding=1)

    self.up_sample_blocks = []
    for filter_size in filter_sizes[::-1]:
      self.up_sample_blocks.append(UpSample(filter_size*2, filter_size, kernel_size=3, padding=1).to(device))

    self.out_conv = nn.Conv2d(filter_sizes[0], 3, kernel_size=1)


  def forward(self, X_agnostic, X_cloth):
    X = torch.cat((X_agnostic, X_cloth), dim=1)

    X1_skip, X = self.down_sample_blocks[0](X)
    X2_skip, X = self.down_sample_blocks[1](X)
    X3_skip, X = self.down_sample_blocks[2](X)
    X4_skip, X = self.down_sample_blocks[3](X)

    X = self.bottleneck(X)

    X = self.up_sample_blocks[0](X, X4_skip)
    X = self.up_sample_blocks[1](X, X3_skip)
    X = self.up_sample_blocks[2](X, X2_skip)
    X = self.up_sample_blocks[3](X, X1_skip)

    X = self.out_conv(X)

    return X

In [9]:
model = UNet([16, 32, 64, 128]).to(device)

# Training

## Train Loop

In [18]:
EPOCHS = 10

loss_fn = torch.nn.MSELoss()
optimizer = Adam(model.parameters(), lr=1e-4)

In [29]:
from tqdm import tqdm

def train(model, dataloader, loss_fn, optimizer, epoch):
  model.train()

  batch_losses = []

  pbar = tqdm(dataloader, unit="batch", leave=False, desc=f"Training : Epoch [{epoch+1}/{EPOCHS}]")
  for batch_idx, (X1, X2, target) in enumerate(pbar):
    X1 = X1.to(device)
    X2 = X2.to(device)
    target = target.to(device)

    prediction = model(X1, X2)
    loss = loss_fn(prediction, target)

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    batch_losses.append(loss.item())
    pbar.set_postfix({"Batch Loss": loss.item()})

  return batch_losses

## Validation Loop

In [30]:
def valid(model, dataloader, loss_fn, epoch):
  model.eval()

  batch_losses = []

  pbar = tqdm(dataloader, unit="batch", leave=False, desc=f"Validation : Epoch [{epoch+1}/{EPOCHS}]")
  for batch_idx, (X1, X2, target) in enumerate(pbar):
    X1 = X1.to(device)
    X2 = X2.to(device)
    target = target.to(device)

    with torch.inference_mode():
      prediction = model(X1, X2)
      loss = loss_fn(prediction, target)

    batch_losses.append(loss.item())
    pbar.set_postfix({"Batch Loss": loss.item()})

  return batch_losses

In [None]:
for epoch in range(EPOCHS):
  train_losses = train(model, train_dataloader, loss_fn, optimizer, epoch)
  print(f"Epoch [{epoch+1}/{EPOCHS}]\tLoss: {sum(train_losses)/len(train_losses)}")

  valid_losses = valid(model, train_dataloader, loss_fn, epoch)
  print(f"Epoch [{epoch+1}/{EPOCHS}]\tLoss: {sum(valid_losses)/len(valid_losses)}")