In [None]:
!pip install cudf-cu11==22.12 rmm-cu11==22.12 --extra-index-url=https://pypi.ngc.nvidia.com
!pip install cugraph-cu11==22.12 dask-cuda==22.12 dask-cudf-cu11==22.12  pylibcugraph-cu11==22.12 --extra-index-url=https://pypi.ngc.nvidia.com/
!pip install cuml-cu11==22.12 raft_dask_cu11==22.12 dask-cudf-cu11==22.12  pylibraft_cu11==22.12 ucx-py-cu11==0.29.0 --extra-index-url=https://pypi.ngc.nvidia.com

In [None]:
# Install the Merlin Framework
!pip install -U git+https://github.com/NVIDIA-Merlin/models.git@release-23.02
!pip install -U git+https://github.com/NVIDIA-Merlin/nvtabular.git@release-23.02
!pip install -U git+https://github.com/NVIDIA-Merlin/core.git@release-23.02
!pip install -U git+https://github.com/NVIDIA-Merlin/system.git@release-23.02
!pip install -U git+https://github.com/NVIDIA-Merlin/dataloader.git@release-23.02
!pip install -U git+https://github.com/NVIDIA-Merlin/Transformers4Rec.git@release-23.02
!pip install -U xgboost lightfm implicit

In [101]:
from google.colab import drive
drive.mount('/content/drive')


Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [102]:
import os
import numpy as np 
import gc
import shutil
import glob

import cudf
import nvtabular as nvt
from merlin.dag import ColumnSelector
from merlin.schema import Schema, Tags
import pandas as pd 

In [103]:
INPUT_DATA_DIR = "/content/drive/MyDrive/ds_data_v3/"

In [115]:
raw_df = cudf.read_csv(os.path.join(INPUT_DATA_DIR, 'cds_2023_march_4days.csv')) 
raw_df.head()

Unnamed: 0,event_time,event_type,user_id,user_session,product_id,quantity,price,brand,category_id,category_code
0,2023-03-02 14:47:11.325028 UTC,view_item,150FEA0F36CF46FA840956CAF974CC0B,150FEA0F36CF46FA840956CAF974CC0B1677768077,CDS88856872,1,1590.0,APPLE,3987,mobile-tablet-accessories
1,2023-03-02 14:22:30.083014 UTC,view_item,38C7CC4F7EC94CEA83CAB02B89F720DB,38C7CC4F7EC94CEA83CAB02B89F720DB1677765651,MKP0690480,1,629.0,XIAOMI,3200,scale
2,2023-03-02 14:18:14.010002 UTC,view_item,38C7CC4F7EC94CEA83CAB02B89F720DB,38C7CC4F7EC94CEA83CAB02B89F720DB1677765651,CDS86195966,1,490.0,SHAPER,3200,scale
3,2023-03-02 14:23:47.644015 UTC,view_item,38C7CC4F7EC94CEA83CAB02B89F720DB,38C7CC4F7EC94CEA83CAB02B89F720DB1677765651,MKP0690480,1,629.0,XIAOMI,3200,scale
4,2023-03-01 22:04:47.048 UTC,view_item,AE731397442F48249A00B00D76DB9328,AE731397442F48249A00B00D76DB93281677708277,MKP1233067,1,990.0,REMAX,3987,mobile-tablet-accessories


In [116]:
raw_df.shape

(1033765, 10)

In [109]:
#raw_df = raw_df.loc[raw_df["event_type"].isin(["add_to_cart","purchase"])]
#raw_df.shape

(48961, 10)

In [117]:
raw_df.head()

Unnamed: 0,event_time,event_type,user_id,user_session,product_id,quantity,price,brand,category_id,category_code
0,2023-03-02 14:47:11.325028 UTC,view_item,150FEA0F36CF46FA840956CAF974CC0B,150FEA0F36CF46FA840956CAF974CC0B1677768077,CDS88856872,1,1590.0,APPLE,3987,mobile-tablet-accessories
1,2023-03-02 14:22:30.083014 UTC,view_item,38C7CC4F7EC94CEA83CAB02B89F720DB,38C7CC4F7EC94CEA83CAB02B89F720DB1677765651,MKP0690480,1,629.0,XIAOMI,3200,scale
2,2023-03-02 14:18:14.010002 UTC,view_item,38C7CC4F7EC94CEA83CAB02B89F720DB,38C7CC4F7EC94CEA83CAB02B89F720DB1677765651,CDS86195966,1,490.0,SHAPER,3200,scale
3,2023-03-02 14:23:47.644015 UTC,view_item,38C7CC4F7EC94CEA83CAB02B89F720DB,38C7CC4F7EC94CEA83CAB02B89F720DB1677765651,MKP0690480,1,629.0,XIAOMI,3200,scale
4,2023-03-01 22:04:47.048 UTC,view_item,AE731397442F48249A00B00D76DB9328,AE731397442F48249A00B00D76DB93281677708277,MKP1233067,1,990.0,REMAX,3987,mobile-tablet-accessories


