In [1]:
# SETUP

!pip install transformers datasets peft accelerate faiss-cpu --quiet

import torch
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from transformers import CLIPProcessor, CLIPModel
from peft import get_peft_model, LoraConfig, TaskType
import pandas as pd
import numpy as np
from PIL import Image
import os
import faiss
from tqdm import tqdm
from torch.cuda.amp import autocast, GradScaler

scaler = GradScaler()


  scaler = GradScaler()


In [2]:
# CONFIG
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
BATCH_SIZE = 128
NUM_EPOCHS = 10
MODEL_NAME = "openai/clip-vit-base-patch32"
DATA_PATH = "./product_data.csv"

In [3]:
# DATASET
import requests
from io import BytesIO

class ProductDataset(Dataset):
    def __init__(self, dataframe):
        self.texts = dataframe["product_text"].tolist()
        self.image_urls = dataframe["product_image_url"].tolist()
        self.processor = CLIPProcessor.from_pretrained(MODEL_NAME)

    def __len__(self):
        return len(self.texts)

    def __getitem__(self, idx):
        text = self.texts[idx]
        url = self.image_urls[idx]

        try:
            response = requests.get(url, timeout=5)
            image = Image.open(BytesIO(response.content)).convert("RGB")
        except Exception:
            image = Image.new("RGB", (224, 224), "white")  # fallback blank image

        return {
            "text": text,
            "image": image
        }


def collate_fn(batch):
    processor = CLIPProcessor.from_pretrained(MODEL_NAME)
    texts  = [x["text"] for x in batch]
    images = [x["image"] for x in batch]
    enc = processor(text=texts, images=images, return_tensors="pt", padding=True, truncation=True)
    return enc


In [4]:
# MODEL + LORA
def get_model_with_lora():
    base = CLIPModel.from_pretrained(MODEL_NAME)
    config = LoraConfig(
        r=8,
        lora_alpha=16,
        target_modules=["q_proj", "v_proj"],
        lora_dropout=0.1,
        bias="none",
        task_type=TaskType.FEATURE_EXTRACTION
    )
    model = get_peft_model(base, config)
    return model.to(device)


In [5]:
# TRAINING LOOP
def train_model(model, dataloader):
    optimizer = torch.optim.AdamW(model.parameters(), lr=2e-5)
    for epoch in range(NUM_EPOCHS):
        model.train()
        total_loss = 0.0
    
        for batch in tqdm(dataloader, desc=f"Epoch {epoch+1}/{NUM_EPOCHS}"):
            batch = {k: v.to(device) for k, v in batch.items()}
    
            with autocast():  # <<<<<<< Mixed Precision starts here
                text_embs = model.get_text_features(
                    input_ids=batch["input_ids"],
                    attention_mask=batch["attention_mask"]
                )
                image_embs = model.get_image_features(
                    pixel_values=batch["pixel_values"]
                )
    
                text_embs = F.normalize(text_embs, p=2, dim=-1)
                image_embs = F.normalize(image_embs, p=2, dim=-1)
    
                logits_per_text = text_embs @ image_embs.t()
                logits_per_image = logits_per_text.t()

                B = logits_per_text.size(0)
                labels = torch.arange(B, device=device)
    
                loss_t2i = F.cross_entropy(logits_per_text, labels)
                loss_i2t = F.cross_entropy(logits_per_image, labels)
                loss = (loss_t2i + loss_i2t) / 2
                
                optimizer.zero_grad()
                scaler.scale(loss).backward()        # <<< Scaled backprop
                scaler.step(optimizer)               # <<< Scaled optimizer step
                scaler.update()                      # <<< Update the scaler
    
    
                total_loss += loss.item()
    
        print(f"Epoch {epoch+1} avg loss: {total_loss / len(dataloader):.4f}")


In [6]:
# EMBEDDING GENERATION
def generate_embeddings(model, dataset):
    dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, collate_fn=collate_fn)
    text_embs, image_embs = [], []

    model.eval()
    with torch.no_grad():
        for batch in tqdm(dataloader):
            batch = {k:v.to(device) for k,v in batch.items()}
            t = model.get_text_features(input_ids=batch["input_ids"], attention_mask=batch["attention_mask"])
            i = model.get_image_features(pixel_values=batch["pixel_values"])
            text_embs.append(F.normalize(t, dim=-1).cpu())
            image_embs.append(F.normalize(i, dim=-1).cpu())

    return torch.cat(text_embs), torch.cat(image_embs)


In [7]:
#  FAISS INDEXING
def build_faiss_index(embeddings):
    index = faiss.IndexFlatIP(embeddings.shape[1])
    index.add(embeddings.numpy())
    return index


In [8]:
# RECOMMENDATION
def recommend(query_emb, index, df, k=5):
    query_emb = F.normalize(query_emb, dim=-1).cpu().numpy()
    scores, ids = index.search(query_emb, k)
    return df.iloc[ids[0]], scores[0]


In [None]:
# RUN ALL
df = pd.read_csv(DATA_PATH)
df["product_text"] = df.apply(lambda x : str(x["product_title"]) + " " + str(x["product_description"]), axis=1)

