# Initialization

In [1]:
import pandas as pd
import matplotlib.pyplot as plt
import datetime as dt
import numpy as np
from copy import deepcopy
import copy

In [2]:
import argparse
import torch
import numpy as np
import os
import datetime
import torch.nn as nn
import torchvision
from torch.autograd import Variable
import torch.nn.functional as F
import torch.optim as optim
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, random_split
from tqdm import tqdm
import time
import math

from collections import OrderedDict
from typing import List, Tuple, Union
import matplotlib.pyplot as plt

DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(f"Training on {DEVICE}")

Training on cuda:0


In [3]:
import random
SEED = 42

torch.manual_seed(SEED)
np.random.seed(SEED)
random.seed(SEED)

## Set Hyperparameters

In [4]:
input_size = 889
hidden_size = 100
output_size = input_size
batch_size = 8
lr = 0.01
numrounds = 10
USE_CUDA = torch.cuda.is_available()

# Load Data

In [5]:
train_data = np.load(f'./train_data_diginetica.npy', allow_pickle=True)
valid_data = np.load(f'./test_data_diginetica.npy', allow_pickle=True)

In [6]:
train_data[0].columns

Index(['sessionId', 'userId', 'itemId', 'timeframe', 'time', 'userId2',
       'delta_t_a', 'delta_t_b', 'h_a', 'm_a', 's_a', 'h_b', 'm_b', 's_b'],
      dtype='object')

In [7]:
print(f"len train: {len(train_data)}; len test: {len(valid_data)}")

len train: 45; len test: 45


In [8]:
max_user = len(valid_data)
max_user

45

In [9]:
# concat all train data as one dataframe
train_combined = np.concatenate(train_data)
#convert to dataframe
train_combined = pd.DataFrame(train_combined)
train_combined.shape

(1455, 14)

In [10]:
train_combined[2].nunique()

889

In [11]:
# Step 1: Extract unique item IDs from the combined DataFrame
all_unique_items = train_combined[2].unique()

# Step 2: Create a universal item index mapping
universal_item_map = pd.DataFrame({
    'item_idx': np.arange(len(all_unique_items)),
    'itemId': all_unique_items
})

In [12]:
universal_item_map

Unnamed: 0,item_idx,itemId
0,0,115599.0
1,1,79898.0
2,2,35039.0
3,3,11604.0
4,4,87524.0
...,...,...
884,884,3694.0
885,885,90072.0
886,886,10440.0
887,887,35015.0


# DataLoader Preparation

In [13]:
class Dataset(object):
    def __init__(self, path, sep=',', session_key='sessionId', item_key='itemId', time_key='time', n_sample=-1, itemmap=None, itemstamp=None, time_sort=False):
        # Read csv
        #self.df = pd.read_csv(path, sep=sep, dtype={session_key: int, item_key: int, time_key: float})
        self.df = path
        self.session_key = session_key
        self.item_key = item_key
        self.time_key = time_key
        self.time_sort = time_sort
        if n_sample > 0:
            self.df = self.df[:n_sample]

        # Add colummn item index to data
        self.add_item_indices(itemmap=itemmap)
        """
        Sort the df by time, and then by session ID. That is, df is sorted by session ID and
        clicks within a session are next to each other, where the clicks within a session are time-ordered.
        """
        self.df.sort_values([session_key, time_key], inplace=True)
        self.click_offsets = self.get_click_offset()
        self.session_idx_arr = self.order_session_idx()

    def add_item_indices(self, itemmap=None):
        """
        Add item index column named "item_idx" to the df
        Args:
            itemmap (pd.DataFrame): mapping between the item Ids and indices
        """
        if itemmap is None:
            item_ids = self.df[self.item_key].unique()  # type is numpy.ndarray
            item2idx = pd.Series(data=np.arange(len(item_ids)),
                                 index=item_ids)
            # Build itemmap is a DataFrame that have 2 columns (self.item_key, 'item_idx)
            itemmap = pd.DataFrame({self.item_key: item_ids,
                                   'item_idx': item2idx[item_ids].values})
        self.itemmap = itemmap
        self.df = pd.merge(self.df, self.itemmap, on=self.item_key, how='inner')

    def get_click_offset(self):
        """
        self.df[self.session_key] return a set of session_key
        self.df[self.session_key].nunique() return the size of session_key set (int)
        self.df.groupby(self.session_key).size() return the size of each session_id
        self.df.groupby(self.session_key).size().cumsum() retunn cumulative sum
        """
        offsets = np.zeros(self.df[self.session_key].nunique() + 1, dtype=np.int32)
        offsets[1:] = self.df.groupby(self.session_key).size().cumsum()
        return offsets

    def order_session_idx(self):
        if self.time_sort:
            sessions_start_time = self.df.groupby(self.session_key)[self.time_key].min().values
            session_idx_arr = np.argsort(sessions_start_time)
        else:
            session_idx_arr = np.arange(self.df[self.session_key].nunique())
        return session_idx_arr
    
    def __len__(self):
        return len(self.session_idx_arr)

    @property
    def items(self):
        return self.itemmap[self.item_key].unique()

