# Imports

In [1]:
"""
Demo single-file notebook to train a ConvNet on CIFAR10 using SoftHebb, an unsupervised, efficient and bio-plausible learning algorithm.
Based on demo.py from the official repo
"""
import math
import warnings
from tqdm.notebook import tqdm
import datetime
from pathlib import Path
from copy import deepcopy
from collections import OrderedDict
import matplotlib.pyplot as plt
from IPython.display import display, clear_output
import random

import numpy as np


import torch
from torch import nn, optim
import torch.nn.functional as F
from torch.nn.modules.utils import _pair
from torch.optim.lr_scheduler import StepLR
import torchvision

def seed_init_fn(seed):
    seed = seed % 2 ** 32
    np.random.seed(seed)
    random.seed(seed)
    torch.manual_seed(seed)

# Device
device = torch.device('cuda:0')

# For Lateral Support

In [2]:
def create_circular_mask_batched(h, w, centers, radii):
    """
    Create circular masks for batched inputs using PyTorch tensors.
    
    Args:
    - h: Height of the grid (or a tensor of heights for each batch).
    - w: Width of the grid (or a tensor of widths for each batch).
    - centers: A tensor of shape (batch_size, 2) representing the (y, x) center of the circle for each batch.
               If None, defaults to the center of each grid.
    - radii: A tensor of shape (batch_size,) representing the radius of the circle for each batch.
             If None, defaults to the minimum distance from the center to the edges of each grid.
    
    Returns:
    - A boolean mask (tensor) of shape (batch_size, h, w), where True represents points inside the circle.
    """
    batch_size = centers.size(0)

    # Ensure `h` and `w` are tensors if they are scalar inputs
    if isinstance(h, int):
        h = torch.tensor([h] * batch_size)
    if isinstance(w, int):
        w = torch.tensor([w] * batch_size)

    # Create the grid for each batch (meshgrid for each image)
    Y, X = torch.meshgrid([torch.arange(h.max()), torch.arange(w.max())], indexing='ij')
    
    # Expand X and Y to match batch size and grid size
    X = X.unsqueeze(0).expand(batch_size, -1, -1).float().to(device=centers.device)
    Y = Y.unsqueeze(0).expand(batch_size, -1, -1).float().to(device=centers.device)

    # Calculate the distance of each point in the grid from the center of the circle
    dist_from_center = torch.sqrt((X - centers[:, 1].unsqueeze(1).unsqueeze(2)) ** 2 + 
                                  (Y - centers[:, 0].unsqueeze(1).unsqueeze(2)) ** 2)

    # Create a mask for each batch based on the radius
    masks = dist_from_center <= radii.unsqueeze(1).unsqueeze(2)

    return masks

# Model architecture definition

