In [1]:
!pip install ipywidgets pandas

Collecting ipywidgets
  Downloading ipywidgets-8.1.7-py3-none-any.whl.metadata (2.4 kB)
Collecting widgetsnbextension~=4.0.14 (from ipywidgets)
  Downloading widgetsnbextension-4.0.14-py3-none-any.whl.metadata (1.6 kB)
Collecting jupyterlab_widgets~=3.0.15 (from ipywidgets)
  Downloading jupyterlab_widgets-3.0.15-py3-none-any.whl.metadata (20 kB)
Downloading ipywidgets-8.1.7-py3-none-any.whl (139 kB)
Downloading jupyterlab_widgets-3.0.15-py3-none-any.whl (216 kB)
Downloading widgetsnbextension-4.0.14-py3-none-any.whl (2.2 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.2/2.2 MB[0m [31m69.1 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: widgetsnbextension, jupyterlab_widgets, ipywidgets
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m3/3[0m [ipywidgets]
[1A[2KSuccessfully installed ipywidgets-8.1.7 jupyterlab_widgets-3.0.15 widgetsnbextension-4.0.14


In [2]:
!pip install torchvision fashion-clip

Collecting fashion-clip
  Downloading fashion_clip-0.2.2-py3-none-any.whl.metadata (11 kB)
Collecting boto3>=1.10.50 (from fashion-clip)
  Downloading boto3-1.38.35-py3-none-any.whl.metadata (6.6 kB)
Collecting appdirs>=1.4.4 (from fashion-clip)
  Downloading appdirs-1.4.4-py2.py3-none-any.whl.metadata (9.0 kB)
Collecting pyarrow>=7.0.0 (from fashion-clip)
  Downloading pyarrow-20.0.0-cp312-cp312-manylinux_2_28_x86_64.whl.metadata (3.3 kB)
Collecting python-dotenv>=0.19.2 (from fashion-clip)
  Using cached python_dotenv-1.1.0-py3-none-any.whl.metadata (24 kB)
Collecting annoy>=1.17.0 (from fashion-clip)
  Downloading annoy-1.17.3.tar.gz (647 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m647.5/647.5 kB[0m [31m29.0 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25ldone
[?25hCollecting transformers>=4.26.1 (from fashion-clip)
  Downloading transformers-4.52.4-py3-none-any.whl.metadata (38 kB)
Collecting ipyplot>=1.1.1 (from fashion-clip)


In [3]:
'''
This script fine-tunes a projection head on top of a frozen FashionCLIP image encoder using a triplet loss. 
It is designed to learn better visual embeddings for fashion images by bringing similar images closer and pushing dissimilar ones apart.

Key components:

Dataset Loader: Loads anchor, positive, and negative images from a CSV file and applies transforms.

Projection Head: A simple linear layer that maps CLIP’s image features into a lower-dimensional normalized space.

Loss Function: InfoNCE-based triplet loss that encourages anchor-positive pairs to be closer than anchor-negative pairs.

Training Loop: For each epoch, it extracts image features using the frozen CLIP encoder, projects them, calculates the loss, 
and updates only the projection head.

Evaluation (optional): Calculates average validation loss on a held-out triplet set.
'''
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import transforms
from PIL import Image
import pandas as pd
import os
from fashion_clip.fashion_clip import FashionCLIP
from torch.utils.data import Dataset,DataLoader

In [None]:
# download dataset 
!pip install gdown
gdown --folder 1g-2bfL18NnH9lWxuiedlGOLlFPXr98Ur

In [21]:

class TripletFashionDataset(Dataset):
    def __init__(self, csv_file, image_folder, transform=None):
        self.data = pd.read_csv(csv_file)
        self.image_folder = image_folder
        self.transform = transform or transforms.Compose([
            transforms.Resize((224, 224)),
            transforms.CenterCrop(224),
            transforms.ToTensor(),
        ])

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        def find_image_path(name):
            for ext in ['.jpg', '.png']:
                path = os.path.join(self.image_folder, name + ext)
                if os.path.exists(path):
                    return path
            raise FileNotFoundError(f"Image {name} not found in supported formats.")

        a = find_image_path(self.data.iloc[idx, 0])
        p = find_image_path(self.data.iloc[idx, 1])
        n = find_image_path(self.data.iloc[idx, 2])

        return self.transform(Image.open(a).convert("RGB")), \
               self.transform(Image.open(p).convert("RGB")), \
               self.transform(Image.open(n).convert("RGB"))
        
train_dataset = TripletFashionDataset("train_triplets.csv","./fashion_images")
val_dataset = TripletFashionDataset("val_triplets.csv","./fashion_images")
train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=16, shuffle=False)

print(f"Train samples: {len(train_dataset)}")
print(f"Val samples: {len(val_dataset)}")

Train samples: 80
Val samples: 20


In [22]:
class ProjectionHead(nn.Module):
    def __init__(self, input_dim=512, output_dim=256):
        super().__init__()
        self.fc = nn.Linear(input_dim, output_dim)

    def forward(self, x):
        return F.normalize(self.fc(x), p=2, dim=-1)

def info_nce_loss(anchor, positive, negative, temperature=0.07):
    sim_ap = (anchor * positive).sum(dim=-1) / temperature
    sim_an = (anchor * negative).sum(dim=-1) / temperature
    logits = torch.cat([sim_ap.unsqueeze(1), sim_an.unsqueeze(1)], dim=1)
    labels = torch.zeros(anchor.size(0), dtype=torch.long, device=anchor.device)
    return F.cross_entropy(logits, labels)

In [23]:
def train(train_loader, val_loader, epochs=5, batch_size=16, lr=1e-4, output_path='projection_head.pth'):
    device = "cuda" if torch.cuda.is_available() else "cpu"

    # Load and freeze FashionCLIP
    fclip = FashionCLIP('fashion-clip')
    clip_model = fclip.model
    clip_model.eval()
    for param in clip_model.parameters():
        param.requires_grad = False

    # Projection head and optimizer
    projector = ProjectionHead(input_dim=512, output_dim=256).to(device)
    optimizer = torch.optim.Adam(projector.parameters(), lr=lr)

    def extract_features(anchor_img, positive_img, negative_img):
        """extract image features with CLIP."""
        anchor_img = anchor_img.to(device)
        positive_img = positive_img.to(device)
        negative_img = negative_img.to(device)

        with torch.no_grad():
            a_feat = clip_model.get_image_features(pixel_values=anchor_img)
            p_feat = clip_model.get_image_features(pixel_values=positive_img)
            n_feat = clip_model.get_image_features(pixel_values=negative_img)
        return a_feat, p_feat, n_feat

    for epoch in range(epochs):
        # === Training ===
        projector.train()
        total_train_loss = 0
        for anchor_img, positive_img, negative_img in train_loader:
            a_feat, p_feat, n_feat = extract_features(anchor_img, positive_img, negative_img)

            a_proj = projector(a_feat)
            p_proj = projector(p_feat)
            n_proj = projector(n_feat)

            loss = info_nce_loss(a_proj, p_proj, n_proj)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            total_train_loss += loss.item()

        avg_train_loss = total_train_loss / len(train_loader)

        # === Evaluation ===
        projector.eval()
        total_val_loss = 0
        with torch.no_grad():
            for anchor_img, positive_img, negative_img in val_loader:
                a_feat, p_feat, n_feat = extract_features(anchor_img, positive_img, negative_img)

                a_proj = projector(a_feat)
                p_proj = projector(p_feat)
                n_proj = projector(n_feat)

                val_loss = info_nce_loss(a_proj, p_proj, n_proj)
                total_val_loss += val_loss.item()
        avg_val_loss = total_val_loss / len(val_loader)

        print(f"Epoch [{epoch+1}/{epochs}] - Train Loss: {avg_train_loss:.4f} | Val Loss: {avg_val_loss:.4f}")

    torch.save(projector.state_dict(), output_path)
    print(f"✅ Projection head saved to: {output_path}")

In [25]:
train(
    train_loader,
    val_loader,
    epochs=20,
    batch_size=16,
    lr=1e-4,
    output_path='projection_head.pth'
)

Epoch [1/20] - Train Loss: 0.1875 | Val Loss: 0.0401
Epoch [2/20] - Train Loss: 0.1111 | Val Loss: 0.0339
Epoch [3/20] - Train Loss: 0.0661 | Val Loss: 0.0293
Epoch [4/20] - Train Loss: 0.0419 | Val Loss: 0.0257
Epoch [5/20] - Train Loss: 0.0275 | Val Loss: 0.0229
Epoch [6/20] - Train Loss: 0.0180 | Val Loss: 0.0206
Epoch [7/20] - Train Loss: 0.0125 | Val Loss: 0.0189
Epoch [8/20] - Train Loss: 0.0097 | Val Loss: 0.0173
Epoch [9/20] - Train Loss: 0.0073 | Val Loss: 0.0160
Epoch [10/20] - Train Loss: 0.0058 | Val Loss: 0.0150
Epoch [11/20] - Train Loss: 0.0047 | Val Loss: 0.0142
Epoch [12/20] - Train Loss: 0.0041 | Val Loss: 0.0134
Epoch [13/20] - Train Loss: 0.0035 | Val Loss: 0.0128
Epoch [14/20] - Train Loss: 0.0031 | Val Loss: 0.0123
Epoch [15/20] - Train Loss: 0.0027 | Val Loss: 0.0118
Epoch [16/20] - Train Loss: 0.0025 | Val Loss: 0.0114
Epoch [17/20] - Train Loss: 0.0022 | Val Loss: 0.0110
Epoch [18/20] - Train Loss: 0.0020 | Val Loss: 0.0107
Epoch [19/20] - Train Loss: 0.0019 | 