# Setup

In [None]:
import sys

app = "/app"
if app not in sys.path:
    sys.path.append(app)

In [None]:
import torch
from torch import nn, Tensor
from torch import optim
from torch.nn import Parameter, GELU, Tanh, Sigmoid, Linear, Conv2d
from torch.nn import functional as F
from torch.utils.data import DataLoader, TensorDataset
from torch.utils.tensorboard import SummaryWriter

import numpy as np
from sklearn.datasets import make_moons
from sklearn.model_selection import train_test_split
import plotly.express as px
from plotly import graph_objects as go
from tqdm import tqdm


In [None]:
from src.models import dd

In [None]:
# Load the TensorBoard notebook extension
%load_ext tensorboard

# A 'dendritic' clustering layer

Inspired by Larkum ME, 2022, "Are Dendrites Conceptually Useful, Neuroscience https://doi.org/10.1016/j.neuroscience.2022.03.008

A 'dendritic' fully connected layer extends the classical fully connected `Linear` layer. It usess a convolution filter `conv_filter` to aggregate the activity of neighbouring synapses. The filter is moved along the sequence of synapses with the indicated `stride`. Note that this is a **fixed filter** -- it is NOT a learnable convolution.

# Toy example with simple classification task by a 2-layer MLP

We create a dataset using sklean `make_moons` function

In [None]:
X, y = make_moons(n_samples=1000, noise=0.2, random_state=42)
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

In [None]:
X_train = torch.tensor(X_train, dtype=torch.float32)
y_train = torch.tensor(y_train, dtype=torch.int64)
X_test = torch.tensor(X_test, dtype=torch.float32)
y_test = torch.tensor(y_test, dtype=torch.int64)

This is what it looks like:

In [None]:
px.scatter(
    x=X[:, 0], y=X[:, 1],
    color=y.astype(str),
    labels={'color': 'Class'},
    width=500, height=500,
    title='2D Classification Dataset Created by make_moons'
)

We convert this into a torch dataset

In [None]:
train_data = TensorDataset(X_train, y_train)
train_loader = DataLoader(train_data, batch_size=100, shuffle=True)

Define the model as a 2-layer MLP. The first layer is a normal `Linear` module wich expands the original dimension. The second layer is the 'dendritic' layer. It usess a convolution filter to aggregate the activity of neighbouring synapses. The filter is moved along the sequence of synapses with the indicated stride. This is a **fixed filter** -- it is NOT a learnable convolution.

In [None]:
class dMLP(nn.Module):
    """'dendritic' MLP with 2 hidden layers, the first classic to expand,
    the next dendritic to integrate.
    """

    def __init__(self, stride, conv_filter):
        super(dMLP, self).__init__()
        self.stride = stride
        self.conv_filter = conv_filter
        self.act_fn = nn.ReLU()
        self.fc1 = nn.Linear(2, 10)
        self.fc2 = dd.DendriticFullyConnected(10, 2, conv_filter=self.conv_filter, stride=self.stride)

    def forward(self, x):
        x = self.act_fn(self.fc1(x))
        x, state2 = self.fc2(x)
        return x, (None, state2)

For comparison, a similar classical MLP

In [None]:
class MLP(nn.Module):
    """Classical MLP with 2 layers
    """
    def __init__(self):
        super(MLP, self).__init__()
        self.fc1 = nn.Linear(2, 10)
        self.fc2 = nn.Linear(10, 2)
        self.act_fn1 = nn.ReLU()

    def forward(self, x):
        x = self.act_fn1((self.fc1(x)))
        x = self.fc2(x)
        return x, None  # for compatibility with dMLP

A simple training loop:

In [None]:

# this was an attempt to include a constraint to clamp the state to a given upstate
# it does not seem to help so far
# def loss_fn(outputs, states, labels):
#     alpha = 0.5  # state regularization coefficient
#     up_state = 0.5  # upper bound for state regularization
#     # include a constraint on the state to encourage clamping it
#     # to up_state
#     if states is not None:
#         states = [torch.relu(s - up_state).mean() for s in states]
#         state_regul = sum(states) / len(states)
#     else:
#         state_regul = torch.tensor(0)
#     loss = criterion(outputs, labels) + alpha * state_regul
#     return loss, state_regul


class Trainer:

    def __init__(self, model, train_loader, testdata):
      self.model = model
      self.criterion = nn.CrossEntropyLoss()
      self.train_loader = train_loader
      self.X_test = testdata[0]
      self.y_test = testdata[1]

    def train(self, epochs: int, lr: float):
        writer = SummaryWriter()  # open new writer --> /runs
        optimizer = optim.Adam(model.parameters(), lr=lr)
        for epoch in tqdm(range(epochs)):
            for inputs, labels in self.train_loader:
                optimizer.zero_grad()
                outputs, states = self.model(inputs)
                loss = self.criterion(outputs, labels)  # loss_fn(outputs, states, labels)
                loss.backward()
                optimizer.step()
                self.log(loss, epoch, writer)
        print(f"Done at epoch {epoch}, Loss: {loss.item()}")
        self.accuracy()
        writer.close()

    def log(self, loss, epoch, writer):
        # write to tensorboard
        writer.add_scalar("Loss/train", loss.data, epoch)
        # loss on Trainset (we are lazy and don't use a separate validation set)
        with torch.no_grad():
            outputs, states = self.model(self.X_test)
            valid_loss = self.criterion(outputs, self.y_test)
        writer.add_scalar("Loss/valid", valid_loss, epoch)

    def accuracy(self):
        # checking accuracy quickly
        with torch.no_grad():
            outputs, states = self.model(self.X_test)
            _, predicted = torch.max(outputs, 1)
            accuracy = (predicted == self.y_test).sum().item() / len(self.y_test)
        print(f'Accuracy: {accuracy}')

# new model
stride=2
conv_filter = torch.tensor([[[0.33]* 3]])
# model = dMLP(stride=stride, conv_filter=conv_filter)
model = MLP()


In [None]:

# train
trainer = Trainer(model, train_loader, (X_test, y_test))
trainer.train(epochs=1000, lr=0.001)


In [None]:
# launch tensorboard
%tensorboard --logdir runs

Visualization of the weights:

In [None]:
px.imshow(
    dict(model.named_modules())['fc1'].weight.T.data,
    width=1000, aspect='auto',
).update_layout(coloraxis_showscale=False).show()

In [None]:
px.imshow(
    dict(model.named_modules())['fc2'].weight.data,
    width=1000, aspect='auto',
).update_layout(coloraxis_showscale=False).show()
print(dict(model.named_modules())['fc2'].weight.size())
