# FM Training - Brand-Specific

Factorization Machines training with MLflow and Optuna.
Hyperparameter ranges vary per brand, data size is static.

In [None]:
import os
import boto3
import sagemaker
import numpy as np
import mlflow
import optuna
from sagemaker import get_execution_role

session = sagemaker.Session()
role = get_execution_role()
bucket = session.default_bucket()
region = session.boto_region_name
account_id = boto3.client('sts').get_caller_identity()['Account']

print(f'Region: {region}')
print(f'Bucket: {bucket}')
print(f'Account: {account_id}')

In [None]:
brand = 'betmax'
max_trials = '10'
early_stopping = '3'
num_factors_min = '16'
num_factors_max = '64'
epochs_min = '10'
epochs_max = '30'
experiment_name = 'fm_betmax'
project_name = 'fm-gambling-recommender'

In [None]:
max_trials = int(max_trials)
early_stopping = int(early_stopping)
num_factors_min = int(num_factors_min)
num_factors_max = int(num_factors_max)
epochs_min = int(epochs_min)
epochs_max = int(epochs_max)

N_USERS = 500
N_GAMES = 50

print(f'Brand: {brand}')
print(f'Data: {N_USERS} users, {N_GAMES} games')
print(f'Trials: {max_trials}, Early stopping: {early_stopping}')
print(f'Factors: {num_factors_min}-{num_factors_max}, Epochs: {epochs_min}-{epochs_max}')

In [None]:
os.environ['MLFLOW_ENABLE_SYSTEM_METRICS_LOGGING'] = 'true'
os.environ['MLFLOW_TRACKING_URI'] = f'arn:aws:sagemaker:{region}:{account_id}:mlflow-tracking-server/{project_name}-dev-mlflow'
mlflow.set_experiment(experiment_name)
print(f'MLflow experiment: {experiment_name}')

## Generate Data

In [None]:
from scripts.simulate_gambling_data import generate_demo_data

X_train, y_train, X_test, y_test = generate_demo_data(
    n_users=N_USERS,
    n_games=N_GAMES,
    brand=brand,
)

n_features = X_train.shape[1]
print(f'Train: {X_train.shape}, Test: {X_test.shape}')
print(f'Features: {n_features}')

In [None]:
from scripts.fm_sagemaker import write_to_s3

prefix = f'fm-training/{brand}'
train_path = write_to_s3(X_train, y_train, bucket, prefix, 'train/train.protobuf')
test_path = write_to_s3(X_test, y_test, bucket, prefix, 'test/test.protobuf')
output_path = f's3://{bucket}/{prefix}/output'

print(f'Data uploaded to s3://{bucket}/{prefix}/')

## Optuna Optimization

In [None]:
from scripts.fm_sagemaker import train_fm_model

def objective(trial):
    num_factors = trial.suggest_int('num_factors', num_factors_min, num_factors_max)
    epochs = trial.suggest_int('epochs', epochs_min, epochs_max)
    mini_batch_size = trial.suggest_categorical('mini_batch_size', [100, 200, 500])
    
    with mlflow.start_run(run_name=f'{brand}-trial-{trial.number}', nested=True):
        mlflow.log_params({
            'brand': brand,
            'num_factors': num_factors,
            'epochs': epochs,
            'mini_batch_size': mini_batch_size,
        })
        
        try:
            fm = train_fm_model(
                train_path=train_path,
                test_path=test_path,
                output_path=f'{output_path}/trial-{trial.number}',
                role=role,
                n_features=n_features,
                num_factors=num_factors,
                epochs=epochs,
                mini_batch_size=mini_batch_size,
            )
            
            job_name = fm.latest_training_job.name
            sm = boto3.client('sagemaker')
            metrics = sm.describe_training_job(TrainingJobName=job_name)
            
            test_rmse = 999.0
            for m in metrics.get('FinalMetricDataList', []):
                if m['MetricName'] == 'test:rmse':
                    test_rmse = m['Value']
                    break
            
            mlflow.log_metric('test_rmse', test_rmse)
            print(f'Trial {trial.number}: RMSE={test_rmse:.4f}')
            return test_rmse
            
        except Exception as e:
            mlflow.log_param('error', str(e)[:200])
            return 999.0


def early_stop_callback(study, trial):
    if len(study.trials) < early_stopping:
        return
    recent = [t.value for t in study.trials[-early_stopping:] if t.value and t.value < 999]
    if len(recent) >= early_stopping and study.best_value:
        if min(recent) >= study.best_value:
            study.stop()

In [None]:
with mlflow.start_run(run_name=f'{brand}-optuna-study'):
    mlflow.log_params({
        'brand': brand,
        'n_users': N_USERS,
        'n_games': N_GAMES,
        'max_trials': max_trials,
    })
    
    study = optuna.create_study(
        study_name=f'{brand}-fm-study',
        direction='minimize',
    )
    
    study.optimize(
        objective,
        n_trials=max_trials,
        callbacks=[early_stop_callback],
    )
    
    mlflow.log_params({f'best_{k}': v for k, v in study.best_params.items()})
    mlflow.log_metric('best_rmse', study.best_value)
    
    print(f'Best trial: {study.best_trial.number}')
    print(f'Best RMSE: {study.best_value:.4f}')
    print(f'Best params: {study.best_params}')

## Optuna Visualization

In [None]:
import optuna.visualization as vis

fig = vis.plot_optimization_history(study)
fig.show()

In [None]:
fig = vis.plot_param_importances(study)
fig.show()

In [None]:
fig = vis.plot_parallel_coordinate(study)
fig.show()

In [None]:
fig = vis.plot_slice(study)
fig.show()

## Summary

In [None]:
print(f'Brand: {brand}')
print(f'Completed trials: {len(study.trials)}')
print(f'Best RMSE: {study.best_value:.4f}')
print(f'Best params: {study.best_params}')