### Two tower model code learning from [this repo](https://github.com/gauravchak/two_tower_models)

In [4]:
import torch
import torch.nn as nn
import torch.nn.functional as F


In [5]:
print(torch.__version__)

2.5.1


In [6]:
torch.cuda.is_available()

True

In [7]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

Using device: cuda


#### Basic two tower Took example from Copilot

In [9]:
import pandas as pd

ratings = pd.read_csv(
    "https://files.grouplens.org/datasets/movielens/ml-100k/u.data",
    sep="\t",
    names=["user_id", "item_id", "rating", "timestamp"]
)

print(ratings.head())


   user_id  item_id  rating  timestamp
0      196      242       3  881250949
1      186      302       3  891717742
2       22      377       1  878887116
3      244       51       2  880606923
4      166      346       1  886397596


In [11]:
positive = ratings[ratings["rating"] >= 4]

In [12]:
pos_pairs = positive[["user_id", "item_id"]].values

In [13]:
import numpy as np

num_items = ratings["item_id"].max() + 1

def sample_negative(user_id, pos_item):
    neg = np.random.randint(1, num_items)
    while neg == pos_item:
        neg = np.random.randint(1, num_items)
    return neg

neg_items = [sample_negative(u, i) for u, i in pos_pairs]


In [14]:
user_ids = pos_pairs[:, 0]
pos_item_ids = pos_pairs[:, 1]
neg_item_ids = np.array(neg_items)


In [25]:
import torch

class MovieLensDataset(torch.utils.data.Dataset):
    def __init__(self, users, pos_items, neg_items):
        self.users = torch.tensor(users, dtype=torch.long,device=device)
        self.pos_items = torch.tensor(pos_items, dtype=torch.long,device=device)
        self.neg_items = torch.tensor(neg_items, dtype=torch.long, device=device)

    def __len__(self):
        return len(self.users)

    def __getitem__(self, idx):
        return {
            "user": self.users[idx],
            "pos_item": self.pos_items[idx],
            "neg_item": self.neg_items[idx],
        }

dataset = MovieLensDataset(user_ids, pos_item_ids, neg_item_ids)
loader = torch.utils.data.DataLoader(dataset, batch_size=256, shuffle=True)


In [26]:
print(user_ids.min(), user_ids.max())
print(pos_item_ids.min(), pos_item_ids.max())
print(neg_item_ids.min(), neg_item_ids.max())

1 943
1 1674
1 1682


In [27]:
class UserTower(torch.nn.Module):
    def __init__(self, num_users, embed_dim=64):
        super().__init__()
        self.embedding = torch.nn.Embedding(num_users + 1, embed_dim,device=device)

    def forward(self, user_ids):
        return torch.nn.functional.normalize(self.embedding(user_ids), dim=-1)


In [28]:
class ItemTower(torch.nn.Module):
    def __init__(self, num_items, embed_dim=64):
        super().__init__()
        self.embedding = torch.nn.Embedding(num_items + 1, embed_dim,device=device)

    def forward(self, item_ids):
        return torch.nn.functional.normalize(self.embedding(item_ids), dim=-1)


In [32]:
def similarity(u, v):
    return (u * v).sum(dim=-1)  # Dot product

In [33]:
def contrastive_loss(user_emb, pos_emb, neg_emb, margin=0.2):
    pos_score = similarity(user_emb, pos_emb)
    neg_score = similarity(user_emb, neg_emb)
    loss = torch.relu(margin + neg_score - pos_score)
    return loss.mean()


In [34]:
user_tower = UserTower(num_users=ratings["user_id"].max())
item_tower = ItemTower(num_items=ratings["item_id"].max())

optimizer = torch.optim.Adam(
    list(user_tower.parameters()) + list(item_tower.parameters()),
    lr=1e-3
)

