In [None]:
import numpy as np
import os
import matplotlib.pyplot as plt
from tqdm import tqdm

import torch
import torch.nn as nn
import torch.nn.functional as F
from datasets import load_dataset, concatenate_datasets
from torch.amp import autocast, GradScaler
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as v2
from transformers import AutoImageProcessor, SwinModel, SwinConfig
from huggingface_hub import PyTorchModelHubMixin

In [None]:
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
print(device)

In [None]:
import huggingface_hub

huggingface_hub.login('<token>')

In [None]:
onthelook_dataset = load_dataset('yainage90/onthelook-fashion-anchor-positive-images')
kream_dataset = load_dataset('yainage90/kream-fashion-anchor-positive-images')
dataset = concatenate_datasets([onthelook_dataset['train'], kream_dataset['train']])
dataset = dataset.train_test_split(test_size=0.05, shuffle=True, stratify_by_column='category')

In [None]:
ckpt = "microsoft/swin-base-patch4-window7-224"
image_processor = AutoImageProcessor.from_pretrained(ckpt)
config = SwinConfig.from_pretrained(ckpt)
print(config.image_size)
print(config.hidden_size)

In [None]:
labels = ['bag', 'bottom', 'hat', 'outer', 'shoes', 'top']

In [None]:
class CustomDataset(Dataset):
    def __init__(self, dataset, label, config, image_processor):
        self.dataset = dataset.select(np.where(np.array(dataset['category']) == labels.index(label))[0])
        self.image_processor = image_processor

        self.transform = v2.Compose(
            [
                v2.Resize((config.image_size, config.image_size)),
                v2.RandomHorizontalFlip(p=0.5),
                v2.RandomApply([v2.RandomRotation(degrees=(-90, 90))], p=0.3),
                v2.RandomApply([v2.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1)], p=0.2),
                v2.RandomApply([v2.RandomAdjustSharpness(sharpness_factor=2)], p=0.1),
                v2.RandomApply([v2.GaussianBlur(kernel_size=(5, 9), sigma=(0.1, 5))], p=0.1),
                v2.ToTensor(),
                v2.RandomErasing(p=0.1, scale=(0.02, 0.33), ratio=(0.3, 3.3)),
                v2.Normalize(mean=self.image_processor.image_mean, std=self.image_processor.image_std),
            ]
        )

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, i):
        data = self.dataset[i]

        anchor_image = self.transform(data['anchor_image'])
        positive_image = self.transform(data['positive_image'])

        return anchor_image, positive_image

In [None]:
class ImageEncoder(nn.Module, PyTorchModelHubMixin):
    def __init__(self):
        super(ImageEncoder, self).__init__()
        self.swin = SwinModel(config=config)
        self.embedding_layer = nn.Linear(config.hidden_size, 128)

    def forward(self, image_tensor):
        features = self.swin(image_tensor).pooler_output
        embeddings = self.embedding_layer(features)
        embeddings = F.normalize(embeddings, p=2, dim=1)
        
        return embeddings

    
class ContrastiveLoss(nn.Module):
    def __init__(self, margin=1.0, temperature=0.07):
        super(ContrastiveLoss, self).__init__()
        self.margin = margin
        self.temperature = temperature
        
    def forward(self, anchor, positive):
        
        batch_size = anchor.shape[0]
        similarity_matrix = torch.matmul(anchor, positive.T) / self.temperature
        labels = torch.arange(batch_size, device=anchor.device)
        loss = F.cross_entropy(similarity_matrix, labels)
        
        return loss

In [None]:
scaler = GradScaler()

def train(model, dataloader, criterion, optimizer):
    model.train()
    total_loss = 0

    for batch in tqdm(dataloader):
        optimizer.zero_grad()
        
        anchor, positive = batch
        
        with autocast(device.type):
            anchor_emb = model(anchor.to(device))
            positive_emb = model(positive.to(device))
            loss = criterion(anchor_emb, positive_emb)

            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()
        
            total_loss += loss.item()
    
    return total_loss / len(dataloader)


def evaluate(model, dataloader, criterion):
    model.eval()
    total_loss = 0
    for batch in tqdm(dataloader):
        anchor, positive = batch
        with torch.no_grad():
            anchor_emb = model(anchor.to(device))
            positive_emb = model(positive.to(device))

            loss = criterion(anchor_emb, positive_emb)
            total_loss += loss.item()

    return total_loss / len(dataloader)

In [None]:
def train_label_model(label):
    train_dataset = CustomDataset(dataset['train'], label=label, config=config, image_processor=image_processor)
    test_dataset = CustomDataset(dataset['test'], label=label, config=config, image_processor=image_processor)

    train_dataloader = DataLoader(train_dataset, batch_size=32, shuffle=True, num_workers=8, pin_memory=True)
    test_dataloader = DataLoader(test_dataset, batch_size=32, shuffle=True, num_workers=8, pin_memory=True)

    print(f'train: {len(train_dataset)}, test: {len(test_dataset)}')

    model = ImageEncoder().to(device)
    criterion = ContrastiveLoss().to(device)
    optimizer = torch.optim.AdamW(model.parameters(), lr=5e-5, eps=1e-8)

    model_save_dir = f'./model_ckpt/{label}'
    if not os.path.isdir(model_save_dir):
        os.makedirs(model_save_dir)

    epoch = 0
    min_loss = float("inf")
    early_stop_counter = 0

    for epoch in range(epoch + 1, 30):
        train_loss = train(model, train_dataloader, criterion, optimizer)
        val_loss = evaluate(model, test_dataloader, criterion)

        print(f'Epoch: {epoch}, Train loss: {train_loss:.4f}, Eval loss: {val_loss:.4f}')

        if val_loss < min_loss:
            model.save_pretrained(model_save_dir)
            config.save_pretrained(model_save_dir)
            image_processor.save_pretrained(model_save_dir)
            min_loss = val_loss
            early_stop_counter = 0
        else:
            early_stop_counter += 1
            if early_stop_counter == 3:
                print(f"Early stopped at epoch {epoch + 1}")
                break

In [None]:
for label in labels:
    train_label_model(label)