# **Temporal Behaviour-Based Product Suggestion System for E-Commerce üõçÔ∏è**

**Uploading Dataset & Reading File**

In [29]:
import numpy as np
import pandas as pd
file_path = '/kaggle/input/ecommerce-events-history-in-cosmetics-shop/2019-Dec.csv'
df = pd.read_csv(file_path)

print(df.head())

                event_time        event_type  product_id          category_id  \
0  2019-12-01 00:00:00 UTC  remove_from_cart     5712790  1487580005268456287   
1  2019-12-01 00:00:00 UTC              view     5764655  1487580005411062629   
2  2019-12-01 00:00:02 UTC              cart        4958  1487580009471148064   
3  2019-12-01 00:00:05 UTC              view     5848413  1487580007675986893   
4  2019-12-01 00:00:07 UTC              view     5824148  1487580005511725929   

  category_code      brand  price    user_id  \
0           NaN      f.o.x   6.27  576802932   
1           NaN        cnd  29.05  412120092   
2           NaN     runail   1.19  494077766   
3           NaN  freedecor   0.79  348405118   
4           NaN        NaN   5.56  576005683   

                           user_session  
0  51d85cb0-897f-48d2-918b-ad63965c12dc  
1  8adff31e-2051-4894-9758-224bfa8aec18  
2  c99a50e8-2fac-4c4d-89ec-41c05f114554  
3  722ffea5-73c0-4924-8e8f-371ff8031af4  
4  28172809-7e

**Important Columns**

In [30]:
df = df[['event_time', 'event_type', 'product_id', 'category_id', 'brand', 'price', 'user_id', 'user_session']]
df.head()

Unnamed: 0,event_time,event_type,product_id,category_id,brand,price,user_id,user_session
0,2019-12-01 00:00:00 UTC,remove_from_cart,5712790,1487580005268456287,f.o.x,6.27,576802932,51d85cb0-897f-48d2-918b-ad63965c12dc
1,2019-12-01 00:00:00 UTC,view,5764655,1487580005411062629,cnd,29.05,412120092,8adff31e-2051-4894-9758-224bfa8aec18
2,2019-12-01 00:00:02 UTC,cart,4958,1487580009471148064,runail,1.19,494077766,c99a50e8-2fac-4c4d-89ec-41c05f114554
3,2019-12-01 00:00:05 UTC,view,5848413,1487580007675986893,freedecor,0.79,348405118,722ffea5-73c0-4924-8e8f-371ff8031af4
4,2019-12-01 00:00:07 UTC,view,5824148,1487580005511725929,,5.56,576005683,28172809-7e4a-45ce-bab0-5efa90117cd5


**Data Preprocessing**

**Check for Missing Values**

In [31]:
check_columns = ['event_time', 'event_type', 'product_id', 'category_id', 'brand', 'price', 'user_id', 'user_session']

missing_values = df[check_columns].isnull().sum()

print(missing_values)

event_time            0
event_type            0
product_id            0
category_id           0
brand           1510289
price                 0
user_id               0
user_session        779
dtype: int64


**Handling Missing Values for Brand**

In [32]:
# Standardizing
df['brand'] = df['brand'].fillna('').astype(str).str.lower().str.strip().str.replace(r'[.\-\s]', '', regex=True)
df['brand'].replace('', np.nan, inplace=True)

# Mode logic for missing values
brand_map = df.groupby('category_id')['brand'].apply(lambda x: x.mode()[0] if not x.mode().empty else np.nan)
df['brand'] = df['brand'].fillna(df['category_id'].map(brand_map))

# Final catch-all
df['brand'].fillna('unknown', inplace=True)

The behavior will change in pandas 3.0. This inplace method will never work because the intermediate object on which we are setting values always behaves as a copy.

For example, when doing 'df[col].method(value, inplace=True)', try using 'df.method({col: value}, inplace=True)' or df[col] = df[col].method(value) instead, to perform the operation inplace on the original object.


  df['brand'].replace('', np.nan, inplace=True)
The behavior will change in pandas 3.0. This inplace method will never work because the intermediate object on which we are setting values always behaves as a copy.

For example, when doing 'df[col].method(value, inplace=True)', try using 'df.method({col: value}, inplace=True)' or df[col] = df[col].method(value) instead, to perform the operation inplace on the original object.


  df['brand'].fillna('unknown', inplace=True)


**Check for Unkown Brand Labels**

In [33]:
unknown_count = (df['brand'] == 'unknown').sum()
print(unknown_count)

457482


**Check for Null Values after Preprocessing Brand**

In [34]:
check_columns = ['brand']

missing_values = df[check_columns].isnull().sum()

print(missing_values)

brand    0
dtype: int64


**Handling Missing Values for User Session**

In [35]:
df = df.dropna(subset=['user_session'])

**Handling Missing Values for Negative Price**

In [36]:
# Find negative price rows
negative_prices = df[df['price'] < 0]

# Count negatives by event_type
event_type_counts = negative_prices.groupby('event_type').size()

# Count negatives by brand
brand_counts = negative_prices.groupby('brand').size().sort_values(ascending=False)

print("Total negative prices:", len(negative_prices))
print("Negative Prices by Event Type:\n", event_type_counts)
print("Negative Prices by Brand:\n", brand_counts.head())

Total negative prices: 18
Negative Prices by Event Type:
 event_type
purchase    18
dtype: int64
Negative Prices by Brand:
 brand
unknown    18
dtype: int64


In [37]:
df = df.drop(df[df['price'] < 0].index)
df.head()

Unnamed: 0,event_time,event_type,product_id,category_id,brand,price,user_id,user_session
0,2019-12-01 00:00:00 UTC,remove_from_cart,5712790,1487580005268456287,fox,6.27,576802932,51d85cb0-897f-48d2-918b-ad63965c12dc
1,2019-12-01 00:00:00 UTC,view,5764655,1487580005411062629,cnd,29.05,412120092,8adff31e-2051-4894-9758-224bfa8aec18
2,2019-12-01 00:00:02 UTC,cart,4958,1487580009471148064,runail,1.19,494077766,c99a50e8-2fac-4c4d-89ec-41c05f114554
3,2019-12-01 00:00:05 UTC,view,5848413,1487580007675986893,freedecor,0.79,348405118,722ffea5-73c0-4924-8e8f-371ff8031af4
4,2019-12-01 00:00:07 UTC,view,5824148,1487580005511725929,unknown,5.56,576005683,28172809-7e4a-45ce-bab0-5efa90117cd5


**Check for Negative Values After Preprocessing Price**

In [38]:
# Find negative price rows
negative_prices = df[df['price'] < 0]
print("Total negative prices:", len(negative_prices))

Total negative prices: 0


**Feature Engineering**

**User Table & Product Table**

In [39]:
# User Table
user_agg = df.groupby(['user_id', 'event_type']).size().unstack(fill_value=0)

# Calculate session-level stats
user_sessions = df.groupby('user_id')['user_session'].nunique().rename("session_count")
user_avg_price = df.groupby('user_id')['price'].mean().rename("user_avg_price")
user_category_div = df.groupby('user_id')['category_id'].nunique().rename("distinct_categories")

# Combine
user_features = pd.concat([user_agg, user_sessions, user_avg_price, user_category_div], axis=1) #axis = 1 -> Horizontal Concat

# Rename
user_features.rename(columns={
    'cart': 'user_cart_count',
    'purchase': 'user_purchase_count',
    'remove_from_cart': 'user_remove_count',
    'view': 'user_view_count'
}, inplace=True)

# Ratios
user_features['user_cart_to_purchase_ratio'] = user_features['user_purchase_count'] / (user_features['user_cart_count'] + 1e-9)
user_features['user_remove_rate'] = user_features['user_remove_count'] / (user_features['user_cart_count'] + 1e-9)

# Avg per session
user_features['user_avg_views_per_session'] = user_features['user_view_count'] / user_features['session_count']
user_features['user_avg_carts_per_session'] = user_features['user_cart_count'] / user_features['session_count']
user_features['user_avg_removes_per_session'] = user_features['user_remove_count'] / user_features['session_count']
user_features['user_avg_purchases_per_session'] = user_features['user_purchase_count'] / user_features['session_count']

user_features = user_features.reset_index()

# Product Table
product_agg = df.groupby(['product_id', 'event_type']).size().unstack(fill_value=0)

product_avg_price = df.groupby('product_id')['price'].mean().rename("product_avg_price")
product_unique_users = df.groupby('product_id')['user_id'].nunique().rename("unique_users")

# Combine
product_features = pd.concat([product_agg, product_avg_price, product_unique_users], axis=1)

# Rename
product_features.rename(columns={
    'cart': 'product_cart_count',
    'purchase': 'product_purchase_count',
    'remove_from_cart': 'product_remove_count',
    'view': 'product_view_count'
}, inplace=True)