for epoch in range(5):
    for batch in loader:
        user = batch["user"]
        pos_item = batch["pos_item"]
        neg_item = batch["neg_item"]

        user_emb = user_tower(user)
        pos_emb = item_tower(pos_item)
        neg_emb = item_tower(neg_item)

        loss = contrastive_loss(user_emb, pos_emb, neg_emb)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    print(f"Epoch {epoch}, Loss: {loss.item():.4f}")


Epoch 0, Loss: 0.2098
Epoch 1, Loss: 0.1925
Epoch 2, Loss: 0.1978
Epoch 3, Loss: 0.2163
Epoch 4, Loss: 0.2108


#### Making Recommendations (Topâ€‘K Retrieval)

In [35]:
all_items = torch.arange(1, num_items,device=device)
item_embs = item_tower(all_items)

In [36]:
user_id = 10
user_emb = user_tower(torch.tensor([user_id],device=device))


In [37]:
scores = (user_emb * item_embs).sum(dim=-1)
topk = torch.topk(scores, k=10)
print("Recommended item IDs:", topk.indices.tolist())
print("Scores:", topk.values.tolist())

Recommended item IDs: [252, 15, 1199, 863, 713, 1310, 351, 1574, 1088, 183]
Scores: [0.4566866457462311, 0.38565975427627563, 0.38276970386505127, 0.38113704323768616, 0.36996063590049744, 0.3659084141254425, 0.35912391543388367, 0.3317814767360687, 0.32569026947021484, 0.32416480779647827]


#### Basic Mips

In [39]:
from typing import Tuple
import torch
import torch.nn as nn

