In [5]:
import numpy as np
from scipy.integrate import solve_ivp
import matplotlib.pyplot as plt

def ms_hamiltonian_fock(t, y, omega, eta, delta, n_fock):
    """ Time-dependent Hamiltonian for the MS gate """
    # Pauli X matrix and identity matrix for the qubit space
    sigma_x = np.array([[0, 1], [1, 0]])
    identity_q = np.eye(2)

    # Creation and annihilation operators for Fock space
    a = np.diag(np.sqrt(np.arange(1, n_fock)), 1)
    a_dagger = np.diag(np.sqrt(np.arange(1, n_fock)), -1)

    # Interaction term
    interaction = (a + a_dagger)

    # Constructing the full Hamiltonian
    H = omega * np.kron(sigma_x, interaction) + eta * omega * np.cos(delta * t) * np.kron(identity_q, interaction)
    dydt = -1j * np.dot(H, y)
    return np.imag(dydt)  # Returning the imaginary part as a real array for ODE solver compatibility

def initial_state(n_fock):
    """ Initialize the system's state vector in the ground state """
    psi_q = np.array([1, 0, 0, 0], dtype=complex)  # Represents qubit state |00>
    psi_m = np.zeros(n_fock, dtype=complex)
    psi_m[0] = 1  # Motional ground state |0>
    return np.kron(psi_q, psi_m)  # Return the Kronecker product as the initial state

def solve_schrodinger(omega, eta, delta, n_fock, t_final):
    """ Solve the Schrödinger equation using scipy's solve_ivp """
    y0 = initial_state(n_fock)
    t_eval = np.linspace(0, t_final, 500)  # Time grid for the ODE solution
    sol = solve_ivp(ms_hamiltonian_fock, [0, t_final], y0.view(np.float64), args=(omega, eta, delta, n_fock), t_eval=t_eval, method='RK45')
    return t_eval, sol.y.T.view(np.complex128)

# Simulation parameters
omega = 1.0  # Rabi frequency
eta = 0.1    # Lamb-Dicke parameter
delta = 0.2  # Detuning frequency
n_fock = 5   # Number of Fock states
t_final = 10 # Total simulation time

# Run the simulation
t_eval, result = solve_schrodinger(omega, eta, delta, n_fock, t_final)

# Plot the probabilities of each state over time
plt.figure(figsize=(12, 8))
for state_idx in range(result.shape[1]):
    plt.plot(t_eval, np.abs(result[:, state_idx])**2, label=f'State {state_idx}')
plt.title('Probabilities of States Over Time')
plt.xlabel('Time')
plt.ylabel('Probability')
plt.legend()
plt.show()


ValueError: shapes (10,10) and (40,) not aligned: 10 (dim 1) != 40 (dim 0)