## Set up

In [68]:
import logging
import os
import random

import time

from datetime import date
from typing import Dict, Optional

import gin

import torch
import torch.distributed as dist

from generative_recommenders.research.data.eval import (
    _avg,
    add_to_summary_writer,
    eval_metrics_v2_from_tensors,
    get_eval_state,
)

from generative_recommenders.research.data.reco_dataset import get_reco_dataset
from generative_recommenders.research.indexing.utils import get_top_k_module
from generative_recommenders.research.modeling.sequential.autoregressive_losses import (
    BCELoss,
    InBatchNegativesSampler,
    LocalNegativesSampler,
)
from generative_recommenders.research.modeling.sequential.embedding_modules import (
    EmbeddingModule,
    LocalEmbeddingModule,
)
from generative_recommenders.research.modeling.sequential.encoder_utils import (
    get_sequential_encoder,
)
from generative_recommenders.research.modeling.sequential.features import (
    movielens_seq_features_from_row,
)
from generative_recommenders.research.modeling.sequential.input_features_preprocessors import (
    LearnablePositionalEmbeddingInputFeaturesPreprocessor,
)
from generative_recommenders.research.modeling.sequential.losses.sampled_softmax import (
    SampledSoftmaxLoss,
)
from generative_recommenders.research.modeling.sequential.output_postprocessors import (
    L2NormEmbeddingPostprocessor,
    LayerNormEmbeddingPostprocessor,
)
from generative_recommenders.research.modeling.similarity_utils import (
    get_similarity_function,
)
from generative_recommenders.research.trainer.data_loader import create_data_loader
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.tensorboard import SummaryWriter

In [17]:
rank: int = 0
world_size: int = 1
master_port: int = 12355
dataset_name: str = "ml-1m"
max_sequence_length: int = 200
positional_sampling_ratio: float = 1.0
local_batch_size: int = 128
eval_batch_size: int = 128
eval_user_max_batch_size: Optional[int] = None
main_module: str = "SASRec"
main_module_bf16: bool = False
dropout_rate: float = 0.2
user_embedding_norm: str = "l2_norm"
sampling_strategy: str = "in-batch"
loss_module: str = "SampledSoftmaxLoss"
loss_weights: Optional[Dict[str, float]] = {}
num_negatives: int = 1
loss_activation_checkpoint: bool = False
item_l2_norm: bool = False
temperature: float = 0.05
num_epochs: int = 101
learning_rate: float = 1e-3
num_warmup_steps: int = 0
weight_decay: float = 1e-3
top_k_method: str = "MIPSBruteForceTopK"
eval_interval: int = 100
full_eval_every_n: int = 1
save_ckpt_every_n: int = 1000
partial_eval_num_iters: int = 32
embedding_module_type: str = "local"
item_embedding_dim: int = 240
interaction_module_type: str = ""
gr_output_length: int = 10
l2_norm_eps: float = 1e-6
enable_tf32: bool = False
random_seed: int = 42

## Dataset

In [26]:
print(dataset_name)

ml-1m


In [27]:
dataset = get_reco_dataset(
    dataset_name=dataset_name,
    max_sequence_length=max_sequence_length,
    chronological=True,
    positional_sampling_ratio=positional_sampling_ratio,
)

In [28]:
train_data_sampler, train_data_loader = create_data_loader(
    dataset.train_dataset,
    batch_size=local_batch_size,
    world_size=world_size,
    rank=rank,
    shuffle=True,
    drop_last=world_size > 1,
)

In [29]:
print(type(train_data_sampler))
print(type(train_data_loader))

<class 'torch.utils.data.distributed.DistributedSampler'>
<class 'torch.utils.data.dataloader.DataLoader'>


In [None]:
batch = next(iter(train_data_loader))
print(batch.keys())

dict_keys(['user_id', 'historical_ids', 'historical_ratings', 'historical_timestamps', 'history_lengths', 'target_ids', 'target_ratings', 'target_timestamps'])


In [34]:
user_id = batch["user_id"]
historical_ids = batch["historical_ids"]
historical_ratings = batch["historical_ratings"]
historical_timestamps = batch["historical_timestamps"]
history_lengths = batch["history_lengths"]
target_ids = batch["target_ids"]
target_ratings = batch["target_ratings"]
target_timestamps = batch["target_timestamps"]

