In [1]:
import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim

from src.data.load_dataset import load_mnist, load_kmnist
from src.models.networks import V1_mnist_RFNet, classical_RFNet
from src.models.utils import train, test

In [2]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

### MNIST

In [11]:
train_loader, val_loader, test_loader = load_mnist(2048, 0.1)

In [12]:
h, s, f, c = 200, 5.34, 2, None
model = V1_mnist_RFNet(h, s, f, c).to(device)

# hyperparams
lr = 1E-2
optimizer = optim.SGD(model.parameters(), lr=lr)

# train
epochs = 20
log_interval = 5
for epoch in range(1, epochs + 1):
    train(log_interval, device, model, train_loader, optimizer, epoch, verbose=True)
    val_accuracy = test(model, device, val_loader)
# calculate and print test accuracy
test_accuracy = test(model, device, test_loader)


Test set: Average loss: 1331.860718. Accuracy: 22403/54000 (41.49%)


Test set: Average loss: 835.472961. Accuracy: 15019/54000 (27.81%)


Test set: Average loss: 676.884033. Accuracy: 28046/54000 (51.94%)


Test set: Average loss: 276.345001. Accuracy: 33985/54000 (62.94%)


Test set: Average loss: 216.392761. Accuracy: 34525/54000 (63.94%)


Test set: Average loss: 69.657921. Accuracy: 35763/54000 (66.23%)


Test set: Average loss: 141.212891. Accuracy: 38513/54000 (71.32%)


Test set: Average loss: 101.331673. Accuracy: 36952/54000 (68.43%)


Test set: Average loss: 44.618378. Accuracy: 41895/54000 (77.58%)


Test set: Average loss: 43.193829. Accuracy: 42321/54000 (78.37%)


Test set: Average loss: 43.135883. Accuracy: 42320/54000 (78.37%)


Test set: Average loss: 44.988132. Accuracy: 42157/54000 (78.07%)


Test set: Average loss: 82.002724. Accuracy: 41946/54000 (77.68%)


Test set: Average loss: 24.351297. Accuracy: 45826/54000 (84.86%)


Test set: Average loss: 79.862709. Accu

In [8]:
## classical network
inp_size, hidden_size = (1, 28, 28), 200
model = classical_RFNet(inp_size, hidden_size, seed=10).to(device)

# optimizer
lr = 1E-2
optimizer = optim.SGD(model.parameters(), lr=lr)

# train
epochs = 10
log_interval = 5
for epoch in range(1, epochs + 1):
    train(log_interval, device, model, train_loader, optimizer, epoch, verbose=True)
    val_accuracy = test(model, device, val_loader)
# calculate and print test accuracy
test_accuracy = test(model, device, test_loader)



Test set: Average loss: 2.310560. Accuracy: 5062/54000 (9.37%)


Test set: Average loss: 2.289449. Accuracy: 5905/54000 (10.94%)


Test set: Average loss: 2.269177. Accuracy: 7073/54000 (13.10%)


Test set: Average loss: 2.249593. Accuracy: 8584/54000 (15.90%)


Test set: Average loss: 2.230593. Accuracy: 10419/54000 (19.29%)


Test set: Average loss: 2.212053. Accuracy: 12454/54000 (23.06%)


Test set: Average loss: 2.193958. Accuracy: 14714/54000 (27.25%)


Test set: Average loss: 2.176253. Accuracy: 16918/54000 (31.33%)


Test set: Average loss: 2.158844. Accuracy: 19036/54000 (35.25%)


Test set: Average loss: 2.141750. Accuracy: 20939/54000 (38.78%)


Test set: Average loss: 2.135391. Accuracy: 3886/10000 (38.86%)



### KMNIST

In [None]:
train_loader, val_loader, test_loader = load_kmnist(128, 0.9)

In [None]:
h, s, f, c = 100, 5, 2, None
model = V1_mnist_RFNet(h, s, f, c).to(device)

# hyperparams
lr = 1E-4
optimizer = optim.Adam(model.parameters(), lr=lr)

# train
epochs = 5
log_interval = 5
for epoch in range(1, epochs + 1):
    train(log_interval, device, model, train_loader, optimizer, epoch, verbose=True)
    val_accuracy = test(model, device, val_loader)
# calculate and print test accuracy
test_accuracy = test(model, device, test_loader)

In [None]:
## classical network
inp_size, hidden_size = (1, 28, 28), 100
model = classical_RFNet(inp_size, hidden_size, seed=10).to(device)

# optimizer
lr = 1E-4
optimizer = optim.Adam(model.parameters(), lr=lr)

# train
epochs = 5
log_interval = 5
for epoch in range(1, epochs + 1):
    train(log_interval, device, model, train_loader, optimizer, epoch, verbose=True)
    val_accuracy = test(model, device, val_loader)
# calculate and print test accuracy
test_accuracy = test(model, device, test_loader)