# Ratios
product_features['view_to_cart_ratio'] = product_features['product_cart_count'] / (product_features['product_view_count'] + 1e-9)
product_features['product_cart_to_purchase_ratio'] = product_features['product_purchase_count'] / (product_features['product_cart_count'] + 1e-9)
product_features['product_remove_rate'] = product_features['product_remove_count'] / (product_features['product_cart_count'] + 1e-9)

# Repeat purchase rate: fraction of users who purchased product >1 times
purchases = df[df['event_type'] == 'purchase'].groupby(['product_id','user_id']).size().reset_index(name='count')
repeat_rate = purchases[purchases['count']>1].groupby('product_id')['user_id'].nunique() / product_unique_users
product_features['repeat_purchase_rate'] = repeat_rate.fillna(0)

product_features = product_features.reset_index()

# Preview
print("User Table:")
print(user_features)

print("\nProduct Table:")
print(product_features)

User Table:
          user_id  user_cart_count  user_purchase_count  user_remove_count  \
0         1180452                0                    0                  0   
1         2963072                7                    0                  0   
2         4661182                2                    0                  2   
3         4891613                0                    0                  0   
4         6217356                0                    0                  0   
...           ...              ...                  ...                ...   
370103  595413843                0                    0                  0   
370104  595413976                0                    0                  0   
370105  595414210                0                    0                  0   
370106  595414257                0                    0                  0   
370107  595414541                0                    0                  0   

        user_view_count  session_count  user_avg_pr

In [40]:
# Merge user features into df
df = df.merge(user_features, on='user_id', how='left')

# Merge product features into df
df = df.merge(product_features, on='product_id', how='left')

**Temporal Features**

In [41]:
df['event_time'] = pd.to_datetime(df['event_time'])

# Sort events by user & time
df = df.sort_values(by=['user_id', 'event_time']).reset_index(drop=True)

# 1. Œît (time since last event per user)
df['delta_t'] = df.groupby('user_id')['event_time'].diff().dt.total_seconds()
df['delta_t'] = df['delta_t'].fillna(0)  # First event for each user

# 2. Recency Features
latest_time = df['event_time'].max()

# Recency of each user‚Äôs last action
user_last_action = df.groupby('user_id')['event_time'].max().reset_index()
user_last_action['user_recency'] = (latest_time - user_last_action['event_time']).dt.total_seconds()

# Map back to df
df = df.merge(user_last_action[['user_id', 'user_recency']], on='user_id', how='left')

# 3. Session Duration & Length
# Duration per session
session_durations = df.groupby('user_session')['event_time'].agg(lambda x: (x.max() - x.min()).total_seconds())
session_durations = session_durations.rename("session_duration").reset_index()

# Length (number of events per session)
session_lengths = df.groupby('user_session').size().rename("session_length").reset_index()

# Merge into df
df = df.merge(session_durations, on='user_session', how='left')
df = df.merge(session_lengths, on='user_session', how='left')


# 4. Time of Day / Day of Week (cyclical encoding ready for Transformer)
df['hour'] = df['event_time'].dt.hour
df['day_of_week'] = df['event_time'].dt.dayofweek  # Monday=0, Sunday=6

# Cyclical encoding (sin/cos for Transformer embeddings)
df['hour_sin'] = np.sin(2 * np.pi * df['hour'] / 24)
df['hour_cos'] = np.cos(2 * np.pi * df['hour'] / 24)
df['dow_sin'] = np.sin(2 * np.pi * df['day_of_week'] / 7)
df['dow_cos'] = np.cos(2 * np.pi * df['day_of_week'] / 7)

print("Temporal features added: delta_t, recency, session_duration, session_length, cyclical time features")
print(df.head())


Temporal features added: delta_t, recency, session_duration, session_length, cyclical time features
                 event_time event_type  product_id          category_id  \
0 2019-12-28 14:32:56+00:00       view     5881337  1487580012096782476   
1 2019-12-22 12:50:22+00:00       view     5746011  1487580009051717646   
2 2019-12-22 12:50:45+00:00       view     5707747  1487580009051717646   
3 2019-12-22 12:50:58+00:00       view     5746011  1487580009051717646   
4 2019-12-22 12:53:12+00:00       view     5707747  1487580009051717646   

      brand  price  user_id                          user_session  \
0     fedua  25.40  1180452  a4818e6d-9069-4aa8-8731-572ac266283f   
1    runail  34.92  2963072  3bf2bbbb-1a32-4f06-a263-a3f49db74750   
2  jessnail  73.02  2963072  3bf2bbbb-1a32-4f06-a263-a3f49db74750   
3    runail  34.92  2963072  3bf2bbbb-1a32-4f06-a263-a3f49db74750   
4  jessnail  73.02  2963072  3bf2bbbb-1a32-4f06-a263-a3f49db74750   

   user_cart_count  user_purchase_

**Split Data for Train, Validation, Test**

In [42]:
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler, LabelEncoder

# Split data
unique_users = df["user_id"].unique()
train_users, temp_users = train_test_split(unique_users, test_size=0.2, random_state=42)
val_users, test_users = train_test_split(temp_users, test_size=0.5, random_state=42)

train_df = df[df["user_id"].isin(train_users)].copy()
val_df   = df[df["user_id"].isin(val_users)].copy()
test_df  = df[df["user_id"].isin(test_users)].copy()

for df_split in [train_df, val_df, test_df]:
    df_split['delta_t_raw'] = df_split['delta_t'].copy()

print("Train size:", len(train_df))
print("Val size:", len(val_df))
print("Test size:", len(test_df))

Train size: 2804072
Val size: 359184
Test size: 369233


**Normalization & Encoding**

In [43]:
# 1. Normalization
from sklearn.preprocessing import StandardScaler
import numpy as np

# Define feature groups
time_features = ['delta_t', 'user_recency', 'session_duration', 'session_length']
other_num_features = [
    'price', 'session_count', 'user_avg_price',
    'user_cart_to_purchase_ratio', 'user_remove_rate', 'user_avg_views_per_session',
    'user_avg_carts_per_session', 'user_avg_removes_per_session',
    'user_avg_purchases_per_session', 'product_avg_price', 'view_to_cart_ratio',
    'product_cart_to_purchase_ratio', 'product_remove_rate', 'repeat_purchase_rate'
]

# Apply log transformation to time features to handle skew and ensure non-negativity
for df_split in [train_df, val_df, test_df]:
    for col in time_features:
        df_split[col] = np.log1p(df_split[col].clip(lower=0)) # Use log(1+x) and clip at 0 just in case

# Initialize scaler
scaler = StandardScaler()

# Combine all numeric features for scaling
num_features = time_features + other_num_features

# Fit the scaler ONLY on the training data
scaler.fit(train_df[num_features])

# Transform the training, validation, and test data
train_df[num_features] = scaler.transform(train_df[num_features])
val_df[num_features] = scaler.transform(val_df[num_features])
test_df[num_features] = scaler.transform(test_df[num_features])

print("Normalization applied")

Normalization applied


In [44]:
categorical_features = [
    "event_type", "category_id", "product_id", 
    "brand", "user_id", "user_session"
]
encoders = {}

for col in categorical_features:
    # Initialize a new encoder for each column
    encoder = LabelEncoder()
    
    # Fit and transform the training data
    train_df[col + "_enc"] = encoder.fit_transform(train_df[col])
    
    # Store the fitted encoder
    encoders[col] = encoder

    unknown_index = len(encoder.classes_)
    # Handle the validation and test sets
    for df_split in [val_df, test_df]:
        # Identify known labels
        known_labels_mask = df_split[col].isin(encoder.classes_)
        
        # Apply transform only on known labels, others will be NaN
        # Use .loc to avoid SettingWithCopyWarning
        df_split.loc[known_labels_mask, col + "_enc"] = encoder.transform(df_split.loc[known_labels_mask, col])
        
        # Fill any NaNs (which were the unknown labels) with a special value, e.g., -1
        df_split[col + "_enc"] = df_split[col + "_enc"].fillna(unknown_index)
        
        # Ensure the column is of integer type
        df_split[col + "_enc"] = df_split[col + "_enc"].astype(int)

print("Data Encoded")

Data Encoded


**Sequence Dataset Preparation**

In [45]:
import torch
from torch.utils.data import Dataset, DataLoader
import numpy as np

