In [11]:
import numpy as np
import pennylane as qml
import torch
import sklearn.datasets

In [12]:
import copy

In [13]:
torch.set_default_dtype(torch.float64)

In [14]:
n_qubits = 2
dev = qml.device("default.qubit", wires=n_qubits, shots=1000)

@qml.qnode(dev, diff_method="spsa", interface="torch")
def qnode(inputs, weights):
    qml.templates.AngleEmbedding(inputs, wires=range(n_qubits))
    qml.templates.StronglyEntanglingLayers(weights, wires=range(n_qubits))
    return qml.expval(qml.PauliZ(0)), qml.expval(qml.PauliZ(1))

weight_shapes = {"weights": (3, n_qubits, 3)}



class QNodeFunction(torch.autograd.Function):

    @staticmethod
    def forward(ctx, input, qlayer):
        ctx.save_for_backward(input)
        ctx.qlayer = qlayer
        return qlayer(input)

    @staticmethod
    def backward(ctx, grad_output):
        input, = ctx.saved_tensors
        qlayer = ctx.qlayer
        epsilon = 0.01
        
        # Gradient w.r.t. input
        perturbation_input = (torch.rand_like(input) * 2 - 1) * epsilon
        positive_input = input + perturbation_input
        negative_input = input - perturbation_input
        
        loss_positive = qlayer(positive_input).sum()
        loss_negative = qlayer(negative_input).sum()
        
        gradient_input = (loss_positive - loss_negative) / (2 * epsilon) * perturbation_input
        gradient_input *= grad_output  # Incorporate grad_output due to chain rule

        # Gradient w.r.t. qlayer's parameters
        gradients_weights = []
        for p in qlayer.parameters():
            perturbation_weight = (torch.rand_like(p) * 2 - 1) * epsilon
            p.data += perturbation_weight

            loss_positive = qlayer(input).sum()
            loss_negative = qlayer(input).sum()

            gradient_weight = (loss_positive - loss_negative) / (2 * epsilon) * perturbation_weight / epsilon
            gradients_weights.append(gradient_weight * grad_output.sum())  # Weighting by grad_output

            p.data -= perturbation_weight  # Reset to original value

        # Update gradients for qlayer's parameters
        for p, grad in zip(qlayer.parameters(), gradients_weights):
            if p.grad is None:
                p.grad = grad.detach()
            else:
                p.grad += grad.detach()

        return gradient_input, None

# Wrapper around the custom autograd function
class CustomQLayer(torch.nn.Module):
    def __init__(self, qlayer):
        super(CustomQLayer, self).__init__()
        self.qlayer = qlayer

    def forward(self, x):
        return QNodeFunction.apply(x, self.qlayer)


qlayer = qml.qnn.TorchLayer(qnode, weight_shapes)
clayer1 = torch.nn.Linear(2, 2)
clayer2 = torch.nn.Linear(2, 2)
softmax = torch.nn.Softmax(dim=1)

model = torch.nn.Sequential(clayer1, qlayer, clayer2, softmax)

model[1] = CustomQLayer(qlayer)

samples = 100
x, y = sklearn.datasets.make_moons(samples)
y_hot = np.zeros((samples, 2))
y_hot[np.arange(samples), y] = 1

X = torch.tensor(x, dtype=torch.float64)
Y = torch.tensor(y_hot, dtype=torch.float64)


opt = torch.optim.SGD(model.parameters(), lr=0.5)

loss = torch.nn.L1Loss()






Average loss over epoch 1: 0.5108
Average loss over epoch 2: 0.5004
Average loss over epoch 3: 0.5067
Average loss over epoch 4: 0.5028
Average loss over epoch 5: 0.5107
Average loss over epoch 6: 0.5039
Average loss over epoch 7: 0.5029
Average loss over epoch 8: 0.5017


In [22]:
for i in qlayer.parameters():
    print(i.shape)

torch.Size([3, 2, 3])


In [19]:
epochs = 8
batch_size = 5
batches = samples // batch_size

dataset = torch.utils.data.TensorDataset(X, Y)
data_loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size,
                                          shuffle=True, drop_last=True)


for epoch in range(epochs):

    running_loss = 0

    for x, y in data_loader:
        opt.zero_grad()

        loss_evaluated = loss(model(x), y)
        loss_evaluated.backward()

        opt.step()

        running_loss += loss_evaluated

    avg_loss = running_loss / batches
    print("Average loss over epoch {}: {:.4f}".format(epoch + 1, avg_loss))

Average loss over epoch 1: 0.4995
Average loss over epoch 2: 0.4992
Average loss over epoch 3: 0.5024
Average loss over epoch 4: 0.5041
Average loss over epoch 5: 0.5041
Average loss over epoch 6: 0.4989
Average loss over epoch 7: 0.4990
Average loss over epoch 8: 0.5018
