# 1. Set up

In [117]:
import logging
import os
import random

import time

from datetime import date
from typing import Dict, Optional

import gin

import torch
import torch.distributed as dist

from generative_recommenders.research.data.eval import (
    _avg,
    add_to_summary_writer,
    eval_metrics_v2_from_tensors,
    get_eval_state,
)

from generative_recommenders.research.data.reco_dataset import get_reco_dataset
from generative_recommenders.research.indexing.utils import get_top_k_module
from generative_recommenders.research.modeling.sequential.autoregressive_losses import (
    BCELoss,
    InBatchNegativesSampler,
    LocalNegativesSampler,
)
from generative_recommenders.research.modeling.sequential.embedding_modules import (
    EmbeddingModule,
    LocalEmbeddingModule,
)
from generative_recommenders.research.modeling.sequential.encoder_utils import (
    get_sequential_encoder,
)
from generative_recommenders.research.modeling.sequential.features import (
    movielens_seq_features_from_row,
)
from generative_recommenders.research.modeling.sequential.input_features_preprocessors import (
    LearnablePositionalEmbeddingInputFeaturesPreprocessor,
)
from generative_recommenders.research.modeling.sequential.losses.sampled_softmax import (
    SampledSoftmaxLoss,
)
from generative_recommenders.research.modeling.sequential.output_postprocessors import (
    L2NormEmbeddingPostprocessor,
    LayerNormEmbeddingPostprocessor,
)
from generative_recommenders.research.modeling.similarity_utils import (
    get_similarity_function,
)
from generative_recommenders.research.trainer.data_loader import create_data_loader
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.tensorboard import SummaryWriter

In [118]:
rank: int = 0
world_size: int = 1
master_port: int = 12355
dataset_name: str = "ml-1m"
max_sequence_length: int = 200
positional_sampling_ratio: float = 1.0
local_batch_size: int = 128
eval_batch_size: int = 128
eval_user_max_batch_size: Optional[int] = None
main_module: str = "HSTU"
main_module_bf16: bool = False
dropout_rate: float = 0.2
user_embedding_norm: str = "l2_norm"
sampling_strategy: str = "in-batch"
loss_module: str = "SampledSoftmaxLoss"
loss_weights: Optional[Dict[str, float]] = {}
num_negatives: int = 1
loss_activation_checkpoint: bool = False
item_l2_norm: bool = False
temperature: float = 0.05
num_epochs: int = 101
learning_rate: float = 1e-3
num_warmup_steps: int = 0
weight_decay: float = 1e-3
top_k_method: str = "MIPSBruteForceTopK"
eval_interval: int = 100
full_eval_every_n: int = 1
save_ckpt_every_n: int = 1000
partial_eval_num_iters: int = 32
embedding_module_type: str = "local"
item_embedding_dim: int = 240
interaction_module_type: str = ""
gr_output_length: int = 10
l2_norm_eps: float = 1e-6
enable_tf32: bool = False
random_seed: int = 42

# 2. Dataset

In [119]:
print(dataset_name)

ml-1m


In [120]:
dataset = get_reco_dataset(
    dataset_name=dataset_name,
    max_sequence_length=max_sequence_length,
    chronological=True,
    positional_sampling_ratio=positional_sampling_ratio,
)

In [121]:
train_data_sampler, train_data_loader = create_data_loader(
    dataset.train_dataset,
    batch_size=local_batch_size,
    world_size=world_size,
    rank=rank,
    shuffle=True,
    drop_last=world_size > 1,
)

In [122]:
print(type(train_data_sampler))
print(type(train_data_loader))

<class 'torch.utils.data.distributed.DistributedSampler'>
<class 'torch.utils.data.dataloader.DataLoader'>


In [123]:
batch = next(iter(train_data_loader))
print(batch.keys())

dict_keys(['user_id', 'historical_ids', 'historical_ratings', 'historical_timestamps', 'history_lengths', 'target_ids', 'target_ratings', 'target_timestamps'])


In [124]:
user_id = batch["user_id"]
historical_ids = batch["historical_ids"]
historical_ratings = batch["historical_ratings"]
historical_timestamps = batch["historical_timestamps"]
history_lengths = batch["history_lengths"]
target_ids = batch["target_ids"]
target_ratings = batch["target_ratings"]
target_timestamps = batch["target_timestamps"]

In [125]:
print(user_id.shape)
print(user_id)