In [3]:
class SoftHebbLinear(nn.Module):
    def __init__(
            self,
            in_channels: int,
            out_channels: int,
            t_invert: float = 12, # Invert of the temperature for softmax
            # Lateral Support params
            n_lateral_neighbors=0, # 0 - turn off lateral support
            lateral_period=None, # None - turn off 2D
            lateral_gain_coef=0, # 0 - no effect
    ) -> None:
        """
        This is the implementation of Linear layer trained with SoftHebb method.
        The code is adapted and augmented from https://github.com/NeuromorphicComputing/SoftHebb/tree/main
        """
        super(SoftHebbLinear, self).__init__()
        self.in_channels = in_channels
        self.out_channels = out_channels
        weight_range = 14 / math.sqrt(in_channels)
        self.weight = nn.Parameter(weight_range * torch.randn((out_channels, in_channels)))
        self.t_invert = torch.tensor(t_invert)
        # Lateral Support params
        self.n_lateral_neighbors = n_lateral_neighbors
        self.lateral_period = lateral_period
        self.lateral_gain_coef = lateral_gain_coef

    def forward(self, x):
        weighted_input = F.linear(x, self.weight, None)

        if self.training:
            # Post-synaptic activation, for plastic update, is weighted input passed through a softmax.
            # Non-winning neurons (those not with the highest activation) receive the negated post-synaptic activation.
            batch_size, out_channels = weighted_input.shape
            flat_weighted_inputs = weighted_input
            # Compute the winner neuron for each batch element
            flat_softwta_activs = torch.softmax(self.t_invert * flat_weighted_inputs, dim=1)
            flat_softwta_activs = - flat_softwta_activs  # Turn all postsynaptic activations into anti-Hebbian
            win_neurons = torch.argmax(flat_weighted_inputs, dim=1)  # winning neuron for each pixel in each input
            competing_idx = torch.arange(flat_weighted_inputs.size(0))  # indices of all pixel-input elements
            # Turn winner neurons' activations back to hebbian (pqlet - i.e. make positive, not anti-hebbian term)
            flat_softwta_activs[:, win_neurons] = - flat_softwta_activs[:, win_neurons]

            # Lateral Support
            if self.n_lateral_neighbors > 0:
                # 1D case
                if self.lateral_period is None:
                    # Vectorize the generation of indices for lateral neighbors
                    # Range of neighbor offsets
                    neighbors_range = torch.hstack(
                        (torch.arange(-self.n_lateral_neighbors, 0, device=win_neurons.device),
                        torch.arange(1, self.n_lateral_neighbors + 1, device=win_neurons.device))
                    )
                    # Create the neighbors for each neuron by adding the offsets
                    ids_support = win_neurons.unsqueeze(1) + neighbors_range.unsqueeze(0)
                    # Clip the resulting indices to be within valid range
                    ids_support = torch.clip(ids_support, min=0, max=flat_softwta_activs.size(0) - 1).T
                    # Set values
                    flat_softwta_activs[:, ids_support] = \
                        + flat_softwta_activs[:, ids_support] \
                        - self.lateral_gain_coef * flat_softwta_activs[competing_idx, win_neurons].unsqueeze(1)
                # 2D case
                else:
                    assert type(self.lateral_period)==int
                    # Assume win_neurons are 1D indices of neurons, and n_lateral_period is the number of columns (width of 2D grid)
                    n_rows = flat_softwta_activs.size(1) // self.lateral_period  # Total rows in the 2D grid
                    # Convert winning neurons' 1D id to 2D
                    centers_from_win = torch.stack((
                        win_neurons.divide(self.lateral_period, rounding_mode='floor'),
                        win_neurons.remainder(self.lateral_period),
                    )).T
                    # Get circular 2D masks - boolean only by now
                    radii = self.n_lateral_neighbors * torch.ones(len(win_neurons), device=win_neurons.device)
                    mask_lsupp = create_circular_mask_batched(n_rows, self.lateral_period, 
                                                                centers=centers_from_win, 
                                                                radii=radii
                                                                )
                    # Apply masks
                    flat_softwta_activs = \
                        flat_softwta_activs \
                        - self.lateral_gain_coef * flat_softwta_activs[competing_idx, win_neurons].unsqueeze(1) * mask_lsupp.flatten(1)

            softwta_activs = flat_softwta_activs
            # ===== compute plastic update Δw = y*(x - u*w) = y*x - (y*u)*w =======================================
            yx = torch.matmul(softwta_activs.T, x)

            yu = torch.multiply(softwta_activs, weighted_input)
            yu = torch.sum(yu.t(), dim=1).unsqueeze(1)
            
            delta_weight = yx - yu.view(-1, 1,) * self.weight
            delta_weight.div_(torch.abs(delta_weight).amax() + 1e-30)  # Scale [min/max , 1]
            self.weight.grad = delta_weight  # store in grad to be used with common optimizers

        return weighted_input