class EventSequenceDataset(Dataset):
    def __init__(self, df, group_col="user_id", seq_len=50):
        self.seq_len = seq_len
        self.grouped_data = list(df.groupby(group_col))
        
        # Define which columns from the DataFrame correspond to which model inputs
        self.feature_cols = {
            "user_ids": "user_id_enc",
            "product_ids": "product_id_enc",
            "event_types": "event_type_enc",
            "categories": "category_id_enc",
            "time_deltas": "delta_t"
        }
        
        # List of numeric features to be bundled together
        self.numeric_cols = [
            'price', 'user_recency', 'session_duration', 'session_length',
            'user_avg_price', 'user_cart_to_purchase_ratio', 'user_remove_rate',
            'user_avg_views_per_session', 'product_avg_price', 'view_to_cart_ratio',
            'product_cart_to_purchase_ratio', 'product_remove_rate', 'repeat_purchase_rate',
            'hour_sin', 'hour_cos', 'dow_sin', 'dow_cos'
        ]

    def __len__(self):
        return len(self.grouped_data)

    def __getitem__(self, idx):
        _, group = self.grouped_data[idx]
        group = group.sort_values("event_time")

        # Truncate if longer than seq_len
        if len(group) > self.seq_len:
            group = group.tail(self.seq_len)
        
        # Prepare the dictionary that will hold the final tensors
        sequence = {}
        
        # Extract sequences for categorical/ID features
        for key, col in self.feature_cols.items():
            sequence[key] = torch.tensor(group[col].values, dtype=torch.long)

        sequence["raw_time_deltas"] = torch.tensor(group["delta_t_raw"].values, dtype=torch.float32)
            
        # Extract and stack numeric features
        numeric_data = group[self.numeric_cols].values
        sequence["numeric_feats"] = torch.tensor(numeric_data, dtype=torch.float32)

        # --- Padding and Mask Creation ---
        seq_len_actual = len(group)
        pad_len = self.seq_len - seq_len_actual
        
        # Create the mask first: 1 for real data, 0 for padding
        sequence["mask"] = torch.cat([torch.ones(seq_len_actual), torch.zeros(pad_len)], dim=0).float()

        if pad_len > 0:
            for key, tensor in sequence.items():
                if key == "mask":
                    continue # Mask is already handled
                    
                if key == "numeric_feats":
                    # Pad numeric features with zeros
                    pad_tensor = torch.zeros(pad_len, tensor.shape[1])
                else:
                    # Pad categorical/ID features with 0
                    pad_tensor = torch.zeros(pad_len, dtype=tensor.dtype)
                
                sequence[key] = torch.cat([tensor, pad_tensor], dim=0)

        # Ensure time_deltas has the extra dimension the model expects (B, L, 1)
        sequence["time_deltas"] = sequence["time_deltas"].unsqueeze(-1).float()

        return sequence

**Checks if Event Time is sorted**

In [46]:
# This checks if the 'event_time' is always increasing within each user's group.
is_sorted = df.groupby('user_id')['event_time'].is_monotonic_increasing.all()

if is_sorted:
    print("DataFrame is already correctly sorted")
else:
    print(" Re-sort")

DataFrame is already correctly sorted


**Wrapping the Dataset**

In [47]:
from torch.utils.data import DataLoader

seq_len = 50

train_dataset = EventSequenceDataset(train_df, seq_len=seq_len)
val_dataset = EventSequenceDataset(val_df, seq_len=seq_len)
test_dataset = EventSequenceDataset(test_df, seq_len=seq_len)

**Dataloaders**

In [48]:
batch_size = 32

train_loader = DataLoader(
    train_dataset, 
    batch_size=batch_size, 
    shuffle=True, 
    num_workers=4, 
    pin_memory=True
)

val_loader = DataLoader(
    val_dataset, 
    batch_size=batch_size, 
    shuffle=False, 
    num_workers=4, 
    pin_memory=True
)

test_loader = DataLoader(
    test_dataset, 
    batch_size=batch_size, 
    shuffle=False, 
    num_workers=4,
    pin_memory=True
)

**Building Neural HawkesTransformer Model**

In [49]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class NeuralHawkesTransformer(nn.Module):
    def __init__(self, 
                 num_users, num_products, num_event_types, num_categories,
                 num_numeric_features,
                 embed_dim=64, 
                 n_heads=4, 
                 n_layers=2, 
                 dropout=0.2):
        super().__init__()

        # 1. Embedding Layers
        self.user_emb = nn.Embedding(num_users, embed_dim)
        self.product_emb = nn.Embedding(num_products, embed_dim)
        self.event_emb = nn.Embedding(num_event_types, embed_dim)
        self.category_emb = nn.Embedding(num_categories, embed_dim)
        self.numeric_proj = nn.Linear(num_numeric_features, embed_dim)
        self.time_proj = nn.Linear(1, embed_dim)

        # 2. Feature Combiner Layer
        # We will concatenate 6 vectors of size embed_dim
        combined_feature_dim = embed_dim * 6 
        # Project this combined vector back to the model's main dimension (embed_dim)
        self.feature_combiner = nn.Linear(combined_feature_dim, embed_dim)

        # 3. Transformer Encoder
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=embed_dim,
            nhead=n_heads,
            dropout=dropout,
            batch_first=True
        )
        self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=n_layers)

        # 4. Output Heads
        self.intensity_fc = nn.Linear(embed_dim, 1)
        self.event_pred_fc = nn.Linear(embed_dim, num_event_types)
        self.product_pred_fc = nn.Linear(embed_dim, num_products)

    def forward(self, user_ids, product_ids, event_types, categories,
                numeric_feats, time_deltas, src_key_padding_mask=None):
        
        # 1. Embeddings
        u = self.user_emb(user_ids)
        p = self.product_emb(product_ids)
        e = self.event_emb(event_types)
        c = self.category_emb(categories)
        n = self.numeric_proj(numeric_feats)
        t = self.time_proj(time_deltas)

        # 2. Concatenate features instead of adding
        x_concat = torch.cat([u, p, e, c, n, t], dim=-1) # (B, L, embed_dim * 6)
        
        # 3. Project the combined vector
        x = self.feature_combiner(x_concat) # (B, L, embed_dim)

        # 4. Transformer Encoding
        h = self.transformer(x, src_key_padding_mask=src_key_padding_mask)

        # 5. Outputs
        intensity = F.softplus(self.intensity_fc(h))
        event_logits = self.event_pred_fc(h)
        product_logits = self.product_pred_fc(h)

        return intensity, event_logits, product_logits

**Get lengths of DataFrames**

In [50]:
# 1. Define the list of all numeric features
numeric_feature_names = [
    'price', 'user_recency', 'session_duration', 'session_length',
    'user_avg_price', 'user_cart_to_purchase_ratio', 'user_remove_rate',
    'user_avg_views_per_session', 'product_avg_price', 'view_to_cart_ratio',
    'product_cart_to_purchase_ratio', 'product_remove_rate', 'repeat_purchase_rate',
    'hour_sin', 'hour_cos', 'dow_sin', 'dow_cos'
]

# 2. Get the count of numeric features
num_numeric_features = len(numeric_feature_names)

# 3. Get the vocabulary sizes from your fitted encoders
num_users = len(encoders['user_id'].classes_) + 1
num_products = len(encoders['product_id'].classes_) + 1
num_event_types = len(encoders['event_type'].classes_) + 1
num_categories = len(encoders['category_id'].classes_) + 1 

print(f"Number of numeric features: {num_numeric_features}")
print(f"Vocabulary size - Users: {num_users}")
print(f"Vocabulary size - Products: {num_products}")
print(f"Vocabulary size - Event Types: {num_event_types}")
print(f"Vocabulary size - Categories: {num_categories}")

# 4. Instantiate the model
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model = NeuralHawkesTransformer(
    num_users=num_users,
    num_products=num_products,
    num_event_types=num_event_types,
    num_categories=num_categories,
    num_numeric_features=num_numeric_features
).to(device) # Move the model to the correct device

Number of numeric features: 17
Vocabulary size - Users: 296087
Vocabulary size - Products: 43580
Vocabulary size - Event Types: 5
Vocabulary size - Categories: 479


In [51]:
from torch.utils.data import Dataset

class YourEventSequenceDataset(Dataset):
    def __init__(self, data):
        self.data = data

    def __len__(self):
        return len(self.data)

    def __getitem__(self, index):
        # Get the data for a single sequence at the given index
        sequence_data = self.data[index]

        # Construct and return a dictionary with the correct keys
        return {
            "user_ids": sequence_data['user_tensor'],
            "product_ids": sequence_data['product_tensor'],
            "event_types": sequence_data['event_type_tensor'],
            "categories": sequence_data['category_tensor'],
            "numeric_feats": sequence_data['numeric_feats_tensor'],
            "time_deltas": sequence_data['delta_tensor'],
            "mask": sequence_data['mask_tensor']
        }

**Model Train & Validate**

In [52]:
import os
import math
import torch
import torch.nn as nn
from torch.optim import AdamW
from tqdm.auto import tqdm

# Hyperparameters / config
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

epochs = 10
learning_rate = 3e-4
weight_decay = 1e-2
grad_clip = 1.0

time_loss_weight = 1.0
event_loss_weight = 1.0
product_loss_weight = 0.5

ckpt_dir = "checkpoints"
os.makedirs(ckpt_dir, exist_ok=True)

# Helpers: masked CE and intensity NLL
eps = 1e-9