In [118]:
raw_df['event_time_dt'] = raw_df['event_time'].astype('datetime64[s]')
raw_df['event_time_ts']= raw_df['event_time_dt'].astype('int')
raw_df.head()

Unnamed: 0,event_time,event_type,user_id,user_session,product_id,quantity,price,brand,category_id,category_code,event_time_dt,event_time_ts
0,2023-03-02 14:47:11.325028 UTC,view_item,150FEA0F36CF46FA840956CAF974CC0B,150FEA0F36CF46FA840956CAF974CC0B1677768077,CDS88856872,1,1590.0,APPLE,3987,mobile-tablet-accessories,2023-03-02 14:47:11,1677768431
1,2023-03-02 14:22:30.083014 UTC,view_item,38C7CC4F7EC94CEA83CAB02B89F720DB,38C7CC4F7EC94CEA83CAB02B89F720DB1677765651,MKP0690480,1,629.0,XIAOMI,3200,scale,2023-03-02 14:22:30,1677766950
2,2023-03-02 14:18:14.010002 UTC,view_item,38C7CC4F7EC94CEA83CAB02B89F720DB,38C7CC4F7EC94CEA83CAB02B89F720DB1677765651,CDS86195966,1,490.0,SHAPER,3200,scale,2023-03-02 14:18:14,1677766694
3,2023-03-02 14:23:47.644015 UTC,view_item,38C7CC4F7EC94CEA83CAB02B89F720DB,38C7CC4F7EC94CEA83CAB02B89F720DB1677765651,MKP0690480,1,629.0,XIAOMI,3200,scale,2023-03-02 14:23:47,1677767027
4,2023-03-01 22:04:47.048 UTC,view_item,AE731397442F48249A00B00D76DB9328,AE731397442F48249A00B00D76DB93281677708277,MKP1233067,1,990.0,REMAX,3987,mobile-tablet-accessories,2023-03-01 22:04:47,1677708287


In [119]:
raw_df.isnull().any()

event_time       False
event_type       False
user_id          False
user_session      True
product_id       False
quantity         False
price            False
brand             True
category_id       True
category_code     True
event_time_dt    False
event_time_ts    False
dtype: bool

In [121]:
raw_df = raw_df[raw_df['user_session'].isnull()==False]
len(raw_df)

1033738

In [122]:
raw_df = raw_df.drop(['event_time'],  axis=1)

In [123]:
cols = list(raw_df.columns)
cols.remove('user_session')
cols

['event_type',
 'user_id',
 'product_id',
 'quantity',
 'price',
 'brand',
 'category_id',
 'category_code',
 'event_time_dt',
 'event_time_ts']

In [124]:
# load data 
df_event = nvt.Dataset(raw_df) 

# categorify user_session 
cat_feats = ['user_session'] >> nvt.ops.Categorify()

workflow = nvt.Workflow(cols + cat_feats)
workflow.fit(df_event)
df = workflow.transform(df_event).to_ddf().compute()

In [125]:
df.head()

Unnamed: 0,user_session,event_type,user_id,product_id,quantity,price,brand,category_id,category_code,event_time_dt,event_time_ts
0,123604,view_item,150FEA0F36CF46FA840956CAF974CC0B,CDS88856872,1,1590.0,APPLE,3987,mobile-tablet-accessories,2023-03-02 14:47:11,1677768431
1,3516,view_item,38C7CC4F7EC94CEA83CAB02B89F720DB,MKP0690480,1,629.0,XIAOMI,3200,scale,2023-03-02 14:22:30,1677766950
2,3516,view_item,38C7CC4F7EC94CEA83CAB02B89F720DB,CDS86195966,1,490.0,SHAPER,3200,scale,2023-03-02 14:18:14,1677766694
3,3516,view_item,38C7CC4F7EC94CEA83CAB02B89F720DB,MKP0690480,1,629.0,XIAOMI,3200,scale,2023-03-02 14:23:47,1677767027
4,156056,view_item,AE731397442F48249A00B00D76DB9328,MKP1233067,1,990.0,REMAX,3987,mobile-tablet-accessories,2023-03-01 22:04:47,1677708287


