In [1]:
!pip install qutip
!pip install 'shimmy>=0.2.1'
!pip install stable_baselines3

Collecting qutip
  Downloading qutip-4.7.3-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (16.3 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m16.3/16.3 MB[0m [31m37.1 MB/s[0m eta [36m0:00:00[0m
Installing collected packages: qutip
Successfully installed qutip-4.7.3
Collecting shimmy>=0.2.1
  Downloading Shimmy-1.3.0-py3-none-any.whl (37 kB)
Collecting gymnasium>=0.27.0 (from shimmy>=0.2.1)
  Downloading gymnasium-0.29.1-py3-none-any.whl (953 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m953.9/953.9 kB[0m [31m16.1 MB/s[0m eta [36m0:00:00[0m
Collecting farama-notifications>=0.0.1 (from gymnasium>=0.27.0->shimmy>=0.2.1)
  Downloading Farama_Notifications-0.0.4-py3-none-any.whl (2.5 kB)
Installing collected packages: farama-notifications, gymnasium, shimmy
Successfully installed farama-notifications-0.0.4 gymnasium-0.29.1 shimmy-1.3.0
Collecting stable_baselines3
  Downloading stable_baselines3-2.1.0-py3-none-any.whl (178 kB

In [None]:
%matplotlib inline
import matplotlib.pyplot as plt
import numpy as np
np.bool = np.bool_
from qutip import *
from IPython.display import HTML
from matplotlib import animation
from base64 import b64encode
import gym
from stable_baselines3 import PPO
from stable_baselines3.common.evaluation import evaluate_policy



In [None]:
N = 35
w = 1 * 2 * np.pi              # oscillator frequency
tlist = np.linspace(0, 4, 101) # periods
Tau = 0.1
r = 0.5

# operators
a = destroy(N)
n = num(N)
x = (a + a.dag())/np.sqrt(2)
p = -1j * (a - a.dag())/np.sqrt(2)
S = displace(1, r)
squeezing_operator = x * p + p * x

# the quantum harmonic oscillator Hamiltonian
H = w * a.dag() * a + w/2 * (S.dag() * squeezing_operator * S + qeye(1))

c_ops = []
e_ops = [np.sqrt(Tau)*x**2]
# uncomment to see how things change when disspation is included
#c_ops = [np.sqrt(0.25) * a]

In [None]:
def plot_expect_with_variance(N, op_list, op_title, states):
    """
    Plot the expectation value of an operator (list of operators)
    with an envelope that describes the operators variance.
    """

    fig, axes = plt.subplots(1, len(op_list), figsize=(14,3))

    for idx, op in enumerate(op_list):

        e_op = expect(op, states)
        v_op = variance(op, states)

        axes[idx].fill_between(tlist, e_op - np.sqrt(v_op), e_op + np.sqrt(v_op), color="green", alpha=0.5);
        axes[idx].plot(tlist, e_op, label="expectation")
        axes[idx].set_xlabel('Time')
        axes[idx].set_title(op_title[idx])

    return fig, axes

def display_embedded_video(filename):
    video = open(filename, "rb").read()
    video_encoded = b64encode(video).decode("ascii")
    video_tag = '<video controls alt="test" src="data:video/x-m4v;base64,{0}">'.format(video_encoded)
    return HTML(video_tag)

In [None]:
psi0 = coherent(N, 2.0)
result_expect = mesolve(H, psi0, tlist, c_ops, e_ops)
result = mesolve(H, psi0, tlist, c_ops, [])

In [None]:
plot_expect_with_variance(N, [n, x, p, x**2, H], [r'$n$', r'$x$', r'$p$', r'$x2$', r'$H$'], result.states);

In [None]:
import gym
from gym import spaces

N = 35
w = 1*2*np.pi
Tau = 0.1

a = destroy(N)
n = num(N)
x = (a + a.dag())/np.sqrt(2)
p = -1j * (a - a.dag())/np.sqrt(2)
squeezing_operator = x * p + p * x

c_ops = []
#c_ops = [np.sqrt(0.25) * a]
e_ops = []
#e_ops = [np.sqrt(Tau)*x**2]

class QuantumHarmonicOscillatorEnv(gym.Env):
    def __init__(self):
        super(QuantumHarmonicOscillatorEnv, self).__init__()
        self.observation_space = spaces.Box(low=-50, high=50, shape=(1,), dtype=np.float32)
        self.action_space = spaces.Box(low=-5, high=5, shape=(1,), dtype=np.float32)

        # Define initial quantum state and time
        self.psi = coherent(N,2.0)
        self.time = 0

    def step(self, action):
        r = action[0]
        S = displace(1, r)

        H = w * a.dag() * a + w/2 * (S.dag() * squeezing_operator * S + qeye(1))
        tlist = np.linspace(self.time,self.time+1,101)

        result = mesolve(H, self.psi, tlist, c_ops)
        mean_current = 0
        for i in range(-4,-1):
          mean_current = mean_current + expect(H,result.states[i])/4
        reward = -1*abs(mean_current)
        observation = np.array([mean_current], dtype=np.float32)

        if mean_current<1:
          done = True
        else:
          #print("Action: ", action)
          #print("Mean current: ", mean_current)
          done = False

        self.psi = result.states[-1]
        time = tlist[-1]
        # Return the next state, reward, and whether the episode is done
        return observation, reward, done, {}

    def reset(self):
        # Reset the environment to the initial state and time
        self.psi = coherent(N,2.0)
        self.time = 0
        obs = np.array(expect(H, self.psi), dtype=np.float32)
        obs = np.mean(obs)  # Take the mean to reduce it to a scalar value
        return np.array([obs], dtype=np.float32)  # Wrap it in a one-dimensional array

In [None]:
env = QuantumHarmonicOscillatorEnv()
model = PPO("MlpPolicy", env, verbose=1)
model.learn(total_timesteps=10)  # You can adjust the number of training steps
model.save("ppo_quantum_harmonic_oscillator")
mean_reward, _ = evaluate_policy(model, env, n_eval_episodes=2)
print(f"Mean reward: {mean_reward:.2f}")

Using cuda device
Wrapping the env with a `Monitor` wrapper
Wrapping the env in a DummyVecEnv.
Action:  [0.2860758]
Mean current:  21.205481344716194
Action:  [-0.6944237]
Mean current:  21.205200916784236
Action:  [0.62947667]
Mean current:  21.204907552457485
Action:  [0.5707612]
Mean current:  21.204619222313287
Action:  [0.89259464]
Mean current:  21.20434048239442
Action:  [-1.0115054]
Mean current:  21.204033792443656
Action:  [-0.6743158]
Mean current:  21.20376098453864
Action:  [-0.6239888]
Mean current:  21.20349465175215
Action:  [0.36826047]
Mean current:  21.20320176218889
Action:  [0.17312]
Mean current:  21.20291482229562
Action:  [-0.43720528]
Mean current:  21.202622997955608
Action:  [0.7766227]
Mean current:  21.20232554426552
Action:  [-0.7930486]
Mean current:  21.202039625852734
Action:  [-0.10302491]
Mean current:  21.201769991307152
Action:  [-0.16321123]
Mean current:  21.20147882009578
Action:  [-0.7914489]
Mean current:  21.201189960896734
Action:  [-0.834568



[1;30;43mStreaming output truncated to the last 5000 lines.[0m
Action:  [0.08089218]
Mean current:  19.94900966846799
Action:  [0.08089213]
Mean current:  19.94851393096777
Action:  [0.08089209]
Mean current:  19.94801150957191
Action:  [0.08089206]
Mean current:  19.947460548956123
Action:  [0.08089201]
Mean current:  19.947072506421577
Action:  [0.08089198]
Mean current:  19.94652539509721
Action:  [0.08089194]
Mean current:  19.94608466491649
Action:  [0.0808919]
Mean current:  19.945539825726755
Action:  [0.08089186]
Mean current:  19.945120033490415
Action:  [0.08089183]
Mean current:  19.94467154718203
Action:  [0.08089179]
Mean current:  19.944165393499024
Action:  [0.08089175]
Mean current:  19.943801395902504
Action:  [0.08089172]
Mean current:  19.943315910210817
Action:  [0.08089168]
Mean current:  19.942758009712943
Action:  [0.08089163]
Mean current:  19.942295574424
Action:  [0.0808916]
Mean current:  19.941733008469967
Action:  [0.08089156]
Mean current:  19.9412189240

In [None]:
from qutip.ipynbtools import version_table
version_table()