In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from transformers import AutoTokenizer, AutoModel
import pandas as pd
from tqdm import tqdm
import numpy as np
import random

# Set device: use MPS if available on MacBook Pro, else CUDA, else CPU.
if torch.backends.mps.is_available():
    device = torch.device("mps")
elif torch.cuda.is_available():
    device = torch.device("cuda")
else:
    device = torch.device("cpu")
    
device

device(type='mps')

In [2]:

# ---------------------------
# 1. Contrastive Dataset Definition
# ---------------------------
class ContrastiveProductDataset(Dataset):
    def __init__(self, csv_file):
        self.df = pd.read_csv(csv_file)
        required_cols = ['asin', 'cleaned_review', 'cleaned_metadata', 'rating', 'price']
        for col in required_cols:
            if col not in self.df.columns:
                raise ValueError(f"CSV file is missing required column: {col}")
        # Shuffle the data initially
        self.df = self.df.sample(frac=1).reset_index(drop=True)
    
    def __len__(self):
        return len(self.df)
    
    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        # View 1: cleaned review text (user perspective)
        view1 = str(row['cleaned_review']).strip()
        # View 2: combine metadata with price and rating (product details)
        view2 = f"{str(row['cleaned_metadata']).strip()}. Price: {row['price']}. Rating: {row['rating']}."
        return {"view1": view1, "view2": view2}


In [3]:
def my_collate_fn(batch, tokenizer, max_length=128):
    view1_texts = [item["view1"] for item in batch]
    view2_texts = [item["view2"] for item in batch]
    
    encoded_view1 = tokenizer(
        view1_texts,
        padding=True,
        truncation=True,
        max_length=max_length,
        return_tensors="pt"
    )
    encoded_view2 = tokenizer(
        view2_texts,
        padding=True,
        truncation=True,
        max_length=max_length,
        return_tensors="pt"
    )
    
    return {
        "view1_input_ids": encoded_view1["input_ids"],
        "view1_attention_mask": encoded_view1["attention_mask"],
        "view2_input_ids": encoded_view2["input_ids"],
        "view2_attention_mask": encoded_view2["attention_mask"],
    }

In [4]:

# ---------------------------
# 2. Model Architecture: Dual Encoder with SBERT
# ---------------------------
class DualEncoderSBERT(nn.Module):
    def __init__(self, model_name="sentence-transformers/all-MiniLM-L12-v2", embed_dim=128):
        """
        Uses SBERT as a base model. The chosen model outputs 384-dimensional embeddings.
        We then add a projection head mapping to a lower-dimensional space (e.g., 128 dims).
        """
        super().__init__()
        self.tokenizer = AutoTokenizer.from_pretrained(model_name)
        # Load SBERT-style model
        self.encoder = AutoModel.from_pretrained(model_name)
        
        # Projection head: maps from 384 (MiniLM) to embed_dim.
        self.projection = nn.Sequential(
            nn.Linear(self.encoder.config.hidden_size, embed_dim),
            nn.ReLU(),
            nn.Linear(embed_dim, embed_dim)
        )
    
    def forward(self, input_ids, attention_mask):
        outputs = self.encoder(input_ids=input_ids, attention_mask=attention_mask)
        # Use the first token's embedding as the sentence representation (like [CLS])
        cls_emb = outputs.last_hidden_state[:, 0]  # shape: (batch, hidden_size)
        proj = self.projection(cls_emb)
        # Normalize embeddings for cosine similarity
        return F.normalize(proj, p=2, dim=1)


In [5]:

# ---------------------------
# 3. Contrastive Loss (InfoNCE)
# ---------------------------
def info_nce_loss(embeddings1, embeddings2, temperature=0.07):
    """
    Computes InfoNCE loss between two sets of embeddings.
    For each sample, the positive pair is at the same index,
    while other samples in the batch are treated as negatives.
    """
    logits = torch.mm(embeddings1, embeddings2.t()) / temperature
    labels = torch.arange(logits.shape[0]).to(logits.device)
    loss = F.cross_entropy(logits, labels)
    return loss