def masked_cross_entropy(logits, targets, mask):
    """
    logits: (B, L, C)
    targets: (B, L) long
    mask: (B, L) float 1/0
    """
    B, L, C = logits.shape
    logits_flat = logits.view(B * L, C)
    targets_flat = targets.view(B * L)
    mask_flat = mask.view(B * L)

    loss_per_pos = nn.functional.cross_entropy(logits_flat, targets_flat, reduction='none')  # (B*L,)
    loss_per_pos = loss_per_pos * mask_flat
    loss = loss_per_pos.sum() / (mask_flat.sum() + eps)
    return loss


def hawkes_time_nll_from_intensities(intensity, deltas, mask):
    """
    Approximates time log-likelihood using intensity values at event timestamps and
    trapezoidal rule for the integral term over observed intervals.

    intensity: (B, L, 1) positive (Œª_i) evaluated at each event time
    deltas: (B, L) time delta to previous event in seconds (Œît_i). For the first event Œît can be 0.
            We will use dt_i = deltas[:, i] for interval between i-1 and i.
    mask: (B, L) float 1/0 valid positions.

    NLL = - sum_i log Œª(t_i) + sum_intervals ‚à´ Œª(t) dt
    We approximate ‚à´_{t_i}^{t_{i+1}} Œª(t) dt ‚âà 0.5*(Œª_i + Œª_{i+1}) * dt_{i+1}
    where dt_{i+1} is time difference between t_{i+1} and t_i (i.e., deltas[:, i+1]).
    """

    # Ensure shapes
    B, L, _ = intensity.shape
    lambda_vals = intensity.squeeze(-1)  # (B, L)

    # log term: - sum log lambda at event times (only where mask==1)
    log_lambda = torch.log(lambda_vals + eps)  # (B, L)
    log_term = - (log_lambda * mask).sum()

    # integral term via trapezoidal rule over intervals where both endpoints exist
    # compute lambda_i and lambda_{i+1} and multiply by dt_{i+1}
    # shift lambda and mask by one to get next-step values
    lambda_i = lambda_vals[:, :-1]      # (B, L-1)
    lambda_ip1 = lambda_vals[:, 1:]     # (B, L-1)
    mask_i = mask[:, :-1]
    mask_ip1 = mask[:, 1:]
    valid_interval_mask = mask_i * mask_ip1  # only intervals where both ends exist

    # dt for interval i->i+1 is deltas[:, 1:] (seconds between t_i and t_{i+1})
    dt_ip1 = deltas[:, 1:]  # (B, L-1)

    # trapezoid integral per interval: 0.5*(Œª_i + Œª_{i+1}) * dt
    trap = 0.5 * (lambda_i + lambda_ip1) * dt_ip1 * valid_interval_mask

    integral_term = trap.sum()

    # total NLL (sum over batch and timesteps)
    nll = log_term + integral_term

    # Optionally normalize per valid event count to keep scale consistent
    event_count = mask.sum()
    nll = nll / (event_count + eps)

    return nll

# Instantiate model, optimizer, scaler

optimizer = AdamW(model.parameters(), lr=learning_rate, weight_decay=weight_decay)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=1, verbose=True)

scaler = torch.cuda.amp.GradScaler() if torch.cuda.is_available() else None

# Training loop
def train_one_epoch(model, loader, optimizer, scaler):
    model.train()
    total_loss = 0.0
    total_time_loss = 0.0
    total_event_loss = 0.0
    total_product_loss = 0.0
    num_batches = 0

    pbar = tqdm(loader, desc="train", leave=False)
    for batch in pbar:
        # unpack batch (adapt names to your DataLoader)
        user_ids = batch["user_ids"].to(device)           # (B, L)
        product_ids = batch["product_ids"].to(device)     # (B, L)
        event_types = batch["event_types"].to(device)     # (B, L)
        categories = batch.get("categories", torch.zeros_like(user_ids)).to(device)  # (B, L)
        numeric_feats = batch.get("numeric_feats", torch.zeros(user_ids.size(0), user_ids.size(1), 0)).to(device)  # (B, L, F)
        deltas = batch["time_deltas"].squeeze(-1).to(device)  # (B, L)
        raw_deltas = batch["raw_time_deltas"].to(device)
        mask = batch["mask"].to(device)                     # (B, L) float

        optimizer.zero_grad()

        with torch.cuda.amp.autocast(enabled=(scaler is not None)):
            intensity, event_logits, product_logits = model(
                user_ids, product_ids, event_types, categories,
                numeric_feats, deltas.unsqueeze(-1), src_key_padding_mask=(mask == 0))
            # intensity: (B, L, 1)
            # event_logits: (B, L, C_event)
            # product_logits: (B, L, C_product)

            # Time NLL
            time_nll = hawkes_time_nll_from_intensities(intensity, raw_deltas, mask)

            # Event CE (predict current event type)
            event_ce = masked_cross_entropy(event_logits, event_types, mask)

            # Product CE (predict current product)
            product_ce = masked_cross_entropy(product_logits, product_ids, mask)

            loss = (time_loss_weight * time_nll +
                    event_loss_weight * event_ce +
                    product_loss_weight * product_ce)

        if scaler is not None:
            scaler.scale(loss).backward()
            scaler.unscale_(optimizer)
            torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip)
            scaler.step(optimizer)
            scaler.update()
        else:
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip)
            optimizer.step()

        total_loss += loss.item()
        total_time_loss += time_nll.item()
        total_event_loss += event_ce.item()
        total_product_loss += product_ce.item()
        num_batches += 1
        pbar.set_postfix({"loss": total_loss / num_batches})

    return {
        "loss": total_loss / num_batches,
        "time_loss": total_time_loss / num_batches,
        "event_loss": total_event_loss / num_batches,
        "product_loss": total_product_loss / num_batches
    }

def validate(model, loader):
    model.eval()
    total_loss = 0.0
    total_time_loss = 0.0
    total_event_loss = 0.0
    total_product_loss = 0.0
    num_batches = 0
    
    # --- METRICS TO ADD ---
    total_event_correct = 0
    total_product_correct = 0
    total_events_masked = 0

    with torch.no_grad():
        pbar = tqdm(loader, desc="val", leave=False)
        for batch in pbar:
            user_ids = batch["user_ids"].to(device)
            product_ids = batch["product_ids"].to(device)
            event_types = batch["event_types"].to(device)
            categories = batch.get("categories", torch.zeros_like(user_ids)).to(device)
            numeric_feats = batch.get("numeric_feats", torch.zeros(user_ids.size(0), user_ids.size(1), 0)).to(device)
            deltas = batch["time_deltas"].squeeze(-1).to(device)
            raw_deltas = batch["raw_time_deltas"].to(device)
            mask = batch["mask"].to(device)

            intensity, event_logits, product_logits = model(
                user_ids, product_ids, event_types, categories,
                numeric_feats, deltas.unsqueeze(-1), src_key_padding_mask=(mask == 0))

            # --- Calculate Losses (unchanged) ---
            time_nll = hawkes_time_nll_from_intensities(intensity, raw_deltas, mask)
            event_ce = masked_cross_entropy(event_logits, event_types, mask)
            product_ce = masked_cross_entropy(product_logits, product_ids, mask)
            loss = (time_loss_weight * time_nll +
                    event_loss_weight * event_ce +
                    product_loss_weight * product_ce)

            total_loss += loss.item()
            total_time_loss += time_nll.item()
            total_event_loss += event_ce.item()
            total_product_loss += product_ce.item()
            num_batches += 1
            
            # --- CALCULATE METRICS (New) ---
            # Get the index of the highest probability prediction
            pred_events = torch.argmax(event_logits, dim=2)
            pred_products = torch.argmax(product_logits, dim=2)
            
            # Compare predictions to targets, only where mask is 1
            total_event_correct += ((pred_events == event_types) * mask).sum().item()
            total_product_correct += ((pred_products == product_ids) * mask).sum().item()
            total_events_masked += mask.sum().item()

    # Return a dictionary with all metrics
    return {
        "loss": total_loss / (num_batches + eps),
        "time_loss": total_time_loss / (num_batches + eps),
        "event_loss": total_event_loss / (num_batches + eps),
        "product_loss": total_product_loss / (num_batches + eps),
        "event_accuracy": total_event_correct / (total_events_masked + eps),
        "product_accuracy": total_product_correct / (total_events_masked + eps)
    }

