In [14]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision

from torchvision.datasets import MNIST

import torchmetrics

from tqdm import tqdm

import torchhd

from torchhd.models import Centroid
from torchhd import embeddings

In [27]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("...using {} device".format(device))

D = 10000
IMG_SIZE = 28
N_LEVELS = 1000
BATCH_SIZE = 1

...using cpu device


In [28]:
transform = torchvision.transforms.ToTensor()

train_dataset = MNIST("./data/", train=True, transform=transform, download=True)
train_labeled_data = torch.utils.data.DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)

test_dataset = MNIST("./data/", train=False, transform=transform, download=True)
test_labeled_data = torch.utils.data.DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=True)

In [29]:
class Encoder(nn.Module):
  def __init__(self, out_features, size, levels):
    super(Encoder, self).__init__()
    self.flatten = torch.nn.Flatten()
    self.position = embeddings.Random(size * size, out_features)
    self.value = embeddings.Level(levels, out_features)

  def forward(self, x):
    x = self.flatten(x)
    sample_hv = torchhd.bind(self.position.weight, self.value(x))
    sample_hv = torchhd.multiset(sample_hv)
    return torchhd.hard_quantize(sample_hv)
  
encode = Encoder(D, IMG_SIZE, N_LEVELS)
encode = encode.to(device)

n_classes = len(train_dataset.classes)
model = Centroid(D, n_classes)
model = model.to(device)

In [30]:
with torch.no_grad():
  for samples, labels in tqdm(train_labeled_data, desc="Training"):
    samples = samples.to(device)
    labels = labels.to(device)
    
    samples_hv = encode(samples)
    model.add(samples_hv, labels)

Training: 100%|██████████| 60000/60000 [04:53<00:00, 204.61it/s]


In [31]:
accuracy = torchmetrics.Accuracy("multiclass", num_classes=n_classes)

with torch.no_grad():
  model.normalize()

  for samples, labels in tqdm(test_labeled_data, desc="Testing"):
    samples = samples.to(device)

    samples_hv = encode(samples)
    outputs = model(samples_hv, dot=True)
    accuracy.update(outputs.cpu(), labels)

Testing: 100%|██████████| 10000/10000 [00:47<00:00, 212.76it/s]


In [33]:
print(f"Testing accuracy of {(accuracy.compute().item() * 100):.3f}%")

Testing accuracy of 82.840%