In [6]:
def evaluate_model(model_view1, model_view2, dataloader, temperature=0.07):
    """
    Computes retrieval metrics on the validation set.
    We'll compute Recall@1 and Mean Reciprocal Rank (MRR).
    The assumption is that for each sample, the correct pair is on the diagonal.
    """
    model_view1.eval()
    model_view2.eval()
    
    all_emb_view1 = []
    all_emb_view2 = []
    
    with torch.no_grad():
        for batch in tqdm(dataloader, desc="Evaluating"):
            view1_ids = batch["view1_input_ids"].to(device)
            view1_mask = batch["view1_attention_mask"].to(device)
            view2_ids = batch["view2_input_ids"].to(device)
            view2_mask = batch["view2_attention_mask"].to(device)
            
            emb1 = model_view1(view1_ids, view1_mask)
            emb2 = model_view2(view2_ids, view2_mask)
            
            all_emb_view1.append(emb1.cpu())
            all_emb_view2.append(emb2.cpu())
    
    all_emb_view1 = torch.cat(all_emb_view1, dim=0)
    all_emb_view2 = torch.cat(all_emb_view2, dim=0)
    
    # Compute similarity matrix
    sims = torch.mm(all_emb_view1, all_emb_view2.t())
    sims_np = sims.numpy()
    
    # For each query (row), compute the rank of the correct pair (diagonal element)
    ranks = []
    for i in range(sims_np.shape[0]):
        # Sort indices in descending order of similarity
        sorted_indices = np.argsort(-sims_np[i])
        # Find rank of the i-th sample (correct pairing)
        rank = np.where(sorted_indices == i)[0][0] + 1  # ranks start at 1
        ranks.append(rank)
    
    ranks = np.array(ranks)
    recall_at_1 = np.mean(ranks == 1)
    mrr = np.mean(1.0 / ranks)
    
    return recall_at_1, mrr

In [7]:
from functools import partial
from torch.optim import AdamW


def train_and_evaluate(train_csv, val_csv, epochs=3, batch_size=32, lr=2e-5,
                       temperature=0.07, max_length=128, num_workers=4):
    # Create datasets and dataloaders for train and validation splits.
    train_dataset = ContrastiveProductDataset(train_csv)
    val_dataset = ContrastiveProductDataset(val_csv)
    
    # Initialize tokenizer for collate function.
    tokenizer = AutoTokenizer.from_pretrained("sentence-transformers/all-MiniLM-L12-v2")
    
    collate = partial(my_collate_fn, tokenizer=tokenizer, max_length=max_length)
    
    train_loader = DataLoader(
        train_dataset,
        batch_size=batch_size,
        shuffle=True,
        num_workers=num_workers,
        collate_fn=collate
    )
    
    val_loader = DataLoader(
        val_dataset,
        batch_size=batch_size,
        shuffle=False,
        num_workers=num_workers,
        collate_fn=collate
    )
    
    # Initialize two models for the two views.
    model_view1 = DualEncoderSBERT(model_name="sentence-transformers/all-MiniLM-L12-v2").to(device)
    model_view2 = DualEncoderSBERT(model_name="sentence-transformers/all-MiniLM-L12-v2").to(device)
    
    optimizer = AdamW(list(model_view1.parameters()) + list(model_view2.parameters()), lr=lr)
    
    best_recall = 0.0
    best_epoch = 0
    
    for epoch in range(epochs):
        model_view1.train()
        model_view2.train()
        total_loss = 0.0
        
        for batch in tqdm(train_loader, desc=f"Training Epoch {epoch+1}/{epochs}"):
            view1_ids = batch["view1_input_ids"].to(device)
            view1_mask = batch["view1_attention_mask"].to(device)
            view2_ids = batch["view2_input_ids"].to(device)
            view2_mask = batch["view2_attention_mask"].to(device)
            
            emb1 = model_view1(view1_ids, view1_mask)
            emb2 = model_view2(view2_ids, view2_mask)
            
            loss = info_nce_loss(emb1, emb2, temperature)
            
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
            total_loss += loss.item()
        
        avg_loss = total_loss / len(train_loader)
        print(f"Epoch {epoch+1} - Training Loss: {avg_loss:.4f}")
        
        # Evaluate on validation set after each epoch
        recall_at_1, mrr = evaluate_model(model_view1, model_view2, val_loader, temperature)
        print(f"Epoch {epoch+1} - Validation Recall@1: {recall_at_1:.4f}, MRR: {mrr:.4f}")
        
        # Save best model based on Recall@1
        if recall_at_1 > best_recall:
            best_recall = recall_at_1
            best_epoch = epoch + 1
            torch.save(model_view1.state_dict(), "best_model_view1.pt")
            torch.save(model_view2.state_dict(), "best_model_view2.pt")
            print(f"New best model found at epoch {epoch+1} with Recall@1: {best_recall:.4f}")
    
    print(f"Training complete. Best model at epoch {best_epoch} with Recall@1: {best_recall:.4f}")


