<a href="https://colab.research.google.com/github/vaglino/V-prop-SNN/blob/main/V_prop_SNN__edge_based_MNIST.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
# -*- coding: utf-8 -*-
"""
Edge-list V-prop SNN for MNIST – using
O(E) memory for spikes / voltages.
"""

# ——————————————————————————————————————————————————————————————
# Imports
# ——————————————————————————————————————————————————————————————
import jax, jax.numpy as jnp
from jax import random, jit, lax
import numpy as np
import matplotlib.pyplot as plt
from functools import partial
import seaborn as sns, optax, time
from sklearn.utils import resample
from sklearn.preprocessing import StandardScaler
from skimage.transform import resize
from keras.datasets import mnist


# ——————————————————————————————————————————————————————————————
# Global constants
# ——————————————————————————————————————————————————————————————
thresh = 0.5
key    = random.PRNGKey(1)
dt = ds = 1.0
t_max  = 30
tau    = 10.0
α      = 0.5
v_rest = 0.0


# ==============================================================
# 0. Original dense-implementation helpers (unchanged bodies)
# ==============================================================

def create_directed_small_world(
    n_total=402,
    n_inputs=196,
    n_outputs=10,
    k_nearest=4,
    p_rewire=0.1,
    p_connect=0.1
):
    """
    Creates a directed small-world adjacency matrix with
    – no inbound edges to inputs,
    – no outbound edges from outputs,
    – ring-rewired hidden layer.
    A[i,j] == 1 ⇒ edge i → j.
    """
    key = random.PRNGKey(0)

    A = jnp.zeros((n_total, n_total), dtype=jnp.int32)

    hidden_start = n_inputs
    hidden_end   = n_total - n_outputs
    n_hidden     = hidden_end - hidden_start
    hidden_slice = slice(hidden_start, hidden_end)

    # 2) regular ring
    indices = jnp.arange(hidden_start, hidden_end)
    for i in range(n_hidden):
        ring_targets = (i + jnp.arange(1, k_nearest + 1)) % n_hidden + hidden_start
        A = A.at[indices[i], ring_targets].set(1)

    # 3) Watts–Strogatz rewiring
    hidden_connections = A[hidden_slice, hidden_slice]
    key, subkey = random.split(key)
    rand_mat    = random.uniform(subkey, hidden_connections.shape)
    to_rewire   = (rand_mat < p_rewire) & (hidden_connections == 1)
    hidden_connections = hidden_connections & (~to_rewire)
    src_rows, _ = jnp.where(to_rewire)
    num_r       = src_rows.shape[0]
    if num_r:
        key, subkey = random.split(key)
        new_cols = random.randint(subkey, (num_r,), 0, n_hidden)
        for s, t in zip(src_rows, new_cols):
            if s != t:
                hidden_connections = hidden_connections.at[s, t].set(1)
    A = A.at[hidden_slice, hidden_slice].set(hidden_connections)

    # 4) extra random shortcuts
    key, subkey = random.split(key)
    extra = (random.uniform(subkey, (n_hidden, n_hidden)) < p_connect)
    extra = extra & (~(hidden_connections.T == 1))
    hidden_connections = (hidden_connections == 1) | extra
    A = A.at[hidden_slice, hidden_slice].set(hidden_connections.astype(jnp.int32))

    # 5) input → hidden
    key, subkey = random.split(key)
    in2hid = (random.uniform(subkey, (n_inputs, n_hidden)) < 2*p_connect)
    A = A.at[:n_inputs, hidden_slice].set(in2hid)

    # 6) hidden → output
    key, subkey = random.split(key)
    hid2out = (random.uniform(subkey, (n_hidden, n_outputs)) < 2*p_connect)
    A = A.at[hidden_slice, n_total-n_outputs:].set(hid2out)

    # 7) guarantee ≥1 outgoing edge per input
    need = (jnp.sum(A[:n_inputs], 1) == 0)
    if jnp.any(need):
        key, subkey = random.split(key)
        t = random.randint(subkey, (n_inputs,), hidden_start, hidden_end)
        for i in range(n_inputs):
            if need[i]:
                A = A.at[i, t[i]].set(1)

    # 8) guarantee ≥1 inbound edge per output
    need = (jnp.sum(A[:, n_total-n_outputs:], 0) == 0)
    if jnp.any(need):
        key, subkey = random.split(key)
        s = random.randint(subkey, (n_outputs,), hidden_start, hidden_end)
        for i in range(n_outputs):
            if need[i]:
                A = A.at[s[i], n_total-n_outputs+i].set(1)

    # 9) strip forbidden connections
    A = A.at[:, :n_inputs].set(0)
    A = A.at[-n_outputs:, :].set(0)

    layers = [
        ('Input',  (n_inputs,), 0, n_inputs-1),
        ('Hidden', (n_hidden,), hidden_start, hidden_end-1),
        ('Output', (n_outputs,), n_total-n_outputs, n_total-1)
    ]
    return A, layers


