In [1]:
import os
os.environ['JAX_PLATFORM_NAME'] = 'gpu'

import jax
print("JAX devices:", jax.devices())   # verify you see gpu:0


JAX devices: [CudaDevice(id=0)]


In [1]:
from easydict import EasyDict as edict
from ding.config import compile_config
from ding.entry import serial_pipeline

main_config = edict({
    "exp_name": "pdqn_exchange_cnot",

    # ────────────────────── environment ────────────────────── #
    "env": {
        "import_names": ["exch_gym_env"],
        "type": "ExchangeCNOTEnvDI",
        "max_episode_steps": 18,
        "collector_env_num": 8,
        "evaluator_env_num": 3,
        "use_act_scale": True,
    },

    # ───────────────────────── policy ───────────────────────── #
    "policy": {
        "type": "pdqn_command",
        "cuda": True,  # use GPU for training
        # ‣ model description → **one** dict for both branches
        "model": {
            "obs_shape": 163,
            "action_shape": edict({
                "action_type_shape": 5,   # discrete: 5 neighbour pairs
                "action_args_shape": 1,   # continuous: swap-power p
            }),
        },

        # ‣ learning hyper-params
        "learn": {
            "multi_gpu": False,
            "hook": {"load_on_driver": True},
            "train_epoch": 100,
            "batch_size": 64,

            # ──► PDQN needs these two ◄──
            "learning_rate_dis": 1e-3,   # discrete Q-network
            "learning_rate_cont": 1e-3,  # continuous Q-network
            "update_circle": 10,
            "weight_decay": 0,
        },
        # ‣ data collection / evaluation
        "collect": {
            "n_sample": 320,
            "unroll_len": 1,
            "noise": True,
            # NEW – Gaussian with σ=0.7 mapped to [-2,2]
            "action_args_noise": {          # <-- continuous branch noise
                "type": "normal",
                "sigma": 0.7
            }
        },
        "eval":    {"evaluator": {"eval_freq": 1000, "n_episode": 5}},

        # ‣ misc
        "other": {
            "eps": {
                "type": "exp",
                "start": 1.0,
                "end": 0.05,
                "decay": 10000,
            },
            "replay_buffer": {"replay_buffer_size": 100_000},
        },
    },
})

# create_cfg now includes the minimal pieces DI-engine needs
create_config = edict({
    # 1. env_manager key so compile_config won't crash
    "env_manager": {
        "type": "base",      # matches your main_config.manager
    },
    # 2. env must point to your registered class
    "env": {
        "import_names": ["exch_gym_env"],
        "type": "ExchangeCNOTEnvDI",
    },
    # 3. policy command name
    "policy": {
        "type": "pdqn",
    },
})

if __name__ == "__main__":
    # pass both dicts in a list to serial_pipeline
    serial_pipeline([main_config, create_config], seed=42)


  register_for_torch(TreeValue)
  register_for_torch(FastTreeValue)
  from .autonotebook import tqdm as notebook_tqdm


TypeError: not support item type: <class 'jaxlib._jax.ArrayImpl'>

## Unit testing

In [None]:
# test_exchange_cnot_env.py
import math
import numpy as np
import pytest
import math, logging, numpy as np, pytest
from exch_gym_env import ExchangeCNOTEnvDI, NEIGHBORS   # adjust import path if needed

logging.basicConfig(level=logging.INFO, format="%(message)s")
log = logging.getLogger("cnot‐env")

p1 = math.acos(-1 / math.sqrt(3)) / math.pi      # ≈ 0.304086723
p2 = math.asin( 1 / 3)            / math.pi      # ≈ 0.108253176

gate_specs = [
    ( 1+p1,  [3,4] ),
    # ( p1,    [3,4] ),
    ( p2,    [4,5] ),
    ( 0.5,   [2,3] ),
    ( 1.0,   [3,4] ),
    (-0.5,   [2,3] ),
    (-0.5,   [4,5] ),
    ( 1.0,   [1,2] ),
    (-0.5,   [3,4] ),
    (-0.5,   [2,3] ),
    ( 1.0,   [4,5] ),
    (-0.5,   [1,2] ),
    ( 0.5,   [3,4] ),
    (-0.5,   [2,3] ),
    ( 1.0,   [4,5] ),
    ( 1.0,   [1,2] ),
    (-0.5,   [3,4] ),
    (-0.5,   [2,3] ),
    (-0.5,   [4,5] ),
    ( 1.0,   [3,4] ),
    ( 0.5,   [2,3] ),
    ( 1-p2,  [4,5] ),
    # ( -p1,   [3,4] ),
    ( 1-p1,  [3,4] ),
]


def pair_to_index(pair):
    for idx, (i, j) in enumerate(NEIGHBORS):
        if pair in ([i, j], [j, i]): return idx
    raise ValueError

env = ExchangeCNOTEnvDI(max_depth=30, obs_mode="block")
obs = env.reset()
cum_r = 0.0
print("step | pair | p        | r   | fid64   | fid9")
for k, (p, pair) in enumerate(gate_specs, 1):
    ts = env.step({"action_type": pair_to_index(pair), "action_args": [p]})
    cum_r += ts.reward
    print(f"{k:4d} | {pair} | {p:+.6f} | {ts.reward:+.3f} | "
          f"{ts.info['fid64']:.6f} | {ts.info['fid9']:.6f}")
    if ts.done:
        break

print("-"*64)
print(f"terminated: {ts.done}   total reward: {cum_r:+.3f}")
print(f"final fidelities  F64={ts.info['fid64']:.6f}  F9={ts.info['fid9']:.6f}")
env.close()




  register_for_torch(TreeValue)
  register_for_torch(FastTreeValue)


step | pair | p        | r   | fid64   | fid9
   1 | [3, 4] | +1.695913 | -1.000 | 0.229334 | 0.273391
   2 | [4, 5] | +0.108173 | -1.000 | 0.223945 | 0.267107
   3 | [2, 3] | +0.500000 | -1.000 | 0.177044 | 0.215730
   4 | [3, 4] | +1.000000 | -1.000 | 0.110696 | 0.149005
   5 | [2, 3] | -0.500000 | +1.500 | 0.114028 | 0.197360
   6 | [4, 5] | -0.500000 | +1.000 | 0.127438 | 0.142548
   7 | [1, 2] | +1.000000 | +8.500 | 0.290135 | 0.270230
   8 | [3, 4] | -0.500000 | +1.000 | 0.291383 | 0.217327
   9 | [2, 3] | -0.500000 | -1.000 | 0.263663 | 0.161522
  10 | [4, 5] | +1.000000 | +6.500 | 0.334522 | 0.401898
  11 | [1, 2] | -0.500000 | -5.750 | 0.221102 | 0.205846
  12 | [3, 4] | +0.500000 | -1.500 | 0.275631 | 0.288777
  13 | [2, 3] | -0.500000 | -4.250 | 0.263796 | 0.201068
  14 | [4, 5] | +1.000000 | -4.500 | 0.224025 | 0.109072
  15 | [1, 2] | +1.000000 | +15.750 | 0.448050 | 0.375983
  16 | [3, 4] | -0.500000 | -5.000 | 0.370208 | 0.320141
  17 | [2, 3] | -0.500000 | -5.250 | 0.27