# Full training run
best_val_loss = float("inf")
for epoch in range(1, epochs + 1):
    print(f"\n=== Epoch {epoch}/{epochs} ===")
    train_metrics = train_one_epoch(model, train_loader, optimizer, scaler)
    val_metrics = validate(model, val_loader)

    print(f"Train Loss: {train_metrics['loss']:.5f} | Time: {train_metrics['time_loss']:.5f} | "
          f"Event: {train_metrics['event_loss']:.5f} | Prod: {train_metrics['product_loss']:.5f}")
    print(f"Val   Loss: {val_metrics['loss']:.5f} | Time: {val_metrics['time_loss']:.5f} | "
          f"Event: {val_metrics['event_loss']:.5f} | Prod: {val_metrics['product_loss']:.5f}")

    # scheduler step on validation loss
    scheduler.step(val_metrics["loss"])

    # save best checkpoint
    if val_metrics["loss"] < best_val_loss:
        best_val_loss = val_metrics["loss"]
        ckpt_path = os.path.join(ckpt_dir, f"best_model_epoch{epoch}.pt")
        torch.save({
            "epoch": epoch,
            "model_state_dict": model.state_dict(),
            "optimizer_state_dict": optimizer.state_dict(),
            "scaler_state_dict": scaler.state_dict() if scaler is not None else None
        }, ckpt_path)
        print("Saved best model ->", ckpt_path)


=== Epoch 1/10 ===


  scaler = torch.cuda.amp.GradScaler() if torch.cuda.is_available() else None


train:   0%|          | 0/9253 [00:00<?, ?it/s]

  with torch.cuda.amp.autocast(enabled=(scaler is not None)):


