In [None]:
import time
import numpy as np
import matplotlib.pyplot as plt
from utils import load_dataset, evaluate

# make sure you've reloaded your solvers module after all previous patches:
%load_ext autoreload
%autoreload 2
from solvers import (
    MatrixCompletionObjective,
    nuclear_norm_lmo,
    FrankWolfe,
    PairwiseFrankWolfe
)

config = {
    'dataset':     'ml-100k',        # 'ml-100k', 'jester2', or 'ml-1m'
    'test_fraction': 0.2,
    'seed':         42,
    'tau_scale':    1.0,             # multiplies ||M_true||_*
    'max_iter':     100,
    'tol':          1e-6,
    'step_method':  'analytic',      # or 'line_search'
}

# --------------------------
# 1) Load data
# --------------------------
print(f"Loading {config['dataset']} ...")
M_obs, mask_train, M_true = load_dataset(
    name=config['dataset'],
    test_fraction=config['test_fraction'],
    seed=config['seed']
)
obj = MatrixCompletionObjective(M_obs, mask_train)
tau = config['tau_scale'] * np.linalg.norm(M_true, ord='nuc')
print(f"→ tau = {tau:.3f}")

# --------------------------
# 2) Monkey‐patch solvers to capture snapshots
# --------------------------
def add_snapshots(solver):
    solver.snapshots = []
    # record initial X
    X0 = np.zeros_like(obj.M_obs)
    solver.snapshots.append(X0.copy())
    # wrap original run
    orig_run = solver.run
    def run_and_snapshot(X0=None):
        X = orig_run(X0)
        # each iteration of run already records history,
        # but we want to capture X after each update:
        # so we rebuild snapshots from the delta‐history:
        # however we don't have direct X per iter,
        # so easiest is to re‐run step by step:
        X = X0.copy() if X0 is not None else np.zeros_like(obj.M_obs)
        snapshots = [X.copy()]
        for (t, gap, objv) in solver.history:
            # repeat exactly the update sequence:
            grad = obj.gradient(X)
            atom = nuclear_norm_lmo(grad, tau)
            S = atom.to_matrix()
            d = S - X
            gamma = solver._choose_step(X, grad, S, d, t)
            X = X + gamma*d
            snapshots.append(X.copy())
        solver.snapshots = snapshots
        return X
    solver.run = run_and_snapshot
    return solver

# --------------------------
# 3) Run FW and PFW
# --------------------------
results = {}
for key, SolverCls in [('FW', FrankWolfe), ('PFW', PairwiseFrankWolfe)]:
    print(f"\n=== Running {key} ===")
    solver = SolverCls(
        objective=obj,
        lmo_fn=nuclear_norm_lmo,
        tau=tau,
        max_iter=config['max_iter'],
        tol=config['tol'],
        step_method=config['step_method']
    )
    solver = add_snapshots(solver)

    t0 = time.time()
    X_final = solver.run()
    elapsed = time.time() - t0
    print(f"{key} done in {elapsed:.2f}s, iters={len(solver.history)}")

    # build RMSE histories
    train_rmses = [
        evaluate(M_true, Xk, mask_train)
        for Xk in solver.snapshots
    ]
    test_rmses = [
        evaluate(M_true, Xk, ~mask_train)
        for Xk in solver.snapshots
    ]
    results[key] = (train_rmses, test_rmses)

# --------------------------
# 4) Plot RMSE curves
# --------------------------
plt.figure(figsize=(7,5))
iters = np.arange(len(results['FW'][0]))
plt.plot(iters, results['FW'][0], label='FW train')
plt.plot(iters, results['FW'][1], label='FW test')
plt.plot(iters, results['PFW'][0], label='PFW train')
plt.plot(iters, results['PFW'][1], label='PFW test')
plt.xlabel('Iteration')
plt.ylabel('RMSE')
plt.title(f"{config['dataset']} RMSE vs iteration (step={config['step_method']})")
plt.legend()
plt.grid(True)
plt.tight_layout()
plt.show()