In [14]:
class GRUDataset(Dataset):
    def __init__(self, data, itemmap, session_key='sessionId', item_key='itemId', time_key='time'):
        self.data = data
        self.itemmap = itemmap
        self.session_key = session_key
        self.item_key = item_key
        self.time_key = time_key

        # Map items to indices
        self.data = pd.merge(self.data, self.itemmap, on=self.item_key, how='inner')

        # Sort by session and time
        self.data.sort_values([self.session_key, self.time_key], inplace=True)

        # Group data by session and collect item indices
        self.sessions = self.data.groupby(self.session_key)['item_idx'].apply(list)

    def __len__(self):
        return len(self.sessions)

    def __getitem__(self, index):
        session_items = self.sessions.iloc[index]
        sequence = torch.tensor(session_items[:-1], dtype=torch.long)
        target = torch.tensor(session_items[1:], dtype=torch.long)
        return sequence, target

In [15]:
# class DataLoader():
#     def __init__(self, dataset, batch_size=1):
#         """
#         A class for creating session-parallel mini-batches.

#         Args:
#              dataset (SessionDataset): the session dataset to generate the batches from
#              batch_size (int): size of the batch
#         """
#         self.dataset = dataset
#         self.batch_size = batch_size

#     def __iter__(self):
#         """ Returns the iterator for producing session-parallel training mini-batches.

#         Yields:
#             input (B,): torch.FloatTensor. Item indices that will be encoded as one-hot vectors later.
#             target (B,): a Variable that stores the target item indices
#             masks: Numpy array indicating the positions of the sessions to be terminated
#         """
#         # initializations
#         df = self.dataset.df
#         click_offsets = self.dataset.click_offsets
#         session_idx_arr = self.dataset.session_idx_arr

#         iters = np.arange(self.batch_size)
#         maxiter = iters.max()
#         start = click_offsets[session_idx_arr[iters]]
#         end = click_offsets[session_idx_arr[iters] + 1]
#         mask = []  # indicator for the sessions to be terminated
#         finished = False

#         while not finished:
#             minlen = (end - start).min()
#             # Item indices(for embedding) for clicks where the first sessions start
#             idx_target = df.item_idx.values[start]

#             for i in range(minlen - 1):
#                 # Build inputs & targets
#                 idx_input = idx_target
#                 idx_target = df.item_idx.values[start + i + 1]
#                 input = torch.LongTensor(idx_input)
#                 target = torch.LongTensor(idx_target)
#                 yield input, target, mask

#             # click indices where a particular session meets second-to-last element
#             start = start + (minlen - 1)
#             # see if how many sessions should terminate
#             mask = np.arange(len(iters))[(end - start) <= 1]
#             for idx in mask:
#                 maxiter += 1
#                 if maxiter >= len(click_offsets) - 1:
#                     finished = True
#                     break
#                 # update the next starting/ending point
#                 iters[idx] = maxiter
#                 start[idx] = click_offsets[session_idx_arr[maxiter]]
#                 end[idx] = click_offsets[session_idx_arr[maxiter] + 1]

#     def __len__(self):
#         # Return the number of batches in the dataset
#         return (len(self.dataset) + self.batch_size - 1) // self.batch_size

In [16]:
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import DataLoader

def collate_fn(batch):
    sequences, targets = zip(*batch)
    sequences_padded = pad_sequence(sequences, batch_first=True, padding_value=0)
    targets_padded = pad_sequence(targets, batch_first=True, padding_value=-1)
    return sequences_padded, targets_padded

def get_loader(data, itemmap, batch_size=32, shuffle=True):
    dataset = GRUDataset(data, itemmap=itemmap)
    return DataLoader(dataset, batch_size=batch_size, shuffle=shuffle, collate_fn=collate_fn)


# Model Architecture

