<a href="https://colab.research.google.com/github/reutdayan/HebbNet/blob/main/HebbNet.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import os
from matplotlib import pyplot as plt
import numpy as np
import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision.transforms import ToTensor

Mount Google Drive - to save weights in google drive

In [None]:
from google.colab import drive

drive.mount('/content/drive')

Mounted at /content/drive


Download and upload Mnist dataset

In [None]:
# Download training data from open datasets.
training_data = datasets.MNIST(
    root="data",
    train=True,
    download=True,
    transform=ToTensor()
)

# Download test data from open datasets.
test_data = datasets.MNIST(
    root="data",
    train=False,
    download=True,
    transform=ToTensor()
)

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to data/MNIST/raw/train-images-idx3-ubyte.gz


100%|██████████| 9912422/9912422 [00:00<00:00, 74795023.30it/s]


Extracting data/MNIST/raw/train-images-idx3-ubyte.gz to data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to data/MNIST/raw/train-labels-idx1-ubyte.gz


100%|██████████| 28881/28881 [00:00<00:00, 35691129.59it/s]


Extracting data/MNIST/raw/train-labels-idx1-ubyte.gz to data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to data/MNIST/raw/t10k-images-idx3-ubyte.gz


100%|██████████| 1648877/1648877 [00:00<00:00, 19813298.83it/s]


Extracting data/MNIST/raw/t10k-images-idx3-ubyte.gz to data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to data/MNIST/raw/t10k-labels-idx1-ubyte.gz


100%|██████████| 4542/4542 [00:00<00:00, 9045835.12it/s]

Extracting data/MNIST/raw/t10k-labels-idx1-ubyte.gz to data/MNIST/raw






Define hyperparameters

In [None]:
# hyperparameters
batch_size = 1
lr = 1 # η, the learning rate
p = 0.01  # top-p percentile for gradient sparsification
epochs = 200
momentum = 5e-4
lr_decay = 0.95

input_layer_size = 28*28
hidden_layer_size = 2000
output_layer_size = 10

In [None]:
# Create data loaders.
train_dataloader = DataLoader(training_data, batch_size=batch_size)
test_dataloader = DataLoader(test_data, batch_size=batch_size)

for X, y in test_dataloader:
    print(f"Shape of X [N, C, H, W]: {X.shape}")
    print(f"Shape of y: {y.shape} {y.dtype}")
    break

# Get cpu, gpu or mps device for training.
device = (
    "cuda"
    if torch.cuda.is_available()
    else "mps"
    if torch.backends.mps.is_available()
    else "cpu"
)
print(f"Using {device} device")

Shape of X [N, C, H, W]: torch.Size([1, 1, 28, 28])
Shape of y: torch.Size([1]) torch.int64
Using cuda device


In [None]:
# Define model
class HebbNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.flatten = nn.Flatten()
        self.hebbian_weights = nn.Linear(input_layer_size, hidden_layer_size, False)
        self.classification_weights = nn.Linear(hidden_layer_size, output_layer_size, True)

        self.relu = nn.ReLU()
        self.softmax = nn.LogSoftmax(dim=1)

    def forward(self, x):
        x = self.flatten(x)
        z = self.hebbian_weights(x)
        z = self.relu(z)  # Apply ReLU activation after the Hebbian layer
        pred = self.classification_weights(z)
        pred = self.softmax(pred)
        return x,z,pred


In [None]:
class BatchedHebbRuleWithActivationThreshold(nn.Module):
    def __init__(self, hidden_layer_size=2000, input_layer_size=784, batch_size=64):
        super().__init__()
        self.batch_size = batch_size
        self.register_buffer('w1_activation_thresholds', torch.zeros((hidden_layer_size, input_layer_size)))
        self.t = 1

    def forward(self, x_input: torch.Tensor, z_hidden: torch.Tensor):
        '''
        Args:
        x_input: torch.Tensor of shape (input_layer): (784)
        hidden: torch.Tensor of shape (hidden_layer): (2000)


        Output:
        shape: hidden_layerXinput_layer (2000,784)
        '''
        with torch.no_grad():
          z = z_hidden[...,None] # (B,2000,1)
          x = x_input[:,None] # (B,1,784)
          activation = torch.einsum('bij,bjk->bik',z,x) #(B,2000,784) - matrix multipication

          if self.t==1:
              delta_w1 = activation
          else:
              delta_w1 = activation - (self.w1_activation_thresholds[None] / (self.t-1))

          self.w1_activation_thresholds += activation.sum(0)
          self.t = self.t + self.batch_size
          return delta_w1.mean(0) # Mean over batch


