In [1]:
import os
os.environ['JAX_PLATFORM_NAME'] = 'gpu'

import jax
print("JAX devices:", jax.devices())   # verify you see gpu:0


JAX devices: [CudaDevice(id=0)]


Improvisations to do:

1. main_config['policy']['model'].update({
    """three layers, each 256 units"""
    'encoder_hidden_size_list': [256, 256, 256]
})

2. learn 9x9 first,
then 9x64,
then 9x64x2-64,
then the (32,32)-9 block structure from heatmap
then 64x64 optional

3.




In [2]:
from easydict import EasyDict as edict
from ding.config import compile_config
from ding.entry import serial_pipeline

main_config = edict({
    "exp_name": "pdqn_exchange_cnot",

    
    # ────────────────────── environment ────────────────────── #
    "env": {
        "import_names": ["exch_gym_env"],
        "type": "ExchangeCNOTEnvDI",
        "max_episode_steps": 18,
        "collector_env_num": 8,
        "evaluator_env_num": 3,
        "use_act_scale": True,
    },

    # ───────────────────────── policy ───────────────────────── #
    "policy": {
        "type": "pdqn_command",
        "cuda": True,  # use GPU for training
        # ‣ model description → **one** dict for both branches
        "model": {
            "obs_shape": 163,
            "action_shape": edict({
                "action_type_shape": 5,   # discrete: 5 neighbour pairs
                "action_args_shape": 1,   # continuous: swap-power p
                "encoder_hidden_size_list": [256, 256, 256]
            }),
        },

        # ‣ learning hyper-params
        "learn": {
            "multi_gpu": False,
            "hook": {"load_on_driver": True},
            "train_epoch": 100,
            "batch_size": 64,

            # ──► PDQN needs these two ◄──
            "learning_rate_dis": 1e-3,   # discrete Q-network
            "learning_rate_cont": 1e-3,  # continuous Q-network
            "update_circle": 10,
            "weight_decay": 0,
        },
        # ‣ data collection / evaluation
        "collect": {
            "n_sample": 320,
            "unroll_len": 1,
            "noise": True,
            # NEW – Gaussian with σ=0.7 mapped to [-2,2]
            "action_args_noise": {          # <-- continuous branch noise
                "type": "normal",
                "sigma": 0.7
            }
        },
        "eval":    {"evaluator": {"eval_freq": 1000, "n_episode": 5}},

        # ‣ misc
        "other": {
            "eps": {
                "type": "exp",
                "start": 1.0,
                "end": 0.05,
                "decay": 10000,
            },
            "replay_buffer": {"replay_buffer_size": 100_000},
        },
    },
})

create_config = edict({
    # 1. env_manager key so compile_config won't crash
    "env_manager": {
        "type": "base",      # matches your main_config.manager
    },
    # 2. env must point to your registered class
    "env": {
        "import_names": ["exch_gym_env"],
        "type": "ExchangeCNOTEnvDI",
    },
    # 3. policy command name
    "policy": {
        "type": "pdqn",
    },
})

if __name__ == "__main__":
    # pass both dicts in a list to serial_pipeline
    serial_pipeline([main_config, create_config], seed=42, max_env_step=20000,)


  register_for_torch(TreeValue)
  register_for_torch(FastTreeValue)
  from .autonotebook import tqdm as notebook_tqdm


  return torch.from_numpy(item).to(dtype)
  pair_idx = int(action['action_type'])


  return F.mse_loss(input, target, reduction=self.reduction)


## Unit testing

In [3]:
# test_exchange_cnot_env.py
import math
import numpy as np
import pytest
import math, logging, numpy as np, pytest
from exch_gym_env import ExchangeCNOTEnvDI, NEIGHBORS   # adjust import path if needed

logging.basicConfig(level=logging.INFO, format="%(message)s")
log = logging.getLogger("cnot‐env")

p1 = math.acos(-1 / math.sqrt(3)) / math.pi      # ≈ 0.304086723
p2 = math.asin( 1 / 3)            / math.pi      # ≈ 0.108253176

