In [2]:
from platform import python_version

print(python_version())

3.10.12


In [3]:
import random 
import time

In [4]:
import optuna
import wandb 

In [5]:
import torch
import torch.nn as nn 
import torch.nn.functional as F
import torch.optim as optim

import torch_geometric

from torch_geometric.nn.conv import MessagePassing 
from torch_geometric.utils import degree

In [6]:
import matplotlib.pyplot as plt
import networkx as nx
import numpy as np
import pandas as pd
import scipy.sparse as sp

In [7]:
from sklearn import preprocessing as pp
from sklearn.model_selection import train_test_split

from tqdm import tqdm

In [8]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device 

device(type='cuda')

In [9]:
wandb.login()

[34m[1mwandb[0m: Currently logged in as: [33mpkyoyetera[0m ([33mhidden-leaf[0m). Use [1m`wandb login --relogin`[0m to force relogin


True

In [10]:
column_names = ['user_id', 'item_id', 'rating', 'timestamp']

df = pd.read_csv('../data/ml-100k/u.data', sep='\t', names=column_names)
df.head()

Unnamed: 0,user_id,item_id,rating,timestamp
0,196,242,3,881250949
1,186,302,3,891717742
2,22,377,1,878887116
3,244,51,2,880606923
4,166,346,1,886397596


In [11]:
# 80/20 split
train_data, test_data = train_test_split(df.values, test_size=0.2, random_state=42)

train_df = pd.DataFrame(train_data, columns=column_names)
test_df = pd.DataFrame(test_data, columns=column_names)

print(f"Train size: {len(train_df)}")
print(f"Test size: {len(test_df)}")

Train size: 80000
Test size: 20000


In [12]:
# Relabel user and item IDs to start from 0
user_le, item_le = pp.LabelEncoder(), pp.LabelEncoder()

train_df['user_id_index'] = user_le.fit_transform(train_df['user_id'].values)
train_df['item_id_index'] = item_le.fit_transform(train_df['item_id'].values)

train_user_ids = train_df['user_id'].unique()
train_item_ids = train_df['item_id'].unique()

print(f"Number of unique users: {len(train_user_ids)}")
print(f"Number of unique items: {len(train_item_ids)}")



Number of unique users: 943
Number of unique items: 1653


In [13]:
test_df = test_df[(test_df['user_id'].isin(train_user_ids)) & (test_df['item_id'].isin(train_item_ids))]
print(f"Test size: {len(test_df)}")

Test size: 19969


In [14]:
test_df['user_id_index'] = user_le.transform(test_df['user_id'].values)
test_df['item_id_index'] = item_le.transform(test_df['item_id'].values)

In [15]:
n_users = train_df['user_id_index'].nunique()
n_items = train_df['item_id_index'].nunique()

print(f"Number of unique users: {n_users}")
print(f"Number of unique items: {n_items}")

Number of unique users: 943
Number of unique items: 1653


### Mini-batch sampling

In [16]:
def data_loader(data, _batch_size, n_usr, n_itm):
    
    def negative_sampler(x):
        while True:
            neg_id = random.randint(0, n_itm-1)
            if neg_id not in x:
                return neg_id
    
    intersection_df = data.groupby("user_id_index")["item_id_index"].apply(list).reset_index()
    indices = [x for x in range(n_usr)]
    
    if n_usr < _batch_size:
        _users = [random.choice(indices) for _ in range(_batch_size)]
    else:
        _users = random.sample(indices, _batch_size)
    
    _users.sort()
    users_df = pd.DataFrame(_users, columns=["users"])
    
    intersection_df = pd.merge(intersection_df, users_df, how="right", left_on="user_id_index", right_on="users")
    positive_items = intersection_df["item_id_index"].apply(lambda x : random.choice(x)).values
    negative_items = intersection_df["item_id_index"].apply(lambda x : negative_sampler(x)).values
    
    return torch.LongTensor(list(_users)).to(device), \
        torch.LongTensor(list(positive_items)).to(device) + n_usr, \
        torch.LongTensor(list(negative_items)).to(device) + n_usr

# test
data_loader(train_df, 8, n_users, n_items)


(tensor([ 85, 161, 217, 326, 637, 671, 805, 929], device='cuda:0'),
 tensor([1270, 1196, 1145, 1230, 1578, 1217, 1038, 1223], device='cuda:0'),
 tensor([1792, 1442, 1058, 1878, 1019,  976, 2557, 2149], device='cuda:0'))

### Edge Index 

In [17]:
u_t = torch.LongTensor(train_df.user_id_index)
i_t = torch.LongTensor(train_df.item_id_index) + n_users


In [18]:
train_edge_index = torch.stack((
    torch.cat([u_t, i_t]), 
    torch.cat([i_t, u_t]))
).to(device)  # .to(torch.int64)

train_edge_index

tensor([[ 806,  473,  462,  ..., 1417, 1264, 1142],
        [2347, 1601, 1210,  ...,  436,  283,  221]], device='cuda:0')

In [19]:
train_edge_index.dtype

torch.int64

In [20]:
# Confirm shapes 
train_edge_index[:, -1], train_edge_index[:, 0]

(tensor([1142,  221], device='cuda:0'), tensor([ 806, 2347], device='cuda:0'))

In [21]:
train_edge_index[:, len(train_df)-1], train_edge_index[:, len(train_df)]

(tensor([ 221, 1142], device='cuda:0'), tensor([2347,  806], device='cuda:0'))

## LightGCNConv architecture

In [22]:
class LightGCNConv(MessagePassing):
    def __init__(self, **kwargs):
        super().__init__(aggr='add')

    def forward(self, x, edge_index):
        # Compute normalization
        from_, to_ = edge_index
        deg = degree(to_, x.size(0), dtype=x.dtype)
        deg_inv_sqrt = deg.pow(-0.5)
        deg_inv_sqrt[deg_inv_sqrt == float('inf')] = 0
        norm = deg_inv_sqrt[from_] * deg_inv_sqrt[to_]
        
        # Start propagating messages (no update after aggregation)
        return self.propagate(edge_index, x=x, norm=norm)

    def message(self, x_j, norm):
        return norm.view(-1, 1) * x_j


## NGCF Conv model 

In [118]:
class NGCFConv(MessagePassing):
    def __init__(self, _latent_dim, dropout, bias=True, **kwargs):  
        super(NGCFConv, self).__init__(aggr='add', **kwargs)
        
        self.dropout = dropout
        
        self.lin_1 = nn.Linear(_latent_dim, _latent_dim, bias=bias)
        self.lin_2 = nn.Linear(_latent_dim, _latent_dim, bias=bias)
        
        self.init_parameters()

    def init_parameters(self):
        nn.init.xavier_uniform_(self.lin_1.weight)
        nn.init.xavier_uniform_(self.lin_2.weight)

    def forward(self, x, edge_index):
        # Compute normalization
        from_, to_ = edge_index
        
        deg = degree(to_, x.size(0), dtype=x.dtype)
        deg_inv_sqrt = deg.pow(-0.5)
        deg_inv_sqrt[deg_inv_sqrt == float('inf')] = 0
        norm = deg_inv_sqrt[from_] * deg_inv_sqrt[to_]
        
        # Start propagating messages
        _out = self.propagate(edge_index, x=(x, x), norm=norm)
        
        # Update after aggregation
        _out += self.lin_1(x)
        _out = F.dropout(_out, self.dropout, self.training)
        
        return F.leaky_relu(_out)

    def message(self, x_j, x_i, norm):
        return norm.view(-1, 1) * (self.lin_1(x_j) + self.lin_2(x_j * x_i))

  

In [119]:
class RecommenderSys(nn.Module):
    def __init__(self, _latent_dim, _num_layers, num_users, num_items, dropout=0.5):
        super(RecommenderSys, self).__init__()
        
        self.embedding = nn.Embedding(num_users + num_items, _latent_dim)
        self.dropout = nn.Dropout(dropout)
        
        # NGCF convs
        self.convs = nn.ModuleList(NGCFConv(_latent_dim, dropout) for _ in range(_num_layers))
        
        self.init_parameters()
        
    def init_parameters(self):
        nn.init.xavier_uniform_(self.embedding.weight) 
        
    def forward(self, edge_index):
        emb_init = self.embedding.weight
        embs = [emb_init]
        
        emb = emb_init
        for conv in self.convs:
            emb = conv(x=emb, edge_index=edge_index)
            embs.append(emb)
        
        out = (torch.cat(embs, dim=-1))
        
        return emb_init, out
        
    def encode_minibatch(self, _users, positive_items, negative_items, edge_index):
        emb0, _out = self(edge_index)  # .to(torch.int64))
        
        return (
            _out[_users],
            _out[positive_items],
            _out[negative_items],
            emb0[_users],
            emb0[positive_items],
            emb0[negative_items],
        )

sysss = RecommenderSys(64, 2, n_users, n_items) 
sysss

RecommenderSys(
  (embedding): Embedding(2596, 64)
  (dropout): Dropout(p=0.5, inplace=False)
  (convs): ModuleList(
    (0-1): 2 x NGCFConv()
  )
)

In [120]:
# BPR loss 
def compute_bpr_loss(_users, user_emb, pos_emb_, neg_emb_, init_user_emb, init_pos_emb, init_neg_emb):
    # Compute loss from initial embeddings, for regularization 
    reg_loss_ = (1/2) * (init_user_emb.norm().pow(2) + init_pos_emb.norm().pow(2) + init_neg_emb.norm().pow(2)) / float(len(_users))
    
    # Compute BPR loss from user, and positive item, and negative item embeddings
    pos_scores = torch.mul(user_emb, pos_emb_).sum(dim=1)
    neg_scores = torch.mul(user_emb, neg_emb_).sum(dim=1)
    
    bpr_loss_ = torch.mean(F.softplus(neg_scores - pos_scores))
    
    return bpr_loss_, reg_loss_



In [127]:
def get_metrics(user_embeddings, item_embeddings, n_users_, n_items_, train_data_, test_data_, K):
    # test_user_ids = torch.LongTensor(test_data_['user_id_index'].unique())
    # test_user_Embed_wts = user_embeddings[test_user_ids]
    
    # compute the score of all user-item pairs
    relevance_score = torch.matmul(user_embeddings, torch.transpose(item_embeddings, 0, 1))
    
    # create dense tensor of all user-item interactions
    # i = torch.stack((
    # torch.LongTensor(train_df['user_id_index'].values),
    # torch.LongTensor(train_df['item_id_index'].values)
    # ))
    # v = torch.ones((len(train_df)), dtype=torch.float64)

    # Dense tensor for all user-item interactions in the training data
    i = torch.stack((
        torch.LongTensor(train_data_['user_id_index'].values),
        torch.LongTensor(train_data_['item_id_index'].values)
    ))
    v = torch.ones((len(train_data_)), dtype=torch.float64)

    interactions_t = torch.sparse.FloatTensor(i, v, (n_users_, n_items_)).to_dense().to(device)
    
    # mask out training user-item interactions from metric computation
    relevance_score = torch.mul(relevance_score, (1 - interactions_t))
    
    # compute top scoring items for each user
    topk_relevance_indices = torch.topk(relevance_score, K).indices
    topk_relevance_indices_df = pd.DataFrame(
        topk_relevance_indices.cpu().numpy(),
        columns=['top_index_'+str(x+1) for x in range(K)]
    )
    topk_relevance_indices_df['user_ID'] = topk_relevance_indices_df.index
    topk_relevance_indices_df['top_relevant_item'] = topk_relevance_indices_df[['top_index_'+str(x+1) for x in range(K)]].values.tolist()
    topk_relevance_indices_df = topk_relevance_indices_df[['user_ID', 'top_relevant_item']]
    
    # measure overlap between recommended (top-scoring) and held-out user-item interactions
    test_interacted_items = test_data_.groupby('user_id_index')['item_id_index'].apply(list).reset_index()
    
    metrics_df = pd.merge(
        test_interacted_items,
        topk_relevance_indices_df,
        how='left',
        left_on='user_id_index',
        right_on = ['user_ID']
    )
    metrics_df['intersecting_item'] = [
        list(set(a).intersection(b)) for a, b in zip(metrics_df.item_id_index, metrics_df.top_relevant_item)
    ]
    
    metrics_df['recall'] = metrics_df.apply(lambda x : len(x['intersecting_item']) / len(x['item_id_index']), axis=1) 
    metrics_df['precision'] = metrics_df.apply(lambda x : len(x['intersecting_item']) / K, axis=1)
    
    return metrics_df['recall'].mean(), metrics_df['precision'].mean()



In [132]:
wandb.init(project="recsys-gnn", group="ngcf")


In [133]:
latent_dim = 64
num_layers = 3
batch_size = 2048
epochs = 50
decay = 1e-4
learning_rate = 1e-3
k = 20 

config = {
    "latent_dim": latent_dim,
    "num_layers": num_layers,
    "epochs": epochs,
    "decay": decay,
    "learning_rate": learning_rate,
    "batch_size": batch_size,
    "k": k,
}
wandb.config.update(config)


In [134]:
global_counter = 0

model = RecommenderSys(latent_dim, num_layers, n_users, n_items)
model.to(device)

optimizer = optim.Adam(model.parameters(), lr=learning_rate)


In [87]:
# # global_counter = 0
# 
# def train_and_eval(model, optimizer, train_df):
#     loss_list_epoch = []
#     bpr_loss_list_epoch = []
#     reg_loss_list_epoch = []
#     
#     recall_list = []
#     precision_list = []
# 
#     for epoch in tqdm(range(epochs)):
#         n_batch = int(len(train_df) / batch_size)
#     
#         final_loss_list = []
#         bpr_loss_list = []
#         reg_loss_list = []
#     
#         model.train()
#         for batch_idx in range(n_batch):
#     
#             optimizer.zero_grad()
#             
#             users, pos_items, neg_items = data_loader(train_df, batch_size, n_users, n_items)
#             users_emb, pos_emb, neg_emb, userEmb0,  posEmb0, negEmb0 = model.encode_minibatch(users, pos_items, neg_items, train_edge_index)
#             
#             bpr_loss, reg_loss = compute_bpr_loss(users, users_emb, pos_emb, neg_emb, userEmb0,  posEmb0, negEmb0)
#             reg_loss = decay * reg_loss
#             final_loss = bpr_loss + reg_loss
#             
#             final_loss.backward()
#             optimizer.step()
#             
#             final_loss_list.append(final_loss.item())
#             bpr_loss_list.append(bpr_loss.item())
#             reg_loss_list.append(reg_loss.item())
#             
#             wandb.log({
#                 "epoch": epoch,
#                 # "batch_idx": batch_idx,
#                 "training loss": final_loss.item(),
#                 "bpr_loss": bpr_loss.item(),
#                 "reg_loss": reg_loss.item(),
#             }, step=global_counter)
# 
#         model.eval()
#         with torch.no_grad():
#             _, out = model(train_edge_index)
#             final_user_Embed, final_item_Embed = torch.split(out, (n_users, n_items))
#             test_topK_recall,  test_topK_precision = get_metrics(
#             final_user_Embed, final_item_Embed, n_users, n_items, train_df, test_df, k
#             )
# 
#         loss_list_epoch.append(round(np.mean(final_loss_list),4))
#         bpr_loss_list_epoch.append(round(np.mean(bpr_loss_list),4))
#         reg_loss_list_epoch.append(round(np.mean(reg_loss_list),4))
#         
#         recall_list.append(round(test_topK_recall,4))
#         precision_list.append(round(test_topK_precision,4))
#         
#         global_counter += 1
# 
#     return (
#         loss_list_epoch, 
#         bpr_loss_list_epoch, 
#         reg_loss_list_epoch, 
#         recall_list, 
#         precision_list
#     )


In [135]:
wandb.watch(model, log="all")

for epoch in tqdm(range(epochs)):
    n_batch = int(len(train_df) / batch_size)

    model.train()
    for batch_idx in range(n_batch):
        optimizer.zero_grad()
        
        users, pos_items, neg_items = data_loader(train_df, batch_size, n_users, n_items)
        
        users_emb, pos_emb, neg_emb, userEmb0, posEmb0, negEmb0 = model.encode_minibatch(users, pos_items, neg_items, train_edge_index)
        
        bpr_loss, reg_loss = compute_bpr_loss(users, users_emb, pos_emb, neg_emb, userEmb0,  posEmb0, negEmb0)
        reg_loss *= decay  #  * reg_loss
        final_loss = bpr_loss + reg_loss
        
        final_loss.backward()
        optimizer.step()
        
        wandb.log({
            # "epoch": epoch,
            # "batch_idx": batch_idx,
            "training loss": final_loss.item(),
            "bpr_loss": bpr_loss.item(),
            "reg_loss": reg_loss.item(),
        }, step=global_counter)

        global_counter += 1

    model.eval()
    with torch.no_grad():
        _, out = model(train_edge_index)
        
        final_user_embed, final_item_embed = torch.split(out, (n_users, n_items))
        
        test_topK_recall, test_topK_precision = get_metrics(
            final_user_embed,
            final_item_embed,
            n_users,
            n_items,
            train_df,
            test_df,
            k
        )
        
        wandb.log({
            # "epoch": epoch,
            f"test_topK@{k}_recall": test_topK_recall,
            f"test_topK@{k}_precision": test_topK_precision
        },
        step=global_counter)

wandb.finish()


100%|██████████| 50/50 [05:06<00:00,  6.12s/it]


0,1
bpr_loss,█▆▇▅▅▄▄▄▃▃▃▃▃▃▂▃▂▂▂▃▃▃▂▂▂▂▃▂▁▂▂▂▁▂▂▁▂▂▁▁
reg_loss,▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▆▇▇▇▇▇███
test_topK@20_precision,▁▁▁▁▅▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇▇▇▇▇██████████████
test_topK@20_recall,▁▁▁▁▄▄▅▅▅▅▅▅▅▅▆▆▆▆▇▇▇▇▇▇▇▇▇█▇███████████
training loss,█▆▇▅▅▄▄▄▃▃▃▃▃▃▂▃▂▂▂▃▃▃▂▂▂▂▃▂▁▂▂▂▁▂▂▁▂▂▁▁

0,1
bpr_loss,0.22674
reg_loss,4e-05
test_topK@20_precision,0.24644
test_topK@20_recall,0.30128
training loss,0.22678


## Cross validation

with optuna

In [138]:
def define_model(trial):
    dim = trial.suggest_int("latent_dim", 16, 128)
    n_layers = trial.suggest_int("num_layers", 1, 3)
    
    model_ = RecommenderSys(dim, n_layers, n_users, n_items).to(device)
    
    return model_



In [143]:
def objective(trial):
    wandb.init(project="recsys-gnn", group="ngcf-crossval")
    
    model_ = define_model(trial)
    
    lr_ = trial.suggest_float("learning_rate", 1e-5, 1e-1, log=True)
    wt_decay = trial.suggest_float("weight_decay", 1e-5, 1e-2, log=True)
    
    optimizer = optim.AdamW(model_.parameters(), lr=lr_)
    
    wandb.config.update({
        "learning_rate": lr_,
        "weight_decay": wt_decay
    })
    wandb.watch(model, log="all")
    
    global_counter_ = 0
    
    for _ in tqdm(range(epochs)):
        n_batch_ = int(len(train_df) / batch_size)
    
        model.train()
        for batch_idx_ in range(n_batch_):
            optimizer.zero_grad()
            
            users, pos_items, neg_items = data_loader(train_df, batch_size, n_users, n_items)
            
            users_emb, pos_emb, neg_emb, userEmb0, posEmb0, negEmb0 = model_.encode_minibatch(users, pos_items, neg_items, train_edge_index)
            
            bpr_loss, reg_loss = compute_bpr_loss(users, users_emb, pos_emb, neg_emb, userEmb0,  posEmb0, negEmb0)
            reg_loss *= wt_decay  #  * reg_loss
            final_loss = bpr_loss + reg_loss
            
            final_loss.backward()
            optimizer.step()
            
            wandb.log({
                "training loss": final_loss.item(),
                "bpr_loss": bpr_loss.item(),
                "reg_loss": reg_loss.item(),
            }, step=global_counter_)
    
            global_counter_ += 1
    
        model_.eval()
        with torch.no_grad():
            _, out = model_(train_edge_index)
            
            final_user_embed, final_item_embed = torch.split(out, (n_users, n_items))
            
            test_topK_recall, test_topK_precision = get_metrics(
                final_user_embed,
                final_item_embed,
                n_users,
                n_items,
                train_df,
                test_df,
                k
            )
            
            wandb.log({
                # "epoch": epoch,
                f"test_topK@{k}_recall": test_topK_recall,
                f"test_topK@{k}_precision": test_topK_precision
            },
            step=global_counter)
    
    wandb.finish()

    trial.set_user_attr('best_model', model_)
    
    return test_topK_recall


In [144]:
def call_back(study_, trial):
    if study_.best_trial.number == trial.number:
        study_.set_user_attr('best_model', trial.user_attrs['best_model'])
        

In [145]:
study = optuna.create_study(pruner=optuna.pruners.MedianPruner(n_warmup_steps=10), direction='maximize')

[I 2023-12-11 21:56:01,164] A new study created in memory with name: no-name-fd9159e6-0cb9-4e46-bea5-90423a6cdcf8


In [146]:
study.optimize(objective, n_trials=50, callbacks=[call_back])

  2%|▏         | 1/50 [00:07<06:17,  7.71s/it]
[W 2023-12-11 21:56:14,991] Trial 0 failed with parameters: {'latent_dim': 42, 'num_layers': 2, 'learning_rate': 0.0003585815106995502, 'weight_decay': 7.54296403702685e-05} because of the following error: KeyboardInterrupt().
Traceback (most recent call last):
  File "/home/patrick/Documents/MSc/Fall23/Project/proj-env/lib/python3.10/site-packages/optuna/study/_optimize.py", line 200, in _run_trial
    value_or_values = func(trial)
  File "/tmp/ipykernel_2599943/3326939693.py", line 38, in objective
    "training loss": final_loss.item(),
KeyboardInterrupt
[W 2023-12-11 21:56:14,993] Trial 0 failed with value None.


KeyboardInterrupt: 