# Optimization in Deep Learning

In [2]:
import jax
import jax.numpy as jnp
from jax import random
from sklearn.datasets import make_classification
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler

from optymus.methods.first_order import StochasticGradientDescent

### Simple neural network

In [6]:
# Define the MLP model
def init_params(layer_sizes, key):
    params = []
    for n_in, n_out in zip(layer_sizes[:-1], layer_sizes[1:]):
        w_key, b_key = random.split(key)
        params.append({
            'w': random.normal(w_key, (n_in, n_out)) * jnp.sqrt(2 / n_in),
            'b': jnp.zeros((n_out,))
        })
        key = random.split(key)[0]
    return params

def relu(x):
    return jnp.maximum(0, x)

def forward(params, x):
    for layer in params[:-1]:
        x = relu(jnp.dot(x, layer['w']) + layer['b'])
    return jnp.dot(x, params[-1]['w']) + params[-1]['b']

def loss_fn(params, x, y):
    logits = forward(params, x)
    return jnp.mean((logits - y) ** 2)

# Generate a simple dataset
X, y = make_classification(n_samples=1000, n_features=20, n_classes=2, random_state=42)
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

# Standardize the features
scaler = StandardScaler()
X_train = scaler.fit_transform(X_train)
X_test = scaler.transform(X_test)

# Convert to jax arrays
X_train = jnp.array(X_train)
y_train = jnp.array(y_train).reshape(-1, 1)
X_test = jnp.array(X_test)
y_test = jnp.array(y_test).reshape(-1, 1)

# Initialize the MLP
layer_sizes = [20, 64, 32, 1]
key = random.PRNGKey(0)
initial_params = init_params(layer_sizes, key)

# Flatten the parameters for SGD
def flatten_params(params):
    return jnp.concatenate([p.ravel() for layer in params for p in layer.values()])

def unflatten_params(flat_params, layer_sizes):
    params = []
    idx = 0
    for n_in, n_out in zip(layer_sizes[:-1], layer_sizes[1:]):
        w_size = n_in * n_out
        b_size = n_out
        w = flat_params[idx:idx+w_size].reshape((n_in, n_out))
        idx += w_size
        b = flat_params[idx:idx+b_size]
        idx += b_size
        params.append({'w': w, 'b': b})
    return params

flat_initial_params = flatten_params(initial_params)

# Define the objective function for SGD
def objective(flat_params, batch):
    params = unflatten_params(flat_params, layer_sizes)
    x, y = batch[:, :-1], batch[:, -1:]
    return loss_fn(params, x, y)

# Create and run SGD
sgd = StochasticGradientDescent(
    f_obj=objective,
    tol=1e-6,
    learning_rate=0.01,
    max_iter=1000,
    batch_size=32,
    verbose=True
)

# Combine X_train and y_train into a single array for batching
train_data = jnp.column_stack((X_train, y_train))

result = sgd.optimize(flat_initial_params, train_data)

# Extract the optimized parameters
optimized_params = unflatten_params(result['xopt'], layer_sizes)

# Evaluate the model
@jax.jit
def accuracy(params, X, y):
    pred = forward(params, X) > 0.5
    return jnp.mean(pred.flatten() == y.flatten())

train_accuracy = accuracy(optimized_params, X_train, y_train)
test_accuracy = accuracy(optimized_params, X_test, y_test)

print(f"Training accuracy: {train_accuracy:.4f}")
print(f"Test accuracy: {test_accuracy:.4f}")
print(f"Number of iterations: {result['num_iter']}")
print(f"Training time: {result['time']:.2f} seconds")

SGD 0: 100%|██████████| 1000/1000 [03:08<00:00,  5.30it/s]


Training accuracy: 0.6400
Test accuracy: 0.6400
Number of iterations: 1000
Training time: 188.65 seconds