gate_specs = [
    ( 1+p1,  [3,4] ),
    # ( p1,    [3,4] ),
    ( p2,    [4,5] ),
    ( 0.5,   [2,3] ),
    ( 1.0,   [3,4] ),
    (-0.5,   [2,3] ),
    (-0.5,   [4,5] ),
    ( 1.0,   [1,2] ),
    (-0.5,   [3,4] ),
    (-0.5,   [2,3] ),
    ( 1.0,   [4,5] ),
    (-0.5,   [1,2] ),
    ( 0.5,   [3,4] ),
    (-0.5,   [2,3] ),
    ( 1.0,   [4,5] ),
    ( 1.0,   [1,2] ),
    (-0.5,   [3,4] ),
    (-0.5,   [2,3] ),
    (-0.5,   [4,5] ),
    ( 1.0,   [3,4] ),
    ( 0.5,   [2,3] ),
    ( 1-p2,  [4,5] ),
    # ( -p1,   [3,4] ),
    ( 1-p1,  [3,4] ),
]


def pair_to_index(pair):
    for idx, (i, j) in enumerate(NEIGHBORS):
        if pair in ([i, j], [j, i]): return idx
    raise ValueError

env = ExchangeCNOTEnvDI(max_depth=30, obs_mode="block")
obs = env.reset()
cum_r = 0.0
print("step | pair | p        | r   | fid64   | fid9")
for k, (p, pair) in enumerate(gate_specs, 1):
    ts = env.step({"action_type": pair_to_index(pair), "action_args": [p]})
    cum_r += ts.reward
    print(f"{k:4d} | {pair} | {p:+.6f} | {ts.reward:+.3f} | "
          f"{ts.info['fid64']:.6f} | {ts.info['fid9']:.6f}")
    if ts.done:
        break

print("-"*64)
print(f"terminated: {ts.done}   total reward: {cum_r:+.3f}")
print(f"final fidelities  F64={ts.info['fid64']:.6f}  F9={ts.info['fid9']:.6f}")
env.close()




step | pair | p        | r   | fid64   | fid9
   1 | [3, 4] | +1.695913 | +13.000 | 0.067011 | 0.155245
   2 | [4, 5] | +0.108173 | +5.000 | 0.064629 | 0.152741
   3 | [2, 3] | +0.500000 | -3.000 | 0.046148 | 0.104504
   4 | [3, 4] | +1.000000 | -3.000 | 0.027403 | 0.090456
   5 | [2, 3] | -0.500000 | +15.000 | 0.028145 | 0.085771
   6 | [4, 5] | -0.500000 | -1.000 | 0.031295 | 0.076565
   7 | [1, 2] | +1.000000 | +15.500 | 0.097912 | 0.102477
   8 | [3, 4] | -0.500000 | +3.000 | 0.098648 | 0.089825
   9 | [2, 3] | -0.500000 | -6.000 | 0.083570 | 0.072815
  10 | [4, 5] | +1.000000 | +12.338 | 0.125226 | 0.133955
  11 | [1, 2] | -0.500000 | -6.317 | 0.063352 | 0.080147
  12 | [3, 4] | +0.500000 | +4.036 | 0.089911 | 0.098619
  13 | [2, 3] | -0.500000 | -6.606 | 0.083674 | 0.079254
  14 | [4, 5] | +1.000000 | -2.742 | 0.064606 | 0.066368
  15 | [1, 2] | +1.000000 | +9.627 | 0.212291 | 0.124714
  16 | [3, 4] | -0.500000 | -7.000 | 0.149831 | 0.107227
  17 | [2, 3] | -0.500000 | -7.123 | 0

### PPO

In [4]:
import jax
import jax.numpy as jnp
import flax.linen as nn
import numpy as np
import optax
from flax.linen.initializers import constant, orthogonal
from typing import Sequence, NamedTuple, Any
from flax.training.train_state import TrainState
import distrax
import gymnax
from wrappers import LogWrapper, FlattenObservationWrapper


