<a href="https://colab.research.google.com/github/yguo005/Recommendation_System/blob/main/Retrieval_v2.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [24]:
import torch

# PyTorch way (works if PyTorch installed)
gpu_available = torch.cuda.is_available()
print("GPU available:", gpu_available)
if gpu_available:
    print("GPU name:", torch.cuda.get_device_name(0))

GPU available: True
GPU name: Tesla T4


In [25]:
import pandas as pd
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import numpy as np
import time

device = "cuda" if torch.cuda.is_available() else "cpu"
print("Using device:", device)

Using device: cuda


In [26]:
def load_inter(path):
    return pd.read_csv(path, sep="\t", dtype=str)

df = load_inter("amazon-beauty-train.inter")
print(df.head())

# Keep only positive interactions
df["label"] = pd.to_numeric(df["label"], errors="coerce").fillna(0).astype(int)
df = df[df["label"] == 1].copy()

  user_id item_id    timestamp label
0    2238     657  908755200.0   1.0
1    2254     661  912297600.0   1.0
2    2274     671  921542400.0   1.0
3    2355     687  928368000.0   1.0
4    2197     647  937267200.0   1.0


In [27]:
df["user_idx"] = df["user_id"].astype("category").cat.codes
df["item_idx"] = df["item_id"].astype("category").cat.codes

num_users = df["user_idx"].nunique()
num_items = df["item_idx"].nunique()

print("Users:", num_users, "Items:", num_items)

Users: 1192466 Items: 210651


In [28]:
class InterDataset(Dataset):
    def __init__(self, df):
        self.users = torch.tensor(df["user_idx"].values, dtype=torch.long)
        self.items = torch.tensor(df["item_idx"].values, dtype=torch.long)

    def __len__(self):
        return len(self.users)

    def __getitem__(self, idx):
        return self.users[idx], self.items[idx]

dataset = InterDataset(df)
loader = DataLoader(dataset, batch_size=4096, shuffle=True)

In [29]:
class TwoTower(nn.Module):
    def __init__(self, num_users, num_items, emb_dim=64):
        super().__init__()
        self.user_emb = nn.Embedding(num_users, emb_dim)
        self.item_emb = nn.Embedding(num_items, emb_dim)

    def forward(self, user, item):
        u = self.user_emb(user)
        i = self.item_emb(item)
        # Dot product
        return (u * i).sum(dim=1)

model = TwoTower(num_users, num_items, 64).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
criterion = nn.BCEWithLogitsLoss()

In [30]:
epochs = 5
for epoch in range(epochs):
    model.train()
    total_loss = 0
    t0 = time.time()

    for user, item in loader:
        user, item = user.to(device), item.to(device)

        # positive examples
        pos_score = model(user, item)

        # negative sampling
        neg_items = torch.randint(0, num_items, item.shape, device=device)
        neg_score = model(user, neg_items)

        # labels
        labels = torch.cat([
            torch.ones_like(pos_score),
            torch.zeros_like(neg_score),
        ])

        scores = torch.cat([pos_score, neg_score])

        loss = criterion(scores, labels.float())

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        total_loss += loss.item()

    print(f"Epoch {epoch+1}/{epochs} - Loss: {total_loss:.4f} - Time: {time.time()-t0:.2f}s")

Epoch 1/5 - Loss: 1207.6955 - Time: 30.07s
Epoch 2/5 - Loss: 1130.8715 - Time: 29.23s
Epoch 3/5 - Loss: 1053.5127 - Time: 29.49s
Epoch 4/5 - Loss: 983.8007 - Time: 29.15s
Epoch 5/5 - Loss: 922.6302 - Time: 29.10s


In [31]:
user_emb = model.user_emb.weight.detach().cpu().numpy()
item_emb = model.item_emb.weight.detach().cpu().numpy()

np.save("user_embeddings.npy", user_emb)
np.save("item_embeddings.npy", item_emb)

print("Saved user_embeddings.npy and item_embeddings.npy")


Saved user_embeddings.npy and item_embeddings.npy


In [32]:
import faiss

dim = item_emb.shape[1]
index = faiss.IndexFlatIP(dim)  # inner product search
index.add(item_emb.astype("float32"))

faiss.write_index(index, "faiss_item_index.bin")
print("Saved faiss_item_index.bin")

Saved faiss_item_index.bin