In [126]:
df.isnull().any()

user_session     False
event_type       False
user_id          False
product_id       False
quantity         False
price            False
brand             True
category_id       True
category_code     True
event_time_dt    False
event_time_ts    False
dtype: bool

In [127]:
raw_df = None
del(raw_df)
gc.collect()

1986

In [128]:
%%time
df = df.sort_values(['user_session', 'event_time_ts']).reset_index(drop=True)

print("Count with in-session repeated interactions: {}".format(len(df)))
# Sorts the dataframe by session and timestamp, to remove consecutive repetitions
df['product_id_past'] = df['product_id'].shift(1).fillna(0)
df['session_id_past'] = df['user_session'].shift(1).fillna(0)
#Keeping only no consecutive repeated in session interactions
df = df[~((df['user_session'] == df['session_id_past']) & \
             (df['product_id'] == df['product_id_past']))]
print("Count after removed in-session repeated interactions: {}".format(len(df)))
del(df['product_id_past'])
del(df['session_id_past'])

gc.collect()


Count with in-session repeated interactions: 1033738
Count after removed in-session repeated interactions: 881641
CPU times: user 342 ms, sys: 12.8 ms, total: 355 ms
Wall time: 343 ms


72

In [129]:
item_first_interaction_df = df.groupby('product_id').agg({'event_time_ts': 'min'}) \
            .reset_index().rename(columns={'event_time_ts': 'prod_first_event_time_ts'})
item_first_interaction_df.head()
gc.collect()


17

In [130]:
df = df.merge(item_first_interaction_df, on=['product_id'], how='left').reset_index(drop=True)
df.head()
del(item_first_interaction_df)
item_first_interaction_df=None
gc.collect()

0

In [131]:
df['event_time_dt'].min()

numpy.datetime64('2023-02-28T17:00:00')

In [132]:
df = df[df['event_time_dt'] < np.datetime64('2023-03-11')].reset_index(drop=True)

In [133]:
df['event_time_dt'].max()

numpy.datetime64('2023-03-04T16:59:59')

In [134]:
df = df.drop(['event_time_dt'],  axis=1)

In [135]:
df.head()

Unnamed: 0,user_session,event_type,user_id,product_id,quantity,price,brand,category_id,category_code,event_time_ts,prod_first_event_time_ts
0,48,view_item,20b8dc0111babc685ae38b73ef380241,MKP1298321,1,1250.0,MAYFINE,3396,mayfine-handbag-mf-10-1175-black-color-mkp1298321,1677764817,1677647096
1,48,view_item,20b8dc0111babc685ae38b73ef380241,MKP1233034,1,2500.0,ELLE,3396,,1677764883,1677639205
2,48,view_item,20b8dc0111babc685ae38b73ef380241,CDS30190313,1,942.0,ELLE,3925,elle-travel-horizontal-crossover-tote-bag-larg...,1677764904,1677676522
3,48,view_item,20b8dc0111babc685ae38b73ef380241,CDS30187092,1,877.0,ELLE PROMO,3925,elle-travel-horizontal-crossover-tote-bag-smal...,1677764943,1677764943
4,48,view_item,20b8dc0111babc685ae38b73ef380241,CDS30190313,1,942.0,ELLE,3925,elle-travel-horizontal-crossover-tote-bag-larg...,1677764946,1677676522


In [136]:
df.to_parquet(os.path.join("./", 'cds_2023_march_10days.parquet'))
df.head()

