# Importing and Utils

In [1]:
import os
import random
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset
import albumentations as A
from albumentations.pytorch import ToTensorV2
import cv2

from torch.optim import Adam


random.seed(42)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [2]:
environment = "kaggle"

if environment == "local":
  data_root = "data"
  BATCH_SIZE = 4
elif environment == "kaggle":
  data_root = "/kaggle/input/marquis-viton-hd"
  BATCH_SIZE = 16

train_folder = os.path.join(data_root, "train")
valid_folder = os.path.join(data_root, "test")

train_file_names = os.listdir(os.path.join(train_folder, "image"))
random.shuffle(train_file_names)

valid_file_names = os.listdir(os.path.join(valid_folder, "image"))
random.shuffle(valid_file_names)

In [3]:
import wandb

LR = 1e-4
EPOCHS = 20
experiment_name = "v0-baseline-0.3"

if environment == "kaggle":
    from kaggle_secrets import UserSecretsClient
    
    user_secrets = UserSecretsClient()
    wandb_api = user_secrets.get_secret("WANDB_API_KEY")
    
    wandb.login(key=wandb_api)

wandb.init(
    project="viton",
    name=experiment_name,
    tags=["torch", environment, "P100", "1GPU"],
    notes="Large Model (1024 Bottleneck + Batch Norm), Smaller Image (288, 384)",
  
    config={
    "learning_rate": LR,
    "architecture": "UNet",
    "epochs": EPOCHS,
    }
)

