In [1]:
device = "cpu"
import os,torch,idx2numpy,pytorch_spiking,time
import torchvision
import torchvision.transforms as transforms
import torch.nn as nn
import pennylane as qml
from tqdm import tqdm
from pennylane import numpy as np
torch.manual_seed(0)
np.random.seed(0)
batch = 100
os.environ["OMP_NUM_THREADS"] = "16"

In [2]:
data_path = './raw/'  # ta" if torch.cuda.is_available() else "cpu")
device = torch.device("cpu" if torch.cuda.is_available() else "cuda")
train_dataset = torchvision.datasets.FashionMNIST(root=data_path, train=True, download=True, transform=transforms.ToTensor())
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch, shuffle=True, num_workers=0)

test_set = torchvision.datasets.FashionMNIST(root=data_path, train=False, download=True, transform=transforms.ToTensor())
test_loader = torch.utils.data.DataLoader(test_set, batch_size=batch, shuffle=False, num_workers=0)
train_images = []
train_labels = []
for image, label in train_dataset:
    train_images.append(image.numpy())  # Convert image tensor to numpy array
    train_labels.append(label)

# Step 2: Process image data (optional: normalization)
train_images = np.array(train_images)  # Convert list of arrays to numpy array
train_images = train_images.astype(np.float32) / 255.0  # Normalize image data

# Step 3: Optionally, split or further process the data
# For example, if you want to split the data into train_sequences and test_sequences:
split = np.array_split(train_images, [int(len(train_images) * 0.7)])   
train_sequences = split[0]
test_sequences = split[1]
split = np.array_split(train_labels,[int(len(train_labels)*0.7)])
train_labels = split[0]
test_labels = split[1]

# repeat the images for n_steps
n_steps = 10
datalength = 8000
train_sequences = np.tile(train_sequences[:, None], (1, n_steps, 1, 1))[:datalength]  
test_sequences = np.tile(test_sequences[:, None], (1, n_steps, 1, 1))
train_labels = train_labels[:datalength]
test_labels = test_labels

In [3]:
spiking_model = torch.nn.Sequential(
    torch.nn.Linear(784, 128),       
    pytorch_spiking.SpikingActivation(torch.nn.ReLU(), spiking_aware_training=True,dt=0.01),     
    # use average pooling layer to average spiking output over time
    pytorch_spiking.TemporalAvgPool(),      
    torch.nn.Linear(128, 10),        
)
display(spiking_model)

Sequential(
  (0): Linear(in_features=784, out_features=128, bias=True)
  (1): SpikingActivation(
    (activation): ReLU()
  )
  (2): TemporalAvgPool()
  (3): Linear(in_features=128, out_features=10, bias=True)
)

In [4]:
#define quantum parameters
n_qubits = 5         
nqubits=n_qubits
batch_size = 4              
num_epochs = 1           
q_depth = 2            
q_delta = 0.01              
start_time = time.time()    
tensor_length = n_qubits*(n_qubits-1)*q_depth+n_qubits

In [5]:
def H_layer(nqubits):
    for idx in range(nqubits):
        qml.Hadamard(wires=idx)

def RZ_layer(w):
    for idx, element in enumerate(w):   
        qml.RZ(element, wires=idx)      

def RY_layer(w):
    for idx, element in enumerate(w):   
        qml.RY(element, wires=idx)   

def RX_layer(w):
    for idx, element in enumerate(w):   
        qml.RX(element, wires=idx)  
                
def entangling_layer(nqubits,weights):
    p = nqubits
    weights_ = (weight for weight in weights)    
    for i in range(1,nqubits):
        param = next(weights_)
        if i == 1:
            qml.CNOT(wires=[0,i])
            qml.RX(param, wires=i)
            p+=1   
            param = next(weights_)
            qml.CNOT(wires=[0,i])

        if i == 2:
            qml.CNOT(wires=[0,i])
            qml.RY(param, wires=i)
            p+=1                       
            param = next(weights_)
            qml.CNOT(wires=[0,i])

        if i == 3:
            qml.CNOT(wires=[0,i])
            qml.RZ(param, wires=i)
            p+=1   
            qml.CNOT(wires=[0,i])

        if i == 4:
            qml.CNOT(wires=[0,i])
            qml.RX(param, wires=i)
            p+=1   
            qml.CNOT(wires=[0,i])

In [6]:
dev = qml.device("default.qubit", wires=n_qubits)
@qml.qnode(dev, interface="torch")

def quantum_net(q_input_features, q_weights_flat):
    q_weights = q_weights_flat
    RY_layer(q_input_features)
    entangling_layer(nqubits,q_weights)
    exp_vals = [qml.expval(qml.PauliZ(position)) for position in range(1)]      

    return tuple(exp_vals)    
    
class DressedQuantumNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.pre_net = nn.Linear(128, n_qubits)      
        self.q_params = nn.Parameter(q_delta * torch.randn(tensor_length))       
        self.post_net = nn.Linear(1, 10)       

    def forward(self, input_features):
        pre_out = self.pre_net(input_features)       
        q_in = torch.tanh(pre_out) * np.pi / 2.0     
        q_out = torch.Tensor(0, 1)      
        q_out = q_out.to(device)               
        for elem in q_in:
            q_out_elem = quantum_net(elem, self.q_params).float().unsqueeze(0)    
            q_out = torch.cat((q_out, q_out_elem))    

        return self.post_net(q_out)      

print(qml.draw(quantum_net)([0,0,0,0,0],np.random.rand(tensor_length)))

0: ──RY(0.00)─╭●───────────╭●─╭●───────────╭●─╭●───────────╭●─╭●───────────╭●─┤  <Z>
1: ──RY(0.00)─╰X──RX(0.55)─╰X─│────────────│──│────────────│──│────────────│──┤     
2: ──RY(0.00)─────────────────╰X──RY(0.60)─╰X─│────────────│──│────────────│──┤     
3: ──RY(0.00)─────────────────────────────────╰X──RZ(0.42)─╰X─│────────────│──┤     
4: ──RY(0.00)─────────────────────────────────────────────────╰X──RX(0.65)─╰X─┤     


In [7]:
quantum_model = torch.nn.Sequential(
    torch.nn.Linear(784, 128),       
    pytorch_spiking.SpikingActivation(torch.nn.ReLU(), spiking_aware_training=True,dt=0.01),     
    # use average pooling layer to average spiking output over time
    pytorch_spiking.TemporalAvgPool(),      
    torch.nn.Linear(128, 128),        
)
quantum_model.fc = DressedQuantumNet()   
display(quantum_model)

Sequential(
  (0): Linear(in_features=784, out_features=128, bias=True)
  (1): SpikingActivation(
    (activation): ReLU()
  )
  (2): TemporalAvgPool()
  (3): Linear(in_features=128, out_features=128, bias=True)
  (fc): DressedQuantumNet(
    (pre_net): Linear(in_features=128, out_features=5, bias=True)
    (post_net): Linear(in_features=1, out_features=10, bias=True)
  )
)

In [8]:
def train(input_model, q_input_model, train_x, test_x):
    minibatch_size = 32
    optimizer = torch.optim.Adam(input_model.parameters())   
    optimizer = torch.optim.Adam(q_input_model.parameters())

    input_model.train()
    q_input_model.train()

    for j in range(101):
        train_acc = 0
        loss_acc = 0
        #cm = np.zeros((10, 10), dtype=np.int32)
        for i in tqdm(range(train_x.shape[0] // minibatch_size)):     
            input_model.zero_grad()    
            q_input_model.zero_grad()

            batch_in = train_x[i * minibatch_size : (i + 1) * minibatch_size]   
            # flatten images
            batch_in = batch_in.reshape((-1,) + train_x.shape[1:-2] + (784,))   
            batch_label = train_labels[i * minibatch_size : (i + 1) * minibatch_size]
            s_out = input_model(torch.tensor(batch_in))     
            q_out = q_input_model(torch.tensor(batch_in))
            output = s_out*0.4 + q_out*0.6
            logp = torch.nn.functional.log_softmax(output, dim=-1) 
            batch_label = batch_label.astype(np.int64)                   
            logpy = torch.gather(logp, 1, torch.tensor(batch_label).view(-1, 1))  
            loss = -logpy.mean()    
            loss_acc += loss 
 
            
            loss.backward()   
            optimizer.step()
        
            train_acc += torch.mean(
                torch.eq(torch.argmax(output, dim=1), torch.tensor(batch_label)).float()   
            )
        train_acc /= i + 1
        print("Train accuracy (%d): " % j, train_acc.numpy())    
        loss_acc /=i + 1
        print("Loss (%d): " % j, loss_acc.detach().numpy())
        # compute test accuracy
        input_model.eval()                             
        q_input_model.eval()

        test_acc = 0
        for i in range(test_x.shape[0] // minibatch_size):
            batch_in = test_x[i * minibatch_size : (i + 1) * minibatch_size]
            batch_in = batch_in.reshape((-1,) + test_x.shape[1:-2] + (784,))
            batch_label = test_labels[i * minibatch_size : (i + 1) * minibatch_size]
            s_out = input_model(torch.tensor(batch_in))     
            q_out = q_input_model(torch.tensor(batch_in))
            output = s_out*0.4 + q_out*0.6

            test_acc += torch.mean(
                torch.eq(torch.argmax(output, dim=1), torch.tensor(batch_label)).float()
        )
        test_acc /= i + 1
        print("Test accuracy:", test_acc.numpy())

In [9]:
train(spiking_model, quantum_model, train_sequences, test_sequences)

  0%|          | 0/250 [00:00<?, ?it/s]


RuntimeError: The size of tensor a (10) must match the size of tensor b (128) at non-singleton dimension 2