# Introduction to Multiplexed Gradient Descent (MGD)
In an era of emerging specialized neuromorphic hardware, the demand for backprop-free gradient descent is more than ever. MGD is a perturbative method that offers a simple and efficient path to training arbitrary networks with respect to an objective function, so long as a gradient exists.  This notebook is a minimal example toward a practical understanding of MGD. A few things to note:
 - We utilize the JAX python package for flexible and efficient neural network operations
 - Where ever the @jax.jit is used, the function that follows will be subject to "just in time" compilation
 - JAX requires explicit seed (random key) definitions.  Where ever a key is defined, simply think of that as a seed.
 - This can simply be thought of as preparing (compiling) the function to be used efficiently
 - We otherwise avoid obfuscating our implementation with calls to pacakages like jax.flax are even more efficent, but opaque

In [2]:
import jax
import optax
import copy

import jax.numpy as jnp
import flax.linen as nn
import numpy as np
import matplotlib.pyplot as plt

import src.utils.helper_functions as hf
import src.data_loader as dat

In [3]:
from jax import config
config.update("jax_default_matmul_precision", "float32")
print(jax.devices())

[CudaDevice(id=0)]


In [5]:
from datasets import load_dataset
cache_dir="~/.cache/huggingface/datasets"

trainset = load_dataset("mnist", split = 'train', cache_dir = cache_dir)
testset  = load_dataset("mnist", split = 'test', cache_dir = cache_dir)

def preprocess(sample):
    sample['label'] = jax.nn.one_hot(sample['label'], 10)
    return sample

trainset = trainset.map(preprocess).shuffle(seed=0)
testset = testset.map(preprocess).shuffle(seed=0)

# Make datasets output in numpy format
trainset = trainset.with_format('numpy')
testset  = testset.with_format('numpy')

# Copy to memory
trainset = trainset[:]
testset  = testset[:]

trainset['image'] = np.array(trainset['image'])/255
testset['image']  = np.array(testset['image'])/255

## To Start: Initialize (any) network
 - MGD is model-free and therefore indifferent to network topology, activation function, and node variation. 
 - We here implement a simple multi-layer feedforward neural network (an MLP).
 - In the recurrent tutorial, we examine more varied topologies.

In [6]:
class InfiniteLoader:
    def __init__(self, X, y, batch_size=32, shuffle=True):
        self.X = np.array(X)
        self.y = np.array(y)
        self.batch_size = batch_size
        self.shuffle = shuffle
        self.indices = np.arange(len(X))
        self.ptr = 0
        if self.shuffle:
            np.random.shuffle(self.indices)

    def __iter__(self):
        return self

    def __next__(self):
        if self.ptr + self.batch_size > len(self.X):
            self.ptr = 0
            if self.shuffle:
                np.random.shuffle(self.indices)

        idx = self.indices[self.ptr:self.ptr+self.batch_size]
        self.ptr += self.batch_size

        xb = jnp.array(self.X[idx])
        yb = jnp.array(self.y[idx])
        return xb, yb
    
batch_size = 64
trainloader = InfiniteLoader(X=trainset['image'], y=trainset['label'], batch_size=batch_size, shuffle=True)
testloader  = InfiniteLoader(X=testset['image'], y=testset['label'], batch_size=batch_size, shuffle=True)

In [7]:

class MLP(nn.Module):
    hidden_sizes: list[int]

    @nn.compact
    def __call__(self, x):
        x = x.reshape(x.shape[0], -1)
        for size in self.hidden_sizes[:-1]:
            x = nn.relu(nn.Dense(size)(x))
        x = nn.Dense(self.hidden_sizes[-1])(x)
        return x
    
x_batch , y_batch = next(trainloader)
layer_dims        = [128,128,10]

model = MLP(hidden_sizes=layer_dims)

key    = jax.random.PRNGKey(0)
params = model.init(key, x_batch)['params']
print(jax.tree.map(lambda p: p.shape, params))

