In [None]:
from ding.entry import serial_pipeline
from config import main_config, create_config

serial_pipeline([main_config, create_config], seed=42, max_env_step=25000,)

In [11]:


# ───────────────────────── manual_inference.py ──────────────────────────
import os
import torch
import numpy as np
from easydict import EasyDict as edict

# 1)  DI-engine PDQN policy
from ding.policy.pdqn import PDQNPolicy          # or pdqn_command
# 2)  Your custom environment
from exch_gym_env import ExchangeCNOTEnvDI

# ─── change these two paths as needed ───────────────────────────────────
CKPT1 = "pdqn_exchange_cnot_250605_010501/ckpt/iteration_12.pth.tar"
CKPT2 = "pdqn_exchange_cnot_250605_010501/ckpt/ckpt_best.pth.tar"
# ------------------------------------------------------------------------

def run_checkpoint(ckpt_path: str) -> None:
    if not os.path.isfile(ckpt_path):
        raise FileNotFoundError(ckpt_path)

    # -------- build a *minimal* PDQN policy object ----------------------
    cfg = edict(
        type="pdqn",
        cuda=torch.cuda.is_available(),
        on_policy=False,
        model=dict(
            obs_shape=168,                      # <-- keep same as training
            action_shape=edict(
                action_type_shape=5,
                action_args_shape=1,
                encoder_hidden_size_list=[256, 256, 256],
            ),
        ),
        collect=dict(n_sample=0),
        other=dict(replay_buffer=dict(replay_buffer_size=1)),
        model_load_mode="ckpt",
        load_path=ckpt_path,
    )

    policy  = PDQNPolicy(cfg, enable_field=["eval"])
    device  = torch.device("cuda" if cfg.cuda else "cpu")

    # -------- load weights ---------------------------------------------
    state = torch.load(ckpt_path, map_location="cpu")
    policy._model.load_state_dict(state["model"], strict=False)
    if hasattr(policy, "_target_model"):
        policy._target_model.load_state_dict(state["model"], strict=False)

    model = policy._eval_model      # wrapped with HybridArgmaxSampleWrapper
    model.eval()                    # no dropout / batch-norm
    model.to(device)

    # -------- evaluation episode ---------------------------------------
    env   = ExchangeCNOTEnvDI(use_act_scale=True)
    obs   = env.reset()                                # numpy array (168,)
    seq   = []                                         # [(pair_idx, p), …]
    total = 0.0
    done  = False

    while not done and len(seq) < env.max_depth:
        # --- observation tensor (shape 1 × N) --------------------------
        obs_t = torch.as_tensor(obs, dtype=torch.float32,
                                device=device).unsqueeze(0)

        # --- writable action-mask tensor (shape 1 × 5) -----------------
        mask_np = np.asarray(env.valid_mask, dtype=np.float32).copy()
        mask_t  = torch.from_numpy(mask_np).to(device).unsqueeze(0)

        with torch.no_grad():
            # 1) continuous argument
            cont_out    = model.forward(obs_t, mode="compute_continuous")
            action_args = cont_out["action_args"]               # (1, 1)

            # 2) discrete choice with mask
            inputs  = {
                "state":       obs_t,
                "action_args": action_args,
                "action_mask": mask_t,          # the crucial piece
            }
            dis_out = model.forward(inputs, mode="compute_discrete")
            logits = dis_out["logit"]             # shape: (1, 5)
            probs  = torch.softmax(logits, dim=-1)
            pair_idx = torch.multinomial(probs, num_samples=1).item()
            p_value  = float(action_args[0])

        act = {"action_type": torch.tensor([pair_idx]), "action_args": torch.tensor([p_value])}
        pair_idx  = int(act["action_type"][0])
        p_value   = float(act["action_args"][0])

        seq.append((pair_idx, p_value))
        obs, reward, done, info = env.step((pair_idx, p_value))
        total += reward

    # -------- print results --------------------------------------------
    print(f"\n=== BEST-PATH SEQUENCE  ({os.path.basename(ckpt_path)}) ===\n")
    for t, (idx, p) in enumerate(seq, 1):
        print(f"Step {t:2d}: pair = {idx},  p = {p:+.4f}")

    print("\nFinal metrics:")
    print(f" 64×64 fidelity    : {info.get('fid64', np.nan):.6f}")
    print(f"  9×9 block fidelity: {info.get('fid9',  np.nan):.6f}")
    print(f" Total return       : {total:.6f}")

# ─── run the two checkpoints -------------------------------------------
run_checkpoint(CKPT1)
run_checkpoint(CKPT2)


=== BEST-PATH SEQUENCE  (iteration_12.pth.tar) ===

Step  1: pair = 0,  p = -0.0144
Step  2: pair = 1,  p = -0.0152
Step  3: pair = 3,  p = -0.0125
Step  4: pair = 2,  p = -0.0159
Step  5: pair = 1,  p = -0.0174
Step  6: pair = 2,  p = -0.0156
Step  7: pair = 0,  p = -0.0194
Step  8: pair = 3,  p = -0.0199
Step  9: pair = 2,  p = -0.0218
Step 10: pair = 0,  p = -0.0222
Step 11: pair = 2,  p = -0.0217
Step 12: pair = 3,  p = -0.0238
Step 13: pair = 3,  p = -0.0257
Step 14: pair = 2,  p = -0.0251
Step 15: pair = 1,  p = -0.0237
Step 16: pair = 2,  p = -0.0213
Step 17: pair = 3,  p = -0.0223
Step 18: pair = 3,  p = -0.0226

Final metrics:
 64×64 fidelity    : nan
  9×9 block fidelity: nan
 Total return       : -52.680750