df_train = df[
    df["product_text"].str.strip().astype(bool) &
    df["product_image_url"].str.strip().astype(bool)
].reset_index(drop=True)

dataset = ProductDataset(df_train)
dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True, collate_fn=collate_fn)

model = get_model_with_lora()
train_model(model, dataloader)

# Generate embeddings after fine-tuning
text_embs, image_embs = generate_embeddings(model, dataset)
combined_embs = (text_embs + image_embs) / 2

# Build FAISS
index = build_faiss_index(combined_embs)


Using a slow image processor as `use_fast` is unset and a slow processor was saved with this model. `use_fast=True` will be the default behavior in v4.50, even if the model was saved with a slow processor. This will result in minor differences in outputs. You'll still be able to use a slow processor with `use_fast=False`.
  with autocast():  # <<<<<<< Mixed Precision starts here
Epoch 1/10: 100%|███████████████████████████| 880/880 [6:32:43<00:00, 26.78s/it]


Epoch 1 avg loss: 4.4192


Epoch 2/10: 100%|███████████████████████████| 880/880 [6:20:23<00:00, 25.94s/it]


Epoch 2 avg loss: 4.2260


Epoch 3/10: 100%|███████████████████████████| 880/880 [6:14:21<00:00, 25.52s/it]


Epoch 3 avg loss: 4.1854


Epoch 4/10: 100%|███████████████████████████| 880/880 [6:19:14<00:00, 25.86s/it]


Epoch 4 avg loss: 4.1649


Epoch 5/10: 100%|███████████████████████████| 880/880 [6:27:05<00:00, 26.39s/it]


Epoch 5 avg loss: 4.1514


Epoch 6/10: 100%|███████████████████████████| 880/880 [6:08:50<00:00, 25.15s/it]


Epoch 6 avg loss: 4.1412


Epoch 7/10: 100%|███████████████████████████| 880/880 [6:21:39<00:00, 26.02s/it]


Epoch 7 avg loss: 4.1335


Epoch 8/10: 100%|███████████████████████████| 880/880 [6:27:49<00:00, 26.44s/it]


Epoch 8 avg loss: 4.1268


Epoch 9/10:  57%|██████████████▎          | 505/880 [3:48:39<2:45:37, 26.50s/it]

In [None]:
from torchvision import transforms

# Setup preprocessing
image_preprocess = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
])

# Prepare model for inference
model.eval()

def normalize(tensor):
    return F.normalize(tensor, dim=-1)

# Unified Query Function
def unified_query(input_text=None, input_image_path=None, k=5):
    assert input_text or input_image_path, "Provide at least text or image input"

    inputs = {}
    if input_text:
        inputs.update({"text": input_text})
    if input_image_path:
        image = Image.open(input_image_path).convert("RGB")
        inputs.update({"images": image})

    processor = CLIPProcessor.from_pretrained(MODEL_NAME)
    encoded = processor(return_tensors="pt", padding=True, truncation=True, **inputs)
    encoded = {k: v.to(device) for k, v in encoded.items()}

    with torch.no_grad():
        if input_text and input_image_path:
            text_emb  = model.get_text_features(input_ids=encoded["input_ids"], attention_mask=encoded["attention_mask"])
            image_emb = model.get_image_features(pixel_values=encoded["pixel_values"])
            query_emb = normalize((text_emb + image_emb) / 2)
        elif input_text:
            text_emb  = model.get_text_features(input_ids=encoded["input_ids"], attention_mask=encoded["attention_mask"])
            query_emb = normalize(text_emb)
        else:
            image_emb = model.get_image_features(pixel_values=encoded["pixel_values"])
            query_emb = normalize(image_emb)

    # Score against both sets of embeddings
    scores_text  = (query_emb @ text_embs.T).squeeze()
    scores_image = (query_emb @ image_embs.T).squeeze()

    # Combine scores
    combined_scores = (scores_text + scores_image) / 2

    # Top K results
    topk_indices = torch.topk(combined_scores, k=k).indices
    top_scores = combined_scores[topk_indices].cpu().numpy()
    top_items = df.iloc[topk_indices.cpu().numpy()]

    return top_items, top_scores


In [None]:

torch.save(text_embs, "models/catalog_text_embs.pt")
torch.save(image_embs, "models/catalog_image_embs.pt")

text_embs  = normalize(torch.load("models/catalog_text_embs.pt").to(device))
image_embs = normalize(torch.load("models/catalog_image_embs.pt").to(device))

model = get_peft_model(CLIPModel.from_pretrained("openai/clip-vit-base-patch32"), lora_cfg).to(device)
model.load_state_dict(torch.load("models/finetuned_clip.pt"))

results, scores = unified_query(input_text="denim jacket for women")
display(results.head())


In [None]:
# Text only
results, scores = unified_query(input_text="minimalist black backpack")
display(results[["text"]])

# Image only
results, scores = unified_query(input_image_path="./query_imgs/backpack.jpg")
display(results[["text"]])

# Both text + image
results, scores = unified_query(
    input_text="black backpack for school",
    input_image_path="./query_imgs/backpack.jpg"
)
display(results[["text"]])
