In [1]:
#! pip install transformers datasets peft faiss-cpu

In [1]:
import pandas as pd
import torch
from datasets import load_dataset
from transformers import CLIPProcessor, CLIPModel
from peft import get_peft_model, LoraConfig, TaskType
from PIL import Image
from torchvision import transforms
import faiss
import requests
from io import BytesIO

### DATA CLEANING

In [2]:
# Load the metadata split from Amazon Reviews 2023
product_meta_data = load_dataset("McAuley-Lab/Amazon-Reviews-2023", "raw_meta_All_Beauty", split="full", trust_remote_code=True)
df_meta = pd.DataFrame.from_records(product_meta_data).add_prefix("product_")

In [3]:
df_meta.shape

(112590, 16)

In [4]:
# Columns to clean
target_cols = ['product_title', 'product_description', 'product_images']

# 1) Drop NaNs and literal None’s
df_meta_clean = df_meta.dropna(subset=target_cols)
df_meta_clean = df_meta_clean[~df_meta_clean[target_cols]
    .map(lambda x: x is None).any(axis=1)]

In [5]:
df_meta_clean.shape

(112590, 16)

In [6]:
# 2) Define what an “invalid” string is
invalid_strs = {'', 'n/a', 'none', 'na'}

# 3) Validator for product_images
def images_valid(img_dict):
    if not isinstance(img_dict, dict):
        return False
    # only consider these keys for actual URLs
    for key in ('hi_res', 'large', 'thumb'):
        urls = img_dict.get(key, [])
        if not isinstance(urls, (list, tuple)):
            continue
        for url in urls:
            if isinstance(url, str) and url.strip().lower() not in invalid_strs:
                return True
    return False

# 4) General validator for text fields
def text_valid(x):
    return isinstance(x, str) and x.strip().lower() not in invalid_strs

# 5) Apply validators
#   - title & description must pass text_valid
#   - images must pass images_valid
df_meta_clean = df_meta_clean[
    df_meta_clean['product_title'].apply(text_valid) &
    #df_meta_clean['product_description'].apply(text_valid) &
    df_meta_clean['product_images'].apply(images_valid)
].reset_index(drop=True)

print(f"After cleaning: {len(df_meta_clean)} rows")  

After cleaning: 112578 rows


In [7]:
def extract_first_valid_image(images_dict):
    if not isinstance(images_dict, dict):
        return None

    # Keys we care about, in order of preference
    image_keys = ['hi_res', 'large', 'thumb']
    invalid_strs = {'', 'none', 'n/a', 'na'}

    for key in image_keys:
        urls = images_dict.get(key, [])
        if not isinstance(urls, list):
            continue
        for url in urls:
            if isinstance(url, str) and url.strip().lower() not in invalid_strs:
                return url.strip()
    return None

# Apply it to create a new column: product_image_url
df_meta_clean['product_image_url'] = df_meta_clean['product_images'].apply(extract_first_valid_image)

# Optional: drop rows where no valid image could be extracted (just in case)
df_meta_clean = df_meta_clean[df_meta_clean['product_image_url'].notnull()].reset_index(drop=True)

print(f"Final dataset with extracted image URLs: {len(df_meta_clean)} rows")


Final dataset with extracted image URLs: 112578 rows


In [8]:
def flatten_description(desc):
    if isinstance(desc, list):
        return ".".join([d.strip() for d in desc if isinstance(d, str)]).strip()
    elif isinstance(desc, str):
        return desc.strip()
    return ""

df_meta_clean['product_description'] = df_meta_clean['product_description'].apply(flatten_description)

In [10]:
df_meta_clean['product_main_category'].unique()

array(['All Beauty', 'Premium Beauty'], dtype=object)

In [11]:
sample_data = df_meta_clean[df_meta_clean['product_main_category'] == "All Beauty"][['product_main_category','product_title','product_description','product_image_url','product_categories','product_details','product_features']].sample(20000).to_dict('records')

