In [1]:
"""
Train convex LassoNet models on a synthetic classification dataset using conve|x optimization.
"""

import sys
sys.path.append("..")

In [2]:
import numpy as np
import matplotlib.pyplot as plt

from convex_nn.private.utils.data import gen_classification_data


from convex_nn.models import ConvexGatedReLU, ConvexReLU
from convex_nn.solvers import RFISTA, AL
from convex_nn.regularizers import NeuronGL1
from convex_nn.metrics import Metrics
from convex_nn.activations import sample_gate_vectors
from convex_nn.optimize import optimize_model, optimize

In [3]:
# Generate realizable synthetic classification problem (ie. Figure 1)
n_train = 1000
n_test = 1000
d = 50
hidden_units = 100
kappa = 10  # condition number

(X_train, y_train), (X_test, y_test) = gen_classification_data(123, n_train, n_test, d, hidden_units, kappa)

In [4]:
lam = 0.001
max_neurons = 500

In [6]:
model, metrics = optimize("gated_relu", 
                          max_neurons, 
                          lam, 
                          X_train, 
                          y_train, 
                          X_test, 
                          y_test, 
                          verbose=True,  
                          device="cpu")

INFO:convex_nn:Pre-Optimization Metrics: Train Set objective: 0.5, Train Set grad_norm: 0.08252203464508057, Train Set base_objective: 0.5, Train Set accuracy: 0.0, Test Set nc_accuracy: 0.0, group_sparsity: 1.0, 


fista:   0%|          | 0/10000 [00:00<?, ?it/s]

Train Set objective: 0.5, Train Set grad_norm: 0.08252203464508057, Train Set base_objective: 0.5, Train Set accuracy: 0.0, Test Set nc_accuracy: 0.0, group_sparsity: 1.0, 
Train Set objective: 0.20554351806640625, Train Set grad_norm: 0.01135412510484457, Train Set base_objective: 0.20554351806640625, Train Set accuracy: 0.949999988079071, Test Set nc_accuracy: 0.9120000004768372, group_sparsity: 0.0, 
Train Set objective: 0.1671169102191925, Train Set grad_norm: 0.0030854204669594765, Train Set base_objective: 0.1671169102191925, Train Set accuracy: 0.9340000152587891, Test Set nc_accuracy: 0.9259999990463257, group_sparsity: 0.0, 
Train Set objective: 0.15418946743011475, Train Set grad_norm: 0.001024787430651486, Train Set base_objective: 0.15418946743011475, Train Set accuracy: 0.9359999895095825, Test Set nc_accuracy: 0.9259999990463257, group_sparsity: 0.0, 
Train Set objective: 0.1473623514175415, Train Set grad_norm: 0.0007096211775206029, Train Set base_objective: 0.147362351

Train Set objective: 0.11709718406200409, Train Set grad_norm: 4.3172667574253865e-06, Train Set base_objective: 0.11709718406200409, Train Set accuracy: 1.0, Test Set nc_accuracy: 0.9449999928474426, group_sparsity: 0.09600000083446503, 
Train Set objective: 0.11694306135177612, Train Set grad_norm: 3.656086846604012e-06, Train Set base_objective: 0.11694306135177612, Train Set accuracy: 1.0, Test Set nc_accuracy: 0.9449999928474426, group_sparsity: 0.1120000034570694, 
Train Set objective: 0.11679454147815704, Train Set grad_norm: 3.2279076549457386e-06, Train Set base_objective: 0.11679454147815704, Train Set accuracy: 1.0, Test Set nc_accuracy: 0.9440000057220459, group_sparsity: 0.12600000202655792, 
Train Set objective: 0.11664826422929764, Train Set grad_norm: 3.1884308100416092e-06, Train Set base_objective: 0.11664826422929764, Train Set accuracy: 1.0, Test Set nc_accuracy: 0.9430000185966492, group_sparsity: 0.13600000739097595, 
Train Set objective: 0.11650378257036209, Trai

INFO:convex_nn:Termination criterion satisfied at iteration 61/10000. Exiting optimization loop.
INFO:convex_nn:Post-Optimization Metrics: Train Set objective: 0.11440114676952362, Train Set grad_norm: 9.364478046336444e-07, Train Set base_objective: 0.11440114676952362, Train Set accuracy: 1.0, Test Set nc_accuracy: 0.9449999928474426, group_sparsity: 0.3840000033378601, 


Train Set objective: 0.11440114676952362, Train Set grad_norm: 9.364478046336444e-07, Train Set base_objective: 0.11440114676952362, Train Set accuracy: 1.0, Test Set nc_accuracy: 0.9449999928474426, group_sparsity: 0.3840000033378601, 


In [9]:
# Instantiate convex model and other options.
G = sample_gate_vectors(np.random.default_rng(123), d, max_neurons)
model = ConvexGatedReLU(G)
solver = RFISTA(model, tol=1e-8)
regularizer = NeuronGL1(lam)
metrics = Metrics(metric_freq=25, model_loss=True, train_accuracy=True, train_mse=True, test_mse=True, test_accuracy=True, neuron_sparsity=True)

In [10]:
grelu_model, grelu_metrics = optimize_model(
    model,
    solver,
    metrics,
    X_train, 
    y_train, 
    X_test, 
    y_test,
    regularizer,
    verbose=True,
)

