In [1]:
import logging
import os
import random

import time

from datetime import date
from typing import Dict, Optional

import gin

import torch
import torch.distributed as dist

from generative_recommenders.research.data.eval import (
    _avg,
    add_to_summary_writer,
    eval_metrics_v2_from_tensors,
    get_eval_state,
)

from generative_recommenders.research.data.reco_dataset import get_reco_dataset
from generative_recommenders.research.indexing.utils import get_top_k_module
from generative_recommenders.research.modeling.sequential.autoregressive_losses import (
    BCELoss,
    InBatchNegativesSampler,
    LocalNegativesSampler,
)
from generative_recommenders.research.modeling.sequential.embedding_modules import (
    EmbeddingModule,
    LocalEmbeddingModule,
)
from generative_recommenders.research.modeling.sequential.encoder_utils import (
    get_sequential_encoder,
)
from generative_recommenders.research.modeling.sequential.features import (
    movielens_seq_features_from_row,
)
from generative_recommenders.research.modeling.sequential.input_features_preprocessors import (
    LearnablePositionalEmbeddingInputFeaturesPreprocessor,
)
from generative_recommenders.research.modeling.sequential.losses.sampled_softmax import (
    SampledSoftmaxLoss,
)
from generative_recommenders.research.modeling.sequential.output_postprocessors import (
    L2NormEmbeddingPostprocessor,
    LayerNormEmbeddingPostprocessor,
)
from generative_recommenders.research.modeling.similarity_utils import (
    get_similarity_function,
)
from generative_recommenders.research.trainer.data_loader import create_data_loader
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.tensorboard import SummaryWriter

In [4]:
rank: int = 0
world_size: int = 1
master_port: int = 12355
dataset_name: str = "ml-1m"
max_sequence_length: int = 200
positional_sampling_ratio: float = 1.0
local_batch_size: int = 128
eval_batch_size: int = 128
eval_user_max_batch_size: Optional[int] = None
main_module: str = "SASRec"
main_module_bf16: bool = False
dropout_rate: float = 0.2
user_embedding_norm: str = "l2_norm"
sampling_strategy: str = "in-batch"
loss_module: str = "SampledSoftmaxLoss"
loss_weights: Optional[Dict[str, float]] = {}
num_negatives: int = 1
loss_activation_checkpoint: bool = False
item_l2_norm: bool = False
temperature: float = 0.05
num_epochs: int = 101
learning_rate: float = 1e-3
num_warmup_steps: int = 0
weight_decay: float = 1e-3
top_k_method: str = "MIPSBruteForceTopK"
eval_interval: int = 100
full_eval_every_n: int = 1
save_ckpt_every_n: int = 1000
partial_eval_num_iters: int = 32
embedding_module_type: str = "local"
item_embedding_dim: int = 240
interaction_module_type: str = ""
gr_output_length: int = 10
l2_norm_eps: float = 1e-6
enable_tf32: bool = False
random_seed: int = 42

In [5]:
print(dataset_name)

ml-1m


In [6]:
dataset = get_reco_dataset(
    dataset_name=dataset_name,
    max_sequence_length=max_sequence_length,
    chronological=True,
    positional_sampling_ratio=positional_sampling_ratio,
)

In [7]:
train_data_sampler, train_data_loader = create_data_loader(
    dataset.train_dataset,
    batch_size=local_batch_size,
    world_size=world_size,
    rank=rank,
    shuffle=True,
    drop_last=world_size > 1,
)

In [8]:
print(type(train_data_sampler))
print(type(train_data_loader))

<class 'torch.utils.data.distributed.DistributedSampler'>
<class 'torch.utils.data.dataloader.DataLoader'>


In [19]:
for i, batch in enumerate(train_data_loader):
    print(f"batch {i + 1} data type: {type(batch)}")
    print(f"batch {i + 1} data shape: {batch.shape if isinstance(batch, torch.Tensor) else None}")
    
    print(f"batch {i + 1}")
    for key in batch:
        print(f"  key: {key}")
        print(f"    data type: {type(batch[key])}")
        print(f"    data shape: {batch[key].shape if isinstance(batch[key], torch.Tensor) else None}")
        print(f"    data: {batch[key]}")

    if i >= 0:
        break

batch 1 data type: <class 'dict'>
batch 1 data shape: None
batch 1
  key: user_id
    data type: <class 'torch.Tensor'>
    data shape: torch.Size([128])
    data: tensor([5536, 4688,  343,  238, 3595, 5447, 5372,  100, 3076, 5615, 2641, 2955,
        4430, 2527, 3897, 1687, 1807, 2045,  576, 5661, 5854, 3005, 4033,  595,
        2914, 1539, 1366,  866, 3014, 2250, 2656,  232, 3272,  181, 4890, 3320,
        4363, 1465, 2409,  113, 5925, 4348, 5381, 2156, 5380, 1091, 5383, 4141,
        5848, 3424, 2565, 5841, 5050, 2731,  708,  142,  365,  404, 3000, 1471,
        5816,  677, 4309,  290, 1510, 5939, 2843, 2298, 3725,  141, 2496, 1430,
        1741, 2978, 5206, 5820, 2880,  535, 5026,  437,  291, 3708,  983, 4030,
        3499, 2196,  855, 3347, 2229, 2121, 4856, 2485, 4361,    9, 2989, 2524,
        5482, 1160, 4591,  783, 2697, 3810, 4446, 5129, 5956, 1296, 2366, 3101,
        3295, 3665, 1369, 5983, 2546,  884,  797, 4967, 4228, 5587, 5144, 2048,
        3768,  944,  665, 3865, 3891