In [36]:
print(user_id.shape)
print(user_id)

torch.Size([128])
tensor([5536, 4688,  343,  238, 3595, 5447, 5372,  100, 3076, 5615, 2641, 2955,
        4430, 2527, 3897, 1687, 1807, 2045,  576, 5661, 5854, 3005, 4033,  595,
        2914, 1539, 1366,  866, 3014, 2250, 2656,  232, 3272,  181, 4890, 3320,
        4363, 1465, 2409,  113, 5925, 4348, 5381, 2156, 5380, 1091, 5383, 4141,
        5848, 3424, 2565, 5841, 5050, 2731,  708,  142,  365,  404, 3000, 1471,
        5816,  677, 4309,  290, 1510, 5939, 2843, 2298, 3725,  141, 2496, 1430,
        1741, 2978, 5206, 5820, 2880,  535, 5026,  437,  291, 3708,  983, 4030,
        3499, 2196,  855, 3347, 2229, 2121, 4856, 2485, 4361,    9, 2989, 2524,
        5482, 1160, 4591,  783, 2697, 3810, 4446, 5129, 5956, 1296, 2366, 3101,
        3295, 3665, 1369, 5983, 2546,  884,  797, 4967, 4228, 5587, 5144, 2048,
        3768,  944,  665, 3865, 3891,  720, 3213, 5435])


In [38]:
print(historical_ids.shape)
print(historical_ids[0])

