# Floquet Cavity Cooling - GPU Training

This notebook runs training on Colab GPU.

In [None]:
# Install dependencies
!pip install -q jax[cuda12] jaxlib flax optax diffrax gymnasium stable-baselines3 matplotlib

In [None]:
# Clone repo and checkout dev branch
import os
if not os.path.exists('/content/abouie_proj'):
    !git clone https://github.com/sattary/abouie_proj.git /content/abouie_proj
%cd /content/abouie_proj
!git checkout dev
!git pull origin dev

In [None]:
# Add src to path
import sys
sys.path.insert(0, '/content/abouie_proj')
!ls src/

In [None]:
# Verify GPU
import jax
devices = jax.devices()
print(f"JAX devices: {devices}")
gpu_available = any('cuda' in str(d).lower() or 'gpu' in str(d).lower() for d in devices)
print(f"GPU available: {gpu_available}")

In [None]:
# Import modules
from src.physics import SystemParams, build_operators, thermal_cavity_ground_qubits
from src.floquet import create_constant_cycle, find_floquet_steady_state
from src.baseline import compute_stochastic_limit, StochasticParams
from src.optimization import run_grape_optimization, GRAPEConfig
from src.rl import FloquetCoolingEnv, train_sac
print("All imports OK!")

## 1. Quick Physics Test

In [None]:
# Test physics engine
params = SystemParams(kappa=0.05, gamma1=0.01, T_bath=0.5, T_atom=0.05)
ops = build_operators(params)
rho = thermal_cavity_ground_qubits(params)

import jax.numpy as jnp
n_init = float(jnp.real(jnp.trace(ops.n_cav @ rho)))
print(f"Initial cavity occupation: {n_init:.4f}")

## 2. GRAPE Optimization (GPU-accelerated)

In [None]:
# Run GRAPE optimization
config = GRAPEConfig(
    n_steps=20,
    T_cycle=0.5,
    n_cycles_eval=100,
    learning_rate=0.02,
    n_iterations=200,
    g_max=1.5,
    delta_max=0.3,
)

optimal_cycle, history = run_grape_optimization(params, config)

print(f"\nFinal <n>: {history[-1]:.4f}")

In [None]:
# Plot GRAPE results
import matplotlib.pyplot as plt
import numpy as np

stoch = StochasticParams(
    omega_c=5.0, omega_a=5.0, kappa=0.05,
    T_bath=0.5, T_atom=0.05, lambda_ex=5.0,
    g=0.5, tau=0.05, R=5.0, chi=2.0,
)
n_stoch, _ = compute_stochastic_limit(stoch)

fig, axes = plt.subplots(1, 3, figsize=(14, 4))

axes[0].plot(history)
axes[0].axhline(n_stoch, color='r', linestyle='--', label='Stochastic')
axes[0].set_xlabel('Iteration')
axes[0].set_ylabel('<n>')
axes[0].legend()

t = np.linspace(0, config.T_cycle, config.n_steps)
axes[1].step(t, optimal_cycle.g_sequence, 'g-', where='post')
axes[1].set_xlabel('Time (ns)')
axes[1].set_ylabel('g(t)')

axes[2].step(t, optimal_cycle.delta_sequence, 'purple', where='post')
axes[2].set_xlabel('Time (ns)')
axes[2].set_ylabel('delta(t)')

plt.tight_layout()
plt.show()

print(f"Stochastic limit: {n_stoch:.4f}")
print(f"Improvement: {(n_stoch - history[-1])/n_stoch*100:.1f}%")

## 3. SAC Training (longer run)

In [None]:
# Train SAC agent (adjust timesteps as needed)
model, callback, env = train_sac(
    total_timesteps=50000,
    n_steps_per_cycle=20,
    n_cycles_per_episode=50,
)

In [None]:
# Evaluate SAC results
from src.rl import evaluate_trained_agent

results, g_seq, delta_seq = evaluate_trained_agent(model, env, n_eval_episodes=5)

best_n = min(r['n_cav'] for r in results)
print(f"Best n_cav: {best_n:.4f} (target: {n_stoch:.4f})")