In [1]:
import os
import json
import torch
import clip
import numpy as np
from torch.utils.data import Dataset, DataLoader, random_split
from PIL import Image
from torch import nn, optim
import matplotlib.pyplot as plt

In [2]:
# Dataset
class FashionDataset(Dataset):
    def __init__(self, img_dir, styles_dir, transform=None):
        self.img_dir = img_dir
        self.styles_dir = styles_dir
        self.transform = transform
        self.images = []
        self.descriptions = []
        self.labels = []

        for img_file in os.listdir(self.img_dir):
            if img_file.endswith(".jpg"):
                img_path = os.path.join(self.img_dir, img_file)
                json_path = os.path.join(self.styles_dir, img_file.replace('.jpg', '.json'))
                if os.path.exists(json_path):
                    with open(json_path, 'r') as f:
                        style_data = json.load(f)
                    description = style_data.get("data", {}).get("productDescriptors", {}).get("description", {}).get("value", '')
                    category = style_data.get("data", {}).get("subCategory", {}).get("typeName", '')
                    
                    self.images.append(img_path)
                    self.descriptions.append(description)
                    self.labels.append(category)

        self.text = clip.tokenize(self.labels)


    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        img_path = self.images[idx]
        description = self.descriptions[idx]
        label = self.labels[idx]
        text = self.text[idx]

        image = Image.open(img_path).convert('RGB')
        if self.transform:
            image = self.transform(image)
        
        return image, description, label, text



In [3]:
# Load CLIP model and preprocess
device = "cuda:0" if torch.cuda.is_available() else "cpu"
model, preprocess = clip.load("ViT-B/32", device=device, jit=False)

In [4]:
# Prepare dataset
img_dir = "./data/fashion-dataset/images/"
styles_dir = "./data/fashion-dataset/styles/"
dataset = FashionDataset(img_dir, styles_dir, transform=preprocess)

In [6]:
# # display a transformed image
# transformed_image, description, label, text = dataset[0]
# transformed_image = transformed_image.permute(1, 2, 0).numpy() 

# plt.imshow(transformed_image)
# plt.show()

# print("label: ", label, ", text: ", text.shape)

In [7]:
# Split the dataset into training and validation sets
train_size = int(0.08 * len(dataset))
val_size = int(0.02 * len(dataset))
remain_size = len(dataset) - train_size - val_size
train_dataset, val_dataset, remain_dataset = random_split(dataset, [train_size, val_size, remain_size], generator=torch.Generator().manual_seed(123))

# train_dataset = torch.utils.data.Subset(train_dataset, list(range(500))) 
# val_dataset = torch.utils.data.Subset(val_dataset, list(range(100)))

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False)
print("Train size: ", train_size, ", Val size: ", val_size)

Train size:  3555 , Val size:  888


In [8]:
# Define the loss function and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.AdamW(model.parameters(), lr=1e-4)

# # Create a label-to-index mapping
unique_labels = list(set(dataset.labels))
# label_to_index = {label: idx for idx, label in enumerate(unique_labels)}

In [9]:
def convert_models_to_fp32(model): 
    for p in model.parameters(): 
        p.data = p.data.float() 
        p.grad.data = p.grad.data.float() 

# Lists to store accuracy values for plotting
train_accuracies = []
val_accuracies = []

