In [1]:
# SETUP

#!pip install transformers datasets peft accelerate faiss-cpu --quiet

import torch
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from transformers import CLIPProcessor, CLIPModel,get_cosine_schedule_with_warmup
from peft import get_peft_model, LoraConfig, TaskType
import pandas as pd
import numpy as np
from PIL import Image
import os
import faiss
from tqdm import tqdm
from torch.amp import autocast, GradScaler
import requests
from io import BytesIO
scaler = GradScaler()
import time
import pickle

In [2]:
from utils import *

In [None]:
# CONFIG
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
CSV_PATH      = "meta_data_beauty.csv"
BATCH_SIZE = 128
NUM_EPOCHS = 20 # 10 gave best results
MODEL_NAME = "openai/clip-vit-base-patch32"
N_SAMPLES = 20000
df = pd.read_csv(CSV_PATH)

In [4]:

class ProductDataset(Dataset):
    def __init__(self, dataframe, model_name="openai/clip-vit-base-patch32"):
        self.texts = dataframe["product_text"].tolist()
        self.image_urls = dataframe["product_image_url"].tolist()
        self.processor = CLIPProcessor.from_pretrained(model_name, use_fast=True)

    def __len__(self):
        return len(self.texts)

    def __getitem__(self, idx):
        text = self.texts[idx]
        url  = self.image_urls[idx]
        try:
            img = Image.open(BytesIO(requests.get(url, timeout=5).content)).convert("RGB")
        except:
            img = Image.new("RGB", (224,224), "white")
        return {"text": text, "image": img}

    def collate_fn(self, batch):
        texts  = [ex["text"] for ex in batch]
        images = [ex["image"] for ex in batch]

        # 1) Tokenize text
        tokenized = self.processor.tokenizer(
            texts,
            padding=True,
            truncation=True,
            return_tensors="pt"
        )

        # 2) Preprocess images
        # Note: depending on your transformers version this may be `.feature_extractor` or `.image_processor`
        image_inputs = self.processor.image_processor(
            images=images,
            return_tensors="pt"
        )

        # 3) Merge
        tokenized["pixel_values"] = image_inputs["pixel_values"]
        return tokenized


In [5]:
# MODEL + LORA
def get_model_with_lora():
    base = CLIPModel.from_pretrained(MODEL_NAME)
    config = LoraConfig(
        r=8,
        lora_alpha=16,
        target_modules=["q_proj", "v_proj"],
        lora_dropout=0.1,
        bias="none",
        task_type=TaskType.FEATURE_EXTRACTION
    )
    model = get_peft_model(base, config)
    return model.to(device)


In [6]:
def train_model(model, dataloader):
    model.train()
    optimizer = torch.optim.AdamW(model.parameters(), lr=2e-5)
    scaler = GradScaler()

    # Scheduler: cosine with 10% warm-up
    total_steps = NUM_EPOCHS * len(dataloader)
    scheduler = get_cosine_schedule_with_warmup(
        optimizer,
        num_warmup_steps=int(0.1 * total_steps),
        num_training_steps=total_steps
    )

    start_time = time.time()

    for epoch in range(NUM_EPOCHS):
        total_loss = 0.0

        for batch in tqdm(dataloader, desc=f"Epoch {epoch+1}/{NUM_EPOCHS}"):
            batch = {k: v.to(DEVICE) for k, v in batch.items()}

            optimizer.zero_grad()

            with autocast(device_type='cuda', dtype=torch.float16):
                text_embs = model.get_text_features(
                    input_ids=batch["input_ids"],
                    attention_mask=batch["attention_mask"]
                )
                image_embs = model.get_image_features(
                    pixel_values=batch["pixel_values"]
                )

                logits_per_text = text_embs @ image_embs.t()
                logits_per_image = logits_per_text.t()

                B = logits_per_text.size(0)
                labels = torch.arange(B, device=DEVICE)
                loss_t2i = F.cross_entropy(logits_per_text, labels)
                loss_i2t = F.cross_entropy(logits_per_image, labels)
                loss = (loss_t2i + loss_i2t) / 2

            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()
            scheduler.step()

            total_loss += loss.item()

        print(f"Epoch {epoch+1} Loss: {total_loss / len(dataloader):.4f}")

    total_training_time = time.time() - start_time
    print(f"Total Training Time: {total_training_time / 60:.2f} minutes")