In [17]:
class SASRec(nn.Module):
    def __init__(self, item_size, hidden_size, num_blocks, num_heads, dropout=0.5):
        super(SASRec, self).__init__()
        self.item_size = item_size
        self.hidden_size = hidden_size
        self.num_blocks = num_blocks
        self.num_heads = num_heads

        # Item embedding layer
        self.item_embedding = nn.Embedding(item_size, hidden_size)

        # Positional encoding
        self.position_embedding = nn.Embedding(1000, hidden_size)  # Assuming max sequence length is 1000

        # Transformer blocks
        self.transformer_blocks = nn.ModuleList([
            nn.TransformerEncoderLayer(
                d_model=hidden_size,
                nhead=num_heads,
                dropout=dropout
            ) for _ in range(num_blocks)
        ])

        # Output fully connected layer
        self.fc = nn.Linear(hidden_size, item_size)

    def forward(self, x):
        batch_size, max_seq_len = x.size()
        print(f"x size: {x.size()}")

        positions = torch.arange(max_seq_len, dtype=torch.long, device=x.device)
        positions = positions.unsqueeze(0).expand_as(x)

        # Create padding mask before transposing x
        # It should be True where the input is a padding token (0)
        padding_mask = (x == 0)

        x = self.item_embedding(x) + self.position_embedding(positions)

        # Transpose x to match the 'sequence first' format
        x = x.transpose(0, 1)  # shape becomes (max_seq_len, batch_size, hidden_size)

        for transformer in self.transformer_blocks:
            x = transformer(x, src_key_padding_mask=padding_mask)

        # Transpose back to 'batch first' for the final layer
        x = x.transpose(0, 1)

        x = self.fc(x)
        return x

In [18]:
# class PointWiseFeedForward(nn.Module):
#     def __init__(self, hidden_units, dropout_rate):
#         super(PointWiseFeedForward, self).__init__()

#         self.conv1 = nn.Conv1d(hidden_units, hidden_units, kernel_size=1)
#         self.dropout1 = nn.Dropout(p=dropout_rate)
#         self.relu = nn.ReLU()
#         self.conv2 = nn.Conv1d(hidden_units, hidden_units, kernel_size=1)
#         self.dropout2 = nn.Dropout(p=dropout_rate)

#     def forward(self, inputs):
#         outputs = self.dropout2(self.conv2(self.relu(self.dropout1(self.conv1(inputs.transpose(-1, -2))))))
#         outputs = outputs.transpose(-1, -2)  # as Conv1D requires (N, C, Length)
#         outputs += inputs
#         return outputs

# class SASRec(nn.Module):
#     def __init__(self, item_size, hidden_size, num_blocks, num_heads, dropout=0.5):
#         super(SASRec, self).__init__()
#         self.item_size = item_size
#         self.hidden_size = hidden_size
#         self.num_blocks = num_blocks
#         self.num_heads = num_heads

#         # Item and positional embeddings
#         self.item_embedding = nn.Embedding(item_size, hidden_size)
#         self.position_embedding = nn.Embedding(1000, hidden_size)  # Assuming max sequence length is 1000

#         self.emb_dropout = nn.Dropout(p=dropout)

#         self.transformer_blocks = nn.ModuleList()
#         for _ in range(num_blocks):
#             self.transformer_blocks.append(
#                 nn.TransformerEncoderLayer(
#                     d_model=hidden_size,
#                     nhead=num_heads,
#                     dim_feedforward=hidden_size,
#                     dropout=dropout
#                 )
#             )

#             self.transformer_blocks.append(PointWiseFeedForward(hidden_size, dropout))

#         self.last_layernorm = nn.LayerNorm(hidden_size, eps=1e-8)
#         self.fc = nn.Linear(hidden_size, item_size)

#     def forward(self, x):
#         batch_size, seq_len = x.size()

#         # Positional Encoding
#         positions = torch.arange(seq_len, dtype=torch.long, device=x.device)
#         positions = positions.unsqueeze(0).expand_as(x)

#         # Item Embedding + Positional Embedding
#         x = self.item_embedding(x) + self.position_embedding(positions)
#         x = self.emb_dropout(x)

#         # Transpose to format (seq_len, batch_size, hidden_size) for Transformer
#         x = x.transpose(0, 1)

#         # Create a mask to prevent attention to padded positions (assuming padding token index is 0)
#         padding_mask = (x == 0).all(dim=2).transpose(0, 1)  # Transposing to shape (batch_size, seq_len)

#         # Attention Mask for causal (auto-regressive) model
#         src_mask = torch.triu(torch.ones((seq_len, seq_len), device=x.device), diagonal=1).bool()

