In [2]:
import dynamiqs as dq
import matplotlib.pyplot as plt
import jax
import jax.numpy as jnp
import numpy as np 

In [61]:
na = 20 # Hilbert space dimension
a = dq.destroy(na) # annihilaiton operator
b = dq.destroy(na) # annihilation operator
#psi0 = dq.tensor(dq.fock(na, 0), dq.fock(nb, 0)) # initial state
psi0 = dq.tensor(dq.coherent(na, 0.0), dq.coherent(na, 0.0)) # start in the vaccum
dq.set_precision('double')  # 'simple' by default

# Define constants
pi2 = 2.0 * np.pi
w_a0 = 5.26 * pi2 # GHz to Hz
w_b0 = 7.70 * pi2 # GHz to Hz
psi_a = 0.06
psi_b = 0.29
E_j = 42.76  # GHz to Hz
DE_j = 0.47  # GHz to Hz
w_d = 7.623 * pi2 # GHz to Hz
w_p = 2.891 * pi2  # GHz to Hz
e_d = -3.815e-3 * pi2 # MHz to Hz
e_p = 0.122  # rad

"""
w_a0 = 5.26e9 * 2 * jnp.pi 
w_b0 = 7.70e9 * 2 * jnp.pi
psi_a = 0.06
psi_b = 0.29
E_j = 42.76e9 #where h? 
DE_j = 0.47e9 #where h?
w_d = 7.623e9 * 2 * jnp.pi
w_p = 2.891e9 * 2 * jnp.pi
e_d = -3.815e6 * 2 * jnp.pi
e_p = 0.122
"""
tsave = jnp.linspace(0, 1e-7, 100)

def ATS(t):
    e_t = e_p * jnp.cos(w_p * t)
    A = psi_a * (a + a.dag()) + psi_b * (b + b.dag())
    #scalar = (-2*E_j * e_t * psi_a * (a + a.dag())) - (2 * E_j * e_t * psi_b * (b + b.dag())) + ((1/3)* E_j * e_t * (A @ A @ A)) \
    scalar = -2*E_j * jnp.sin(e_t) * jnp.sin(dq.to_jax(A)) \
        + 2 * DE_j * jnp.cos(e_t) * jnp.cos(dq.to_jax(A))
    #jax.debug.print("ATS scalar at t={}: {}", t, scalar)  # JAX-safe debugging    
    return scalar 

def d(t):
    scalar = 2 * e_d * jnp.cos(w_d * t)
    #jax.debug.print("d scalar at t={}: {}", t, scalar)  # JAX-safe debugging
    return scalar * (jnp.eye(na))

def o(t):
    return jnp.eye(na)


H_0 = w_a0 * (dq.dag(a) @ a) + w_b0 * (dq.dag(b) @ b)
#H_0 = dq.modulated(o, (H_0))
#print(H_0(1))
H_ATS = dq.modulated(ATS, (jnp.eye(na)))
#H_ATS = dq.modulated(ATS, (a + dq.dag(a)))

#print(H_ATS(1))
#H_d = dq.modulated(d, (b + dq.dag(b)))
H_d = dq.modulated(d, (b + dq.dag(b)))


H = H_0  + H_ATS + H_d #+ H_d.dag() + H_ATS.dag()

kappa_a = 9.3e-6
kappa_b = 2.6e-3
alpha = 2.0 # cat size
beta = 1.0 
#two_photon_loss_a = jnp.sqrt(kappa_a) * (a @ a - alpha**2 * dq.eye(na))
#two_photon_loss_b = jnp.sqrt(kappa_b) * (b @ b - beta**2 * dq.eye(na))
loss_ops = [jnp.sqrt(kappa_a) * a + jnp.sqrt(kappa_b) * b]
#loss_ops = [two_photon_loss_a, two_photon_loss_b]
print(H.shape)
print(psi0.shape)
psi0 = np.reshape(psi0, (na,na))
print(H.shape)
print(psi0.shape)
res = dq.mesolve(H, loss_ops, psi0, tsave, solver=dq.solver.Tsit5(max_steps=1000000))