In [7]:
def generate_embeddings(model, dataset):
    dataloader = DataLoader(
        dataset,
        batch_size=BATCH_SIZE,
        shuffle=False,  
        num_workers=4,
        pin_memory=True,
        collate_fn=dataset.collate_fn  
    )
    text_embs, image_embs = [], []  
    model.eval()  
    model.to(device)  
    with torch.no_grad():  
        for batch in tqdm(dataloader, desc="Generating embeddings"):
            batch = {k: v.to(device) for k, v in batch.items()}
            text_embeddings = model.get_text_features(input_ids=batch["input_ids"], attention_mask=batch["attention_mask"])
            image_embeddings = model.get_image_features(pixel_values=batch["pixel_values"])
            text_embs.append(F.normalize(text_embeddings, p=2, dim=-1).cpu())  # L2 normalization
            image_embs.append(F.normalize(image_embeddings, p=2, dim=-1).cpu())  # L2 normalization
    text_embs = torch.cat(text_embs, dim=0)
    image_embs = torch.cat(image_embs, dim=0)
    return text_embs, image_embs

In [8]:
#  FAISS INDEXING
def build_faiss_index(embeddings):
    index = faiss.IndexFlatIP(embeddings.shape[1])
    index.add(embeddings.numpy())
    return index

In [9]:
# MODEL FINETUNING 
df_train = df.sample(N_SAMPLES)
dataset = ProductDataset(df_train, model_name=MODEL_NAME)
loader = DataLoader(
    dataset,
    batch_size=BATCH_SIZE,
    shuffle=True,
    num_workers=4,
    pin_memory=True,
    collate_fn=dataset.collate_fn
)

In [10]:
model = get_model_with_lora()
train_model(model, loader)

Epoch 1/20: 100%|█████████████████████████████| 157/157 [16:42<00:00,  6.39s/it]


Epoch 1 Loss: 2.1830


Epoch 2/20: 100%|█████████████████████████████| 157/157 [15:54<00:00,  6.08s/it]


Epoch 2 Loss: 1.5900


Epoch 3/20: 100%|█████████████████████████████| 157/157 [16:02<00:00,  6.13s/it]


Epoch 3 Loss: 1.2741


Epoch 4/20: 100%|█████████████████████████████| 157/157 [16:21<00:00,  6.25s/it]


Epoch 4 Loss: 1.1488


Epoch 5/20: 100%|█████████████████████████████| 157/157 [15:34<00:00,  5.95s/it]


Epoch 5 Loss: 1.0794


Epoch 6/20: 100%|█████████████████████████████| 157/157 [15:40<00:00,  5.99s/it]


Epoch 6 Loss: 1.0280


Epoch 7/20: 100%|█████████████████████████████| 157/157 [15:40<00:00,  5.99s/it]


Epoch 7 Loss: 0.9993


Epoch 8/20: 100%|█████████████████████████████| 157/157 [15:22<00:00,  5.88s/it]


Epoch 8 Loss: 0.9636


Epoch 9/20: 100%|█████████████████████████████| 157/157 [15:39<00:00,  5.98s/it]


Epoch 9 Loss: 0.9420


Epoch 10/20: 100%|████████████████████████████| 157/157 [15:21<00:00,  5.87s/it]


Epoch 10 Loss: 0.9134


Epoch 11/20: 100%|████████████████████████████| 157/157 [15:31<00:00,  5.93s/it]


Epoch 11 Loss: 0.9084


Epoch 12/20: 100%|████████████████████████████| 157/157 [15:34<00:00,  5.95s/it]


Epoch 12 Loss: 0.8892


Epoch 13/20: 100%|████████████████████████████| 157/157 [15:16<00:00,  5.84s/it]


Epoch 13 Loss: 0.8773