#         # Pass through each layer in the transformer block
#         for layer in self.transformer_blocks:
#             if isinstance(layer, nn.TransformerEncoderLayer):
#                 x = layer(x, src_mask=src_mask, src_key_padding_mask=padding_mask)
#             else:
#                 # Pointwise feedforward layers
#                 x = layer(x)

#         # Transpose back to (batch_size, seq_len, hidden_size)
#         x = x.transpose(0, 1)

#         # Output layer
#         x = self.fc(x)

#         return x

## Loss Function

In [19]:
class TOP1MaxLoss(torch.nn.Module):
    def __init__(self):
        super(TOP1MaxLoss, self).__init__()

    def forward(self, scores, targets):
        # Initialize loss
        loss = 0.0

        # Loop over each element in the batch
        for i in range(scores.size(0)):  # Loop over batch
            for j in range(targets.size(1)):  # Loop over sequence
                if targets[i, j] == -1:  # Skip padding
                    continue

                # Get the score of the target item
                pos_score = scores[i, targets[i, j]]

                # Calculate the difference with all other items
                diff = -torch.sigmoid(pos_score - scores[i])

                # Exclude the positive item from the loss
                diff[targets[i, j]] = 0

                # Add to the total loss
                loss += torch.sum(diff)

        # Average the loss
        loss = loss / (scores.size(0) * targets.size(1))

        return loss

# Train & Test

In [20]:
def evaluate(net, dataloader, device, k):
    """Evaluate the network on the given data loader for top-k recommendation."""
    net.to(device)
    net.eval()
    criterion = nn.CrossEntropyLoss(ignore_index=-1)  # Replace with your loss function
    total_recall = 0.0
    total_mrr = 0.0
    total_count = 0
    total_loss = 0.0

    with torch.no_grad():
        for x, y in dataloader:
            data, target = x.to(device), y.to(device)
            outputs = net(data)
            outputs = outputs.view(-1, outputs.size(-1))  # Flatten output
            target2 = target.view(-1)  # Flatten target

            # Calculate total loss
            total_loss += criterion(outputs, target2).item()

            # Select top-k items
            _, top_k_indices = torch.topk(outputs, k, dim=1)

            # Calculate recall and MRR for each batch
            for i in range(data.size(0)):
                for target_item in target[i]:
                    if target_item == -1:  # Skip padding or any special token
                        continue
                    target_item_scalar = target_item.item()
                    top_k_items = top_k_indices[i].tolist()

                    # Calculate Recall@k
                    if target_item_scalar in top_k_items:
                        total_recall += 1

                    # Calculate MRR@k
                    if target_item_scalar in top_k_items:
                        rank = top_k_items.index(target_item_scalar)
                        total_mrr += 1 / (rank + 1)

                total_count += len(target[i][target[i] != -1])  # Count non-padding elements

    avg_recall = total_recall / total_count
    avg_mrr = total_mrr / total_count
    avg_loss = total_loss / len(dataloader)

    results = {
        'recall': avg_recall,
        'mrr': avg_mrr
    }

    return avg_loss, results


In [21]:
def train(net, trainloader, epochs, device, valloader=None):
    """Train the network for session-based recommendation."""
    # Define loss and optimizer
    criterion = nn.CrossEntropyLoss(ignore_index=-1)  # Assuming -1 is your padding index
    optimizer = torch.optim.Adam(net.parameters(), lr=lr)  # Adam optimizer is often used for Transformers

    print(f"Training {epochs} epoch(s) w/ {len(trainloader)} batches each")
    start_time = time.time()

    net.to(device)
    net.train()

    for epoch in range(epochs):
        total_loss = 0.0

        for x, y in trainloader:
            data, target = x.to(device), y.to(device)

            # Zero the parameter gradients
            optimizer.zero_grad()

            # Forward pass
            outputs = net(data)
            outputs = outputs.view(-1, outputs.size(-1))  # Flatten output for cross-entropy
            target = target.view(-1)  # Flatten target

            # Compute loss and backpropagate
            loss = criterion(outputs, target)
            loss.backward()
            optimizer.step()

            total_loss += loss.item()

        # Calculate metrics
        val_loss, val_results = evaluate(net, valloader, device, k=3)

        print(f"Epoch {epoch + 1}: Loss: {total_loss / len(trainloader):.4f}, Recall: {val_results['recall']:.4f}, MRR: {val_results['mrr']:.4f}")

        net.train()  # Ensure the network is in training mode

    total_time = time.time() - start_time
    net.to("cpu")  # Move model back to CPU

    print(f"Training completed in {total_time:.2f} seconds")

    return val_results

# Solo Training

