In [None]:
cd ..

In [None]:
import numpy as np
import scipy.sparse as sp

from pandas import DataFrame

from sklearn.datasets import make_regression
from sklearn.metrics import mean_squared_error, median_absolute_error
from sklearn.model_selection import train_test_split

from fastfm import als

# Callbacks in fastfm

Callbacks are one of the key features that fastfm provides to better understand, modify and tune factorization machine models.

A callback is a user provided function that get's executed at every iteration of the choosen solver. This allows the user
to easily step into and interact with the optimization routine. This tutorial shows two use-cases

- collect various performance metrics at every iteration e.g. to draw learning curves
- to influence further iterations e.g. for early stopping to avoid overfitting

## Create Learning Curves

In [None]:
X, y = make_regression(n_features=25, random_state=2)
X_train, X_test, y_train, y_test = train_test_split(sp.csc_matrix(X), y, test_size=0.33, random_state=42)

In [None]:
fm = als.FMRegression(n_iter = 50, rank = 4,  l2_reg=0.4)

i = 0
records = []

def callback(arg):
    global i
    
    # we can evaluate onlt n'th iteration to reduce the evaluation cost
    if i%2 == 0:
        pred_test = fm.predict(X_test)
        pred_train = fm.predict(X_train)
        
        rmse_test = np.sqrt(mean_squared_error(y_test, pred_test))
        rmse_train = np.sqrt(mean_squared_error(y_train, pred_train))
        
        mae_test = median_absolute_error(y_test, pred_test)
        mae_train = median_absolute_error(y_train, pred_train)
        
        records.append((rmse_train, rmse_test, mae_train, mae_test))
    i+=1
    
fm.fit(X_train, y_train, callback=callback)

iterations_processed_wo = i

df = DataFrame.from_records(records, columns=['rmse_train', 'rmse_test', 'mae_train', 'mae_test'])
df[['rmse_train', 'rmse_test']].plot()

In [None]:
df[['mae_train', 'mae_test']].plot()

## Early Stopping

The learning curves abouve clearly show that our model is overfitting already after a few iterations. Knowing this we could
now increase the regularization `l2_reg`. However let's assume we want instead to early stopping to prevent overfitting.

We therefore need to stop training as soon as the test accuracy stops to improve. Lets use our callback mechanism to achieve
this.

In [None]:
fm = als.FMRegression(n_iter = 50, rank = 4,  l2_reg=0.4)

i=0
records_mae = []
prev_ep = -1

def callback(*args):
    global prev_ep
    global i
    if i%2 == 0:
        pred_test = fm.predict(X_test)
        pred_train = fm.predict(X_train)

        mae_test = median_absolute_error(y_test, pred_test)
        mae_train = median_absolute_error(y_train, pred_train)
        # print(mae_test, mae_train)

        if records_mae:
            _, prev_test = records_mae[-1]
            if prev_test < mae_test:
                print("EARLY STOP!")
                return True 

        records_mae.append((mae_train, mae_test)) 
    i+=1
    
fm.fit(X_train, y_train, callback=callback)

iterations_processed_w = i

df = DataFrame.from_records(records_mae, columns=['mae_train', 'mae_test'])
df.plot()

In [None]:
print(f"Maximum number of Iterations: {iterations_processed_wo}")
print(f"Iterations have been stoped, via early stopping callbback, at iteration: {iterations_processed_w}")

We hope you found this simple example on how to use callback with fastfm inspiring. The mechanism allows
to easily implement more complex stopping rules or monitoring of the learning process.