In [None]:
# -*- coding: utf-8 -*-
%matplotlib inline
import matplotlib.pyplot as plt

import math
import numpy as np
import pandas as pd
import seaborn as sns

#from sklearn.metrics import accuracy_score, log_loss
#from sklearn.model_selection import StratifiedKFold
from sklearn.decomposition import PCA

import torch
import torch.nn as nn
#from torch.optim import Adam, SGD
import torch.utils.checkpoint as checkpoint

import torchvision
from torchvision.transforms import v2

import warnings
import time
#import functools
#import copy
from tqdm import tqdm

seed = 1001
np.random.seed(seed)
torch.manual_seed(seed)
torch.set_default_dtype(torch.float32)
warnings.simplefilter('ignore', UserWarning)

In [None]:
import torchquantum as tq
from torchquantum.measurement import expval_joint_analytical

In [None]:
print(torch.__version__)
print(tq.__version__)

In [None]:
n_gpu = torch.cuda.device_count()
print(n_gpu)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
# device = "cpu"
print([torch.cuda.get_device_name('cuda:'+str(i)) for i in range(n_gpu)])

In [None]:
data = torchvision.datasets.MNIST(root='./data', train=True, download=True)
data_tr, label_tr = data.train_data, data.train_labels
data = torchvision.datasets.MNIST(root='./data', train=False, download=True)
data_te, label_te = data.test_data, data.test_labels
print(data_tr.shape, data_te.shape)