Unnamed: 0,user_session,event_type,user_id,product_id,quantity,price,brand,category_id,category_code,event_time_ts,prod_first_event_time_ts
0,48,view_item,20b8dc0111babc685ae38b73ef380241,MKP1298321,1,1250.0,MAYFINE,3396,mayfine-handbag-mf-10-1175-black-color-mkp1298321,1677764817,1677647096
1,48,view_item,20b8dc0111babc685ae38b73ef380241,MKP1233034,1,2500.0,ELLE,3396,,1677764883,1677639205
2,48,view_item,20b8dc0111babc685ae38b73ef380241,CDS30190313,1,942.0,ELLE,3925,elle-travel-horizontal-crossover-tote-bag-larg...,1677764904,1677676522
3,48,view_item,20b8dc0111babc685ae38b73ef380241,CDS30187092,1,877.0,ELLE PROMO,3925,elle-travel-horizontal-crossover-tote-bag-smal...,1677764943,1677764943
4,48,view_item,20b8dc0111babc685ae38b73ef380241,CDS30190313,1,942.0,ELLE,3925,elle-travel-horizontal-crossover-tote-bag-larg...,1677764946,1677676522


In [137]:
cat_feats = ['user_session', 'category_code', 'brand', 'user_id', 'product_id', 'category_id', 'event_type'] >> nvt.ops.Categorify(start_index=1)

In [138]:
# create time features
session_ts = ['event_time_ts']

session_time = (
    session_ts >> 
    nvt.ops.LambdaOp(lambda col: cudf.to_datetime(col, unit='s')) >> 
    nvt.ops.Rename(name = 'event_time_dt')
)

sessiontime_weekday = (
    session_time >> 
    nvt.ops.LambdaOp(lambda col: col.dt.weekday) >> 
    nvt.ops.Rename(name ='et_dayofweek')
)


In [139]:
def get_cycled_feature_value_sin(col, max_value):
    value_scaled = (col + 0.000001) / max_value
    value_sin = np.sin(2*np.pi*value_scaled)
    return value_sin

def get_cycled_feature_value_cos(col, max_value):
    value_scaled = (col + 0.000001) / max_value
    value_cos = np.cos(2*np.pi*value_scaled)
    return value_cos
weekday_sin = sessiontime_weekday >> (lambda col: get_cycled_feature_value_sin(col+1, 7)) >> nvt.ops.Rename(name = 'et_dayofweek_sin')
weekday_cos= sessiontime_weekday >> (lambda col: get_cycled_feature_value_cos(col+1, 7)) >> nvt.ops.Rename(name = 'et_dayofweek_cos')

In [140]:

# Compute Item recency: Define a custom Op 
class ItemRecency(nvt.ops.Operator):
    def transform(self, columns, gdf):
        for column in columns.names:
            col = gdf[column]
            item_first_timestamp = gdf['prod_first_event_time_ts']
            delta_days = (col - item_first_timestamp) / (60*60*24)
            gdf[column + "_age_days"] = delta_days * (delta_days >=0)
        return gdf
    @property
    def dependencies(self):
         return ["prod_first_event_time_ts"]

    def output_column_names(self, columns):
        return ColumnSelector([column + "_age_days" for column in columns.names])
    
recency_features = ['event_time_ts'] >> ItemRecency() 
# Apply standardization to this continuous feature
recency_features_norm = recency_features >> nvt.ops.LogOp() >> nvt.ops.Normalize() >> nvt.ops.Rename(name='product_recency_days_log_norm')

time_features = (
    session_time +
    sessiontime_weekday +
    weekday_sin +
    weekday_cos +
    recency_features_norm
)



In [141]:
price_log = ['price'] >> nvt.ops.LogOp() >> nvt.ops.Normalize() >> nvt.ops.Rename(name='price_log_norm')

In [142]:
# Relative price to the average price for the category_id
def relative_price_to_avg_categ(col, gdf):
    epsilon = 1e-5
    col = ((gdf['price'] - col) / (col + epsilon)) * (col > 0).astype(int)
    return col
    
avg_category_id_pr = ['category_id'] >> nvt.ops.JoinGroupby(cont_cols =['price'], stats=["mean"]) >> nvt.ops.Rename(name='avg_category_id_price')
relative_price_to_avg_category = avg_category_id_pr >> nvt.ops.LambdaOp(relative_price_to_avg_categ, dependency=['price']) >> nvt.ops.Rename(name="relative_price_to_avg_categ_id")


In [143]:
groupby_feats = ['event_time_ts', 'user_session'] + cat_feats + time_features + price_log + relative_price_to_avg_category