In [None]:
class HebbRuleWithActivationThreshold(nn.Module):
    def __init__(self, hidden_layer_size=2000, input_layer_size=784):
        super().__init__()
        self.register_buffer('w1_activation_thresholds', torch.zeros((hidden_layer_size, input_layer_size)))
        self.t = 1

    def forward(self, x: torch.Tensor, z: torch.Tensor):
        with torch.no_grad():
          activation = torch.matmul(z.T,x) #(2000,784) - matrix multipication

          if self.t==1:
              delta_w1 = activation
          else:
              delta_w1 = activation - (self.w1_activation_thresholds / (self.t-1))

          self.w1_activation_thresholds += activation
          self.t += 1
          return delta_w1

In [None]:
def gradiant_sparsity(delta_w1, p):
  # Calculate the number of values to keep based on the percentile (p)
  num_values_to_keep = int( p * delta_w1.numel())

  # Find the top k values and their indices
  top_values, _ = torch.topk(torch.abs(delta_w1).view(-1), num_values_to_keep)
  threshold = top_values[-1]  # The threshold is the k-th largest value

  # Set values below the threshold to zero
  delta_w1 = torch.where(torch.abs(delta_w1) >= threshold, delta_w1, torch.tensor(0.0).to(device))
  return delta_w1




In [None]:
def train(dataloader, model, loss_fn, optimizer, lr, activation_thresholder: HebbRuleWithActivationThreshold):
    size = len(dataloader.dataset)
    correct = 0
    for batch, (X, y) in enumerate(dataloader):
        X, y = X.to(device), y.to(device)
        X,z_hidden,pred = model(X)
        loss = loss_fn(pred, y)
        correct += (pred.argmax(1) == y).type(torch.float).sum().item()

        # Backpropagation
        optimizer.zero_grad()
        loss.backward()
        # optimize classifiction wieghts
        optimizer.step()

        # optimize hebbian weights
        # activation threshold
        delta_w1 = activation_thresholder(X, z_hidden)

        # Gradient sparsity
        delta_w1 = gradiant_sparsity(delta_w1, p)

        # update hebbian weights
        model.hebbian_weights.weight.data = model.hebbian_weights.weight.data - lr*delta_w1

        if batch % 10000 == 0:
          loss, current = loss.item(), (batch + 1) * len(X)
          print(f"loss: {loss:>7f}  [{current:>5d}/{size:>5d}]")
    correct /= size
    print(f"Train Accuracy: {(100*correct):>0.1f}%")



In [None]:
def test(dataloader, model, loss_fn):
    size = len(dataloader.dataset)
    num_batches = len(dataloader)
    model.eval()
    test_loss, correct = 0, 0
    with torch.no_grad():
        for X, y in dataloader:
            X, y = X.to(device), y.to(device)
            x,z,pred = model(X)
            test_loss += loss_fn(pred, y).item()
            correct += (pred.argmax(1) == y).type(torch.float).sum().item()
    test_loss /= num_batches
    correct /= size
    print(f"Test Error: \n Accuracy: {(100*correct):>0.1f}%, Avg loss: {test_loss:>8f} \n")
    return 100* correct

In [None]:
def visualize_weights (weights):
    for neuron in weights:
        image = neuron.reshape(28,28).cpu()
        plt.imshow(image.detach().numpy())
        plt.show()

def visualize_activations(weights,X):
  X,z_hidden,pred = model(X)
  delta_w1 = activation_thresholder(X, z_hidden).squeeze()
  plt.imshow(delta_w1.detach().numpy())
  plt.show()