In [12]:
full_data = df_meta_clean[['product_main_category','product_title','product_description','product_image_url','product_categories','product_details','product_features']].to_dict('records')

In [None]:
pd.DataFrame(sample_data).to_csv('product_data_beauty_sample.csv', index= False)
pd.DataFrame(full_data).to_csv('product_data_beauty_full.csv', index= False)

In [None]:
import json
with open('full_data_beauty.json', 'w') as json_file:
    json.dump(full_data, json_file, indent=4)

In [None]:
import json
from pathlib import Path
from tqdm import tqdm

import torch
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM

# Configuration
RAW_PATH       = Path("full_data_beauty.json")
PROCESSED_PATH = Path("processed_for_clip.json")
MODEL_NAME     = "facebook/bart-large-cnn"
BATCH_SIZE     = 16
MAX_OUT        = 256

# Setup
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_NAME)

if torch.cuda.device_count() > 1:
    print(f"Using {torch.cuda.device_count()} GPUs")
    model = torch.nn.DataParallel(model)

model.to(device)
model.eval()

def batch_standardize(batch_raw: list) -> list:
    """
    Summarize a batch of raw product dictionaries into clean text paragraphs.
    Returns a list of strings, one per product.
    """
    batch_texts = [
        f"Product Title: {item.get('product_title', '')}\n"
        f"Product Description: {item.get('product_description', '')}\n"
        f"Product Features: {'; '.join(item.get('product_features', []))}\n"
        f"Product Details: {item.get('product_details', '')}"
        for item in batch_raw
    ]

    inputs = tokenizer(
        batch_texts,
        return_tensors="pt",
        truncation=True,
        max_length=1024,
        padding=True
    ).to(device)

    with torch.no_grad():
        summary_ids = model.generate(
            input_ids=inputs["input_ids"],
            attention_mask=inputs["attention_mask"],
            max_length=MAX_OUT,
            num_beams=4,
            early_stopping=True
        )

    summaries = tokenizer.batch_decode(summary_ids, skip_special_tokens=True)
    return [summary.strip() for summary in summaries]

# Load and process
with RAW_PATH.open("r") as f:
    all_products = json.load(f)

processed = []
for i in tqdm(range(0, len(all_products), BATCH_SIZE), desc="Processing"):
    batch = all_products[i : i + BATCH_SIZE]
    texts = batch_standardize(batch)
    for raw, text in zip(batch, texts):
        processed.append({
            "text": text,
            "image_url": raw.get("product_image_url", "")
        })
    torch.cuda.empty_cache()

# Save output
with PROCESSED_PATH.open("w") as f:
    json.dump(processed, f, indent=2)

print(f"\nDone! Written {len(processed)} entries to {PROCESSED_PATH}")


Processing: 100%|█████████████████████████| 7037/7037 [9:41:08<00:00,  4.95s/it]



Done! Written 112578 entries to processed_for_clip.json


In [16]:
cleaned_text = pd.DataFrame(processed)

In [17]:
cleaned_text

Unnamed: 0,text,image_url
0,Product Title: Howard LC0008 Leather Condition...,https://m.media-amazon.com/images/I/71i77AuI9x...
1,Yes to Tomatoes Detoxifying Charcoal Cleanser ...,https://m.media-amazon.com/images/I/71g1lP0pMb...
2,Product Title: Eye Patch Black Adult with Tie ...,https://m.media-amazon.com/images/I/31bz+uqzWC...
3,Product Title: Tattoo Eyebrow Stickers. Waterp...,https://m.media-amazon.com/images/I/71GJhXQGvy...
4,Precision Plunger Bars for Cartridge Grips – 9...,https://m.media-amazon.com/images/I/31TgqAZ8kQ...
...,...,...
112573,"Product Title: TOPREETY 24""120gr 3/4 Full Head...",https://m.media-amazon.com/images/I/71Ud1D40lg...
112574,Product Title: Pets Playmate Pet Grooming Glov...,https://m.media-amazon.com/images/I/61o89FR2Dj...
112575,Makeup Brushes Set Cosmetics Tools Kit Peacock...,https://m.media-amazon.com/images/I/71pKBX5Xrx...
112576,Product Title: Xcoser Pretty Party Anna Wig Ha...,https://m.media-amazon.com/images/I/61t3vpvoZK...