In [40]:
class BaselineMIPSModule(nn.Module):
    def __init__(
        self, 
        corpus_size:int, 
        embedding_dim:int) -> None:
        super(BaselineMIPSModule,self).__init__()
        self.corpus_size = corpus_size
        self.embedding_dim = embedding_dim
        # Create Random Corpus
        self.corpus = torch.randn(corpus_size, embedding_dim,device=device) # [C, DI]
    
    def forward(
        self,
        query_embedding: torch.Tensor, # [B, DI],
        num_items: int, # (NI)
    ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        mips_score, indices = torch.topk(
            torch.matmul(query_embedding, self.corpus.T),
            k=num_items,
            dim=1
        )
        
        expanded_indices = indices.unsqueeze(2)
        embeddings = self.corpus[expanded_indices]
        embeddings = embeddings.squeeze(2)
        
        return indices, mips_score, embeddings

#### Two Tower base retrieval

In [41]:
from typing import List, Tuple
import torch
import torch.nn as nn
import torch.nn.functional as F

##### Define Arch

In [48]:
class TwoTowerBaseRetrieval(nn.Module):
    def __init__(
        self,
        num_items:int,
        user_id_hash_size: int,
        user_id_embedding_dim: int,
        user_features_size: int,
        item_id_hash_size: int,
        item_id_embedding_dim: int,
        item_features_size:int,
        user_value_weights: List[float],
        mips_module: BaselineMIPSModule,
    ) -> None:
        super().__init__()
        self.num_items = num_items
        
        self.user_value_weights = torch.Tensor(user_value_weights).to(device)
        self.mips_module = mips_module
        
        # Create the machinery for user tower

        # 1. Create a module to represent user preference by a table lookup.
        # Please see https://github.com/gauravchak/user_preference_modeling
        # for other ways to represent user preference embedding.
        self.user_id_embedding_arch = nn.Embedding(
            user_id_hash_size, user_id_embedding_dim,device=device
            )
        # 2. Create an arch to process the user_features. We are using one
        # hidden layer of 256 dimensions. This is just a reasonable default.
        # You can experiment with other architectures.
        self.user_features_arch = nn.Sequential(
            nn.Linear(user_features_size, 256,device=device),
            nn.ReLU(),
            nn.Linear(256, user_id_embedding_dim,device=device),
        )
        
        # 3. Create an arch to process the user_tower_input
        # Input dimension =
        #   user_id_embedding_dim from get_user_embedding,
        #      essentially based on user_id
        #   + user_id_embedding_dim from user_features_arch,
        #      essentially based on user_features
        # Output dimension = item_id_embedding_dim
        # The output of this arch will be used for MIPS module.
        # Hence the output dimension needs to be same as the item tower output.
        self.user_tower_arch = nn.Linear(
            in_features=2 * user_id_embedding_dim,
            out_features=item_id_embedding_dim,
            device=device
        )
        
        # Create arch for item tower

        # 1. Embedding layers for item id
        self.item_id_embedding_arch = nn.Embedding(
            item_id_hash_size,
            item_id_embedding_dim,
            device=device
        )
        
        # 2. Create an arch to process the item_features
        self.item_featurs_arch = nn.Sequential(
            nn.Linear(item_features_size, 256,device=device),
            nn.ReLU(),
            nn.Linear(256, item_id_embedding_dim,device=device),
        )
        
        # 3. Create an arch to process the item_tower_input
        self.item_tower_arch = nn.Linear(
            in_features=2 * item_id_embedding_dim,
            out_features=item_id_embedding_dim,
            device=device
        )
    def get_user_embedding(
        self,
        user_id: torch.Tensor,
        user_features: torch.Tensor,
    ) -> torch.Tensor:
        user_id_embedding = self.user_id_embedding_arch(user_id)
        return user_id_embedding
    
    def process_user_features(
        self,
        user_id: torch.Tensor,
        user_features: torch.Tensor,
        user_history: torch.Tensor,
    ) -> torch.Tensor:
        user_id_emb = self.get_user_embedding(user_id, user_features) # [B, DU]
        user_features_emb = self.user_features_arch(user_features) #[B, DU]
        user_tower_input = torch.cat([user_id_emb, user_features_emb], dim=1) # [B, 2*DU]
        return user_tower_input
    
    def compute_user_embedding(
        self,
        user_id: torch.Tensor,
        user_features: torch.Tensor,
        user_history: torch.Tensor,
    ) -> torch.Tensor:
        user_tower_input = self.process_user_features(
            user_id, user_features, user_history
        ) # [B, 2*DU]
        # Compute user embedding
        user_embedding = self.user_tower_arch(user_tower_input) # [B, DI]
        return user_embedding
    
    def compute_item_embedding(
        self,
        item_id: torch.Tensor,
        item_features: torch.Tensor,
    ) -> torch.Tensor:
        # Process item id
        item_id_emb = self.item_id_embedding_arch(item_id) # [B, DI]
        # Process item features
        item_features_emb = self.item_featurs_arch(item_features) # [B, DI]
        # Concatenate item id embedding and item features embedding
        item_tower_input = torch.cat([item_id_emb, item_features_emb], dim=1) # [B, 2*DI]
        # Compute item embedding
        item_embedding = self.item_tower_arch(item_tower_input) # [B, DI]
        return item_embedding
    
    def debias_net_user_value(
        self, 
        net_user_value: torch.Tensor, # [B]
        position: torch.Tensor, # [B]
        user_embeddings: torch.Tensor, # [B, DI]
    ) -> Tuple[torch.Tensor, torch.Tensor]:

        return net_user_value, torch.zeros_like(net_user_value,device=device)
    
    def compute_training_loss(
        self,
        user_embeddings: torch.Tensor, #[B, DI ]
        item_embeddings: torch.Tensor, #[B, DI ]
        position: torch.Tensor, #[B, ]
        labels: torch.Tensor, #[B, T]
    ) -> torch.Tensor:
        # Compute similarity scores
        scores = torch.matmul(
            user_embeddings, # [B, DU]
            item_embeddings.T,  # [B, DI]
        ) # [B, B]
        
        target = torch.arange(scores.shape[0]).to(scores.device) #[B, ]
        loss = F.cross_entropy(scores, target, reduction='none') #[B, ]
        
        net_user_value = torch.sum(
            labels * self.user_value_weights,
            dim=-1 
        ) #[B, ]
        
        # optional debiasing step
        net_user_value, additional_loss = self.debias_net_user_value(
            net_user_value=net_user_value,
            position=position,
            user_embeddings=user_embeddings,
        ) #[B], [1]
        
        # Floor by epsilon to avoid zero division
        net_user_value = torch.clamp(net_user_value, min=0.000001) #[B, ]
        
        # Normalize net user value by the max value in the batch
        net_user_value = net_user_value / torch.max(net_user_value) #[B, ]
        
        # compute the product of loss and net user value
        weighted_loss = loss * net_user_value #[B, ]
        loss = torch.mean(weighted_loss) #[1]
        return loss
    
    def train_forward(
        self,
        user_id: torch.Tensor, #[B]
        user_features: torch.Tensor, #[B, IU]
        user_history: torch.Tensor, #[B, H]
        item_id: torch.Tensor, # [B]
        item_features: torch.Tensor, #[B, II]
        position: torch.Tensor, #[B]
        labels: torch.Tensor, #[B, T]
        ) -> float:
        # compute user embeddings
        user_embeddings = self.compute_user_embedding(
            user_id, user_features, user_history
        )
        
        # compute item embseddings
        
        item_embeddings = self.compute_item_embeddings(
            item_id, item_features
        ) # [B, DI]
        
        loss = self.compute_training_loss(
            user_embeddings = user_embeddings,
            item_embeddings = item_embeddings,
            position = position,
            labels = labels,
        )
        return loss
    
    def forward(
        self,
        user_id: torch.Tensor,  # [B]
        user_features: torch.Tensor,  # [B, IU]
        user_history: torch.Tensor,  # [B, H]
    ) -> torch.Tensor:
        """This is used for inference.
        Compute the user embedding and return the top num_items items using the mips module.
        Args:
            user_id (torch.Tensor): Tensor representing the user ID. Shape: [B]
            user_features (torch.Tensor): Tensor representing the user features. Shape: [B, IU]
            user_history (torch.Tensor): Tensor representing the user history. Shape: [B, H]
        Returns:
            torch.Tensor: Tensor representing the top num_items items. Shape: [B, num_items]
        """
        # Compute the user embedding
        user_embedding = self.compute_user_embedding(
            user_id, user_features, user_history
        )
        # Query the mips module to get the top num_items items and their
        # embeddings. The embeddings aren't strictly necessary in the base
        # implementation.
        top_items, _, _ = self.mips_module(
            query_embedding=user_embedding, num_items=self.num_items
        )  # indices [B, num_items], mips_scores [B, NI], embeddings [B, NI, DI]  # noqa
        return top_items

##### Helper Fxn

In [49]:
num_items = 10
user_id_hash_size = 100
user_id_embedding_dim = 50
user_features_size = 20
item_id_hash_size = 150
item_id_embedding_dim = 40
item_features_size = 30
tasknum: int = 3
user_value_weights = [0.1, 0.2, 0.3]  # dimension = tasknum
user_history_seqlen: int = 128
corpus_size: int = 1001
mips_module = BaselineMIPSModule(
    corpus_size=corpus_size,
    embedding_dim=item_id_embedding_dim,
)
candidate_generator = TwoTowerBaseRetrieval(
    num_items=num_items,
    user_id_hash_size=user_id_hash_size,
    user_id_embedding_dim=user_id_embedding_dim,
    user_features_size=user_features_size,
    item_id_hash_size=item_id_hash_size,
    item_id_embedding_dim=item_id_embedding_dim,
    item_features_size=item_features_size,
    user_value_weights=user_value_weights,
    mips_module=mips_module,
)
batch_size = 32
user_id = torch.randint(
    0, user_id_hash_size, (batch_size,),device=device
)
user_features = torch.randn(
    batch_size, user_features_size,device=device
)
user_history = torch.randint(
    low=0,
    high=num_items,
    size=(batch_size, user_history_seqlen),device=device
)

item_recommendations = candidate_generator(
            user_id, user_features, user_history
        )

In [50]:
print(item_recommendations.shape)
print(torch.Size([batch_size, candidate_generator.num_items]))

torch.Size([32, 10])
torch.Size([32, 10])
