<a href="https://colab.research.google.com/github/saandeep17/two_tower_retrieval/blob/main/tte_retrieval_random_data.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
pip install faiss-cpu

Collecting faiss-cpu
  Downloading faiss_cpu-1.10.0-cp311-cp311-manylinux_2_28_x86_64.whl.metadata (4.4 kB)
Downloading faiss_cpu-1.10.0-cp311-cp311-manylinux_2_28_x86_64.whl (30.7 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m30.7/30.7 MB[0m [31m51.1 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: faiss-cpu
Successfully installed faiss-cpu-1.10.0


In [None]:
import torch
import torch.nn as nn
import numpy as np
import faiss

# Random data
num_users = 1000
num_items = 2000
embedding_dim = 32
interactions = 10000

np.random.seed(42)
user_ids = np.random.randint(0, num_users, interactions)
item_ids = np.random.randint(0, num_items, interactions)
labels = np.random.randint(0, 2, interactions)

# Simple two-tower model
class TwoTowerModel(nn.Module):
  def __init__(self, num_users, num_items, embedding_dim):
    super().__init__()
    self.user_embedding = nn.Embedding(num_users, embedding_dim)
    self.item_embedding = nn.Embedding(num_items, embedding_dim)

  def forward(self, user_ids, item_ids):
    u = self.user_embedding(user_ids)
    v = self.item_embedding(item_ids)
    score = (u*v).sum(dim=1)
    return score

  def encode_user(self, user_id):
    return self.user_embedding(user_id)

  def encode_item(self, item_id):
    return self.item_embedding(item_id)


# model training
model = TwoTowerModel(num_users, num_items, embedding_dim)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-2)
loss_fn = nn.BCEWithLogitsLoss()

BATCH_SIZE = 256
EPOCHS = 3

for epoch in range(EPOCHS):
  perm = np.random.permutation(interactions)
  for i in range(0, interactions, BATCH_SIZE):
    batch_idx = perm[i:i+BATCH_SIZE]
    u = torch.LongTensor(user_ids[batch_idx])
    v  = torch.LongTensor(item_ids[batch_idx])
    y = torch.FloatTensor(labels[batch_idx])

    optimizer.zero_grad()
    scores = model(u,v)
    loss = loss_fn(scores, y)
    loss.backward()
    optimizer.step()
  print(f"Epoch {epoch+1} loss: {loss.item():.4f}")









Epoch 1 loss: 3.3311
Epoch 2 loss: 1.5438
Epoch 3 loss: 0.4948