In [18]:
import re
import unicodedata

def clean_product_text(text: str) -> str:
    if not isinstance(text, str):
        return ""

    # Normalize unicode characters
    text = unicodedata.normalize("NFKC", text)

    # Replace unusual quotes and dashes with standard versions
    text = text.translate(str.maketrans({
        '“': '"', '”': '"', '’': "'", '‘': "'",
        '–': '-', '—': '-', 'œ': 'oe', '‚': ',',
        '•': '-', '…': '...', '″': '"', '′': "'"
    }))

    # Remove non-ASCII or control characters
    text = re.sub(r"[^\x00-\x7F]+", " ", text)

    # Remove extra punctuation artifacts
    #text = re.sub(r"['\"`]+", "'", text)  # collapse quotes
    #text = re.sub(r"[-=]{2,}", "-", text)  # collapse dashes
    text = re.sub(r"\s+", " ", text)  # remove excessive whitespace

    # Optional: remove isolated digits or trailing junk
#    text = re.sub(r"\b(\d{1,2})\b", "", text)  # remove 1-2 digit isolated numbers
    text = text.strip(" '\"-.,\n\t")

    return text.strip()


In [19]:
cleaned_text['text'] = cleaned_text['text'].apply(lambda x : clean_product_text(x))

In [None]:
cleaned_text.rename(columns={'text':'product_text','image_url':'product_image_url'}).to_csv('meta_data_beauty.csv',index=False)

In [21]:
cleaned_text

Unnamed: 0,text,image_url
0,Product Title: Howard LC0008 Leather Condition...,https://m.media-amazon.com/images/I/71i77AuI9x...
1,Yes to Tomatoes Detoxifying Charcoal Cleanser ...,https://m.media-amazon.com/images/I/71g1lP0pMb...
2,Product Title: Eye Patch Black Adult with Tie ...,https://m.media-amazon.com/images/I/31bz+uqzWC...
3,Product Title: Tattoo Eyebrow Stickers. Waterp...,https://m.media-amazon.com/images/I/71GJhXQGvy...
4,Precision Plunger Bars for Cartridge Grips - 9...,https://m.media-amazon.com/images/I/31TgqAZ8kQ...
...,...,...
112573,"Product Title: TOPREETY 24""120gr 3/4 Full Head...",https://m.media-amazon.com/images/I/71Ud1D40lg...
112574,Product Title: Pets Playmate Pet Grooming Glov...,https://m.media-amazon.com/images/I/61o89FR2Dj...
112575,Makeup Brushes Set Cosmetics Tools Kit Peacock...,https://m.media-amazon.com/images/I/71pKBX5Xrx...
112576,Product Title: Xcoser Pretty Party Anna Wig Ha...,https://m.media-amazon.com/images/I/61t3vpvoZK...


## EMBEDDING APPROACH

In [1]:
import os
import numpy as np
import pandas as pd
import torch
import torch.nn.functional as F

from torch.utils.data import Dataset
from transformers import CLIPProcessor,CLIPModel
from PIL import Image
import requests
from io import BytesIO

In [None]:
from utils import load_and_clean_data,get_model,generate_embeddings,save_embeddings,build_faiss_index
SAVE_DIR = "artifacts_zeroshot_beauty"
os.makedirs(SAVE_DIR, exist_ok=True)

In [None]:
#prod_data = load_and_clean_data("product_data.csv")
prod_data = pd.read_csv('meta_data_beauty.csv')

In [8]:
prod_data

