In [1]:
import torch
import pandas as pd
from torch.utils.data import Dataset


class MovieLensDataset(Dataset):
    def __init__(self, df: pd.DataFrame):
        self.userid = torch.tensor(df["userid"].values, dtype=torch.long)
        self.movieid = torch.tensor(df["movieid"].values, dtype=torch.long)
        self.gender = torch.tensor(df["gender"].values, dtype=torch.float32)
        self.rating_binary = torch.tensor(
            df["rating_binary"].values, dtype=torch.float32
        )

        self.genre_lists = df["genre_list"].values
        self.max_genres = max(len(g) for g in self.genre_lists)

    def __len__(self):
        return len(self.userid)

    def __getitem__(self, idx):
        genres = self.genre_lists[idx]
        genre_tensor = torch.zeros(self.max_genres, dtype=torch.long)
        genre_tensor[: len(genres)] = torch.tensor(genres, dtype=torch.long)

        return {
            "userid": self.userid[idx],
            "movieid": self.movieid[idx],
            "gender": self.gender[idx],
            "genres": genre_tensor,
            "genre_length": len(genres),
            "label": self.rating_binary[idx],
        }

In [2]:
from torchrec.sparse.jagged_tensor import KeyedJaggedTensor


def collate_fn(batch: list[dict]):
    batch_size = len(batch)
    userid = torch.stack([b["userid"] for b in batch])
    movieid = torch.stack([b["movieid"] for b in batch])

    genre_values = []
    genre_lengths = []
    for b in batch:
        length = b["genre_length"]
        genre_values.extend(b["genres"][:length].tolist())
        genre_lengths.append(length)

    sparse_features = KeyedJaggedTensor(
        keys=["userid", "movieid", "genres"],
        values=torch.cat(
            [userid, movieid, torch.tensor(genre_values, dtype=torch.long)]
        ),
        lengths=torch.tensor([1] * batch_size + [1] * batch_size + genre_lengths),
    )

    dense_features = torch.stack([b["gender"] for b in batch]).unsqueeze(1)
    labels = torch.stack([b["label"] for b in batch])
    return sparse_features, dense_features, labels

In [3]:
import numpy as np
import torch.nn as nn
from tqdm import tqdm
from torch.utils.data import DataLoader
from torchrec.models.dlrm import DLRM
from torchrec.modules.embedding_configs import EmbeddingBagConfig
from torchrec.modules.embedding_modules import EmbeddingBagCollection
from sklearn.metrics import roc_auc_score, accuracy_score


class DLRMRecommender:
    def __init__(
        self,
        num_users: int,
        num_movies: int,
        num_genres: int,
        embedding_dim: int = 64,
        dense_arch_layer_sizes: list[int] = [128, 64],
        over_arch_layer_sizes: list[int] = [128, 64, 1],
        device: str = "cuda" if torch.cuda.is_available() else "cpu",
    ):
        self.device = device

        eb_configs = [
            EmbeddingBagConfig(
                name="userid",
                embedding_dim=embedding_dim,
                num_embeddings=num_users,
                feature_names=["userid"],
            ),
            EmbeddingBagConfig(
                name="movieid",
                embedding_dim=embedding_dim,
                num_embeddings=num_movies,
                feature_names=["movieid"],
            ),
            EmbeddingBagConfig(
                name="genres",
                embedding_dim=embedding_dim,
                num_embeddings=num_genres,
                feature_names=["genres"],
            ),
        ]

        embedding_bag_collection = EmbeddingBagCollection(
            tables=eb_configs,
            device=torch.device(device),
        )

        self.model = DLRM(
            embedding_bag_collection=embedding_bag_collection,
            dense_in_features=1,
            dense_arch_layer_sizes=dense_arch_layer_sizes,
            over_arch_layer_sizes=over_arch_layer_sizes,
            dense_device=device,
        ).to(device)

        self.criterion = nn.BCEWithLogitsLoss()
        self.optimizer = torch.optim.Adam(self.model.parameters(), lr=0.001)

    def train_epoch(self, train_loader: DataLoader) -> float:
        self.model.train()
        total_loss = 0
        num_batches = 0

        for sparse_features, dense_features, labels in tqdm(train_loader, total=len(train_loader)):
            sparse_features = sparse_features.to(self.device)
            dense_features = dense_features.to(self.device)
            labels = labels.to(self.device)

            self.optimizer.zero_grad()

            logits = self.model(dense_features, sparse_features)
            loss = self.criterion(logits.squeeze(), labels)

            loss.backward()
            self.optimizer.step()

            total_loss += loss.item()
            num_batches += 1

        return total_loss / num_batches

    def evaluate(self, val_loader: DataLoader) -> dict[str, float]:
        self.model.eval()
        all_preds = []
        all_labels = []
        total_loss = 0
        num_batches = 0

        with torch.no_grad():
            for sparse_features, dense_features, labels in val_loader:
                sparse_features = sparse_features.to(self.device)
                dense_features = dense_features.to(self.device)
                labels = labels.to(self.device)

                logits = self.model(dense_features, sparse_features)
                loss = self.criterion(logits.squeeze(), labels)

                preds = torch.sigmoid(logits.squeeze())

                all_preds.extend(preds.cpu().numpy())
                all_labels.extend(labels.cpu().numpy())
                total_loss += loss.item()
                num_batches += 1

        all_preds = np.array(all_preds)
        all_labels = np.array(all_labels)

        auc = roc_auc_score(all_labels, all_preds)
        acc = accuracy_score(all_labels, (all_preds > 0.5).astype(int))

        return {"loss": total_loss / num_batches, "auc": auc, "accuracy": acc}