# Training loop
num_epochs = 50
for epoch in range(num_epochs):
    model.train()
    running_loss = 0.0
    correct_predictions = 0
    total_predictions = 0
    for images, _, labels, text in train_loader:
        images = images.to(device)
        text = text.to(device)

        # Forward pass
        logits_per_image, logits_per_text = model(images, text)

        # Compute the loss
        ground_truth = torch.arange(len(images), dtype=torch.long, device=device)
        total_loss = (criterion(logits_per_image, ground_truth) + criterion(logits_per_text, ground_truth)) / 2

        # Backward pass and optimization
        optimizer.zero_grad()
        total_loss.backward()
        convert_models_to_fp32(model)
        optimizer.step()
        clip.model.convert_weights(model)

        running_loss += total_loss.item()

    # Validation loop
    model.eval()
    val_loss = 0.0
    correct_predictions = 0
    total_predictions = 0
    with torch.no_grad():
        for images, _, labels, text in val_loader:
            images = images.to(device)
            text = text.to(device)

            # Forward pass
            logits_per_image, logits_per_text = model(images, text)

            # Compute the loss
            ground_truth = torch.arange(len(images), dtype=torch.long, device=device)
            total_loss = (criterion(logits_per_image, ground_truth) + criterion(logits_per_text, ground_truth)) / 2
            val_loss += total_loss.item()

            # Calculate accuracy
            labels_text = torch.cat([clip.tokenize(f"{c}") for c in unique_labels]).to(device)
            # print("images: ", images)
            # print("labels_text: ", labels_text)
            image_features = model.encode_image(images)
            text_features = model.encode_text(labels_text)
            # print("image_features: ", image_features)
            # print("text_features: ", text_features)

            # Normalize image and text features
            image_features /= image_features.norm(dim=-1, keepdim=True)
            text_features /= text_features.norm(dim=-1, keepdim=True)

            # Calculate similarity scores
            similarity = (100.0 * image_features @ text_features.T).softmax(dim=-1)
            # print("similarity: ", similarity)
            # print("similarity: ", similarity.shape)
            values, indices = similarity.topk(1, dim=-1)
            # print("indices: ", indices.shape)
            indices = indices.squeeze()

            # print("indices: ", indices.shape)
            # print("values: ", values)
            predicted_labels = [unique_labels[i] for i in indices]
            # print("Predicted labels: ", predicted_labels)
            # print("Actual labels: ", labels)
            correct_predictions += sum([pred == actual for pred, actual in zip(predicted_labels, labels)])
            # correct_predictions += (unique_labels[indices] == labels).sum().item()
            total_predictions += len(labels)

    val_accuracy = correct_predictions / total_predictions
    print(f"Epoch [{epoch + 1}/{num_epochs}], Train Loss: {running_loss/train_size:.4f}, Val Loss: {val_loss/val_size:.4f}, Val Accuracy: {val_accuracy:.4f}")

print("Training complete.")

Epoch [1/50], Train Loss: 0.1091, Val Loss: 0.1088, Val Accuracy: 0.0000
Epoch [2/50], Train Loss: 0.1030, Val Loss: 0.0904, Val Accuracy: 0.3468
Epoch [3/50], Train Loss: 0.0851, Val Loss: 0.0787, Val Accuracy: 0.4617
Epoch [4/50], Train Loss: 0.0776, Val Loss: 0.0778, Val Accuracy: 0.3164
Epoch [5/50], Train Loss: 0.0705, Val Loss: 0.0751, Val Accuracy: 0.4200
Epoch [6/50], Train Loss: 0.0665, Val Loss: 0.0662, Val Accuracy: 0.6160
Epoch [7/50], Train Loss: 0.0631, Val Loss: 0.0701, Val Accuracy: 0.6655
Epoch [8/50], Train Loss: 0.0603, Val Loss: 0.0652, Val Accuracy: 0.6047
Epoch [9/50], Train Loss: 0.0575, Val Loss: 0.0721, Val Accuracy: 0.6610
Epoch [10/50], Train Loss: 0.0564, Val Loss: 0.0689, Val Accuracy: 0.5912
Epoch [11/50], Train Loss: 0.0546, Val Loss: 0.0678, Val Accuracy: 0.7016
Epoch [12/50], Train Loss: 0.0541, Val Loss: 0.0680, Val Accuracy: 0.7095
Epoch [13/50], Train Loss: 0.0524, Val Loss: 0.0650, Val Accuracy: 0.7331
Epoch [14/50], Train Loss: 0.0508, Val Loss: 0.

In [None]:
# Plotting training and validation accuracy
plt.figure(figsize=(10, 5))
plt.plot(range(1, num_epochs + 1), train_accuracies, label='Train Accuracy')
plt.plot(range(1, num_epochs + 1), val_accuracies, label='Validation Accuracy')
plt.xlabel('Epochs')
plt.ylabel('Accuracy')
plt.title('Training and Validation Accuracy Over Epochs')
plt.legend()
plt.show()