<a href="https://colab.research.google.com/github/uzum4ke/Topological-Quantum-Compilation/blob/main/RL_for_TQC.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
pip install stable-baselines3[extra]

In [None]:
pip install sb3-contrib

In [None]:
!pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu
!pip install cloud-tpu-client
!pip install torch_xla[tpu] -f https://storage.googleapis.com/libtpu-releases/index.html


In [23]:
import torch
import torch_xla
import torch_xla.core.xla_model as xm

# Check TPU device
print("TPU device:", xm.xla_device())

if torch.cuda.is_available():
    print("GPU is available.")
    print(f"Device name: {torch.cuda.get_device_name(0)}")
else:
    print("GPU is not available. Make sure to enable GPU in the runtime settings.")



TPU device: xla:0
GPU is not available. Make sure to enable GPU in the runtime settings.


In [24]:
import gymnasium as gym
from gymnasium import spaces

#from scipy.linalg import sqrtm

from stable_baselines3 import PPO, DQN
from sb3_contrib import RecurrentPPO
from stable_baselines3.common.callbacks import BaseCallback

import matplotlib.pyplot as plt

from tqdm import tqdm

import numpy as np

import random

from sympy import Matrix, symbols, eye, KroneckerProduct

import warnings

# Suppress RuntimeWarnings
warnings.filterwarnings("ignore", category=RuntimeWarning)

In [25]:
######################################
# Configuration and Hyperparameters
######################################

LEARNING_RATE = 1e-4
ALPHA = 0.25  # Leakage weight
BETA = 0.50   # Closeness weight
GAMMA = 0.25  # Unitarity weight
SEQUENCE_LENGTH = 100  # Number of compositions

In [26]:
def direct_sum(A, B):

    # Helper function to convert input to a SymPy Matrix
    def to_matrix(x):
        if isinstance(x, Matrix):
            return x
        else:
            # Assume x is a scalar, convert to 1x1 Matrix
            return Matrix([[x]])

    # Convert inputs to matrices
    A_matrix = to_matrix(A)
    B_matrix = to_matrix(B)

    # Check if A_matrix is square
    if A_matrix.rows != A_matrix.cols:
        raise ValueError(f"Matrix A is not square: {A_matrix.rows}x{A_matrix.cols}")

    # Check if B_matrix is square
    if B_matrix.rows != B_matrix.cols:
        raise ValueError(f"Matrix B is not square: {B_matrix.rows}x{B_matrix.cols}")

    # Dimensions
    N = A_matrix.rows
    M = B_matrix.rows

    # Create a zero matrix of size (N+M) x (N+M)
    C = Matrix.zeros(N + M, N + M)

    # Assign A_matrix to the upper-left block
    C[:N, :N] = A_matrix

    # Assign B_matrix to the lower-right block
    C[N:N+M, N:N+M] = B_matrix

    return C

def tensor_product(A, B):
    return KroneckerProduct(A, B)

In [27]:
# R matrix
R_matrix = np.array([
    [ np.exp(-4j * np.pi / 5), 0                      ],
    [ 0                      , np.exp(3j * np.pi / 5) ]
])

R_tt1 = symbols("R_tt1")  # Top-left diagonal
R_ttt = symbols("R_ttt")  # Bottom-right diagonal

sym_R = Matrix([
    [R_tt1, 0],
    [0, R_ttt]
])

# F matrix
phi = (1 + np.sqrt(5)) / 2  # Golden ratio
F_matrix = np.array([
    [ 1/phi          , np.sqrt(1/phi) ],
    [ np.sqrt(1/phi) , -1/phi         ]
])

F_11 =  symbols("F_11")
F_12 =  symbols("F_12")
F_21 =  symbols("F_21")
F_22 =  symbols("F_22")

sym_F = Matrix([
    [F_11, F_12],
    [F_21, F_22]
])

# Substitution dictionary
subs = {
    R_tt1: R_matrix[0, 0],
    R_ttt: R_matrix[1, 1],
    F_11: F_matrix[0, 0],
    F_12: F_matrix[1, 0],
    F_21: F_matrix[0, 1],
    F_22: F_matrix[1, 1]
}

# Permutation matrix
I_2 = eye(2)

I_5 = eye(5)
I_5.row_swap(0, 3)
P14 = I_5


In [28]:
# Braid representations

rho_1 = direct_sum(R_ttt, tensor_product(sym_R, I_2).doit())

rho_2 = direct_sum(R_ttt, tensor_product(sym_F * sym_R * sym_F, I_2).doit())

rho_3 = P14 * direct_sum(R_ttt, direct_sum(sym_R, sym_F * sym_R * sym_F)) * P14