In [None]:
def save_model(model, activation_thresholder, optimizer, scheduler, filepath):
    """
    Save PyTorch model, activation_thresholder, optimizer, and scheduler to a file.

    Args:
        model (nn.Module): The PyTorch model to be saved.
        activation_thresholder: The activation_thresholder object to be saved.
        optimizer: The optimizer used for training the model.
        scheduler: The learning rate scheduler.
        filepath (str): The path to the file where the model and other objects will be saved.
    """
    # Create a dictionary to store the model, activation_thresholder, optimizer, and scheduler states
    state = {
        'model_state_dict': model.state_dict(),
        'activation_thresholder': activation_thresholder,
        'optimizer_state_dict': optimizer.state_dict(),
        'scheduler_state_dict': scheduler.state_dict() if scheduler else None
    }

    # Save the state dictionary to the specified file using torch.save
    torch.save(state, filepath)

def load_model(model, optimizer, scheduler, filepath):
    """
    Load PyTorch model, activation_thresholder, optimizer, and scheduler from a file.

    Args:
        model (nn.Module): The PyTorch model to be loaded.
        optimizer: The optimizer to be loaded.
        scheduler: The learning rate scheduler to be loaded.
        filepath (str): The path to the file from which the objects will be loaded.

    Returns:
        model, activation_thresholder, optimizer, scheduler: The loaded model, activation_thresholder,
        optimizer, and scheduler.
    """
    # Load the saved state dictionary from the file using torch.load
    state = torch.load(filepath)

    # Load the model, activation_thresholder, and optimizer states
    model.load_state_dict(state['model_state_dict'])
    activation_thresholder = state['activation_thresholder']
    optimizer.load_state_dict(state['optimizer_state_dict'])

    # Load the scheduler state, if it exists
    if scheduler and state['scheduler_state_dict']:
        scheduler.load_state_dict(state['scheduler_state_dict'])

    return model, activation_thresholder, optimizer, scheduler

In [None]:
load_from = -1
# Define a directory to save the model weights
drive_save_dir = '/content/drive/My Drive/HebbNet_model_weights/'

# Load the model architecture
model = HebbNet().to(device)
print(model)

loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD([
            {'params': model.classification_weights.parameters(), 'lr': lr, 'momentum': momentum} ])
scheduler = torch.optim.lr_scheduler.StepLR(optimizer,step_size=1, gamma=lr_decay)
activation_thresholder = HebbRuleWithActivationThreshold().to(device)

# Load the most recent model weights if available
if load_from > -1:
  checkpoint_path = os.path.join(drive_save_dir, f'model_epoch_{load_from}.pth')
  model, activation_thresholder, optimizer, scheduler = load_model(model, optimizer, scheduler, checkpoint_path)


for t in range(load_from+1 ,epochs):
    print(f"Epoch {t}\n-------------------------------")
    train(train_dataloader, model, loss_fn, optimizer, scheduler.get_last_lr()[0], activation_thresholder)
    # learning rate decsy for the hebbian layer
    test(test_dataloader, model, loss_fn)
    scheduler.step()
    # Save the model weights after each epoch
    # save_model(model, activation_thresholder, optimizer, scheduler, os.path.join(drive_save_dir, f"model_epoch_{t}.pth"))
    torch.cuda.empty_cache()
print("Done!")


NeuralNetwork(
  (flatten): Flatten(start_dim=1, end_dim=-1)
  (hebbian_weights): Linear(in_features=784, out_features=2000, bias=False)
  (classification_weights): Linear(in_features=2000, out_features=10, bias=True)
  (relu): ReLU()
  (softmax): LogSoftmax(dim=1)
)
Epoch 0
-------------------------------
loss: 2.356591  [    1/60000]
loss: 2.207929  [10001/60000]
loss: 2.090814  [20001/60000]
loss: 37446008.000000  [30001/60000]
loss: 3.105266  [40001/60000]
loss: 2.914796  [50001/60000]
Train Accuracy: 9.5%
Test Error: 
 Accuracy: 11.9%, Avg loss: 8269772064.914073 

Epoch 1
-------------------------------
loss: 541.339111  [    1/60000]
loss: 2.551543  [10001/60000]
loss: 2.126985  [20001/60000]
loss: 2.478526  [30001/60000]
loss: 4.136813  [40001/60000]
loss: 2.949136  [50001/60000]
Train Accuracy: 9.2%
Test Error: 
 Accuracy: 12.0%, Avg loss: 7497279728.182299 