torch.Size([128])
tensor([5536, 4688,  343,  238, 3595, 5447, 5372,  100, 3076, 5615, 2641, 2955,
        4430, 2527, 3897, 1687, 1807, 2045,  576, 5661, 5854, 3005, 4033,  595,
        2914, 1539, 1366,  866, 3014, 2250, 2656,  232, 3272,  181, 4890, 3320,
        4363, 1465, 2409,  113, 5925, 4348, 5381, 2156, 5380, 1091, 5383, 4141,
        5848, 3424, 2565, 5841, 5050, 2731,  708,  142,  365,  404, 3000, 1471,
        5816,  677, 4309,  290, 1510, 5939, 2843, 2298, 3725,  141, 2496, 1430,
        1741, 2978, 5206, 5820, 2880,  535, 5026,  437,  291, 3708,  983, 4030,
        3499, 2196,  855, 3347, 2229, 2121, 4856, 2485, 4361,    9, 2989, 2524,
        5482, 1160, 4591,  783, 2697, 3810, 4446, 5129, 5956, 1296, 2366, 3101,
        3295, 3665, 1369, 5983, 2546,  884,  797, 4967, 4228, 5587, 5144, 2048,
        3768,  944,  665, 3865, 3891,  720, 3213, 5435])


In [126]:
print(historical_ids.shape)
print(historical_ids[0])

