# imCLR-Style Contrastive Learning
Setting up self-supervised visual representation learning on outfit images using a pre-trained ResNet50 encoder and SimCLR-style contrastive learning. This approach leverages both original and segmented images to create augmented pairs for contrastive training.

In [1]:
import os
import random
import time

import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
from PIL import Image
from sklearn.manifold import TSNE
from torch.utils.data import Dataset

from src import config

In [2]:
print("Using device:", config.DEVICE)

Using device: mps


In [3]:
os.makedirs(config.CHECKPOINT_PATH, exist_ok=True)

In [4]:
# Load pre-trained ResNet-50
encoder = torchvision.models.resnet50(pretrained=True)
for name, param in encoder.named_parameters():
    if "layer4" in name or "fc" in name: # "layer3" in name or 
        param.requires_grad = True
    else:
        param.requires_grad = False

encoder.fc = torch.nn.Identity()
encoder = encoder.to(config.DEVICE)
# model_name = "arize-ai/resnet-50-fashion-mnist-quality-drift"
# model = AutoModelForImageClassification.from_pretrained(model_name)
# processor = AutoImageProcessor.from_pretrained(model_name)



In [5]:
def contrastive_loss(z1, z2, temperature=0.5):
    z1 = F.normalize(z1, dim=1)
    z2 = F.normalize(z2, dim=1)
    representations = torch.cat([z1, z2], dim=0)
    similarity_matrix = F.cosine_similarity(representations.unsqueeze(1), representations.unsqueeze(0), dim=2)
    labels = torch.cat([torch.arange(z1.size(0)) for _ in range(2)], dim=0).to(config.DEVICE)
    labels = (labels.unsqueeze(0) == labels.unsqueeze(1)).float()
    logits = similarity_matrix / temperature
    loss = F.cross_entropy(logits, labels)
    return loss

def nt_xent_loss(z1, z2, temperature=0.5):
    batch_size = z1.size(0)
    z = torch.cat([z1, z2], dim=0)
    z = F.normalize(z, dim=1)
    similarity_matrix = torch.matmul(z, z.T)
    mask = torch.eye(2 * batch_size, dtype=torch.bool).to(z.device)
    similarity_matrix = similarity_matrix.masked_fill(mask, -9e15)
    positives = torch.cat([torch.arange(batch_size, 2*batch_size), torch.arange(0, batch_size)]).to(z.device)
    logits = similarity_matrix / temperature
    loss = F.cross_entropy(logits, positives)
    return loss

In [6]:
or_pos_dir = config.ORIGINAL_POS_OUTFITS_DIR
or_neg_dir = config.ORIGINAL_NEG_OUTFITS_DIR
seg_pos_dir = config.SEGMENTED_POS_OUTFITS_DIR
seg_neg_dir = config.SEGMENTED_NEG_OUTFITS_DIR

# Collect image paths and labels
image_paths = []
for class_idx, folder in enumerate([seg_neg_dir, seg_pos_dir]):
    for img_name in os.listdir(folder):
        if img_name.lower().endswith(config.IMAGE_FILE_EXTENSIONS):
            img_path = os.path.join(folder, img_name)
            image_paths.append(img_path)

random.seed(42)
random.shuffle(image_paths)

In [7]:
print(f"Images: {len(image_paths)}")

Images: 5347


In [8]:
class CustomImageDataset(Dataset):
    def __init__(self, image_paths, transform=None):
        self.image_paths = image_paths
        self.transform = transform

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        img_path = self.image_paths[idx]
        image = Image.open(img_path).convert('RGB')
        img1 = self.transform(image)
        img2 = self.transform(image)
        return img1, img2