Unnamed: 0,product_text,product_image_url
0,Product Title: Howard LC0008 Leather Condition...,https://m.media-amazon.com/images/I/71i77AuI9x...
1,Yes to Tomatoes Detoxifying Charcoal Cleanser ...,https://m.media-amazon.com/images/I/71g1lP0pMb...
2,Product Title: Eye Patch Black Adult with Tie ...,https://m.media-amazon.com/images/I/31bz+uqzWC...
3,Product Title: Tattoo Eyebrow Stickers. Waterp...,https://m.media-amazon.com/images/I/71GJhXQGvy...
4,Precision Plunger Bars for Cartridge Grips - 9...,https://m.media-amazon.com/images/I/31TgqAZ8kQ...
...,...,...
112573,"Product Title: TOPREETY 24""120gr 3/4 Full Head...",https://m.media-amazon.com/images/I/71Ud1D40lg...
112574,Product Title: Pets Playmate Pet Grooming Glov...,https://m.media-amazon.com/images/I/61o89FR2Dj...
112575,Makeup Brushes Set Cosmetics Tools Kit Peacock...,https://m.media-amazon.com/images/I/71pKBX5Xrx...
112576,Product Title: Xcoser Pretty Party Anna Wig Ha...,https://m.media-amazon.com/images/I/61t3vpvoZK...


In [9]:
import torch
from torch.utils.data import Dataset
from PIL import Image
import requests
from io import BytesIO
from transformers import CLIPProcessor

class ProductCLIPDataset(Dataset):
    def __init__(self, df, model_name="openai/clip-vit-base-patch32"):
        self.texts = df["product_text"].tolist()
        self.urls = df["product_image_url"].tolist()
        self.processor = CLIPProcessor.from_pretrained(model_name)

    def __len__(self):
        return len(self.texts)

    def __getitem__(self, idx):
        text = self.texts[idx]
        url = self.urls[idx]

        try:
            image = Image.open(requests.get(url, stream=True).raw).convert("RGB")
        except:
            # Fallback image in case of failure
            image = Image.new("RGB", (224, 224), color=(255, 255, 255))

        return {
            "text": text,
            "image": image
        }

    def collate_fn(self, batch):
        texts = [ex["text"] for ex in batch]
        images = [ex["image"] for ex in batch]

        # Tokenize text
        tokenized = self.processor.tokenizer(
            texts,
            padding=True,
            truncation=True,
            return_tensors="pt"
        )

        # Process images
        image_inputs = self.processor.image_processor(
            images,
            return_tensors="pt"
        )

        return {
            "input_ids": tokenized["input_ids"],
            "attention_mask": tokenized["attention_mask"],
            "pixel_values": image_inputs["pixel_values"]
        }


In [10]:
# Load Zero-Shot CLIP Model
model_zs = get_model(approach="zero_shot", save_dir=SAVE_DIR)
print("Loaded model:", model_zs.__class__.__name__)

Loaded model: CLIPModel


In [11]:
# Generate text & image embeddings
dataset = ProductCLIPDataset(prod_data)
text_embs, image_embs = generate_embeddings(model_zs, dataset, batch_size=32)
print("Generated embeddings:", text_embs.shape, image_embs.shape)

Using a slow image processor as `use_fast` is unset and a slow processor was saved with this model. `use_fast=True` will be the default behavior in v4.50, even if the model was saved with a slow processor. This will result in minor differences in outputs. You'll still be able to use a slow processor with `use_fast=False`.
Generating embeddings: 100%|██████████████| 3519/3519 [2:24:01<00:00,  2.46s/it]


Generated embeddings: torch.Size([112578, 512]) torch.Size([112578, 512])


In [None]:
save_embeddings(text_embs, image_embs, SAVE_DIR)
combined = F.normalize((text_embs + image_embs) / 2, dim =-1)
index_path = build_faiss_index(combined, SAVE_DIR)

FAISS index saved to: artifacts_zeroshot_fash/faiss.index
