# imCLR-Style Contrastive Learning
Setting up self-supervised visual representation learning on outfit images using a pre-trained ResNet50 encoder and SimCLR-style contrastive learning. This approach leverages both original and segmented images to create augmented pairs for contrastive training.

In [1]:
import os

import matplotlib.pyplot as plt
import numpy as np
import torch
import torchvision
import torchvision.transforms as transforms
from PIL import Image
from sklearn.decomposition import PCA
from sklearn.manifold import TSNE
from torch.utils.data import DataLoader
from torch.utils.data import Dataset

from src import config

In [2]:
print("Using device:", config.DEVICE)

Using device: mps


In [3]:
# Load pre-trained ResNet-50
encoder = torchvision.models.resnet50(pretrained=True)
for name, param in encoder.named_parameters():
    if "layer4" in name or "fc" in name: # "layer3" in name or 
        param.requires_grad = True
    else:
        param.requires_grad = False

encoder.fc = torch.nn.Identity()
encoder = encoder.to(config.DEVICE)



In [4]:
# Image transformations
contrastive_transform = transforms.Compose([
    transforms.RandomResizedCrop(224),
    transforms.RandomHorizontalFlip(),
    transforms.ColorJitter(0.4, 0.4, 0.4, 0.1),
    transforms.RandomGrayscale(p=0.2),
    transforms.GaussianBlur(3),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

In [5]:
optimizer = torch.optim.Adam(encoder.parameters(), lr=1e-4)

In [6]:
# Load checkpoint if resuming
if config.RESUME_CHECKPOINT and os.path.exists(os.path.join(config.CHECKPOINT_PATH, f"contrastive_encoder.pth")):
    checkpoint = torch.load(os.path.join(config.CHECKPOINT_PATH, f"contrastive_encoder.pth"))
    encoder.load_state_dict(checkpoint['model_state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    start_epoch = checkpoint['epoch'] + 1
    print(f"Resuming training from epoch {start_epoch}")

Resuming training from epoch 100


In [7]:
class CustomVisualizationDataset(torch.utils.data.Dataset):
    def __init__(self, image_paths, labels, transform=None):
        self.image_paths = image_paths
        self.labels = labels
        self.transform = transform

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        img_path = self.image_paths[idx]
        image = Image.open(img_path).convert('RGB')
        label = self.labels[idx]
        if self.transform:
            image = self.transform(image)
        return image, label

In [8]:
or_pos_dir = config.ORIGINAL_POS_OUTFITS_DIR
or_neg_dir = config.ORIGINAL_NEG_OUTFITS_DIR
seg_pos_dir = config.SEGMENTED_POS_OUTFITS_DIR
seg_neg_dir = config.SEGMENTED_NEG_OUTFITS_DIR

image_paths = []
labels = []
for class_idx, folder in enumerate([seg_neg_dir, seg_pos_dir]):  # 0=negative, 1=positive
    for img_name in os.listdir(folder):
        if img_name.lower().endswith(config.IMAGE_FILE_EXTENSIONS):  # Filter images
            img_path = os.path.join(folder, img_name)
            image_paths.append(img_path)
            labels.append(class_idx)

dataset = CustomVisualizationDataset(image_paths, labels, contrastive_transform)
loader = DataLoader(dataset, batch_size=64, shuffle=False)

In [9]:
encoder.eval()
embeddings = []
labels_list = []

with torch.no_grad():
    for images, batch_labels in loader:
        images = images.to(config.DEVICE)
        z = encoder(images)
        embeddings.append(z.cpu().numpy())
        labels_list.extend(batch_labels.numpy())

embeddings = np.concatenate(embeddings, axis=0)
labels_array = np.array(labels_list)

In [29]:
import umap
from matplotlib.lines import Line2D

umap_result = umap.UMAP(n_components=2, random_state=42).fit_transform(embeddings)

color_map = np.array(['blue', 'orange']) 
plt.figure(figsize=(8, 8))
scatter = plt.scatter(umap_result[:, 0], umap_result[:, 1], c=color_map[labels], cmap='coolwarm', alpha=0.7)

legend_elements = [
    Line2D([0], [0], marker='o', color='w', label='Bad Outfit', markerfacecolor='blue', markersize=10),
    Line2D([0], [0], marker='o', color='w', label='Good Outfit', markerfacecolor='orange', markersize=10)
]
plt.legend(handles=legend_elements, title='Outfit Quality')

plt.title('2D UMAP Visualization of Image Embeddings')
plt.xlabel('Dimension 1')
plt.ylabel('Dimension 2')
plt.savefig('embeddings-2d-visualization.png')
plt.close()

  warn(
  scatter = plt.scatter(umap_result[:, 0], umap_result[:, 1], c=color_map[labels], cmap='coolwarm', alpha=0.7)


In [30]:
umap_result = umap.UMAP(n_components=3, random_state=42).fit_transform(embeddings)

color_map = np.array(['blue', 'orange']) 

fig = plt.figure(figsize=(10, 8))
ax = fig.add_subplot(111, projection='3d')

# Create scatter plot
scatter = ax.scatter(
    umap_result[:, 0], umap_result[:, 1], umap_result[:, 2],
    c=color_map[labels], alpha=0.7
)

# Add axis labels
ax.set_xlabel('Dimension 1')
ax.set_ylabel('Dimension 2')
ax.set_zlabel('Dimension 3')
plt.title('3D UMAP Visualization of Embeddings')

legend_elements = [
    Line2D([0], [0], marker='o', color='w', label='Bad Outfit', markerfacecolor='blue', markersize=10),
    Line2D([0], [0], marker='o', color='w', label='Good Outfit', markerfacecolor='orange', markersize=10)
]
ax.legend(handles=legend_elements, title='Outfit Quality')

plt.tight_layout()
plt.savefig('embeddings-3d-visualization.png')
plt.close()

  warn(


In [None]:
encoder.eval()
embeddings = []
labels_list = []

with torch.no_grad():
    for images, batch_labels in loader:
        images = images.to(config.DEVICE)
        z = encoder(images)
        embeddings.append(z.cpu().numpy())
        labels_list.extend(batch_labels.numpy())

embeddings = np.concatenate(embeddings, axis=0)
labels_array = np.array(labels_list)

In [None]:
test_image_files = [
    f for f in os.listdir(config.TEST_DIR)
    if f.lower().endswith(config.IMAGE_FILE_EXTENSIONS)
]

test_image_files.sort()

correct_count = 0
faulty_count = 0

for i, image_file in enumerate(test_image_files):
    image_path = os.path.join(config.TEST_DIR, image_file)
    image_tensor = preprocess_image(image_path).to(config.DEVICE)

    with torch.no_grad():
        output = model(image_tensor)
        probabilities = torch.nn.functional.softmax(output, dim=1)
        predicted_class = torch.argmax(probabilities, 1).item()

    class_names = ['bad', 'good']  # 0=negative, 1=positive
    print(f"Image {i+1}: {image_file}")
    print(f"Predicted: {class_names[predicted_class]}")
    print(f"Confidence: {probabilities[0][predicted_class].item():.2%}")
    
    # Condition checks
    if (class_names[predicted_class] == 'good' and 'good' in image_file) or \
       (class_names[predicted_class] == 'bad' and 'bad' in image_file):
        print("classification correct\n")
        correct_count += 1
    else:
        print("classification faulty\n")
        faulty_count += 1

not_rated = 0

# Final summary
print(f"Total correct classifications: {correct_count} out of {correct_count + faulty_count - not_rated} ({correct_count / (correct_count + faulty_count - not_rated) * 100:.2f}%)")
print(f"Total faulty classifications: {faulty_count - not_rated} out of {correct_count + faulty_count - not_rated} ({(faulty_count - not_rated) / (correct_count + faulty_count - not_rated) * 100:.2f}%)")