class ActorCritic(nn.Module):
    action_dim: Sequence[int]
    activation: str = "tanh"

    @nn.compact
    def __call__(self, x):
        if self.activation == "relu":
            activation = nn.relu
        else:
            activation = nn.tanh
        actor_mean = nn.Dense(
            64, kernel_init=orthogonal(np.sqrt(2)), bias_init=constant(0.0)
        )(x)
        actor_mean = activation(actor_mean)
        actor_mean = nn.Dense(
            64, kernel_init=orthogonal(np.sqrt(2)), bias_init=constant(0.0)
        )(actor_mean)
        actor_mean = activation(actor_mean)
        actor_mean = nn.Dense(
            self.action_dim, kernel_init=orthogonal(0.01), bias_init=constant(0.0)
        )(actor_mean)
        pi = distrax.Categorical(logits=actor_mean)

        critic = nn.Dense(
            64, kernel_init=orthogonal(np.sqrt(2)), bias_init=constant(0.0)
        )(x)
        critic = activation(critic)
        critic = nn.Dense(
            64, kernel_init=orthogonal(np.sqrt(2)), bias_init=constant(0.0)
        )(critic)
        critic = activation(critic)
        critic = nn.Dense(1, kernel_init=orthogonal(1.0), bias_init=constant(0.0))(
            critic
        )

        return pi, jnp.squeeze(critic, axis=-1)


class Transition(NamedTuple):
    done: jnp.ndarray
    action: jnp.ndarray
    value: jnp.ndarray
    reward: jnp.ndarray
    log_prob: jnp.ndarray
    obs: jnp.ndarray
    info: jnp.ndarray


