In [1]:
import gymnasium as gym
import pennylane as qml
import numpy as np
import torch
import random
import matplotlib.pyplot as plt
from collections import deque, namedtuple


# Environment
env = gym.make("CartPole-v1")

N_QUBITS = 4
dev = qml.device("default.qubit", wires=N_QUBITS)

In [4]:
Transition = namedtuple('Transition', ('state', 'action', 'next_state', 'reward'))

class ReplayMemory(object):
    def __init__(self, capacity):
        self.memory = deque([], maxlen=capacity)

    def push(self, *args):
        """Saves a transition."""
        self.memory.append(Transition(*args)) 

    def sample(self, batch_size):
        return random.sample(self.memory, batch_size)
    
    def __len__(self):
        return len(self.memory)

In [6]:
# Classical Critic -> Classical NN
class V(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.fc1 = torch.nn.Linear(4, 256)
        self.fc_v = torch.nn.Linear(256, 1)

    def forward(self, x):
        x = torch.relu(self.fc1(x))
        v = self.fc_v(x)
        return v

In [7]:
# Parameterized Rotation Layer
def layer(W):
    for i in range(N_QUBITS):
        qml.RX(W[i, 0], wires=i)
        qml.RY(W[i, 1], wires=i)
        qml.RZ(W[i, 2], wires=i)


# Quantum Circuit
@qml.qnode(dev, interface="torch")
def circuit(W, s):
    # W: Layer Variable Parameters, s: State Variable

    # Input Encoding
    for i in range(N_QUBITS):
        qml.RY(np.pi * s[i], wires=i)

    # VQC
    layer(W[0])
    for i in range(N_QUBITS - 1):
        qml.CNOT(wires=[i, i + 1])
    layer(W[1])
    for i in range(N_QUBITS - 1):
        qml.CNOT(wires=[i, i + 1])
    layer(W[2])
    for i in range(N_QUBITS - 1):
        qml.CNOT(wires=[i, i + 1])
    layer(W[3])
    qml.CNOT(wires=[0, 2])
    qml.CNOT(wires=[1, 3])
    return [qml.expval(qml.PauliY(i)) for i in range(N_QUBITS)]

In [None]:
MEMORY_CAPACITY = 10000
EPS_START = 0.9
EPS_END = 0.05
EPS_DECAY = 1000

D = ReplayMemory(MEMORY_CAPACITY)
W = torch.autograd.Variable(torch.DoubleTensor(4, N_QUBITS, 3), requires_grad=True)
v = V()
circuit_pi = circuit
optimizer1 = torch.optim.Adam([W], lr=1e-3)
optimizer2 = torch.optim.Adam(v.parameters(), lr=1e-5)

In [None]:
def select_action(state):
    global stops_done
    sample = random.random()
    eps_threshold = EPS_END + (EPS_START - EPS_END) * np.exp(-1. * stops_done / EPS_DECAY)
    stops_done += 1
    if sample > eps_threshold:
        with torch.no_grad():
            return v(state).max(1)[1].indices.view(1,1)
    else:
        return torch.tensor([[env.action_space.sample()]], dtype=torch.long)