{'Dense_0': {'bias': (128,), 'kernel': (784, 128)}, 'Dense_1': {'bias': (128,), 'kernel': (128, 128)}, 'Dense_2': {'bias': (10,), 'kernel': (128, 10)}}


## 1. Make a (totally normal) foward pass
- The forward pass for MGD is exactly like any old MLP forward pass

In [8]:
logits = model.apply({'params': params}, x_batch)
print(f"Activations of final layer (for random sample of batch):")
print(logits[np.random.randint(batch_size)])

Activations of final layer (for random sample of batch):
[-0.16950952  0.13096733  0.10776845  0.18692796  0.03496152 -0.00521973
 -0.0693294  -0.22302842  0.121273   -0.107751  ]


Finally, like with any forward pass, we determine the loss with respect to our objective function.  Here, we use cross-entropy.

In [9]:
@jax.jit
def loss_CE(logits, labels):
    return optax.softmax_cross_entropy(logits=logits, labels=labels).mean()

cost0 = loss_CE(logits, y_batch)
print(f"Forward pass cost: \n{cost0:.6}")

Forward pass cost: 
2.28932


## 2. Perturb all learnable parameters simultaneously
 - Select some (small) value $\epsilon$
 - Random perturb all learnable parameters by $\pm \epsilon$
 - Store those perturbation values for every parameter
 - Apply those perturbations to a copy of the original parametes
   - In hardware, there would be no copying. The pert value would just be stored locally.

In [10]:
@jax.jit
def sample_perturbations(params,epsilon,i):
    key = jax.random.PRNGKey(i, impl=None)
    return jax.tree.map(
        lambda p: jax.random.choice(key, jnp.array([-1,1])*epsilon, shape=(p.shape)), params
    )

@jax.jit
def apply_perturbations(theta,perturbations):
    return jax.tree.map(lambda param, pert: param+pert, theta, perturbations)

epsilon = 1e-6

# create a set of epsilon-sized perturbations for given the shape of parameters
perturbations = sample_perturbations(params,epsilon,0)

# apply perturbations to copy
params_perturbed = apply_perturbations(params,perturbations)

## 3. Perform forward pass (again), but now with perturbed parameters
 - You'll notice these results are not exactly equivalent to the originals.

In [11]:
logits_perturbed = logits = model.apply({'params': params_perturbed}, x_batch)
cost_perturbed   = loss_CE(logits_perturbed, y_batch)

print(f"Perturbed forward pass cost: \n{cost_perturbed:.6}")

Perturbed forward pass cost: 
2.28932


## 4. Collect the gradient
 - Collect the gradient at every parameter
 - Defined simple as the difference in orginal and perturbed cost, weighted by learning rate
 - A good default learning rate is $\eta = 1/\epsilon^2$

In [12]:
@jax.jit
def init_grad(params):
    return jax.tree.map(lambda p: jnp.zeros(shape=(p.shape)), params)

@jax.jit
def collect_grad(perts,delta_c,G):
    return jax.tree.map(lambda G, p: G + p*delta_c, G, perts)

delta_cost = cost_perturbed - cost0
print(f"Difference in cost before and after perturbing network parameters: {delta_cost:.6}")

gradient = init_grad(params)
gradient = collect_grad(perturbations, delta_cost, gradient)

Difference in cost before and after perturbing network parameters: -4.76837e-07


## 5. Update parameters
 - Simply apply this gradient to every parameter component-wise
 - Weight update by $\tau_\theta$ (the number of iterations for which the gradient was collected)
   - Analogous to batch size
   - In this case $\tau_\theta=1$

In [13]:
@jax.jit
def MGD_update(params,G,eta,tau_theta):
    return jax.tree.map(
        lambda p, G: p - G*eta/tau_theta, params, G
    )

eta = 1/epsilon**2
tau_theta = batch_size
params = MGD_update(params,gradient,eta,tau_theta)

## That's it! Now just repeat
 - Everything else about learning is analogous to the backprop procedure
 - However, importantly, a sophisticated gradient calculation, weight transport, and network model are *not* needed
 - Instead, two forward passes, local short-term storage of perturbations, and one global broadcase are used
 - That's it!