class ModelSoftHebb(nn.Module):
    def __init__(self, in_channels, out_channels, t_invert, layer_kwargs=None):
        super(ModelSoftHebb, self).__init__()
        
        self.in_channels = in_channels
        self.out_channels = out_channels

        if layer_kwargs is None:
            layer_kwargs = {}
        self.layer_kwargs = layer_kwargs 
        
        self.bn1 = nn.BatchNorm1d(in_channels, affine=False)
        self.ff1 = SoftHebbLinear(
            in_channels=in_channels, 
            out_channels=out_channels, 
            t_invert=t_invert,
            **self.layer_kwargs
        ) 

    def forward(self, x):
        out = self.ff1(self.bn1(x))
        return out

# Optimization definitions - Scheduler, Optimizer

In [4]:
class TensorLRSGD(optim.SGD):
    @torch.no_grad()
    def step(self, closure=None):
        """Performs a single optimization step, using a non-scalar (tensor) learning rate.

        Arguments:
            closure (callable, optional): A closure that reevaluates the model
                and returns the loss.
        """
        loss = None
        if closure is not None:
            with torch.enable_grad():
                loss = closure()

        for group in self.param_groups:
            weight_decay = group['weight_decay']
            momentum = group['momentum']
            dampening = group['dampening']
            nesterov = group['nesterov']

            for p in group['params']:
                if p.grad is None:
                    continue
                d_p = p.grad
                if weight_decay != 0:
                    d_p = d_p.add(p, alpha=weight_decay)
                if momentum != 0:
                    param_state = self.state[p]
                    if 'momentum_buffer' not in param_state:
                        buf = param_state['momentum_buffer'] = torch.clone(d_p).detach()
                    else:
                        buf = param_state['momentum_buffer']
                        buf.mul_(momentum).add_(d_p, alpha=1 - dampening)
                    if nesterov:
                        d_p = d_p.add(buf, alpha=momentum)
                    else:
                        d_p = buf

                p.add_(-group['lr'] * d_p)
        return loss

# Dataset definition - CIFAR10

In [5]:
import pandas as pd
from pathlib import Path
import numpy as np
from torch.utils.data import DataLoader

datadir = Path("../")

# Load DataFrames with embedding ids and texts
df_acts_filename = datadir / "tinystories1M-TinyStories1_gpt2token-acts_id.csv"
df_texts_filename = datadir / "tinystories1M-TinyStories1_gpt2token-texts.csv"
df_acts_tokens = pd.read_csv(df_acts_filename)
df_texts= pd.read_csv(df_texts_filename)

# Load embedding matrix
with open(datadir / 'emb_matrix.npy', 'rb') as f:
    embs_flat = np.load(f)
    
unsup_trainloader = DataLoader(embs_flat, batch_size=64, pin_memory=True, num_workers=0, shuffle=True)

# Prepare model, optimizers, data

In [6]:
# seed training 
seed_init_fn(42)

# Layer kwargs
layer_kwargs_2D = {
    'n_lateral_neighbors':1,
    'lateral_period':60,
    'lateral_gain_coef':-0.3,
}
# Model 
model = ModelSoftHebb(
    in_channels=64, 
    out_channels=60*30, 
    t_invert=100,
    layer_kwargs=layer_kwargs_2D,
)
model.to(device)

unsup_optimizer = TensorLRSGD([
    # Basically just SGD despite the complicated class, becaouse of just 'lr' param
    {"params": model.ff1.parameters(), "lr": -0.2, },  # SGD does descent, but we assign the positive updates, so set lr to negative
], lr=0)

# Unsupervised training

In [7]:
n_epochs_unsup=1
# Unsupervised training with SoftHebb
print(f'{datetime.datetime.now().strftime(f"%H-%M-%S")}| Unsupervised Training - {n_epochs_unsup} epoch')

dd_ker_vals = {'ff1':[]}