rho_4 = direct_sum(R_ttt, tensor_product(I_2, sym_F * sym_R * sym_F).doit())

rho_5 = direct_sum(R_ttt, tensor_product(I_2, sym_R).doit())

# CNOT gate
cnot_gate = Matrix([
    [1, 0, 0, 0],
    [0, 1, 0, 0],
    [0, 0, 0, 1],
    [0, 0, 1, 0]
])

In [29]:
class GateApproxEnv(gym.Env):
    metadata = {'render.modes': ['human']}

    def __init__(self, braid_gates, target_gate, subs, max_length,
                 alpha, beta, gamma, local_equivalence_class=False):
        super(GateApproxEnv, self).__init__()

        self.max_length = max_length
        self.alpha = alpha
        self.beta = beta
        self.gamma = gamma
        self.local_equivalence_class = local_equivalence_class

        # Precompute numeric versions of braid_gates
        self.braid_gates = []
        for g in braid_gates:
            g_evaluated = g.subs(subs).evalf()
            if any(sym.is_symbol for sym in g_evaluated):
                raise ValueError("Not all symbols substituted in a braid gate.")
            self.braid_gates.append(np.array(g_evaluated.tolist(), dtype=complex))

        # Precompute numeric target gate
        t_evaluated = target_gate.subs(subs).evalf()
        if any(sym.is_symbol for sym in t_evaluated):
            raise ValueError("Not all symbols substituted in target_gate.")
        self.target_gate = np.array(t_evaluated.tolist(), dtype=complex)

        # Action space: One action per braid gate
        self.action_space = spaces.Discrete(len(self.braid_gates))

        # Observation: Flattened real+imag parts of the 5x5 gate = 50-dim vector
        self.observation_space = spaces.Box(low=-np.inf, high=np.inf, shape=(50,), dtype=np.float64)

        self.reset_composition()

    def reset_composition(self):
        """Initialize the gate composition as a numeric 5x5 identity matrix."""
        self.current_composition = np.eye(5, dtype=complex)
        self.current_length = 0
        self.gate_stack = []

    def reset(self, *, seed=None, options=None):
        super().reset(seed=seed)
        self.reset_composition()
        self.current_step = 0
        info = {}

        # Randomize the episode length between 1 and max_length
        self.random_episode_length = random.randint(1, self.max_length)

        return self._get_obs(), info

    def take_action(self, action):
        if self.current_length < self.max_length:
            # Numeric matrix multiplication only
            self.current_composition = self.current_composition @ self.braid_gates[action]
            if action < 5:
                #self.gate_stack.append(f"rho_{action+1}")
                self.gate_stack.append(f"{action}")
            else:
                #self.gate_stack.append(f"inv_{action-4}")
                self.gate_stack.append(f"{action}")
            self.current_length += 1
        else:
            print("Warning: Maximum composition length reached. No action taken.")

    def compute_reward(self):
        # Current composition is already numeric
        M = self.current_composition
        T = self.target_gate

        # Leakage: abs(M[0,0])
        leakage = np.abs(M[0,0])

        # Extract 4x4 submatrix
        M_4x4 = M[1:5, 1:5]

        # Unitarity check
        UdagU = M_4x4.conjugate().T @ M_4x4
        unitarity_error = self.schatten_p_norm(UdagU - np.eye(4), 1)

        if self.local_equivalence_class:
            closeness_error = self.local_equivalence_distance(T, M_4x4)
        else:
            # Closeness: Schatten p-norm
            A_normalized = M_4x4 / self.schatten_p_norm(M_4x4, 2)
            T_normalized = T / self.schatten_p_norm(T, 2)

            closeness_error = self.schatten_p_norm(A_normalized - T_normalized, 2)


        # Reward (negative weighted sum)
        reward = - (self.alpha * leakage + self.beta * closeness_error + self.gamma * unitarity_error)
        #reward = - (self.gamma * unitarity_error)

        return float(leakage), float(unitarity_error), float(closeness_error), float(reward)

    def step(self, action):
        self.take_action(action)
        self.current_step += 1

        terminated = (self.current_step >= self.random_episode_length)
        #terminated = (self.current_length >= self.max_length)
        truncated = False

        _, _, _, final_reward = self.compute_reward()
        if terminated:
            reward = final_reward
        else:
            reward = final_reward * 0.01

        obs = self._get_obs()
        info = {}
        return obs, reward, terminated, truncated, info

    def _get_obs(self):
        # current_composition is numeric, just flatten
        M = self.current_composition
        obs = np.concatenate([M.real.flatten(), M.imag.flatten()])
        return obs

    def render(self, mode='human'):
        pass

    def close(self):
        pass

    """
    def schatten_p_norm(self, T, p):

        # Compute |T| = sqrt(T^† T)
        abs_T = sqrtm(T.conj().T @ T)

        # Compute |T|^p
        abs_T_p = np.linalg.matrix_power(abs_T, p)

        # Compute the trace of |T|^p
        trace_value = np.trace(abs_T_p)

        # Compute the Schatten p-norm
        schatten_norm = np.real(trace_value)**(1/p)

        return schatten_norm
    """

    def schatten_p_norm(self, T, p):
        """
        Compute the Schatten p-norm of a matrix T using its singular values.
        """
        # Ensure T is a NumPy array
        T = np.array(T, dtype=complex)

        # Compute singular values of T
        singular_values = np.linalg.svd(T, compute_uv=False)

        # Compute Schatten p-norm
        if p == np.inf:  # Special case for p = infinity
            schatten_norm = np.max(singular_values)
        elif p == 1:  # Special case for p = 1 (nuclear norm)
            schatten_norm = np.sum(singular_values)
        else:
            # General case for arbitrary p
            schatten_norm = (np.sum(singular_values**p))**(1/p)

        return schatten_norm



    def compute_makhlin_invariants(self, U):

        i = 1j  # Complex unit (sqrt(-1))
        Q = (1 / np.sqrt(2)) * np.array([
            [1,  0,  0,  i],
            [0,  i,  1,  0],
            [0,  i, -1,  0],
            [1,  0,  0, -i]
        ], dtype=complex)

        # Compute U_B = Q^\dagger U Q
        U_B = Q.conjugate().T @ U @ Q

        # Makhlin matrix: m_U = (U_B)^T U_B
        m_U = (U_B.T) @ U_B

        # Compute trace and related quantities
        tr_mU = np.trace(m_U)
        tr_mU2 = np.trace(m_U @ m_U)
        det_U = np.linalg.det(U)  # determinant of U

        if np.abs(det_U) < 1e-12:
            print("Warning: Determinant is very small, adding regularization.")
            det_U += 1e-12


        # Compute complex quantity: (tr^2(m_U) / (16 * det(U)))
        complex_val = (tr_mU**2) / (16.0 * det_U)

        # g_1 = Re{complex_val}
        g_1 = complex_val.real

        # g_2 = Im{complex_val}
        g_2 = complex_val.imag

        # g_3 = (tr^2(m_U) - tr(m_U^2)) / (4 * det(U))
        g_3 = ((tr_mU**2) - tr_mU2) / (4.0 * det_U)

        return (g_1, g_2, g_3)

    def local_equivalence_distance(self, E, U):
        """
        Compute the closeness measure d_E(U) using the Makhlin invariants.

        d_E(U) = sum_{i=1}^3 (Δg_i)^2, where Δg_i = |g_i(E) - g_i(U)|
        """
        # Compute Makhlin invariants for E and U
        gE = self.compute_makhlin_invariants(E)
        gU = self.compute_makhlin_invariants(U)

        # Compute Δg_i and sum their squares
        diff_squares = [(abs(e - u))**2 for e, u in zip(gE, gU)]
        d_EU = sum(diff_squares)
        return d_EU