[34m[1mwandb[0m: W&B API key is configured. Use [1m`wandb login --relogin`[0m to force relogin
[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc
[34m[1mwandb[0m: Currently logged in as: [33msashanktalakola2[0m ([33msashanktalakola[0m). Use [1m`wandb login --relogin`[0m to force relogin
[34m[1mwandb[0m: wandb version 0.17.1 is available!  To upgrade, please run:
[34m[1mwandb[0m:  $ pip install wandb --upgrade
[34m[1mwandb[0m: Tracking run with wandb version 0.17.0
[34m[1mwandb[0m: Run data is saved locally in [35m[1m/kaggle/working/wandb/run-20240607_035513-pd0tzdee[0m
[34m[1mwandb[0m: Run [1m`wandb offline`[0m to turn off syncing.
[34m[1mwandb[0m: Syncing run [33mv0-baseline-0.3[0m
[34m[1mwandb[0m: ⭐️ View project at [34m[4mhttps://wandb.ai/sashanktalakola/viton[0m
[34m[1mwandb[0m: 🚀 View run at [34m[4mhttps://wandb.ai/sashanktalakola/viton/runs/pd0tzdee[0m


# Dataset

In [4]:
IMG_SIZE = (288, 384)

train_transforms = A.Compose([
    A.Resize(height=IMG_SIZE[0], width=IMG_SIZE[1]),
    A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
    ToTensorV2()
])

label_transforms = A.Compose([
  A.Resize(height=IMG_SIZE[0], width=IMG_SIZE[1]),
  ToTensorV2()
])

valid_transforms = A.Compose([
    A.Resize(height=IMG_SIZE[0], width=IMG_SIZE[1]),
    A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
    ToTensorV2()
])

In [5]:
class TrainDataset(Dataset):
  def __init__(self, image_list, image_folder, transforms, label_transforms):
    self.image_list = image_list
    self.image_folder = image_folder
    self.transforms = transforms
    self.label_transforms = label_transforms

  def __len__(self): return len(self.image_list)

  def __getitem__(self, i):
    image_name = self.image_list[i]

    agnostic_path = os.path.join(self.image_folder, "agnostic", image_name)
    cloth_path = os.path.join(self.image_folder, "cloth", image_name)
    output_img_path = os.path.join(self.image_folder, "image", image_name)

    agnostic_image = cv2.cvtColor(cv2.imread(agnostic_path), cv2.COLOR_BGR2RGB)
    cloth_image = cv2.cvtColor(cv2.imread(cloth_path), cv2.COLOR_BGR2RGB)
    output_img_image = cv2.cvtColor(cv2.imread(output_img_path), cv2.COLOR_BGR2RGB)

    agnostic_image = self.transforms(image=agnostic_image)["image"]
    cloth_image = self.transforms(image=cloth_image)["image"]
    output_img_image = self.label_transforms(image=output_img_image)["image"]

    return agnostic_image, cloth_image, output_img_image.float()

train_dataset = TrainDataset(train_file_names, train_folder, train_transforms, label_transforms)
train_dataloader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)

In [6]:
class ValidDataset(Dataset):
  def __init__(self, image_list, image_folder, transforms, label_transforms):
    self.image_list = image_list
    self.image_folder = image_folder
    self.transforms = transforms
    self.label_transforms = label_transforms

  def __len__(self): return len(self.image_list)

  def __getitem__(self, i):
    image_name = self.image_list[i]

    agnostic_path = os.path.join(self.image_folder, "agnostic", image_name)
    cloth_path = os.path.join(self.image_folder, "cloth", image_name)
    output_img_path = os.path.join(self.image_folder, "image", image_name)

    agnostic_image = cv2.cvtColor(cv2.imread(agnostic_path), cv2.COLOR_BGR2RGB)
    cloth_image = cv2.cvtColor(cv2.imread(cloth_path), cv2.COLOR_BGR2RGB)
    output_img_image = cv2.cvtColor(cv2.imread(output_img_path), cv2.COLOR_BGR2RGB)

    agnostic_image = self.transforms(image=agnostic_image)["image"]
    cloth_image = self.transforms(image=cloth_image)["image"]
    output_img_image = self.label_transforms(image=output_img_image)["image"]

    return agnostic_image, cloth_image, output_img_image.float()

valid_dataset = TrainDataset(valid_file_names, valid_folder, train_transforms, label_transforms)
valid_dataloader = DataLoader(valid_dataset, batch_size=BATCH_SIZE, shuffle=False)

# Model

In [7]:
class DoubleConv(nn.Module):
  def __init__(self, in_channels, out_channels, kernel_size=3, padding=1):
    super(DoubleConv, self).__init__()

    self.in_channels = in_channels
    self.out_channels = out_channels
    self.kernel_size = kernel_size
    self.padding = padding

    self.double_conv = nn.Sequential(
        nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, padding=padding),
        nn.BatchNorm2d(out_channels),
        nn.ReLU(inplace=True),
        nn.Conv2d(out_channels, out_channels, kernel_size=kernel_size, padding=padding),
        nn.BatchNorm2d(out_channels),
        nn.ReLU(inplace=True)
    )

  def forward(self, X):
    X = self.double_conv(X)
    
    return X

In [8]:
class DownSample(nn.Module):
  def __init__(self, in_channels, out_channels, kernel_size=3, padding=1):
    super(DownSample, self).__init__()

    self.in_channels = in_channels
    self.out_channels = out_channels
    self.kernel_size = kernel_size
    self.padding = padding

    self.double_conv = DoubleConv(in_channels, out_channels, kernel_size=kernel_size, padding=padding)
    self.pool = nn.MaxPool2d(kernel_size=2, stride=2)

  def forward(self, X):
    X = self.double_conv(X)
    X_pooled = self.pool(X)

    return X, X_pooled

In [9]:
class UpSample(nn.Module):
  def __init__(self, in_channels, out_channels, kernel_size=3, padding=1):
    super(UpSample, self).__init__()

    self.in_channels = in_channels
    self.out_channels = out_channels
    self.kernel_size = kernel_size
    self.padding = padding

    self.up = nn.ConvTranspose2d(in_channels, out_channels, kernel_size=2, stride=2)
    self.double_conv = DoubleConv(in_channels, out_channels, kernel_size=kernel_size, padding=padding)

  def forward(self, X, X_skip):
    X = self.up(X)

    X_cat = torch.cat((X, X_skip), dim=1)
    X_cat = self.double_conv(X_cat)

    return X_cat

In [10]:
class UNet(nn.Module):
  def __init__(self, filter_sizes):
    super(UNet, self).__init__()

    self.filter_sizes = filter_sizes
    self.down_sample_blocks = []

    current_in_channels = 6
    for filter_size in filter_sizes:
      self.down_sample_blocks.append(DownSample(current_in_channels, filter_size, kernel_size=3, padding=1).to(device))
      current_in_channels = filter_size

    self.bottleneck = DoubleConv(filter_sizes[-1], filter_sizes[-1]*2, kernel_size=3, padding=1)

    self.up_sample_blocks = []
    for filter_size in filter_sizes[::-1]:
      self.up_sample_blocks.append(UpSample(filter_size*2, filter_size, kernel_size=3, padding=1).to(device))

    self.out_conv = nn.Conv2d(filter_sizes[0], 3, kernel_size=1)


  def forward(self, X_agnostic, X_cloth):
    X = torch.cat((X_agnostic, X_cloth), dim=1)

    X1_skip, X = self.down_sample_blocks[0](X)
    X2_skip, X = self.down_sample_blocks[1](X)
    X3_skip, X = self.down_sample_blocks[2](X)
    X4_skip, X = self.down_sample_blocks[3](X)

    X = self.bottleneck(X)

    X = self.up_sample_blocks[0](X, X4_skip)
    X = self.up_sample_blocks[1](X, X3_skip)
    X = self.up_sample_blocks[2](X, X2_skip)
    X = self.up_sample_blocks[3](X, X1_skip)

    X = self.out_conv(X)

    return X

In [11]:
model = UNet([64, 128, 256, 512]).to(device)

# Training

## Train Loop

In [12]:
loss_fn = torch.nn.MSELoss()
optimizer = Adam(model.parameters(), lr=LR)

In [13]:
from tqdm import tqdm

def train(model, dataloader, loss_fn, optimizer, epoch):
  model.train()

  batch_losses = []

  pbar = tqdm(dataloader, unit="batch", leave=True, desc=f"Training : Epoch [{epoch+1}/{EPOCHS}]")
  for batch_idx, (X1, X2, target) in enumerate(pbar):
    X1 = X1.to(device)
    X2 = X2.to(device)
    target = target.to(device)

    prediction = model(X1, X2)
    loss = loss_fn(prediction, target)

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    batch_losses.append(loss.item())
    pbar.set_postfix({"Batch Loss": loss.item()})

  return batch_losses

## Validation Loop

In [14]:
def valid(model, dataloader, loss_fn, epoch):
  model.eval()

  batch_losses = []

  pbar = tqdm(dataloader, unit="batch", leave=True, desc=f"Validation : Epoch [{epoch+1}/{EPOCHS}]")
  for batch_idx, (X1, X2, target) in enumerate(pbar):
    X1 = X1.to(device)
    X2 = X2.to(device)
    target = target.to(device)

    with torch.inference_mode():
      prediction = model(X1, X2)
      loss = loss_fn(prediction, target)

    batch_losses.append(loss.item())
    pbar.set_postfix({"Batch Loss": loss.item()})

  return batch_losses

In [15]:
os.makedirs(f"saved-models/{experiment_name}", exist_ok=True)

for epoch in range(EPOCHS):
  train_losses = train(model, train_dataloader, loss_fn, optimizer, epoch)
  train_epoch_loss = sum(train_losses) / len(train_losses)
  print(f"Epoch [{epoch+1}/{EPOCHS}]\tLoss: {train_epoch_loss}")

  valid_losses = valid(model, valid_dataloader, loss_fn, epoch)
  valid_epoch_loss = sum(valid_losses) / len(valid_losses)
  print(f"Epoch [{epoch+1}/{EPOCHS}]\tLoss: {valid_epoch_loss}\n")

  wandb.log({"train_epoch_loss": train_epoch_loss, "valid_epoch_loss": valid_epoch_loss})
  torch.save(model.state_dict(), f"saved-models/epoch - {epoch}")

wandb.finish()

Training : Epoch [1/20]: 100%|██████████| 728/728 [20:01<00:00,  1.65s/batch, Batch Loss=3.75e+4]


Epoch [1/20]	Loss: 38216.4814453125


Validation : Epoch [1/20]: 100%|██████████| 127/127 [02:31<00:00,  1.19s/batch, Batch Loss=4e+4]


Epoch [1/20]	Loss: 37888.46465920276



Training : Epoch [2/20]: 100%|██████████| 728/728 [14:46<00:00,  1.22s/batch, Batch Loss=3.68e+4]


Epoch [2/20]	Loss: 37632.2126518501


Validation : Epoch [2/20]: 100%|██████████| 127/127 [01:39<00:00,  1.28batch/s, Batch Loss=3.94e+4]


Epoch [2/20]	Loss: 37307.428672490154



Training : Epoch [3/20]: 100%|██████████| 728/728 [14:47<00:00,  1.22s/batch, Batch Loss=3.63e+4]


Epoch [3/20]	Loss: 37056.22685761504


Validation : Epoch [3/20]: 100%|██████████| 127/127 [01:42<00:00,  1.25batch/s, Batch Loss=3.88e+4]


Epoch [3/20]	Loss: 36735.55041215551



Training : Epoch [4/20]: 100%|██████████| 728/728 [14:54<00:00,  1.23s/batch, Batch Loss=3.48e+4]


Epoch [4/20]	Loss: 36488.90303056319


Validation : Epoch [4/20]: 100%|██████████| 127/127 [01:40<00:00,  1.27batch/s, Batch Loss=3.82e+4]


Epoch [4/20]	Loss: 36172.77448326772



Training : Epoch [5/20]: 100%|██████████| 728/728 [14:52<00:00,  1.23s/batch, Batch Loss=3.61e+4]


Epoch [5/20]	Loss: 35930.54759669042


Validation : Epoch [5/20]: 100%|██████████| 127/127 [01:40<00:00,  1.26batch/s, Batch Loss=3.76e+4]


Epoch [5/20]	Loss: 35617.9702878937



Training : Epoch [6/20]: 100%|██████████| 728/728 [14:50<00:00,  1.22s/batch, Batch Loss=3.66e+4]


Epoch [6/20]	Loss: 35380.91757436899


Validation : Epoch [6/20]: 100%|██████████| 127/127 [01:40<00:00,  1.26batch/s, Batch Loss=3.7e+4]


Epoch [6/20]	Loss: 35070.738004429135



Training : Epoch [7/20]: 100%|██████████| 728/728 [14:54<00:00,  1.23s/batch, Batch Loss=3.65e+4]


Epoch [7/20]	Loss: 34835.583471947975


Validation : Epoch [7/20]: 100%|██████████| 127/127 [01:40<00:00,  1.26batch/s, Batch Loss=3.65e+4]


Epoch [7/20]	Loss: 34532.99953863189



Training : Epoch [8/20]: 100%|██████████| 728/728 [15:01<00:00,  1.24s/batch, Batch Loss=3.3e+4]


Epoch [8/20]	Loss: 34300.87962794042


Validation : Epoch [8/20]: 100%|██████████| 127/127 [01:44<00:00,  1.22batch/s, Batch Loss=3.59e+4]


Epoch [8/20]	Loss: 34000.692882627955



Training : Epoch [9/20]: 100%|██████████| 728/728 [15:08<00:00,  1.25s/batch, Batch Loss=3.41e+4]


Epoch [9/20]	Loss: 33773.98232797476


Validation : Epoch [9/20]: 100%|██████████| 127/127 [01:44<00:00,  1.22batch/s, Batch Loss=3.54e+4]


Epoch [9/20]	Loss: 33476.09767162894



Training : Epoch [10/20]: 100%|██████████| 728/728 [15:00<00:00,  1.24s/batch, Batch Loss=3.33e+4]


Epoch [10/20]	Loss: 33254.9622869806


Validation : Epoch [10/20]: 100%|██████████| 127/127 [01:41<00:00,  1.25batch/s, Batch Loss=3.49e+4]


Epoch [10/20]	Loss: 32959.76470226378



Training : Epoch [11/20]: 100%|██████████| 728/728 [14:53<00:00,  1.23s/batch, Batch Loss=3.14e+4]


Epoch [11/20]	Loss: 32743.916662195228


Validation : Epoch [11/20]: 100%|██████████| 127/127 [01:41<00:00,  1.25batch/s, Batch Loss=3.43e+4]


Epoch [11/20]	Loss: 32452.439960629923



Training : Epoch [12/20]: 100%|██████████| 728/728 [14:53<00:00,  1.23s/batch, Batch Loss=3.23e+4]


Epoch [12/20]	Loss: 32239.837246737636


Validation : Epoch [12/20]: 100%|██████████| 127/127 [01:41<00:00,  1.25batch/s, Batch Loss=3.38e+4]


Epoch [12/20]	Loss: 31950.83146222933



Training : Epoch [13/20]: 100%|██████████| 728/728 [14:58<00:00,  1.23s/batch, Batch Loss=2.99e+4]


Epoch [13/20]	Loss: 31741.007160564044


Validation : Epoch [13/20]: 100%|██████████| 127/127 [01:40<00:00,  1.27batch/s, Batch Loss=3.33e+4]


Epoch [13/20]	Loss: 31456.17413570374



Training : Epoch [14/20]: 100%|██████████| 728/728 [15:04<00:00,  1.24s/batch, Batch Loss=3.05e+4]


Epoch [14/20]	Loss: 31251.65168913118


Validation : Epoch [14/20]: 100%|██████████| 127/127 [01:40<00:00,  1.26batch/s, Batch Loss=3.28e+4]


Epoch [14/20]	Loss: 30971.869156003937



Training : Epoch [15/20]: 100%|██████████| 728/728 [15:05<00:00,  1.24s/batch, Batch Loss=3.06e+4]


Epoch [15/20]	Loss: 30768.13940697974


Validation : Epoch [15/20]: 100%|██████████| 127/127 [01:40<00:00,  1.26batch/s, Batch Loss=3.23e+4]


Epoch [15/20]	Loss: 30492.00610543799



Training : Epoch [16/20]: 100%|██████████| 728/728 [15:06<00:00,  1.24s/batch, Batch Loss=3.07e+4]


Epoch [16/20]	Loss: 30292.467682220125


Validation : Epoch [16/20]: 100%|██████████| 127/127 [01:46<00:00,  1.19batch/s, Batch Loss=3.18e+4]


Epoch [16/20]	Loss: 30020.462967519685



Training : Epoch [17/20]: 100%|██████████| 728/728 [15:05<00:00,  1.24s/batch, Batch Loss=3.08e+4]


Epoch [17/20]	Loss: 29826.432324755322


Validation : Epoch [17/20]: 100%|██████████| 127/127 [01:41<00:00,  1.26batch/s, Batch Loss=3.13e+4]


Epoch [17/20]	Loss: 29557.746278297243



Training : Epoch [18/20]: 100%|██████████| 728/728 [15:00<00:00,  1.24s/batch, Batch Loss=2.85e+4]


Epoch [18/20]	Loss: 29367.734069153503


Validation : Epoch [18/20]: 100%|██████████| 127/127 [01:41<00:00,  1.25batch/s, Batch Loss=3.08e+4]


Epoch [18/20]	Loss: 29102.039047121063



Training : Epoch [19/20]: 100%|██████████| 728/728 [14:57<00:00,  1.23s/batch, Batch Loss=2.8e+4]


Epoch [19/20]	Loss: 28914.694907387533


Validation : Epoch [19/20]: 100%|██████████| 127/127 [01:41<00:00,  1.25batch/s, Batch Loss=3.04e+4]


Epoch [19/20]	Loss: 28653.887718380905



Training : Epoch [20/20]: 100%|██████████| 728/728 [14:56<00:00,  1.23s/batch, Batch Loss=2.75e+4]


Epoch [20/20]	Loss: 28472.257689088256


Validation : Epoch [20/20]: 100%|██████████| 127/127 [01:40<00:00,  1.26batch/s, Batch Loss=2.99e+4]


Epoch [20/20]	Loss: 28214.060254675198



[34m[1mwandb[0m:                                                                                
[34m[1mwandb[0m: 
[34m[1mwandb[0m: Run history:
[34m[1mwandb[0m: train_epoch_loss ██▇▇▆▆▆▅▅▄▄▄▃▃▃▂▂▂▁▁
[34m[1mwandb[0m: valid_epoch_loss ██▇▇▆▆▆▅▅▄▄▄▃▃▃▂▂▂▁▁
[34m[1mwandb[0m: 
[34m[1mwandb[0m: Run summary:
[34m[1mwandb[0m: train_epoch_loss 28472.25769
[34m[1mwandb[0m: valid_epoch_loss 28214.06025
[34m[1mwandb[0m: 
[34m[1mwandb[0m: 🚀 View run [33mv0-baseline-0.3[0m at: [34m[4mhttps://wandb.ai/sashanktalakola/viton/runs/pd0tzdee[0m
[34m[1mwandb[0m: ⭐️ View project at: [34m[4mhttps://wandb.ai/sashanktalakola/viton[0m
[34m[1mwandb[0m: Synced 5 W&B file(s), 0 media file(s), 0 artifact file(s) and 0 other file(s)
[34m[1mwandb[0m: Find logs at: [35m[1m./wandb/run-20240607_035513-pd0tzdee/logs[0m