ValueError: Qarrays have incompatible Hilbert space dimensions. Got (20, 5) and (20, 20).

In [39]:
import matplotlib.pyplot as plt
import matplotlib.animation as animation
from matplotlib.animation import PillowWriter
from IPython.display import HTML, display
from functools import reduce     # Import reduce to sum qarrays without starting with 0# ---------------------------
# 6. Animate the Wigner function of mode a
# ---------------------------
fig_w, ax = plt.subplots(2, 2, figsize=(6, 6))

def update(frame):
    # Obtain the reduced state of mode a by tracing out mode b.
    for a in [0,1]:
        for b in [0,1]:
            ax_w = ax[a][b]
            ax_w.cla()  # Clear the axis.
            rho_a = dq.ptrace(res.states[:][a][frame][][b], 0)
            dq.plot.wigner(rho_a, ax=ax_w)
            ax_w.set_title(f"Mode a Wigner Function\nTime = {tsave[frame]*1e8:.2f}")

ani = animation.FuncAnimation(fig_w, update, frames=len(tsave), repeat=False)
gif_filename = 'wigner_mode_a.gif'
ani.save(gif_filename, writer=PillowWriter(fps=len(tsave) / 4))
plt.close(fig_w)
display(HTML(f'<img src="{gif_filename}">'))

In [87]:
na = 20 # Hilbert space dimension
nb = 5
#a, b = dq.destroy(na, nb) # annihilaiton operator
a = dq.tensor(dq.destroy(na), dq.eye(nb))     # Mode a annihilation operator
b = dq.tensor(dq.eye(na), dq.destroy(nb))
identity = dq.tensor(dq.eye(na), dq.eye(nb))
psi0 = dq.tensor(dq.fock(na, 0), dq.fock(nb, 0)) # initial state
#psi0 = dq.tensor(dq.coherent(na, 0.0), dq.coherent(nb, 0.0)) # start in the vaccum
dq.set_precision('double')  # 'simple' by default

# Define constants
pi2 = 2.0 * np.pi
w_a0 = 5.26 * pi2 # GHz to Hz
w_b0 = 7.70 * pi2 # GHz to Hz
psi_a = 0.06
psi_b = 0.29
E_j = 42.76  # GHz to Hz
DE_j = 0.47  # GHz to Hz
w_d = 7.623 * pi2 # GHz to Hz
w_p = 2.891 * pi2  # GHz to Hz
e_d = -3.815e-3 * pi2 # MHz to Hz
e_p = 0.122  # rad

"""
w_a0 = 5.26e9 * 2 * jnp.pi 
w_b0 = 7.70e9 * 2 * jnp.pi
psi_a = 0.06
psi_b = 0.29
E_j = 42.76e9 #where h? 
DE_j = 0.47e9 #where h?
w_d = 7.623e9 * 2 * jnp.pi
w_p = 2.891e9 * 2 * jnp.pi
e_d = -3.815e6 * 2 * jnp.pi
e_p = 0.122
"""
tsave = jnp.linspace(0, 1e-7, 100)

def ATS(t):
    e_t = e_p * jnp.cos(w_p * t)
    A = psi_a * (a + a.dag()) + psi_b * (b + b.dag())
    #Attempted taylor expansion from tutorial 2
    #term1 = (-2*E_j * e_t * psi_a * (a + a.dag())) - (2 * E_j * e_t * psi_b * (b + b.dag())) + ((1/3)* E_j * e_t * (A @ A @ A)) \
    term1 = -2*E_j * jnp.sin(e_t) * jnp.sin(dq.to_jax(A)) 
    print(f"term1 {term1.shape}")
    term2 = 2 * DE_j * jnp.cos(e_t) * jnp.cos(dq.to_jax(A))
    print(f"term2 {term1.shape}")

    scalar = term1 + term2
    print(f"scalar {scalar.shape}")
    #jax.debug.print("ATS scalar at t={}: {}", t, scalar)  # JAX-safe debugging    
    return scalar 