In [21]:
def compute_accuracy(apply_fn, theta, x, y):
    logits  = apply_fn({'params': theta}, x)
    preds   = jnp.argmax(logits, axis=-1)
    targets = jnp.argmax(y, axis=-1)
    return (preds == targets).mean()

def loss_fn_prejit(apply_fn,params,x,y):
    logits = apply_fn({'params': params}, x)
    loss = optax.softmax_cross_entropy(logits, y).mean()
    return loss

loss_fn = jax.jit(loss_fn_prejit, static_argnums=(0,))

# deriveable normalization method
def simple_eta_norm(K, epsilon):
    return 1/(K*(epsilon)**2)

In [None]:
# hyperparams
epochs         = 25
epsilon        = 1e-6
layer_dims     = [128,64,10]
learning_rate  = 1e2
decay          = 0.1
batch_size     = tau_theta = 128
validation_mod = 1
    
trainloader = InfiniteLoader(X=trainset['image'], y=trainset['label'], batch_size=batch_size)
testloader  = InfiniteLoader(X=testset['image'],  y=testset['label'],  batch_size=batch_size)

x_batch , y_batch = next(trainloader)

# initializing network
model    = MLP(hidden_sizes=layer_dims)
key      = jax.random.PRNGKey(0)
theta    = model.init(key, x_batch)['params']
apply_fn = model.apply

# normalizing eta and weighting with learning rate
num_params = jnp.size(jax.flatten_util.ravel_pytree(theta)[0])
eta0       = simple_eta_norm(num_params,epsilon) * learning_rate

# corresponding batch size to iterations per epoch / test
batches      = trainloader.X.shape[0]//batch_size
test_batches = testloader.X.shape[0]//batch_size
iterations   = epochs * batches

# for recording training data
costs    = []
accs     = []
valpochs = []
epoch    = 0
epoch_costs      = []
for i in range(iterations):

    # setup for epoch
    x_batch, y_batch = next(trainloader)
    eta              = eta0/(1+decay*epoch)

    # initialize an empty gradient with same shapes as parameters (theta)
    gradient = init_grad(theta)

    # Step 1: Forward pass
    cost0  = loss_fn(apply_fn, theta, x_batch, y_batch)
    
    # Step 2: Perturb
    perturbations   = sample_perturbations(theta,epsilon,i)
    theta_perturbed = apply_perturbations(theta,perturbations)

    # Step 3: Perturbed forward pass
    cost_perturbed   = loss_fn(apply_fn, theta_perturbed, x_batch, y_batch)

    # Step 4: Collect gradient
    delta_cost = (cost_perturbed - cost0)*-1
    gradient   = collect_grad(perturbations, delta_cost*eta, gradient)

    # Step 5: Make update
    theta = optax.apply_updates(theta, gradient)

    epoch_costs.append(cost0)

    # validate (test) only with specified frequency according to validation mod
    if i % batches == 0: epoch+=1
    if i % batches == 0 and epoch % validation_mod == 0:
        
        # testing prediction accuracy over all batches in test set
        batch_accs = [compute_accuracy(apply_fn, theta, *next(testloader)) for _ in range(test_batches)]

        # recording data
        accs.append(np.mean(batch_accs))
        valpochs.append(epoch)
        costs.append(np.mean(epoch_costs))
        epoch_costs = []

        # updating user
        print(f"Iter {i} :: Epoch {epoch}  ::  cost {costs[-1]:.6}  :: accuracy {accs[-1]:.3}     ", 
              end='\r')

plt.style.use('seaborn-v0_8-muted')
plt.figure(figsize=(8,3))
plt.plot(valpochs, accs)
plt.plot(valpochs, costs/np.max(costs))
plt.xlabel("Epochs")
plt.ylabel("Cost")
plt.title("Training a Neural Net with MGD on MNIST")
plt.show()

Iter 0 :: Epoch 1  ::  cost 2.31715  :: accuracy 0.0913     