In [8]:
def test_evaluate(test_csv, max_length=128, batch_size=32, num_workers=4):
    # Load test dataset
    test_dataset = ContrastiveProductDataset(test_csv)
    tokenizer = AutoTokenizer.from_pretrained("sentence-transformers/all-MiniLM-L12-v2")
    collate = partial(my_collate_fn, tokenizer=tokenizer, max_length=max_length)
    test_loader = DataLoader(
        test_dataset,
        batch_size=batch_size,
        shuffle=False,
        num_workers=num_workers,
        collate_fn=collate
    )
    
    # Load the best models saved from training.
    model_view1 = DualEncoderSBERT(model_name="sentence-transformers/all-MiniLM-L12-v2").to(device)
    model_view2 = DualEncoderSBERT(model_name="sentence-transformers/all-MiniLM-L12-v2").to(device)
    model_view1.load_state_dict(torch.load("best_model_view1.pt", map_location=device))
    model_view2.load_state_dict(torch.load("best_model_view2.pt", map_location=device))
    
    recall_at_1, mrr = evaluate_model(model_view1, model_view2, test_loader)
    print(f"Test Evaluation - Recall@1: {recall_at_1:.4f}, MRR: {mrr:.4f}")


In [9]:
train_csv='/Users/sanamoin/Documents/sites/gadgets/recommendation_engine/data/filtered_splits/electronics_train.csv'
val_csv='/Users/sanamoin/Documents/sites/gadgets/recommendation_engine/data/filtered_splits/electronics_val.csv'
test_csv='/Users/sanamoin/Documents/sites/gadgets/recommendation_engine/data/filtered_splits/electronics_test.csv'


In [11]:
from multiprocessing import freeze_support

freeze_support()
# Train and evaluate on validation set.
train_and_evaluate(train_csv, val_csv, epochs=3, batch_size=32, lr=2e-5,
                   temperature=0.07, max_length=128, num_workers=4)


Training Epoch 1/3:   0%|          | 0/62251 [00:00<?, ?it/s]Traceback (most recent call last):
  File "<string>", line 1, in <module>
  File "/Library/Frameworks/Python.framework/Versions/3.12/lib/python3.12/multiprocessing/spawn.py", line 122, in spawn_main
    exitcode = _main(fd, parent_sentinel)
               ^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Library/Frameworks/Python.framework/Versions/3.12/lib/python3.12/multiprocessing/spawn.py", line 132, in _main
    self = reduction.pickle.load(from_parent)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
AttributeError: Can't get attribute 'ContrastiveProductDataset' on <module '__main__' (<class '_frozen_importlib.BuiltinImporter'>)>
Training Epoch 1/3:   0%|          | 0/62251 [00:23<?, ?it/s]


KeyboardInterrupt: 

In [None]:

# Finally, evaluate on the test set using the best saved model.
test_evaluate(test_csv, max_length=128, batch_size=32, num_workers=4)