# Echo State Network example

In [None]:
import time

import matplotlib.pyplot as plt
import torch
import torchvision

from qbraid_algorithms.esn import EchoStateNetwork, EchoStateReservoir

Download MNIST train and test data

In [None]:
transform = torchvision.transforms.Compose(
    [torchvision.transforms.ToTensor(), torchvision.transforms.Normalize([0.5], [0.5])]
)
trainset = torchvision.datasets.MNIST(root="./data", train=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, shuffle=True)
testset = torchvision.datasets.MNIST(root="./data", train=False, transform=transform)
testloader = torch.utils.data.DataLoader(testset, shuffle=False)

In [None]:
input_size = torch.prod(torch.tensor(trainset.data.shape[1:])).item()
output_size = len(torch.unique(trainset.targets.clone().detach()))

Initialize echo state network

In [None]:
hyperparams = {
    "hidden_size": 5000,
    "sparsity": 0.9,
    "spectral_radius": 0.99,
    "a": 0.6,
    "leak": 1.0,
}

reservoir = EchoStateReservoir(input_size, **hyperparams)
esn = EchoStateNetwork(reservoir, output_size).float()

Initialize optimizer using ESN parameters, and define loss criterion

In [None]:
optimizer = torch.optim.Adam(esn.parameters(), lr=1e-5)
criterion = torch.nn.CrossEntropyLoss()

Train network

In [None]:
nsamples = 1000
nepochs = 200
loss_values = []

start = time.time()
for epoch in range(nepochs):
    running_loss = 0.0
    for i, data in enumerate(trainset, 0):
        if i > nsamples:
            break
        images, labels = data

        # forward + backward + optimize
        outputs = esn(images)
        loss = criterion(outputs, torch.tensor([labels]))

        # zero the parameter gradients
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # log statistics
        running_loss += loss.item()
        if i % nsamples == (nsamples - 1):
            # print(f"epoch {epoch}, loss: {running_loss / nsamples:.3f}")
            loss_values.append(running_loss / nsamples)

end = time.time()

seconds = int(end - start)
minutes = seconds // 60
print(f"Training duration: {minutes} min {seconds % 60} sec")

In [None]:
fig, ax = plt.subplots(dpi=100)
ax.set_xlabel("Epoch")
ax.set_ylabel("Loss")
ax.plot(range(1, nepochs + 1), loss_values, color="red")
plt.show()

Test network accuracy

In [None]:
total = 0
correct = 0

for i, data in enumerate(testset, 0):
    if i > nsamples / 4:
        break
    images, labels = data
    pred = torch.argmax(esn(images)).item()
    if pred == labels:
        correct += 1
    total += 1
percent_correct = correct * 100.0 / total

print(f"Accuracy: {percent_correct:.2f} %")