In [None]:
"""
Train convex LassoNet models on a synthetic classification dataset using convex optimization.
"""

import sys
sys.path.append("..")

In [None]:
import numpy as np
import matplotlib.pyplot as plt

from convex_nn import opt_nn
from convex_nn.datasets import generate_synthetic_classification

In [None]:
# Generate realizable synthetic classification problem (ie. Figure 1)
n_train = 250
n_test = 250
d = 50
hidden_units = 100
kappa = 10  # condition number

(X_train, y_train), (X_test, y_test) = generate_synthetic_classification(123, n_train, n_test, d, hidden_units, kappa)

In [None]:
# Step 1: Get an convex formulation by solving an initial problem.

lambda_path = np.flip(np.logspace(-2, 0, 10))

# start with huge lambda to guarantee full sparsity
lam = 10
convex_model, grelu_metrics = opt_nn(X_train, 
               y_train, 
               X_test, 
               y_test, 
               train_metrics=["accuracy"],
               test_metrics=["accuracy"],
               additional_metrics=["feature_sparsity", "active_features", "step_size"],
               max_patterns=250,
               formulation="grelu_lasso_net",
               reg_strength=lam,
               backend="numpy",
               return_convex_form=True,   # make sure to use get the convex formulation
               verbose=True)


In [None]:
# Step 2: Get the regularization path by warm-starting the optimization procedure.

sparsities = []
accuracies = []

for lam in lambda_path:
    
    print(f"\n Trying lambda = {lam}")
    
    convex_model, grelu_metrics = opt_nn(X_train, 
               y_train, 
               X_test, 
               y_test, 
               train_metrics=["accuracy"],
               test_metrics=["accuracy"],
               additional_metrics=["feature_sparsity", "active_features", "step_size"],
               warm_start=convex_model, # warm start at the previous solution 
               formulation="grelu_lasso_net",
               reg_strength=lam,
               backend="numpy",
               return_convex_form=True,
               verbose=True)
    
    # final record sparsity and test accuracy
    sparsities.append(grelu_metrics["feature_sparsity"][-1])
    accuracies.append(grelu_metrics["test_accuracy"][-1])

In [None]:
# Plot Results

fig = plt.figure(figsize=(18,6))
spec = fig.add_gridspec(ncols=2, nrows=1)
ax0 = fig.add_subplot(spec[0, 0])

ax0.plot(np.flip(lambda_path), 
         np.flip(accuracies), 
         label="Gated Relu LassoNet", 
         color="#ff7f0e",
         marker="^",
         markevery=1,
         markersize=14,
         linewidth="3")

ax0.set_title("Test Accuracy", fontsize=22)
ax0.set_xlabel("Reg. Strength", fontsize=18)
ax0.set_xscale("log")

ax1 = fig.add_subplot(spec[0, 1])

ax1.plot(np.flip(lambda_path), 
         np.flip(sparsities), 
         label="Gated Relu LassoNet", 
         color="#ff7f0e",
         marker="^",
         markevery=1,
         markersize=14,
         linewidth="3")

ax1.set_title("Feature Sparsity")
ax1.set_xlabel("Reg. Strength", fontsize=18)
ax1.set_xscale("log")
handles, labels = ax0.get_legend_handles_labels()
legend = fig.legend(
    handles=handles,
    labels=labels,
    loc="lower center",
    borderaxespad=0.1,
    fancybox=False,
    shadow=False,
    ncol=2,
    fontsize=16,
    frameon=False,
)
fig.subplots_adjust(
    bottom=0.15,
)