In [1]:
import sys
import os

project_root = os.path.abspath("..")
if project_root not in sys.path:
    sys.path.insert(0, project_root)

In [2]:
import torch
import torch.nn as nn

from src.utils.data import (
    load_data_100k, load_data_1m, load_data_monti
)

from src.GLocalKernel.kernel import KernelNet
from src.GLocalKernel.model import GlobalLocalKernelModel
from src.GLocalKernel.train import GLKTrainer

In [3]:
# Set the project root directory
PROJECT_ROOT = os.path.abspath(os.path.join(os.path.dirname(os.getcwd()), "."))

# Set the paths to the datasets
MOVIELENS_100K_PATH = os.path.join(PROJECT_ROOT, "data", "MovieLens_100K")
MOVIELENS_1M_PATH = os.path.join(PROJECT_ROOT, "data", "MovieLens_1M")
DOUBAN_MONTI_PATH = os.path.join(PROJECT_ROOT, "data", "DoubanMonti")

# Load the MovieLens 100k dataset
n_m, n_u, train_r, train_m, test_r, test_m = load_data_100k(
    path=MOVIELENS_100K_PATH,
)

.~.~.~.~.~.~.~.~.~.~.~.~.~.~.~.~.~.~.~.~.~.~.~.~.~.~.~.~.~.~.~.~.~.~.~.~.~.~.~.~
Number of users: 943
Number of movies: 1682
Number of training ratings: 80000
Number of test ratings: 20000
.~.~.~.~.~.~.~.~.~.~.~.~.~.~.~.~.~.~.~.~.~.~.~.~.~.~.~.~.~.~.~.~.~.~.~.~.~.~.~.~


In [4]:
kernel_net = KernelNet(
    n_u = n_u,
    n_layers=2,
    n_hid = 500,
    n_dim = 5,
    lambda_s=6e-3,
    lambda_2=20.,
    activation=nn.Sigmoid(),
    dropout_rate=0.5,
)
    
model = GlobalLocalKernelModel(
    local_kernel_net=kernel_net,
    gk_size=3,
    dot_scale=1.,
    n_m=n_m,
)

In [5]:
trainer = GLKTrainer(
    model=model,
    optimizer_p=torch.optim.Adam(model.local_kernel_net.parameters(), lr=5e-4),
    optimizer_f=torch.optim.Adam(model.parameters(), lr=5e-4),
    epoch_p=500,
    epoch_f=1000,
)

In [6]:
train_r_local = trainer.pre_train_model(
    train_r=torch.Tensor(train_r),
    train_m=torch.Tensor(train_m),
    test_r=torch.Tensor(test_r),
    test_m=torch.Tensor(test_m),
    tol_p=1e-4,
    patience_p=5
)

Pre-training epoch:  0 Train RMSE:  2.7412224 Test RMSE:  2.7641919
~.~.~.~.~.~.~.~.~.~.~.~.~.~.~.~.~.~.~.~.~.~.~.~.~.~.~.~.~.~.~.~.~.~.~.~.~.~.~.~.~.~.~.~.~.~.~.~.~.~.
Pre-training epoch:  5 Train RMSE:  2.5902865 Test RMSE:  2.6205754
~.~.~.~.~.~.~.~.~.~.~.~.~.~.~.~.~.~.~.~.~.~.~.~.~.~.~.~.~.~.~.~.~.~.~.~.~.~.~.~.~.~.~.~.~.~.~.~.~.~.
Pre-training epoch:  10 Train RMSE:  2.1206656 Test RMSE:  2.205946
~.~.~.~.~.~.~.~.~.~.~.~.~.~.~.~.~.~.~.~.~.~.~.~.~.~.~.~.~.~.~.~.~.~.~.~.~.~.~.~.~.~.~.~.~.~.~.~.~.~.
Pre-training epoch:  15 Train RMSE:  1.5191805 Test RMSE:  1.6055671
~.~.~.~.~.~.~.~.~.~.~.~.~.~.~.~.~.~.~.~.~.~.~.~.~.~.~.~.~.~.~.~.~.~.~.~.~.~.~.~.~.~.~.~.~.~.~.~.~.~.
Pre-training epoch:  20 Train RMSE:  1.2067753 Test RMSE:  1.255105
~.~.~.~.~.~.~.~.~.~.~.~.~.~.~.~.~.~.~.~.~.~.~.~.~.~.~.~.~.~.~.~.~.~.~.~.~.~.~.~.~.~.~.~.~.~.~.~.~.~.
Pre-training epoch:  25 Train RMSE:  1.1595027 Test RMSE:  1.185662
~.~.~.~.~.~.~.~.~.~.~.~.~.~.~.~.~.~.~.~.~.~.~.~.~.~.~.~.~.~.~.~.~.~.~.~.~.~.~.~.~.~.~.

In [7]:
trainer.fine_tune_model(
    train_r=torch.Tensor(train_r),
    train_r_local=torch.Tensor(train_r_local),
    train_m=torch.Tensor(train_m),
    test_r=torch.Tensor(test_r),
    test_m=torch.Tensor(test_m),
    tol_f=1e-5,
    patience_f=10
)

Fine-tuning epoch:  0 
 Train RMSE:  1.0287783 Train MAE:  0.82412225 Train NDCG:  0.828443425675224 
 Test RMSE:  1.0648279 Test MAE:  0.85451466 Test NDCG:  0.8266275621490495
~.~.~.~.~.~.~.~.~.~.~.~.~.~.~.~.~.~.~.~.~.~.~.~.~.~.~.~.~.~.~.~.~.~.~.~.~.~.~.~.~.~.~.~.~.~.~.~.~.~.
Fine-tuning epoch:  20 
 Train RMSE:  0.96687627 Train MAE:  0.77436453 Train NDCG:  0.8788539520981458 
 Test RMSE:  1.0143763 Test MAE:  0.8127619 Test NDCG:  0.8702093253709747
~.~.~.~.~.~.~.~.~.~.~.~.~.~.~.~.~.~.~.~.~.~.~.~.~.~.~.~.~.~.~.~.~.~.~.~.~.~.~.~.~.~.~.~.~.~.~.~.~.~.
Fine-tuning epoch:  40 
 Train RMSE:  0.9196411 Train MAE:  0.7307349 Train NDCG:  0.8901041519128633 
 Test RMSE:  0.9725099 Test MAE:  0.7726566 Test NDCG:  0.8823517473047845
~.~.~.~.~.~.~.~.~.~.~.~.~.~.~.~.~.~.~.~.~.~.~.~.~.~.~.~.~.~.~.~.~.~.~.~.~.~.~.~.~.~.~.~.~.~.~.~.~.~.
Fine-tuning epoch:  60 
 Train RMSE:  0.90016305 Train MAE:  0.7143672 Train NDCG:  0.897130768841865 
 Test RMSE:  0.95778 Test MAE:  0.76024324 Test NDCG:  0.8