In [None]:
# Group project for CSE 6240: Web Search and Text Mining at Georgia Tech
# Author: Kien Tran (github.com/trantrikien239)
%load_ext autoreload
%autoreload 2

In [None]:
import torch
from torch import Tensor
print(torch.__version__)

2.0.0+cu118


In [None]:
# Install required packages.
import os
os.environ['TORCH'] = torch.__version__

!pip install torch-scatter -f https://data.pyg.org/whl/torch-${TORCH}.html
!pip install torch-sparse -f https://data.pyg.org/whl/torch-${TORCH}.html
!pip install pyg-lib -f https://data.pyg.org/whl/nightly/torch-${TORCH}.html
!pip install git+https://github.com/pyg-team/pytorch_geometric.git

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Looking in links: https://data.pyg.org/whl/torch-2.0.0+cu118.html
Collecting torch-scatter
  Downloading https://data.pyg.org/whl/torch-2.0.0%2Bcu118/torch_scatter-2.1.1%2Bpt20cu118-cp39-cp39-linux_x86_64.whl (10.2 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m10.2/10.2 MB[0m [31m102.3 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: torch-scatter
Successfully installed torch-scatter-2.1.1+pt20cu118
Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Looking in links: https://data.pyg.org/whl/torch-2.0.0+cu118.html
Collecting torch-sparse
  Downloading https://data.pyg.org/whl/torch-2.0.0%2Bcu118/torch_sparse-0.6.17%2Bpt20cu118-cp39-cp39-linux_x86_64.whl (4.8 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m4.8/4.8 MB[0m [31m84.6 MB/s[0m eta [36m0:00:00[0m
Installing collected p

In [None]:
from google.colab import drive
drive.mount('/content/drive/')

Mounted at /content/drive/


In [None]:
%cd /content/drive/MyDrive/CSE6240/

/content/drive/.shortcut-targets-by-id/1y1Lvxa-gIYQlLQs7Dzoz9XVXO1BPTpP5/CSE6240


In [None]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt

import torch
from torch_geometric import nn
from torch_geometric.data import Data, HeteroData
from torch_geometric.loader import DataLoader
import torch_geometric.transforms as T
from torch_geometric.loader import LinkNeighborLoader

In [None]:
from models.interaction_func import DotProduct, MLP2LayersV1, MLP2LayersV2
from models.graphsage import GraphSage2Layers, GraphSage3Layers, GraphSageLinkPred,\
    GraphSageLinkPredNoEmb
from models.gat import GAT2Layers, GAT3Layers, GATLinkPred, GATLinkPredNoEmb

In [None]:
from models.utils import prep_dataset, eval_auc, train, train_early_stop

In [None]:
from sklearn.metrics import roc_auc_score, classification_report

In [None]:
BATCH_SIZE = 2048 * 4

In [None]:
# Load data
test_edges = pd.read_parquet('./data/graphs/test_edges.parquet')
test_labels = pd.read_parquet('./data/graphs/test_labels_neg.parquet')

val_edges = pd.read_parquet('./data/graphs/val_edges.parquet')
val_labels = pd.read_parquet('./data/graphs/val_labels_neg.parquet')

train_edges = pd.read_parquet('./data/graphs/train_edges.parquet')
train_labels = pd.read_parquet('./data/graphs/train_labels_neg.parquet')

In [None]:
prod_feat = np.load('./data/feature_emb/products.npy')
user_feat = np.load('./data/feature_agg/train_user_features_norm.npy')

In [None]:
with open(r"saved_model/node2vec_128_v2.pkl", "rb") as f:
    z = torch.load(f, map_location=torch.device('cpu'))
z.shape

torch.Size([256210, 128])

In [None]:
enc_user_id = train_edges[["user_id"]].drop_duplicates().sort_values("user_id").reset_index(drop=True).reset_index().rename(columns={"index": "enc_user_id"})
enc_user_id_dict = dict(zip(enc_user_id.user_id, enc_user_id.enc_user_id))
# encode user_id
train_edges["enc_user_id"] = train_edges["user_id"].map(enc_user_id_dict)
val_edges["enc_user_id"] = val_edges["user_id"].map(enc_user_id_dict)
test_edges["enc_user_id"] = test_edges["user_id"].map(enc_user_id_dict)
train_labels["enc_user_id"] = train_labels["user_id"].map(enc_user_id_dict)
val_labels["enc_user_id"] = val_labels["user_id"].map(enc_user_id_dict)
test_labels["enc_user_id"] = test_labels["user_id"].map(enc_user_id_dict)


In [None]:
train_edges

Unnamed: 0,user_id,product_id,weight,enc_user_id
0,1,196,0.700000,0
1,1,10258,0.600000,0
2,1,10326,0.100000,0
3,1,12427,0.700000,0
4,1,13032,0.200000,0
...,...,...,...,...
8675716,206209,41665,0.076923,101695
8675717,206209,43961,0.153846,101695
8675718,206209,44325,0.076923,101695
8675719,206209,48697,0.076923,101695


In [None]:
map_uid = train_edges[["user_id", "enc_user_id"]].drop_duplicates()\
  .sort_values("enc_user_id").reset_index(drop=True)
map_uid["fake_uid"] = map_uid["user_id"] + 50000
map_uid

Unnamed: 0,user_id,enc_user_id,fake_uid
0,1,0,50001
1,2,1,50002
2,3,2,50003
3,7,3,50007
4,13,4,50013
...,...,...,...
101691,206202,101691,256202
101692,206206,101692,256206
101693,206207,101693,256207
101694,206208,101694,256208


In [None]:
n2v_user_emb = z[map_uid["fake_uid"],:].detach().cpu().numpy()
n2v_user_emb.shape

(101696, 128)

In [None]:
n_prod = prod_feat.shape[0]
n2v_prod_emb = z[:n_prod,:].detach().cpu().numpy()
n2v_prod_emb.shape

(49689, 128)

In [None]:
prod_features = np.hstack([n2v_prod_emb, prod_feat])
user_features = np.hstack([n2v_user_emb, user_feat])

### Prepare dataset and data loaders

In [None]:
# Data loaders for baseline features
base_train_data = prep_dataset(train_edges, train_labels, user_feat, prod_feat)
base_train_loader = LinkNeighborLoader(
    data=base_train_data,
    num_neighbors=[20, 10, 20],
    edge_label_index=(("user", "buy", "prod"), 
                      base_train_data["user", "buy", "prod"].edge_label_index),
    edge_label=base_train_data["user", "buy", "prod"].edge_label,
    batch_size=BATCH_SIZE,
    shuffle=True,
)

In [None]:
base_val_data_lite = prep_dataset(val_edges, 
                        val_labels[val_labels["enc_user_id"] < 10_000], 
                        user_feat, prod_feat)
# Define the validation seed edges:
base_val_loader_lite = LinkNeighborLoader(
    data=base_val_data_lite,
    num_neighbors=[20, 10, 20],
    edge_label_index=(("user", "buy", "prod"), 
                      base_val_data_lite["user", "buy", "prod"].edge_label_index),
    edge_label=base_val_data_lite["user", "buy", "prod"].edge_label,
    batch_size=BATCH_SIZE * 4,
    shuffle=False,
)
val_data_large = prep_dataset(val_edges, 
                        val_labels[val_labels["enc_user_id"] >= 10_000], 
                        user_feat, prod_feat)
# Define the validation seed edges:
base_val_loader_large = LinkNeighborLoader(
    data=val_data_large,
    num_neighbors=[20, 10, 20],
    edge_label_index=(("user", "buy", "prod"), 
                      val_data_large["user", "buy", "prod"].edge_label_index),
    edge_label=val_data_large["user", "buy", "prod"].edge_label,
    batch_size=BATCH_SIZE * 4,
    shuffle=False,
)


In [None]:
n_user, base_user_feat_size = user_feat.shape
n_prod, base_prod_feat_size = prod_feat.shape

In [None]:
# Data loaders for advanced features
train_data = prep_dataset(train_edges, train_labels, user_features, prod_features)
train_loader = LinkNeighborLoader(
    data=train_data,
    num_neighbors=[20, 10, 20],
    edge_label_index=(("user", "buy", "prod"), 
                      train_data["user", "buy", "prod"].edge_label_index),
    edge_label=train_data["user", "buy", "prod"].edge_label,
    batch_size=BATCH_SIZE,
    shuffle=True,
)

# Inspect a sample:
sampled_data = next(iter(train_loader))

print("Sampled mini-batch:")
print("===================")
print(sampled_data)

print(sampled_data["user", "buy", "prod"].edge_label_index.size(1))
print(sampled_data["user", "buy", "prod"].edge_label.min())
print(sampled_data["user", "buy", "prod"].edge_label.max())

