In [1]:
pip install ipywidgets pandas

Collecting ipywidgets
  Obtaining dependency information for ipywidgets from https://files.pythonhosted.org/packages/58/6a/9166369a2f092bd286d24e6307de555d63616e8ddb373ebad2b5635ca4cd/ipywidgets-8.1.7-py3-none-any.whl.metadata
  Downloading ipywidgets-8.1.7-py3-none-any.whl.metadata (2.4 kB)
Collecting widgetsnbextension~=4.0.14 (from ipywidgets)
  Obtaining dependency information for widgetsnbextension~=4.0.14 from https://files.pythonhosted.org/packages/ca/51/5447876806d1088a0f8f71e16542bf350918128d0a69437df26047c8e46f/widgetsnbextension-4.0.14-py3-none-any.whl.metadata
  Downloading widgetsnbextension-4.0.14-py3-none-any.whl.metadata (1.6 kB)
Collecting jupyterlab_widgets~=3.0.15 (from ipywidgets)
  Obtaining dependency information for jupyterlab_widgets~=3.0.15 from https://files.pythonhosted.org/packages/43/6a/ca128561b22b60bd5a0c4ea26649e68c8556b82bc70a0c396eebc977fe86/jupyterlab_widgets-3.0.15-py3-none-any.whl.metadata
  Downloading jupyterlab_widgets-3.0.15-py3-none-any.whl.met

In [1]:
pip install torchvision fashion-clip


Collecting torchvision
  Obtaining dependency information for torchvision from https://files.pythonhosted.org/packages/f6/00/bdab236ef19da050290abc2b5203ff9945c84a1f2c7aab73e8e9c8c85669/torchvision-0.22.1-cp311-cp311-macosx_11_0_arm64.whl.metadata
  Using cached torchvision-0.22.1-cp311-cp311-macosx_11_0_arm64.whl.metadata (6.1 kB)
Collecting fashion-clip
  Obtaining dependency information for fashion-clip from https://files.pythonhosted.org/packages/87/9b/b5ef30166a21aac6412ce0c306bee9d1e6181b9f38b89e5f12b572613807/fashion_clip-0.2.2-py3-none-any.whl.metadata
  Using cached fashion_clip-0.2.2-py3-none-any.whl.metadata (11 kB)
Collecting torch==2.7.1 (from torchvision)
  Obtaining dependency information for torch==2.7.1 from https://files.pythonhosted.org/packages/5b/2b/d36d57c66ff031f93b4fa432e86802f84991477e522adcdffd314454326b/torch-2.7.1-cp311-none-macosx_11_0_arm64.whl.metadata
  Using cached torch-2.7.1-cp311-none-macosx_11_0_arm64.whl.metadata (29 kB)