In [144]:
# Define Groupby Workflow
groupby_features = groupby_feats >> nvt.ops.Groupby(
    groupby_cols=["user_session"], 
    sort_cols=["event_time_ts"],
    aggs={
        'user_id': ['first'],
        'product_id': ["list", "count"],
        'category_code': ["list"],  
        'event_type': ["list"], 
        'brand': ["list"], 
        'category_id': ["list"], 
        'event_time_ts': ["first"],
        'event_time_dt': ["first"],
        'et_dayofweek_sin': ["list"],
        'et_dayofweek_cos': ["list"],
        'price_log_norm': ["list"],
        'relative_price_to_avg_categ_id': ["list"],
        'product_recency_days_log_norm': ["list"]
        },
    name_sep="-")

In [145]:
groupby_features_list = groupby_features['product_id-list',
        'category_code-list',  
        'event_type-list', 
        'brand-list', 
        'category_id-list', 
        'et_dayofweek_sin-list',
        'et_dayofweek_cos-list',
        'price_log_norm-list',
        'relative_price_to_avg_categ_id-list',
        'product_recency_days_log_norm-list']


In [146]:
SESSIONS_MAX_LENGTH = 20 
MINIMUM_SESSION_LENGTH = 2


In [147]:
groupby_features_trim = groupby_features_list >> nvt.ops.ListSlice(0,SESSIONS_MAX_LENGTH) >> nvt.ops.Rename(postfix = '_seq')

In [148]:
# calculate session day index based on 'timestamp-first' column
day_index = ((groupby_features['event_time_dt-first'])  >> 
    nvt.ops.LambdaOp(lambda col: (col - col.min()).dt.days +1) >> 
    nvt.ops.Rename(f = lambda col: "day_index")
)


In [149]:
selected_features = groupby_features['user_session', 'product_id-count'] + groupby_features_trim + day_index

In [150]:
filtered_sessions = selected_features >> nvt.ops.Filter(f=lambda df: df["product_id-count"] >= MINIMUM_SESSION_LENGTH)

In [151]:
# avoid numba warnings
from numba import config
config.CUDA_LOW_OCCUPANCY_WARNINGS = 0

In [152]:
dataset = nvt.Dataset(df)

workflow = nvt.Workflow(filtered_sessions)
workflow.fit(dataset)
sessions_gdf = workflow.transform(dataset).to_ddf()

In [153]:
sessions_gdf.head(3)

Unnamed: 0,user_session,product_id-count,product_id-list_seq,category_code-list_seq,event_type-list_seq,brand-list_seq,category_id-list_seq,et_dayofweek_sin-list_seq,et_dayofweek_cos-list_seq,price_log_norm-list_seq,relative_price_to_avg_categ_id-list_seq,product_recency_days_log_norm-list_seq,day_index
0,2,330,"[5403, 3, 5, 16, 86, 2477, 407, 9329, 2161, 49...","[74, 14, 2, 2, 2, 52, 52, 20, 39, 2, 54, 2, 31...","[2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, ...","[207, 3, 12, 12, 12, 12, 12, 12, 12, 12, 12, 1...","[84, 15, 2, 2, 2, 43, 43, 17, 27, 2, 53, 2, 49...","[-0.9749281119157542, -0.9749281119157542, -0....","[-0.22252005886297657, -0.22252005886297657, -...","[2.1784331798553467, -0.20048125088214874, -0....","[-0.48128364896438747, -0.0008073852542247119,...","[0.007566578686237335, 0.007566578686237335, 0...",3
1,3,280,"[1089, 3539, 1540, 806, 17411, 13505, 6624, 27...","[19, 76, 11, 11, 75, 29, 29, 162, 7, 15, 15, 1...","[2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, ...","[23, 23, 23, 23, 23, 23, 23, 121, 253, 44, 25,...","[20, 76, 55, 55, 79, 72, 72, 169, 22, 32, 326,...","[-0.9749281119157542, -0.9749281119157542, -0....","[-0.22252005886297657, -0.22252005886297657, -...","[-1.8582172393798828, -1.036149501800537, -1.4...","[-0.8825235764056624, -0.354679024075132, -0.7...","[0.009187988005578518, 0.009187988005578518, 0...",3
2,4,259,"[3, 1633, 2970, 4059, 2572, 4909, 2572, 1633, ...","[14, 40, 40, 40, 40, 40, 40, 40, 40, 40, 326, ...","[2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, ...","[3, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10...","[15, 47, 47, 47, 47, 47, 47, 47, 47, 47, 401, ...","[0.43388293040961884, 0.43388293040961884, 0.4...","[-0.9009692573551896, -0.9009692573551896, -0....","[-0.20048125088214874, -0.08019714057445526, -...","[-0.0008073852542247119, -0.009803943114713387...","[-0.019997386261820793, -0.019997386261820793,...",1