torch.Size([128, 200])
tensor([3528,   31, 3436,  605, 1726, 1609,  277,  458, 2264, 3270, 3437, 2005,
           2, 2021, 1702,  541, 1344, 1748, 2672, 3476, 1258, 1997, 1994, 3499,
        1982, 2455, 1407, 2710, 1332, 2459, 1347, 1339, 1974, 1327,  879, 2004,
        1762,  742, 2121, 2513, 2717,  196, 1970, 2122, 1977, 2315, 1380, 1088,
        1086, 1964, 1459, 1625, 2583, 1092, 1464, 1892, 2752,  695, 2413,  924,
        1247,  377,  802, 1059, 3479, 3501, 1216, 2581, 2340, 1755,  289, 1835,
         249,  163,  461,  236, 2802, 2369, 3394, 2875,  804, 2950, 1556,  750,
        1199, 1240, 2010,  589,  198, 2407, 2527, 1921, 1603,  316,  185,  442,
        3033,  611, 3354, 1037,  172,  256,  880, 2701, 1882,  319,  373, 3147,
        2819, 1834, 2391, 3203, 3102, 3518, 3505, 2118,  832, 3176, 3101, 2058,
        3555, 2707, 2353, 1608,  422,  490, 1722,   16, 3686, 3005, 1422, 2273,
         376, 1518, 3370, 3219, 2881, 1597,  230, 1438, 2956, 1047, 3557, 2803,
         225, 168

In [40]:
print(historical_ratings.shape)
print(historical_ratings[0])

torch.Size([128, 200])
tensor([3, 2, 2, 2, 1, 3, 2, 3, 2, 4, 1, 1, 2, 2, 1, 4, 2, 3, 3, 4, 4, 4, 1, 3,
        3, 1, 3, 4, 2, 2, 2, 1, 3, 3, 3, 1, 2, 1, 1, 3, 1, 2, 3, 1, 1, 1, 3, 2,
        3, 3, 3, 5, 3, 3, 4, 3, 2, 3, 1, 5, 3, 4, 2, 4, 3, 4, 3, 3, 1, 3, 3, 3,
        3, 2, 3, 1, 2, 2, 2, 1, 3, 1, 1, 5, 4, 4, 3, 4, 3, 3, 4, 3, 2, 2, 2, 2,
        2, 1, 4, 2, 1, 2, 2, 2, 1, 3, 4, 4, 5, 3, 4, 3, 3, 2, 4, 3, 3, 4, 3, 3,
        4, 4, 3, 2, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 2, 2, 3, 3, 2, 3, 3, 3, 2,
        3, 3, 2, 3, 3, 3, 3, 1, 2, 1, 3, 3, 3, 3, 2, 4, 3, 2, 3, 4, 3, 1, 2, 4,
        3, 4, 3, 1, 1, 3, 2, 4, 4, 3, 4, 5, 4, 2, 2, 4, 5, 4, 2, 3, 2, 2, 4, 1,
        4, 5, 5, 1, 2, 4, 1, 3])


In [53]:
print(historical_timestamps.shape)
print(historical_timestamps[1])

torch.Size([128, 200])
tensor([963617959, 963617985, 963617985, 963617985, 963618003, 963618093,
        963618093, 963618093, 963618116, 963618136, 963618179, 963618179,
        963618179, 963618196, 963618216, 963618266, 963618299, 963618331,
        963618350, 963618375, 963618375, 963618414, 963618414, 963618414,
        963618414, 963618414, 963618450, 963618450, 963618495, 963618565,
        963618565, 963618583, 963618583, 963618600, 963618600, 963618648,
        963618648, 963618648, 963618670, 963618670, 963618670, 963618693,
        963618693, 963618729, 963618748, 963618760, 963618781, 963618781,
        963618781, 963618815, 963618815, 963618987, 963619024, 963619071,
        963619071, 963619071, 963619099, 963619099, 963619099, 963619129,
        963619129, 963619129, 963619129, 963619129, 963619129, 963619153,
        963619153, 963619181, 963619181, 963619181, 963619246, 963619246,
        963619246, 963619246, 963619246, 963619246, 963619289, 963619289,
        9636192

In [45]:
print(history_lengths.shape)
print(history_lengths)
print(max(history_lengths))

torch.Size([128])
tensor([200, 131, 200, 200,  30, 130,  35,  75,  21, 200, 200, 190, 200,  20,
         20, 200, 200,  24,  56, 124, 200, 110, 200,  99,  71,  26, 142,  35,
        169,  66, 131,  89, 200, 200, 170, 200, 191, 200, 133,  67, 196,  23,
         23, 200,  19,  32,  27, 200, 200,  70, 200, 200,  84, 200,  19,  46,
         46,  49, 105,  26,  31, 200, 131,  32,  32, 110,  38, 119,  25,  22,
        200, 115, 200, 172,  28,  47, 200,  20, 200,  51, 101,  34,  77, 123,
        200,  35, 200,  44,  38, 190,  78, 200, 131, 105, 104,  25, 200, 106,
        200, 118,  33,  66, 200, 109, 200,  47,  40, 200,  73, 200, 200,  27,
        117,  98,  82, 103,  22,  54,  31,  34, 200,  44,  19,  56,  99, 200,
        136,  45])
tensor(200)


In [47]:
print(target_ids.shape)
print(target_ids)

torch.Size([128])
tensor([ 224,  194, 1817, 3379, 2976, 2877, 3565, 1208, 1573, 2020, 3428, 3624,
        2798, 2000, 2245, 2537, 2735,   61, 3624,    1, 2460, 3638, 3868, 3831,
        3363, 1938, 1682,  260,  508, 3717, 3510, 2539, 3394, 2329, 1393, 2119,
        1627, 2633,  594, 3895, 2858, 1162, 2915,  393, 3174,   89,  260,  736,
         432, 3510, 3512,  278, 3299, 3422, 3753, 3160,  593, 3498, 3394, 3173,
        1589, 2414, 3826, 2671, 1976, 3534, 3846, 2399, 2959, 2203, 1303, 3534,
         900,  978, 3623, 3361, 2361, 2987,  193, 3062, 3194,  370, 2408, 2529,
        2401, 2918,  574, 1784, 3869, 2300, 1282, 3566,  596, 2294, 1376, 1240,
        3614, 1623,   36,  673, 3793, 3753, 3442, 1358,  469, 1225, 1586, 1394,
        2427, 2502, 3310, 2657, 3555, 1784,  933, 3408, 1121, 2501,  587,  541,
        3000,  608, 1529, 1381, 1962, 2006, 2791,  780])


In [49]:
print(target_ratings.shape)
print(target_ratings)

torch.Size([128])
tensor([3, 4, 2, 5, 2, 4, 2, 2, 5, 4, 4, 1, 1, 5, 4, 4, 3, 5, 2, 5, 2, 3, 5, 4,
        4, 5, 5, 4, 5, 1, 5, 3, 4, 5, 1, 2, 3, 5, 5, 4, 1, 5, 5, 1, 3, 3, 5, 4,
        1, 5, 4, 4, 3, 3, 5, 4, 4, 4, 1, 4, 3, 4, 4, 3, 1, 2, 5, 3, 5, 5, 5, 2,
        4, 4, 3, 2, 5, 3, 1, 5, 4, 3, 5, 5, 5, 5, 3, 4, 3, 4, 4, 3, 3, 4, 3, 3,
        4, 3, 4, 1, 5, 5, 2, 4, 3, 4, 3, 4, 1, 4, 4, 5, 3, 3, 5, 4, 1, 4, 5, 5,
        4, 5, 5, 2, 4, 3, 5, 4])


In [51]:
print(target_timestamps.shape)
print(target_timestamps)

torch.Size([128])
tensor([1004157486,  963679456, 1045945209,  976767198,  966630460,  996013330,
         961607747,  977594963,  969851412,  959307449, 1028484729,  971207165,
         965106162,  974058435,  965779263,  974714308,  974742329,  974668725,
         976063550,  958780306,  958346936,  970569489,  966208724,  983261337,
         971763054,  981396825,  974778511,  975277867, 1013919102,  974596713,
         973531630,  976812993,  971259833,  977090747,  962740382,  970861825,
         965186085,  974768388, 1008704143,  977507389,  959813284,  965485145,
         960928975,  974632776,  960393128,  974930258,  960355691,  965353589,
         957786029,  975562172,  996354648, 1044278237,  983492567,  973244412,
         975542365,  980559291,  976405459,  988503219,  970622299,  974753669,
         957915227,  975612817,  967745970,  976564993,  974749071,  957216463,
         972545479,  974500216,  966218328,  977357662,  974435367,  974761580,
         974712244,  9

In [55]:
# Creates model and moves it to GPU with id rank
device = rank

In [56]:
seq_features, target_ids, target_ratings = movielens_seq_features_from_row(
    batch,
    device=device,
    max_output_length=gr_output_length + 1,
)

In [61]:
print(target_ids.shape)
print(target_ids[:5])
print(target_ratings.shape)
print(target_ratings[:5])

torch.Size([128, 1])
tensor([[ 224],
        [ 194],
        [1817],
        [3379],
        [2976]], device='cuda:0')
torch.Size([128, 1])
tensor([[3],
        [4],
        [2],
        [5],
        [2]], device='cuda:0')


In [64]:
print(seq_features.past_lengths.shape)
print(seq_features.past_ids.shape)
print(seq_features.past_embeddings.shape if seq_features.past_embeddings else None)
print(seq_features.past_payloads["timestamps"].shape)
print(seq_features.past_payloads["ratings"].shape)

torch.Size([128])
torch.Size([128, 211])
None
torch.Size([128, 211])
torch.Size([128, 211])


## Embedding

In [66]:
print(dataset.max_item_id)
print(item_embedding_dim)

3952
240


In [70]:
embedding_module: EmbeddingModule = LocalEmbeddingModule(
    num_items=dataset.max_item_id,
    item_embedding_dim=item_embedding_dim
)

Initialize _item_emb.weight as truncated normal: torch.Size([3953, 240]) params


In [71]:
for name, params in embedding_module.named_parameters():
    print(f"{name}: {params.shape}")

_item_emb.weight: torch.Size([3953, 240])


In [74]:
B, N = seq_features.past_ids.shape
print(B)
print(N)

128
211


In [75]:
seq_features.past_ids.scatter_(
    dim=1,
    index=seq_features.past_lengths.view(-1, 1),
    src=target_ids.view(-1, 1),
)

tensor([[3528,   31, 3436,  ...,    0,    0,    0],
        [1711, 1254, 1196,  ...,    0,    0,    0],
        [1278, 2394,  899,  ...,    0,    0,    0],
        ...,
        [1201, 1222, 2366,  ...,    0,    0,    0],
        [ 480, 1777, 2359,  ...,    0,    0,    0],
        [1358, 1230,  480,  ...,    0,    0,    0]], device='cuda:0')