for epoch in range(n_epochs_unsup):
    running_loss = 0.0
    for step, data in tqdm(enumerate(unsup_trainloader, 0), total=len(unsup_trainloader)):
        global_step = (step+1 + epoch*len(unsup_trainloader))
        
        inputs = data
        inputs = inputs.flatten(1)
        inputs = inputs.to(device)

        # zero the parameter gradients
        unsup_optimizer.zero_grad()

        # forward + update computation
        with torch.no_grad():
            outputs = model(inputs)

        # optimize
        unsup_optimizer.step()
                    
        # save kernels for plots
        ker_vals = deepcopy(model.ff1.weight.cpu().detach())
        dd_ker_vals['ff1'].append(ker_vals)

print(f'{datetime.datetime.now().strftime(f"%H-%M-%S")}| End of Unsupervised Training - {n_epochs_unsup} epoch')

20-36-12| Unsupervised Training - 1 epoch


  0%|          | 0/10666 [00:00<?, ?it/s]

20-37-08| End of Unsupervised Training - 1 epoch


# View top tokents to some neuron 

In [48]:
ic = 15 + 60
name = f"feat"
df_acts_tokens[name] = (embs_flat @ model.ff1.weight[ic].detach().cpu().numpy())
df_acts_tokens.sort_values(by=name, ascending=False).head(10)


Unnamed: 0,emb_id,token_str,seq_id,feat,sort_value
607536,607536,dad,2782,0.203379,0.101648
73537,73537,Jill,367,0.195074,0.107867
148784,148784,dad,737,0.192554,0.089043
463881,463881,dad,2166,0.192051,0.094174
674053,674053,children,3153,0.19181,0.144837
18562,18562,parents,99,0.19136,0.109007
577363,577363,friends,2638,0.191189,0.116324
271151,271151,dad,1212,0.190664,0.086794
12402,12402,daddy,69,0.189825,0.093433
162029,162029,pa,768,0.188926,0.083726


# Plot tokents weights are specific to in the 2D grid of output neurons

In [16]:
# Save top 3 unique tokens for every output neuron
joint_text = []
for ic in tqdm(range((model.ff1.weight.shape[0]))):
    activ_ic = (embs_flat @ model.ff1.weight[ic].detach().cpu().numpy())
    df_acts_tokens['sort_value'] = activ_ic
    top_token_str = df_acts_tokens.sort_values(by='sort_value', ascending=False).token_str
    joint_text.append('<br>'.join(top_token_str[:3].unique()))

  0%|          | 0/1800 [00:00<?, ?it/s]


The behavior of `series[i:j]` with an integer-dtype index is deprecated. In a future version, this will be treated as *label-based* indexing, consistent with e.g. `series[i]` lookups. To retain the old behavior, use `series.iloc[i:j]`. To get the future behavior, use `series.loc[i:j]`.



In [60]:
import plotly.express as px

# layer_kwargs_2D['lateral_period']
# fig = px.scatter(x=np.arange(0,60,1).repeat(10), y=np.tile(np.arange(0,10), 60), text=joint_text[:600])
x_rows = 16
fig = px.scatter(
    x=np.tile(np.arange(0,60), x_rows), 
    y=np.arange(0,x_rows,1).repeat(60), 
    text=joint_text[:60*x_rows]
)

fig.update_traces(textposition='top center')

fig.update_layout(
    height=1200,
    width=3000,
    plot_bgcolor='white',  # Background color of the plot area
    paper_bgcolor='white', # Background color of the area outside the plot
    title_text='2D Grid of neurons and their associate tokens',
    xaxis=dict(showgrid=False, zeroline=False, showticklabels=False),
    yaxis=dict(showgrid=False, zeroline=False, showticklabels=False),
    # xaxis_range=[11, 31],
    # yaxis_range=[-0.5, x_rows+1],
    font=dict(size=10) 
)
fig.show()

from pathlib import Path
dir = Path('image-pres')
dir.mkdir(exist_ok=True)
fig.write_image(f'{dir}/2D-neuron-map-words.png', format='png', scale=6, width=3000, height=1000)