In [22]:
item_size = len(universal_item_map)  # Assuming universal_item_map contains all unique items
hidden_size = 400  # Size of embeddings and hidden layers
num_blocks = 3  # Number of Transformer layers
num_heads = 4  # Number of heads in the multi-head attention mechanism
dropout_rate = 0.5  # Dropout rate

In [23]:
#list of recall and mrr for each user
recall_list = []
mrr_list = []

for i in range(max_user):
    print(f"Training on user {i}...")
    local_train = train_data[i]
    local_test = valid_data[i]

    trainloader = get_loader(local_train, itemmap=universal_item_map, batch_size=batch_size)
    testloader = get_loader(local_test, itemmap=universal_item_map, batch_size=batch_size)

    # Initialize the network
    net = SASRec(item_size=item_size, 
                      hidden_size=hidden_size, 
                      num_blocks=num_blocks, 
                      num_heads=num_heads, 
                      dropout=dropout_rate)
    net.to(DEVICE)

    # Train the network
    train_res = train(net, trainloader, numrounds, DEVICE, testloader)

    # Evaluate the network
    loss, results = evaluate(net, testloader, DEVICE, k=5)

    print(f"Recall@5: {results['recall']:.4f}")
    print(f"MRR@5: {results['mrr']:.4f}")

    recall_list.append(results['recall'])
    mrr_list.append(results['mrr'])

print(f"Average Recall@5: {np.mean(recall_list):.4f}")
print(f"Average MRR@5: {np.mean(mrr_list):.4f}")

Training on user 0...
Training 10 epoch(s) w/ 1 batches each
x size: torch.Size([5, 8])
x size: torch.Size([1, 4])
Epoch 1: Loss: 6.7559, Recall: 0.5000, MRR: 0.2500
x size: torch.Size([5, 8])
x size: torch.Size([1, 4])
Epoch 2: Loss: 5.0779, Recall: 1.0000, MRR: 0.6667
x size: torch.Size([5, 8])
x size: torch.Size([1, 4])
Epoch 3: Loss: 4.0281, Recall: 0.5000, MRR: 0.1667
x size: torch.Size([5, 8])
x size: torch.Size([1, 4])
Epoch 4: Loss: 3.0280, Recall: 0.5000, MRR: 0.2500
x size: torch.Size([5, 8])
x size: torch.Size([1, 4])
Epoch 5: Loss: 3.0327, Recall: 0.0000, MRR: 0.0000
x size: torch.Size([5, 8])
x size: torch.Size([1, 4])
Epoch 6: Loss: 3.1781, Recall: 0.0000, MRR: 0.0000
x size: torch.Size([5, 8])
x size: torch.Size([1, 4])
Epoch 7: Loss: 3.0014, Recall: 0.0000, MRR: 0.0000
x size: torch.Size([5, 8])
x size: torch.Size([1, 4])
Epoch 8: Loss: 2.9383, Recall: 0.5000, MRR: 0.2500
x size: torch.Size([5, 8])
x size: torch.Size([1, 4])
Epoch 9: Loss: 3.0131, Recall: 0.5000, MRR: 0

In [24]:
# # Define the range of values for each hyperparameter
# hidden_sizes = [100, 200, 300, 400]
# num_blocks_list = [2, 3, 4]
# num_heads_list = [2, 4, 6]
# dropout_rates = [0.3, 0.5, 0.7]

# # Variables to store the best hyperparameters and corresponding recall
# best_recall = 0
# best_hyperparameters = {}

# # Loop over each combination of hyperparameters
# for hidden_size in hidden_sizes:
#     for num_blocks in num_blocks_list:
#         for num_heads in num_heads_list:

#             if hidden_size % num_heads != 0:
#                 continue
#             for dropout_rate in dropout_rates:
#                 # Reset the lists to store recall and MRR for each user
#                 recall_list = []
#                 mrr_list = []

#                 for i in range(max_user):
#                     print(f"Training on user {i} with hidden_size={hidden_size}, num_blocks={num_blocks}, num_heads={num_heads}, dropout_rate={dropout_rate}...")

#                     local_train = train_data[i]
#                     local_test = valid_data[i]

#                     trainloader = get_loader(local_train, itemmap=universal_item_map, batch_size=batch_size)
#                     testloader = get_loader(local_test, itemmap=universal_item_map, batch_size=batch_size)

#                     # Initialize the network with current set of hyperparameters
#                     net = SASRec(item_size=item_size, 
#                                  hidden_size=hidden_size, 
#                                  num_blocks=num_blocks, 
#                                  num_heads=num_heads, 
#                                  dropout=dropout_rate)
#                     net.to(DEVICE)

