In [39]:
from typing import Tuple, List, Dict, Any

from rdkit import Chem
from rdkit.Chem import AllChem

import numpy as np
from math import ceil, log2
import re
import pandas as pd
import optax
import csv
import flax.linen as nn
import math

import pennylane as qml
from pennylane import qchem
from pennylane.templates import StronglyEntanglingLayers

import jax
import jax.numpy as jnp
from jax.nn.initializers import normal

import haiku as hk

In [40]:
from chembl_webresource_client.new_client import new_client

# Using the ChEMBL API to get the molecules dataset

molecule = new_client.molecule

# Filter for drug-like small molecules interesting for human use
druglike_molecules = molecule.filter(
    molecule_properties__heavy_atoms__lte=15,           # Heavy atoms less than 20
    molecule_properties__alogp__lte=5,                  # LogP less than 5 (Lipophilicity and membrane permeability)
    molecule_properties__mw_freebase__lte=300,          # Molecular weight less than 300 g/mol
    molecule_properties__qed_weighted__gte=0.5,         # QED weighted greater than 0.5 (Drug-likeness)
    molecule_properties__num_ro5_violations__lte=1,     # At most 1 Rule of 5 violation (Drug-likeness filter)

)

print("Training molecules set: ", len(druglike_molecules))  # Check how many molecules match the filter criteria

Training molecules set:  65778


## 1. Get the alphabet used for the SMILEs representation

Build the alphabet considering the structure of some special tokens

In [41]:
# Regex sin grupos capturantes y con reconocimiento de átomos minúsculos

TOKEN_PATTERN = re.compile(
    r'\[[^\[\]]+\]'      # atoms in brackets (e.g. [C@@H])
    r'|\([A-Za-z0-9]{1,3}\)'  # short groups in parentheses (1–3 chars only)
    r'|Br|Cl'            # 2-letter atoms (e.g. Br, Cl)
    r'|[A-Z](?![a-z])'   # capital letters (e.g. C, N, O)
    r'|[a-z]'            # lowercase letters (e.g. c, n, o, s)
    r'|\d'               # ring numbers (e.g. 1, 2, 3)
    r'|=|#|\/|\\|\(|\)'  # bonds and standalone parentheses
)


def tokenize_smiles(smiles: str):
    mol = Chem.MolFromSmiles(smiles)
    if mol is None:
        raise ValueError(f"SMILES inválido: {smiles}")
    can = Chem.MolToSmiles(mol, canonical=True) # canonical=True → consistent representation (better for building vocab, training reproducibility)
    return TOKEN_PATTERN.findall(can)


In [42]:
# Define the subset of molecules we are going to train the model with
molecules_subset = druglike_molecules[:600]

max_len = 0
alphabet = set()

# We use a subset of molecules to build the alphabet
for mol in molecules_subset:  # Limiting to 1000 molecules for performance
    smiles = mol.get('molecule_structures', {}).get('canonical_smiles')
    if smiles:
        tokens = tokenize_smiles(smiles)
        if max_len < len(tokens):
            max_len = len(tokens)
        alphabet.update(tokens)

alphabet = sorted(alphabet)
alphabet = ['<SOS>'] + alphabet  + ['<EOS>'] # Add Start-of-Secuence, End-of-secuence


In [43]:
print("Alphabet of SMILES characters:", alphabet)

VOCABULARY_SIZE = len(alphabet)
BITS_PER_TOKEN = ceil(log2(VOCABULARY_SIZE))  # n bits por token

print("Total unique characters in SMILES:", VOCABULARY_SIZE)
print("Maximum length of SMILES in dataset:", max_len)
print("Bits per token:", BITS_PER_TOKEN)

Alphabet of SMILES characters: ['<SOS>', '#', '(', '(Br)', '(C)', '(C2)', '(CC)', '(CC1)', '(CC2)', '(CCC)', '(CCN)', '(CCO)', '(CN)', '(CNC)', '(CO)', '(CS)', '(Cl)', '(F)', '(I)', '(N)', '(O)', '(O1)', '(O2)', '(OC)', '(c1)', '(c12)', '(c23)', '(n1)', ')', '/', '1', '2', '3', '=', 'Br', 'C', 'Cl', 'F', 'I', 'N', 'O', 'P', 'S', '[C@@H]', '[C@@]', '[C@H]', '[C@]', '[Cl-]', '[Li+]', '[N+]', '[Na+]', '[O-]', '[PH]', '[S+]', '[nH]', '\\', 'c', 'n', 'o', 's', '<EOS>']
Total unique characters in SMILES: 61
Maximum length of SMILES in dataset: 31
Bits per token: 6


•	Símbolos de enlaces y paréntesis: #, (, ), /, \, =

•	Dígitos simples para cierres de anillos: '1', '2', '3', '4', '5'

•	Átomos orgánicos y halógenos comunes, tanto mayúsculas (alifáticos) como minúsculas (aromáticos)

•	Tokens entre corchetes para isótopos, estados de carga, quiralidad, etc.

•	Un token especial '<PAD>' para padding en modelos ML

In [44]:
# Diccionario token → índice
token_to_index = {tok: i for i, tok in enumerate(alphabet)}

def print_token_bits(tokens, token_to_index):
    for tok in tokens:
        idx = token_to_index.get(tok, None)
        if idx is None:
            print(f"Token '{tok}' no está en el diccionario.")
            continue
        binary = format(idx, f'0{BITS_PER_TOKEN}b')
        print(f"'{tok}' → index {idx} → {binary}")

print_token_bits(alphabet, token_to_index)

'<SOS>' → index 0 → 000000
'#' → index 1 → 000001
'(' → index 2 → 000010
'(Br)' → index 3 → 000011
'(C)' → index 4 → 000100
'(C2)' → index 5 → 000101
'(CC)' → index 6 → 000110
'(CC1)' → index 7 → 000111
'(CC2)' → index 8 → 001000
'(CCC)' → index 9 → 001001
'(CCN)' → index 10 → 001010
'(CCO)' → index 11 → 001011
'(CN)' → index 12 → 001100
'(CNC)' → index 13 → 001101
'(CO)' → index 14 → 001110
'(CS)' → index 15 → 001111
'(Cl)' → index 16 → 010000
'(F)' → index 17 → 010001
'(I)' → index 18 → 010010
'(N)' → index 19 → 010011
'(O)' → index 20 → 010100
'(O1)' → index 21 → 010101
'(O2)' → index 22 → 010110
'(OC)' → index 23 → 010111
'(c1)' → index 24 → 011000
'(c12)' → index 25 → 011001
'(c23)' → index 26 → 011010
'(n1)' → index 27 → 011011
')' → index 28 → 011100
'/' → index 29 → 011101
'1' → index 30 → 011110
'2' → index 31 → 011111
'3' → index 32 → 100000
'=' → index 33 → 100001
'Br' → index 34 → 100010
'C' → index 35 → 100011
'Cl' → index 36 → 100100
'F' → index 37 → 100101
'I' → index 38