def plot_adjacency_matrix(A):
    plt.figure(figsize=(6,6))
    sns.heatmap(A)
    plt.title("Adjacency matrix")
    plt.show()


def load_mnist_data(
    train_size=0.8, val_size=0.1, test_size=0.1,
    resample_fraction=0.2, target_size=(14,14), random_state=42
):
    """
    Loads, normalises, downsizes and stratified-resamples MNIST.
    Returns JAX arrays (x_train, x_val, x_test, y_train, … small labels)
    identical to the dense helper.
    """
    if not np.isclose(train_size + val_size + test_size, 1.0):
        raise ValueError("Splits must sum to 1.0")

    (X_tr, y_tr), (X_te, y_te) = mnist.load_data()
    X, y = np.concatenate([X_tr, X_te]), np.concatenate([y_tr, y_te])

    X = X.astype(np.float32) / 255.0
    X = StandardScaler().fit_transform(X.reshape(len(X), -1)).reshape(X.shape)

    def _resize(imgs, size):
        out = np.zeros((len(imgs), *size))
        for i in range(len(imgs)):
            out[i] = resize(imgs[i], size, anti_aliasing=True)
        return out
    X = _resize(X, target_size)[..., None]

    n_tot = int(resample_fraction * len(X))
    X_small, y_small = resample(X, y, n_samples=n_tot, random_state=random_state)

    n_tr = int(n_tot*train_size)
    n_v  = int(n_tot*val_size)
    perm = np.random.RandomState(random_state).permutation(n_tot)
    tr, v, te = perm[:n_tr], perm[n_tr:n_tr+n_v], perm[n_tr+n_v:]

    x_train, x_val, x_test = map(jnp.array, (X_small[tr], X_small[v], X_small[te]))
    y_train = jax.nn.one_hot(y_small[tr], 10)
    y_val   = jax.nn.one_hot(y_small[v],  10)
    y_test  = jax.nn.one_hot(y_small[te], 10)

    return (x_train, x_val, x_test,
            y_train, y_val, y_test,
            y_small[tr], y_small[v], y_small[te])


# Basic math helpers
@jit
def v(ds, dt):          return ds / dt
@jit
def relu_shift(x, th):  return jnp.maximum(0., x - th)


# ==============================================================
# 1. Build the graph, convert to edge lists
# ==============================================================

adj_dense, layers = create_directed_small_world(
    n_total = 1510, n_inputs=784, n_outputs=10,
    k_nearest=16, p_rewire=0.4, p_connect=0.1
)
plot_adjacency_matrix(adj_dense)

N_NEURONS = adj_dense.shape[0]
N_INPUTS  = 784
N_OUTPUTS = 10

src_np, tgt_np = np.nonzero(np.array(adj_dense))
src = jnp.array(src_np, dtype=jnp.int32)
tgt = jnp.array(tgt_np, dtype=jnp.int32)
E   = len(src)

vmax = v(ds, dt)


# ==============================================================
# 2. Per-edge initialisation
# ==============================================================

def initialize_lengths(adj, lengths=jnp.arange(4., 6., 1.), key=random.PRNGKey(1)):
    key, sub = random.split(key)
    L = adj * random.choice(sub, lengths, adj.shape)
    return L, key

def initialize_weights(adj, p_inhib=0.0, key=random.PRNGKey(2)):
    key, sk1, sk2 = random.split(key, 3)
    W = adj * random.uniform(sk1, adj.shape)
    sign = (random.uniform(sk2, adj.shape) > p_inhib) * 2 - 1
    W = W * sign * adj
    return W, key

def initialise_edges(adj, key):
    Lmat, key = initialize_lengths(adj, jnp.arange(3., 8., ds), key)
    L_e       = Lmat[src, tgt]
    L_e = L_e.at[src < N_INPUTS].set(3.0)          # force pixel-edges length 3
    Wmat, key = initialize_weights(adj, 0.0, key)
    W_e       = Wmat[src, tgt]
    return L_e, W_e, key