torch.Size([128, 200])
tensor([2676, 3528,   31, 3436,  605, 1726, 1609,  277,  458, 2264, 3270, 3437,
        2005,    2, 2021, 1702,  541, 1344, 1748, 2672, 3476, 1258, 1997, 1994,
        3499, 1982, 2455, 1407, 2710, 1332, 2459, 1347, 1339, 1974, 1327,  879,
        2004, 1762,  742, 2121, 2513, 2717,  196, 1970, 2122, 1977, 2315, 1380,
        1088, 1086, 1964, 1459, 1625, 2583, 1092, 1464, 1892, 2752,  695, 2413,
         924, 1247,  377,  802, 1059, 3479, 3501, 1216, 2581, 2340, 1755,  289,
        1835,  249,  163,  461,  236, 2802, 2369, 3394, 2875,  804, 2950, 1556,
         750, 1199, 1240, 2010,  589,  198, 2407, 2527, 1921, 1603,  316,  185,
         442, 3033,  611, 3354, 1037,  172,  256,  880, 2701, 1882,  319,  373,
        3147, 2819, 1834, 2391, 3203, 3102, 3518, 3505, 2118,  832, 3176, 3101,
        2058, 3555, 2707, 2353, 1608,  422,  490, 1722,   16, 3686, 3005, 1422,
        2273,  376, 1518, 3370, 3219, 2881, 1597,  230, 1438, 2956, 1047, 3557,
        2803,  22

In [127]:
print(historical_ratings.shape)
print(historical_ratings[0])

torch.Size([128, 200])
tensor([3, 3, 2, 2, 2, 1, 3, 2, 3, 2, 4, 1, 1, 2, 2, 1, 4, 2, 3, 3, 4, 4, 4, 1,
        3, 3, 1, 3, 4, 2, 2, 2, 1, 3, 3, 3, 1, 2, 1, 1, 3, 1, 2, 3, 1, 1, 1, 3,
        2, 3, 3, 3, 5, 3, 3, 4, 3, 2, 3, 1, 5, 3, 4, 2, 4, 3, 4, 3, 3, 1, 3, 3,
        3, 3, 2, 3, 1, 2, 2, 2, 1, 3, 1, 1, 5, 4, 4, 3, 4, 3, 3, 4, 3, 2, 2, 2,
        2, 2, 1, 4, 2, 1, 2, 2, 2, 1, 3, 4, 4, 5, 3, 4, 3, 3, 2, 4, 3, 3, 4, 3,
        3, 4, 4, 3, 2, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 2, 2, 3, 3, 2, 3, 3, 3,
        2, 3, 3, 2, 3, 3, 3, 3, 1, 2, 1, 3, 3, 3, 3, 2, 4, 3, 2, 3, 4, 3, 1, 2,
        4, 3, 4, 3, 1, 1, 3, 2, 4, 4, 3, 4, 5, 4, 2, 2, 4, 5, 4, 2, 3, 2, 2, 4,
        1, 4, 5, 5, 1, 2, 4, 1])


In [128]:
print(historical_timestamps.shape)
print(historical_timestamps[1])

torch.Size([128, 200])
tensor([963617959, 963617985, 963617985, 963617985, 963618003, 963618093,
        963618093, 963618093, 963618116, 963618136, 963618179, 963618179,
        963618179, 963618196, 963618216, 963618266, 963618299, 963618331,
        963618350, 963618375, 963618375, 963618414, 963618414, 963618414,
        963618414, 963618414, 963618450, 963618450, 963618495, 963618565,
        963618565, 963618583, 963618583, 963618600, 963618600, 963618648,
        963618648, 963618648, 963618670, 963618670, 963618670, 963618693,
        963618693, 963618729, 963618748, 963618760, 963618781, 963618781,
        963618781, 963618815, 963618815, 963618987, 963619024, 963619071,
        963619071, 963619071, 963619099, 963619099, 963619099, 963619129,
        963619129, 963619129, 963619129, 963619129, 963619129, 963619153,
        963619153, 963619181, 963619181, 963619181, 963619246, 963619246,
        963619246, 963619246, 963619246, 963619246, 963619289, 963619289,
        9636192

In [129]:
print(history_lengths.shape)
print(history_lengths)
print(max(history_lengths))

torch.Size([128])
tensor([200, 130, 200, 200,  29, 129,  34,  74,  20, 200, 200, 189, 200,  19,
         19, 200, 200,  23,  55, 123, 200, 109, 200,  98,  70,  25, 141,  34,
        168,  65, 130,  88, 200, 200, 169, 200, 190, 200, 132,  66, 195,  22,
         22, 200,  18,  31,  26, 200, 200,  69, 200, 200,  83, 200,  18,  45,
         45,  48, 104,  25,  30, 200, 130,  31,  31, 109,  37, 118,  24,  21,
        200, 114, 200, 171,  27,  46, 200,  19, 200,  50, 100,  33,  76, 122,
        200,  34, 200,  43,  37, 189,  77, 200, 130, 104, 103,  24, 200, 105,
        200, 117,  32,  65, 200, 108, 200,  46,  39, 200,  72, 200, 200,  26,
        116,  97,  81, 102,  21,  53,  30,  33, 200,  43,  18,  55,  98, 200,
        135,  44])
tensor(200)


In [130]:
print(target_ids.shape)
print(target_ids)

torch.Size([128])
tensor([ 423, 1639, 2498, 3420, 3005, 3604, 3317,  265, 1722, 1242, 2374, 3328,
        2953, 1287, 2359, 2013,  208, 2881, 1956, 2355, 1980, 1371, 3869,  454,
        1923,  898, 2739, 3510, 2431, 3745, 2018, 2567, 1606, 3246,  515, 1326,
         457, 1928, 2067, 3883, 1230, 3811, 2791, 1681, 2836,  648,  750, 1552,
         303, 2706, 3831,  921, 3285, 2133, 3675, 2761, 2762, 3159, 2950,  268,
        1092, 3230, 3186, 1188, 2456, 3481, 2668,    3, 3386, 3469, 1278, 1251,
         417, 1256, 1954, 1265, 1235, 3114, 3649, 3836, 1936, 3846, 3701, 2144,
        2121, 3300, 2245, 3006, 3263, 1207, 2700, 3115, 2761,  367, 1371, 1219,
        3686, 2714,  549, 2141, 2070, 2922,  181,  377, 3906, 1307,  368, 1243,
         587, 2124, 2158, 2528, 3189,   34,  969, 3821,  593,  337, 1088, 1252,
        3442,  318, 1219, 1081, 2928,  733, 1231,   24])


In [131]:
print(target_ratings.shape)
print(target_ratings)

torch.Size([128])
tensor([3, 5, 3, 4, 3, 5, 3, 3, 3, 4, 2, 4, 1, 5, 5, 5, 2, 3, 4, 3, 2, 2, 4, 5,
        5, 4, 5, 5, 3, 4, 5, 1, 3, 5, 4, 2, 5, 4, 5, 1, 4, 5, 4, 2, 3, 3, 5, 3,
        2, 4, 4, 3, 3, 4, 4, 4, 5, 4, 1, 4, 3, 4, 4, 2, 3, 5, 2, 2, 2, 5, 5, 5,
        4, 4, 4, 2, 5, 5, 2, 5, 4, 3, 3, 5, 3, 4, 4, 5, 3, 4, 3, 4, 3, 3, 1, 4,
        4, 4, 3, 3, 5, 4, 1, 4, 3, 5, 2, 5, 4, 4, 5, 4, 4, 5, 5, 3, 4, 4, 5, 5,
        2, 5, 5, 2, 3, 3, 5, 4])


In [132]:
print(target_timestamps.shape)
print(target_timestamps)

torch.Size([128])
tensor([1003623343,  963679419, 1044805014,  976767198,  966630460,  996013330,
         961607731,  977594963,  969851412,  959307449, 1028484651,  971207152,
         965106138,  974058435,  965779227,  974714308,  974742329,  974668616,
         976063418,  958780306,  958346913,  970569489,  966208707,  975898200,
         971763054,  981396825,  974778476,  975277824, 1013919071,  974596665,
         973531539,  976812980,  971259833,  977090726,  962740343,  970861801,
         965186085,  974768388, 1008702921,  977507339,  959186558,  965485055,
         960928975,  974632724,  960393128,  974930242,  960355691,  965353589,
         957786013,  967359392,  996354386, 1037881424,  983492534,  973244412,
         975542300,  980559291,  976405459,  985827325,  970622299,  974753567,
         957915227,  975612817,  967745654,  976564993,  974749071,  957216437,
         972545433,  974500185,  966218328,  977357662,  974435324,  974761580,
         974712244,  9

In [133]:
# Creates model and moves it to GPU with id rank
device = rank

In [134]:
seq_features, target_ids, target_ratings = movielens_seq_features_from_row(
    batch,
    device=device,
    max_output_length=gr_output_length + 1,
)

In [135]:
print(target_ids.shape)
print(target_ids[:5])
print(target_ratings.shape)
print(target_ratings[:5])

torch.Size([128, 1])
tensor([[ 423],
        [1639],
        [2498],
        [3420],
        [3005]], device='cuda:0')
torch.Size([128, 1])
tensor([[3],
        [5],
        [3],
        [4],
        [3]], device='cuda:0')


In [136]:
print(seq_features.past_lengths.shape)
print(seq_features.past_ids.shape)
print(seq_features.past_embeddings.shape if seq_features.past_embeddings else None)
print(seq_features.past_payloads["timestamps"].shape)
print(seq_features.past_payloads["ratings"].shape)

torch.Size([128])
torch.Size([128, 211])
None
torch.Size([128, 211])
torch.Size([128, 211])


# 3. Embedding

In [145]:
print(dataset.max_item_id)
print(item_embedding_dim)

3952
240


In [146]:
embedding_module: EmbeddingModule = LocalEmbeddingModule(
    num_items=dataset.max_item_id,
    item_embedding_dim=item_embedding_dim
).to(device)

Initialize _item_emb.weight as truncated normal: torch.Size([3953, 240]) params


In [147]:
embedding_module.debug_str()

'local_emb_d240'

In [148]:
for name, params in embedding_module.named_parameters():
    print(f"{name}: {params.shape}")

_item_emb.weight: torch.Size([3953, 240])


In [149]:
B, N = seq_features.past_ids.shape
print(B)
print(N)

128
211


In [150]:
seq_features.past_ids.scatter_(
    dim=1,
    index=seq_features.past_lengths.view(-1, 1),
    src=target_ids.view(-1, 1),
)

tensor([[2676, 3528,   31,  ...,    0,    0,    0],
        [1711, 1254, 1196,  ...,    0,    0,    0],
        [1136, 1278, 2394,  ...,    0,    0,    0],
        ...,
        [2000, 1201, 1222,  ...,    0,    0,    0],
        [ 480, 1777, 2359,  ...,    0,    0,    0],
        [1358, 1230,  480,  ...,    0,    0,    0]], device='cuda:0')

In [151]:
input_embeddings = embedding_module.get_item_embeddings(seq_features.past_ids)

In [152]:
# batch x seq x dim
print(input_embeddings.shape)

torch.Size([128, 211, 240])


# 4. Model

## 4.1. Preprocessing

In [154]:
# for max_sequence_len
print(dataset.max_sequence_length)
print(gr_output_length)

# for embedding_dim
print(item_embedding_dim)

# for dropout_rate
print(dropout_rate)

200
10
240
0.2


In [158]:
input_preproc_module = LearnablePositionalEmbeddingInputFeaturesPreprocessor(
    max_sequence_len=dataset.max_sequence_length + gr_output_length + 1,
    embedding_dim=item_embedding_dim,
    dropout_rate=dropout_rate,
).to(device)

input_preproc_module.debug_str()

'posi_d0.2'

In [159]:
past_lengths, user_embeddings, _ = input_preproc_module(
    past_lengths=seq_features.past_lengths,
    past_ids=seq_features.past_ids,
    past_embeddings=input_embeddings,
    past_payloads=seq_features.past_payloads,
)

In [161]:
print(past_lengths.shape)
print(user_embeddings.shape)

torch.Size([128])
torch.Size([128, 211, 240])


## 4.2. Forward