In [None]:
"""
Train shallow neural networks on a synthetic classification dataset using convex optimization.
"""

import sys
sys.path.append("..")

In [None]:
import numpy as np
import matplotlib.pyplot as plt

from scnn.private.utils.data import gen_classification_data


from scnn.models import ConvexGatedReLU, ConvexReLU
from scnn.solvers import RFISTA, AL, LeastSquaresSolver, CVXPYSolver, ApproximateConeDecomposition
from scnn.regularizers import NeuronGL1, L2, L1
from scnn.metrics import Metrics
from scnn.activations import sample_gate_vectors
from scnn.optimize import optimize_model, optimize

In [None]:
# Generate realizable synthetic classification problem (ie. Figure 1)
n_train = 1000
n_test = 1000
d = 50
hidden_units = 100
kappa = 10  # condition number

(X_train, y_train), (X_test, y_test) = gen_classification_data(123, n_train, n_test, d, hidden_units, kappa)

In [None]:
lam = 0.001
max_neurons = 500
G = sample_gate_vectors(123, d, max_neurons)
metrics = Metrics(metric_freq=25, model_loss=True, train_accuracy=True, train_mse=True, test_mse=True, test_accuracy=True, neuron_sparsity=True)

# 1. The Functional Approach: `scnn.optimize`

The simplest way to train a neural network with convex optimization is to call `optimize` with the intended problem formulation, a training dataset, some (optional) test data, and a regularizer. 
In this case, we train a neural network with gated ReLU activations.

In [None]:
model, _ = optimize("gated_relu", 
                          max_neurons, 
                          X_train, 
                          y_train, 
                          X_test, 
                          y_test,
                          regularizer=NeuronGL1(0.01),
                          verbose=True,  
                          device="cpu")

# 2. The Object-Oriented Approach: `scnn.optimize_model`

For more control over the model and optimization procedure, we can use `optimize_model`.
We instantiate the convex formulation ourselves and choose an appropriate solver. 
In this approach, we set the gate vectors for the convex reformulation manually.
We can also directly specify the optimizer parameters if we so choose.
The following code trains an identical neural network as in the first approach.

In [None]:
# Instantiate convex model and other options.
model = ConvexGatedReLU(G)
solver = RFISTA(model, tol=1e-6)
regularizer = NeuronGL1(0.01)

In [None]:
grelu_model, grelu_metrics = optimize_model(
    model,
    solver,
    metrics,
    X_train, 
    y_train, 
    X_test, 
    y_test,
    regularizer=regularizer,
    verbose=True,
)

In [None]:
# Training Accuracy
np.sum(np.sign(grelu_model(X_train)) == y_train) / len(y_train)

## 2.1 Changing the Optimization Method

The main advantage of the second (object-oriented) approach is that it is easy to change the optimization method used.
Previously, we trained a gated ReLU model using R-FISTA, a solver for unconstrained problems based on proximal-gradient methods.
Now we train a variety of optimization methods leading to different final models.

### 2.1.1 Cone Decompositions

We want to train a ReLU model, but directly solving the corresponding convex optimization problem, which has complicating constraints, can be costly.
In this case, we use an approximate cone decomposition to convert a gated ReLU model into a ReLU neural network.

In [None]:
model = ConvexGatedReLU(G) # start with a Gated ReLU model; a ReLU model will be output.
solver = ApproximateConeDecomposition(model)
cd_model, cd_metrics = optimize_model(
    model,
    solver,
    metrics,
    X_train, 
    y_train, 
    X_test, 
    y_test,
    regularizer,
    verbose=True,
)

In [None]:
# Training Accuracy
np.sum(np.sign(cd_model(X_train)) == y_train) / len(y_train)

### 2.1.2 Direct ReLU Training

Of course, sometimes we prefer to directly solve the convex formulation of the ReLU training problem. 
We can use the built-in augmented Lagrangian method (AL) to do this.
One advantage of this approach is that it produces models with smaller weights.

In [None]:
model = ConvexReLU(G)
solver = AL(model)
relu_model, relu_metrics = optimize_model(
    model,
    solver,
    metrics,
    X_train, 
    y_train, 
    X_test, 
    y_test,
    regularizer,
    verbose=True,
)

In [None]:
# Training Accuracy
np.sum(np.sign(relu_model(X_train)) == y_train) / len(y_train)

### 2.1.3 High-Accuracy Interior Point Methods

The R-FISTA and AL methods are suitable for generating moderate-accuracy solutions, fast. 
For very-high accuracy solutions, we use CVXPY as an iterface to open-source and commerical interior point methods.
Interior point method do not produce (neuron) sparse solutions in general, so we provide a post-optimization clean-up phase that sparsifies the solution.

In [None]:
model = ConvexGatedReLU(G)
# note that commercial solvers like MOSEK/Gurobi can be used if they are installed.
solver = CVXPYSolver(model, "ecos", clean_sol=True)
regularizer = NeuronGL1(0.01)
cvxpy_model, cvxpy_metrics = optimize_model(
    model,
    solver,
    metrics,
    X_train, 
    y_train, 
    X_test, 
    y_test,
    regularizer=regularizer,
    verbose=True,
)