In [4]:
from sklearn.model_selection import train_test_split


def train_recommendation_system(
    parquet_path: str,
    num_epochs: int = 20,
    batch_size: int = 512,
    embedding_dim: int = 64,
    val_split: float = 0.2,
):
    df = pd.read_parquet(parquet_path)

    num_users = df["userid"].max() + 1
    num_movies = df["movieid"].max() + 1

    all_genres = [g for genre_list in df["genre_list"] for g in genre_list]
    num_genres = max(all_genres) + 1

    print(f"Dataset stats:")
    print(f"  Users: {num_users}")
    print(f"  Movies: {num_movies}")
    print(f"  Genres: {num_genres}")
    print(f"  Samples: {len(df)}")

    train_df, val_df = train_test_split(df, test_size=val_split, random_state=42)

    train_dataset = MovieLensDataset(train_df)
    val_dataset = MovieLensDataset(val_df)

    train_loader = DataLoader(
        train_dataset,
        batch_size=batch_size,
        shuffle=True,
        collate_fn=collate_fn,
        num_workers=4,
        pin_memory=True,
    )

    val_loader = DataLoader(
        val_dataset,
        batch_size=batch_size,
        shuffle=False,
        collate_fn=collate_fn,
        num_workers=4,
        pin_memory=True,
    )

    recommender = DLRMRecommender(
        num_users=num_users,
        num_movies=num_movies,
        num_genres=num_genres,
        embedding_dim=embedding_dim,
        dense_arch_layer_sizes=[256, 128],
        over_arch_layer_sizes=[256, 128, 64, 1],
    )

    best_auc = 0

    for epoch in range(num_epochs):
        train_loss = recommender.train_epoch(train_loader)
        val_metrics = recommender.evaluate(val_loader)

        print(f"Epoch {epoch+1}/{num_epochs}")
        print(f"  Train Loss: {train_loss:.4f}")
        print(f"  Val Loss: {val_metrics['loss']:.4f}")
        print(f"  Val AUC: {val_metrics['auc']:.4f}")
        print(f"  Val Accuracy: {val_metrics['accuracy']:.4f}")

        if val_metrics["auc"] > best_auc:
            best_auc = val_metrics["auc"]
            torch.save(recommender.model.state_dict(), "best_dlrm_model.pt")
            print(f"  ✓ New best model saved! (AUC: {best_auc:.4f})")

        print()

    print(f"Training complete! Best validation AUC: {best_auc:.4f}")
    return recommender

In [5]:
recommender = train_recommendation_system(
    parquet_path="data/preprocessed.parquet",
    num_epochs=3,
    batch_size=1024,
    embedding_dim=128,
    val_split=0.2,
)

Dataset stats:
  Users: 6041
  Movies: 3953
  Genres: 18
  Samples: 1000209


  0%|          | 0/782 [00:00<?, ?it/s]W1209 17:03:33.577000 31544 torch/fx/_symbolic_trace.py:52] is_fx_tracing will return true for both fx.symbolic_trace and torch.export. Please use is_fx_tracing_symbolic_tracing() for specifically fx.symbolic_trace or torch.compiler.is_compiling() for specifically torch.export/compile.
W1209 17:03:33.577000 31545 torch/fx/_symbolic_trace.py:52] is_fx_tracing will return true for both fx.symbolic_trace and torch.export. Please use is_fx_tracing_symbolic_tracing() for specifically fx.symbolic_trace or torch.compiler.is_compiling() for specifically torch.export/compile.
W1209 17:03:33.577000 31546 torch/fx/_symbolic_trace.py:52] is_fx_tracing will return true for both fx.symbolic_trace and torch.export. Please use is_fx_tracing_symbolic_tracing() for specifically fx.symbolic_trace or torch.compiler.is_compiling() for specifically torch.export/compile.
W1209 17:03:33.577000 31547 torch/fx/_symbolic_trace.py:52] is_fx_tracing will return true for both 

Epoch 1/3
  Train Loss: 0.5545
  Val Loss: 0.5337
  Val AUC: 0.7981
  Val Accuracy: 0.7315
  ✓ New best model saved! (AUC: 0.7981)



100%|██████████| 782/782 [00:10<00:00, 75.50it/s]


Epoch 2/3
  Train Loss: 0.5109
  Val Loss: 0.5232
  Val AUC: 0.8081
  Val Accuracy: 0.7391
  ✓ New best model saved! (AUC: 0.8081)



100%|██████████| 782/782 [00:10<00:00, 73.80it/s]


Epoch 3/3
  Train Loss: 0.4677
  Val Loss: 0.5238
  Val AUC: 0.8112
  Val Accuracy: 0.7405
  ✓ New best model saved! (AUC: 0.8112)

Training complete! Best validation AUC: 0.8112
