In [None]:
%load_ext autoreload
%autoreload 2

import numpy as np
import pandas as pd
import torch, torch.optim as optim, torch.nn as nn
import sys, os
sys.path.append("./..")
from dataset.ASMGMovieLens import ASMGMovieLens
from utils.save import get_timestamp, save_as_json, get_path_from_re

from torch.utils.data import DataLoader # is this tqdm importer?
from model import MF

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
print(f'{device = }')

# control flow parameters


In [None]:
do_train = True
do_test = True

In [None]:
train_params = dict(
    input_path="../../data/preprocessed/ml_processed.csv",
    test_start_period=25,
    train_window=10,
    n_epochs=10,
    batch_size = 64,
    learning_rate = 1e-3, # 1e-2 is the ASMG MF implementation
    model_checkpoint_dir = './../../model/MF',
    model_filename_stem = 'first_mf'
)
model_params = dict(
    n_users=43183,
    n_items=51149,
    n_latents=8
)

In [None]:
# initialize trianing variables
test_period = train_params["test_start_period"]
train_window_begin = train_params["test_start_period"] - train_params["train_window"]
train_window_end = train_params["test_start_period"] - 1 


In [None]:
train_data = ASMGMovieLens(train_params["input_path"], train_window_begin, train_window_end)
test_data = ASMGMovieLens(train_params["input_path"], test_period)

In [None]:
mf = MF(
    model_params["n_users"], model_params["n_items"], model_params["n_latents"]).to(device)


In [None]:
len(train_data)/64

In [None]:
# get model checkpoint path
timestamp = get_timestamp()

if not os.path.exists(train_params["model_checkpoint_dir"]):
    os.makedirs(train_params["model_checkpoint_dir"])

model_checkpoint_path = f'{train_params["model_checkpoint_dir"]}/' \
    f'{train_params["model_filename_stem"]}-{timestamp}.pth'


In [None]:
torch.manual_seed(6202)

loss_function = nn.BCELoss()
optimizer = optim.Adam(mf.parameters(), lr=train_params["learning_rate"])

# train
for epoch in range(1, train_params["n_epochs"] + 1):

    if not do_train:
        break

    cum_epoch_loss = 0.
    running_loss = 0.

    # set mopdel on training mode
    mf.train()
    train_dataloader = DataLoader(
        train_data, batch_size=train_params["batch_size"], shuffle=True,
        num_workers=os.cpu_count())

    for i, (x_minibatch, y_minibatch) in enumerate(train_dataloader):

        # parse input
        user_minibatch = x_minibatch[:, 0].squeeze().to(device)
        item_minibatch = x_minibatch[:, 1].squeeze().to(device)

        # do not acumulate the gradients from last mini-batch
        optimizer.zero_grad()

        # get loss
        scores = mf(user_minibatch, item_minibatch)
        loss = loss_function(scores, y_minibatch.to(device))
        ave_loss = loss.item()
        cum_epoch_loss += ave_loss

        # compute backpropagation gradients
        loss.backward()

        # update parameters
        optimizer.step()

        # report mean loss per item
        running_loss += ave_loss
        if i % 2000 == 1999:    # print every 2000 mini-batches
            print(f'[{epoch = }, {i + 1:5d}]'
                  f'loss: {running_loss / 2000:.3f}',
                  end="\r")
            running_loss = 0.

    # report epoch statistics
    epoch_loss = cum_epoch_loss * train_params["batch_size"] / len(train_data)
    print(f"{epoch_loss = :.3f}")


In [None]:
if do_train:

    # save model
    torch.save(mf.state_dict(), model_checkpoint_path)
    print("saved model at:", model_checkpoint_path)
    save_as_json({**train_params, **model_params},
                 model_checkpoint_path.replace(".pth", ""))
elif do_test:

    # load model for test
    mf = MF(
        model_params["n_users"], model_params["n_items"], model_params["n_latents"])
    try:
        mf.load_state_dict(torch.load(model_checkpoint_path))
        print("loaded model:", model_checkpoint_path)
    except FileNotFoundError:
        latest_checkpoint_path = get_path_from_re(f'{train_params["model_checkpoint_dir"]}/'
                                                  f'{train_params["model_filename_stem"]}*.pth')
        mf.load_state_dict(torch.load(latest_checkpoint_path))
        print("loaded model:", latest_checkpoint_path)


In [None]:

# test
test_dataloader = DataLoader(
    test_data, batch_size=train_params["batch_size"], shuffle=False,
    num_workers=os.cpu_count())

cum_test_loss = 0.
running_loss = 0.


mf.to(device)

with torch.no_grad():
    for i, (x_minibatch, y_minibatch) in enumerate(test_dataloader):

        # parse input
        user_minibatch = x_minibatch[:, 0].squeeze().to(device)
        item_minibatch = x_minibatch[:, 1].squeeze().to(device)


        # get loss
        scores = mf(user_minibatch, item_minibatch)
        loss = loss_function(scores, y_minibatch.to(device))
        ave_loss = loss.item()
        cum_test_loss += ave_loss 

        # report mean loss per item
        running_loss += ave_loss
        if i % 1000 == 999:    # print every 1000 mini-batches
            print(f'running test loss: {running_loss / 1000:.3f}',
                  end="\r")
            running_loss = 0.

    # report epoch statistics
    test_loss = cum_test_loss * train_params["batch_size"] / len(test_data)
    print(f"\n{test_loss = :.3f}")