def d(t):
    scalar = 2 * e_d * jnp.cos(w_d * t)
    #jax.debug.print("d scalar at t={}: {}", t, scalar)  # JAX-safe debugging
    return scalar * (b + dq.dag(b))

def o(t):
    return identity
#print(a.shape)
#print(b.shape)

def get_H(t):
    H_0 = w_a0 * (dq.dag(a) @ a) + w_b0 * (dq.dag(b) @ b)
    H_d = d(t)
    H_ATS = ATS(t)
    return H_0 + H_ATS + H_d

#H_0 = w_a0 * (dq.dag(a) @ a) + w_b0 * (dq.dag(b) @ b)
#H_0 = dq.modulated(o, (H_0))
#print(H_0(1))
#H_ATS = dq.modulated(ATS, dq.eye(1))
#H_ATS = dq.modulated(ATS, (a + dq.dag(a)))

#print(H_ATS(1))
#H_d = dq.modulated(d, (b + dq.dag(b)))
#H_d = dq.modulated(d, ())
#print(f"H_0 {H_0.shape}")
#print(f"HATS {H_ATS.shape}")
#print(f"H_d {H_d.shape}")

#H = H_0  + H_ATS + H_d #+ H_d.dag() + H_ATS.dag()
H = dq.timecallable(get_H)
#print(H(1))
kappa_a = 9.3e-6
kappa_b = 2.6e-3
alpha = 2.0 # cat size
beta = 1.0 
two_photon_loss_a = jnp.sqrt(kappa_a) * (a @ a - alpha**2 * identity)
two_photon_loss_b = jnp.sqrt(kappa_b) * (b @ b - beta**2 * identity)
loss_ops = [jnp.sqrt(kappa_a) * a, jnp.sqrt(kappa_b) * b]
loss_ops = [two_photon_loss_a, two_photon_loss_b]
#print(f"H {H.shape}")
#print(f"psi0 {psi0.shape}")
#psi0 = np.reshape(psi0, (na,na))
#print(f"psi0 {psi0.shape}")
res = dq.mesolve(H, loss_ops, psi0, tsave, solver=dq.solver.Tsit5(max_steps=1000000))
#print(res.states.shape)


term1 (100, 100)
term2 (100, 100)
scalar (100, 100)
term1 (100, 100)
term2 (100, 100)
scalar (100, 100)
term1 (100, 100)
term2 (100, 100)
scalar (100, 100)


  return H_0 + H_ATS + H_d
  return self.__add__(y)
  return H_0 + H_ATS + H_d
  return self.__add__(y)
  return H_0 + H_ATS + H_d
  return self.__add__(y)


term1 (100, 100)
term2 (100, 100)
scalar (100, 100)
term1 (100, 100)
term2 (100, 100)
scalar (100, 100)
term1 (100, 100)
term2 (100, 100)
scalar (100, 100)


|██████████| 100.0% ◆ elapsed 4.03ms ◆ remaining 0.00ms


In [88]:
import matplotlib.pyplot as plt
import matplotlib.animation as animation
from matplotlib.animation import PillowWriter
from IPython.display import HTML, display
from functools import reduce     # Import reduce to sum qarrays without starting with 0# ---------------------------
# 6. Animate the Wigner function of mode a
# ---------------------------
fig_w, ax_w = plt.subplots(1, 1, figsize=(6, 6))

def update(frame):
    # Obtain the reduced state of mode a by tracing out mode b.
    ax_w.cla()  # Clear the axis.
    rho_a = dq.ptrace(res.states[frame], 0)
    dq.plot.wigner(rho_a, ax=ax_w)
    ax_w.set_title(f"Mode a Wigner Function\nTime = {tsave[frame]*1e8:.2f}")

ani = animation.FuncAnimation(fig_w, update, frames=len(tsave), repeat=False)
gif_filename = 'wigner_mode_a.gif'
ani.save(gif_filename, writer=PillowWriter(fps=len(tsave) / 4))
plt.close(fig_w)
display(HTML(f'<img src="{gif_filename}">'))