# Setup

In [1]:
from pathlib import Path
APP = "/app" if Path("/app").exists() else "../app"
RUNS = "/runs" if Path("/runs").exists() else "../runs"
(APP, RUNS)

('../app', '../runs')

In [2]:
import sys

app = "/app"
if app not in sys.path:
    sys.path.append(app)

In [3]:
import torch
from torch import nn, Tensor
from torch import optim
from torch.nn import Parameter, GELU, Tanh, Sigmoid, Linear, Conv2d
from torch.nn import functional as F
from torch.utils.data import DataLoader, TensorDataset
from torch.utils.tensorboard import SummaryWriter

import numpy as np
from sklearn.datasets import make_moons, make_classification
from sklearn.model_selection import train_test_split
import plotly.express as px
from plotly import graph_objects as go
from tqdm import tqdm


In [4]:
from src.models import dd

# A 'dendritic' clustering layer

Inspired by Larkum ME, 2022, "Are Dendrites Conceptually Useful, Neuroscience https://doi.org/10.1016/j.neuroscience.2022.03.008

A 'dendritic' fully connected layer extends the classical fully connected `Linear` layer. It usess a convolution filter `conv_filter` to aggregate the activity of neighbouring synapses. The filter is moved along the sequence of synapses with the indicated `stride`. Note that this is a **fixed filter** -- it is NOT a learnable convolution.

# Toy example with simple classification task by a 2-layer MLP

We create a dataset using sklean `make_moons` function

In [41]:
N_FEATURES = 50
N_CLASSES = 5

In [42]:
# X, y = make_moons(n_samples=1000, noise=0.2, random_state=42)
X, y = make_classification(
    n_samples=1000,
    n_classes=N_CLASSES,
    n_features=N_FEATURES, n_informative=3, n_redundant=10, n_repeated=0,
    random_state=42, n_clusters_per_class=1,
)
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

In [43]:
# pca reduce to 3 dimensions for visualization
from sklearn.decomposition import PCA
pca = PCA(n_components=3)
X_pca = pca.fit_transform(X)

This is what it looks like:

In [45]:
fig = px.scatter_3d(
    x=X_pca[:, 0], y=X_pca[:, 1], z=X_pca[:, 2],
    color=y.astype(str),
    labels={'color': 'Class'},
    width=500, height=500,
    title='2D Classification Dataset Created by classification'
).update_traces(marker=dict(size=2)).show()






We convert this into a torch dataset

In [46]:
X_train = torch.tensor(X_train, dtype=torch.float32)
y_train = torch.tensor(y_train, dtype=torch.int64)
X_test = torch.tensor(X_test, dtype=torch.float32)
y_test = torch.tensor(y_test, dtype=torch.int64)

In [47]:
train_data = TensorDataset(X_train, y_train)
train_loader = DataLoader(train_data, batch_size=100, shuffle=True)

A simple training loop:

In [48]:

# this was an attempt to include a network-level constraint 
# to clamp the state to a given upstate max value

# def loss_fn(outputs, states, labels):
#     alpha = 0.5  # state regularization coefficient
#     up_state = 0  # upper bound for state regularization
#     # include a constraint on the state to encourage clamping it
#     # to up_state
#     if states is not None:
#         states = [torch.relu(s - up_state).mean() for s in states]
#         state_regul = sum(states) / len(states)
#     else:
#         state_regul = torch.tensor(0)
#     loss = criterion(outputs, labels) + alpha * state_regul
#     return loss, state_regul

class Trainer:

    def __init__(self, model, train_loader, testdata):
      self.model = model
      self.criterion = nn.CrossEntropyLoss()
      self.train_loader = train_loader
      self.X_test = testdata[0]
      self.y_test = testdata[1]

    def train(self, epochs: int, lr: float, run_name:str):
        writer = SummaryWriter(run_name)  # open new writer --> /runs
        optimizer = optim.Adam(self.model.parameters(), lr=lr)
        for epoch in tqdm(range(epochs)):
            for inputs, labels in self.train_loader:
                optimizer.zero_grad()
                outputs = self.model(inputs)
                loss = self.criterion(outputs, labels)  # loss_fn(outputs, states, labels)
                loss.backward()
                optimizer.step()
                self.log(loss, epoch, writer)
        print(f"Done at epoch {epoch}, Loss: {loss.item()}")
        self.accuracy()
        writer.close()

    def log(self, loss, epoch, writer):
        # write to tensorboard
        writer.add_scalar("Loss/train", loss.data, epoch)
        # loss on Trainset (we are lazy and don't use a separate validation set)
        with torch.no_grad():
            outputs = self.model(self.X_test)
            valid_loss = self.criterion(outputs, self.y_test)
        writer.add_scalar("Loss/valid", valid_loss, epoch)

    def accuracy(self):
        # checking accuracy quickly
        with torch.no_grad():
            outputs = self.model(self.X_test)
            predicted = outputs.argmax(1)
            accuracy = (predicted == self.y_test).sum().item() / len(self.y_test)
        print(f'Accuracy: {accuracy}')

Define the model as a 2-layer MLP. The first layer is a normal `Linear` module wich expands the original dimension. The second layer is the 'dendritic' layer. It usess a convolution filter to aggregate the activity of neighbouring synapses. The filter is moved along the sequence of synapses with the indicated stride. This is a **fixed filter** -- it is NOT a learnable convolution.

In [49]:
HIDDEN = 100