In [30]:
braid_gates = [rho_1, rho_2, rho_3, rho_4, rho_5]

"""
braid_gates = [
        rho_1,       rho_2,       rho_3,       rho_4,       rho_5,
        rho_1.inv(), rho_2.inv(), rho_3.inv(), rho_4.inv(), rho_5.inv()
    ]
"""

target_gate = cnot_gate

In [31]:
# Instantiate environment and model
env = GateApproxEnv(
    braid_gates=braid_gates,
    target_gate=target_gate,
    subs=subs,
    max_length=SEQUENCE_LENGTH,
    alpha=ALPHA,
    beta=BETA,
    gamma=GAMMA,
    local_equivalence_class=True
)

In [None]:
total_timesteps = 1000000
verbose = 1

policy_kwargs = dict(net_arch=[256 for _ in range(9)], activation_fn=torch.nn.GELU)

#device = xm.xla_device() <--- Too Slow
device = "cuda"

model = DQN(
    "MlpPolicy",
    env,
    #policy_kwargs=policy_kwargs,
    #device=device,
    verbose=verbose,
    learning_rate=LEARNING_RATE,
    gamma=GAMMA,
    tensorboard_log="./tensorboard_logs/"
    )

model.learn(total_timesteps=total_timesteps)


In [1]:
def evaluate_policy(env, model, num_sequences):

    # Determine the maximum width of the Operator column dynamically
    max_operator_width = SEQUENCE_LENGTH
    operator_column_width = max(max_operator_width, 10)  # At least 10 for aesthetics

    # Print table header with dynamically adjusted width
    print(f"{'Operator':<{operator_column_width}} | {'Closeness':<10} | {'Leakage':<10} | {'Unitarity':<10}")
    print("-" * (operator_column_width + 34))  # Adjust divider length based on width

    for seq in tqdm(range(num_sequences)):
        #print(f"--- Evaluating Sequence {seq + 1} ---")
        obs, _ = env.reset()
        done = False

        while not done:
            # Use the trained policy to sample an action

            #lstm_states = None
            #num_envs = 1
            #episode_starts = np.ones((num_envs,), dtype=bool)

            action, _ = model.predict(obs, deterministic=False)
            obs, _, terminated, truncated, _ = env.step(action)
            done = terminated or truncated

        t = ''.join(map(str, env.gate_stack))

        leakage, unitarity_error, closeness_error, _ = env.compute_reward()

        if unitarity_error < 0.0001:
            unitarity_error = 0.0

        if (leakage > 0.8) and (closeness_error < .001):


            # Print the results in a clean tabular format with dynamic width
           print(f"{t:<{operator_column_width}} | {closeness_error:<10.4e} | {leakage:<10.3f} | {unitarity_error:<10.3e}")