#                     # Train and evaluate the network
#                     train_res = train(net, trainloader, numrounds, DEVICE, testloader)
#                     loss, results = evaluate(net, testloader, DEVICE, k=5)

#                     recall_list.append(results['recall'])
#                     mrr_list.append(results['mrr'])

#                 # Compute average recall
#                 avg_recall = np.mean(recall_list)

#                 # Update best hyperparameters if necessary
#                 if avg_recall > best_recall:
#                     best_recall = avg_recall
#                     best_hyperparameters = {
#                         'hidden_size': hidden_size,
#                         'num_blocks': num_blocks,
#                         'num_heads': num_heads,
#                         'dropout_rate': dropout_rate
#                     }

# # Print out the best hyperparameters
# print("Best Hyperparameters:")
# print(best_hyperparameters)

In [25]:
print(net)

SASRec(
  (item_embedding): Embedding(889, 400)
  (position_embedding): Embedding(1000, 400)
  (transformer_blocks): ModuleList(
    (0-2): 3 x TransformerEncoderLayer(
      (self_attn): MultiheadAttention(
        (out_proj): NonDynamicallyQuantizableLinear(in_features=400, out_features=400, bias=True)
      )
      (linear1): Linear(in_features=400, out_features=2048, bias=True)
      (dropout): Dropout(p=0.5, inplace=False)
      (linear2): Linear(in_features=2048, out_features=400, bias=True)
      (norm1): LayerNorm((400,), eps=1e-05, elementwise_affine=True)
      (norm2): LayerNorm((400,), eps=1e-05, elementwise_affine=True)
      (dropout1): Dropout(p=0.5, inplace=False)
      (dropout2): Dropout(p=0.5, inplace=False)
    )
  )
  (fc): Linear(in_features=400, out_features=889, bias=True)
)


In [26]:
from torchinfo import summary
summary(net)

Layer (type:depth-idx)                                       Param #
SASRec                                                       --
├─Embedding: 1-1                                             355,600
├─Embedding: 1-2                                             400,000
├─ModuleList: 1-3                                            --
│    └─TransformerEncoderLayer: 2-1                          --
│    │    └─MultiheadAttention: 3-1                          641,600
│    │    └─Linear: 3-2                                      821,248
│    │    └─Dropout: 3-3                                     --
│    │    └─Linear: 3-4                                      819,600
│    │    └─LayerNorm: 3-5                                   800
│    │    └─LayerNorm: 3-6                                   800
│    │    └─Dropout: 3-7                                     --
│    │    └─Dropout: 3-8                                     --
│    └─TransformerEncoderLayer: 2-2                          --
│    │  

# Centralized

In [27]:
#combine all train data as one dataframe
train_combined = np.concatenate(train_data)
train_combined = pd.DataFrame(train_combined)

#set the column name
train_combined.columns = train_data[0].columns


#combine all test data as one dataframe
test_combined = np.concatenate(valid_data)
test_combined = pd.DataFrame(test_combined)

#set the column name
test_combined.columns = valid_data[0].columns

trainloader = get_loader(train_combined, itemmap=universal_item_map, batch_size=batch_size)
testloader = get_loader(test_combined, itemmap=universal_item_map, batch_size=batch_size)

# Initialize the network
net = SASRec(item_size=item_size, 
                      hidden_size=hidden_size, 
                      num_blocks=num_blocks, 
                      num_heads=num_heads, 
                      dropout=dropout_rate)
net.to(DEVICE)

# Train the network
train_res = train(net, trainloader, numrounds, DEVICE, testloader)

# Evaluate the network
loss, results = evaluate(net, testloader, DEVICE, k=5)

print(f"Recall@5: {results['recall']:.4f}")
print(f"MRR@5: {results['mrr']:.4f}")