val:   0%|          | 0/1157 [00:00<?, ?it/s]

  output = torch._nested_tensor_from_mask(
Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7a77a50bfd80>
Traceback (most recent call last):
  File "/usr/local/lib/python3.11/dist-packages/torch/utils/data/dataloader.py", line 1618, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.11/dist-packages/torch/utils/data/dataloader.py", line 1601, in _shutdown_workers
    if w.is_alive():
     Exception ignored in:   <function _MultiProcessingDataLoaderIter.__del__ at 0x7a77a50bfd80>^^
^Traceback (most recent call last):
  File "/usr/local/lib/python3.11/dist-packages/torch/utils/data/dataloader.py", line 1618, in __del__
^    ^^self._shutdown_workers()^^
^  File "/usr/local/lib/python3.11/dist-packages/torch/utils/data/dataloader.py", line 1601, in _shutdown_workers
^^Exception ignored in:     ^<function _MultiProcessingDataLoaderIter.__del__ at 0x7a77a50bfd80>if w.is_alive():


  File "/usr/lib/python3.11/multiprocessing/process.py", line 160

Train Loss: 46.69596 | Time: 42.11101 | Event: 0.04354 | Prod: 9.08284
Val   Loss: 10.61847 | Time: 6.70146 | Event: 0.00153 | Prod: 7.83095
Saved best model -> checkpoints/best_model_epoch1.pt

=== Epoch 2/10 ===


train:   0%|          | 0/9253 [00:00<?, ?it/s]

Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7a77a50bfd80>
Exception ignored in: Traceback (most recent call last):
<function _MultiProcessingDataLoaderIter.__del__ at 0x7a77a50bfd80>  File "/usr/local/lib/python3.11/dist-packages/torch/utils/data/dataloader.py", line 1618, in __del__

Exception ignored in: Traceback (most recent call last):
    <function _MultiProcessingDataLoaderIter.__del__ at 0x7a77a50bfd80>  File "/usr/local/lib/python3.11/dist-packages/torch/utils/data/dataloader.py", line 1618, in __del__
self._shutdown_workers()    

self._shutdown_workers()Exception ignored in:       File "/usr/local/lib/python3.11/dist-packages/torch/utils/data/dataloader.py", line 1601, in _shutdown_workers
Traceback (most recent call last):
  File "/usr/local/lib/python3.11/dist-packages/torch/utils/data/dataloader.py", line 1601, in _shutdown_workers
  File "/usr/local/lib/python3.11/dist-packages/torch/utils/data/dataloader.py", line 1618, in __del__

if w.i

val:   0%|          | 0/1157 [00:00<?, ?it/s]

Exception ignored in: Exception ignored in: Exception ignored in: Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7a77a50bfd80><function _MultiProcessingDataLoaderIter.__del__ at 0x7a77a50bfd80><function _MultiProcessingDataLoaderIter.__del__ at 0x7a77a50bfd80>

<function _MultiProcessingDataLoaderIter.__del__ at 0x7a77a50bfd80>Traceback (most recent call last):
Exception ignored in: Traceback (most recent call last):

  File "/usr/local/lib/python3.11/dist-packages/torch/utils/data/dataloader.py", line 1618, in __del__
  File "/usr/local/lib/python3.11/dist-packages/torch/utils/data/dataloader.py", line 1618, in __del__
<function _MultiProcessingDataLoaderIter.__del__ at 0x7a77a50bfd80>
        
Traceback (most recent call last):
self._shutdown_workers()Traceback (most recent call last):
self._shutdown_workers()
Traceback (most recent call last):
  File "/usr/local/lib/python3.11/dist-packages/torch/utils/data/dataloader.py", line 1601, in _shutdown_workers

Train Loss: 10.49403 | Time: 7.77381 | Event: 0.00748 | Prod: 5.42550
Val   Loss: 8.32266 | Time: 6.90026 | Event: 0.00091 | Prod: 2.84299
Saved best model -> checkpoints/best_model_epoch2.pt

=== Epoch 3/10 ===


train:   0%|          | 0/9253 [00:00<?, ?it/s]

val:   0%|          | 0/1157 [00:00<?, ?it/s]

Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7a77a50bfd80>
Traceback (most recent call last):
  File "/usr/local/lib/python3.11/dist-packages/torch/utils/data/dataloader.py", line 1618, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.11/dist-packages/torch/utils/data/dataloader.py", line 1601, in _shutdown_workers
    if w.is_alive():
       ^^^^^^^^^^^^Exception ignored in: 
<function _MultiProcessingDataLoaderIter.__del__ at 0x7a77a50bfd80>  File "/usr/lib/python3.11/multiprocessing/process.py", line 160, in is_alive

    Traceback (most recent call last):
Exception ignored in: assert self._parent_pid == os.getpid(), 'can only test a child process'  File "/usr/local/lib/python3.11/dist-packages/torch/utils/data/dataloader.py", line 1618, in __del__
<function _MultiProcessingDataLoaderIter.__del__ at 0x7a77a50bfd80>
     
 Traceback (most recent call last):
 self._shutdown_workers()  File "/usr/local/lib/python3.11/dist-packages/to

Train Loss: 8.61327 | Time: 7.70422 | Event: 0.00460 | Prod: 1.80890
Val   Loss: 7.40640 | Time: 6.89683 | Event: 0.00063 | Prod: 1.01787
Saved best model -> checkpoints/best_model_epoch3.pt

=== Epoch 4/10 ===


train:   0%|          | 0/9253 [00:00<?, ?it/s]

Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7a77a50bfd80>
Traceback (most recent call last):
  File "/usr/local/lib/python3.11/dist-packages/torch/utils/data/dataloader.py", line 1618, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.11/dist-packages/torch/utils/data/dataloader.py", line 1601, in _shutdown_workers
    if w.is_alive():Exception ignored in: 
 <function _MultiProcessingDataLoaderIter.__del__ at 0x7a77a50bfd80>
  Traceback (most recent call last):
   File "/usr/local/lib/python3.11/dist-packages/torch/utils/data/dataloader.py", line 1618, in __del__
      self._shutdown_workers() 
^  File "/usr/local/lib/python3.11/dist-packages/torch/utils/data/dataloader.py", line 1601, in _shutdown_workers
^^    ^if w.is_alive():
 ^   ^ ^ ^ ^^^^^^^
^  File "/usr/lib/python3.11/multiprocessing/process.py", line 160, in is_alive
    ^assert self._parent_pid == os.getpid(), 'can only test a child process'^
^ ^ ^  ^ ^ ^  
    File "/usr/

val:   0%|          | 0/1157 [00:00<?, ?it/s]

Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7a77a50bfd80>
Traceback (most recent call last):
  File "/usr/local/lib/python3.11/dist-packages/torch/utils/data/dataloader.py", line 1618, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.11/dist-packages/torch/utils/data/dataloader.py", line 1601, in _shutdown_workers
    if w.is_alive():
   Exception ignored in:   <function _MultiProcessingDataLoaderIter.__del__ at 0x7a77a50bfd80>
Traceback (most recent call last):
   File "/usr/local/lib/python3.11/dist-packages/torch/utils/data/dataloader.py", line 1618, in __del__
     ^self._shutdown_workers()^^
^  File "/usr/local/lib/python3.11/dist-packages/torch/utils/data/dataloader.py", line 1601, in _shutdown_workers
^    if w.is_alive():^Exception ignored in: 
^<function _MultiProcessingDataLoaderIter.__del__ at 0x7a77a50bfd80> ^^
 ^ Traceback (most recent call last):
  File "/usr/local/lib/python3.11/dist-packages/torch/utils/data/dataload

Train Loss: 8.16838 | Time: 7.80050 | Event: 0.00274 | Prod: 0.73028
Val   Loss: 7.29101 | Time: 7.01547 | Event: 0.00004 | Prod: 0.55100
Saved best model -> checkpoints/best_model_epoch4.pt

=== Epoch 5/10 ===


train:   0%|          | 0/9253 [00:00<?, ?it/s]

Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7a77a50bfd80>
Traceback (most recent call last):
  File "/usr/local/lib/python3.11/dist-packages/torch/utils/data/dataloader.py", line 1618, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.11/dist-packages/torch/utils/data/dataloader.py", line 1601, in _shutdown_workers
    if w.is_alive():
       ^^^Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7a77a50bfd80>^
^Traceback (most recent call last):
^  File "/usr/local/lib/python3.11/dist-packages/torch/utils/data/dataloader.py", line 1618, in __del__
^^    ^self._shutdown_workers()^
  File "/usr/local/lib/python3.11/dist-packages/torch/utils/data/dataloader.py", line 1601, in _shutdown_workers
    ^^if w.is_alive():

   File "/usr/lib/python3.11/multiprocessing/process.py", line 160, in is_alive
     assert self._parent_pid == os.getpid(), 'can only test a child process' 
       ^^ ^^^ ^ ^   ^^  ^^^^^^^
^  File 

val:   0%|          | 0/1157 [00:00<?, ?it/s]

Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7a77a50bfd80>
Traceback (most recent call last):
  File "/usr/local/lib/python3.11/dist-packages/torch/utils/data/dataloader.py", line 1618, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.11/dist-packages/torch/utils/data/dataloader.py", line 1601, in _shutdown_workers
    if w.is_alive():
     Exception ignored in:   <function _MultiProcessingDataLoaderIter.__del__ at 0x7a77a50bfd80>^
^Traceback (most recent call last):
  File "/usr/local/lib/python3.11/dist-packages/torch/utils/data/dataloader.py", line 1618, in __del__
^    ^Exception ignored in: self._shutdown_workers()<function _MultiProcessingDataLoaderIter.__del__ at 0x7a77a50bfd80>^
^
Traceback (most recent call last):
  File "/usr/local/lib/python3.11/dist-packages/torch/utils/data/dataloader.py", line 1601, in _shutdown_workers
  File "/usr/local/lib/python3.11/dist-packages/torch/utils/data/dataloader.py", line 1618, in __del_

Train Loss: 8.12798 | Time: 7.91846 | Event: 0.00222 | Prod: 0.41461
Val   Loss: 7.15317 | Time: 6.96389 | Event: 0.00008 | Prod: 0.37839
Saved best model -> checkpoints/best_model_epoch5.pt

=== Epoch 6/10 ===


train:   0%|          | 0/9253 [00:00<?, ?it/s]

Exception ignored in: Exception ignored in: Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7a77a50bfd80><function _MultiProcessingDataLoaderIter.__del__ at 0x7a77a50bfd80><function _MultiProcessingDataLoaderIter.__del__ at 0x7a77a50bfd80>


Traceback (most recent call last):
Traceback (most recent call last):
Traceback (most recent call last):
  File "/usr/local/lib/python3.11/dist-packages/torch/utils/data/dataloader.py", line 1618, in __del__
  File "/usr/local/lib/python3.11/dist-packages/torch/utils/data/dataloader.py", line 1618, in __del__
  File "/usr/local/lib/python3.11/dist-packages/torch/utils/data/dataloader.py", line 1618, in __del__
            self._shutdown_workers()
self._shutdown_workers()self._shutdown_workers()  File "/usr/local/lib/python3.11/dist-packages/torch/utils/data/dataloader.py", line 1601, in _shutdown_workers


  File "/usr/local/lib/python3.11/dist-packages/torch/utils/data/dataloader.py", line 1601, in _shutdown_workers
  F

val:   0%|          | 0/1157 [00:00<?, ?it/s]

Train Loss: 7.80648 | Time: 7.65775 | Event: 0.00202 | Prod: 0.29344
Val   Loss: 9.20951 | Time: 9.05872 | Event: 0.00010 | Prod: 0.30138

=== Epoch 7/10 ===


train:   0%|          | 0/9253 [00:00<?, ?it/s]

Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7a77a50bfd80>
Traceback (most recent call last):
  File "/usr/local/lib/python3.11/dist-packages/torch/utils/data/dataloader.py", line 1618, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.11/dist-packages/torch/utils/data/dataloader.py", line 1601, in _shutdown_workers
    if w.is_alive():
       ^^^^^^^^^^^^
  File "/usr/lib/python3.11/multiprocessing/process.py", line 160, in is_alive
    assert self._parent_pid == os.getpid(), 'can only test a child process'
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
AssertionError: can only test a child process
Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7a77a50bfd80>
Traceback (most recent call last):
  File "/usr/local/lib/python3.11/dist-packages/torch/utils/data/dataloader.py", line 1618, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.11/dist-packages/torch/utils/data/dataloader.py", line 16

val:   0%|          | 0/1157 [00:00<?, ?it/s]

Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7a77a50bfd80>
Traceback (most recent call last):
  File "/usr/local/lib/python3.11/dist-packages/torch/utils/data/dataloader.py", line 1618, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.11/dist-packages/torch/utils/data/dataloader.py", line 1601, in _shutdown_workers
    if w.is_alive():
       ^^^^^^^^^^^^
  File "/usr/lib/python3.11/multiprocessing/process.py", line 160, in is_alive
    assert self._parent_pid == os.getpid(), 'can only test a child process'
     Exception ignored in:  <function _MultiProcessingDataLoaderIter.__del__ at 0x7a77a50bfd80> 
   Traceback (most recent call last):
   File "/usr/local/lib/python3.11/dist-packages/torch/utils/data/dataloader.py", line 1618, in __del__
^    ^self._shutdown_workers()^^
^  File "/usr/local/lib/python3.11/dist-packages/torch/utils/data/dataloader.py", line 1601, in _shutdown_workers
^    ^if w.is_alive():
^^ ^ ^ ^^ ^ ^^ Exception 

Train Loss: 8.02662 | Time: 7.90707 | Event: 0.00199 | Prod: 0.23512
Val   Loss: 8.80110 | Time: 8.68420 | Event: 0.00008 | Prod: 0.23364

=== Epoch 8/10 ===


train:   0%|          | 0/9253 [00:00<?, ?it/s]

Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7a77a50bfd80>
Traceback (most recent call last):
  File "/usr/local/lib/python3.11/dist-packages/torch/utils/data/dataloader.py", line 1618, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.11/dist-packages/torch/utils/data/dataloader.py", line 1601, in _shutdown_workers
    if w.is_alive():
       ^^^^^^^^^^^^
  File "/usr/lib/python3.11/multiprocessing/process.py", line 160, in is_alive
    assert self._parent_pid == os.getpid(), 'can only test a child process'
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
AssertionError: can only test a child process
Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7a77a50bfd80>
Traceback (most recent call last):
  File "/usr/local/lib/python3.11/dist-packages/torch/utils/data/dataloader.py", line 1618, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.11/dist-packages/torch/utils/data/dataloader.py", line 16

val:   0%|          | 0/1157 [00:00<?, ?it/s]

Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7a77a50bfd80>
Traceback (most recent call last):
  File "/usr/local/lib/python3.11/dist-packages/torch/utils/data/dataloader.py", line 1618, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.11/dist-packages/torch/utils/data/dataloader.py", line 1601, in _shutdown_workers
    if w.is_alive():
       ^Exception ignored in: ^^<function _MultiProcessingDataLoaderIter.__del__ at 0x7a77a50bfd80>^
^Traceback (most recent call last):
  File "/usr/local/lib/python3.11/dist-packages/torch/utils/data/dataloader.py", line 1618, in __del__
^    ^self._shutdown_workers()^
^  File "/usr/local/lib/python3.11/dist-packages/torch/utils/data/dataloader.py", line 1601, in _shutdown_workers
^    if w.is_alive():^
^ 
   File "/usr/lib/python3.11/multiprocessing/process.py", line 160, in is_alive
      assert self._parent_pid == os.getpid(), 'can only test a child process' 
   ^^^ ^ ^  ^^  ^ ^ ^ Exception ignore

Train Loss: 8.39246 | Time: 8.30662 | Event: 0.00120 | Prod: 0.16930
Val   Loss: 10.17707 | Time: 10.07369 | Event: 0.00007 | Prod: 0.20661

=== Epoch 9/10 ===


train:   0%|          | 0/9253 [00:00<?, ?it/s]

Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7a77a50bfd80>
Traceback (most recent call last):
  File "/usr/local/lib/python3.11/dist-packages/torch/utils/data/dataloader.py", line 1618, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.11/dist-packages/torch/utils/data/dataloader.py", line 1601, in _shutdown_workers
    if w.is_alive():
       ^^^^^^^^^^^^
  File "/usr/lib/python3.11/multiprocessing/process.py", line 160, in is_alive
    assert self._parent_pid == os.getpid(), 'can only test a child process'
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
AssertionError: can only test a child process
Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7a77a50bfd80>
Traceback (most recent call last):
  File "/usr/local/lib/python3.11/dist-packages/torch/utils/data/dataloader.py", line 1618, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.11/dist-packages/torch/utils/data/dataloader.py", line 16

val:   0%|          | 0/1157 [00:00<?, ?it/s]

Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7a77a50bfd80>
Traceback (most recent call last):
  File "/usr/local/lib/python3.11/dist-packages/torch/utils/data/dataloader.py", line 1618, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.11/dist-packages/torch/utils/data/dataloader.py", line 1601, in _shutdown_workers
    if w.is_alive():
       ^^^^^Exception ignored in: ^<function _MultiProcessingDataLoaderIter.__del__ at 0x7a77a50bfd80>^
^^Traceback (most recent call last):
^  File "/usr/local/lib/python3.11/dist-packages/torch/utils/data/dataloader.py", line 1618, in __del__
Exception ignored in: ^<function _MultiProcessingDataLoaderIter.__del__ at 0x7a77a50bfd80>    ^self._shutdown_workers()


Traceback (most recent call last):
  File "/usr/local/lib/python3.11/dist-packages/torch/utils/data/dataloader.py", line 1601, in _shutdown_workers
Exception ignored in:   File "/usr/lib/python3.11/multiprocessing/process.py", line 160, in is

Train Loss: 8.31168 | Time: 8.23076 | Event: 0.00129 | Prod: 0.15925
Val   Loss: 10.02379 | Time: 9.92014 | Event: 0.00002 | Prod: 0.20726

=== Epoch 10/10 ===


train:   0%|          | 0/9253 [00:00<?, ?it/s]

Exception ignored in: Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7a77a50bfd80><function _MultiProcessingDataLoaderIter.__del__ at 0x7a77a50bfd80>
Traceback (most recent call last):

  File "/usr/local/lib/python3.11/dist-packages/torch/utils/data/dataloader.py", line 1618, in __del__
Traceback (most recent call last):
      File "/usr/local/lib/python3.11/dist-packages/torch/utils/data/dataloader.py", line 1618, in __del__
Exception ignored in: Exception ignored in: self._shutdown_workers()    <function _MultiProcessingDataLoaderIter.__del__ at 0x7a77a50bfd80><function _MultiProcessingDataLoaderIter.__del__ at 0x7a77a50bfd80>

self._shutdown_workers()
Traceback (most recent call last):
  File "/usr/local/lib/python3.11/dist-packages/torch/utils/data/dataloader.py", line 1601, in _shutdown_workers
Traceback (most recent call last):
  File "/usr/local/lib/python3.11/dist-packages/torch/utils/data/dataloader.py", line 1618, in __del__

  File "/usr/local/l

val:   0%|          | 0/1157 [00:00<?, ?it/s]

Exception ignored in: Exception ignored in: Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7a77a50bfd80><function _MultiProcessingDataLoaderIter.__del__ at 0x7a77a50bfd80><function _MultiProcessingDataLoaderIter.__del__ at 0x7a77a50bfd80>


Traceback (most recent call last):
Traceback (most recent call last):
Traceback (most recent call last):
  File "/usr/local/lib/python3.11/dist-packages/torch/utils/data/dataloader.py", line 1618, in __del__
  File "/usr/local/lib/python3.11/dist-packages/torch/utils/data/dataloader.py", line 1618, in __del__
  File "/usr/local/lib/python3.11/dist-packages/torch/utils/data/dataloader.py", line 1618, in __del__
            self._shutdown_workers()self._shutdown_workers()self._shutdown_workers()

Exception ignored in:   File "/usr/local/lib/python3.11/dist-packages/torch/utils/data/dataloader.py", line 1601, in _shutdown_workers

<function _MultiProcessingDataLoaderIter.__del__ at 0x7a77a50bfd80>  File "/usr/local/lib/pyth

Train Loss: 8.49960 | Time: 8.42895 | Event: 0.00091 | Prod: 0.13948
Val   Loss: 12.50792 | Time: 12.41239 | Event: 0.00001 | Prod: 0.19104


**Test Dataset**

In [56]:
# 3. Load Your Best Checkpoint
best_model_path = "checkpoints/best_model_epoch5.pt" 
model.load_state_dict(torch.load(best_model_path)['model_state_dict'])

print(f"Successfully loaded best model from {best_model_path}")

# We can just re-use your existing 'validate' function
test_metrics = validate(model, test_loader) 

print("Final Test Set Results")
print(f"  Test Loss: {test_metrics['loss']:.5f}")
print(f"  Test Time Loss: {test_metrics['time_loss']:.5f}")
print(f"  Test Event Loss: {test_metrics['event_loss']:.5f}")
print(f"  Test Product Loss: {test_metrics['product_loss']:.5f}")

print("Final Test Set Accuracy")
print(f"  Event Accuracy: {test_metrics['event_accuracy']*100:.2f}%")
print(f"  Product Accuracy: {test_metrics['product_accuracy']*100:.2f}%")

Successfully loaded best model from checkpoints/best_model_epoch5.pt


                                                        

Final Test Set Results
  Test Loss: 7.35122
  Test Time Loss: 7.15859
  Test Event Loss: 0.00019
  Test Product Loss: 0.38488
Final Test Set Accuracy
  Event Accuracy: 99.99%
  Product Accuracy: 96.26%




**Saving Artifacts**

In [57]:
import pickle

# 1. Save the encoders dictionary
with open("encoders.pkl", "wb") as f:
    pickle.dump(encoders, f)

# 2. Save the scaler
with open("scaler.pkl", "wb") as f:
    pickle.dump(scaler, f)

print("Artifacts saved.")

Artifacts saved.


**Model Predictions & Output**

In [None]:

import pandas as pd
import pickle
import random
from tqdm import tqdm

# Load All Artifacts
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Load encoders
with open("encoders.pkl", 'rb') as f:
    encoders = pickle.load(f)

# Load scaler
with open("scaler.pkl", 'rb') as f:
    scaler = pickle.load(f)

# Initialize Model and Load Best Weights
num_users = len(encoders['user_id'].classes_) + 1
num_products = len(encoders['product_id'].classes_) + 1
num_event_types = len(encoders['event_type'].classes_) + 1
num_categories = len(encoders['category_id'].classes_) + 1

# Get numeric feature count
numeric_feature_names = [
    'price', 'user_recency', 'session_duration', 'session_length',
    'user_avg_price', 'user_cart_to_purchase_ratio', 'user_remove_rate',
    'user_avg_views_per_session', 'product_avg_price', 'view_to_cart_ratio',
    'product_cart_to_purchase_ratio', 'product_remove_rate', 'repeat_purchase_rate',
    'hour_sin', 'hour_cos', 'dow_sin', 'dow_cos'
]
num_numeric_features = len(numeric_feature_names)

# Instantiate the model
model = NeuralHawkesTransformer(
    num_users=num_users,
    num_products=num_products,
    num_event_types=num_event_types,
    num_categories=num_categories,
    num_numeric_features=num_numeric_features
).to(device)

# Load the best checkpoint from training
best_model_path = "checkpoints/best_model_epoch5.pt" 
model.load_state_dict(torch.load(best_model_path, map_location=device)['model_state_dict'])
model.eval() # Set model to evaluation mode

print(f"Model {best_model_path} and artifacts loaded.")

# Define the Prediction Function
def predict_for_user(user_id_original, model, encoders, numeric_feature_names, seq_len=50):
    
    # 1. Find which split the user is in to get their processed data
    user_df = None
    if user_id_original in train_users:
        user_df = train_df[train_df['user_id'] == user_id_original].sort_values("event_time")
    elif user_id_original in val_users:
        user_df = val_df[val_df['user_id'] == user_id_original].sort_values("event_time")
    elif user_id_original in test_users:
        user_df = test_df[test_df['user_id'] == user_id_original].sort_values("event_time")

    if user_df is None or user_df.empty:
        print(f"\n--- Prediction for User: {user_id_original} ---")
        print("  Error: No event data found for this user in any split.")
        return

    # 2. Prepare the batch
    if len(user_df) > seq_len:
        user_df_seq = user_df.tail(seq_len)
    else:
        user_df_seq = user_df
    
    seq_len_actual = len(user_df_seq)
    pad_len = seq_len - seq_len_actual
    
    batch = {}
    
    # Get categorical/ID features
    batch['user_ids'] = torch.tensor(user_df_seq["user_id_enc"].values, dtype=torch.long)
    batch['product_ids'] = torch.tensor(user_df_seq["product_id_enc"].values, dtype=torch.long)
    batch['event_types'] = torch.tensor(user_df_seq["event_type_enc"].values, dtype=torch.long)
    batch['categories'] = torch.tensor(user_df_seq["category_id_enc"].values, dtype=torch.long)
    batch['time_deltas'] = torch.tensor(user_df_seq["delta_t"].values, dtype=torch.float32).unsqueeze(-1)
    
    # Get numeric features
    numeric_data = user_df_seq[numeric_feature_names].values
    batch['numeric_feats'] = torch.tensor(numeric_data, dtype=torch.float32)
    
    # Create mask
    batch['mask'] = torch.cat([torch.ones(seq_len_actual), torch.zeros(pad_len)], dim=0).float()
    
    # Apply padding
    for key, tensor in batch.items():
        if key == "mask": continue
        if pad_len > 0:
            if key == "numeric_feats": pad_tensor = torch.zeros(pad_len, tensor.shape[1])
            elif key == "time_deltas": pad_tensor = torch.zeros(pad_len, 1, dtype=tensor.dtype)
            else: pad_tensor = torch.zeros(pad_len, dtype=tensor.dtype)
            batch[key] = torch.cat([tensor, pad_tensor], dim=0)

    # Add batch dimension and move to device
    for key, tensor in batch.items():
        batch[key] = tensor.unsqueeze(0).to(device)

    # 3. Get model prediction
    with torch.no_grad():
        intensity, event_logits, product_logits = model(
            batch['user_ids'], batch['product_ids'], batch['event_types'],
            batch['categories'], batch['numeric_feats'], batch['time_deltas'],
            src_key_padding_mask=(batch['mask'] == 0)
        )
    
    # 4. Decode the prediction for the last real event
    last_real_event_idx = seq_len_actual - 1
    
    last_intensity = intensity[0, last_real_event_idx, 0].item()
    last_event_logits = event_logits[0, last_real_event_idx]
    last_product_logits = product_logits[0, last_real_event_idx]
    
    # Decode Event
    event_probs = F.softmax(last_event_logits, dim=0)
    top_event_prob, top_event_idx = torch.max(event_probs, dim=0)
    pred_event_type = encoders['event_type'].inverse_transform([top_event_idx.item()])[0]
    
    # Decode Product
    product_probs = F.softmax(last_product_logits, dim=0)
    top_prod_prob, top_prod_idx = torch.max(product_probs, dim=0)
    pred_product_id = encoders['product_id'].inverse_transform([top_prod_idx.item()])[0]
    
    # Decode Time
    last_raw_delta_t_hours = user_df_seq.iloc[-1]["delta_t_raw"] / 3600
    
    # Print Results
    print(f"\n Prediction for User: {user_id_original}")
    #print(f"Last Event: '{user_df_seq.iloc[-1]['event_type']}' (Product {user_df_seq.iloc[-1]['product_id']}) {last_raw_delta_t_hours:.2f} hours after previous event.")
    
    print("\nModel Predicts:")
    print(f"  ‚û°Ô∏è Most Likely Next Event: '{pred_event_type}' (Confidence: {top_event_prob.item()*100:.2f}%)")
    print(f"  ‚û°Ô∏è Most Likely Next Product: {pred_product_id} (Confidence: {top_prod_prob.item()*100:.2f}%)")
    print(f"  ‚û°Ô∏è Time Intensity: {last_intensity:.4f} (A higher value means the event is expected sooner)")
    
    # Show top 3 product predictions
    top_3_prod_probs, top_3_prod_indices = torch.topk(product_probs, 3)
    top_3_product_ids = encoders['product_id'].inverse_transform(top_3_prod_indices.cpu().numpy())
    
    print("\n  Top 3 Product Suggestions:")
    for i in range(3):
        print(f"    {i+1}. {top_3_product_ids[i]} (Confidence: {top_3_prod_probs[i].item()*100:.2f}%)")

# Run Predictions from Test Dataset
print("\n OUTPUT ON TEST SET USERS")

sample_test_users = random.sample(list(test_users), 10)
for user_id_orig in sample_test_users:
    predict_for_user(
        user_id_orig, 
        model, 
        encoders, 
        numeric_feature_names,
        seq_len=seq_len        
    )
    print("-" * 50)

Model checkpoints/best_model_epoch5.pt and artifacts loaded.

 OUTPUT ON TEST SET USERS

 Prediction for User: 591565561

Model Predicts:
  ‚û°Ô∏è Most Likely Next Event: 'view' (Confidence: 100.00%)
  ‚û°Ô∏è Most Likely Next Product: 5765554 (Confidence: 99.60%)
  ‚û°Ô∏è Time Intensity: 14.2343 (A higher value means the event is expected sooner)

  Top 3 Product Suggestions:
    1. 5765554 (Confidence: 99.60%)
    2. 5773606 (Confidence: 0.08%)
    3. 5808493 (Confidence: 0.03%)
--------------------------------------------------

 Prediction for User: 524703210

Model Predicts:
  ‚û°Ô∏è Most Likely Next Event: 'view' (Confidence: 100.00%)
  ‚û°Ô∏è Most Likely Next Product: 5864653 (Confidence: 90.57%)
  ‚û°Ô∏è Time Intensity: 17.7921 (A higher value means the event is expected sooner)

  Top 3 Product Suggestions:
    1. 5864653 (Confidence: 90.57%)
    2. 5911266 (Confidence: 0.40%)
    3. 5817692 (Confidence: 0.20%)
--------------------------------------------------

 Prediction for

In [65]:
import torch
from tqdm.auto import tqdm
from torch.nn.functional import softmax
from sklearn.metrics import mean_squared_error, mean_absolute_error
import numpy as np

K = 3 
eps = 1e-9

# Precision/Recall
def calculate_top_k_metrics_only_product(product_logits, product_ids, mask, k):
    
    # Ignore padded events
    relevant_mask = mask.bool()
    
    # 1. Get the top K predicted product indices
    # topk_indices shape: (B, L, K)
    topk_indices = torch.topk(product_logits, k=k, dim=2).indices
    
    # 2. Expand targets for comparison: product_ids shape (B, L) -> (B, L, 1)
    target_ids = product_ids.unsqueeze(-1)
    
    # 3. Check if the true target ID is present anywhere in the top K predictions
    # hit_mask shape: (B, L, K) -> sum over K -> (B, L) bool
    hit_mask = (topk_indices == target_ids).sum(dim=2).bool()
    
    # Apply the mask to only consider real events
    valid_hits = hit_mask * relevant_mask
    
    # Total events where a recommendation was possible
    total_events = relevant_mask.sum().item()
    
    if total_events == 0:
        return 0.0, 0.0, 0 # Precision, Recall, Event Count
    
    # Recall sum (Total Hits)
    recall_sum = valid_hits.sum().item()
    
    # Precision@K is typically (Hits / (K * Number of Samples))
    precision = recall_sum / (k * total_events)
    # Recall@K (Hit Rate) is Total Hits / Total Relevant Events
    recall = recall_sum / total_events

    return precision, recall, total_events


# Evaluation Function

def evaluate_test_set(model, test_loader, k):
    
    model.eval()
    total_p = 0.0
    total_r = 0.0
    total_events_counted = 0

    # Ensure device is set
    device = next(model.parameters()).device
    
    with torch.no_grad():
        pbar = tqdm(test_loader, desc=f"Final Test Evaluation (K={k})", leave=True)
        for batch in pbar:
            # 1. Unpack and move data to the correct device
            user_ids = batch["user_ids"].to(device)
            product_ids = batch["product_ids"].to(device)
            event_types = batch["event_types"].to(device)
            categories = batch.get("categories", torch.zeros_like(user_ids)).to(device)
            numeric_feats = batch.get("numeric_feats", torch.zeros(user_ids.size(0), user_ids.size(1), 0)).to(device)
            deltas = batch["time_deltas"].to(device)
            raw_deltas = batch["raw_time_deltas"].to(device)
            mask = batch["mask"].to(device)

            # 2. Forward Pass
            intensity, _, product_logits = model(
                user_ids, product_ids, event_types, categories,
                numeric_feats, deltas, src_key_padding_mask=(mask == 0))
            
            # Product Ranking
            batch_p, batch_r, batch_events = calculate_top_k_metrics_only_product(
                product_logits, product_ids, mask, k
            )
            total_p += batch_p * batch_events # Un-normalize, then re-normalize later
            total_r += batch_r * batch_events
            total_events_counted += batch_events
            

    final_metrics = {}
    
    # Ranking Metrics
    if total_events_counted > 0:
        final_metrics[f"Precision@{k}"] = total_p / total_events_counted
        final_metrics[f"Recall@{k}"] = total_r / total_events_counted
    else:
        final_metrics[f"Precision@{k}"] = 0.0
        final_metrics[f"Recall@{k}"] = 0.0

    return final_metrics

final_test_metrics = evaluate_test_set(model, test_loader, k=K)

print("Product Ranking Metrics:")
print(f"  - Precision@{K}: {final_test_metrics[f'Precision@{K}']*100:.3f}%")
print(f"  - Recall@{K}:    {final_test_metrics[f'Recall@{K}']*100:.3f}%")

Final Test Evaluation (K=3):   0%|          | 0/1157 [00:00<?, ?it/s]

Product Ranking Metrics:
  - Precision@3: 32.296%
  - Recall@3:    96.889%
