In [1]:
import pennylane as qml
#from pennylane import numpy as np
import torch

head = ["A", "B", "X", "Y", "C"]
relations = [["A","B","X","Y", "C"], ["X","Y"], ["A","X"], ["A","Y"], ["B","X"], ["B", "Y"], ["C"]]
fds = [[["A", "B"], ["X", "Y", "C"]], 
           [["B", "X", "Y"], ["A", "C"]], 
           [["X", "C"], ["A", "B", "Y"]], 
           [["A", "X", "Y"], ["B", "C"]], 
           [["A", "C"], ["B", "X", "Y"]], 
           [["Y", "C"], ["A", "B", "X"]]]
rmax = 1

n_qubits = len(head)*rmax
print(n_qubits)

variables_to_marginals = {}
for i in range(len(head)):
    variables_to_marginals[head[i]] = [j for j in range(i * rmax, (i+1) * rmax)]

5


In [2]:
def entropy(probabilities):
    probabilities = torch.clamp(probabilities, 1e-10, 1.0)
    entropy = -torch.sum(probabilities * torch.log2(probabilities))
    return entropy

In [3]:
dev = qml.device("default.qubit", wires=n_qubits)

@qml.qnode(dev, interface="torch")
def circuit(rotations, measured_qubits):
    for i in range(n_qubits):
        qml.RY(rotations[i][0], wires=i)
        qml.RZ(rotations[i][1], wires=i)
        qml.CNOT(wires=[i, (i+1) % n_qubits])
    return qml.probs(wires=measured_qubits)

In [4]:
def cost_less_one(x):
    return torch.where(x > 1, torch.square(x - 1), 0.0)

In [5]:
indices_head = [marg for var in head for marg in variables_to_marginals[var]]
indices_relations = [[marg for var in relation for marg in variables_to_marginals[var]] for relation in relations]
indices1 = [[marg for var in fd[0] + fd[1] for marg in variables_to_marginals[var]] for fd in fds]
indices2 = [[marg for var in fd[1] for marg in variables_to_marginals[var]] for fd in fds]

def cost_function(weights):
    entropy_sum = torch.tensor(0.0)
    hu0 = circuit(weights, indices_head)
    entropy_sum += -entropy(hu0)
    
    entropies = torch.stack([entropy(circuit(weights, indices)) for indices in indices_relations])
    entropy_sum += torch.sum(cost_less_one(entropies))
    
    entropies1 = torch.stack([entropy(circuit(weights, indices)) for indices in indices1])
    entropies2 = torch.stack([entropy(circuit(weights, indices)) for indices in indices2])
    entropy_sum += torch.sum(entropies1 - entropies2)
    
    return entropy_sum

In [6]:
steps = 100

weights = torch.rand(n_qubits, 2, requires_grad=True)
opt = torch.optim.AdamW([weights], lr=0.1)

for i in range(steps):
    opt.zero_grad()
    loss = cost_function(weights)
    loss.backward()
    opt.step()
    print("Step {}: Loss: {}".format(i + 1, loss.item()))

Step 1: Loss: 3.836354970932007
Step 2: Loss: 1.8560353517532349
Step 3: Loss: 0.7422341704368591
Step 4: Loss: 0.3020022213459015
Step 5: Loss: 0.1459791213274002
Step 6: Loss: 0.0629500076174736
Step 7: Loss: -0.004909770097583532
Step 8: Loss: -0.06699071079492569
Step 9: Loss: -0.12144710123538971
Step 10: Loss: -0.1551903337240219
Step 11: Loss: -0.18495559692382812
Step 12: Loss: -0.20584332942962646
Step 13: Loss: -0.21061336994171143
Step 14: Loss: -0.216363325715065
Step 15: Loss: -0.22685085237026215
Step 16: Loss: -0.24367322027683258
Step 17: Loss: -0.26818186044692993
Step 18: Loss: -0.3013626039028168
Step 19: Loss: -0.34332647919654846
Step 20: Loss: -0.39275839924812317
Step 21: Loss: -0.44637376070022583
Step 22: Loss: -0.49783581495285034
Step 23: Loss: -0.5468140244483948
Step 24: Loss: -0.5973817110061646
Step 25: Loss: -0.6488492488861084
Step 26: Loss: -0.7030434608459473
Step 27: Loss: -0.760107159614563
Step 28: Loss: -0.8168932795524597
Step 29: Loss: -0.868674

In [7]:
indices = [marg for var in head for marg in variables_to_marginals[var]]
hu0 = circuit(weights, indices)
total = entropy(hu0).item()
print("HU0: ", total)

HU0:  1.0000444527595604


In [8]:
for fd in fds:
    indices1 = [marg for var in fd[0] + fd[1] for marg in variables_to_marginals[var]]
    indices2 = [marg for var in fd[1] for marg in variables_to_marginals[var]]
    marginal1 = circuit(weights, indices1)
    marginal2 = circuit(weights, indices2)
    entr1 = entropy(marginal1)
    entr2 = entropy(marginal2)
    print("fd", entr1 - entr2)
    
for relation in relations:
    indices = [marg for var in relation for marg in variables_to_marginals[var]]
    marginal = circuit(weights, indices)
    entr = entropy(marginal)
    print("rel ", entr)

fd tensor(1.3193e-05, dtype=torch.float64, grad_fn=<SubBackward0>)
fd tensor(4.1908e-05, dtype=torch.float64, grad_fn=<SubBackward0>)
fd tensor(1.8263e-06, dtype=torch.float64, grad_fn=<SubBackward0>)
fd tensor(6.7317e-06, dtype=torch.float64, grad_fn=<SubBackward0>)
fd tensor(1.3846e-05, dtype=torch.float64, grad_fn=<SubBackward0>)
fd tensor(2.1477e-06, dtype=torch.float64, grad_fn=<SubBackward0>)
rel  tensor(1.0000, dtype=torch.float64, grad_fn=<NegBackward0>)
rel  tensor(1.0000, dtype=torch.float64, grad_fn=<NegBackward0>)
rel  tensor(1.0000, dtype=torch.float64, grad_fn=<NegBackward0>)
rel  tensor(1.0000, dtype=torch.float64, grad_fn=<NegBackward0>)
rel  tensor(1.0000, dtype=torch.float64, grad_fn=<NegBackward0>)
rel  tensor(1.0000, dtype=torch.float64, grad_fn=<NegBackward0>)
rel  tensor(1.0000, dtype=torch.float64, grad_fn=<NegBackward0>)


In [9]:
print(weights)

tensor([[ 8.0887e-04,  5.1333e-01],
        [-1.5722e+00,  4.2235e-02],
        [ 1.2741e-03, -1.0305e-02],
        [-2.0693e-03, -3.0259e-02],
        [ 1.3217e-03, -2.6881e-01]], requires_grad=True)