The input SMILEs must be the same size, so we need to use the padding to make it uniform

In [45]:
basis_encoded_dataset = []
token_to_index = {tok: i for i, tok in enumerate(alphabet)}

def smiles_to_bits(tokens: list, max_len: int) -> np.ndarray:
    """Convert tokens to a 2D array of shape (max_len, n_bits)"""
    padded_tokens = ['<SOS>'] + tokens + ['<EOS>']
    bit_matrix = []
    for tok in padded_tokens:
        idx = token_to_index[tok]
        bits = list(f"{idx:0{BITS_PER_TOKEN}b}")  # length of the binary string depends on the number of bits required to represent the alphabet
        bit_matrix.append([int(b) for b in bits])
    return np.array(bit_matrix)  # shape (max_len, n_bits)

for mol in molecules_subset:
    smiles = mol.get('molecule_structures', {}).get('canonical_smiles')
    if smiles:
        tokens = tokenize_smiles(smiles)
        print(f"Processing SMILES: {tokens}")
        if all(tok in token_to_index for tok in tokens):
            bit_matrix = smiles_to_bits(tokens, max_len)
            basis_encoded_dataset.append(bit_matrix)
            print(f"SMILES: {smiles}, shape: {bit_matrix}")



Processing SMILES: ['C', 'N', '1', 'C', 'C', 'C', '[C@H]', '1', 'c', '1', 'c', 'c', 'c', 'n', 'c', '1']
SMILES: CN1CCC[C@H]1c1cccnc1, shape: [[0 0 0 0 0 0]
 [1 0 0 0 1 1]
 [1 0 0 1 1 1]
 [0 1 1 1 1 0]
 [1 0 0 0 1 1]
 [1 0 0 0 1 1]
 [1 0 0 0 1 1]
 [1 0 1 1 0 1]
 [0 1 1 1 1 0]
 [1 1 1 0 0 0]
 [0 1 1 1 1 0]
 [1 1 1 0 0 0]
 [1 1 1 0 0 0]
 [1 1 1 0 0 0]
 [1 1 1 0 0 1]
 [1 1 1 0 0 0]
 [0 1 1 1 1 0]
 [1 1 1 1 0 0]]
Processing SMILES: ['c', '1', 'c', 'n', 'c', 'c', '(', '[C@@H]', '2', 'C', 'C', 'C', 'N', '2', ')', 'c', '1']
SMILES: c1cncc([C@@H]2CCCN2)c1, shape: [[0 0 0 0 0 0]
 [1 1 1 0 0 0]
 [0 1 1 1 1 0]
 [1 1 1 0 0 0]
 [1 1 1 0 0 1]
 [1 1 1 0 0 0]
 [1 1 1 0 0 0]
 [0 0 0 0 1 0]
 [1 0 1 0 1 1]
 [0 1 1 1 1 1]
 [1 0 0 0 1 1]
 [1 0 0 0 1 1]
 [1 0 0 0 1 1]
 [1 0 0 1 1 1]
 [0 1 1 1 1 1]
 [0 1 1 1 0 0]
 [1 1 1 0 0 0]
 [0 1 1 1 1 0]
 [1 1 1 1 0 0]]
