In [23]:
import os
from pathlib import Path

import torch
import numpy as np
import brainbox
from brainbox import trainer as bb_trainer

from block import datasets, trainer
from block.models import CIFAR10Model
from block.datasets.transforms import Normalize, ToClip, List

%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [24]:
path = Path(os.path.abspath("")).parent

We ran training with the following paremeters:

python train.py --method=fast_naive --t_len=4 --beta_requires_grad=False --readout_max=False --single_spike=True --gamma=0.1 --dataset=cifar10 --load_spatial_dims=True --use_augmentation=True --epoch=140 --batch=128 --lr=0.001 --track_activity=False"

## Load dataset

In [15]:
t_len = 4

train_transform = List.get_cifar10_transform(t_len, use_augmentation=True)
test_transform = brainbox.datasets.transforms.Compose(
    [
        Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
        ToClip(t_len)
        
    ]
)
dataset = datasets.CIFAR10Dataset(os.path.join(path, "data"), t_len=4, transform=train_transform)
test_dataset = datasets.CIFAR10Dataset(os.path.join(path, "data"),  train=False, t_len=4, transform=test_transform)

Files already downloaded and verified
Files already downloaded and verified


## Load model

In [25]:
model_path = f"{path}/results/datasets/cifar10"
model_id = "7a223732d8f147a4a5c23167c56586e0"
device = "cuda"
dtype = torch.float
model = trainer.Trainer.load_model(model_path, model_id, device, dtype)

In [26]:
def accuracy_metric(output, target):
    _, predictions = torch.max(output, 1)
    return (predictions == target).sum().cpu().item()

In [27]:
scores = bb_trainer.compute_metric(model, dataset, accuracy_metric, batch_size=128)
print(f"Train acc {np.sum(scores)/len(dataset)}")

Train acc 0.88822


In [28]:
scores = bb_trainer.compute_metric(model, test_dataset, accuracy_metric, batch_size=128)
print(f"Test acc {np.sum(scores)/len(test_dataset)}")

Test acc 0.8214