In [None]:
# Training Accuracy
np.sum(np.sign(cvxpy_model(X_train)) == y_train) / len(y_train)

### 2.1.4 Fast Quadratic Solvers

Finally, we can use super-fast iterative solvers by changing the model formulation to make the entire problem quadratic.
Specifically, changing the regularizer to a L2-squared penalty for gated ReLU models yields a ridge-regression problem that does not correspond to a non-convex model. 
However, it performs comparably in practice and can be trained quickly even on CPU.

In [None]:
# Super-fast least-squares solver.
model = ConvexGatedReLU(G)
solver = LeastSquaresSolver(model, tol=1e-8)
regularizer = L2(0.01)
lstsq_model, lstsq_metrics = optimize_model(
    model,
    solver,
    metrics,
    X_train, 
    y_train, 
    X_test, 
    y_test,
    regularizer=regularizer,
    verbose=True,
)

In [None]:
# Training Accuracy
np.sum(np.sign(lstsq_model(X_train)) == y_train) / len(y_train)

# 3. Training Times and Test Metrics

We briefly summarize results for the different optimizers and models discussed above.

In [None]:
fig = plt.figure(figsize=(18,6))
spec = fig.add_gridspec(ncols=3, nrows=1)
ax0 = fig.add_subplot(spec[0, 0])

ax0.plot(np.arange(len(relu_metrics.objective)), 
         relu_metrics.objective, 
         label="ReLU", 
         color="#ff7f0e",
         marker="^",
         markevery=0.1,
         markersize=14,
         linewidth="3")

ax0.plot(np.arange(len(grelu_metrics.objective)), 
         grelu_metrics.objective, 
         label="Gated ReLU", 
         color="#1f77b4",
         marker="v", 
         markevery=0.1,
         markersize=14,
         linewidth="3"
        )

ax0.plot(np.arange(len(cd_metrics.objective)), 
         cd_metrics.objective, 
         label="Cone Decomp.", 
         color="#2ca02c",
         marker="X", 
         markevery=0.1,
         markersize=14,
         linewidth="3"
        )

ax0.plot(np.arange(len(lstsq_metrics.objective)), 
         lstsq_metrics.objective, 
         label="Ridge Regression", 
         color="#d62728",
         marker="X", 
         markevery=0.1,
         markersize=14,
         linewidth="3"
        )
ax0.set_yscale("log")
ax0.set_title("Training Objective", fontsize=22)
ax0.set_xlabel("Iterations", fontsize=18)

ax1 = fig.add_subplot(spec[0, 1])

ax1.plot(relu_metrics.time, 
         relu_metrics.train_accuracy, 
         label="Relu", 
         color="#ff7f0e",
         marker="^",
         markevery=0.1,
         markersize=14,
         linewidth="3")

ax1.plot(grelu_metrics.time, 
         grelu_metrics.train_accuracy, 
         label="Gated ReLU", 
         color="#1f77b4",
         marker="v", 
         markevery=0.1,
         markersize=14,
         linewidth="3"
        )

ax1.plot(cd_metrics.time, 
         cd_metrics.train_accuracy, 
         label="Cone Decomp.", 
         color="#2ca02c",
         marker="X", 
         markevery=0.1,
         markersize=14,
         linewidth="3"
        )

ax1.plot(lstsq_metrics.time, 
         lstsq_metrics.train_accuracy, 
         label="Ridge Regression", 
         color="#d62728",
         marker="X", 
         markevery=0.1,
         markersize=14,
         linewidth="3"
        )
ax1.set_xscale("log")
ax1.set_ylim([0.5, 1])
ax1.set_title("Training Accuracy", fontsize=22)
ax1.set_xlabel("Time (S)", fontsize=18)

ax2 = fig.add_subplot(spec[0, 2])


ax2.plot(relu_metrics.time, 
         relu_metrics.test_accuracy, 
         label="Relu", 
         color="#ff7f0e",
         marker="^",
         markevery=0.1,
         markersize=14,
         linewidth="3")

ax2.plot(grelu_metrics.time, 
         grelu_metrics.test_accuracy, 
         label="Gated ReLU", 
         color="#1f77b4",
         marker="v", 
         markevery=0.1,
         markersize=14,
         linewidth="3"
        )

ax2.plot(cd_metrics.time, 
         cd_metrics.test_accuracy, 
         label="Cone Decomp.", 
         color="#2ca02c",
         marker="X", 
         markevery=0.1,
         markersize=14,
         linewidth="3"
        )
ax2.plot(lstsq_metrics.time, 
         lstsq_metrics.test_accuracy, 
         label="Ridge Regression", 
         color="#d62728",
         marker="X", 
         markevery=0.1,
         markersize=14,
         linewidth="3"
        )
ax2.set_xscale("log")
ax2.set_ylim([0.5, 1])
ax2.set_title("Test Accuracy", fontsize=22)
ax2.set_xlabel("Time (S)", fontsize=18)
handles, labels = ax0.get_legend_handles_labels()
legend = fig.legend(
    handles=handles,
    labels=labels,
    loc="lower center",
    borderaxespad=0.1,
    fancybox=False,
    shadow=False,
    ncol=4,
    fontsize=16,
    frameon=False,
)
fig.subplots_adjust(
    bottom=0.15,
)