L_e, W_e, key = initialise_edges(adj_dense, key)


# ==============================================================
# 3. Edge-tensor primitives (1 ↔ 1 with dense ops)
# ==============================================================

SRC_EDGES = (src < N_INPUTS)        # boolean mask, length E

def scatter_edges_to_nodes(edge_vals):
    """(B,E) → (B,N) – sums edge currents into their *target* neurons."""
    return jax.vmap(lambda ev: jnp.zeros(N_NEURONS).at[tgt].add(ev))(edge_vals)

@jit
def take_step_edges(S):  return S + (S > 0) * dt * vmax

# @jit
# def check_arrival_edges(S, L):  return (S == L)          # exact int lengths
@jit
def check_arrival_edges(S, L_e, *, atol=1e-5, rtol=1e-8):
    """
    Detect when a spike has traversed its edge length, with tolerance.

    Parameters
    ----------
    S   : jnp.ndarray, shape (batch_size, E)
          Current spike positions (distance already travelled along each edge).
    L_e : jnp.ndarray, shape (E,)
          Edge lengths.
    atol, rtol : float
          Absolute and relative tolerances for jnp.isclose.  Defaults match
          the dense reference code (atol=1e-5, rtol=1e-8).

    Returns
    -------
    arrived : jnp.ndarray, bool, shape (batch_size, E)
              True where the spike has reached (or very nearly reached) the
              target node.
    """
    return jnp.isclose(S, L_e, atol=atol, rtol=rtol)
@jit
def excite_inputs_edges(S, V, x_batch):
    B = x_batch.shape[0]
    pix = x_batch.reshape(B, -1)                        # (B,784)

    idle_mask = (S[:, SRC_EDGES] == 0)
    S = S.at[:, SRC_EDGES].set(jnp.where(idle_mask, dt*vmax, S[:, SRC_EDGES]))
    V = V.at[:, SRC_EDGES].set(jnp.where(idle_mask, pix[:, src[SRC_EDGES]], V[:, SRC_EDGES]))
    return S, V

@jit
def synapse_integration_edges(Vm, W_e, V_arr):
    I_syn = scatter_edges_to_nodes(V_arr * W_e)
    dV    = (-Vm + I_syn) * (dt / tau)
    Vm    = Vm + dV
    Y     = jnp.maximum(0., Vm - thresh)
    return Y, Vm

@jit
def update_spikes_edges(S, V_exc, fired):
    idle = (S == 0)
    newS = fired[:, src] & idle
    newV = newS * V_exc[:, src]
    return newS, newV

@jit
def delete_finished_edges(S, V, arrived):
    return S * (~arrived), V * (~arrived)

@jit
def reset_vm(Vm, fired):
    return Vm - Vm*fired - 0.2*fired


# ==============================================================
# 4. **Fixed** RSNN step & inference – stores Vm *before* reset
# ==============================================================

def make_edge_step(W_e, L_e, dropout_rate):
    @jit
    def _step(carry, key):
        S, V, Vm, x, acc = carry

        # 1) excite inputs
        S, V = excite_inputs_edges(S, V, x)

        # 2-3) spike arrivals
        arrived = check_arrival_edges(S, L_e)
        V_arr   = V * arrived

        # 4) synaptic integration
        V_exc, Vm = synapse_integration_edges(Vm, W_e, V_arr)

        # 5) threshold (outputs don’t fire)
        fired = V_exc > 0
        fired = fired.at[:, -N_OUTPUTS:].set(False)

        # 6) dropout
        keep_p = 1.0 - dropout_rate
        mask   = random.bernoulli(key, keep_p, fired.shape)
        fired &= mask
        V_exc  = V_exc * mask / keep_p

        # 7) spawn new spikes
        newS, newV = update_spikes_edges(S, V_exc, fired)

        # 8) delete finished spikes
        S, V = delete_finished_edges(S, V, arrived)

        # 9) **accumulate Vm BEFORE reset**
        acc = acc + Vm                                   # running sum

        # 10) reset fired neurons
        Vm = reset_vm(Vm, fired)

        # 11) advance spikes & add newborns
        S = take_step_edges(S) + newS * dt * vmax
        V = V + newV

        return (S, V, Vm, x, acc), None
    return _step


