In [27]:
import pennylane as qml
from pennylane import numpy as np
import embedding
import data
import torch

In [28]:
n_qubits = 4
dev = qml.device("default.qubit", wires=4)

In [35]:
def trainable_embedding(parameters):
    for i in range(n_qubits):
        qml.RX(parameters[i], wires=i)
    for i in range(n_qubits - 1):
        qml.IsingXX(parameters[i + n_qubits], wires=[i, i + 1])
    for i in range(n_qubits):
        qml.RY(parameters[i + 2 * n_qubits - 1], wires= i)
    for i in range(n_qubits - 1):
        qml.IsingYY(parameters[i + 3 * n_qubits -1], wires=[i, i+1])
    

@qml.qnode(dev)
def qnode(inputs, params1, params2, params3):
    trainable_embedding(parameters=params1)
    embedding.Noisy_Four_QuantumEmbedding1(inputs[0:4])
    trainable_embedding(parameters=params2)
    embedding.Noisy_Four_QuantumEmbedding1(inputs[0:4])
    trainable_embedding(parameters=params3)
    embedding.Noisy_Four_QuantumEmbedding1(inputs[0:4])

    embedding.Noisy_Four_QuantumEmbedding1_inverse(inputs[4:8])
    qml.adjoint(trainable_embedding)(params3)
    embedding.Noisy_Four_QuantumEmbedding1_inverse(inputs[4:8])
    qml.adjoint(trainable_embedding)(params2)
    embedding.Noisy_Four_QuantumEmbedding1_inverse(inputs[4:8])
    qml.adjoint(trainable_embedding)(params1)
    return qml.probs(wires=range(4))

weight_shapes = {"params1": 14, "params2": 14, "params3": 14}

In [30]:
class Model_quantum(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.qlayer = qml.qnn.TorchLayer(qnode, weight_shapes)

    def forward(self, x1, x2):
        x = torch.concat([x1, x2], 1)
        x = self.qlayer(x)
        return x[:,0]


In [31]:
batch_size = 25
iterations = 1000
classes = [0, 1]
feature_reduction = 'PCA4'

In [32]:
X_train, X_test, Y_train, Y_test = data.data_load_and_process('mnist', feature_reduction=feature_reduction, classes=classes)
#make new data for hybrid model
def new_data(batch_size, X, Y):
    X1_new, X2_new, Y_new = [], [], []
    for i in range(batch_size):
        n, m = np.random.randint(len(X)), np.random.randint(len(X))
        X1_new.append(X[n])
        X2_new.append(X[m])
        if Y[n] == Y[m]:
            Y_new.append(1)
        else:
            Y_new.append(0)
    X1_new, X2_new, Y_new = torch.tensor(X1_new).to(torch.float32), torch.tensor(X2_new).to(torch.float32), torch.tensor(Y_new)
    return X1_new, X2_new, Y_new

In [33]:
def train_models():
    model = Model_quantum()
    model.train()
    loss_fn = torch.nn.MSELoss()
    opt = torch.optim.SGD(model.parameters(), lr=0.01)
    for it in range(1000):
        X1_batch, X2_batch, Y_batch = new_data(batch_size, X_train, Y_train)
        pred = model(X1_batch, X2_batch)
        loss = loss_fn(pred.to(torch.float32), Y_batch.to(torch.float32))
        
        opt.zero_grad()
        loss.backward()
        opt.step()
    
        if it % 10 == 0:
            print(f"Iterations: {it} Loss: {loss.item()}")
            for p in model.parameters():
                if p.requires_grad:
                    print(p.name, p.data)  

In [36]:
train_models()

Iterations: 0 Loss: 0.4451241195201874
None tensor([6.0954, 5.9903, 2.0830, 4.5080, 2.5952, 3.0116, 4.2362, 3.8846, 4.5247,
        5.2666, 5.3386, 3.7550, 0.3044, 3.3324])
None tensor([2.5892, 0.4239, 3.3226, 1.3400, 5.5752, 4.4081, 4.6339, 5.3736, 4.5695,
        5.7000, 0.2401, 4.1922, 1.2487, 1.2864])
None tensor([4.5073, 1.5615, 4.2716, 1.3607, 2.8627, 4.7606, 4.6614, 4.2588, 0.3656,
        1.1730, 6.0872, 1.9473, 4.2389, 4.7746])
Iterations: 10 Loss: 0.5022210478782654
None tensor([6.0955, 5.9905, 2.0827, 4.5080, 2.5952, 3.0116, 4.2353, 3.8850, 4.5239,
        5.2673, 5.3395, 3.7553, 0.3042, 3.3326])
None tensor([2.5892, 0.4234, 3.3231, 1.3406, 5.5756, 4.4083, 4.6342, 5.3733, 4.5699,
        5.6998, 0.2400, 4.1917, 1.2474, 1.2862])
None tensor([4.5069, 1.5626, 4.2727, 1.3607, 2.8636, 4.7607, 4.6604, 4.2586, 0.3649,
        1.1721, 6.0869, 1.9471, 4.2392, 4.7739])
Iterations: 20 Loss: 0.38370054960250854
None tensor([6.0956, 5.9904, 2.0830, 4.5078, 2.5947, 3.0110, 4.2362, 3.8854,