Processing SMILES: ['C', 'C', '1', '(C)', '[C@H]', '(', 'C', '(', '=', 'O', ')', 'O', ')', 'N', '2', 'C', '(', '=', 'O', ')', 'C', '[C@H]', '2', 'S', '1'

## 2. Obtain the molecular properties of interest

logP (o cx_logp) -> Coeficiente de partición octanol/agua (lipofilia)

QED (quantitative estimate of drug-likeness) -> Escala combinada que evalúa qué tan “drug-like” es una molécula

SAS (Synthetic Accessibility Score) -> Qué tan difícil sería sintetizar la molécula en laboratorio * Needs to be calculated separately!!

MW (peso molecular) -> Masa total, típicamente ≤ 500 Da para buenos fármacos orales


First we get the max and min range values in order to later normalize the properties in the range (0, pi)

In [46]:
min_logp = float('inf')
max_logp = float('-inf')
min_qed = float('inf')
max_qed = float('-inf')
min_mw = float('inf')
max_mw = float('-inf')


# Iterate through the subset of molecules to find min/max properties to normalize them
for mol in molecules_subset:
    logP = mol.get('molecule_properties', {}).get('alogp')
    qed = mol.get('molecule_properties', {}).get('qed_weighted')
    mw = mol.get('molecule_properties', {}).get('mw_freebase')

    if logP is None or qed is None or mw is None:
        continue  # Skip if any property is missing

    logP = float(logP)
    qed = float(qed)
    mw = float(mw)

    if logP < min_logp:
        min_logp = logP
    if logP > max_logp:
        max_logp = logP

    if qed < min_qed:
        min_qed = qed
    if qed > max_qed:
        max_qed = qed

    if mw < min_mw:
        min_mw = mw
    if mw > max_mw:
        max_mw = mw

print(f"LogP range: {min_logp} to {max_logp}")
print(f"QED range: {min_qed} to {max_qed}")
print(f"MW range: {min_mw} to {max_mw}")

LogP range: -1.5 to 3.89
QED range: 0.5 to 0.92
MW range: 107.11 to 298.11


In [47]:
def normalize(value, min_val, max_val, target_max=np.pi):
    ''' Normalize a value to a range [0, target_max] based on min and max values.'''
    norm = (value - min_val) / (max_val - min_val) * target_max
    return float(f"{norm:.3f}")

In [48]:
print("Maximum length of sequences in the subset:", max_len)

Maximum length of sequences in the subset: 31


In [49]:
# Write the structured data to a CSV file
with open("structured_smiles_data.csv", mode="w", newline="") as file:
    writer = csv.writer(file)

    n_tokens = max_len + 2  # +2 for <SOS> and <EOS>
    header = ["logP", "qed", "mw"] + [f"token_{i}" for i in range(n_tokens)]
    writer.writerow(header)

    for mol in molecules_subset:
        smiles = mol.get('molecule_structures', {}).get('canonical_smiles')
        props = mol.get('molecule_properties', {})
        if not smiles:
            continue

        try:
            logP = float(props.get('alogp'))
            qed = float(props.get('qed_weighted'))
            mw = float(props.get('mw_freebase'))
        except (TypeError, ValueError):
            continue

        norm_logp = normalize(logP, min_logp, max_logp)
        norm_qed = normalize(qed, min_qed, max_qed)
        norm_mw = normalize(mw, min_mw, max_mw)

        tokens = tokenize_smiles(smiles)
        if not all(tok in token_to_index for tok in tokens):
            continue

        bit_matrix = smiles_to_bits(tokens, max_len)  # shape (n_tokens, 6)
        token_bits_as_strings = ["".join(map(str, row)) for row in bit_matrix]

        row = [norm_logp, norm_qed, norm_mw] + token_bits_as_strings
        writer.writerow(row)

## 3. Quantum Generative Model

In [50]:
import itertools
def zstring_combos(wires):
    """
    Return an ordered list of wire-tuples for all Z-strings up to order H_LOCAL.
    Order: all 1-local, then all 2-local, ..., up to H_LOCAL.
    """
    L = []
    for k in range(1, H_LOCAL + 1):
        L.extend(itertools.combinations(wires, k))
    return [tuple(c) for c in L]

def num_zstrings(n_wires):
    """Count how many Z-strings up to order H_LOCAL."""
    from math import comb
    return sum(comb(n_wires, k) for k in range(1, H_LOCAL + 1))

In [51]:
jax.config.update("jax_enable_x64", True)

In [52]:
# Device and qubit setup
# BITS_PER_TOKEN number of qubits needed to encode each token
n_prop_qubits = 3  # number of qubits needed to encode properties (logP, QED, MW)
n_ancillas = 3  # number of ancilla qubits that represent the environment
n_total_qubits = n_prop_qubits + BITS_PER_TOKEN + n_ancillas

N_LAYERS = 6  # number of variational layers
H_LOCAL = 3 # h_local sets the maximum number of qubits that can interact in each Z-string term of Σ


# Name them explicitly
prop_wires = [f"prop_{i}" for i in range(n_prop_qubits)]
token_wires = [f"token_{i}" for i in range(BITS_PER_TOKEN)]
ancilla_wires = [f"ancilla_{i}" for i in range(n_ancillas)]
all_wires = prop_wires + token_wires + ancilla_wires

dev = qml.device("default.qubit", wires=all_wires)


def molecular_property_encoder(props):
    """Encode continuous props on property qubits via RY rotations"""
    for wire, val in zip(prop_wires, props):
        qml.RY(val, wires=wire)


def token_encoder(token_bits):
    qml.BasisState(token_bits, wires=token_wires)


def operator_layer(theta_params, theta_prop, wires):
    """
    Variational layer where:
      - theta_params[...] are rotations for token + ancilla qubits
      - theta_prop encodes property→token entanglement
    """
    prop_ws = prop_wires
    token_ancilla_ws = token_wires + ancilla_wires
    
    # Property → token entanglement
    for p, prop_wire in enumerate(prop_ws):
        for t, t_a_wire in enumerate(token_ancilla_ws):
            # RZX + RY + RZ rotation controlled by property qubit
            '''qml.ctrl(qml.Rot, control=prop_wire)(
                theta_prop[p, t, 0],  # alpha
                theta_prop[p, t, 1],  # beta
                theta_prop[p, t, 2],  # gamma
                wires=t_a_wire
            )'''
            qml.CRX(theta_prop[p, t, 0], wires=[prop_wire, t_a_wire])
            qml.CRY(theta_prop[p, t, 1], wires=[prop_wire, t_a_wire])
            '''# Add extra phase coupling
            qml.CRZ(theta_prop[p, t, 2], wires=[prop_wire, t_a_wire])
            qml.IsingZZ(theta_prop[p, t, 3], wires=[prop_wire, t_a_wire])'''

    # Single-qubit rotations
    for i, wire in enumerate(token_ancilla_ws):
        qml.RX(theta_params[i, 0], wires=wire)
        qml.RY(theta_params[i, 1], wires=wire)
        qml.RZ(theta_params[i, 2], wires=wire)

    # Entangle token + ancilla chain
    for i in range(len(token_ancilla_ws) - 1):
        qml.CNOT(wires=[token_ancilla_ws[i], token_ancilla_ws[i + 1]])
    qml.CNOT(wires=[token_ancilla_ws[-1], token_ancilla_ws[0]])


'''def embedding_encoder(embedding_vector, wires):
    """Apply RY rotations based on embedding vector to token wires."""
    for i, w in enumerate(token_wires):
        rx_angle, ry_angle = embedding_vector[i]
        qml.RY(rx_angle, wires=w)
        qml.RX(ry_angle, wires=w)'''


def embedding_encoder_attention(embedding_vector, wires):
    token_ancilla_ws = token_wires + ancilla_wires
    n_wires = len(token_ancilla_ws)

    # Step 1: Local context rotations
    for i, w in enumerate(token_ancilla_ws):
        rx_angle, ry_angle = embedding_vector[i]
        qml.RX(rx_angle, wires=w)
        qml.RY(ry_angle, wires=w)
'''
    # Step 2: Data-controlled entanglement ("attention")
    for i in range(n_wires):
        for j in range(i+1, n_wires):
            control_angle = embedding_vector[i, 0] * embedding_vector[j, 1]
            qml.CRZ(control_angle, wires=[token_ancilla_ws[i], token_ancilla_ws[j]])'''


def Sigma_layer_vec(gamma_vec, wires, time=1.0, combos=None):
    """
    Diagonal multi-Z unitary Σ = exp(i * sum_s gamma_s * Z^{⊗|s|} * t)
    using a flat parameter vector 'gamma_vec' aligned with 'combos'.
    """
    token_ancilla_ws = list(wires)  # pass token+ancilla here
    if combos is None:
        combos = zstring_combos(token_ancilla_ws)

    # Safety: ensure the vector length matches the number of combos
    assert gamma_vec.shape[0] == len(combos), \
        f"gamma_vec has length {gamma_vec.shape[0]} but expected {len(combos)}"

    # MultiRZ(phi) = exp(-i * phi/2 * Z^{⊗k}); choose phi = -2 * gamma * time
    for gamma, combo in zip(gamma_vec, combos):
        qml.MultiRZ(-2.0 * gamma * time, wires=list(combo))


# QNode combining encoding and variational layers, returning expectation values
@qml.qnode(dev, interface="jax")
def autoregressive_model(token_bits, props, theta_params, theta_prop, sigma_params, embedding_vector):
    molecular_property_encoder(props)        # Encode molecular props
    token_encoder(token_bits)               # Encode token bits
    #embedding_encoder_attention(embedding_vector, wires=all_wires)  # Embedding layer

    token_ancilla_ws = token_wires + ancilla_wires
    combos = zstring_combos(token_ancilla_ws)

    for l in range(N_LAYERS):
        # Embedding
        embedding_encoder_attention(embedding_vector[l], wires=all_wires)

        # Forward V(θ)
        operator_layer(theta_params[l], theta_prop[l], wires=all_wires)

        # Diagonal Σ(γ,t): vector API
        Sigma_layer_vec(sigma_params[l], token_ancilla_ws, time=1.0, combos=combos)

        # Backward V(θ)†
        qml.adjoint(operator_layer)(theta_params[l], theta_prop[l], wires=all_wires)

    return qml.probs(wires=token_wires)



In [53]:
def bitstr_to_array(bitstr):
    """Convert a string of bits (e.g., '010101') to a numpy float32 array."""
    return np.array([int(b) for b in bitstr], dtype=np.float32)

def build_training_data(df):
    """
    Build dataset tuples of (input_token_bits, molecular_properties, target_token_bits)
    from a DataFrame.

    Args:
        df (pandas.DataFrame): DataFrame containing molecular properties and token bit strings.
        n_token_cols (int): Number of token columns in the DataFrame.

    Returns:
        list of tuples: Each tuple contains (x_token: np.array, x_props: np.array, y_target: np.array)
    """
    dataset = []

    for _, row in df.iterrows():
        # Extract molecular properties as a numpy float32 array
        props = [row['logP'], row['qed'], row['mw']]
        x_props = np.array(props, dtype=np.float32)

        tokens = row[3:]  # token columns after properties

        # Iterate over token sequence to create input-target pairs
        for i in range(len(tokens) - 1):
            current_token = tokens.iloc[i]
            next_token = tokens.iloc[i + 1]

            # Skip missing or NaN tokens
            if current_token is None or (isinstance(current_token, float) and math.isnan(current_token)):
                continue
            if next_token is None or (isinstance(next_token, float) and math.isnan(next_token)):
                continue

            # Convert token strings (e.g., '01011') to bit arrays
            x_token = bitstr_to_array(current_token)
            y_target = bitstr_to_array(next_token)

            dataset.append((x_token, x_props, y_target))

    return dataset

# Load dataset
token_cols = [f"token_{i}" for i in range(n_tokens)]
df = pd.read_csv("structured_smiles_data.csv", dtype={col: str for col in token_cols})
dataset = build_training_data(df)  # Should return list/array of (x_token, x_props, y_target)

In [54]:
# 1. Define the embedding network function
def embedding_network_fn(x):
    embedding_size = N_LAYERS * 2 * (BITS_PER_TOKEN + n_ancillas)
    mlp = hk.Sequential([
        hk.Linear(32), jax.nn.relu,
        hk.Linear(16), jax.nn.relu,
        hk.Linear(embedding_size),
        jnp.tanh  # restricts output to (-1, 1)
    ])
    return mlp(x) * jnp.pi  # now in (-π, π)


# 2. Transform the function to make it usable in JAX/Haiku
embedding_network = hk.transform(embedding_network_fn)


def get_context_embedding(prev_token_bits, embed_params):
    """
    Returns: vector of shape (BITS_PER_TOKEN,) for embedding_encoder
    """
    embedding_output = embedding_network.apply(embed_params, None, prev_token_bits)
    embedding_vector = embedding_output.reshape((N_LAYERS, BITS_PER_TOKEN + n_ancillas, 2))

    return embedding_vector


In [55]:
def bits_to_index(bits):
    powers = 2 ** jnp.arange(len(bits) - 1, -1, -1)
    return jnp.dot(bits, powers).astype(jnp.int32)


def categorical_crossentropy(pred_probs, target_index):
    epsilon = 1e-10
    return -jnp.log(pred_probs[target_index] + epsilon)


def label_smoothing_crossentropy_normalized(pred_probs, target_index, epsilon=0.1):
    """Cross-entropy loss with label smoothing, normalized to [0,1]."""
    num_classes = pred_probs.shape[0] # Number of classes (tokens)
    
    # Build smoothed target
    smooth_target = jnp.full_like(pred_probs, epsilon / (num_classes - 1))
    smooth_target = smooth_target.at[target_index].set(1.0 - epsilon)
    
    # Compute cross-entropy
    loss = -jnp.sum(smooth_target * jnp.log(pred_probs + 1e-10))
    
    # Normalize to [0,1]
    max_loss = jnp.log(num_classes)  # worst-case (uniform distribution)
    norm_loss = loss / max_loss
    
    return norm_loss

'''
    # Mask invalid tokens (e.g. token 46 does not exist if vocab size is 36)
    valid_mask = jnp.concatenate([
        jnp.ones(VOCABULARY_SIZE), 
        jnp.zeros(pred_probs.shape[0] - VOCABULARY_SIZE)
    ])
    masked_probs = pred_probs * valid_mask
    masked_probs = masked_probs / (jnp.sum(masked_probs) + 1e-10)  # renormalize

    # Label smoothing over valid tokens only
    num_valid = VOCABULARY_SIZE
    smooth_target = jnp.zeros_like(masked_probs)
    smooth_target = smooth_target + (epsilon / (num_valid - 1)) * valid_mask
    smooth_target = smooth_target.at[target_index].set(1.0 - epsilon)

    # Cross-entropy
    loss = -jnp.sum(smooth_target * jnp.log(masked_probs + 1e-10))

    # Normalize loss to [0,1]
    max_loss = jnp.log(num_valid)  # worst case: uniform distribution
    norm_loss = loss / max_loss

    return norm_loss
'''



def compute_accuracy(pred_probs, target_index):
    predicted_index = jnp.argmax(pred_probs)
    return jnp.array(predicted_index == target_index, dtype=jnp.float32)


In [56]:
# Initialize embedding params
rng = jax.random.PRNGKey(0)
dummy_context = jnp.zeros((3 * BITS_PER_TOKEN,), dtype=jnp.float32)  # 3 prev. tokens
embedding_params = embedding_network.init(rng, dummy_context)

# Effective qubit counts in variational layers
n_token_ancilla = BITS_PER_TOKEN + n_ancillas

# PRNG keys
key = jax.random.PRNGKey(42)
key, k_theta, k_theta_prop, k_sigma = jax.random.split(key, 4)

# Precompute Z-string combos once
token_ancilla_ws = token_wires + ancilla_wires
combos = zstring_combos(token_ancilla_ws)
n_strings = len(combos)


combined_params = {
    'theta': jax.random.normal(k_theta, (N_LAYERS, n_token_ancilla, 3)) * 0.1,
    'theta_prop': jax.random.normal(k_theta_prop, (N_LAYERS, n_prop_qubits, n_token_ancilla, 3)) * 0.1,
    'sigma': jax.random.normal(k_sigma, (N_LAYERS, n_strings)) * 0.1,
    'embedding': embedding_params,
}

# Initialize parameters
learning_rate = 0.001
n_epochs = 25

# Optimizer
optimizer = optax.adam(learning_rate)
opt_state = optimizer.init(combined_params)

# Define two separate optimizers with different learning rates
'''optimizer = optax.multi_transform(
    {
        'quantum': optax.adam(learning_rate=0.005),
        'embedding': optax.adam(learning_rate=0.001)
    },
    param_labels={
        'theta': 'quantum',
        'theta_prop': 'quantum',
        'sigma': 'quantum',
        'embedding': 'embedding'
    }
)'''

@jax.jit
def training_step(params, opt_state, x_token, x_props, y_target, context_vector):
    def loss_fn(params):
        theta_params = params['theta']
        theta_prop = params['theta_prop']
        sigma_params = params['sigma']
        embedding_params = params['embedding']

        # Get embedding output from context vector and embedding params
        embedding = get_context_embedding(context_vector, embedding_params)

        # Predict using the embedding layer
        pred_probs = autoregressive_model(x_token, x_props, theta_params, theta_prop, sigma_params, embedding)
        index = bits_to_index(y_target)
        
        # Return scalar loss for gradient computation
        return label_smoothing_crossentropy_normalized(pred_probs, index), pred_probs

    # value_and_grad computes both loss and grads in one pass
    (loss, pred_probs), grads = jax.value_and_grad(loss_fn, has_aux=True)(params)

    # Update parameters
    updates, opt_state = optimizer.update(grads, opt_state, params)
    new_params = optax.apply_updates(params, updates)

    # Accuracy from pred_probs (already computed!)
    index = bits_to_index(y_target)
    acc = compute_accuracy(pred_probs, index)

    return new_params, loss, opt_state, grads, acc

# Context initialization outside epoch loop
SOS_token = jnp.zeros((BITS_PER_TOKEN,), dtype=jnp.int32)
eos_index = len(alphabet)-1
EOS_token = jnp.array([int(b) for b in format(eos_index, f'0{BITS_PER_TOKEN}b')], dtype=jnp.int32)

prev_token1 = SOS_token
prev_token2 = SOS_token
prev_token3 = SOS_token

for epoch in range(n_epochs):
    total_loss = total_acc = 0.0
    
    for x_token, x_props, y_target in dataset:
        x_token = jnp.array(x_token, dtype=jnp.int32)
        x_props = jnp.array(x_props, dtype=jnp.float32)
        y_target = jnp.array(y_target, dtype=jnp.int32)
      
        context_vector = jnp.concatenate([prev_token1, prev_token2, prev_token3])

        combined_params, loss, opt_state, grads, acc = training_step(combined_params, opt_state, x_token, x_props, y_target, context_vector)
        
        # Update previous tokens
        if jnp.array_equal(x_token, EOS_token):
            prev_token1 = SOS_token
            prev_token2 = SOS_token
            prev_token3 = SOS_token
        else:
            # Shift prev tokens, add current token as the newest previous token
            prev_token3 = prev_token2
            prev_token2 = prev_token1
            prev_token1 = x_token

        total_loss += loss
        total_acc  += acc

    avg_loss = total_loss / len(dataset)
    avg_acc  = total_acc / len(dataset)
    print(f"Epoch {epoch+1} | Loss = {avg_loss:.4f} | Accuracy = {avg_acc:.4f}")


KeyboardInterrupt: 

n_layers = 6
h_local = 3
prob_mask = NO

Epoch 1 | Loss = 0.6503 | Accuracy = 0.4038
Epoch 2 | Loss = 0.5920 | Accuracy = 0.4562
Epoch 3 | Loss = 0.5577 | Accuracy = 0.4904
Epoch 4 | Loss = 0.5467 | Accuracy = 0.5012
Epoch 5 | Loss = 0.5369 | Accuracy = 0.5178
Epoch 6 | Loss = 0.5301 | Accuracy = 0.5250
Epoch 7 | Loss = 0.5227 | Accuracy = 0.5357
Epoch 8 | Loss = 0.5171 | Accuracy = 0.5385
Epoch 9 | Loss = 0.5141 | Accuracy = 0.5373
Epoch 10 | Loss = 0.5103 | Accuracy = 0.5449
Epoch 11 | Loss = 0.5089 | Accuracy = 0.5475
Epoch 12 | Loss = 0.5066 | Accuracy = 0.5506
Epoch 13 | Loss = 0.5065 | Accuracy = 0.5469
Epoch 14 | Loss = 0.5056 | Accuracy = 0.5511
Epoch 15 | Loss = 0.5035 | Accuracy = 0.5558
Epoch 16 | Loss = 0.5017 | Accuracy = 0.5604
Epoch 17 | Loss = 0.5018 | Accuracy = 0.5568
Epoch 18 | Loss = 0.5044 | Accuracy = 0.5546
Epoch 19 | Loss = 0.5032 | Accuracy = 0.5563
Epoch 20 | Loss = 0.5017 | Accuracy = 0.5584
Epoch 21 | Loss = 0.5012 | Accuracy = 0.5572
Epoch 22 | Loss = 0.5002 | Accuracy = 0.5597
Epoch 23 | Loss = 0.5030 | Accuracy = 0.5554
Epoch 24 | Loss = 0.5022 | Accuracy = 0.5557
Epoch 25 | Loss = 0.5014 | Accuracy = 0.5572

h=2

Epoch 1 | Loss = 2.7894 | Accuracy = 0.4019
Epoch 2 | Loss = 2.4482 | Accuracy = 0.4809
Epoch 3 | Loss = 2.3743 | Accuracy = 0.4954
Epoch 4 | Loss = 2.3402 | Accuracy = 0.5056
Epoch 5 | Loss = 2.3080 | Accuracy = 0.5155
Epoch 6 | Loss = 2.2813 | Accuracy = 0.5223
Epoch 7 | Loss = 2.2683 | Accuracy = 0.5263
Epoch 8 | Loss = 2.2563 | Accuracy = 0.5309
Epoch 9 | Loss = 2.2506 | Accuracy = 0.5329
Epoch 10 | Loss = 2.2450 | Accuracy = 0.5294
Epoch 11 | Loss = 2.2452 | Accuracy = 0.5278
Epoch 12 | Loss = 2.2378 | Accuracy = 0.5308
Epoch 13 | Loss = 2.2224 | Accuracy = 0.5372
Epoch 14 | Loss = 2.2186 | Accuracy = 0.5424
Epoch 15 | Loss = 2.2122 | Accuracy = 0.5437
Epoch 16 | Loss = 2.2111 | Accuracy = 0.5386
Epoch 17 | Loss = 2.2118 | Accuracy = 0.5389
Epoch 18 | Loss = 2.2103 | Accuracy = 0.5403
Epoch 19 | Loss = 2.2126 | Accuracy = 0.5410
Epoch 20 | Loss = 2.2138 | Accuracy = 0.5446
Epoch 21 | Loss = 2.2075 | Accuracy = 0.5421
Epoch 22 | Loss = 2.2007 | Accuracy = 0.5436
Epoch 23 | Loss = 2.2041 | Accuracy = 0.5397
Epoch 24 | Loss = 2.2018 | Accuracy = 0.5411
Epoch 25 | Loss = 2.2068 | Accuracy = 0.5357
Epoch 26 | Loss = 2.2117 | Accuracy = 0.5360
Epoch 27 | Loss = 2.1967 | Accuracy = 0.5379
Epoch 28 | Loss = 2.1956 | Accuracy = 0.5440
Epoch 29 | Loss = 2.1966 | Accuracy = 0.5461
Epoch 30 | Loss = 2.1943 | Accuracy = 0.5452

Quantum Circuit:
   prop_0: ──RY(1.93)─────────────────────────────────────────────────────────────────────────── ···
   prop_1: ──RY(0.97)─────────────────────────────────────────────────────────────────────────── ···
   prop_2: ──RY(0.86)─────────────────────────────────────────────────────────────────────────── ···
  token_0: ──RX(0.00)───RY(-0.01)──RX(-0.01)─╭●─────────╭●────────╭●─────────╭●────────╭●─────── ···
  token_1: ──RX(0.00)───RY(0.01)───RX(0.01)──╰RZ(-0.00)─│─────────│──────────│─────────│──────── ···
  token_2: ──RX(3.14)───RY(-0.01)──RX(-0.01)────────────╰RZ(0.00)─│──────────│─────────│──────── ···
  token_3: ──RX(0.00)───RY(0.01)───RX(0.01)───────────────────────╰RZ(-0.00)─│─────────│──────── ···
  token_4: ──RX(3.14)───RY(0.01)───RX(-0.01)─────────────────────────────────╰RZ(0.00)─│──────── ···
  token_5: ──RX(0.00)───RY(0.00)───RX(-0.00)───────────────────────────────────────────╰RZ(0.00) ···
ancilla_0: ──RX(-0.00)──RY(-0.01)──RZ(-0.01)──────────────────────────────────────────────────── ···
ancilla_1: ──RX(0.00)───RY(-0.00)──RZ(0.00)───────────────────────────────────────────────────── ···
ancilla_2: ──RX(0.00)───RY(-0.00)──RZ(0.00)───────────────────────────────────────────────────── ···

   prop_0: ··· ─────────────────────────────────╭●─────────╭●────────────────────────────── ···
   prop_1: ··· ─────────────────────────────────│──────────│─────────────────────────────── ···
   prop_2: ··· ─────────────────────────────────│──────────│─────────────────────────────── ···
  token_0: ··· ──RX(0.00)───RY(0.00)──RZ(-0.01)─╰RX(0.00)──╰RY(0.00)─────────────────────── ···
  token_1: ··· ─╭●─────────╭●────────╭●─────────╭●──────────RX(-0.00)──RY(-0.01)──RZ(-0.00) ···
  token_2: ··· ─╰RZ(-0.00)─│─────────│──────────│──────────╭●─────────╭●─────────╭●──────── ···
  token_3: ··· ────────────╰RZ(0.00)─│──────────│──────────╰RZ(-0.00)─│──────────│───────── ···
  token_4: ··· ──────────────────────╰RZ(-0.00)─│─────────────────────╰RZ(0.00)──│───────── ···
  token_5: ··· ─────────────────────────────────╰RZ(-0.00)───────────────────────╰RZ(0.00)─ ···
ancilla_0: ··· ──────────────────────────────────────────────────────────────────────────── ···
ancilla_1: ··· ──────────────────────────────────────────────────────────────────────────── ···
ancilla_2: ··· ──────────────────────────────────────────────────────────────────────────── ···

   prop_0: ··· ─╭●─────────╭●────────────────────╭●─────────╭●─────────╭●─────────╭●──────── ···
   prop_1: ··· ─│──────────│─────────────────────│──────────│──────────│──────────│───────── ···
   prop_2: ··· ─│──────────│─────────────────────│──────────│──────────│──────────│───────── ···
  token_0: ··· ─│──────────│─────────────────────│──────────│──────────│──────────│───────── ···
  token_1: ··· ─╰RX(-0.01)─╰RY(-0.00)────────────│──────────│──────────│──────────│───────── ···
  token_2: ··· ──RX(-0.01)──RY(-0.02)──RZ(0.00)──╰RX(-0.01)─╰RY(-0.02)─│──────────│───────── ···
  token_3: ··· ─╭●─────────╭●──────────RX(0.00)───RY(-0.00)──RZ(-0.00)─╰RX(-0.02)─╰RY(-0.00) ···
  token_4: ··· ─╰RZ(-0.00)─│──────────╭●──────────RX(0.02)───RY(-0.00)──RZ(0.00)──────────── ···
  token_5: ··· ────────────╰RZ(-0.00)─╰RZ(-0.00)──RX(-0.01)──RY(-0.01)──RZ(0.00)──────────── ···
ancilla_0: ··· ───────────────────────────────────────────────────────────────────────────── ···
ancilla_1: ··· ───────────────────────────────────────────────────────────────────────────── ···
ancilla_2: ··· ───────────────────────────────────────────────────────────────────────────── ···

   prop_0: ··· ─╭●────────╭●─────────╭●────────╭●─────────────────────────────────────── ···
   prop_1: ··· ─│─────────│──────────│─────────│─────────╭●────────╭●─────────╭●──────── ···
   prop_2: ··· ─│─────────│──────────│─────────│─────────│─────────│──────────│───────── ···
  token_0: ··· ─│─────────│──────────│─────────│─────────╰RX(0.01)─╰RY(-0.01)─│───────── ···
  token_1: ··· ─│─────────│──────────│─────────│──────────────────────────────╰RX(-0.00) ···
  token_2: ··· ─│─────────│──────────│─────────│──────────────────────────────────────── ···
  token_3: ··· ─│─────────│──────────│─────────│──────────────────────────────────────── ···
  token_4: ··· ─╰RX(0.00)─╰RY(-0.02)─│─────────│──────────────────────────────────────── ···
  token_5: ··· ──────────────────────╰RX(0.01)─╰RY(0.00)──────────────────────────────── ···
ancilla_0: ··· ───────────────────────────────────────────────────────────────────────── ···
ancilla_1: ··· ───────────────────────────────────────────────────────────────────────── ···
ancilla_2: ··· ───────────────────────────────────────────────────────────────────────── ···

   prop_0: ··· ────────────────────────────────────────────────────────────────────────── ···
   prop_1: ··· ─╭●─────────╭●────────╭●─────────╭●─────────╭●────────╭●─────────╭●─────── ···
   prop_2: ··· ─│──────────│─────────│──────────│──────────│─────────│──────────│──────── ···
  token_0: ··· ─│──────────│─────────│──────────│──────────│─────────│──────────│──────── ···
  token_1: ··· ─╰RY(-0.01)─│─────────│──────────│──────────│─────────│──────────│──────── ···
  token_2: ··· ────────────╰RX(0.01)─╰RY(-0.02)─│──────────│─────────│──────────│──────── ···
  token_3: ··· ─────────────────────────────────╰RX(-0.01)─╰RY(0.01)─│──────────│──────── ···
  token_4: ··· ──────────────────────────────────────────────────────╰RX(-0.00)─╰RY(0.00) ···
  token_5: ··· ────────────────────────────────────────────────────────────────────────── ···
ancilla_0: ··· ────────────────────────────────────────────────────────────────────────── ···
ancilla_1: ··· ────────────────────────────────────────────────────────────────────────── ···
ancilla_2: ··· ────────────────────────────────────────────────────────────────────────── ···

   prop_0: ··· ───────────────────────────────────────────────────────────────────────── ···
   prop_1: ··· ─╭●─────────╭●─────────────────────────────────────────────────────────── ···
   prop_2: ··· ─│──────────│─────────╭●────────╭●────────╭●─────────╭●────────╭●──────── ···
  token_0: ··· ─│──────────│─────────╰RX(0.01)─╰RY(0.02)─│──────────│─────────│───────── ···
  token_1: ··· ─│──────────│─────────────────────────────╰RX(-0.01)─╰RY(0.02)─│───────── ···
  token_2: ··· ─│──────────│──────────────────────────────────────────────────╰RX(-0.00) ···
  token_3: ··· ─│──────────│──────────────────────────────────────────────────────────── ···
  token_4: ··· ─│──────────│──────────────────────────────────────────────────────────── ···
  token_5: ··· ─╰RX(-0.00)─╰RY(0.01)──────────────────────────────────────────────────── ···
ancilla_0: ··· ───────────────────────────────────────────────────────────────────────── ···
ancilla_1: ··· ───────────────────────────────────────────────────────────────────────── ···
ancilla_2: ··· ───────────────────────────────────────────────────────────────────────── ···

   prop_0: ··· ─────────────────────────────────────────────────────────────────────────────── ···
   prop_1: ··· ─────────────────────────────────────────────────────────────────────────────── ···
   prop_2: ··· ─╭●────────╭●─────────╭●────────╭●────────╭●─────────╭●─────────╭●───────────── ···
  token_0: ··· ─│─────────│──────────│─────────│─────────│──────────│──────────│─────────╭●─── ···
  token_1: ··· ─│─────────│──────────│─────────│─────────│──────────│──────────│─────────╰X─╭● ···
  token_2: ··· ─╰RY(0.01)─│──────────│─────────│─────────│──────────│──────────│────────────╰X ···
  token_3: ··· ───────────╰RX(-0.01)─╰RY(0.01)─│─────────│──────────│──────────│────────────── ···
  token_4: ··· ────────────────────────────────╰RX(0.00)─╰RY(-0.01)─│──────────│────────────── ···
  token_5: ··· ─────────────────────────────────────────────────────╰RX(-0.01)─╰RY(0.00)────── ···
ancilla_0: ··· ─────────────────────────────────────────────────────────────────────────────── ···
ancilla_1: ··· ─────────────────────────────────────────────────────────────────────────────── ···
ancilla_2: ··· ─────────────────────────────────────────────────────────────────────────────── ···

   prop_0: ··· ─────────────────────────────────────────────────────────────────────────── ···
   prop_1: ··· ─────────────────────────────────────────────────────────────────────────── ···
   prop_2: ··· ─────────────────────────────────────────────────────────────────────────── ···
  token_0: ··· ───────────────────╭X──RZ(-0.01)─╭X†─────────────────────────────────────── ···
  token_1: ··· ───────────────────│───RZ(-0.03)─│───────────────────────────────────────── ···
  token_2: ··· ─╭●────────────────│───RZ(-0.03)─│───────────────────────────────────────── ···
  token_3: ··· ─╰X─╭●─────────────│───RZ(0.01)──│───────────────────────────────────────── ···
  token_4: ··· ────╰X─╭●──────────│───RZ(-0.01)─│──────────────────────────────╭X†──────── ···
  token_5: ··· ───────╰X─╭●───────│───RZ(-0.00)─│──────────────────╭X†─────────╰X†──────── ···
ancilla_0: ··· ──────────╰X─╭●────│───RZ(-0.00)─│───────╭X†────────╰X†──────────RZ(-0.01)† ···
ancilla_1: ··· ─────────────╰X─╭●─│───RZ(-0.02)─│───╭X†─╰X†─────────RZ(0.00)†───RY(-0.00)† ···
ancilla_2: ··· ────────────────╰X─╰●──RZ(-0.01)─╰X†─╰X†──RZ(0.00)†──RY(-0.00)†──RX(0.00)†─ ···

   prop_0: ··· ───────────────────────────────────────────────────────────────────────── ···
   prop_1: ··· ───────────────────────────────────────────────────────────────────────── ···
   prop_2: ··· ─────────────────────────────────╭(RY(0.00))†─╭(RX(-0.01))†─╭(RY(-0.01))† ···
  token_0: ··· ─────────────────────────────╭X†─│────────────│─────────────│──────────── ···
  token_1: ··· ─────────────────────────╭X†─╰X†─│────────────│─────────────│──────────── ···
  token_2: ··· ─────────────╭X†─────────╰X†─────│────────────│─────────────│──────────── ···
  token_3: ··· ─╭X†─────────╰X†─────────────────│────────────│─────────────│──────────── ···
  token_4: ··· ─╰X†─────────────────────────────│────────────│─────────────╰(RY(-0.01))† ···
  token_5: ··· ─────────────────────────────────╰(RY(0.00))†─╰(RX(-0.01))†────────────── ···
ancilla_0: ··· ──RY(-0.01)†──RX(-0.00)†───────────────────────────────────────────────── ···
ancilla_1: ··· ──RX(0.00)†────────────────────────────────────────────────────────────── ···
ancilla_2: ··· ───────────────────────────────────────────────────────────────────────── ···

   prop_0: ··· ──────────────────────────────────────────────────────────────────────────────── ···
   prop_1: ··· ──────────────────────────────────────────────────────────────────────────────── ···
   prop_2: ··· ─╭(RX(0.00))†─╭(RY(0.01))†─╭(RX(-0.01))†─╭(RY(0.01))†─╭(RX(-0.00))†─╭(RY(0.02))† ···
  token_0: ··· ─│────────────│────────────│─────────────│────────────│─────────────│─────────── ···
  token_1: ··· ─│────────────│────────────│─────────────│────────────│─────────────╰(RY(0.02))† ···
  token_2: ··· ─│────────────│────────────│─────────────╰(RY(0.01))†─╰(RX(-0.00))†───────────── ···
  token_3: ··· ─│────────────╰(RY(0.01))†─╰(RX(-0.01))†──────────────────────────────────────── ···
  token_4: ··· ─╰(RX(0.00))†─────────────────────────────────────────────────────────────────── ···
  token_5: ··· ──────────────────────────────────────────────────────────────────────────────── ···
ancilla_0: ··· ──────────────────────────────────────────────────────────────────────────────── ···
ancilla_1: ··· ──────────────────────────────────────────────────────────────────────────────── ···
ancilla_2: ··· ──────────────────────────────────────────────────────────────────────────────── ···

   prop_0: ··· ──────────────────────────────────────────────────────────────────────────────── ···
   prop_1: ··· ─────────────────────────────────────────╭(RY(0.01))†─╭(RX(-0.00))†─╭(RY(0.00))† ···
   prop_2: ··· ─╭(RX(-0.01))†─╭(RY(0.02))†─╭(RX(0.01))†─│────────────│─────────────│─────────── ···
  token_0: ··· ─│─────────────╰(RY(0.02))†─╰(RX(0.01))†─│────────────│─────────────│─────────── ···
  token_1: ··· ─╰(RX(-0.01))†───────────────────────────│────────────│─────────────│─────────── ···
  token_2: ··· ─────────────────────────────────────────│────────────│─────────────│─────────── ···
  token_3: ··· ─────────────────────────────────────────│────────────│─────────────│─────────── ···
  token_4: ··· ─────────────────────────────────────────│────────────│─────────────╰(RY(0.00))† ···
  token_5: ··· ─────────────────────────────────────────╰(RY(0.01))†─╰(RX(-0.00))†───────────── ···
ancilla_0: ··· ──────────────────────────────────────────────────────────────────────────────── ···
ancilla_1: ··· ──────────────────────────────────────────────────────────────────────────────── ···
ancilla_2: ··· ──────────────────────────────────────────────────────────────────────────────── ···

   prop_0: ··· ──────────────────────────────────────────────────────────────────── ···
   prop_1: ··· ─╭(RX(-0.00))†─╭(RY(0.01))†─╭(RX(-0.01))†─╭(RY(-0.02))†─╭(RX(0.01))† ···
   prop_2: ··· ─│─────────────│────────────│─────────────│─────────────│─────────── ···
  token_0: ··· ─│─────────────│────────────│─────────────│─────────────│─────────── ···
  token_1: ··· ─│─────────────│────────────│─────────────│─────────────│─────────── ···
  token_2: ··· ─│─────────────│────────────│─────────────╰(RY(-0.02))†─╰(RX(0.01))† ···
  token_3: ··· ─│─────────────╰(RY(0.01))†─╰(RX(-0.01))†─────────────────────────── ···
  token_4: ··· ─╰(RX(-0.00))†────────────────────────────────────────────────────── ···
  token_5: ··· ──────────────────────────────────────────────────────────────────── ···
ancilla_0: ··· ──────────────────────────────────────────────────────────────────── ···
ancilla_1: ··· ──────────────────────────────────────────────────────────────────── ···
ancilla_2: ··· ──────────────────────────────────────────────────────────────────── ···

   prop_0: ··· ────────────────────────────────────────────────────────╭(RY(0.00))†─╭(RX(0.01))† ···
   prop_1: ··· ─╭(RY(-0.01))†─╭(RX(-0.00))†─╭(RY(-0.01))†─╭(RX(0.01))†─│────────────│─────────── ···
   prop_2: ··· ─│─────────────│─────────────│─────────────│────────────│────────────│─────────── ···
  token_0: ··· ─│─────────────│─────────────╰(RY(-0.01))†─╰(RX(0.01))†─│────────────│─────────── ···
  token_1: ··· ─╰(RY(-0.01))†─╰(RX(-0.00))†────────────────────────────│────────────│─────────── ···
  token_2: ··· ────────────────────────────────────────────────────────│────────────│─────────── ···
  token_3: ··· ────────────────────────────────────────────────────────│────────────│─────────── ···
  token_4: ··· ────────────────────────────────────────────────────────│────────────│─────────── ···
  token_5: ··· ────────────────────────────────────────────────────────╰(RY(0.00))†─╰(RX(0.01))† ···
ancilla_0: ··· ───────────────────────────────────────────────────────────────────────────────── ···
ancilla_1: ··· ───────────────────────────────────────────────────────────────────────────────── ···
ancilla_2: ··· ───────────────────────────────────────────────────────────────────────────────── ···

   prop_0: ··· ─╭(RY(-0.02))†─╭(RX(0.00))†─╭(RY(-0.00))†─╭(RX(-0.02))†─╭(RY(-0.02))† ···
   prop_1: ··· ─│─────────────│────────────│─────────────│─────────────│──────────── ···
   prop_2: ··· ─│─────────────│────────────│─────────────│─────────────│──────────── ···
  token_0: ··· ─│─────────────│────────────│─────────────│─────────────│──────────── ···
  token_1: ··· ─│─────────────│────────────│─────────────│─────────────│──────────── ···
  token_2: ··· ─│─────────────│────────────│─────────────│─────────────╰(RY(-0.02))† ···
  token_3: ··· ─│─────────────│────────────╰(RY(-0.00))†─╰(RX(-0.02))†──RZ(-0.00)†── ···
  token_4: ··· ─╰(RY(-0.02))†─╰(RX(0.00))†──RZ(0.00)†─────RY(-0.00)†────RX(0.02)†─── ···
  token_5: ··· ──RZ(0.00)†─────RY(-0.01)†───RX(-0.01)†────────────────────────────── ···
ancilla_0: ··· ───────────────────────────────────────────────────────────────────── ···
ancilla_1: ··· ───────────────────────────────────────────────────────────────────── ···
ancilla_2: ··· ───────────────────────────────────────────────────────────────────── ···

   prop_0: ··· ─╭(RX(-0.01))†─╭(RY(-0.00))†─╭(RX(-0.01))†─╭(RY(0.00))†─╭(RX(0.00))†──────────── ···
   prop_1: ··· ─│─────────────│─────────────│─────────────│────────────│─────────────────────── ···
   prop_2: ··· ─│─────────────│─────────────│─────────────│────────────│─────────────────────── ···
  token_0: ··· ─│─────────────│─────────────│─────────────╰(RY(0.00))†─╰(RX(0.00))†──RZ(-0.01)† ···
  token_1: ··· ─│─────────────╰(RY(-0.00))†─╰(RX(-0.01))†──RZ(-0.00)†───RY(-0.01)†───RX(-0.00)† ···
  token_2: ··· ─╰(RX(-0.01))†──RZ(0.00)†─────RY(-0.02)†────RX(-0.01)†────────────────────────── ···
  token_3: ··· ──RY(-0.00)†────RX(0.00)†─────────────────────────────────────────────────────── ···
  token_4: ··· ──────────────────────────────────────────────────────────────────────────────── ···
  token_5: ··· ──────────────────────────────────────────────────────────────────────────────── ···
ancilla_0: ··· ──────────────────────────────────────────────────────────────────────────────── ···
ancilla_1: ··· ──────────────────────────────────────────────────────────────────────────────── ···
ancilla_2: ··· ──────────────────────────────────────────────────────────────────────────────── ···

   prop_0: ··· ───────────────────────┤       
   prop_1: ··· ───────────────────────┤       
   prop_2: ··· ───────────────────────┤       
  token_0: ··· ──RY(0.00)†──RX(0.00)†─┤ ╭Probs
  token_1: ··· ───────────────────────┤ ├Probs
  token_2: ··· ───────────────────────┤ ├Probs
  token_3: ··· ───────────────────────┤ ├Probs
  token_4: ··· ───────────────────────┤ ├Probs
  token_5: ··· ───────────────────────┤ ╰Probs
ancilla_0: ··· ───────────────────────┤       
ancilla_1: ··· ───────────────────────┤       
ancilla_2: ··· ───────────────────────┤       