Sampled mini-batch:
HeteroData(
  [1muser[0m={
    node_id=[85936],
    x=[85936, 150],
    n_id=[85936]
  },
  [1mprod[0m={
    node_id=[40418],
    x=[40418, 896],
    n_id=[40418]
  },
  [1m(user, buy, prod)[0m={
    edge_index=[2, 468335],
    edge_label_index=[2, 8192],
    edge_label=[8192],
    e_id=[468335],
    input_id=[8192]
  },
  [1m(prod, rev_buy, user)[0m={
    edge_index=[2, 1067969],
    e_id=[1067969]
  }
)
8192
tensor(0.)
tensor(1.)


In [None]:
val_data_lite = prep_dataset(val_edges, 
                        val_labels[val_labels["enc_user_id"] < 10_000], 
                        user_features, prod_features)
# Define the validation seed edges:
val_loader_lite = LinkNeighborLoader(
    data=val_data_lite,
    num_neighbors=[20, 10, 20],
    edge_label_index=(("user", "buy", "prod"), 
                      val_data_lite["user", "buy", "prod"].edge_label_index),
    edge_label=val_data_lite["user", "buy", "prod"].edge_label,
    batch_size=BATCH_SIZE * 4,
    shuffle=False,
)
val_data_large = prep_dataset(val_edges, 
                        val_labels[val_labels["enc_user_id"] >= 10_000], 
                        user_features, prod_features)
# Define the validation seed edges:
val_loader_large = LinkNeighborLoader(
    data=val_data_large,
    num_neighbors=[20, 10, 20],
    edge_label_index=(("user", "buy", "prod"), 
                      val_data_large["user", "buy", "prod"].edge_label_index),
    edge_label=val_data_large["user", "buy", "prod"].edge_label,
    batch_size=BATCH_SIZE * 4,
    shuffle=False,
)


In [None]:
n_user, user_feat_size = user_features.shape
n_prod, prod_feat_size = prod_features.shape
train_meta_data = train_data.metadata()
train_meta_data

(['user', 'prod'], [('user', 'buy', 'prod'), ('prod', 'rev_buy', 'user')])

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Exp 1: Feature set x NCF

Experiment with GraphSAGE, try combinations of:
- Baseline Features vs Advanced Features (Baseline + Node2Vec embeddings)
- Vallina GraphSAGE vs (GraphSAGE + NCF-inspired embeddings)

In [None]:
gs_base_noemb = GraphSageLinkPredNoEmb(hidden_channels=256, 
                 user_feat_size=base_user_feat_size, 
                 prod_feat_size=base_prod_feat_size, 
                 dataset_metadata=train_meta_data,
                 n_gnn_layers=3,
                 interaction_func="mlp2layers_v1")

gs_base_emb = GraphSageLinkPred(hidden_channels=256, 
                 n_user=n_user, n_prod=n_prod, 
                 user_feat_size=base_user_feat_size, 
                 prod_feat_size=base_prod_feat_size, 
                 dataset_metadata=train_meta_data,
                 n_gnn_layers=3,
                 interaction_func="mlp2layers_v1")

gs_adv_noemb = GraphSageLinkPredNoEmb(hidden_channels=256, 
                 user_feat_size=user_feat_size, 
                 prod_feat_size=prod_feat_size, 
                 dataset_metadata=train_meta_data,
                 n_gnn_layers=3,
                 interaction_func="mlp2layers_v1")

gs_adv_emb = GraphSageLinkPred(hidden_channels=256, 
                 n_user=n_user, n_prod=n_prod, 
                 user_feat_size=user_feat_size, 
                 prod_feat_size=prod_feat_size, 
                 dataset_metadata=train_meta_data,
                 n_gnn_layers=3,
                 interaction_func="mlp2layers_v1")

In [None]:
gs_base_noemb, df_stats = train_early_stop(
    gs_base_noemb, base_train_loader, base_val_loader_lite, device, 
    epochs=20, return_stats=True)
auc1, gt1, pred1 = eval_auc(gs_base_noemb, base_val_loader_large, device)
print("AUC: ", auc1)
pred_proba = torch.sigmoid(torch.tensor(pred1)).cpu().numpy()
pred_binary = (pred_proba > 0.5).astype(int)
print(classification_report(gt1, pred_binary))

  0%|          | 0/81 [00:00<?, ?it/s]

Evaluating on the lite validation set...


  0%|          | 0/2 [00:00<?, ?it/s]

Epoch: 000, Loss: 0.5242, AUC: 0.8757, Precision: 0.7715, Recall: 0.8250, F1: 0.7973


  0%|          | 0/81 [00:00<?, ?it/s]

Evaluating on the lite validation set...


  0%|          | 0/2 [00:00<?, ?it/s]

Epoch: 001, Loss: 0.4158, AUC: 0.8903, Precision: 0.7672, Recall: 0.8754, F1: 0.8177


  0%|          | 0/81 [00:00<?, ?it/s]

Evaluating on the lite validation set...


  0%|          | 0/2 [00:00<?, ?it/s]

Epoch: 002, Loss: 0.3993, AUC: 0.8939, Precision: 0.7612, Recall: 0.8977, F1: 0.8238


  0%|          | 0/81 [00:00<?, ?it/s]

Evaluating on the lite validation set...


  0%|          | 0/2 [00:00<?, ?it/s]

Epoch: 003, Loss: 0.3878, AUC: 0.9000, Precision: 0.7991, Recall: 0.8396, F1: 0.8189


  0%|          | 0/81 [00:00<?, ?it/s]

Evaluating on the lite validation set...


  0%|          | 0/2 [00:00<?, ?it/s]

Epoch: 004, Loss: 0.3814, AUC: 0.9038, Precision: 0.7774, Recall: 0.8972, F1: 0.8330


  0%|          | 0/81 [00:00<?, ?it/s]

Evaluating on the lite validation set...


  0%|          | 0/2 [00:00<?, ?it/s]

Epoch: 005, Loss: 0.3764, AUC: 0.9054, Precision: 0.7910, Recall: 0.8773, F1: 0.8319


  0%|          | 0/81 [00:00<?, ?it/s]

Evaluating on the lite validation set...


  0%|          | 0/2 [00:00<?, ?it/s]

Epoch: 006, Loss: 0.3721, AUC: 0.9087, Precision: 0.7931, Recall: 0.8824, F1: 0.8354


  0%|          | 0/81 [00:00<?, ?it/s]

Evaluating on the lite validation set...


  0%|          | 0/2 [00:00<?, ?it/s]

Epoch: 007, Loss: 0.3683, AUC: 0.9091, Precision: 0.7799, Recall: 0.9076, F1: 0.8389


  0%|          | 0/81 [00:00<?, ?it/s]

Evaluating on the lite validation set...


  0%|          | 0/2 [00:00<?, ?it/s]

Epoch: 008, Loss: 0.3638, AUC: 0.9114, Precision: 0.7978, Recall: 0.8791, F1: 0.8365


  0%|          | 0/81 [00:00<?, ?it/s]

Evaluating on the lite validation set...


  0%|          | 0/2 [00:00<?, ?it/s]

Epoch: 009, Loss: 0.3607, AUC: 0.9125, Precision: 0.7907, Recall: 0.8915, F1: 0.8381


  0%|          | 0/81 [00:00<?, ?it/s]

Evaluating on the lite validation set...


  0%|          | 0/2 [00:00<?, ?it/s]

Epoch: 010, Loss: 0.3571, AUC: 0.9132, Precision: 0.7967, Recall: 0.8867, F1: 0.8393


  0%|          | 0/81 [00:00<?, ?it/s]

Evaluating on the lite validation set...


  0%|          | 0/2 [00:00<?, ?it/s]

Epoch: 011, Loss: 0.3536, AUC: 0.9140, Precision: 0.8104, Recall: 0.8648, F1: 0.8367


  0%|          | 0/81 [00:00<?, ?it/s]

Evaluating on the lite validation set...


  0%|          | 0/2 [00:00<?, ?it/s]

Epoch: 012, Loss: 0.3496, AUC: 0.9176, Precision: 0.8102, Recall: 0.8729, F1: 0.8404


  0%|          | 0/81 [00:00<?, ?it/s]

Evaluating on the lite validation set...


  0%|          | 0/2 [00:00<?, ?it/s]

Epoch: 013, Loss: 0.3465, AUC: 0.9191, Precision: 0.8174, Recall: 0.8648, F1: 0.8404


  0%|          | 0/81 [00:00<?, ?it/s]

Evaluating on the lite validation set...


  0%|          | 0/2 [00:00<?, ?it/s]

Epoch: 014, Loss: 0.3413, AUC: 0.9205, Precision: 0.8201, Recall: 0.8647, F1: 0.8418


  0%|          | 0/81 [00:00<?, ?it/s]

Evaluating on the lite validation set...


  0%|          | 0/2 [00:00<?, ?it/s]

Epoch: 015, Loss: 0.3384, AUC: 0.9201, Precision: 0.8097, Recall: 0.8828, F1: 0.8447
Early stopping at epoch 15


  0%|          | 0/19 [00:00<?, ?it/s]

AUC:  0.921811370009544
              precision    recall  f1-score   support

         0.0       0.87      0.79      0.83    298582
         1.0       0.81      0.88      0.85    298582

    accuracy                           0.84    597164
   macro avg       0.84      0.84      0.84    597164
weighted avg       0.84      0.84      0.84    597164



In [None]:
df_stats.to_parquet("stats/gs_base_noemb.parquet")
df_stats.tail()

Unnamed: 0,epoch,train_loss,val_auc,val_precision,val_recall,val_f1
10,10,0.357136,0.913243,0.796652,0.886722,0.839278
11,11,0.353589,0.914026,0.810435,0.864805,0.836738
12,12,0.34959,0.917622,0.810198,0.872865,0.840365
13,13,0.346458,0.91914,0.817375,0.864836,0.840436
14,14,0.341329,0.92049,0.820074,0.864743,0.841816


In [None]:
gs_base_emb, df_stats = train_early_stop(
    gs_base_emb, base_train_loader, base_val_loader_lite, device, 
    epochs=20, return_stats=True)
auc1, gt1, pred1 = eval_auc(gs_base_emb, base_val_loader_large, device)
print("AUC: ", auc1)
pred_proba = torch.sigmoid(torch.tensor(pred1)).cpu().numpy()
pred_binary = (pred_proba > 0.5).astype(int)
print(classification_report(gt1, pred_binary))
df_stats.to_parquet("stats/gs_base_emb.parquet")
df_stats.tail()

  0%|          | 0/81 [00:00<?, ?it/s]

Evaluating on the lite validation set...


  0%|          | 0/2 [00:00<?, ?it/s]

Epoch: 000, Loss: 0.5006, AUC: 0.9030, Precision: 0.7987, Recall: 0.8399, F1: 0.8188


  0%|          | 0/81 [00:00<?, ?it/s]

Evaluating on the lite validation set...


  0%|          | 0/2 [00:00<?, ?it/s]

Epoch: 001, Loss: 0.3524, AUC: 0.9219, Precision: 0.8785, Recall: 0.7562, F1: 0.8128


  0%|          | 0/81 [00:00<?, ?it/s]

Evaluating on the lite validation set...


  0%|          | 0/2 [00:00<?, ?it/s]

Epoch: 002, Loss: 0.3204, AUC: 0.9288, Precision: 0.8314, Recall: 0.8708, F1: 0.8506


  0%|          | 0/81 [00:00<?, ?it/s]

Evaluating on the lite validation set...


  0%|          | 0/2 [00:00<?, ?it/s]

Epoch: 003, Loss: 0.3033, AUC: 0.9308, Precision: 0.8240, Recall: 0.8892, F1: 0.8553


  0%|          | 0/81 [00:00<?, ?it/s]

Evaluating on the lite validation set...


  0%|          | 0/2 [00:00<?, ?it/s]

Epoch: 004, Loss: 0.2946, AUC: 0.9312, Precision: 0.8282, Recall: 0.8896, F1: 0.8578


  0%|          | 0/81 [00:00<?, ?it/s]

Evaluating on the lite validation set...


  0%|          | 0/2 [00:00<?, ?it/s]

Epoch: 005, Loss: 0.2848, AUC: 0.9317, Precision: 0.8289, Recall: 0.8883, F1: 0.8576


  0%|          | 0/81 [00:00<?, ?it/s]

Evaluating on the lite validation set...


  0%|          | 0/2 [00:00<?, ?it/s]

Epoch: 006, Loss: 0.2766, AUC: 0.9307, Precision: 0.8251, Recall: 0.8914, F1: 0.8570
Early stopping at epoch 6


  0%|          | 0/19 [00:00<?, ?it/s]

AUC:  0.9308854058463027
              precision    recall  f1-score   support

         0.0       0.88      0.81      0.84    298582
         1.0       0.82      0.89      0.86    298582

    accuracy                           0.85    597164
   macro avg       0.85      0.85      0.85    597164
weighted avg       0.85      0.85      0.85    597164



Unnamed: 0,epoch,train_loss,val_auc,val_precision,val_recall,val_f1
1,1,0.352375,0.921884,0.878484,0.756177,0.812755
2,2,0.32037,0.928752,0.83135,0.870757,0.850597
3,3,0.303344,0.930788,0.823964,0.889202,0.855341
4,4,0.294643,0.931244,0.828172,0.889636,0.857805
5,5,0.284756,0.931668,0.828936,0.888272,0.857579


In [None]:
print("gs_base_noemb num parameters:", sum(p.numel() for p in gs_base_noemb.parameters()))
print("gs_base_emb num parameters:", sum(p.numel() for p in gs_base_emb.parameters()))

gs_base_noemb num parameters: 1056769
gs_base_emb num parameters: 39811329


In [None]:
gs_adv_noemb, df_stats = train_early_stop(
    gs_adv_noemb, train_loader, val_loader_lite, device, 
    epochs=20, return_stats=True)
auc1, gt1, pred1 = eval_auc(gs_adv_noemb, val_loader_large, device)
print("AUC: ", auc1)
pred_proba = torch.sigmoid(torch.tensor(pred1)).cpu().numpy()
pred_binary = (pred_proba > 0.5).astype(int)
print(classification_report(gt1, pred_binary))
df_stats.to_parquet("stats/gs_adv_noemb.parquet")
df_stats.tail()

  0%|          | 0/81 [00:00<?, ?it/s]

Evaluating on the lite validation set...


  0%|          | 0/2 [00:00<?, ?it/s]

Epoch: 000, Loss: 0.4321, AUC: 0.9242, Precision: 0.8416, Recall: 0.8559, F1: 0.8487


  0%|          | 0/81 [00:00<?, ?it/s]

Evaluating on the lite validation set...


  0%|          | 0/2 [00:00<?, ?it/s]

Epoch: 001, Loss: 0.3182, AUC: 0.9391, Precision: 0.8584, Recall: 0.8707, F1: 0.8645


  0%|          | 0/81 [00:00<?, ?it/s]

Evaluating on the lite validation set...


  0%|          | 0/2 [00:00<?, ?it/s]

Epoch: 002, Loss: 0.3061, AUC: 0.9415, Precision: 0.8454, Recall: 0.8970, F1: 0.8704


  0%|          | 0/81 [00:00<?, ?it/s]

Evaluating on the lite validation set...


  0%|          | 0/2 [00:00<?, ?it/s]

Epoch: 003, Loss: 0.2994, AUC: 0.9439, Precision: 0.8598, Recall: 0.8849, F1: 0.8722


  0%|          | 0/81 [00:00<?, ?it/s]

Evaluating on the lite validation set...


  0%|          | 0/2 [00:00<?, ?it/s]

Epoch: 004, Loss: 0.2954, AUC: 0.9448, Precision: 0.8588, Recall: 0.8905, F1: 0.8744


  0%|          | 0/81 [00:00<?, ?it/s]

Evaluating on the lite validation set...


  0%|          | 0/2 [00:00<?, ?it/s]

Epoch: 005, Loss: 0.2917, AUC: 0.9456, Precision: 0.8656, Recall: 0.8820, F1: 0.8737


  0%|          | 0/81 [00:00<?, ?it/s]

Evaluating on the lite validation set...


  0%|          | 0/2 [00:00<?, ?it/s]

Epoch: 006, Loss: 0.2888, AUC: 0.9466, Precision: 0.8565, Recall: 0.9018, F1: 0.8786


  0%|          | 0/81 [00:00<?, ?it/s]

Evaluating on the lite validation set...


  0%|          | 0/2 [00:00<?, ?it/s]

Epoch: 007, Loss: 0.2854, AUC: 0.9473, Precision: 0.8554, Recall: 0.9067, F1: 0.8803


  0%|          | 0/81 [00:00<?, ?it/s]

Evaluating on the lite validation set...


  0%|          | 0/2 [00:00<?, ?it/s]

Epoch: 008, Loss: 0.2817, AUC: 0.9481, Precision: 0.8560, Recall: 0.9094, F1: 0.8819


  0%|          | 0/81 [00:00<?, ?it/s]

Evaluating on the lite validation set...


  0%|          | 0/2 [00:00<?, ?it/s]

Epoch: 009, Loss: 0.2786, AUC: 0.9491, Precision: 0.8646, Recall: 0.8995, F1: 0.8817


  0%|          | 0/81 [00:00<?, ?it/s]

Evaluating on the lite validation set...


  0%|          | 0/2 [00:00<?, ?it/s]

Epoch: 010, Loss: 0.2761, AUC: 0.9492, Precision: 0.8628, Recall: 0.9016, F1: 0.8818


  0%|          | 0/81 [00:00<?, ?it/s]

Evaluating on the lite validation set...


  0%|          | 0/2 [00:00<?, ?it/s]

Epoch: 011, Loss: 0.2734, AUC: 0.9497, Precision: 0.8647, Recall: 0.9025, F1: 0.8832


  0%|          | 0/81 [00:00<?, ?it/s]

Evaluating on the lite validation set...


  0%|          | 0/2 [00:00<?, ?it/s]

Epoch: 012, Loss: 0.2707, AUC: 0.9502, Precision: 0.8595, Recall: 0.9107, F1: 0.8843


  0%|          | 0/81 [00:00<?, ?it/s]

Evaluating on the lite validation set...


  0%|          | 0/2 [00:00<?, ?it/s]

Epoch: 013, Loss: 0.2684, AUC: 0.9510, Precision: 0.8636, Recall: 0.9086, F1: 0.8855


  0%|          | 0/81 [00:00<?, ?it/s]

Evaluating on the lite validation set...


  0%|          | 0/2 [00:00<?, ?it/s]

Epoch: 014, Loss: 0.2659, AUC: 0.9505, Precision: 0.8632, Recall: 0.9062, F1: 0.8842
Early stopping at epoch 14


  0%|          | 0/19 [00:00<?, ?it/s]

AUC:  0.9501611890190083
              precision    recall  f1-score   support

         0.0       0.90      0.85      0.88    298582
         1.0       0.86      0.91      0.88    298582

    accuracy                           0.88    597164
   macro avg       0.88      0.88      0.88    597164
weighted avg       0.88      0.88      0.88    597164



Unnamed: 0,epoch,train_loss,val_auc,val_precision,val_recall,val_f1
9,9,0.278626,0.949139,0.864567,0.899464,0.88167
10,10,0.276086,0.949184,0.862839,0.901572,0.88178
11,11,0.273366,0.949725,0.864653,0.902502,0.883172
12,12,0.270659,0.950234,0.859471,0.910655,0.884323
13,13,0.268385,0.950985,0.86363,0.908609,0.885549


In [None]:
gs_adv_emb, df_stats = train_early_stop(
    gs_adv_emb, train_loader, val_loader_lite, device, 
    epochs=20, return_stats=True)
auc1, gt1, pred1 = eval_auc(gs_adv_emb, val_loader_large, device)
print("AUC: ", auc1)
pred_proba = torch.sigmoid(torch.tensor(pred1)).cpu().numpy()
pred_binary = (pred_proba > 0.5).astype(int)
print(classification_report(gt1, pred_binary))
df_stats.to_parquet("stats/gs_adv_emb.parquet")
df_stats.tail()

  0%|          | 0/81 [00:00<?, ?it/s]

Evaluating on the lite validation set...


  0%|          | 0/2 [00:00<?, ?it/s]

Epoch: 000, Loss: 0.4984, AUC: 0.9039, Precision: 0.8215, Recall: 0.7993, F1: 0.8103


  0%|          | 0/81 [00:00<?, ?it/s]

Evaluating on the lite validation set...


  0%|          | 0/2 [00:00<?, ?it/s]

Epoch: 001, Loss: 0.3446, AUC: 0.9300, Precision: 0.8531, Recall: 0.8473, F1: 0.8502


  0%|          | 0/81 [00:00<?, ?it/s]

Evaluating on the lite validation set...


  0%|          | 0/2 [00:00<?, ?it/s]

Epoch: 002, Loss: 0.3118, AUC: 0.9345, Precision: 0.8488, Recall: 0.8707, F1: 0.8596


  0%|          | 0/81 [00:00<?, ?it/s]

Evaluating on the lite validation set...


  0%|          | 0/2 [00:00<?, ?it/s]

Epoch: 003, Loss: 0.2974, AUC: 0.9361, Precision: 0.8537, Recall: 0.8688, F1: 0.8612


  0%|          | 0/81 [00:00<?, ?it/s]

Evaluating on the lite validation set...


  0%|          | 0/2 [00:00<?, ?it/s]

Epoch: 004, Loss: 0.2854, AUC: 0.9366, Precision: 0.8434, Recall: 0.8839, F1: 0.8632


  0%|          | 0/81 [00:00<?, ?it/s]

Evaluating on the lite validation set...


  0%|          | 0/2 [00:00<?, ?it/s]

Epoch: 005, Loss: 0.2753, AUC: 0.9359, Precision: 0.8393, Recall: 0.8916, F1: 0.8646
Early stopping at epoch 5


  0%|          | 0/19 [00:00<?, ?it/s]

AUC:  0.936256474882958
              precision    recall  f1-score   support

         0.0       0.89      0.83      0.85    298582
         1.0       0.84      0.89      0.86    298582

    accuracy                           0.86    597164
   macro avg       0.86      0.86      0.86    597164
weighted avg       0.86      0.86      0.86    597164



Unnamed: 0,epoch,train_loss,val_auc,val_precision,val_recall,val_f1
0,0,0.498448,0.903909,0.821487,0.79933,0.810257
1,1,0.344648,0.930033,0.853085,0.84732,0.850193
2,2,0.311799,0.934506,0.848776,0.870695,0.859596
3,3,0.297435,0.936121,0.853718,0.868804,0.861195
4,4,0.285409,0.936593,0.84343,0.883932,0.863206


# Exp 2: GraphSAGE: Depth and Size

Since we have determined that gs_adv_noemb is the best setting so far, let's experiments with a few variations of this model. Namely: Depth and Hidden size

In [None]:
# Depth try 2 gnn layers instead of the default 3
gs_adv_noemb_2gnn = GraphSageLinkPredNoEmb(hidden_channels=256, 
                 user_feat_size=user_feat_size, 
                 prod_feat_size=prod_feat_size, 
                 dataset_metadata=train_meta_data,
                 n_gnn_layers=2,
                 interaction_func="mlp2layers_v1")
gs_adv_noemb_2gnn, df_stats = train_early_stop(
    gs_adv_noemb_2gnn, train_loader, val_loader_lite, device, 
    epochs=20, return_stats=True)
auc1, gt1, pred1 = eval_auc(gs_adv_noemb_2gnn, val_loader_large, device)
print("AUC: ", auc1)
pred_proba = torch.sigmoid(torch.tensor(pred1)).cpu().numpy()
pred_binary = (pred_proba > 0.5).astype(int)
print(classification_report(gt1, pred_binary))
df_stats.to_parquet("stats/gs_adv_noemb_2gnn.parquet")
df_stats.tail()

  0%|          | 0/81 [00:00<?, ?it/s]

Evaluating on the lite validation set...


  0%|          | 0/2 [00:00<?, ?it/s]

Epoch: 000, Loss: 0.4198, AUC: 0.9289, Precision: 0.8353, Recall: 0.8740, F1: 0.8542


  0%|          | 0/81 [00:00<?, ?it/s]

Evaluating on the lite validation set...


  0%|          | 0/2 [00:00<?, ?it/s]

Epoch: 001, Loss: 0.3122, AUC: 0.9416, Precision: 0.8519, Recall: 0.8884, F1: 0.8698


  0%|          | 0/81 [00:00<?, ?it/s]

Evaluating on the lite validation set...


  0%|          | 0/2 [00:00<?, ?it/s]

Epoch: 002, Loss: 0.3016, AUC: 0.9435, Precision: 0.8573, Recall: 0.8875, F1: 0.8721


  0%|          | 0/81 [00:00<?, ?it/s]

Evaluating on the lite validation set...


  0%|          | 0/2 [00:00<?, ?it/s]

Epoch: 003, Loss: 0.2968, AUC: 0.9445, Precision: 0.8568, Recall: 0.8923, F1: 0.8742


  0%|          | 0/81 [00:00<?, ?it/s]

Evaluating on the lite validation set...


  0%|          | 0/2 [00:00<?, ?it/s]

Epoch: 004, Loss: 0.2917, AUC: 0.9461, Precision: 0.8601, Recall: 0.8942, F1: 0.8768


  0%|          | 0/81 [00:00<?, ?it/s]

Evaluating on the lite validation set...


  0%|          | 0/2 [00:00<?, ?it/s]

Epoch: 005, Loss: 0.2873, AUC: 0.9471, Precision: 0.8688, Recall: 0.8819, F1: 0.8753


  0%|          | 0/81 [00:00<?, ?it/s]

Evaluating on the lite validation set...


  0%|          | 0/2 [00:00<?, ?it/s]

Epoch: 006, Loss: 0.2830, AUC: 0.9485, Precision: 0.8645, Recall: 0.8956, F1: 0.8798


  0%|          | 0/81 [00:00<?, ?it/s]

Evaluating on the lite validation set...


  0%|          | 0/2 [00:00<?, ?it/s]

Epoch: 007, Loss: 0.2791, AUC: 0.9494, Precision: 0.8664, Recall: 0.8977, F1: 0.8818


  0%|          | 0/81 [00:00<?, ?it/s]

Evaluating on the lite validation set...


  0%|          | 0/2 [00:00<?, ?it/s]

Epoch: 008, Loss: 0.2759, AUC: 0.9498, Precision: 0.8697, Recall: 0.8923, F1: 0.8808


  0%|          | 0/81 [00:00<?, ?it/s]

Evaluating on the lite validation set...


  0%|          | 0/2 [00:00<?, ?it/s]

Epoch: 009, Loss: 0.2733, AUC: 0.9503, Precision: 0.8639, Recall: 0.9055, F1: 0.8842


  0%|          | 0/81 [00:00<?, ?it/s]

Evaluating on the lite validation set...


  0%|          | 0/2 [00:00<?, ?it/s]

Epoch: 010, Loss: 0.2712, AUC: 0.9504, Precision: 0.8586, Recall: 0.9125, F1: 0.8847


  0%|          | 0/81 [00:00<?, ?it/s]

Evaluating on the lite validation set...


  0%|          | 0/2 [00:00<?, ?it/s]

Epoch: 011, Loss: 0.2681, AUC: 0.9509, Precision: 0.8665, Recall: 0.9027, F1: 0.8843


  0%|          | 0/81 [00:00<?, ?it/s]

Evaluating on the lite validation set...


  0%|          | 0/2 [00:00<?, ?it/s]

Epoch: 012, Loss: 0.2657, AUC: 0.9515, Precision: 0.8595, Recall: 0.9149, F1: 0.8864


  0%|          | 0/81 [00:00<?, ?it/s]

Evaluating on the lite validation set...


  0%|          | 0/2 [00:00<?, ?it/s]

Epoch: 013, Loss: 0.2632, AUC: 0.9518, Precision: 0.8626, Recall: 0.9120, F1: 0.8866


  0%|          | 0/81 [00:00<?, ?it/s]

Evaluating on the lite validation set...


  0%|          | 0/2 [00:00<?, ?it/s]

Epoch: 014, Loss: 0.2602, AUC: 0.9514, Precision: 0.8608, Recall: 0.9141, F1: 0.8867
Early stopping at epoch 14


  0%|          | 0/19 [00:00<?, ?it/s]

AUC:  0.9507084269152051
              precision    recall  f1-score   support

         0.0       0.91      0.85      0.88    298582
         1.0       0.86      0.91      0.88    298582

    accuracy                           0.88    597164
   macro avg       0.88      0.88      0.88    597164
weighted avg       0.88      0.88      0.88    597164



Unnamed: 0,epoch,train_loss,val_auc,val_precision,val_recall,val_f1
9,9,0.273253,0.950337,0.863869,0.90554,0.884214
10,10,0.27121,0.950357,0.858556,0.912453,0.884684
11,11,0.268074,0.950871,0.866512,0.90275,0.88426
12,12,0.265689,0.951491,0.859535,0.914933,0.886369
13,13,0.263183,0.951765,0.862571,0.911988,0.886592


In [None]:
print("gs_adv_noemb num parameters:", sum(p.numel() for p in gs_adv_noemb.parameters()))
print("gs_adv_noemb_2gnn num parameters:", sum(p.numel() for p in gs_adv_noemb_2gnn.parameters()))

gs_adv_noemb num parameters: 1122305
gs_adv_noemb_2gnn num parameters: 859649


The two-layer version GraphSAGE performs slightly better than the 3-layer version. This result suggests that the deeper version (3 layers) encorage over-fitting to the training set which harm generalization.

Considering the occam's razor principle, we will proceed with 2 layers models.

In [None]:
# Hidden size: Try 128 and 512 for hidden size instead of default 256
gs_adv_noemb_2gnn_128d = GraphSageLinkPredNoEmb(hidden_channels=128, 
                 user_feat_size=user_feat_size, 
                 prod_feat_size=prod_feat_size, 
                 dataset_metadata=train_meta_data,
                 n_gnn_layers=2,
                 interaction_func="mlp2layers_v1")
gs_adv_noemb_2gnn_128d, df_stats = train_early_stop(
    gs_adv_noemb_2gnn_128d, train_loader, val_loader_lite, device, 
    epochs=20, return_stats=True)
auc1, gt1, pred1 = eval_auc(gs_adv_noemb_2gnn_128d, val_loader_large, device)
print("AUC: ", auc1)
pred_proba = torch.sigmoid(torch.tensor(pred1)).cpu().numpy()
pred_binary = (pred_proba > 0.5).astype(int)
print(classification_report(gt1, pred_binary))
df_stats.to_parquet("stats/gs_adv_noemb_2gnn_128d.parquet")
df_stats.tail()


  0%|          | 0/81 [00:00<?, ?it/s]

Evaluating on the lite validation set...


  0%|          | 0/2 [00:00<?, ?it/s]

Epoch: 000, Loss: 0.4296, AUC: 0.9295, Precision: 0.8346, Recall: 0.8781, F1: 0.8558


  0%|          | 0/81 [00:00<?, ?it/s]

Evaluating on the lite validation set...


  0%|          | 0/2 [00:00<?, ?it/s]

Epoch: 001, Loss: 0.3157, AUC: 0.9408, Precision: 0.8633, Recall: 0.8689, F1: 0.8661


  0%|          | 0/81 [00:00<?, ?it/s]

Evaluating on the lite validation set...


  0%|          | 0/2 [00:00<?, ?it/s]

Epoch: 002, Loss: 0.3032, AUC: 0.9430, Precision: 0.8473, Recall: 0.8997, F1: 0.8727


  0%|          | 0/81 [00:00<?, ?it/s]

Evaluating on the lite validation set...


  0%|          | 0/2 [00:00<?, ?it/s]

Epoch: 003, Loss: 0.2982, AUC: 0.9443, Precision: 0.8718, Recall: 0.8681, F1: 0.8699


  0%|          | 0/81 [00:00<?, ?it/s]

Evaluating on the lite validation set...


  0%|          | 0/2 [00:00<?, ?it/s]

Epoch: 004, Loss: 0.2941, AUC: 0.9454, Precision: 0.8485, Recall: 0.9071, F1: 0.8768


  0%|          | 0/81 [00:00<?, ?it/s]

Evaluating on the lite validation set...


  0%|          | 0/2 [00:00<?, ?it/s]

Epoch: 005, Loss: 0.2908, AUC: 0.9465, Precision: 0.8595, Recall: 0.8969, F1: 0.8778


  0%|          | 0/81 [00:00<?, ?it/s]

Evaluating on the lite validation set...


  0%|          | 0/2 [00:00<?, ?it/s]

Epoch: 006, Loss: 0.2868, AUC: 0.9472, Precision: 0.8588, Recall: 0.8988, F1: 0.8783


  0%|          | 0/81 [00:00<?, ?it/s]

Evaluating on the lite validation set...


  0%|          | 0/2 [00:00<?, ?it/s]

Epoch: 007, Loss: 0.2842, AUC: 0.9485, Precision: 0.8725, Recall: 0.8812, F1: 0.8768


  0%|          | 0/81 [00:00<?, ?it/s]

Evaluating on the lite validation set...


  0%|          | 0/2 [00:00<?, ?it/s]

Epoch: 008, Loss: 0.2804, AUC: 0.9489, Precision: 0.8593, Recall: 0.9048, F1: 0.8815


  0%|          | 0/81 [00:00<?, ?it/s]

Evaluating on the lite validation set...


  0%|          | 0/2 [00:00<?, ?it/s]

Epoch: 009, Loss: 0.2773, AUC: 0.9497, Precision: 0.8622, Recall: 0.9065, F1: 0.8838


  0%|          | 0/81 [00:00<?, ?it/s]

Evaluating on the lite validation set...


  0%|          | 0/2 [00:00<?, ?it/s]

Epoch: 010, Loss: 0.2756, AUC: 0.9501, Precision: 0.8583, Recall: 0.9102, F1: 0.8835


  0%|          | 0/81 [00:00<?, ?it/s]

Evaluating on the lite validation set...


  0%|          | 0/2 [00:00<?, ?it/s]

Epoch: 011, Loss: 0.2735, AUC: 0.9504, Precision: 0.8702, Recall: 0.8944, F1: 0.8821


  0%|          | 0/81 [00:00<?, ?it/s]

Evaluating on the lite validation set...


  0%|          | 0/2 [00:00<?, ?it/s]

Epoch: 012, Loss: 0.2716, AUC: 0.9506, Precision: 0.8669, Recall: 0.9020, F1: 0.8841


  0%|          | 0/81 [00:00<?, ?it/s]

Evaluating on the lite validation set...


  0%|          | 0/2 [00:00<?, ?it/s]

Epoch: 013, Loss: 0.2695, AUC: 0.9509, Precision: 0.8648, Recall: 0.9014, F1: 0.8827


  0%|          | 0/81 [00:00<?, ?it/s]

Evaluating on the lite validation set...


  0%|          | 0/2 [00:00<?, ?it/s]

Epoch: 014, Loss: 0.2686, AUC: 0.9515, Precision: 0.8686, Recall: 0.9006, F1: 0.8843


  0%|          | 0/81 [00:00<?, ?it/s]

Evaluating on the lite validation set...


  0%|          | 0/2 [00:00<?, ?it/s]

Epoch: 015, Loss: 0.2669, AUC: 0.9508, Precision: 0.8728, Recall: 0.8939, F1: 0.8832
Early stopping at epoch 15


  0%|          | 0/19 [00:00<?, ?it/s]

AUC:  0.9504605318465849
              precision    recall  f1-score   support

         0.0       0.89      0.87      0.88    298582
         1.0       0.87      0.89      0.88    298582

    accuracy                           0.88    597164
   macro avg       0.88      0.88      0.88    597164
weighted avg       0.88      0.88      0.88    597164



Unnamed: 0,epoch,train_loss,val_auc,val_precision,val_recall,val_f1
10,10,0.275609,0.950138,0.858329,0.91019,0.883499
11,11,0.273525,0.950355,0.870182,0.89438,0.882115
12,12,0.271577,0.950561,0.866937,0.902037,0.884139
13,13,0.269527,0.950917,0.864813,0.901355,0.882706
14,14,0.268606,0.951506,0.868624,0.900642,0.884343


In [None]:
gs_adv_noemb_2gnn_512d = GraphSageLinkPredNoEmb(hidden_channels=512, 
                 user_feat_size=user_feat_size, 
                 prod_feat_size=prod_feat_size, 
                 dataset_metadata=train_meta_data,
                 n_gnn_layers=2,
                 interaction_func="mlp2layers_v1")
gs_adv_noemb_2gnn_512d, df_stats = train_early_stop(
    gs_adv_noemb_2gnn_512d, train_loader, val_loader_lite, device, 
    epochs=20, return_stats=True)
auc1, gt1, pred1 = eval_auc(gs_adv_noemb_2gnn_512d, val_loader_large, device)
print("AUC: ", auc1)
pred_proba = torch.sigmoid(torch.tensor(pred1)).cpu().numpy()
pred_binary = (pred_proba > 0.5).astype(int)
print(classification_report(gt1, pred_binary))
df_stats.to_parquet("stats/gs_adv_noemb_2gnn_512d.parquet")
df_stats.tail()


  0%|          | 0/81 [00:00<?, ?it/s]

Evaluating on the lite validation set...


  0%|          | 0/2 [00:00<?, ?it/s]

Epoch: 000, Loss: 0.4087, AUC: 0.9386, Precision: 0.8593, Recall: 0.8700, F1: 0.8646


  0%|          | 0/81 [00:00<?, ?it/s]

Evaluating on the lite validation set...


  0%|          | 0/2 [00:00<?, ?it/s]

Epoch: 001, Loss: 0.3061, AUC: 0.9425, Precision: 0.8550, Recall: 0.8860, F1: 0.8702


  0%|          | 0/81 [00:00<?, ?it/s]

Evaluating on the lite validation set...


  0%|          | 0/2 [00:00<?, ?it/s]

Epoch: 002, Loss: 0.2988, AUC: 0.9438, Precision: 0.8515, Recall: 0.8975, F1: 0.8739


  0%|          | 0/81 [00:00<?, ?it/s]

Evaluating on the lite validation set...


  0%|          | 0/2 [00:00<?, ?it/s]

Epoch: 003, Loss: 0.2936, AUC: 0.9454, Precision: 0.8648, Recall: 0.8849, F1: 0.8747


  0%|          | 0/81 [00:00<?, ?it/s]

Evaluating on the lite validation set...


  0%|          | 0/2 [00:00<?, ?it/s]

Epoch: 004, Loss: 0.2886, AUC: 0.9470, Precision: 0.8591, Recall: 0.8977, F1: 0.8779


  0%|          | 0/81 [00:00<?, ?it/s]

Evaluating on the lite validation set...


  0%|          | 0/2 [00:00<?, ?it/s]

Epoch: 005, Loss: 0.2828, AUC: 0.9485, Precision: 0.8685, Recall: 0.8905, F1: 0.8794


  0%|          | 0/81 [00:00<?, ?it/s]

Evaluating on the lite validation set...


  0%|          | 0/2 [00:00<?, ?it/s]

Epoch: 006, Loss: 0.2782, AUC: 0.9496, Precision: 0.8636, Recall: 0.9025, F1: 0.8827


  0%|          | 0/81 [00:00<?, ?it/s]

Evaluating on the lite validation set...


  0%|          | 0/2 [00:00<?, ?it/s]

Epoch: 007, Loss: 0.2742, AUC: 0.9505, Precision: 0.8771, Recall: 0.8864, F1: 0.8818


  0%|          | 0/81 [00:00<?, ?it/s]

Evaluating on the lite validation set...


  0%|          | 0/2 [00:00<?, ?it/s]

Epoch: 008, Loss: 0.2709, AUC: 0.9512, Precision: 0.8687, Recall: 0.9002, F1: 0.8842


  0%|          | 0/81 [00:00<?, ?it/s]

Evaluating on the lite validation set...


  0%|          | 0/2 [00:00<?, ?it/s]

Epoch: 009, Loss: 0.2676, AUC: 0.9510, Precision: 0.8633, Recall: 0.9058, F1: 0.8840
Early stopping at epoch 9


  0%|          | 0/19 [00:00<?, ?it/s]

AUC:  0.950500113883344
              precision    recall  f1-score   support

         0.0       0.90      0.86      0.88    298582
         1.0       0.86      0.91      0.88    298582

    accuracy                           0.88    597164
   macro avg       0.88      0.88      0.88    597164
weighted avg       0.88      0.88      0.88    597164



Unnamed: 0,epoch,train_loss,val_auc,val_precision,val_recall,val_f1
4,4,0.288583,0.947036,0.85905,0.897666,0.877933
5,5,0.282764,0.948543,0.868485,0.890535,0.879372
6,6,0.278197,0.949553,0.863631,0.902533,0.882653
7,7,0.274247,0.950465,0.877117,0.886443,0.881755
8,8,0.270922,0.951203,0.868677,0.900239,0.884176


In [None]:
print("gs_adv_noemb_2gnn_128d num parameters:", sum(p.numel() for p in gs_adv_noemb_2gnn_128d.parameters()))
print("gs_adv_noemb_2gnn num parameters:", sum(p.numel() for p in gs_adv_noemb_2gnn.parameters()))
print("gs_adv_noemb_2gnn_512d num parameters:", sum(p.numel() for p in gs_adv_noemb_2gnn_512d.parameters()))

gs_adv_noemb_2gnn_128d num parameters: 299009
gs_adv_noemb_2gnn num parameters: 859649
gs_adv_noemb_2gnn_512d num parameters: 2767361


A hidden dimension of 256 produce the best AUC on the validation set. Although 128d and 512d version doesn't perform much worse either.

In [None]:
with open(r"saved_model/gs_adv_noemb.pkl", "wb") as f:
    torch.save(gs_adv_noemb, f=f)
with open(r"saved_model/gs_adv_noemb_2gnn.pkl", "wb") as f:
    torch.save(gs_adv_noemb_2gnn, f=f)
with open(r"saved_model/gs_adv_noemb_2gnn_128d.pkl", "wb") as f:
    torch.save(gs_adv_noemb_2gnn_128d, f=f)
with open(r"saved_model/gs_adv_noemb_2gnn_512d.pkl", "wb") as f:
    torch.save(gs_adv_noemb_2gnn_512d, f=f)


# Exp 3: Graph Attention Network

With insights from GraphSAGE, we made a judgement call to focus on Advanced features, pure Graph Attention Network (without NCF-style internal embedding learning), and only message-passing GNN layers. That being said, there are a few hyperparameters we can experiment with to understand the effect of the network's architecture on it's learning effectiveness. We experiment with 2 aspect: Hidden size and the number of heads.

In [None]:
gat_128d_1h = GATLinkPredNoEmb(hidden_channels=128, 
                 user_feat_size=user_feat_size, 
                 prod_feat_size=prod_feat_size, 
                 dataset_metadata=train_meta_data,
                 n_gnn_layers=2, num_heads=1,
                 interaction_func="mlp2layers_v1")
gat_128d_1h, df_stats = train_early_stop(
    gat_128d_1h, train_loader, val_loader_lite, device, 
    epochs=20, return_stats=True)
auc1, gt1, pred1 = eval_auc(gat_128d_1h, val_loader_large, device)
print("AUC: ", auc1)
pred_proba = torch.sigmoid(torch.tensor(pred1)).cpu().numpy()
pred_binary = (pred_proba > 0.5).astype(int)
print(classification_report(gt1, pred_binary))
df_stats.to_parquet("stats/gat_128d_1h.parquet")
df_stats.tail()


  0%|          | 0/81 [00:00<?, ?it/s]

Evaluating on the lite validation set...


  0%|          | 0/2 [00:00<?, ?it/s]

Epoch: 000, Loss: 0.5176, AUC: 0.9032, Precision: 0.8390, Recall: 0.7770, F1: 0.8068


  0%|          | 0/81 [00:00<?, ?it/s]

Evaluating on the lite validation set...


  0%|          | 0/2 [00:00<?, ?it/s]

Epoch: 001, Loss: 0.3556, AUC: 0.9260, Precision: 0.8431, Recall: 0.8618, F1: 0.8523


  0%|          | 0/81 [00:00<?, ?it/s]

Evaluating on the lite validation set...


  0%|          | 0/2 [00:00<?, ?it/s]

Epoch: 002, Loss: 0.3342, AUC: 0.9312, Precision: 0.8528, Recall: 0.8620, F1: 0.8574


  0%|          | 0/81 [00:00<?, ?it/s]

Evaluating on the lite validation set...


  0%|          | 0/2 [00:00<?, ?it/s]

Epoch: 003, Loss: 0.3275, AUC: 0.9321, Precision: 0.8437, Recall: 0.8771, F1: 0.8600


  0%|          | 0/81 [00:00<?, ?it/s]

Evaluating on the lite validation set...


  0%|          | 0/2 [00:00<?, ?it/s]

Epoch: 004, Loss: 0.3243, AUC: 0.9341, Precision: 0.8489, Recall: 0.8751, F1: 0.8618


  0%|          | 0/81 [00:00<?, ?it/s]

Evaluating on the lite validation set...


  0%|          | 0/2 [00:00<?, ?it/s]

Epoch: 005, Loss: 0.3209, AUC: 0.9352, Precision: 0.8577, Recall: 0.8631, F1: 0.8604


  0%|          | 0/81 [00:00<?, ?it/s]

Evaluating on the lite validation set...


  0%|          | 0/2 [00:00<?, ?it/s]

Epoch: 006, Loss: 0.3182, AUC: 0.9358, Precision: 0.8521, Recall: 0.8744, F1: 0.8631


  0%|          | 0/81 [00:00<?, ?it/s]

Evaluating on the lite validation set...


  0%|          | 0/2 [00:00<?, ?it/s]

Epoch: 007, Loss: 0.3171, AUC: 0.9371, Precision: 0.8409, Recall: 0.8929, F1: 0.8661


  0%|          | 0/81 [00:00<?, ?it/s]

Evaluating on the lite validation set...


  0%|          | 0/2 [00:00<?, ?it/s]

Epoch: 008, Loss: 0.3136, AUC: 0.9380, Precision: 0.8472, Recall: 0.8894, F1: 0.8678


  0%|          | 0/81 [00:00<?, ?it/s]

Evaluating on the lite validation set...


  0%|          | 0/2 [00:00<?, ?it/s]

Epoch: 009, Loss: 0.3110, AUC: 0.9384, Precision: 0.8535, Recall: 0.8795, F1: 0.8663


  0%|          | 0/81 [00:00<?, ?it/s]

Evaluating on the lite validation set...


  0%|          | 0/2 [00:00<?, ?it/s]

Epoch: 010, Loss: 0.3093, AUC: 0.9394, Precision: 0.8620, Recall: 0.8691, F1: 0.8655


  0%|          | 0/81 [00:00<?, ?it/s]

Evaluating on the lite validation set...


  0%|          | 0/2 [00:00<?, ?it/s]

Epoch: 011, Loss: 0.3085, AUC: 0.9399, Precision: 0.8499, Recall: 0.8911, F1: 0.8700


  0%|          | 0/81 [00:00<?, ?it/s]

Evaluating on the lite validation set...


  0%|          | 0/2 [00:00<?, ?it/s]

Epoch: 012, Loss: 0.3070, AUC: 0.9401, Precision: 0.8429, Recall: 0.9005, F1: 0.8707


  0%|          | 0/81 [00:00<?, ?it/s]

Evaluating on the lite validation set...


  0%|          | 0/2 [00:00<?, ?it/s]

Epoch: 013, Loss: 0.3052, AUC: 0.9401, Precision: 0.8484, Recall: 0.8953, F1: 0.8712
Early stopping at epoch 13


  0%|          | 0/19 [00:00<?, ?it/s]

AUC:  0.9396213379180625
              precision    recall  f1-score   support

         0.0       0.89      0.84      0.86    298582
         1.0       0.85      0.90      0.87    298582

    accuracy                           0.87    597164
   macro avg       0.87      0.87      0.87    597164
weighted avg       0.87      0.87      0.87    597164



Unnamed: 0,epoch,train_loss,val_auc,val_precision,val_recall,val_f1
8,8,0.313612,0.938004,0.847233,0.889388,0.867799
9,9,0.311015,0.938414,0.853469,0.87953,0.866303
10,10,0.3093,0.939374,0.862001,0.869083,0.865527
11,11,0.30849,0.939883,0.849932,0.891062,0.870011
12,12,0.306987,0.940142,0.842895,0.900487,0.87074


In [None]:
gat_128d_2h = GATLinkPredNoEmb(hidden_channels=128, 
                 user_feat_size=user_feat_size, 
                 prod_feat_size=prod_feat_size, 
                 dataset_metadata=train_meta_data,
                 n_gnn_layers=2, num_heads=2,
                 interaction_func="mlp2layers_v1")
gat_128d_2h, df_stats = train_early_stop(
    gat_128d_2h, train_loader, val_loader_lite, device, 
    epochs=20, return_stats=True)
auc1, gt1, pred1 = eval_auc(gat_128d_2h, val_loader_large, device)
print("AUC: ", auc1)
pred_proba = torch.sigmoid(torch.tensor(pred1)).cpu().numpy()
pred_binary = (pred_proba > 0.5).astype(int)
print(classification_report(gt1, pred_binary))
df_stats.to_parquet("stats/gat_128d_2h.parquet")
df_stats.tail()


  0%|          | 0/81 [00:00<?, ?it/s]

Evaluating on the lite validation set...


  0%|          | 0/2 [00:00<?, ?it/s]

Epoch: 000, Loss: 0.5176, AUC: 0.9042, Precision: 0.8678, Recall: 0.7384, F1: 0.7979


  0%|          | 0/81 [00:00<?, ?it/s]

Evaluating on the lite validation set...


  0%|          | 0/2 [00:00<?, ?it/s]

Epoch: 001, Loss: 0.3519, AUC: 0.9306, Precision: 0.8429, Recall: 0.8729, F1: 0.8576


  0%|          | 0/81 [00:00<?, ?it/s]

Evaluating on the lite validation set...


  0%|          | 0/2 [00:00<?, ?it/s]

Epoch: 002, Loss: 0.3263, AUC: 0.9357, Precision: 0.8432, Recall: 0.8870, F1: 0.8646


  0%|          | 0/81 [00:00<?, ?it/s]

Evaluating on the lite validation set...


  0%|          | 0/2 [00:00<?, ?it/s]

Epoch: 003, Loss: 0.3161, AUC: 0.9380, Precision: 0.8436, Recall: 0.8905, F1: 0.8664


  0%|          | 0/81 [00:00<?, ?it/s]

Evaluating on the lite validation set...


  0%|          | 0/2 [00:00<?, ?it/s]

Epoch: 004, Loss: 0.3108, AUC: 0.9395, Precision: 0.8564, Recall: 0.8763, F1: 0.8663


  0%|          | 0/81 [00:00<?, ?it/s]

Evaluating on the lite validation set...


  0%|          | 0/2 [00:00<?, ?it/s]

Epoch: 005, Loss: 0.3090, AUC: 0.9394, Precision: 0.8402, Recall: 0.9005, F1: 0.8693
Early stopping at epoch 5


  0%|          | 0/19 [00:00<?, ?it/s]

AUC:  0.9397369431063298
              precision    recall  f1-score   support

         0.0       0.89      0.83      0.86    298582
         1.0       0.84      0.90      0.87    298582

    accuracy                           0.86    597164
   macro avg       0.87      0.86      0.86    597164
weighted avg       0.87      0.86      0.86    597164



Unnamed: 0,epoch,train_loss,val_auc,val_precision,val_recall,val_f1
0,0,0.517607,0.904249,0.86785,0.738413,0.797916
1,1,0.351936,0.930595,0.842873,0.872896,0.857621
2,2,0.326312,0.935685,0.843196,0.887032,0.864559
3,3,0.3161,0.93801,0.843567,0.890535,0.866415
4,4,0.310804,0.939514,0.856398,0.876337,0.866253


In [None]:
gat_256d_2h = GATLinkPredNoEmb(hidden_channels=256, 
                 user_feat_size=user_feat_size, 
                 prod_feat_size=prod_feat_size, 
                 dataset_metadata=train_meta_data,
                 n_gnn_layers=2, num_heads=2,
                 interaction_func="mlp2layers_v1")
gat_256d_2h, df_stats = train_early_stop(
    gat_256d_2h, train_loader, val_loader_lite, device, 
    epochs=20, return_stats=True)
auc1, gt1, pred1 = eval_auc(gat_256d_2h, val_loader_large, device)
print("AUC: ", auc1)
pred_proba = torch.sigmoid(torch.tensor(pred1)).cpu().numpy()
pred_binary = (pred_proba > 0.5).astype(int)
print(classification_report(gt1, pred_binary))
df_stats.to_parquet("stats/gat_256d_2h.parquet")
df_stats.tail()


  0%|          | 0/81 [00:00<?, ?it/s]

Evaluating on the lite validation set...


  0%|          | 0/2 [00:00<?, ?it/s]

Epoch: 000, Loss: 0.4707, AUC: 0.9268, Precision: 0.8489, Recall: 0.8570, F1: 0.8529


  0%|          | 0/81 [00:00<?, ?it/s]

Evaluating on the lite validation set...


  0%|          | 0/2 [00:00<?, ?it/s]

Epoch: 001, Loss: 0.3311, AUC: 0.9342, Precision: 0.8604, Recall: 0.8573, F1: 0.8588


  0%|          | 0/81 [00:00<?, ?it/s]

Evaluating on the lite validation set...


  0%|          | 0/2 [00:00<?, ?it/s]

Epoch: 002, Loss: 0.3201, AUC: 0.9370, Precision: 0.8559, Recall: 0.8721, F1: 0.8639


  0%|          | 0/81 [00:00<?, ?it/s]

Evaluating on the lite validation set...


  0%|          | 0/2 [00:00<?, ?it/s]

Epoch: 003, Loss: 0.3120, AUC: 0.9368, Precision: 0.8799, Recall: 0.8218, F1: 0.8499
Early stopping at epoch 3


  0%|          | 0/19 [00:00<?, ?it/s]

AUC:  0.9371042581254536
              precision    recall  f1-score   support

         0.0       0.83      0.89      0.86    298582
         1.0       0.88      0.82      0.85    298582

    accuracy                           0.86    597164
   macro avg       0.86      0.86      0.86    597164
weighted avg       0.86      0.86      0.86    597164



Unnamed: 0,epoch,train_loss,val_auc,val_precision,val_recall,val_f1
0,0,0.470697,0.926768,0.848865,0.857023,0.852925
1,1,0.33111,0.934178,0.860365,0.857271,0.858815
2,2,0.320132,0.937018,0.8559,0.872059,0.863904


In [None]:
gat_256d_4h = GATLinkPredNoEmb(hidden_channels=256, 
                 user_feat_size=user_feat_size, 
                 prod_feat_size=prod_feat_size, 
                 dataset_metadata=train_meta_data,
                 n_gnn_layers=2, num_heads=4,
                 interaction_func="mlp2layers_v1")
gat_256d_4h, df_stats = train_early_stop(
    gat_256d_4h, train_loader, val_loader_lite, device, 
    epochs=20, return_stats=True)
auc1, gt1, pred1 = eval_auc(gat_256d_4h, val_loader_large, device)
print("AUC: ", auc1)
pred_proba = torch.sigmoid(torch.tensor(pred1)).cpu().numpy()
pred_binary = (pred_proba > 0.5).astype(int)
print(classification_report(gt1, pred_binary))
df_stats.to_parquet("stats/gat_256d_4h.parquet")
df_stats.tail()


  0%|          | 0/81 [00:00<?, ?it/s]

Evaluating on the lite validation set...


  0%|          | 0/2 [00:00<?, ?it/s]

Epoch: 000, Loss: 0.4780, AUC: 0.9271, Precision: 0.8499, Recall: 0.8496, F1: 0.8497


  0%|          | 0/81 [00:00<?, ?it/s]

Evaluating on the lite validation set...


  0%|          | 0/2 [00:00<?, ?it/s]

Epoch: 001, Loss: 0.3242, AUC: 0.9373, Precision: 0.8404, Recall: 0.8930, F1: 0.8659


  0%|          | 0/81 [00:00<?, ?it/s]

Evaluating on the lite validation set...


  0%|          | 0/2 [00:00<?, ?it/s]

Epoch: 002, Loss: 0.3107, AUC: 0.9401, Precision: 0.8511, Recall: 0.8860, F1: 0.8682


  0%|          | 0/81 [00:00<?, ?it/s]

Evaluating on the lite validation set...


  0%|          | 0/2 [00:00<?, ?it/s]

Epoch: 003, Loss: 0.3051, AUC: 0.9411, Precision: 0.8769, Recall: 0.8480, F1: 0.8622


  0%|          | 0/81 [00:00<?, ?it/s]

Evaluating on the lite validation set...


  0%|          | 0/2 [00:00<?, ?it/s]

Epoch: 004, Loss: 0.3028, AUC: 0.9425, Precision: 0.8583, Recall: 0.8817, F1: 0.8698


  0%|          | 0/81 [00:00<?, ?it/s]

Evaluating on the lite validation set...


  0%|          | 0/2 [00:00<?, ?it/s]

Epoch: 005, Loss: 0.2995, AUC: 0.9431, Precision: 0.8442, Recall: 0.9067, F1: 0.8743


  0%|          | 0/81 [00:00<?, ?it/s]

Evaluating on the lite validation set...


  0%|          | 0/2 [00:00<?, ?it/s]

Epoch: 006, Loss: 0.2970, AUC: 0.9442, Precision: 0.8476, Recall: 0.9039, F1: 0.8749


  0%|          | 0/81 [00:00<?, ?it/s]

Evaluating on the lite validation set...


  0%|          | 0/2 [00:00<?, ?it/s]

Epoch: 007, Loss: 0.2951, AUC: 0.9446, Precision: 0.8532, Recall: 0.8973, F1: 0.8747


  0%|          | 0/81 [00:00<?, ?it/s]

Evaluating on the lite validation set...


  0%|          | 0/2 [00:00<?, ?it/s]

Epoch: 008, Loss: 0.2940, AUC: 0.9448, Precision: 0.8661, Recall: 0.8789, F1: 0.8724


  0%|          | 0/81 [00:00<?, ?it/s]

Evaluating on the lite validation set...


  0%|          | 0/2 [00:00<?, ?it/s]

Epoch: 009, Loss: 0.2918, AUC: 0.9457, Precision: 0.8664, Recall: 0.8814, F1: 0.8738


  0%|          | 0/81 [00:00<?, ?it/s]

Evaluating on the lite validation set...


  0%|          | 0/2 [00:00<?, ?it/s]

Epoch: 010, Loss: 0.2910, AUC: 0.9458, Precision: 0.8569, Recall: 0.8969, F1: 0.8764


  0%|          | 0/81 [00:00<?, ?it/s]

Evaluating on the lite validation set...


  0%|          | 0/2 [00:00<?, ?it/s]

Epoch: 011, Loss: 0.2898, AUC: 0.9462, Precision: 0.8701, Recall: 0.8767, F1: 0.8734


  0%|          | 0/81 [00:00<?, ?it/s]

Evaluating on the lite validation set...


  0%|          | 0/2 [00:00<?, ?it/s]

Epoch: 012, Loss: 0.2885, AUC: 0.9464, Precision: 0.8541, Recall: 0.9050, F1: 0.8788


  0%|          | 0/81 [00:00<?, ?it/s]

Evaluating on the lite validation set...


  0%|          | 0/2 [00:00<?, ?it/s]

Epoch: 013, Loss: 0.2874, AUC: 0.9465, Precision: 0.8539, Recall: 0.9049, F1: 0.8786


  0%|          | 0/81 [00:00<?, ?it/s]

Evaluating on the lite validation set...


  0%|          | 0/2 [00:00<?, ?it/s]

Epoch: 014, Loss: 0.2859, AUC: 0.9465, Precision: 0.8481, Recall: 0.9138, F1: 0.8798
Early stopping at epoch 14


  0%|          | 0/19 [00:00<?, ?it/s]

AUC:  0.9465907407837573
              precision    recall  f1-score   support

         0.0       0.91      0.84      0.87    298582
         1.0       0.85      0.91      0.88    298582

    accuracy                           0.87    597164
   macro avg       0.88      0.87      0.87    597164
weighted avg       0.88      0.87      0.87    597164



Unnamed: 0,epoch,train_loss,val_auc,val_precision,val_recall,val_f1
9,9,0.291847,0.94573,0.866429,0.88139,0.873845
10,10,0.290962,0.945776,0.856881,0.89686,0.876414
11,11,0.289783,0.946152,0.870104,0.87674,0.87341
12,12,0.288462,0.946414,0.854062,0.904951,0.878771
13,13,0.287353,0.94652,0.853858,0.90492,0.878648


In [None]:
gat_512d_4h = GATLinkPredNoEmb(hidden_channels=512, 
                 user_feat_size=user_feat_size, 
                 prod_feat_size=prod_feat_size, 
                 dataset_metadata=train_meta_data,
                 n_gnn_layers=2, num_heads=4,
                 interaction_func="mlp2layers_v1")
gat_512d_4h, df_stats = train_early_stop(
    gat_512d_4h, train_loader, val_loader_lite, device, 
    epochs=20, return_stats=True)
auc1, gt1, pred1 = eval_auc(gat_512d_4h, val_loader_large, device)
print("AUC: ", auc1)
pred_proba = torch.sigmoid(torch.tensor(pred1)).cpu().numpy()
pred_binary = (pred_proba > 0.5).astype(int)
print(classification_report(gt1, pred_binary))
df_stats.to_parquet("stats/gat_512d_4h.parquet")
df_stats.tail()


  0%|          | 0/81 [00:00<?, ?it/s]

Evaluating on the lite validation set...


  0%|          | 0/2 [00:00<?, ?it/s]

Epoch: 000, Loss: 0.4679, AUC: 0.9261, Precision: 0.8422, Recall: 0.8580, F1: 0.8500


  0%|          | 0/81 [00:00<?, ?it/s]

Evaluating on the lite validation set...


  0%|          | 0/2 [00:00<?, ?it/s]

Epoch: 001, Loss: 0.3253, AUC: 0.9378, Precision: 0.8500, Recall: 0.8828, F1: 0.8661


  0%|          | 0/81 [00:00<?, ?it/s]

Evaluating on the lite validation set...


  0%|          | 0/2 [00:00<?, ?it/s]

Epoch: 002, Loss: 0.3095, AUC: 0.9410, Precision: 0.8461, Recall: 0.8956, F1: 0.8701


  0%|          | 0/81 [00:00<?, ?it/s]

Evaluating on the lite validation set...


  0%|          | 0/2 [00:00<?, ?it/s]

Epoch: 003, Loss: 0.3043, AUC: 0.9424, Precision: 0.8550, Recall: 0.8890, F1: 0.8717


  0%|          | 0/81 [00:00<?, ?it/s]

Evaluating on the lite validation set...


  0%|          | 0/2 [00:00<?, ?it/s]

Epoch: 004, Loss: 0.3004, AUC: 0.9433, Precision: 0.8618, Recall: 0.8828, F1: 0.8722


  0%|          | 0/81 [00:00<?, ?it/s]

Evaluating on the lite validation set...


  0%|          | 0/2 [00:00<?, ?it/s]

Epoch: 005, Loss: 0.2977, AUC: 0.9445, Precision: 0.8641, Recall: 0.8807, F1: 0.8723


  0%|          | 0/81 [00:00<?, ?it/s]

Evaluating on the lite validation set...


  0%|          | 0/2 [00:00<?, ?it/s]

Epoch: 006, Loss: 0.2955, AUC: 0.9443, Precision: 0.8491, Recall: 0.9011, F1: 0.8744
Early stopping at epoch 6


  0%|          | 0/19 [00:00<?, ?it/s]

AUC:  0.944198776796188
              precision    recall  f1-score   support

         0.0       0.90      0.84      0.87    298582
         1.0       0.85      0.90      0.87    298582

    accuracy                           0.87    597164
   macro avg       0.87      0.87      0.87    597164
weighted avg       0.87      0.87      0.87    597164



Unnamed: 0,epoch,train_loss,val_auc,val_precision,val_recall,val_f1
1,1,0.325265,0.93785,0.849975,0.882754,0.866054
2,2,0.309459,0.940974,0.846089,0.895558,0.87012
3,3,0.30431,0.942354,0.855009,0.889016,0.871681
4,4,0.300442,0.943306,0.861821,0.882847,0.872207
5,5,0.297716,0.944451,0.864122,0.880677,0.872321


In [33]:
gat_512d_8h = GATLinkPredNoEmb(hidden_channels=512, 
                 user_feat_size=user_feat_size, 
                 prod_feat_size=prod_feat_size, 
                 dataset_metadata=train_meta_data,
                 n_gnn_layers=2, num_heads=8,
                 interaction_func="mlp2layers_v1")
gat_512d_8h, df_stats = train_early_stop(
    gat_512d_8h, train_loader, val_loader_lite, device, 
    epochs=20, return_stats=True)
auc1, gt1, pred1 = eval_auc(gat_512d_8h, val_loader_large, device)
print("AUC: ", auc1)
pred_proba = torch.sigmoid(torch.tensor(pred1)).cpu().numpy()
pred_binary = (pred_proba > 0.5).astype(int)
print(classification_report(gt1, pred_binary))
df_stats.to_parquet("stats/gat_512d_8h.parquet")
df_stats.tail()


  0%|          | 0/81 [00:00<?, ?it/s]

Evaluating on the lite validation set...


  0%|          | 0/2 [00:00<?, ?it/s]

Epoch: 000, Loss: 0.4696, AUC: 0.9317, Precision: 0.8422, Recall: 0.8778, F1: 0.8596


  0%|          | 0/81 [00:00<?, ?it/s]

Evaluating on the lite validation set...


  0%|          | 0/2 [00:00<?, ?it/s]

Epoch: 001, Loss: 0.3186, AUC: 0.9386, Precision: 0.8507, Recall: 0.8817, F1: 0.8660


  0%|          | 0/81 [00:00<?, ?it/s]

Evaluating on the lite validation set...


  0%|          | 0/2 [00:00<?, ?it/s]

Epoch: 002, Loss: 0.3062, AUC: 0.9417, Precision: 0.8540, Recall: 0.8857, F1: 0.8695


  0%|          | 0/81 [00:00<?, ?it/s]

Evaluating on the lite validation set...


  0%|          | 0/2 [00:00<?, ?it/s]

Epoch: 003, Loss: 0.3011, AUC: 0.9435, Precision: 0.8469, Recall: 0.9031, F1: 0.8741


  0%|          | 0/81 [00:00<?, ?it/s]

Evaluating on the lite validation set...


  0%|          | 0/2 [00:00<?, ?it/s]

Epoch: 004, Loss: 0.2971, AUC: 0.9449, Precision: 0.8492, Recall: 0.9013, F1: 0.8745


  0%|          | 0/81 [00:00<?, ?it/s]

Evaluating on the lite validation set...


  0%|          | 0/2 [00:00<?, ?it/s]

Epoch: 005, Loss: 0.2948, AUC: 0.9448, Precision: 0.8515, Recall: 0.9013, F1: 0.8756
Early stopping at epoch 5


  0%|          | 0/19 [00:00<?, ?it/s]

AUC:  0.9446599423615921
              precision    recall  f1-score   support

         0.0       0.90      0.84      0.87    298582
         1.0       0.85      0.90      0.88    298582

    accuracy                           0.87    597164
   macro avg       0.87      0.87      0.87    597164
weighted avg       0.87      0.87      0.87    597164



Unnamed: 0,epoch,train_loss,val_auc,val_precision,val_recall,val_f1
0,0,0.469568,0.931729,0.842177,0.877763,0.859602
1,1,0.318624,0.938609,0.850742,0.881731,0.865959
2,2,0.306215,0.941736,0.853957,0.885699,0.869539
3,3,0.301101,0.943544,0.846852,0.90306,0.874053
4,4,0.297133,0.944887,0.849243,0.901293,0.874494


# Test performance

In [34]:
test_data = prep_dataset(test_edges, 
                        test_labels, 
                        user_features, prod_features)
# Define the validation seed edges:
test_loader = LinkNeighborLoader(
    data=test_data,
    num_neighbors=[20, 10, 20],
    edge_label_index=(("user", "buy", "prod"), 
                      test_data["user", "buy", "prod"].edge_label_index),
    edge_label=test_data["user", "buy", "prod"].edge_label,
    batch_size=BATCH_SIZE * 4,
    shuffle=False,
)


In [None]:
# GraphSAGE models
# gs_adv_noemb
# gs_adv_noemb_2gnn
# gs_adv_noemb_2gnn_128d
# gs_adv_noemb_2gnn_512d


In [None]:
print("gs_adv_noemb num parameters:", sum(p.numel() for p in gs_adv_noemb.parameters()))
print("Performance on test set:")
auc2, gt2, pred2 = eval_auc(gs_adv_noemb, test_loader, device)
print("AUC: ", auc2)
pred_proba = torch.sigmoid(torch.tensor(pred2)).cpu().numpy()
pred_binary = (pred_proba > 0.5).astype(int)
print(classification_report(gt2, pred_binary))

gs_adv_noemb num parameters: 1122305
Performance on test set:


  0%|          | 0/21 [00:00<?, ?it/s]

AUC:  0.9500063038576192
              precision    recall  f1-score   support

         0.0       0.90      0.86      0.88    328203
         1.0       0.86      0.90      0.88    328203

    accuracy                           0.88    656406
   macro avg       0.88      0.88      0.88    656406
weighted avg       0.88      0.88      0.88    656406



In [None]:
print("gs_adv_noemb_2gnn num parameters:", sum(p.numel() for p in gs_adv_noemb_2gnn.parameters()))
print("Performance on test set:")
auc2, gt2, pred2 = eval_auc(gs_adv_noemb_2gnn, test_loader, device)
print("AUC: ", auc2)
pred_proba = torch.sigmoid(torch.tensor(pred2)).cpu().numpy()
pred_binary = (pred_proba > 0.5).astype(int)
print(classification_report(gt2, pred_binary))

gs_adv_noemb_2gnn num parameters: 859649
Performance on test set:


  0%|          | 0/21 [00:00<?, ?it/s]

AUC:  0.9504444586320195
              precision    recall  f1-score   support

         0.0       0.90      0.85      0.88    328203
         1.0       0.86      0.91      0.88    328203

    accuracy                           0.88    656406
   macro avg       0.88      0.88      0.88    656406
weighted avg       0.88      0.88      0.88    656406



In [None]:
print("gs_adv_noemb_2gnn_128d num parameters:", sum(p.numel() for p in gs_adv_noemb_2gnn_128d.parameters()))
print("Performance on test set:")
auc2, gt2, pred2 = eval_auc(gs_adv_noemb_2gnn_128d, test_loader, device)
print("AUC: ", auc2)
pred_proba = torch.sigmoid(torch.tensor(pred2)).cpu().numpy()
pred_binary = (pred_proba > 0.5).astype(int)
print(classification_report(gt2, pred_binary))

gs_adv_noemb_2gnn_128d num parameters: 299009
Performance on test set:


  0%|          | 0/21 [00:00<?, ?it/s]

AUC:  0.9501906084608093
              precision    recall  f1-score   support

         0.0       0.89      0.87      0.88    328203
         1.0       0.87      0.89      0.88    328203

    accuracy                           0.88    656406
   macro avg       0.88      0.88      0.88    656406
weighted avg       0.88      0.88      0.88    656406



In [None]:
print("gs_adv_noemb_2gnn_512d num parameters:", sum(p.numel() for p in gs_adv_noemb_2gnn_512d.parameters()))
print("Performance on test set:")
auc2, gt2, pred2 = eval_auc(gs_adv_noemb_2gnn_512d, test_loader, device)
print("AUC: ", auc2)
pred_proba = torch.sigmoid(torch.tensor(pred2)).cpu().numpy()
pred_binary = (pred_proba > 0.5).astype(int)
print(classification_report(gt2, pred_binary))

gs_adv_noemb_2gnn_512d num parameters: 2767361
Performance on test set:


  0%|          | 0/21 [00:00<?, ?it/s]

AUC:  0.9502408344139297
              precision    recall  f1-score   support

         0.0       0.90      0.86      0.88    328203
         1.0       0.86      0.90      0.88    328203

    accuracy                           0.88    656406
   macro avg       0.88      0.88      0.88    656406
weighted avg       0.88      0.88      0.88    656406



In [None]:
# Graph Attention Models

In [36]:
print("gat_128d_1h num parameters:", sum(p.numel() for p in gat_128d_1h.parameters()))
print("Performance on test set:")
auc2, gt2, pred2 = eval_auc(gat_128d_1h, test_loader, device)
print("AUC: ", auc2)
pred_proba = torch.sigmoid(torch.tensor(pred2)).cpu().numpy()
pred_binary = (pred_proba > 0.5).astype(int)
print(classification_report(gt2, pred_binary))

gat_128d_1h num parameters: 300545
Performance on test set:


  0%|          | 0/21 [00:00<?, ?it/s]

AUC:  0.9395048443990365
              precision    recall  f1-score   support

         0.0       0.89      0.84      0.86    328203
         1.0       0.85      0.89      0.87    328203

    accuracy                           0.87    656406
   macro avg       0.87      0.87      0.87    656406
weighted avg       0.87      0.87      0.87    656406



In [37]:
print("gat_256d_4h num parameters:", sum(p.numel() for p in gat_256d_4h.parameters()))
print("Performance on test set:")
auc2, gt2, pred2 = eval_auc(gat_256d_4h, test_loader, device)
print("AUC: ", auc2)
pred_proba = torch.sigmoid(torch.tensor(pred2)).cpu().numpy()
pred_binary = (pred_proba > 0.5).astype(int)
print(classification_report(gt2, pred_binary))

gat_256d_4h num parameters: 862721
Performance on test set:


  0%|          | 0/21 [00:00<?, ?it/s]

AUC:  0.9461773800112936
              precision    recall  f1-score   support

         0.0       0.90      0.84      0.87    328203
         1.0       0.85      0.91      0.88    328203

    accuracy                           0.87    656406
   macro avg       0.88      0.87      0.87    656406
weighted avg       0.88      0.87      0.87    656406



In [35]:
print("gat_512d_8h num parameters:", sum(p.numel() for p in gat_512d_8h.parameters()))
print("Performance on test set:")
auc2, gt2, pred2 = eval_auc(gat_512d_8h, test_loader, device)
print("AUC: ", auc2)
pred_proba = torch.sigmoid(torch.tensor(pred2)).cpu().numpy()
pred_binary = (pred_proba > 0.5).astype(int)
print(classification_report(gt2, pred_binary))

gat_512d_8h num parameters: 2773505
Performance on test set:


  0%|          | 0/21 [00:00<?, ?it/s]

AUC:  0.9441353539310111
              precision    recall  f1-score   support

         0.0       0.89      0.84      0.87    328203
         1.0       0.85      0.90      0.87    328203

    accuracy                           0.87    656406
   macro avg       0.87      0.87      0.87    656406
weighted avg       0.87      0.87      0.87    656406



In [38]:
with open(r"saved_model/gat_128d_1h.pkl", "wb") as f:
    torch.save(gat_128d_1h, f=f)
with open(r"saved_model/gat_256d_4h.pkl", "wb") as f:
    torch.save(gat_256d_4h, f=f)
with open(r"saved_model/gat_512d_8h.pkl", "wb") as f:
    torch.save(gat_512d_8h, f=f)
    

# References
1. [Pytorch Geometric's tutorial for Heterogeneous graphs](https://pytorch-geometric.readthedocs.io/en/latest/tutorial/heterogeneous.html)
2. [Inductive Representation Learning on Large Graphs by Hamilton et al - GraphSAGE paper](https://arxiv.org/abs/1706.02216)