In [None]:
%load_ext autoreload
%autoreload 2

import numpy as np
import pandas as pd
import torch, torch.optim as optim, torch.nn as nn
import sys, os
sys.path.append("./..")
from dataset.ASMGMovieLens import ASMGMovieLens
from utils.save import get_timestamp, save_as_json

from torch.utils.data import DataLoader
from model import MF

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
print(f'{device = }')

In [None]:
os.cpu_count()

In [None]:
train_params = dict(
    input_path="../../data/preprocessed/ml_processed.csv",
    test_start_period=25,
    train_window=10,
    n_epochs=10,
    batch_size = 64,
    model_checkpoint_dir = './../../model/MF',
    model_stem_filename = 'first_mf'
)
model_params = dict(
    n_users=43183,
    n_items=51149,
    n_latents=10
)

In [None]:
# initialize trianing variables
test_period = train_params["test_start_period"]
train_window_begin = train_params["test_start_period"] - train_params["train_window"]
train_window_end = train_params["test_start_period"] - 1 


In [None]:
train_data = ASMGMovieLens(train_params["input_path"], train_window_begin, train_window_end)
test_data = ASMGMovieLens(train_params["input_path"], test_period)

data_df = pd.read_csv(train_params["input_path"])

data_df.head()




df_train = data_df.loc[
    select_period(train_window_begin, train_window_end), :]
df_test = data_df.loc[
    select_period(test_period), :]


df_train.info()

In [None]:
# initialize dataloaders
# torch.manual_seed(6202)
# train_dataloader = DataLoader(train_data, batch_size=64, shuffle=True)
# test_dataloader = DataLoader(test_data, batch_size=64, shuffle=True)
# for x_minibatch, y_minibatch in train_dataloader:

#     # parse 
#     user_minibatch = x_minibatch[:, 0].squeeze()
#     item_minibatch = x_minibatch[:, 1].squeeze()
#     break

In [None]:
mf = MF(
    model_params["n_users"], model_params["n_items"], model_params["n_latents"]).to(device)

In [None]:
len(train_data)/64

In [None]:
# get model checkpoint path
timestamp = get_timestamp()

if not os.path.exists(train_params["model_checkpoint_dir"]):
    os.makedirs(train_params["model_checkpoint_dir"])

model_checkpoint_path = f'{train_params["model_checkpoint_dir"]}/' \
    f'{train_params["model_stem_filename"]}-{timestamp}.pth'


In [None]:
torch.manual_seed(6202)

loss_function = nn.BCELoss()
optimizer = optim.Adam(mf.parameters())

# train
for epoch in range(1, train_params["n_epochs"] + 1):

    cum_epoch_loss = 0.
    running_loss = 0.

    # set mopdel on training mode
    mf.train()
    train_dataloader = DataLoader(
        train_data, batch_size=train_params["batch_size"], shuffle=True,
        num_workers=os.cpu_count())

    for i, (x_minibatch, y_minibatch) in enumerate(train_dataloader):

        # parse input
        user_minibatch = x_minibatch[:, 0].squeeze().to(device)
        item_minibatch = x_minibatch[:, 1].squeeze().to(device)

        # do not acumulate the gradients from last mini-batch
        optimizer.zero_grad()

        # get loss
        scores = mf(user_minibatch, item_minibatch)
        loss = loss_function(scores, y_minibatch.to(device))
        sum_loss = loss.item() * train_params["batch_size"]
        cum_epoch_loss += sum_loss

        # compute backpropagation gradients
        loss.backward()

        # update parameters
        optimizer.step()

        # report mean loss per item
        running_loss += sum_loss 
        if i % 2000 == 1999:    # print every 2000 mini-batches
            print(f'[{epoch = }, {i + 1:5d}]'
                  f'loss: {running_loss / 2000 /train_params["batch_size"]:.3f}',
                   end="\r")
            running_loss = 0.

    # report epoch statistics
    epoch_loss = cum_epoch_loss / len(train_data)
    print(f"{epoch_loss = :.3f}")


In [None]:
# save model
torch.save(mf.state_dict(), model_checkpoint_path)

In [None]:
save_as_json({**train_params, **model_params}, model_checkpoint_path.replace(".pth", ""))

In [None]:
model2 = MF(
    model_params["n_users"], model_params["n_items"], model_params["n_latents"])
model2.load_state_dict(torch.load(model_checkpoint_path))


# test
test_dataloader = DataLoader(
    test_data, batch_size=train_params["batch_size"], shuffle=True)

with torch.no_grad():
    for i, (x_minibatch, y_minibatch) in enumerate(train_dataloader):

        # parse input
        user_minibatch = x_minibatch[:, 0].squeeze()
        item_minibatch = x_minibatch[:, 1].squeeze()