INFO:convex_nn:Pre-Optimization Metrics: Train Set objective: 0.5, Train Set grad_norm: 0.08310095220804214, Train Set base_objective: 0.5, Train Set accuracy: 0.0, Train Set nc_squared_error: 0.5, Test Set nc_accuracy: 0.0, Test Set squared_error: 0.5, group_sparsity: 1.0, 


fista:   0%|          | 0/10000 [00:00<?, ?it/s]

Train Set objective: 0.5, Train Set grad_norm: 0.08310095220804214, Train Set base_objective: 0.5, Train Set accuracy: 0.0, Train Set nc_squared_error: 0.5, Test Set nc_accuracy: 0.0, Test Set squared_error: 0.5, group_sparsity: 1.0, 
Train Set objective: 0.11960909515619278, Train Set grad_norm: 7.062597433105111e-06, Train Set base_objective: 0.11960909515619278, Train Set accuracy: 1.0, Train Set nc_squared_error: 0.03016858734190464, Test Set nc_accuracy: 0.9399999976158142, Test Set squared_error: 0.11168362200260162, group_sparsity: 0.006000000052154064, 
Train Set objective: 0.11572559922933578, Train Set grad_norm: 2.001475195356761e-06, Train Set base_objective: 0.11572559922933578, Train Set accuracy: 1.0, Train Set nc_squared_error: 0.029887598007917404, Test Set nc_accuracy: 0.9409999847412109, Test Set squared_error: 0.10983996838331223, group_sparsity: 0.2759999930858612, 
Train Set objective: 0.11423622071743011, Train Set grad_norm: 8.815756018520915e-07, Train Set base

INFO:convex_nn:Termination criterion satisfied at iteration 140/10000. Exiting optimization loop.
INFO:convex_nn:Post-Optimization Metrics: Train Set objective: 0.11366136372089386, Train Set grad_norm: 9.954113977528323e-09, Train Set base_objective: 0.11366136372089386, Train Set accuracy: 1.0, Train Set nc_squared_error: 0.02936391532421112, Test Set nc_accuracy: 0.9390000104904175, Test Set squared_error: 0.1150580644607544, group_sparsity: 0.7039999961853027, 


In [None]:
model = ConvexReLU(G)
solver = AL(model)
relu_model, relu_metrics = optimize_model(
    model,
    solver,
    metrics,
    X_train, 
    y_train, 
    X_test, 
    y_test,
    regularizer,
    verbose=True,
)

In [None]:
np.sum(np.sign(relu_model(X_train)) == y_train) / len(y_train)

In [None]:
fig = plt.figure(figsize=(18,6))
spec = fig.add_gridspec(ncols=3, nrows=1)
ax0 = fig.add_subplot(spec[0, 0])

ax0.plot(np.arange(len(relu_metrics.objective)), 
         relu_metrics.objective, 
         label="ReLU", 
         color="#ff7f0e",
         marker="^",
         markevery=0.1,
         markersize=14,
         linewidth="3")

ax0.plot(np.arange(len(grelu_metrics.objective)), 
         grelu_metrics.objective, 
         label="Gated ReLU", 
         color="#1f77b4",
         marker="v", 
         markevery=0.1,
         markersize=14,
         linewidth="3"
        )

ax0.set_title("Training Objective", fontsize=22)
ax0.set_xlabel("Iterations", fontsize=18)

ax1 = fig.add_subplot(spec[0, 1])

ax1.plot(np.arange(len(relu_metrics.train_accuracy)), 
         relu_metrics.train_accuracy, 
         label="Relu", 
         color="#ff7f0e",
         marker="^",
         markevery=0.1,
         markersize=14,
         linewidth="3")

ax1.plot(np.arange(len(grelu_metrics.train_accuracy)), 
         grelu_metrics.train_accuracy, 
         label="Gated ReLU", 
         color="#1f77b4",
         marker="v", 
         markevery=0.1,
         markersize=14,
         linewidth="3"
        )

ax1.set_title("Training Accuracy", fontsize=22)
ax1.set_xlabel("Iterations", fontsize=18)

ax2 = fig.add_subplot(spec[0, 2])


ax2.plot(np.arange(len(relu_metrics.test_accuracy)), 
         relu_metrics.test_accuracy, 
         label="Relu", 
         color="#ff7f0e",
         marker="^",
         markevery=0.1,
         markersize=14,
         linewidth="3")

ax2.plot(np.arange(len(grelu_metrics.test_accuracy)), 
         grelu_metrics.test_accuracy, 
         label="Gated ReLU", 
         color="#1f77b4",
         marker="v", 
         markevery=0.1,
         markersize=14,
         linewidth="3"
        )

ax2.set_title("Test Accuracy", fontsize=22)
ax2.set_xlabel("Iterations", fontsize=18)
handles, labels = ax0.get_legend_handles_labels()
legend = fig.legend(
    handles=handles,
    labels=labels,
    loc="lower center",
    borderaxespad=0.1,
    fancybox=False,
    shadow=False,
    ncol=2,
    fontsize=16,
    frameon=False,
)
fig.subplots_adjust(
    bottom=0.15,
)