In [154]:
workflow.output_schema.column_names

['user_session',
 'product_id-count',
 'product_id-list_seq',
 'category_code-list_seq',
 'event_type-list_seq',
 'brand-list_seq',
 'category_id-list_seq',
 'et_dayofweek_sin-list_seq',
 'et_dayofweek_cos-list_seq',
 'price_log_norm-list_seq',
 'relative_price_to_avg_categ_id-list_seq',
 'product_recency_days_log_norm-list_seq',
 'day_index']

In [155]:
workflow_path = os.path.join('./', 'workflow_etl')
workflow.save(workflow_path)

In [156]:
import locale
def getpreferredencoding(do_setlocale = True):
    return "UTF-8"
locale.getpreferredencoding = getpreferredencoding

In [157]:
# define partition column
PARTITION_COL = 'day_index'

# define output_folder to store the partitioned parquet files
OUTPUT_FOLDER = os.environ.get("OUTPUT_FOLDER", './' + "sessions_by_day")
!mkdir -p $OUTPUT_FOLDER

In [158]:
from transformers4rec.data.preprocessing import save_time_based_splits
save_time_based_splits(data=nvt.Dataset(sessions_gdf),
                       output_dir= OUTPUT_FOLDER,
                       partition_col=PARTITION_COL,
                       timestamp_col='user_session', 
                      )

Creating time-based splits: 100%|██████████| 4/4 [00:00<00:00,  4.07it/s]


In [159]:
!ls $OUTPUT_FOLDER

1  2  3  4


In [160]:
!pip install torchmetrics

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/


In [161]:
import os
import glob

import torch 
import transformers4rec.torch as tr

from transformers4rec.torch.ranking_metric import NDCGAt, RecallAt
from transformers4rec.torch.utils.examples_utils import wipe_memory

In [162]:
from merlin_standard_lib import Schema
# Define schema object to pass it to the TabularSequenceFeatures class
SCHEMA_PATH = '/content/drive/MyDrive/ds_data/schema_tutorial.pb'

x_cat_names = ['product_id-list_seq', 'category_id-list_seq', 'brand-list_seq']
x_cont_names = ['product_recency_days_log_norm-list_seq', 'et_dayofweek_sin-list_seq', 'et_dayofweek_cos-list_seq', 
                'price_log_norm-list_seq', 'relative_price_to_avg_categ_id-list_seq']

schema = Schema().from_proto_text(SCHEMA_PATH)
schema = schema.select_by_name(x_cat_names + x_cont_names)

In [163]:
sequence_length, d_model = 20, 192
# Define input module to process tabular input-features and to prepare masked inputs
inputs= tr.TabularSequenceFeatures.from_schema(
    schema,
    max_sequence_length=sequence_length,
    aggregation="concat",
    d_output=d_model,
    masking="mlm",
)

In [None]:
# body = tr.SequentialBlock(
#         inputs,
#         tr.MLPBlock([d_model]),
#         tr.Block(torch.nn.GRU(input_size=d_model, hidden_size=d_model, num_layers=1), [None, 20, d_model])
# )

In [164]:
# Define XLNetConfig class and set default parameters for HF XLNet config  
transformer_config = tr.XLNetConfig.build(
    d_model=d_model, n_head=4, n_layer=2, total_seq_length=sequence_length
)
# Define the model block including: inputs, masking, projection and transformer block.
body = tr.SequentialBlock(
    inputs, tr.MLPBlock([192]), tr.TransformerBlock(transformer_config, masking=inputs.masking)
)
# Define the head related to next item prediction task 
head = tr.Head(
    body,
    tr.NextItemPredictionTask(weight_tying=True, 
                                     metrics=[NDCGAt(top_ks=[10, 20], labels_onehot=True),  
                                              RecallAt(top_ks=[10, 20], labels_onehot=True)]),
              
)