Training 10 epoch(s) w/ 34 batches each
x size: torch.Size([8, 5])
x size: torch.Size([8, 9])
x size: torch.Size([8, 12])
x size: torch.Size([8, 14])
x size: torch.Size([8, 11])
x size: torch.Size([8, 8])
x size: torch.Size([8, 13])
x size: torch.Size([8, 7])
x size: torch.Size([8, 13])
x size: torch.Size([8, 12])
x size: torch.Size([8, 19])
x size: torch.Size([8, 5])
x size: torch.Size([8, 5])
x size: torch.Size([8, 11])
x size: torch.Size([8, 10])
x size: torch.Size([8, 12])
x size: torch.Size([8, 18])
x size: torch.Size([8, 8])
x size: torch.Size([8, 8])
x size: torch.Size([8, 12])
x size: torch.Size([8, 15])
x size: torch.Size([8, 14])
x size: torch.Size([8, 20])
x size: torch.Size([8, 9])
x size: torch.Size([8, 10])
x size: torch.Size([8, 8])
x size: torch.Size([8, 12])
x size: torch.Size([8, 8])
x size: torch.Size([8, 9])
x size: torch.Size([8, 11])
x size: torch.Size([8, 17])
x size: torch.Size([8, 10])
x size: torch.Size([8, 6])
x size: torch.Size([2, 9])
x size: torch.Size([8,

# FL Settings

In [28]:
numrounds = 50

## Client

In [29]:
class Client():
  def __init__(self, client_config:dict):
    # client config as dict to make configuration dynamic
    self.id = client_config["id"]
    self.config = client_config
    self.__model = None

    # check if CUDA is available
    if torch.cuda.is_available():
      self.device = 'cuda'
    else:
       self.device = 'cpu'

    self.train_loader = self.config["train_data"]
    self.valid_loader = self.config["test_data"]

  @property
  def model(self):
    return self.__model

  @model.setter
  def model(self, model):
    self.__model = model

  def __len__(self):
    """Return a total size of the client's local data."""
    return len(self.train_loader.sampler)

  def train(self):
    results = train(net=self.model,
                    trainloader= self.train_loader,
                    epochs= self.config["local_epoch"],
                    device= self.device,
                    valloader= self.valid_loader)
    print(f"Train result client {self.id}: {results}")

  def test(self):
    loss,result = evaluate(net = self.model,
                    dataloader= self.valid_loader,
                    device=self.device, k=5)
    print(f"Test result client {self.id}: {loss, result}")
    return result

## Server

In [30]:
class FedAvg():
  def __init__(self):
    self.globalmodel = SASRec(item_size=item_size, 
                      hidden_size=hidden_size, 
                      num_blocks=num_blocks, 
                      num_heads=num_heads, 
                      dropout=dropout_rate)
    self.rounds = 0
    self.params = {}

    # check if CUDA is available
    if torch.cuda.is_available():
      self.device = 'cuda'
    else:
       self.device = 'cpu'


  def aggregate(self, round):
    #v1:update the aggregate to save the model with round and date indicator
    modelparams = []
    for i in self.params.keys():
      modelparams.append(self.params[i])

    avg_weights = {}
    for name in modelparams[0].keys():
      avg_weights[name] = torch.mean(torch.stack([w[name] for w in modelparams]), dim = 0)

    self.globalmodel.load_state_dict(avg_weights)

    #current timestamp
    current_time = datetime.datetime.now().strftime('%Y-%m-%d_%H-%M-%S')
    # filename = f"{path_glob_m}/global_model_round_{round}_{current_time}.pth"
    # torch.save(self.globalmodel.state_dict(), filename)

  def clientstrain(self, clientconfig):
    clients = clientconfig
    for i in clients.keys():
      test_client = Client(clients[i])
      test_client.model = copy.deepcopy(self.globalmodel)
      test_client.model.to(self.device)
      test_client.train()
      # test_client.test()
      self.params[i] = test_client.model.state_dict()

  def initiate_FL(self, clientconfig, serverdata):
    clients = clientconfig
    print("Round: {}".format(self.rounds))

    print("Obtaining Weights!!")
    self.clientstrain(clients)

    #### Aggregate model
    print("Aggregating Model!!")
    self.aggregate(self.rounds)

    #### Replace parameters with global model parameters
    for i in self.params.keys():
        self.params[i] = self.globalmodel.state_dict()


    servertest = serverdata
    loss, results = evaluate(net = self.globalmodel,
                    dataloader= servertest,
                    device=self.device, k=5)
    print("Round {} metrics:".format(self.rounds))
    print("Server Loss = {}".format(loss))
    print("Server Recall = {}".format(results['recall']))
    print("Round {} finished!".format(self.rounds))
    self.rounds += 1
    return clients, results['recall']

## Main

In [31]:
clients = {}

for i in range(max_user):
  clients[i] = {"id": i, "val_size": 0.25, "batch_size": batch_size, "local_epoch": 1}
  clients[i]['train_data'] = get_loader(train_data[i], itemmap=universal_item_map, batch_size=batch_size)
  clients[i]['test_data'] = get_loader(valid_data[i], itemmap=universal_item_map, batch_size=batch_size)
  print(f"client: {i}")
  print(f"Number of batches in the dataloader train: {len(clients[i]['train_data'])}")
  print(f"Number of batches in the dataloader test: {len(clients[i]['test_data'])}")

serverdata = get_loader(valid_data[37], itemmap=universal_item_map, batch_size=batch_size)
server = FedAvg() ### initialize server

allrecall = []
for i in range(numrounds):
  clients, recall = server.initiate_FL(clients, serverdata)
  allrecall.append(recall)

print("\n")
print("-" * 50)
print("Recall of all rounds: {}".format(allrecall))

client: 0
Number of batches in the dataloader train: 1
Number of batches in the dataloader test: 1
client: 1
Number of batches in the dataloader train: 1
Number of batches in the dataloader test: 1
client: 2
Number of batches in the dataloader train: 1
Number of batches in the dataloader test: 1
client: 3
Number of batches in the dataloader train: 2
Number of batches in the dataloader test: 1
client: 4
Number of batches in the dataloader train: 1
Number of batches in the dataloader test: 1
client: 5
Number of batches in the dataloader train: 1
Number of batches in the dataloader test: 1
client: 6
Number of batches in the dataloader train: 1
Number of batches in the dataloader test: 1
client: 7
Number of batches in the dataloader train: 1
Number of batches in the dataloader test: 1
client: 8
Number of batches in the dataloader train: 1
Number of batches in the dataloader test: 1


client: 9
Number of batches in the dataloader train: 1
Number of batches in the dataloader test: 1
client: 10
Number of batches in the dataloader train: 1
Number of batches in the dataloader test: 1
client: 11
Number of batches in the dataloader train: 1
Number of batches in the dataloader test: 1
client: 12
Number of batches in the dataloader train: 1
Number of batches in the dataloader test: 1
client: 13
Number of batches in the dataloader train: 2
Number of batches in the dataloader test: 1
client: 14
Number of batches in the dataloader train: 1
Number of batches in the dataloader test: 1
client: 15
Number of batches in the dataloader train: 1
Number of batches in the dataloader test: 1
client: 16
Number of batches in the dataloader train: 1
Number of batches in the dataloader test: 1
client: 17
Number of batches in the dataloader train: 1
Number of batches in the dataloader test: 1
client: 18
Number of batches in the dataloader train: 1
Number of batches in the dataloader test: 1
c

## Test to All Clients

In [32]:
final_model = server.globalmodel
final_model.to(DEVICE)

recall_clients = []
mrr_clients = []
loss_clients = []

# loop for each client
for i in range(max_user):
    print(f"Testing on user {i}...")
    local_test = valid_data[i]
    testloader = get_loader(local_test, itemmap=universal_item_map, batch_size=batch_size)

    # Evaluate the network
    loss, results = evaluate(final_model, testloader, DEVICE, k=5)

    print(f"Recall@5: {results['recall']:.4f}")
    print(f"MRR@5: {results['mrr']:.4f}")

    recall_clients.append(results['recall'])
    mrr_clients.append(results['mrr'])
    loss_clients.append(loss) 

print(f"Average Recall@5: {np.mean(recall_clients):.4f}")
print(f"Average MRR@5: {np.mean(mrr_clients):.4f}")

Testing on user 0...
x size: torch.Size([1, 4])
Recall@5: 1.0000
MRR@5: 0.7500
Testing on user 1...
x size: torch.Size([1, 3])
Recall@5: 0.6667
MRR@5: 0.2778
Testing on user 2...
x size: torch.Size([1, 2])
Recall@5: 0.5000
MRR@5: 0.5000
Testing on user 3...
x size: torch.Size([3, 1])
Recall@5: 0.3333
MRR@5: 0.3333
Testing on user 4...
x size: torch.Size([1, 1])
Recall@5: 0.0000
MRR@5: 0.0000
Testing on user 5...
x size: torch.Size([1, 1])
Recall@5: 0.0000
MRR@5: 0.0000
Testing on user 6...
x size: torch.Size([1, 1])
Recall@5: 0.0000
MRR@5: 0.0000
Testing on user 7...
x size: torch.Size([1, 2])
Recall@5: 0.0000
MRR@5: 0.0000
Testing on user 8...
x size: torch.Size([1, 3])
Recall@5: 0.0000
MRR@5: 0.0000
Testing on user 9...
x size: torch.Size([1, 1])
Recall@5: 1.0000
MRR@5: 1.0000
Testing on user 10...
x size: torch.Size([1, 3])
Recall@5: 1.0000
MRR@5: 0.7500
Testing on user 11...
x size: torch.Size([1, 3])
Recall@5: 0.3333
MRR@5: 0.1111
Testing on user 12...
x size: torch.Size([1, 1])
R