# Trainer of Field-aware Factorization Machine

In [None]:
from functools import partial
import numpy as np
import pandas as pd
import re
from sklearn.model_selection import train_test_split
import torch
import torch.nn as nn
import torch.utils.data
import torchtext
import torecsys as trs
from typing import Dict, Tuple, List

In [None]:
# get samples data from movielens as a example
# trs.data.sampledata.download_ml_data(size="latest-small", dir="./data")
_, movies_df, ratings_df, _ = trs.data.sampledata.load_ml_data(size="latest-small", dir="./data")
# movies_df["year"] = movies_df.title.apply(lambda x: re.findall(r"\((\d+)\)", x))
# movies_df["year"] = movies_df.year.apply(lambda x: int(x[0]) if len(x) > 0 else np.nan)
# movies_df = pd.concat([
#     movies_df, 
#     pd.get_dummies(movies_df.genres.apply(
#         lambda x: x.split("|")).apply(pd.Series).stack()).sum(level=0)
# ], axis=1).drop(["title", "genres"], axis=1)
# merged = pd.merge(ratings_df, movies_df, on="movieId")

In [None]:
# set hyper-parameters of model
user_size = ratings_df.userId.max() + 1
item_size = ratings_df.movieId.max() + 1

embed_size = 16
num_fields = 2


In [None]:
# split data into training set and testing set
train_df, test_df = train_test_split(ratings_df, test_size=0.1)

In [None]:
# define inputs' schema and colleat_fn for dataloader
schema = {
    "userId": ["user_id", "single_index"],
    "movieId": ["movie_id", "single_index"],
    "rating": ["labels", "values"]
}
collate_fn = partial(trs.data.dataloader.dict_collate_fn, schema=schema)

In [None]:
# initialize training and testing dataset
columns = ["userId", "movieId", "rating"]
train_set = trs.data.dataset.DataFrameToDataset(train_df, columns=columns)
test_set = trs.data.dataset.DataFrameToDataset(test_df, columns=columns)

In [None]:
# initialize training and testing dataloader
train_dl = torch.utils.data.DataLoader(
    train_set, batch_size=1024, shuffle=True, 
    num_workers=0, collate_fn=collate_fn)

test_dl = torch.utils.data.DataLoader(
    test_set, batch_size=1024, shuffle=False, 
    num_workers=0, collate_fn=collate_fn)

In [None]:
# inititalize embedding fields
feat_inputs_embedding = trs.inputs.base.MultiIndicesEmbedding(
    1, [user_size, item_size]
)
field_aware_embedding = trs.inputs.base.MultiIndicesFieldAwareEmbedding(
    embed_size, [user_size, item_size]
)

# define schema of wrapper and initialize InputsWrapper
schema = {
    "feat_inputs"      : (feat_inputs_embedding, ["user_id", "movie_id"]),
    "field_emb_inputs" : (field_aware_embedding, ["user_id", "movie_id"])
}

# initialize inputs wrapper
inputs_wrapper = trs.inputs.InputsWrapper(schema)

In [None]:
# initialize field-aware factorizatiob machine model
ffm = trs.models.FieldAwareFactorizationMachineModel(embed_size, num_fields)

In [None]:
# initialize trainer to train the module
trainer = trs.Trainer(
    inputs_wrapper = inputs_wrapper, 
    model = ffm,
    epochs = 1,
    verboses = 1,
    use_jit = False
)

In [None]:
trainer.sequential

In [None]:
trainer.fit(train_dl)

In [None]:
for batch in test_dl:
    print(trainer.predict(batch))
    break