# MNIST CLASSIFIER
### Let's build a hybrid classical-quantum algorithm to classify the digits of the MNIST dataset

In [1]:
# First, we import the libaries
import sys

sys.path.append('../')
import quforge.quforge as qf
from quforge.quforge import State as State

In [2]:
#Load the dataset
import torch
import torchvision
import torchvision.transforms as transforms

N_train = 1000 #number of samples from the training set
N_test = 1000 #number of samples from the test set

transform_train = transforms.Compose([
    transforms.ToTensor(),
])

trainset = torchvision.datasets.MNIST(root='../../datasets/', train=True, download=True, transform=transform_train)
indices = torch.arange(N_train)
trainset = torch.utils.data.dataset.Subset(trainset, indices)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=1, shuffle=True, num_workers=1)

testset = torchvision.datasets.MNIST(root='../../datasets/', train=False, download=True, transform=transform_train)
indices = torch.arange(N_test)
testset = torch.utils.data.dataset.Subset(testset, indices)
testloader = torch.utils.data.DataLoader(testset, batch_size=1, shuffle=False, num_workers=1)

In [3]:
#Now, we define the hybrid model

class Circuit(qf.Module):
    def __init__(self, dim, wires):
        super(Circuit, self).__init__()

        self.encoder = qf.Sequential(
            qf.Linear(784, 10)
        )

        self.init = qf.H(dim=dim, index=range(wires))
        self.qencoder = qf.RZ(dim=dim, index=range(wires))

        self.layers1 = qf.Sequential(
            qf.RX(dim=dim, index=range(wires)),
            qf.RY(dim=dim, index=range(wires)),
            qf.RZ(dim=dim, index=range(wires)),
        )

        self.layers2 = []
        for i in range(9):
            self.layers2.append(qf.CNOT(dim=dim, wires=wires, index=[0,1]))
        self.layers2 = qf.ModuleList(self.layers2)

        self.layers3 = qf.Sequential(
            qf.RX(dim=dim, index=range(wires)),
            qf.RY(dim=dim, index=range(wires)),
            qf.RZ(dim=dim, index=range(wires)),
        )

    def forward(self, x):
        x = x.reshape((1, 784))
        x = self.encoder(x)
        x = x.flatten()

        y = State('0-0-0-0-0-0-0-0-0-0', device=x.device)
        y = self.init(y)
        y = self.qencoder(y, param=x)
        y = self.layers1(y)
        for i in range(9):
            y = self.layers2[i](y)
        y = self.layers3(y)

        return y 

In [4]:
#Instatiate the circuit and define the optimizer
device = 'cuda'
model = Circuit(dim=2, wires=10).to(device)
optimizer = qf.optim.Adam(model.parameters(), lr=0.001, betas=(0.9,0.999))

In [5]:
#Define the target states
targets = []
for i in range(10):
    state = ''
    for j in range(10):
        if j == i:
            state += '1'
        else:
            state += '0'
        if j < 9:
            state += '-'
    targets.append(State(state, dim=2, device=device))

targets_arg = []
for i in range(10):
    targets_arg.append(torch.argmax(abs(targets[i])))

In [6]:
#Let's train the model
for epoch in range(8):
    acc_train = 0
    acc_test = 0
    for batch_idx, (inputs, labels) in enumerate(trainloader):
        inputs = inputs.to(device)

        output = model(inputs)
        target = targets[labels[0]].reshape(output.shape)
        F = qf.fidelity(target, output)
        loss = (1-F)**2

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        _, m = qf.measure(output, index=range(10), dim=2)
        predict = torch.argmax(m).item()
        if predict == targets_arg[labels[0]]:
            acc_train += 1

    with torch.no_grad():
        for batch_idx, (inputs, labels) in enumerate(testloader):
            inputs = inputs.to(device)
            output = model(inputs)
            _, m = qf.measure(output, index=range(10), dim=2)

            predict = torch.argmax(m)
            if predict == targets_arg[labels[0]]:
                acc_test += 1

    acc_train = acc_train/N_train
    acc_test = acc_test/N_test
    print(epoch, acc_train, acc_test)

0 0.489 0.478