In [None]:
evaluate_policy(env, model, num_sequences=10000)

In [34]:
"""
test = ["30",
        "880",
        "5300",
        "48440",
        "553000",
        "4334300",
        "37770000",
        "441001048",
        "2555337582",
        "59672955693",
        "374050007970",
        "306947382105",
        "6595888969003",
        "75139699757375",
        "78450807782427077787",
        "287430798424700981936977"
        ]
"""
test = [
    "000",
    "0000",
    "00000",
    "222000",
    "0000000",
    "00000000",
    "000000000",
    "2221001222",
    "22210012220",
    "223443100122",
    "2213404301242",
    "34224334310031224",
    "242314040312211412221",
    "422211222020214000112021",
    "101422102223130431111322222",
    "31423322434332420442310301422",
    "222223043422422024043333320003",
    "2344130440331032213334400344312",
    "423330314001220244224333334034032",
    "132424244140142042111112011202122020400",
    "231333001244422220433403402422343302043"
]

In [35]:
def construct_gate_from_string(string_of_numbers, env):

    test_gate = [ch for ch in string_of_numbers]

    for i in test_gate:
        env.take_action(int(i))

    return env


In [36]:
# Instantiate environment and model
env = GateApproxEnv(
    braid_gates=braid_gates,
    target_gate=target_gate,
    subs=subs,
    max_length=SEQUENCE_LENGTH,
    alpha=ALPHA,
    beta=BETA,
    gamma=GAMMA,
    local_equivalence_class=True
)

# Determine the maximum width of the Operator column dynamically
max_operator_width = max(len(str(t)) for t in test)
operator_column_width = max(max_operator_width, 10)  # At least 10 for aesthetics

# Print table header with dynamically adjusted width
print(f"{'Operator':<{operator_column_width}} | {'Closeness':<10} | {'Leakage':<10} | {'Unitarity':<10}")
print("-" * (operator_column_width + 34))  # Adjust divider length based on width

# Iterate over the test cases and compute metrics
for t in test:
    env.reset_composition()
    gate_object = construct_gate_from_string(t, env)
    leakage, unitarity_error, closeness_error, total_error = gate_object.compute_reward()

    if unitarity_error < 0.0001:
        unitarity_error = 0.0

    # Print the results in a clean tabular format with dynamic width
    print(f"{t:<{operator_column_width}} | {closeness_error:<10.3e} | {leakage:<10.3f} | {unitarity_error:<10.3e}")


Operator                                | Closeness  | Leakage    | Unitarity 
-------------------------------------------------------------------------
000                                     | 5.000e+00  | 1.000      | 0.000e+00 
0000                                    | 5.000e+00  | 1.000      | 0.000e+00 
00000                                   | 5.000e+00  | 1.000      | 0.000e+00 
222000                                  | 2.908e+00  | 0.954      | 9.017e-02 
0000000                                 | 5.000e+00  | 1.000      | 0.000e+00 
00000000                                | 5.000e+00  | 1.000      | 0.000e+00 
000000000                               | 5.000e+00  | 1.000      | 0.000e+00 
2221001222                              | 4.635e-01  | 0.976      | 4.760e-02 
22210012220                             | 4.635e-01  | 0.976      | 4.760e-02 
223443100122                            | 2.169e-05  | 0.976      | 4.760e-02 
2213404301242                           | 2.169e-05  | 0.