Epoch 14/20: 100%|████████████████████████████| 157/157 [14:32<00:00,  5.56s/it]


Epoch 14 Loss: 0.8664


Epoch 15/20: 100%|████████████████████████████| 157/157 [14:50<00:00,  5.67s/it]


Epoch 15 Loss: 0.8621


Epoch 16/20: 100%|████████████████████████████| 157/157 [14:39<00:00,  5.60s/it]


Epoch 16 Loss: 0.8508


Epoch 17/20: 100%|████████████████████████████| 157/157 [14:49<00:00,  5.67s/it]


Epoch 17 Loss: 0.8512


Epoch 18/20: 100%|████████████████████████████| 157/157 [15:02<00:00,  5.75s/it]


Epoch 18 Loss: 0.8430


Epoch 19/20: 100%|████████████████████████████| 157/157 [14:58<00:00,  5.72s/it]


Epoch 19 Loss: 0.8493


Epoch 20/20: 100%|████████████████████████████| 157/157 [14:59<00:00,  5.73s/it]

Epoch 20 Loss: 0.8454
Total Training Time: 308.57 minutes





## Testing

In [None]:
# ── Load Model & Generate Embeddings ────────────────────────────────────
SAVE_DIR = "artifacts_lora_mp_beauty/"
tuned_model = get_model(approach="lora_opt", save_dir=SAVE_DIR)
tuned_model.to(DEVICE)

PeftModelForFeatureExtraction(
  (base_model): LoraModel(
    (model): CLIPModel(
      (text_model): CLIPTextTransformer(
        (embeddings): CLIPTextEmbeddings(
          (token_embedding): Embedding(49408, 512)
          (position_embedding): Embedding(77, 512)
        )
        (encoder): CLIPEncoder(
          (layers): ModuleList(
            (0-11): 12 x CLIPEncoderLayer(
              (self_attn): CLIPSdpaAttention(
                (k_proj): Linear(in_features=512, out_features=512, bias=True)
                (v_proj): lora.Linear(
                  (base_layer): Linear(in_features=512, out_features=512, bias=True)
                  (lora_dropout): ModuleDict(
                    (default): Dropout(p=0.1, inplace=False)
                  )
                  (lora_A): ModuleDict(
                    (default): Linear(in_features=512, out_features=8, bias=False)
                  )
                  (lora_B): ModuleDict(
                    (default): Linear(in_features=8, out_

In [12]:
df_full = pd.read_csv('product_data_fash_full_cleaned.csv')

In [13]:
# Generate embeddings after fine-tuning and building faiss index
dataset = ProductDataset(df_full, model_name=MODEL_NAME)
text_embs, image_embs = generate_embeddings(tuned_model, dataset)
combined_embs = F.normalize(text_embs + image_embs, dim=-1)  # [N, D]
index = build_faiss_index(combined_embs)

Generating embeddings: 100%|████████████████| 880/880 [1:29:38<00:00,  6.11s/it]


In [None]:
SAVE_DIR = "artifacts_lora_mp_beauty/"
os.makedirs(SAVE_DIR, exist_ok=True)

tuned_model.save_pretrained(os.path.join(SAVE_DIR, "clip_lora_model_mp"))
CLIPProcessor.from_pretrained(MODEL_NAME, use_fast=True).save_pretrained(os.path.join(SAVE_DIR, "clip_processor_mp"))

torch.save(text_embs, os.path.join(SAVE_DIR, "text_embeddings.pt"))
torch.save(image_embs, os.path.join(SAVE_DIR, "image_embeddings.pt"))
torch.save(combined_embs, os.path.join(SAVE_DIR, "combined_embeddings.pt"))

with open(os.path.join(SAVE_DIR, "product_metadata.pkl"), "wb") as f:
    pickle.dump(df_full.to_dict(), f)

faiss.write_index(index, os.path.join(SAVE_DIR, "faiss.index"))

print(f"Model, processor, embeddings, metadata, and FAISS index saved in {SAVE_DIR}")


Model, processor, embeddings, metadata, and FAISS index saved in artifacts_20k_mp_fash/