@jit
def RSNN_inference_edge(W_e, L_e, x_batch, rng_key, dropout_rate):
    B   = x_batch.shape[0]
    S   = V = jnp.zeros((B, E))
    Vm  = jnp.zeros((B, N_NEURONS))
    acc = jnp.zeros((B, N_NEURONS))          # accumulator for Vm

    (S, V, Vm, x, acc), _ = lax.scan(
        make_edge_step(W_e, L_e, dropout_rate),
        (S, V, Vm, x_batch, acc),
        random.split(rng_key, t_max)
    )
    return acc / t_max


# ==============================================================
# 5. Loss / prediction wrappers
# ==============================================================

@jit
def loss_edge(W_e, L_e, xb, yb, key, drop):
    Vm_mean = RSNN_inference_edge(W_e, L_e, xb, key, drop)    # (B, N)
    logits  = Vm_mean[:, -N_OUTPUTS:]                       # (B, 10)
    probs   = jax.nn.softmax(logits, 1)
    return -jnp.mean(jnp.sum(yb * jnp.log(probs + 1e-8), 1))


@jit
def predict_edge(W_e, L_e, xb):
    Vm_mean = RSNN_inference_edge(W_e, L_e, xb, key, 0.0)
    logits  = Vm_mean[:, -N_OUTPUTS:]
    return jnp.argmax(logits, 1)


# ==============================================================
# 6. Optimiser
# ==============================================================

optimizer = optax.chain(
    optax.clip_by_global_norm(5.0),
    optax.adam(0.005)
)

opt_state = optimizer.init(W_e)

@jit
def train_edge_step(W_e, L_e, opt_state, xb, yb, key):
    loss, grads = jax.value_and_grad(loss_edge)(
        W_e, L_e, xb, yb, key, 0.2)
    updates, opt_state = optimizer.update(grads, opt_state, W_e)
    W_e = optax.apply_updates(W_e, updates)
    return W_e, opt_state, loss


# ==============================================================
# 7. Data
# ==============================================================

(x_train, x_val, x_test,
 y_train, y_val, y_test,
 y_train_lbl, y_val_lbl, y_test_lbl) = load_mnist_data(
     train_size=0.8, val_size=0.1, test_size=0.1,
     resample_fraction=1.0, target_size=(28,28), random_state=42
)


# ==============================================================
# 8. Training loop
# ==============================================================

batch_size = 32
epochs     = 20
rng = random.PRNGKey(42)


print("Starting training…"); t0 = time.time()
for ep in range(epochs):
    # --- shuffle training data ------------------------------------------------
    perm = np.random.permutation(len(x_train))
    x_tr, y_tr = x_train[perm], y_train[perm]

    # --- TRAIN -----------------------------------------------------------------
    train_loss, nb = 0.0, 0
    rng, sk = random.split(rng)
    for i in range(0, len(x_train) - batch_size + 1, batch_size):
        xb, yb = x_tr[i:i+batch_size], y_tr[i:i+batch_size]
        sk, sub = random.split(sk)
        W_e, opt_state, lv = train_edge_step(W_e, L_e, opt_state, xb, yb, sub)
        train_loss += float(lv);  nb += 1
    train_loss /= nb

    # --- VALIDATE (no dropout) -------------------------------------------------
    val_loss, nbv = 0.0, 0
    preds = []
    for j in range(0, len(x_val) - batch_size + 1, batch_size):
        xb, yb = x_val[j:j+batch_size], y_val[j:j+batch_size]
        rng, sub = random.split(rng)
        val_loss += float(loss_edge(W_e, L_e, xb, yb, sub, 0.0));  nbv += 1
        preds.append(predict_edge(W_e, L_e, xb))
    val_loss /= nbv

    preds = jnp.concatenate(preds)
    val_acc = float((preds == y_val_lbl[:len(preds)]).mean()) * 100.0

    # --- LOG -------------------------------------------------------------------
    print(f"Epoch {ep+1:02d} | "
          f"train loss {train_loss:.4f} | "
          f"val loss {val_loss:.4f} | "
          f"val acc {val_acc:.2f} %")

print("Training finished in", round(time.time() - t0, 1), "s")

# ==============================================================
# 9. Test accuracy
# ==============================================================

pred=[]
for i in range(0, len(x_test)-batch_size+1, batch_size):
    pred.append(predict_edge(W_e, L_e, x_test[i:i+batch_size]))
pred = jnp.concatenate(pred)
acc  = (pred == y_test_lbl[:len(pred)]).mean() * 100
print(f"Final test accuracy: {acc:.2f} %")

if __name__ == "__main__":
    print("Edge-list SNN – bug-fixed and self-contained")