### Two tower model code learning from [this repo](https://github.com/gauravchak/two_tower_models)

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F


In [71]:
print(torch.__version__)

2.5.1


In [73]:
torch.cuda.is_available()

True

In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

Using device: cuda


#### Basic two tower Took example from Copilot

In [21]:
import pandas as pd

ratings = pd.read_csv(
    "https://files.grouplens.org/datasets/movielens/ml-100k/u.data",
    sep="\t",
    names=["user_id", "item_id", "rating", "timestamp"]
)

print(ratings.head())


   user_id  item_id  rating  timestamp
0      196      242       3  881250949
1      186      302       3  891717742
2       22      377       1  878887116
3      244       51       2  880606923
4      166      346       1  886397596


In [4]:
positive = ratings[ratings["rating"] >= 4]

In [5]:
pos_pairs = positive[["user_id", "item_id"]].values

In [22]:
import numpy as np

num_items = ratings["item_id"].max() + 1

def sample_negative(user_id, pos_item):
    neg = np.random.randint(1, num_items)
    while neg == pos_item:
        neg = np.random.randint(1, num_items)
    return neg

neg_items = [sample_negative(u, i) for u, i in pos_pairs]


In [23]:
user_ids = pos_pairs[:, 0]
pos_item_ids = pos_pairs[:, 1]
neg_item_ids = np.array(neg_items)


In [24]:
import torch

class MovieLensDataset(torch.utils.data.Dataset):
    def __init__(self, users, pos_items, neg_items):
        self.users = torch.tensor(users, dtype=torch.long)
        self.pos_items = torch.tensor(pos_items, dtype=torch.long)
        self.neg_items = torch.tensor(neg_items, dtype=torch.long)

    def __len__(self):
        return len(self.users)

    def __getitem__(self, idx):
        return {
            "user": self.users[idx],
            "pos_item": self.pos_items[idx],
            "neg_item": self.neg_items[idx],
        }

dataset = MovieLensDataset(user_ids, pos_item_ids, neg_item_ids)
loader = torch.utils.data.DataLoader(dataset, batch_size=256, shuffle=True)


In [16]:
print(user_ids.min(), user_ids.max())
print(pos_item_ids.min(), pos_item_ids.max())
print(neg_item_ids.min(), neg_item_ids.max())

1 943
1 1674
1 1682


In [25]:
class UserTower(torch.nn.Module):
    def __init__(self, num_users, embed_dim=64):
        super().__init__()
        self.embedding = torch.nn.Embedding(num_users + 1, embed_dim)

    def forward(self, user_ids):
        return torch.nn.functional.normalize(self.embedding(user_ids), dim=-1)


In [26]:
class ItemTower(torch.nn.Module):
    def __init__(self, num_items, embed_dim=64):
        super().__init__()
        self.embedding = torch.nn.Embedding(num_items + 1, embed_dim)

    def forward(self, item_ids):
        return torch.nn.functional.normalize(self.embedding(item_ids), dim=-1)


In [18]:
def similarity(u, v):
    return (u * v).sum(dim=-1)

In [27]:
def contrastive_loss(user_emb, pos_emb, neg_emb, margin=0.2):
    pos_score = similarity(user_emb, pos_emb)
    neg_score = similarity(user_emb, neg_emb)
    loss = torch.relu(margin + neg_score - pos_score)
    return loss.mean()


In [28]:
user_tower = UserTower(num_users=ratings["user_id"].max())
item_tower = ItemTower(num_items=ratings["item_id"].max())

optimizer = torch.optim.Adam(
    list(user_tower.parameters()) + list(item_tower.parameters()),
    lr=1e-3
)

for epoch in range(5):
    for batch in loader:
        user = batch["user"]
        pos_item = batch["pos_item"]
        neg_item = batch["neg_item"]

        user_emb = user_tower(user)
        pos_emb = item_tower(pos_item)
        neg_emb = item_tower(neg_item)

        loss = contrastive_loss(user_emb, pos_emb, neg_emb)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    print(f"Epoch {epoch}, Loss: {loss.item():.4f}")


Epoch 0, Loss: 0.2167
Epoch 1, Loss: 0.2007
Epoch 2, Loss: 0.2159
Epoch 3, Loss: 0.1938
Epoch 4, Loss: 0.1702


#### Making Recommendations (Topâ€‘K Retrieval)

In [29]:
all_items = torch.arange(1, num_items)
item_embs = item_tower(all_items)

In [30]:
user_id = 10
user_emb = user_tower(torch.tensor([user_id]))


In [31]:
scores = (user_emb * item_embs).sum(dim=-1)
topk = torch.topk(scores, k=10)
print("Recommended item IDs:", topk.indices.tolist())
print("Scores:", topk.values.tolist())

Recommended item IDs: [1622, 1233, 79, 1024, 81, 1204, 391, 755, 942, 1048]
Scores: [0.3805980980396271, 0.3712553083896637, 0.3366747498512268, 0.33095109462738037, 0.32895219326019287, 0.3208177089691162, 0.32034242153167725, 0.31708914041519165, 0.3086739480495453, 0.2951280474662781]


#### Basic Mips

In [33]:
from typing import Tuple
import torch
import torch.nn as nn

In [34]:
class BaselineMIPSModule(nn.Module):
    def __init__(
        self, 
        corpus_size:int, 
        embedding_dim:int) -> None:
        super(BaselineMIPSModule,self).__init__()
        self.corpus_size = corpus_size
        self.embedding_dim = embedding_dim
        # Create Random Corpus
        self.corpus = torch.randn(corpus_size, embedding_dim) # [C, DI]
    
    def forward(
        self,
        query_embedding: torch.Tensor, # [B, DI],
        num_items: int, # (NI)
    ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        mips_score, indices = torch.topk(
            torch.matmul(query_embedding, self.corpus.T),
            k=num_items,
            dim=1
        )
        
        expanded_indices = indices.unsqueeze(2)
        embeddings = self.corpus[expanded_indices]
        embeddings = embeddings.squeeze(2)
        
        return indices, mips_score, embeddings

#### Two Tower base retrieval

In [35]:
from typing import List, Tuple
import torch
import torch.nn as nn
import torch.nn.functional as F

In [None]:
class TwoTowerBaseRetrieval(nn.Module):
    def __init__(
        self,
        num_items:int,
        user_id_hash_size: int,
        user_id_embedding_dim: int,
        user_feature_size: int,
        item_id_hash_size: int,
        item_id_embedding_dim: int,
        item_feature_size:int,
        user_value_weights: List[float],
        mips_module: BaselineMIPSModule,
    ) -> None:
        super().__init__()
        self.num_items = num_items
        