In [1]:
import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
class Hebbian(nn.Module):
    def __init__(self, input_size, output_size, learning_rate=0.001):
        super(Hebbian, self).__init__()
        self.linear = nn.Linear(input_size, output_size, bias=False)
        self.learning_rate = learning_rate

    def forward(self, x):
        return self.linear(x)

    def update_weights(self, x, y):
        delta_w = torch.mm(x.t(), y) * self.learning_rate
        delta_w /= torch.norm(delta_w) + 1e-8
        self.linear.weight.data.add_(delta_w.t())

In [3]:
class HebbianVgg19(nn.Module):
    def __init__(self, num_classes=10):
        super(HebbianVgg19, self).__init__()
        vgg19 = torchvision.models.vgg19(pretrained=True)
        self.features = nn.Sequential(*list(vgg19.children())[:-1])
        for param in self.features.parameters():
            param.requires_grad = False
         
        self.classifier = Hebbian(25088, num_classes)

    def forward(self, x):
        x = self.features(x)
        x = x.view(x.size(0), -1)
        y = self.classifier(x)
        return y

In [4]:
# CIFAR10
transform = transforms.Compose(
    [transforms.Resize(224),
     transforms.ToTensor(),
     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

# MNIST
'''transform = transforms.Compose(
    [transforms.Resize(224),
     transforms.Grayscale(num_output_channels=3),
     transforms.ToTensor(),
     transforms.Normalize((0.5,), (0.5,))
     ])'''

trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
                                        download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=100,
                                          shuffle=True, num_workers=2)

testset = torchvision.datasets.CIFAR10(root='./data', train=False,
                                       download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=100,
                                         shuffle=False, num_workers=2)

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

net = HebbianVgg19().to(device)

criterion = nn.CrossEntropyLoss()

Files already downloaded and verified
Files already downloaded and verified




In [5]:
from tqdm import tqdm

epochs = 10

for epoch in range(epochs):
    running_loss = 0.0
    correct = 0
    total = 0
    with tqdm(trainloader, unit="batch") as tepoch:
        for i, data in enumerate(tepoch):
            inputs, labels = data
            inputs, labels = inputs.to(device), labels.to(device)

            net.zero_grad()
            outputs = net(inputs)
            loss = criterion(outputs, labels)

            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

            output_grad = torch.zeros_like(outputs)
            for idx, l in enumerate(labels):
                output_grad[idx, l] = 1
            output_grad -= nn.functional.softmax(outputs, dim=1)
            output_grad /= output_grad.size(0)

            net.classifier.update_weights(net.features(inputs).view(inputs.size(0), -1).detach(), output_grad.detach())

            running_loss += loss.item()
            tepoch.set_postfix(loss=running_loss / (i + 1), accuracy=correct / total * 100)

    net.eval()
    test_loss = 0.0
    test_correct = 0
    test_total = 0

    with torch.no_grad(), tqdm(total=len(testloader)) as pbar:
        for data in testloader:
            images, labels = data
            images, labels = images.to(device), labels.to(device)

            outputs = net(images)
            loss = criterion(outputs, labels)

            _, predicted = torch.max(outputs.data, 1)
            test_total += labels.size(0)
            test_correct += (predicted == labels).sum().item()

            test_loss += loss.item()
            pbar.update(1)

    print(f"Epoch {epoch + 1}, Train Loss: {running_loss / (i + 1)}, Train Acc: {correct / total * 100}. Test Loss: {test_loss / len(testloader)}, Test Acc: {test_correct / test_total * 100}")

print("Finished Training")

100%|██████████| 500/500 [05:22<00:00,  1.55batch/s, accuracy=70.4, loss=2.62]
  1%|          | 100/10000 [00:35<58:56,  2.80it/s] 


Epoch 1, Train Loss: 2.616222856402397, Train Acc: 70.39. Test Loss: 2.107984893321991, Test Acc: 78.21000000000001


100%|██████████| 500/500 [05:07<00:00,  1.63batch/s, accuracy=81.1, loss=1.69]
  1%|          | 100/10000 [00:33<55:51,  2.95it/s]


Epoch 2, Train Loss: 1.6891072260737419, Train Acc: 81.138. Test Loss: 2.329500181674957, Test Acc: 76.79


100%|██████████| 500/500 [05:09<00:00,  1.62batch/s, accuracy=84.5, loss=1.36]
  1%|          | 100/10000 [00:34<56:43,  2.91it/s]


Epoch 3, Train Loss: 1.3563514048457146, Train Acc: 84.476. Test Loss: 2.6183029890060423, Test Acc: 78.25999999999999


100%|██████████| 500/500 [05:09<00:00,  1.61batch/s, accuracy=87, loss=1.09]  
  1%|          | 100/10000 [00:34<56:43,  2.91it/s]


Epoch 4, Train Loss: 1.0940768716931344, Train Acc: 87.044. Test Loss: 2.0865228736400603, Test Acc: 81.14


100%|██████████| 500/500 [05:12<00:00,  1.60batch/s, accuracy=88.6, loss=0.943]
  1%|          | 100/10000 [00:35<59:09,  2.79it/s] 


Epoch 5, Train Loss: 0.9426973187029362, Train Acc: 88.64999999999999. Test Loss: 2.0720505994558334, Test Acc: 82.21000000000001


100%|██████████| 500/500 [05:21<00:00,  1.55batch/s, accuracy=90.1, loss=0.799]
  1%|          | 100/10000 [00:36<1:00:09,  2.74it/s]


Epoch 6, Train Loss: 0.7994300420507788, Train Acc: 90.106. Test Loss: 2.1199687603116035, Test Acc: 82.44


100%|██████████| 500/500 [05:05<00:00,  1.64batch/s, accuracy=91.3, loss=0.694]
  1%|          | 100/10000 [00:34<57:03,  2.89it/s] 


Epoch 7, Train Loss: 0.6938239350989461, Train Acc: 91.328. Test Loss: 1.98048497736454, Test Acc: 83.23


100%|██████████| 500/500 [05:10<00:00,  1.61batch/s, accuracy=92.1, loss=0.636]
  1%|          | 100/10000 [00:34<56:39,  2.91it/s]


Epoch 8, Train Loss: 0.6355333993807435, Train Acc: 92.144. Test Loss: 2.1381836956739426, Test Acc: 82.78999999999999


100%|██████████| 500/500 [05:09<00:00,  1.62batch/s, accuracy=92.8, loss=0.565]
  1%|          | 100/10000 [00:34<56:52,  2.90it/s]


Epoch 9, Train Loss: 0.5648845233730971, Train Acc: 92.824. Test Loss: 2.310579543709755, Test Acc: 82.36


100%|██████████| 500/500 [05:19<00:00,  1.57batch/s, accuracy=93.5, loss=0.484]
  1%|          | 100/10000 [00:36<59:51,  2.76it/s] 

Epoch 10, Train Loss: 0.4835554581945762, Train Acc: 93.54400000000001. Test Loss: 2.883821804523468, Test Acc: 81.93
Finished Training