def make_train(config):
    config["NUM_UPDATES"] = (
        config["TOTAL_TIMESTEPS"] // config["NUM_STEPS"] // config["NUM_ENVS"]
    )
    config["MINIBATCH_SIZE"] = (
        config["NUM_ENVS"] * config["NUM_STEPS"] // config["NUM_MINIBATCHES"]
    )
    env, env_params = gymnax.make(config["ENV_NAME"])
    env = FlattenObservationWrapper(env)
    env = LogWrapper(env)

    def linear_schedule(count):
        frac = (
            1.0
            - (count // (config["NUM_MINIBATCHES"] * config["UPDATE_EPOCHS"]))
            / config["NUM_UPDATES"]
        )
        return config["LR"] * frac

    def train(rng):
        # INIT NETWORK
        network = ActorCritic(
            env.action_space(env_params).n, activation=config["ACTIVATION"]
        )
        rng, _rng = jax.random.split(rng)
        init_x = jnp.zeros(env.observation_space(env_params).shape)
        network_params = network.init(_rng, init_x)
        if config["ANNEAL_LR"]:
            tx = optax.chain(
                optax.clip_by_global_norm(config["MAX_GRAD_NORM"]),
                optax.adam(learning_rate=linear_schedule, eps=1e-5),
            )
        else:
            tx = optax.chain(
                optax.clip_by_global_norm(config["MAX_GRAD_NORM"]),
                optax.adam(config["LR"], eps=1e-5),
            )
        train_state = TrainState.create(
            apply_fn=network.apply,
            params=network_params,
            tx=tx,
        )

        # INIT ENV
        rng, _rng = jax.random.split(rng)
        reset_rng = jax.random.split(_rng, config["NUM_ENVS"])
        obsv, env_state = jax.vmap(env.reset, in_axes=(0, None))(reset_rng, env_params)

        # TRAIN LOOP
        def _update_step(runner_state, unused):
            # COLLECT TRAJECTORIES
            def _env_step(runner_state, unused):
                train_state, env_state, last_obs, rng = runner_state

                # SELECT ACTION
                rng, _rng = jax.random.split(rng)
                pi, value = network.apply(train_state.params, last_obs)
                action = pi.sample(seed=_rng)
                log_prob = pi.log_prob(action)

                # STEP ENV
                rng, _rng = jax.random.split(rng)
                rng_step = jax.random.split(_rng, config["NUM_ENVS"])
                obsv, env_state, reward, done, info = jax.vmap(
                    env.step, in_axes=(0, 0, 0, None)
                )(rng_step, env_state, action, env_params)
                transition = Transition(
                    done, action, value, reward, log_prob, last_obs, info
                )
                runner_state = (train_state, env_state, obsv, rng)
                return runner_state, transition

            runner_state, traj_batch = jax.lax.scan(
                _env_step, runner_state, None, config["NUM_STEPS"]
            )

            # CALCULATE ADVANTAGE
            train_state, env_state, last_obs, rng = runner_state
            _, last_val = network.apply(train_state.params, last_obs)

            def _calculate_gae(traj_batch, last_val):
                def _get_advantages(gae_and_next_value, transition):
                    gae, next_value = gae_and_next_value
                    done, value, reward = (
                        transition.done,
                        transition.value,
                        transition.reward,
                    )
                    delta = reward + config["GAMMA"] * next_value * (1 - done) - value
                    gae = (
                        delta
                        + config["GAMMA"] * config["GAE_LAMBDA"] * (1 - done) * gae
                    )
                    return (gae, value), gae

                _, advantages = jax.lax.scan(
                    _get_advantages,
                    (jnp.zeros_like(last_val), last_val),
                    traj_batch,
                    reverse=True,
                    unroll=16,
                )
                return advantages, advantages + traj_batch.value

            advantages, targets = _calculate_gae(traj_batch, last_val)

            # UPDATE NETWORK
            def _update_epoch(update_state, unused):
                def _update_minbatch(train_state, batch_info):
                    traj_batch, advantages, targets = batch_info

                    def _loss_fn(params, traj_batch, gae, targets):
                        # RERUN NETWORK
                        pi, value = network.apply(params, traj_batch.obs)
                        log_prob = pi.log_prob(traj_batch.action)

                        # CALCULATE VALUE LOSS
                        value_pred_clipped = traj_batch.value + (
                            value - traj_batch.value
                        ).clip(-config["CLIP_EPS"], config["CLIP_EPS"])
                        value_losses = jnp.square(value - targets)
                        value_losses_clipped = jnp.square(value_pred_clipped - targets)
                        value_loss = (
                            0.5 * jnp.maximum(value_losses, value_losses_clipped).mean()
                        )

                        # CALCULATE ACTOR LOSS
                        ratio = jnp.exp(log_prob - traj_batch.log_prob)
                        gae = (gae - gae.mean()) / (gae.std() + 1e-8)
                        loss_actor1 = ratio * gae
                        loss_actor2 = (
                            jnp.clip(
                                ratio,
                                1.0 - config["CLIP_EPS"],
                                1.0 + config["CLIP_EPS"],
                            )
                            * gae
                        )
                        loss_actor = -jnp.minimum(loss_actor1, loss_actor2)
                        loss_actor = loss_actor.mean()
                        entropy = pi.entropy().mean()

                        total_loss = (
                            loss_actor
                            + config["VF_COEF"] * value_loss
                            - config["ENT_COEF"] * entropy
                        )
                        return total_loss, (value_loss, loss_actor, entropy)

                    grad_fn = jax.value_and_grad(_loss_fn, has_aux=True)
                    total_loss, grads = grad_fn(
                        train_state.params, traj_batch, advantages, targets
                    )
                    train_state = train_state.apply_gradients(grads=grads)
                    return train_state, total_loss

                train_state, traj_batch, advantages, targets, rng = update_state
                rng, _rng = jax.random.split(rng)
                # Batching and Shuffling
                batch_size = config["MINIBATCH_SIZE"] * config["NUM_MINIBATCHES"]
                assert (
                    batch_size == config["NUM_STEPS"] * config["NUM_ENVS"]
                ), "batch size must be equal to number of steps * number of envs"
                permutation = jax.random.permutation(_rng, batch_size)
                batch = (traj_batch, advantages, targets)
                batch = jax.tree_util.tree_map(
                    lambda x: x.reshape((batch_size,) + x.shape[2:]), batch
                )
                shuffled_batch = jax.tree_util.tree_map(
                    lambda x: jnp.take(x, permutation, axis=0), batch
                )
                # Mini-batch Updates
                minibatches = jax.tree_util.tree_map(
                    lambda x: jnp.reshape(
                        x, [config["NUM_MINIBATCHES"], -1] + list(x.shape[1:])
                    ),
                    shuffled_batch,
                )
                train_state, total_loss = jax.lax.scan(
                    _update_minbatch, train_state, minibatches
                )
                update_state = (train_state, traj_batch, advantages, targets, rng)
                return update_state, total_loss
            # Updating Training State and Metrics:
            update_state = (train_state, traj_batch, advantages, targets, rng)
            update_state, loss_info = jax.lax.scan(
                _update_epoch, update_state, None, config["UPDATE_EPOCHS"]
            )
            train_state = update_state[0]
            metric = traj_batch.info
            rng = update_state[-1]
            
            # Debugging mode
            if config.get("DEBUG"):
                def callback(info):
                    return_values = info["returned_episode_returns"][info["returned_episode"]]
                    timesteps = info["timestep"][info["returned_episode"]] * config["NUM_ENVS"]
                    for t in range(len(timesteps)):
                        print(f"global step={timesteps[t]}, episodic return={return_values[t]}")
                jax.debug.callback(callback, metric)

            runner_state = (train_state, env_state, last_obs, rng)
            return runner_state, metric

        rng, _rng = jax.random.split(rng)
        runner_state = (train_state, env_state, obsv, _rng)
        runner_state, metric = jax.lax.scan(
            _update_step, runner_state, None, config["NUM_UPDATES"]
        )
        return {"runner_state": runner_state, "metrics": metric}

    return train


if __name__ == "__main__":
    config = {
        "LR": 2.5e-4,
        "NUM_ENVS": 4,
        "NUM_STEPS": 128,
        "TOTAL_TIMESTEPS": 5e5,
        "UPDATE_EPOCHS": 4,
        "NUM_MINIBATCHES": 4,
        "GAMMA": 0.99,
        "GAE_LAMBDA": 0.95,
        "CLIP_EPS": 0.2,
        "ENT_COEF": 0.01,
        "VF_COEF": 0.5,
        "MAX_GRAD_NORM": 0.5,
        "ACTIVATION": "tanh",
        "ENV_NAME": "CartPole-v1",
        "ANNEAL_LR": True,
        "DEBUG": True,
    }
    rng = jax.random.PRNGKey(30)
    train_jit = jax.jit(make_train(config))
    out = train_jit(rng)

2025-06-01 16:49:51.383560: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:467] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1748821791.444458  554626 cuda_dnn.cc:8579] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1748821791.460259  554626 cuda_blas.cc:1407] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
W0000 00:00:1748821791.585564  554626 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
W0000 00:00:1748821791.585781  554626 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
W0000 00:00:1748821791.585783  554626 computation_placer.cc:177] computation placer alr

global step=56, episodic return=14.0
global step=72, episodic return=18.0
global step=100, episodic return=11.0
global step=124, episodic return=31.0
global step=136, episodic return=16.0
global step=160, episodic return=15.0
global step=176, episodic return=44.0
global step=176, episodic return=13.0
global step=208, episodic return=12.0
global step=252, episodic return=19.0
global step=252, episodic return=29.0
global step=268, episodic return=15.0
global step=324, episodic return=18.0
global step=356, episodic return=45.0
global step=364, episodic return=28.0
global step=396, episodic return=32.0
global step=396, episodic return=18.0
global step=416, episodic return=13.0
global step=452, episodic return=14.0
global step=460, episodic return=16.0
global step=480, episodic return=31.0
global step=496, episodic return=20.0
global step=504, episodic return=13.0
global step=576, episodic return=20.0
global step=576, episodic return=29.0
global step=580, episodic return=19.0
global step=63