=== BEST-PATH SEQUENCE  (ckpt_best.pth.tar) ===

Step  1: pair = 2,  p = +0.0414
Step  2: pair = 1,  p = +0.0424
Step  3: pair = 4,  p = +0.0440
Step  4: pair = 0,  p = +0.0434
Step  5: pair = 3,  p = +0.0406
Step  6: pair = 0,  p = +0.0454
Step  7: pair = 4,  p = +0.04

In [1]:
"""
Replay the forward-target gate list and print reward + fidelity each step.
"""

import math
import jax.numpy as jnp

from exch_gym_env import ExchangeCNOTEnvDI, _fidelity    # env + helper
from fw_target       import gate_specs, U_circuit        # sequence + target

# -----------------------------------------------------------------------------
# The environment hard-codes five nearest-neighbour pairs in this order:
# [(0,1), (1,2), (2,3), (3,4), (4,5)].
# We invert that list once so each (i, j) → correct discrete index.
PAIR_TO_INDEX = {(0, 1): 0, (1, 2): 1, (2, 3): 2, (3, 4): 3, (4, 5): 4}

# -----------------------------------------------------------------------------
def main():
    env = ExchangeCNOTEnvDI(
        max_depth=len(gate_specs) + 2,   # small cushion
        obs_mode="both"                  # whatever you like → "block"|"full"|"both"
    )
    env.reset()

    print("\nReplay of fw_target gate list:")
    print("step |  pair  |      p       |  reward   | fidelity")
    print("-----+--------+--------------+-----------+----------")

    for step_idx, (p, (i, j)) in enumerate(gate_specs, 1):
        pair_idx = PAIR_TO_INDEX[(i, j)]
        _, reward, done, _ = env.step((pair_idx, p))

        # compute fidelity wrt the published target circuit
        fid_now = float(_fidelity(jnp.asarray(env.U), jnp.asarray(U_circuit)))

        print(f"{step_idx:4d} | {i}-{j}  | {p:+.6f} | {reward:+9.3f} | {fid_now:8.6f}")

        if done:
            print(f"\nEpisode terminated early by reward logic at step {step_idx}.\n")
            break

    # final sanity check
    if not done:
        frob_dist = jnp.linalg.norm(env.U - U_circuit)
        assert math.isclose(fid_now, 1.0, abs_tol=1e-6) and frob_dist < 1e-6, \
            "Replay finished but circuit differs from target!"

    env.close()


if __name__ == "__main__":
    main()


  register_for_torch(TreeValue)
  register_for_torch(FastTreeValue)



Replay of fw_target gate list:
step |  pair  |      p       |  reward   | fidelity
-----+--------+--------------+-----------+----------
[debug] step 1  raw_reward=13.0
   1 | 3-4  | +1.695913 |   +13.000 | 0.067011
[debug] step 2  raw_reward=5.0
   2 | 4-5  | +0.108173 |    +5.000 | 0.064629
[debug] step 3  raw_reward=-3.0
   3 | 2-3  | +0.500000 |    -3.000 | 0.046148
[debug] step 4  raw_reward=-3.0
   4 | 3-4  | +1.000000 |    -3.000 | 0.027403
[debug] step 5  raw_reward=15.0
   5 | 2-3  | -0.500000 |   +15.000 | 0.028145
[debug] step 6  raw_reward=-1.0
   6 | 4-5  | -0.500000 |    -1.000 | 0.031295
[debug] step 7  raw_reward=15.5
   7 | 1-2  | +1.000000 |   +15.500 | 0.097912
[debug] step 8  raw_reward=3.0
   8 | 3-4  | -0.500000 |    +3.000 | 0.098648
[debug] step 9  raw_reward=-6.0
   9 | 2-3  | -0.500000 |    -6.000 | 0.083570
[debug] step 10  raw_reward=12.33772233983162
  10 | 4-5  | +1.000000 |   +12.338 | 0.125226
[debug] step 11  raw_reward=-6.3166247903554
  11 | 1-2  | -0

In [2]:
test_gate_sequence_runs_cleanly()

  pair_idx, p = int(action[0]), float(action[1])


[debug] step 1  raw_reward=13.0
step 01 pair=3 p=+1.695913  reward=+13.000000  fid64=0.067011  fid9=0.155245
[debug] step 2  raw_reward=5.0
step 02 pair=4 p=+0.108173  reward=+5.000000  fid64=0.064629  fid9=0.152741
[debug] step 3  raw_reward=97.0
Episode finished early at step 3


In [18]:
def _fidelity(A: np.ndarray, B: np.ndarray) -> float:

    """
    Computes the fidelity between two square matrices A and B.
    Fidelity is defined as:
    F(A, B) = (Tr(A†A) + |Tr(B†A)|²) / (n * (n + 1))

    updated from using just falttened inner product divided by product of norms.

    # inner = jnp.vdot(A, B)
    # return jnp.abs(inner) / (jnp.linalg.norm(A) * jnp.linalg.norm(B))

    """

    if A.shape != B.shape or A.ndim != 2 or A.shape[0] != A.shape[1]:
        raise ValueError("A and B must be square matrices of the same size")

    n = A.shape[0]

    # --- core computation ----------------------------------------------------
    term1 = np.trace(A.conj().T @ A)                 # Tr(A†A)
    term2 = np.trace(B.conj().T @ A)                 # Tr(B†A)
    fidelity_val = (term1 + abs(term2)**2) / (n * (n + 1))

    return fidelity_val

In [19]:
from fw_target import U_circuit as TARGET_FULL


print(_fidelity(TARGET_FULL, TARGET_FULL))  # should be 1.0

(0.9954603734941675+8.045487353259573e-14j)