In [9]:
# Image transformations
contrastive_transform = transforms.Compose([
    transforms.RandomResizedCrop(224),
    transforms.RandomHorizontalFlip(),
    transforms.ColorJitter(0.4, 0.4, 0.4, 0.1),
    transforms.RandomGrayscale(p=0.2),
    transforms.GaussianBlur(3),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

In [10]:
# Load datasets
train_dataset = CustomImageDataset(
    image_paths=image_paths,
    transform=contrastive_transform
)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=32, shuffle=True)

In [11]:
optimizer = torch.optim.Adam(encoder.parameters(), lr=1e-4)
start_epoch = 0
train_losses = []    # Track training loss per epoch

In [12]:
# Load checkpoint if resuming
if config.RESUME_CHECKPOINT and os.path.exists(os.path.join(config.CHECKPOINT_PATH, f"contrastive_encoder.pth")):
    checkpoint = torch.load(os.path.join(config.CHECKPOINT_PATH, f"contrastive_encoder.pth"))
    encoder.load_state_dict(checkpoint['model_state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    start_epoch = checkpoint['epoch'] + 1
    print(f"Resuming training from epoch {start_epoch}")

Resuming training from epoch 100


In [13]:
start = time.time()

for epoch in range(start_epoch, config.EPOCHS):
    encoder.train()
    total_loss = 0.0
    for img1, img2 in train_loader:
        img1, img2 = img1.to(config.DEVICE), img2.to(config.DEVICE)
        
        z1, z2 = encoder(img1), encoder(img2)
        loss = contrastive_loss(z1, z2)

        optimizer.zero_grad() # Reset gradients
        loss.backward() # Compute gradients
        torch.nn.utils.clip_grad_norm_(encoder.parameters(), 1.0) # Clip gradients
        optimizer.step() # Update parameters
        
        total_loss += loss.item()
        train_losses.append(total_loss / len(train_loader))
    
    # Save checkpoint
    checkpoint = {
        'epoch': epoch,
        'model_state_dict': encoder.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'loss': loss.item()
    }
    torch.save(checkpoint, os.path.join(config.CHECKPOINT_PATH, "contrastive_encoder.pth"))
    print(f"Epoch {epoch+1} / {config.EPOCHS}, Loss: {total_loss / len(train_loader):.4f}")
    print(f"⏱️ Took {time.time() - start:.2f}s")


Epoch 1 / 100, Loss: 6.8237
⏱️ Took 173.23s
Epoch 2 / 100, Loss: 6.3213
⏱️ Took 342.28s
Epoch 3 / 100, Loss: 6.1844
⏱️ Took 519.77s
Epoch 4 / 100, Loss: 6.1326
⏱️ Took 700.48s
Epoch 5 / 100, Loss: 6.1012
⏱️ Took 1403.73s
Epoch 6 / 100, Loss: 6.0745
⏱️ Took 1622.85s
Epoch 7 / 100, Loss: 6.0651
⏱️ Took 1794.31s
Epoch 8 / 100, Loss: 6.0307
⏱️ Took 1974.78s
Epoch 9 / 100, Loss: 6.0045
⏱️ Took 2165.32s
Epoch 10 / 100, Loss: 5.9931
⏱️ Took 2640.63s
Epoch 11 / 100, Loss: 5.9752
⏱️ Took 2813.08s
Epoch 12 / 100, Loss: 5.9519
⏱️ Took 3025.52s
Epoch 13 / 100, Loss: 5.9490
⏱️ Took 3265.00s
Epoch 14 / 100, Loss: 5.9434
⏱️ Took 3488.93s
Epoch 15 / 100, Loss: 5.9326
⏱️ Took 3716.48s
Epoch 16 / 100, Loss: 5.9353
⏱️ Took 3923.03s
Epoch 17 / 100, Loss: 5.9368
⏱️ Took 4147.85s
Epoch 18 / 100, Loss: 5.9210
⏱️ Took 5060.39s
Epoch 19 / 100, Loss: 5.9143
⏱️ Took 5233.22s
Epoch 20 / 100, Loss: 5.9066
⏱️ Took 5426.39s
Epoch 21 / 100, Loss: 5.8909
⏱️ Took 5629.67s
Epoch 22 / 100, Loss: 5.8985
⏱️ Took 5862.12s
E