In [1]:
import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
class Hebbian(nn.Module):
    def __init__(self, input_size, output_size, learning_rate=0.001):
        super(Hebbian, self).__init__()
        self.linear = nn.Linear(input_size, output_size, bias=False)
        self.learning_rate = learning_rate

    def forward(self, x):
        return self.linear(x)

    def update_weights(self, x, y):
        delta_w = torch.mm(x.t(), y) * self.learning_rate
        delta_w /= torch.norm(delta_w) + 1e-8
        self.linear.weight.data.add_(delta_w.t())

class HebbianResNet50(nn.Module):
    def __init__(self, num_classes=10):
        super(HebbianResNet50, self).__init__()
        resnet50 = torchvision.models.resnet50(pretrained=True)
        self.features = nn.Sequential(*list(resnet50.children())[:-1])
        for param in self.features.parameters():
            param.requires_grad = False
         
        self.hebb1 = Hebbian(2048, 2048)
        self.hebb2 = Hebbian(2048, num_classes)
        self.gelu = nn.GELU()
        self.softmax = nn.Softmax(dim=1)

    def forward(self, x):
        x = self.features(x)
        x = x.view(x.size(0), -1)
        x = self.gelu(self.hebb1(x))
        x = self.softmax(self.hebb2(x))
        return x

transform = transforms.Compose(
    [transforms.Resize(224),
     transforms.Grayscale(num_output_channels=3),
     transforms.ToTensor(),
     transforms.Normalize((0.5,), (0.5,))
     ])

trainset = torchvision.datasets.MNIST(root='./data', train=True,
                                        download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=100,
                                          shuffle=True, num_workers=2)

testset = torchvision.datasets.MNIST(root='./data', train=False,
                                       download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=100,
                                         shuffle=False, num_workers=2)

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

net = HebbianResNet50().to(device)

criterion = nn.CrossEntropyLoss()

from tqdm import tqdm

epochs = 10

for epoch in range(epochs):
    running_loss = 0.0
    correct = 0
    total = 0
    with tqdm(trainloader, unit="batch") as tepoch:
        for i, data in enumerate(tepoch):
            inputs, labels = data
            inputs, labels = inputs.to(device), labels.to(device)

            net.zero_grad()
            outputs = net(inputs)
            loss = criterion(outputs, labels)

            feat = net.features(inputs).view(inputs.size(0), -1).detach()

            output_grad = torch.zeros_like(outputs)
            for idx, l in enumerate(labels):
                output_grad[idx, l] = 1
            output_grad -= outputs
            output_grad /= output_grad.size(0)

            net.hebb1.update_weights(feat, net.hebb1(feat).detach())
            net.hebb2.update_weights(net.hebb1(feat).detach(), output_grad.detach())

            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

            running_loss += loss.item()
            tepoch.set_postfix(loss=running_loss / (i + 1), accuracy=correct / total * 100)

    net.eval()
    test_loss = 0.0
    test_correct = 0
    test_total = 0

    with torch.no_grad(), tqdm(total=len(testset)) as pbar:
        for data in testloader:
            images, labels = data
            images, labels = images.to(device), labels.to(device)

            outputs = net(images)
            loss = criterion(outputs, labels)

            _, predicted = torch.max(outputs.data, 1)
            test_total += labels.size(0)
            test_correct += (predicted == labels).sum().item()

            test_loss += loss.item()
            pbar.update(1)

    print(f"Epoch {epoch + 1}, Train Loss: {running_loss / (i + 1)}, Train Acc: {correct / total * 100}. Test Loss: {test_loss / len(testloader)}, Test Acc: {test_correct / test_total * 100}")

print("Finished Training")


100%|██████████| 600/600 [02:38<00:00,  3.78batch/s, accuracy=10.1, loss=2.36]
  1%|          | 100/10000 [00:14<23:58,  6.88it/s]


Epoch 1, Train Loss: 2.3590021761258444, Train Acc: 10.145. Test Loss: 2.363150744438171, Test Acc: 9.8


 60%|██████    | 361/600 [01:29<00:59,  4.05batch/s, accuracy=9.88, loss=2.36]


KeyboardInterrupt: 