In [None]:
import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt

from MemSE.train_test_loop import test, test_mse_th, test_mse_sim
from MemSE.model_load import load_memristor
from MemSE.dataset import get_dataloader

device = torch.device('cuda')
print(device)

bs = 4
nb_batch = 100 / bs
train_loader, valid_loader, test_loader, nclasses, input_shape = get_dataloader('CIFAR10', bs=bs)
criterion = nn.CrossEntropyLoss().to(device)

models_names = ['smallest_vgg', 'really_small_vgg']

N = 128

memse = load_memristor(models_names[0], nclasses, 'all', device, input_shape, 0.01, N)

# Figure 2 results

In [None]:
accuracies = {n: [] for n in models_names}
mses_th = {n: [] for n in models_names}
mses_sim = {n: [] for n in models_names}
std_noise_fig_2 = np.linspace(1e-3, 1e-1, 10)

for network in models_names:
    for sig in std_noise_fig_2:
        memse.quanter.std_noise = sig
        memse.quant()
        accuracies[network].append(test(valid_loader, memse.quanter, criterion, device=device, batch_stop=nb_batch))
        mses, _ = test_mse_th(valid_loader, memse, device, nb_batch)
        mses_th[network].append(np.mean(mses))
        mses_sim[network].append(np.mean(test_mse_sim(valid_loader, memse, device, nb_batch)))
        memse.unquant()

In [None]:
fig_2 = plt.subplot()
fig_2_acc = fig_2.twinx()
for network in models_names:
    fig_2.plot(std_noise_fig_2, mses_th[network], label=f'{network=} th')
    fig_2.plot(std_noise_fig_2, mses_sim[network], label=f'{network=} sim')

    fig_2_acc.plot(std_noise_fig_2, accuracies[network], label=f'{network=} accuracy')
plt.legend()
plt.show()

# Figure 3 results

In [None]:
std_noise_fig_3 = [0.01, 0.003, 0.001]
Gmax_fig_3 = np.logspace(1e-1, 1e2)
network = models_names[0]

mses_sim_fig_3 = {n: [] for n in std_noise_fig_3}
mses_th_fig_3 = {n: [] for n in std_noise_fig_3}
for sig in std_noise_fig_3:
    for Gmax in Gmax_fig_3:
        memse.quanter.init_gmax(Gmax)
        mses_sim_fig_3[sig].append(np.mean(test_mse_sim(valid_loader, memse, device, nb_batch)))
        mses, _ = test_mse_th(valid_loader, memse, device, nb_batch)
        mses_th_fig_3[sig].append(np.mean(mses))

In [None]:
fig_3 = plt.subplot()
for sig in std_noise_fig_3:
    fig_3.plot(Gmax_fig_3, mses_th_fig_3[sig], label=f'Theory {sig=}')
    fig_3.plot(Gmax_fig_3, mses_sim_fig_3[sig], label=f'Simulation {sig=}')

# Figure 4 results

In [None]:
std_noise_fig_4 = 0.01
Gmax_fig_4 = []
# TODO load Gmaxses per layer