In [None]:
"""
Explore regularization path using warm starts and convex optimization.
"""

import sys
sys.path.append("..")

In [None]:
import numpy as np
import matplotlib.pyplot as plt

from scnn.optimize import optimize_path
from scnn.private.utils.data import gen_classification_data
from scnn.models import ConvexGatedReLU, ConvexReLU
from scnn.solvers import RFISTA, AL
from scnn.regularizers import NeuronGL1
from scnn.metrics import Metrics
from scnn.activations import sample_gate_vectors

In [None]:
# Generate realizable synthetic classification problem (ie. Figure 1)
n_train = 250
n_test = 250
d = 50
hidden_units = 100
kappa = 10  # condition number

(X_train, y_train), (X_test, y_test) = gen_classification_data(123, n_train, n_test, d, hidden_units, kappa)

In [None]:
# Instantiate convex model and other options.
max_neurons = 500
lambda_path = [0.1, 0.01, 0.001, 0.0001, 0.00001]

G = sample_gate_vectors(np.random.default_rng(123), d, max_neurons)
path = [NeuronGL1(lam) for lam in lambda_path]
metrics = Metrics(metric_freq=25, model_loss=True, train_accuracy=True, train_mse=True, test_mse=True, test_accuracy=True, neuron_sparsity=True)

# 1. Gated ReLU Models

In [None]:
model = ConvexGatedReLU(G)
solver = RFISTA(model, tol=1e-6)

In [None]:
gated_model_path, gated_metric_path = optimize_path(
    model,
    solver,
    path,
    metrics,
    X_train, 
    y_train, 
    X_test, 
    y_test,
    verbose=True,
)

# 2. ReLU Models

In [None]:
model = ConvexReLU(G)
solver = AL(model)

In [None]:
relu_model_path, relu_metric_path = optimize_path(
    model,
    solver,
    path,
    metrics,
    X_train, 
    y_train, 
    X_test, 
    y_test,
    verbose=True,
)

In [None]:
# extract summary statistics:
gated_accuracies = [metrics.train_accuracy[-1] for metrics in gated_metric_path]
gated_sparsities = [metrics.neuron_sparsity[-1] for metrics in gated_metric_path]

relu_accuracies = [metrics.train_accuracy[-1] for metrics in relu_metric_path]
relu_sparsities = [metrics.neuron_sparsity[-1] for metrics in relu_metric_path]

In [None]:
# Plot Results

fig = plt.figure(figsize=(18,6))
spec = fig.add_gridspec(ncols=2, nrows=1)
ax0 = fig.add_subplot(spec[0, 0])

ax0.plot(np.flip(lambda_path), 
         np.flip(relu_accuracies), 
         label="Relu", 
         color="#ff7f0e",
         marker="^",
         markevery=1,
         markersize=14,
         linewidth="3")

ax0.plot(np.flip(lambda_path), 
         np.flip(gated_accuracies), 
         label="Gated Relu", 
         color="#1f77b4",
         marker="v", 
         markevery=1,
         markersize=14,
         linewidth="3")

ax0.set_title("Test Accuracy", fontsize=22)
ax0.set_xlabel("Regularization Strength", fontsize=18)
ax0.set_xscale("log")

ax1 = fig.add_subplot(spec[0, 1])

ax1.plot(np.flip(lambda_path), 
         np.flip(relu_sparsities), 
         label="Relu", 
         color="#ff7f0e",
         marker="^",
         markevery=1,
         markersize=14,
         linewidth="3")

ax1.plot(np.flip(lambda_path), 
         np.flip(gated_sparsities), 
         label="Gated Relu", 
         color="#1f77b4",
         marker="v", 
         markevery=1,
         markersize=14,
         linewidth="3")

ax1.set_title("Neuron Sparsity", fontsize=22)
ax1.set_xlabel("Regularization Strength", fontsize=18)
ax1.set_xscale("log")
handles, labels = ax0.get_legend_handles_labels()
legend = fig.legend(
    handles=handles,
    labels=labels,
    loc="lower center",
    borderaxespad=0.1,
    fancybox=False,
    shadow=False,
    ncol=2,
    fontsize=16,
    frameon=False,
)
fig.subplots_adjust(
    bottom=0.15,
)