In [1]:
PARQUET_PATH = "data/preprocessed.parquet"
VAL_SPLIT = 0.2
BATCH_SIZE = 2048
EMBEDDING_DIM = 128
NUM_EPOCHS = 5

In [2]:
import pandas as pd
from sklearn.model_selection import train_test_split


df = pd.read_parquet(PARQUET_PATH)

num_users = df["userid"].max() + 1
num_movies = df["movieid"].max() + 1
all_genres = [g for genre_list in df["genre_list"] for g in genre_list]
num_genres = max(all_genres) + 1

print(f"Dataset stats:")
print(f"  Users: {num_users}")
print(f"  Movies: {num_movies}")
print(f"  Genres: {num_genres}")
print(f"  Samples: {len(df)}")

Dataset stats:
  Users: 6041
  Movies: 3953
  Genres: 18
  Samples: 1000209


In [3]:
from torch.utils.data import DataLoader
from utils.collator import collate_fn
from utils.dataset import MovieLensDataset

train_df, val_df = train_test_split(df, test_size=VAL_SPLIT, random_state=42)
train_dataset = MovieLensDataset(train_df)
val_dataset = MovieLensDataset(val_df)

train_loader = DataLoader(
    train_dataset,
    batch_size=BATCH_SIZE,
    shuffle=True,
    collate_fn=collate_fn,
    num_workers=4,
    pin_memory=True,
)

val_loader = DataLoader(
    val_dataset,
    batch_size=BATCH_SIZE,
    shuffle=False,
    collate_fn=collate_fn,
    num_workers=4,
    pin_memory=True,
)

  ALLREDUCE = partial(_ddp_comm_hook_wrapper, comm_hook=default.allreduce_hook)
  FP16_COMPRESS = partial(
  BF16_COMPRESS = partial(
  QUANTIZE_PER_TENSOR = partial(
  QUANTIZE_PER_CHANNEL = partial(
  POWER_SGD = partial(
  POWER_SGD_RANK2 = partial(
  BATCHED_POWER_SGD = partial(
  BATCHED_POWER_SGD_RANK2 = partial(
  NOOP = partial(


In [4]:
import torch
from arch.dlrm import DLRMRecommender

dlrm_model = DLRMRecommender(
    num_users=num_users,
    num_movies=num_movies,
    num_genres=num_genres,
    embedding_dim=EMBEDDING_DIM,
    dense_arch_layer_sizes=[256, 128],
    over_arch_layer_sizes=[256, 128, 64, 1],
)

best_auc = 0

for epoch in range(NUM_EPOCHS):
    train_loss = dlrm_model.train_epoch(train_loader)
    val_metrics = dlrm_model.evaluate(val_loader)

    print(f"Epoch {epoch+1}/{NUM_EPOCHS}")
    print(f"  Train Loss: {train_loss:.4f}")
    print(f"  Val Loss: {val_metrics['loss']:.4f}")
    print(f"  Val AUC: {val_metrics['auc']:.4f}")
    print(f"  Val Accuracy: {val_metrics['accuracy']:.4f}")

    if val_metrics["auc"] > best_auc:
        best_auc = val_metrics["auc"]
        torch.save(dlrm_model.model.state_dict(), "best_dlrm_model.pt")
        print(f"  ✓ New best model saved! (AUC: {best_auc:.4f})")

    print()

print(f"Training complete! Best validation AUC: {best_auc:.4f}")

  0%|          | 0/391 [00:00<?, ?it/s]W1211 15:43:39.290000 30166 torch/fx/_symbolic_trace.py:52] is_fx_tracing will return true for both fx.symbolic_trace and torch.export. Please use is_fx_tracing_symbolic_tracing() for specifically fx.symbolic_trace or torch.compiler.is_compiling() for specifically torch.export/compile.
W1211 15:43:39.290000 30170 torch/fx/_symbolic_trace.py:52] is_fx_tracing will return true for both fx.symbolic_trace and torch.export. Please use is_fx_tracing_symbolic_tracing() for specifically fx.symbolic_trace or torch.compiler.is_compiling() for specifically torch.export/compile.
W1211 15:43:39.290000 30173 torch/fx/_symbolic_trace.py:52] is_fx_tracing will return true for both fx.symbolic_trace and torch.export. Please use is_fx_tracing_symbolic_tracing() for specifically fx.symbolic_trace or torch.compiler.is_compiling() for specifically torch.export/compile.
W1211 15:43:39.290000 30174 torch/fx/_symbolic_trace.py:52] is_fx_tracing will return true for both 

Epoch 1/5
  Train Loss: 0.5632
  Val Loss: 0.5364
  Val AUC: 0.7958
  Val Accuracy: 0.7295
  ✓ New best model saved! (AUC: 0.7958)



100%|██████████| 391/391 [00:09<00:00, 42.46it/s]


Epoch 2/5
  Train Loss: 0.5133
  Val Loss: 0.5309
  Val AUC: 0.8025
  Val Accuracy: 0.7343
  ✓ New best model saved! (AUC: 0.8025)



100%|██████████| 391/391 [00:07<00:00, 53.99it/s]


Epoch 3/5
  Train Loss: 0.4791
  Val Loss: 0.5317
  Val AUC: 0.8087
  Val Accuracy: 0.7378
  ✓ New best model saved! (AUC: 0.8087)



100%|██████████| 391/391 [00:07<00:00, 55.70it/s]


Epoch 4/5
  Train Loss: 0.4348
  Val Loss: 0.5468
  Val AUC: 0.8034
  Val Accuracy: 0.7317



100%|██████████| 391/391 [00:07<00:00, 51.13it/s]


Epoch 5/5
  Train Loss: 0.3705
  Val Loss: 0.6081
  Val AUC: 0.7915
  Val Accuracy: 0.7223

Training complete! Best validation AUC: 0.8087