Epoch 2
-------------------------------
loss: 3.323491  [    1/60000]
loss: 2.540766  [10001/60000]
loss: 2.196618  [20

Viaualize the hebbian weights

In [None]:
#visualize_weights(model.hebbian_weights.weight)

In [None]:
'''
X = next(iter(test_dataloader))[0].to(device)
X,z_hidden,pred = model(X)
delta_w1 = activation_thresholder(X, z_hidden).squeeze()
# delta_w1 = gradiant_sparsity(delta_w1, p)
delta_w1 = delta_w1.reshape(2000,28,28)
for d_w in delta_w1:
  plt.imshow(d_w.detach().cpu().numpy(), cmap='gray')
  plt.show()
  '

SyntaxError: ignored

In [None]:
# Directory containing your .pth files
drive_save_dir = '/content/drive/My Drive/HebbNet_model_weights/'

# Lists to store epoch and accuracy data
epochs = []
train_accuracies = []
test_accuracies = []

# Loop through .pth files
for filename in os.listdir(drive_save_dir):
    if filename.startswith("model_epoch_") and filename.endswith(".pth"):
        epoch = int(filename.split("_")[-1].split(".")[0])
        checkpoint_path = os.path.join(drive_save_dir, f'model_epoch_{epoch}.pth')
        model, _, _, _ = load_model(model, optimizer, scheduler, checkpoint_path)
        print(epoch)
        train_accuracy = test(train_dataloader,model, loss_fn)
        test_accuracy = test(test_dataloader, model, loss_fn)
        epochs.append(epoch)
        train_accuracies.append(train_accuracy)
        test_accuracies.append(test_accuracy)

# Sort the data by epoch number
# epochs, accuracies = zip(*sorted(zip(epochs, accuracies)))

# Create a figure for accuracy vs. epoch
plt.figure(figsize=(10, 6))
plt.plot(epochs, test_accuracies, marker='o', linestyle='-')
plt.title('Test Accuracy vs. Epoch')
plt.xlabel('Epoch')
plt.ylabel('Accuracy')
plt.grid(True)
plt.show()


plt.figure(figsize=(10, 6))
plt.plot(epochs, train_accuracies, marker='o', linestyle='-')
plt.title('Train Accuracy vs. Epoch')
plt.xlabel('Epoch')
plt.ylabel('Accuracy')
plt.grid(True)
plt.show()

In [None]:
from sklearn.decomposition import PCA
from sklearn.cluster import KMeans
import matplotlib.pyplot as plt

all_data = torch.cat([data for data, _ in train_dataloader], dim=0).to(device)
all_labels = torch.cat([labels for _, labels in train_dataloader], dim=0).to(device)
print(all_data.shape)
print(all_labels.shape)

X,z_hidden,pred = model(all_data)

print(X.shape)



Test for one hots PCA Kmeans - doesnt achive good clustering

In [None]:
import numpy as np
from sklearn.decomposition import PCA
from sklearn.cluster import KMeans
import matplotlib.pyplot as plt

n_components = 2  # You can change this number
# Generate one-hot encoded vectors and matching labels
vectors = []
labels = []
for _ in range(100):
    # Randomly choose the index for the '1' in the one-hot vector
    one_index = np.random.randint(10)

    # Create a one-hot vector with one '1' at the chosen index
    vector = np.zeros(10)
    vector[one_index] = 1

    # Append the vector to the list of vectors
    vectors.append(vector)

    # Append the label (the index where '1' is located) to the list of labels
    labels.append(one_index)


pca = PCA(n_components=n_components)
X_pca = pca.fit_transform(vectors)

k = 10  # Number of clusters
kmeans = KMeans(n_clusters=k)
y_kmeans = kmeans.fit_predict(X_pca)

# Create a scatter plot of the PCA results with colored labels
plt.scatter(X_pca[:, 0], X_pca[:, 1], c=kmeans.labels_, cmap='rainbow')
plt.title('K-means Clusters')
plt.xlabel('Feature 1')
plt.ylabel('Feature 2')
plt.show()

# Assuming X is your data and y_true is the true labels
plt.scatter(X_pca[:, 0], X_pca[:, 1], c=labels, cmap='rainbow')
plt.title('True Labels')
plt.xlabel('Feature 1')
plt.ylabel('Feature 2')
plt.show()