Collecting sympy>=1.13.3 (f

In [6]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import transforms
from PIL import Image
import pandas as pd
import os
from fashion_clip.fashion_clip import FashionCLIP
from torch.utils.data import Dataset,DataLoader

[34marchive[m[m/            demo_1.ipynb        fine_tune.ipynb     test.py
archive.zip         [34mfashion_images[m[m/     match.py            train_triplets.csv
crawler.py          fclip.py            [34mmyenv[m[m/              val_triplets.csv
demo.py             [34mfclipenv[m[m/           shirt.jpg


In [8]:

class TripletFashionDataset(Dataset):
    def __init__(self, csv_file, image_folder, transform=None):
        self.data = pd.read_csv(csv_file)
        self.image_folder = image_folder
        self.transform = transform or transforms.Compose([
            transforms.Resize((224, 224)),
            transforms.CenterCrop(224),
            transforms.ToTensor(),
        ])

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        def find_image_path(name):
            for ext in ['.jpg', '.png']:
                path = os.path.join(self.image_folder, name + ext)
                if os.path.exists(path):
                    return path
            raise FileNotFoundError(f"Image {name} not found in supported formats.")

        a = find_image_path(self.data.iloc[idx, 0])
        p = find_image_path(self.data.iloc[idx, 1])
        n = find_image_path(self.data.iloc[idx, 2])

        return self.transform(Image.open(a).convert("RGB")), \
               self.transform(Image.open(p).convert("RGB")), \
               self.transform(Image.open(n).convert("RGB"))
        
train_dataset = TripletFashionDataset("train_triplets.csv","./fashion_images")
val_dataset = TripletFashionDataset("val_triplets.csv","./fashion_images")
train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=16, shuffle=False)

print(f"Train samples: {len(train_dataset)}")
print(f"Val samples: {len(val_dataset)}")

Train samples: 80
Val samples: 20


In [10]:
class ProjectionHead(nn.Module):
    def __init__(self, input_dim=512, output_dim=256):
        super().__init__()
        self.fc = nn.Linear(input_dim, output_dim)

    def forward(self, x):
        return F.normalize(self.fc(x), p=2, dim=-1)

def info_nce_loss(anchor, positive, negative, temperature=0.07):
    sim_ap = (anchor * positive).sum(dim=-1) / temperature
    sim_an = (anchor * negative).sum(dim=-1) / temperature
    logits = torch.cat([sim_ap.unsqueeze(1), sim_an.unsqueeze(1)], dim=1)
    labels = torch.zeros(anchor.size(0), dtype=torch.long, device=anchor.device)
    return F.cross_entropy(logits, labels)

In [13]:
def train(epochs=5, batch_size=16, lr=1e-4, output_path='projection_head.pth'):
    device = "cuda" if torch.cuda.is_available() else "cpu"

    fclip = FashionCLIP('fashion-clip')
    fclip.model.eval()
    for param in fclip.model.parameters():
        param.requires_grad = False
        
    

    projector = ProjectionHead(input_dim=512, output_dim=256).to(device)
    optimizer = torch.optim.Adam(projector.parameters(), lr=lr)

    for epoch in range(epochs):
        projector.train()
        total_loss = 0

        for anchor_img, positive_img, negative_img in train_loader:
            anchor_img = anchor_img.to(device)
            positive_img = positive_img.to(device)
            negative_img = negative_img.to(device)

            with torch.no_grad():
                
                a_feat = fclip.model.get_image_features(pixel_values=anchor_img)
                p_feat = fclip.model.get_image_features(pixel_values=anchor_img)
                n_feat = fclip.model.get_image_features(pixel_values=anchor_img)

            a_proj = projector(a_feat)
            p_proj = projector(p_feat)
            n_proj = projector(n_feat)

            loss = info_nce_loss(a_proj, p_proj, n_proj)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            total_loss += loss.item()

        avg_train_loss = total_loss / len(train_loader)
        val_loss = evaluate(visual_encoder, projector, val_loader, device)
        print(f"Epoch [{epoch+1}/{epochs}] - Train Loss: {avg_train_loss:.4f} | Val Loss: {val_loss:.4f}")

    torch.save(projector.state_dict(), output_path)
    print(f"✅ Projection head saved to: {output_path}")


def evaluate(model, projector, dataloader, device):
    model.eval()
    projector.eval()
    total_loss = 0

    with torch.no_grad():
        for anchor_img, positive_img, negative_img in dataloader:
            anchor_img = anchor_img.to(device)
            positive_img = positive_img.to(device)
            negative_img = negative_img.to(device)

            a_feat = model(anchor_img)["pooler_output"]
            p_feat = model(positive_img)["pooler_output"]
            n_feat = model(negative_img)["pooler_output"]

            a_proj = projector(a_feat)
            p_proj = projector(p_feat)
            n_proj = projector(n_feat)

            loss = info_nce_loss(a_proj, p_proj, n_proj)
            total_loss += loss.item()

    return total_loss / len(dataloader)

In [14]:
train(
    epochs=10,
    batch_size=16,
    lr=1e-4,
    output_path='projection_head.pth'
)

AttributeError: 'FashionCLIP' object has no attribute 'image_encoder'