In [None]:
"""
Train convex LassoNet models on a synthetic classification dataset using convex optimization.
"""

import sys
sys.path.append("..")

In [None]:
import numpy as np
import matplotlib.pyplot as plt

from convex_nn import optimize
from convex_nn.utils.data import gen_classification_data

In [None]:
# Generate realizable synthetic classification problem (ie. Figure 1)
n_train = 250
n_test = 250
d = 50
hidden_units = 100
kappa = 10  # condition number

(X_train, y_train), (X_test, y_test) = gen_classification_data(123, n_train, n_test, d, hidden_units, kappa)

In [None]:
grelu_model, grelu_metrics = optimize(X_train, 
               y_train, 
               X_test, 
               y_test, 
               train_metrics=["accuracy"],
               test_metrics=["accuracy"],
               additional_metrics=["feature_sparsity", "active_features", "step_size"],
               max_patterns=250, 
               formulation="grelu_lasso_net",
               reg_strength=0.05,
               backend="numpy",
               verbose=True)

In [None]:
relu_model, relu_metrics = optimize(X_train, 
               y_train, 
               X_test, 
               y_test, 
               train_metrics=["base_objective", "accuracy"],
               test_metrics=["accuracy"], 
               additional_metrics=["feature_sparsity", "active_features", "step_size", "constraint_gaps"], 
               max_patterns=250, 
               formulation="relu_lasso_net", 
               reg_strength=0.1,
               backend="numpy", 
               verbose=True)

In [None]:
# inspect weights of final models (in non-convex formulation)

print("Gated ReLU Model")
W1, W2, theta = grelu_model.get_weights()

print("Layer Shapes:", W1.shape, W2.shape, theta.shape)
print("Skip-Layer Sparsity:", np.sum(theta == 0) / theta.shape[1])
print("Network Sparsity:", np.sum(np.sum(W1, axis=0) == 0) / theta.shape[1])

print("ReLU Model")
W1, W2, theta = relu_model.get_weights()

# Note: the ReLU network *is* feature sparse; the sparsity arises from cancellation in the forward pass and is
# "hidden" when working with the non-convex formulation.

print("Layer Shapes:", W1.shape, W2.shape, theta.shape)
print("Skip-Layer Sparsity:", np.sum(theta == 0) / theta.shape[1])
print("Network Sparsity:", np.sum(np.sum(W1, axis=0) == 0) / theta.shape[1])

In [None]:
fig = plt.figure(figsize=(18,6))
spec = fig.add_gridspec(ncols=3, nrows=1)
ax0 = fig.add_subplot(spec[0, 0])

ax0.plot(np.arange(len(relu_metrics["train_base_objective"])), 
         relu_metrics["train_base_objective"], 
         label="Relu LassoNet", 
         color="#ff7f0e",
         marker="^",
         markevery=0.1,
         markersize=14,
         linewidth="3")

ax0.plot(np.arange(len(grelu_metrics["train_objective"])), 
         grelu_metrics["train_objective"], 
         label="Gated LassoNet", 
         color="#1f77b4",
         marker="v", 
         markevery=0.1,
         markersize=14,
         linewidth="3"
        )

ax0.set_title("Training Objective", fontsize=22)
ax0.set_xlabel("Iterations", fontsize=18)

ax1 = fig.add_subplot(spec[0, 1])

ax1.plot(np.arange(len(relu_metrics["train_accuracy"])), 
         relu_metrics["train_accuracy"], 
         label="Relu LassoNet", 
         color="#ff7f0e",
         marker="^",
         markevery=0.1,
         markersize=14,
         linewidth="3")

ax1.plot(np.arange(len(grelu_metrics["train_accuracy"])), 
         grelu_metrics["train_accuracy"], 
         label="Gated LassoNet", 
         color="#1f77b4",
         marker="v", 
         markevery=0.1,
         markersize=14,
         linewidth="3"
        )

ax1.set_title("Training Accuracy", fontsize=22)
ax1.set_xlabel("Iterations", fontsize=18)

ax2 = fig.add_subplot(spec[0, 2])


ax2.plot(np.arange(len(relu_metrics["test_accuracy"])), 
         relu_metrics["test_accuracy"], 
         label="Relu LassoNet", 
         color="#ff7f0e",
         marker="^",
         markevery=0.1,
         markersize=14,
         linewidth="3")

ax2.plot(np.arange(len(grelu_metrics["test_accuracy"])), 
         grelu_metrics["test_accuracy"], 
         label="Gated LassoNet", 
         color="#1f77b4",
         marker="v", 
         markevery=0.1,
         markersize=14,
         linewidth="3"
        )

ax2.set_title("Test Accuracy", fontsize=22)
ax2.set_xlabel("Iterations", fontsize=18)
handles, labels = ax0.get_legend_handles_labels()
legend = fig.legend(
    handles=handles,
    labels=labels,
    loc="lower center",
    borderaxespad=0.1,
    fancybox=False,
    shadow=False,
    ncol=2,
    fontsize=16,
    frameon=False,
)
fig.subplots_adjust(
    bottom=0.15,
)