In [50]:
class dMLP(nn.Module):
    """'dendritic' MLP with 2 hidden layers, the first classic to expand,
    the next dendritic to integrate.
    """

    def __init__(self, stride, conv_filter, **kwargs):
        super(dMLP, self).__init__()
        self.stride = stride
        self.conv_filter = conv_filter
        # self.act_fn = nn.ReLU()
        self.fc1 = dd.DendriticFullyConnected(N_FEATURES, HIDDEN, conv_filter=self.conv_filter, stride=self.stride, **kwargs)
        self.fc2 = nn.Linear(HIDDEN, N_CLASSES)

    def forward(self, x):
        x = self.fc1(x)
        x = self.fc2(x)
        return x

For comparison, a similar classical MLP

In [51]:
class MLP(nn.Module):
    """Classical MLP with 2 layers
    """
    def __init__(self):
        super(MLP, self).__init__()
        self.fc1 = nn.Linear(N_FEATURES, HIDDEN)
        self.fc2 = nn.Linear(HIDDEN, N_CLASSES)
        self.act_fn = nn.ReLU()

    def forward(self, x):
        x = self.act_fn(self.fc1(x))
        x = self.fc2(x)
        return x

In [52]:
from datetime import datetime
from src.models.dd import Hill, MyRELU

# normal 
mlp = MLP()
model = (mlp,)
run_name = (f"{RUNS}/MLP-{datetime.now().isoformat().replace(':','-')}",)

# dendritic MLP
stride=3
kernel_size=10
conv_filter =  torch.tensor([[[1/kernel_size] * kernel_size]])    # torch.tensor([[[1/kernel_size]* kernel_size]])  # torch.tensor([[[0.001, 0.01, 0.2, 0.5, 0.2, 0.01, 0.001]]])  # torch.tensor([[[1/kernel_size]* kernel_size]])
dmlp = dMLP(
    stride=stride,
    conv_filter=conv_filter,
    cluster_act_fn=Hill(2, 0.5),
    bias=True
)
model += (dmlp,)
run_name += (f"{RUNS}/dMLP-{datetime.now().isoformat().replace(':','-')}",)

model

(MLP(
   (fc1): Linear(in_features=50, out_features=100, bias=True)
   (fc2): Linear(in_features=100, out_features=5, bias=True)
   (act_fn): ReLU()
 ),
 dMLP(
   (fc1): DendriticFullyConnected(
     (nmda): Linear(in_features=50, out_features=100, bias=True)
     (non_nmda): Linear(in_features=50, out_features=100, bias=True)
   )
   (fc2): Linear(in_features=100, out_features=5, bias=True)
 ))

In [53]:
# # NMDA antagonist simulation
# params = dict(p for p in dmlp.named_parameters())
# params['fc1.nmda.weight'].requires_grad_(False)
# params['fc1.nmda.bias'].requires_grad_(False)
# params['fc1.non_nmda.weight'].requires_grad_(False)
# params['fc1.non_nmda.bias'].requires_grad_(False)
# same for fc2
# params['fc2.weight'].requires_grad_(False)
# params['fc2.bias'].requires_grad_(False)
# params.keys()

In [54]:
#train
for run, m in zip(run_name, model):
    print(f"training {run}")
    print(m)
    trainer = Trainer(m, train_loader, (X_test, y_test))
    trainer.train(epochs=100, lr=0.001, run_name=run)

training ../runs/MLP-2024-01-28T17-14-34.474679
MLP(
  (fc1): Linear(in_features=50, out_features=100, bias=True)
  (fc2): Linear(in_features=100, out_features=5, bias=True)
  (act_fn): ReLU()
)


100%|██████████| 100/100 [00:00<00:00, 177.81it/s]


Done at epoch 99, Loss: 0.03526962921023369
Accuracy: 0.695
training ../runs/dMLP-2024-01-28T17-14-34.475322
dMLP(
  (fc1): DendriticFullyConnected(
    (nmda): Linear(in_features=50, out_features=100, bias=True)
    (non_nmda): Linear(in_features=50, out_features=100, bias=True)
  )
  (fc2): Linear(in_features=100, out_features=5, bias=True)
)


100%|██████████| 100/100 [02:29<00:00,  1.50s/it]

Done at epoch 99, Loss: 0.193076491355896
Accuracy: 0.7





In [None]:
# Load the TensorBoard notebook extension
%load_ext tensorboard
# %reload_ext tensorboar

In [None]:
# launch tensorboard
%tensorboard --logdir ../runs

Visualization of the weights:

In [55]:
for name in model[1].named_modules(): print(name)

('', dMLP(
  (fc1): DendriticFullyConnected(
    (nmda): Linear(in_features=50, out_features=100, bias=True)
    (non_nmda): Linear(in_features=50, out_features=100, bias=True)
  )
  (fc2): Linear(in_features=100, out_features=5, bias=True)
))
('fc1', DendriticFullyConnected(
  (nmda): Linear(in_features=50, out_features=100, bias=True)
  (non_nmda): Linear(in_features=50, out_features=100, bias=True)
))
('fc1.nmda', Linear(in_features=50, out_features=100, bias=True))
('fc1.non_nmda', Linear(in_features=50, out_features=100, bias=True))
('fc2', Linear(in_features=100, out_features=5, bias=True))


In [56]:
px.imshow(
    dict(model[1].named_modules())['fc1.non_nmda'].weight.data,
    width=1000, aspect='auto',
).update_layout(coloraxis_showscale=False).show()
print(dict(model[1].named_modules())['fc1.non_nmda'].weight.size())

torch.Size([100, 50])


In [57]:
px.imshow(
    dict(model[1].named_modules())['fc1.nmda'].weight.data,
    width=1000, aspect='auto',
).update_layout(coloraxis_showscale=False).show()
print(dict(model[1].named_modules())['fc1.nmda'].weight.size())


torch.Size([100, 50])