# Get the end-to-end Model class 
model = tr.Model(head)




In [165]:
model 

Model(
  (heads): ModuleList(
    (0): Head(
      (body): SequentialBlock(
        (0): TabularSequenceFeatures(
          (_aggregation): ConcatFeatures()
          (to_merge): ModuleDict(
            (continuous_module): ContinuousFeatures(
              (filter_features): FilterFeatures()
            )
            (categorical_module): SequenceEmbeddingFeatures(
              (filter_features): FilterFeatures()
              (embedding_tables): ModuleDict(
                (category_id-list_seq): Embedding(1010, 64, padding_idx=0)
                (brand-list_seq): Embedding(4192, 64, padding_idx=0)
                (product_id-list_seq): Embedding(176394, 64, padding_idx=0)
              )
            )
          )
          (projection_module): SequentialBlock(
            (0): DenseBlock(
              (0): Linear(in_features=197, out_features=192, bias=True)
              (1): ReLU(inplace=True)
            )
          )
          (_masking): MaskedLanguageModeling()
        )
   

In [None]:
# # import NVTabular dependencies
# from transformers4rec.torch.utils.data_utils import NVTabularDataLoader


# # dictionary representing max sequence length for column
# sparse_features_max = {
#     fname: sequence_length
#     for fname in x_cat_names + x_cont_names
# }

# # Define a `get_dataloader` function to call in the training loop
# def get_dataloader(path, batch_size=32):

#     return NVTabularDataLoader.from_schema(
#         schema,
#         path, 
#         batch_size,
#         max_sequence_length=sequence_length,
#         sparse_names=x_cat_names + x_cont_names,
#         sparse_max=sparse_features_max,
# )

In [166]:
from transformers4rec.config.trainer import T4RecTrainingArguments
from transformers4rec.torch import Trainer
from transformers4rec.torch.utils.examples_utils import wipe_memory

#Set arguments for training 
training_args = T4RecTrainingArguments(
            output_dir="./tmp",
            max_sequence_length=50,
            data_loader_engine='nvtabular',
            num_train_epochs=1000, 
            dataloader_drop_last=False,
            per_device_train_batch_size = 2500,
            per_device_eval_batch_size = 32,
            gradient_accumulation_steps = 10,
            learning_rate=0.000666,
            report_to = [],
            logging_steps=1,
)

PyTorch: setting up devices


In [167]:
# Instantiate the T4Rec Trainer, which manages training and evaluation
trainer = Trainer(
    model=model,
    args=training_args,
    schema=schema,
    compute_metrics=True,
)

In [170]:
OUTPUT_DIR ="./sessions_by_day"
OUTPUT_DIR

'./sessions_by_day'

In [171]:
%%time
# start_time_window_index = 1
# final_time_window_index = 4
# for time_index in range(start_time_window_index, final_time_window_index):
#     # Set data 
#     time_index_train = time_index
#     time_index_eval = time_index + 1
train_paths = glob.glob(os.path.join(OUTPUT_DIR, f"*/train.parquet"))
eval_paths = glob.glob(os.path.join(OUTPUT_DIR, f"*/valid.parquet"))
# Train on day related to time_index 
print('*'*20)
# print("Launch training for day %s are:" %time_index)
print('*'*20 + '\n')
trainer.train_dataset_or_path = train_paths
trainer.reset_lr_scheduler()
trainer.train()
trainer.state.global_step +=1
# Evaluate on the following day
trainer.eval_dataset_or_path = eval_paths
train_metrics = trainer.evaluate(metric_key_prefix='eval')
print('*'*20)
# print("Eval results for day %s are:\t" %time_index_eval)
print('\n' + '*'*20 + '\n')
for key in sorted(train_metrics.keys()):
    print(" %s = %s" % (key, str(train_metrics[key]))) 
wipe_memory()

***** Running training *****
  Num examples = 82500
  Num Epochs = 1000
  Instantaneous batch size per device = 2500
  Total train batch size (w. parallel, distributed & accumulation) = 25000
  Gradient Accumulation steps = 10
  Total optimization steps = 3000


********************
********************



ValueError: ignored

In [None]:
trainer._save_model_and_checkpoint(save_model_class=True)

In [None]:
dataloader = trainer.get_eval_dataloader()

In [None]:
for b in dataloader:
    print(model(b)['predictions'].size())
    break