In [None]:
num = 10
for i in range(num):
    ax = plt.subplot(2,num//2,i+1)
    ax.imshow(data_tr[i].detach().numpy(), cmap='gray_r')

In [None]:
dataset_name = 'mnist'
n_class = len(np.unique(label_tr))

In [None]:
n_qubits = 8
n_half_qubits = n_qubits//2 # 4
n_latter_half_qubits = n_qubits-n_half_qubits

In [None]:
class ConstCoeffLayer(nn.Module):
    def __init__(self, coeff):
        super().__init__()
        self.coeff = coeff
    def forward(self, x):
        ret = x * self.coeff
        return ret

In [None]:
def calc_exp_val(qdev, obs):
    assert len(obs)==n_qubits
    state2 = qdev.states.clone()
    for i in range(n_qubits):
        if obs[i]=='I':
            continue
        elif obs[i]=='X':
            mat = torch.tensor([[0,1],[1,0]])
        elif obs[i]=='Y':
            mat = torch.tensor([[0,-1j],[1j,0]])
        elif obs[i]=='Z':
            mat = torch.tensor([[1,0],[0,-1]])
        state2 = tq.functional.apply_unitary_bmm(state2, mat, [i])
    state1 = qdev.states.clone()
    exp_val = torch.einsum("bij...k,bij...k->b", state1.conj(), state2).real
    # to confirm the calculation of expectation values
    if False:
        bra = qdev.get_states_1d().clone().conj()
        ket = torch.reshape(state2, [bsz, 2**n_qubits])
        tmp_exp_val = torch.sum(bra*ket,dim=1).real # (bsz,dim) => (bsz)
        assert np.allclose(tmp_exp_val.detach().cpu().numpy(),exp_val.detach().cpu().numpy(),rtol=1e-5,atol=1e-5)
    return exp_val

In [None]:
# 14x14 => 7x14x2
# 2n_qubitsx28 => 14x7x8 = 14x28x2
class QNNsubModel(nn.Module):
    def __init__(self,n_depth_per_block):
        # params is numpy array
        super().__init__()
        self.n_depth_per_block = n_depth_per_block
    def forward(self, x, phi):
        bsz, nx_features = x.shape
        qdev = tq.QuantumDevice(
            n_wires=n_qubits, bsz = bsz, device=x.device, record_op=False
        )
        n_depth_per_block = self.n_depth_per_block
        for k in range(n_depth_per_block):
            # j = 2*d*n_depth_per_block + 2*k
            for i in range(n_qubits):
                tq.functional.rx(qdev, wires=i, params=phi[i+2*k*n_qubits])
            for i in range(n_qubits):
                tq.functional.ry(qdev, wires=i, params=phi[i+(2*k+1)*n_qubits])
            for i in range(n_qubits):
                qdev.cz(wires=[i,(i+1)%n_qubits])
        for i in range(n_qubits): # x: 32, phi: 64
            for j in range(n_half_qubits):
                if j%2==0:
                    tq.functional.ry(qdev, wires=i, params=phi[2*n_half_qubits*i+2*j+2*n_depth_per_block*n_qubits])
                    tq.functional.rx(qdev, wires=i, params=x[:,n_half_qubits*i+j]) ##
                    tq.functional.ry(qdev, wires=i, params=phi[2*n_half_qubits*i+2*j+1+2*n_depth_per_block*n_qubits])
                else:
                    tq.functional.rx(qdev, wires=i, params=phi[2*n_half_qubits*i+2*j+2*n_depth_per_block*n_qubits])
                    tq.functional.ry(qdev, wires=i, params=x[:,n_half_qubits*i+j]) ##
                    tq.functional.rx(qdev, wires=i, params=phi[2*n_half_qubits*i+2*j+1+2*n_depth_per_block*n_qubits])
        for i in range(n_qubits):
            qdev.cz(wires=[i,(i+1)%(n_qubits)])
        for k in range(n_depth_per_block):
            # j = 2*d*n_depth_per_block + 2*k
            for i in range(n_qubits):
                tq.functional.rx(qdev, wires=i, params=phi[i+(2*n_depth_per_block+2*n_half_qubits +2*k)*n_qubits])
            for i in range(n_qubits):
                tq.functional.ry(qdev, wires=i, params=phi[i+(2*n_depth_per_block+2*n_half_qubits +2*k+1)*n_qubits])
            for i in range(n_qubits):
                qdev.cz(wires=[i,(i+1)%n_qubits])
        for i in range(n_qubits): # 32, 64
            for j in range(n_latter_half_qubits):
                if j%2==0:
                    tq.functional.ry(qdev, wires=i, params=phi[2*n_latter_half_qubits*i+2*j+2*n_half_qubits*n_qubits+4*n_depth_per_block*n_qubits])
                    tq.functional.rx(qdev, wires=i, params=x[:,n_latter_half_qubits*i+j+n_half_qubits*n_qubits]) ##
                    tq.functional.ry(qdev, wires=i, params=phi[2*n_latter_half_qubits*i+2*j+1+2*n_half_qubits*n_qubits+4*n_depth_per_block*n_qubits])
                else:
                    tq.functional.rx(qdev, wires=i, params=phi[2*n_latter_half_qubits*i+2*j+2*n_half_qubits*n_qubits+4*n_depth_per_block*n_qubits])
                    tq.functional.ry(qdev, wires=i, params=x[:,n_latter_half_qubits*i+j+n_half_qubits*n_qubits]) ##
                    tq.functional.rx(qdev, wires=i, params=phi[2*n_latter_half_qubits*i+2*j+1+2*n_half_qubits*n_qubits+4*n_depth_per_block*n_qubits])
        for i in range(n_qubits):
            qdev.cz(wires=[i,(i+1)%(n_qubits)])
        j= 2
        for k in range(n_depth_per_block):
            for i in range(n_qubits):
                tq.functional.rx(qdev, wires=i, params=phi[i+(4*n_depth_per_block+2*n_qubits +2*k)*n_qubits])
            for i in range(n_qubits):
                tq.functional.ry(qdev, wires=i, params=phi[i+(4*n_depth_per_block+2*n_qubits +2*k+1)*n_qubits])
            if (k==n_depth_per_block-1):
                break
            for i in range(n_qubits):
                qdev.cz(wires=[i,(i+1)%n_qubits])
        obs_list = [ calc_exp_val(qdev, "I"*i+Pauli+"I"*(n_qubits-1-i)) for Pauli in ["X","Z"] for i in range(n_class//2)]
        ret = torch.stack(obs_list, dim=1)
        return ret

In [None]:
# 14x14 => 7x14x2
# 1:8, 7:14 -> 0:7, 6:13
# 1:8, 4:11, 7:14 -> 0:7, 3:10, 6:13
class QNNModel(nn.Module):
    def __init__(self,n_qnn,n_depth_per_block):
        super().__init__()
        self.n_qnn = n_qnn
        self.n_depth_per_block = n_depth_per_block
        self.params_list = nn.ParameterList([torch.rand( (3*2*n_depth_per_block+2*n_qubits)*n_qubits )*math.pi for _ in range(n_qnn)])
        self.pos_bias = nn.Parameter( torch.zeros(14, 14, device=device) )
        self.qnn_list = [QNNsubModel(n_depth_per_block) for _ in range(n_qnn)]
        self.qnn_index_list = [i for i in range(0,7,6//int(np.sqrt(n_qnn)-1))]
    def forward(self, x):
        n_data = len(x)
        in_x = x + self.pos_bias
        in_x = torch.stack([ in_x[:,i:i+n_qubits,j:j+n_qubits].reshape(n_data,n_qubits*n_qubits) for i in self.qnn_index_list for j in self.qnn_index_list ], axis=0) # (4,n_data,64)
        ret_list = [checkpoint.checkpoint(self.qnn_list[i], in_x[i], self.params_list[i]) for i in range(self.n_qnn)]
        ret = torch.stack(ret_list, axis=1) # (bsz, n_qnn, n_class)
        ret = torch.mean(ret, axis=1) # (bsz,n_class)
        return ret

In [None]:
def train(data, label, model, accumulation_steps):
    pred = model(data) # (bsz, n_class)
    loss = torch.nn.CrossEntropyLoss()(pred, label)
    loss = loss / accumulation_steps
    loss.backward()
    with torch.no_grad():
        pred_normalized = nn.functional.softmax(pred, dim=1)
        acc = (pred_normalized.argmax(axis=1) == label).sum().item() / len(label)
    print(f"train loss: {loss.item()*accumulation_steps:.5f} train acc: {acc:.3f}", end='\n')
    return loss.item(), acc

def valid(data, label, model):
    with torch.no_grad():
        pred = model(data)
        loss = torch.nn.CrossEntropyLoss()(pred, label)
        pred_normalized = nn.functional.softmax(pred, dim=1)
        acc = (pred_normalized.argmax(axis=1) == label).sum().item() / len(label)
    print(f"valid loss: {loss.item():.5f} valid acc: {acc:.3f}", end='\n')
    return loss.item(), acc

In [None]:
data_tr = data_tr/255*2*math.pi/n_qubits
data_te = data_te/255*2*math.pi/n_qubits

data_tr = torch.nn.AvgPool2d( (2,2), stride=(2,2) )(data_tr) # (28,28) -> (14,14)
data_te = torch.nn.AvgPool2d( (2,2), stride=(2,2) )(data_te) # (28,28) -> (14,14)

#data_tr = data_tr.reshape(-1,data_tr.shape[1]*data_tr.shape[2]) #.detach().numpy()
#data_te = data_te.reshape(-1,data_te.shape[1]*data_te.shape[2]) #.detach().numpy()
print(data_tr.shape, data_te.shape)

In [None]:
max_epochs = 50
coeff=100
data_tr, label_tr = data_tr.to(device), label_tr.to(device)
data_te, label_te = data_te.to(device), label_te.to(device)

In [None]:
for n_qnn in [4,9,16]:
    for n_depth_per_block in [50,100,150,200]:
        model = torch.nn.Sequential( QNNModel(n_qnn, n_depth_per_block), ConstCoeffLayer(coeff) )
        if n_depth_per_block==50:
            dir_name = 'tmp_8qubits_'+str(n_qnn)+'qnn'
        else:
            dir_name = 'tmp_8qubits_'+str(n_qnn)+'qnn'+str(n_depth_per_block)

        #if ((n_qnn==4 or n_qnn==9) and n_depth_per_block==50)or(n_qnn==4 and n_depth_per_block==100):
        #    prefix_name = ''
        #else:
        prefix_name = "mnist_"+str(n_qnn)+"qnn"+str(n_depth_per_block)+"_c100_8qubits_ensembling_cos"
        model.load_state_dict(torch.load(dir_name+'/'+prefix_name+'_init.pt', weights_only=True))
        loss_tr_list = []
        loss_te_list = []
        acc_tr_list = []
        acc_te_list = []
        with torch.no_grad():
            pred_tr = model(data_tr)
            loss = torch.nn.CrossEntropyLoss()(pred_tr, label_tr)
            loss_tr_list.append(loss.item())
            pred_tr_normalized = nn.functional.softmax(pred_tr, dim=1)
            acc_tr = (pred_tr_normalized.argmax(axis=1) == label_tr).sum().item() / len(label_tr)
            acc_tr_list.append(acc_tr)
            time.sleep(10)
            pred_te = model(data_te)
            loss = torch.nn.CrossEntropyLoss()(pred_te, label_te)
            loss_te_list.append(loss.item())
            pred_te_normalized = nn.functional.softmax(pred_te, dim=1)
            acc_te= (pred_te_normalized.argmax(axis=1) == label_te).sum().item() / len(label_te)
            acc_te_list.append(acc_te)
            time.sleep(10)
        for epoch_i in tqdm(range(max_epochs)):
            model.load_state_dict(torch.load(dir_name+'/'+prefix_name+'_epoch'+str(epoch_i)+'.pt', weights_only=True))
            pred_tr = model(data_tr)
            loss = torch.nn.CrossEntropyLoss()(pred_tr, label_tr)
            loss_tr_list.append(loss.item())
            pred_tr_normalized = nn.functional.softmax(pred_tr, dim=1)
            acc_tr = (pred_tr_normalized.argmax(axis=1) == label_tr).sum().item() / len(label_tr)
            acc_tr_list.append(acc_tr)
            time.sleep(10)
            pred_te = model(data_te)
            loss = torch.nn.CrossEntropyLoss()(pred_te, label_te)
            loss_te_list.append(loss.item())
            pred_te_normalized = nn.functional.softmax(pred_te, dim=1)
            acc_te= (pred_te_normalized.argmax(axis=1) == label_te).sum().item() / len(label_te)
            acc_te_list.append(acc_te)
            time.sleep(10)
        loss_tr_np = np.array(loss_tr_list)
        loss_te_np = np.array(loss_te_list)
        acc_tr_np = np.array(acc_tr_list)
        acc_te_np = np.array(acc_te_list)
        LEN = len(loss_tr_np)
        pd.DataFrame({'epochs': np.arange(LEN)-1, 'train_loss': loss_tr_np, 'test_loss': loss_te_np, \
                    'train_acc': acc_tr_np, 'test_acc': acc_te_np
                    }).to_csv(dir_name+'/'+prefix_name+'_loss_acc.csv', index=False)