In [1]:
from typing import Tuple, List, Dict, Any

from rdkit import Chem
from rdkit.Chem import AllChem
import selfies as sf

import numpy as np
from math import ceil, log2
import re
import pandas as pd
import optax
import csv
import flax.linen as nn
import math

import pennylane as qml
from pennylane import qchem
from pennylane.templates import StronglyEntanglingLayers

import jax
import jax.numpy as jnp
from jax.nn.initializers import normal

import haiku as hk

In [2]:
from chembl_webresource_client.new_client import new_client

# Using the ChEMBL API to get the molecules dataset

molecule = new_client.molecule

# Filter for drug-like small molecules interesting for human use
druglike_molecules = molecule.filter(
    molecule_properties__heavy_atoms__lte=15,           # Heavy atoms less than 20
    molecule_properties__alogp__lte=5,                  # LogP less than 5 (Lipophilicity and membrane permeability)
    molecule_properties__mw_freebase__lte=300,          # Molecular weight less than 300 g/mol
    molecule_properties__qed_weighted__gte=0.5,         # QED weighted greater than 0.5 (Drug-likeness)
    molecule_properties__num_ro5_violations__lte=1,     # At most 1 Rule of 5 violation (Drug-likeness filter)

)

print("Training molecules set: ", len(druglike_molecules))  # Check how many molecules match the filter criteria

Training molecules set:  65778


## 1. Get the alphabet used for the SMILEs representation

Build the alphabet considering the structure of some special tokens

In [3]:
# Define the subset of molecules we are going to train the model with
molecules_subset = druglike_molecules[:600]

max_len = 0
alphabet = set()
# We use a subset of molecules to build the alphabet
for mol in molecules_subset:  # Limiting to 1000 molecules for performance
    smiles = mol.get('molecule_structures', {}).get('canonical_smiles')
    selfies = sf.encoder(smiles)
    if selfies:
        # Skip if contains '.'
        if "." in selfies:
            continue
        tokens = list(sf.split_selfies(selfies))
        if max_len < len(tokens):
            max_len = len(tokens)
        alphabet.update(tokens)

alphabet = sorted(alphabet)
alphabet = ['<SOS>'] + alphabet  + ['<EOS>'] # Add Start-of-Secuence, End-of-secuence

In [4]:
print("Alphabet of SMILES characters:", alphabet)

VOCABULARY_SIZE = len(alphabet)
BITS_PER_TOKEN = ceil(log2(VOCABULARY_SIZE))  # n bits por token

print("Total unique characters in SMILES:", VOCABULARY_SIZE)
print("Maximum length of SMILES in dataset:", max_len)
print("Bits per token:", BITS_PER_TOKEN)

Alphabet of SMILES characters: ['<SOS>', '[#Branch1]', '[#Branch2]', '[#C]', '[#N]', '[/C]', '[/Cl]', '[/N]', '[/S]', '[=Branch1]', '[=Branch2]', '[=C]', '[=N+1]', '[=N]', '[=O]', '[=P]', '[=Ring1]', '[=S]', '[Br]', '[Branch1]', '[Branch2]', '[C@@H1]', '[C@@]', '[C@H1]', '[C@]', '[C]', '[Cl]', '[F]', '[I]', '[N+1]', '[NH1]', '[N]', '[O-1]', '[O]', '[PH1]', '[P]', '[Ring1]', '[Ring2]', '[S+1]', '[S]', '[\\C]', '[\\Cl]', '[\\N]', '<EOS>']
Total unique characters in SMILES: 44
Maximum length of SMILES in dataset: 29
Bits per token: 6


•	Símbolos de enlaces y paréntesis: #, (, ), /, \, =

•	Dígitos simples para cierres de anillos: '1', '2', '3', '4', '5'

•	Átomos orgánicos y halógenos comunes, tanto mayúsculas (alifáticos) como minúsculas (aromáticos)

•	Tokens entre corchetes para isótopos, estados de carga, quiralidad, etc.

•	Un token especial '<PAD>' para padding en modelos ML

In [5]:
# Diccionario token → índice
token_to_index = {tok: i for i, tok in enumerate(alphabet)}

def print_token_bits(tokens, token_to_index):
    for tok in tokens:
        idx = token_to_index.get(tok, None)
        if idx is None:
            print(f"Token '{tok}' no está en el diccionario.")
            continue
        binary = format(idx, f'0{BITS_PER_TOKEN}b')
        print(f"'{tok}' → index {idx} → {binary}")

print_token_bits(alphabet, token_to_index)

'<SOS>' → index 0 → 000000
'[#Branch1]' → index 1 → 000001
'[#Branch2]' → index 2 → 000010
'[#C]' → index 3 → 000011
'[#N]' → index 4 → 000100
'[/C]' → index 5 → 000101
'[/Cl]' → index 6 → 000110
'[/N]' → index 7 → 000111
'[/S]' → index 8 → 001000
'[=Branch1]' → index 9 → 001001
'[=Branch2]' → index 10 → 001010
'[=C]' → index 11 → 001011
'[=N+1]' → index 12 → 001100
'[=N]' → index 13 → 001101
'[=O]' → index 14 → 001110
'[=P]' → index 15 → 001111
'[=Ring1]' → index 16 → 010000
'[=S]' → index 17 → 010001
'[Br]' → index 18 → 010010
'[Branch1]' → index 19 → 010011
'[Branch2]' → index 20 → 010100
'[C@@H1]' → index 21 → 010101
'[C@@]' → index 22 → 010110
'[C@H1]' → index 23 → 010111
'[C@]' → index 24 → 011000
'[C]' → index 25 → 011001
'[Cl]' → index 26 → 011010
'[F]' → index 27 → 011011
'[I]' → index 28 → 011100
'[N+1]' → index 29 → 011101
'[NH1]' → index 30 → 011110
'[N]' → index 31 → 011111
'[O-1]' → index 32 → 100000
'[O]' → index 33 → 100001
'[PH1]' → index 34 → 100010
'[P]' → index 35 →

The input SMILEs must be the same size, so we need to use the padding to make it uniform

In [6]:
basis_encoded_dataset = []
token_to_index = {tok: i for i, tok in enumerate(alphabet)}

def smiles_to_bits(tokens: list) -> np.ndarray:
    """Convert tokens to a 2D array"""
    padded_tokens = ['<SOS>'] + tokens + ['<EOS>']
    bit_matrix = []
    for tok in padded_tokens:
        idx = token_to_index[tok]
        bits = list(f"{idx:0{BITS_PER_TOKEN}b}")  # length of the binary string depends on the number of bits required to represent the alphabet
        bit_matrix.append([int(b) for b in bits])
    return np.array(bit_matrix)

## 2. Obtain the molecular properties of interest

logP (o cx_logp) -> Coeficiente de partición octanol/agua (lipofilia)

QED (quantitative estimate of drug-likeness) -> Escala combinada que evalúa qué tan “drug-like” es una molécula

SAS (Synthetic Accessibility Score) -> Qué tan difícil sería sintetizar la molécula en laboratorio * Needs to be calculated separately!!

MW (peso molecular) -> Masa total, típicamente ≤ 500 Da para buenos fármacos orales


First we get the max and min range values in order to later normalize the properties in the range (0, pi)

In [7]:
min_logp = float('inf')
max_logp = float('-inf')
min_qed = float('inf')
max_qed = float('-inf')
min_mw = float('inf')
max_mw = float('-inf')


# Iterate through the subset of molecules to find min/max properties to normalize them
for mol in molecules_subset:
    logP = mol.get('molecule_properties', {}).get('alogp')
    qed = mol.get('molecule_properties', {}).get('qed_weighted')
    mw = mol.get('molecule_properties', {}).get('mw_freebase')

    if logP is None or qed is None or mw is None:
        continue  # Skip if any property is missing

    logP = float(logP)
    qed = float(qed)
    mw = float(mw)

    if logP < min_logp:
        min_logp = logP
    if logP > max_logp:
        max_logp = logP

    if qed < min_qed:
        min_qed = qed
    if qed > max_qed:
        max_qed = qed

    if mw < min_mw:
        min_mw = mw
    if mw > max_mw:
        max_mw = mw

print(f"LogP range: {min_logp} to {max_logp}")
print(f"QED range: {min_qed} to {max_qed}")
print(f"MW range: {min_mw} to {max_mw}")

LogP range: -1.5 to 3.89
QED range: 0.5 to 0.92
MW range: 107.11 to 298.11


In [8]:
def normalize(value, min_val, max_val, target_max=np.pi):
    ''' Normalize a value to a range [0, [0, pi] to later encode them as rotation angles'''
    norm = (value - min_val) / (max_val - min_val) * target_max
    return float(f"{norm:.3f}")

In [9]:
print("Maximum length of sequences in the subset:", max_len)

Maximum length of sequences in the subset: 29


In [10]:
# Write the structured data to a CSV file
with open("structured_data_selfies.csv", mode="w", newline="") as file:
    writer = csv.writer(file)

    n_tokens = max_len + 2  # +2 for <SOS> and <EOS>
    header = ["logP", "qed", "mw"] + [f"token_{i}" for i in range(n_tokens)]
    writer.writerow(header)


    for mol in molecules_subset:
        smiles = mol.get('molecule_structures', {}).get('canonical_smiles')
        selfies = sf.encoder(smiles)
        props = mol.get('molecule_properties', {})
        if not selfies:
            continue
        if "." in selfies:
            continue
        try:
            logP = float(props.get('alogp'))
            qed = float(props.get('qed_weighted'))
            mw = float(props.get('mw_freebase'))
        except (TypeError, ValueError):
            continue

        norm_logp = normalize(logP, min_logp, max_logp)
        norm_qed = normalize(qed, min_qed, max_qed)
        norm_mw = normalize(mw, min_mw, max_mw)

        tokens = list(sf.split_selfies(selfies))
        if not all(tok in token_to_index for tok in tokens):
            continue

        bit_matrix = smiles_to_bits(tokens)  # shape (n_tokens, 6)
        token_bits_as_strings = ["".join(map(str, row)) for row in bit_matrix]

        row = [norm_logp, norm_qed, norm_mw] + token_bits_as_strings
        writer.writerow(row)

## 3. Quantum Generative Model

In [11]:
import itertools
def zstring_combos(wires):
    """
    Return an ordered list of wire-tuples for all Z-strings up to order H_LOCAL.
    Order: all 1-local, then all 2-local, ..., up to H_LOCAL.
    """
    L = []
    for k in range(1, H_LOCAL + 1):
        L.extend(itertools.combinations(wires, k))
    return [tuple(c) for c in L]

def num_zstrings(n_wires):
    """Count how many Z-strings up to order H_LOCAL."""
    from math import comb
    return sum(comb(n_wires, k) for k in range(1, H_LOCAL + 1))

In [12]:
# Quantum Attention mechanism using SWAP test

# --- Device for attention ---
n_past = 5
attn_dev = qml.device("default.qubit", wires=(BITS_PER_TOKEN+1)*n_past)

@qml.qnode(attn_dev, interface="jax")
def quantum_attention_qnode(Q_vec, K_vecs):
    """
    Q_vec: projected query vector of current token
    K_vecs: list of projected key vectors for past tokens
    Returns: attention scores ⟨q_i | k_j⟩ for each j
    """
    n_tokens = len(K_vecs)
    q_wires = list(range(BITS_PER_TOKEN))

    # Encode Q and all K in parallel (different wire registers)
    def encode_token(angles, wires):
        # Use AngleEmbedding for compactness, then entangle
        qml.templates.AngleEmbedding(angles, wires=wires, rotation="Y")
        for i in range(len(wires)-1):
            qml.CNOT(wires=[wires[i], wires[i+1]])

    # Encode Q
    encode_token(Q_vec, wires=q_wires)

    # Encode K_j
    # Collect expectation values (one per K_j)
    measurements = []
    for j, K_j in enumerate(K_vecs):
        start = (BITS_PER_TOKEN+1) + j*(BITS_PER_TOKEN+1)
        k_wires = list(range(start, start+BITS_PER_TOKEN))
        encode_token(K_j, k_wires)

        # SWAP test
        ancilla = start + BITS_PER_TOKEN
        qml.Hadamard(wires=ancilla)
        for qw, kw in zip(q_wires, k_wires):
            qml.CSWAP(wires=[ancilla, qw, kw])
        qml.Hadamard(wires=ancilla)

        measurements.append(qml.expval(qml.PauliZ(ancilla)))

    # **Return as tuple** so PennyLane converts to JAX array
    return tuple(measurements)


def quantum_attention(Q_vec, K_vecs, V_vecs):
    # Make non-traced (concrete) copies for the QNode
    Q_safe = jax.lax.stop_gradient(Q_vec)
    K_safe = [jax.lax.stop_gradient(k) for k in K_vecs]

    raw_expvals = quantum_attention_qnode(Q_safe, K_safe)
    raw_expvals = jnp.asarray(raw_expvals)
    # Convert from expectation values (in [-1,1]) to probabilities [0,1]
    overlaps = (1.0 - raw_expvals) / 2.0
    return overlaps


def classical_attention(Q_vec, K_vecs, V_vecs, scale=True):
    # Q_vec: (proj_dim,) ; K_vecs: list of (proj_dim,) ; V_vecs: list of (proj_dim,)
    if len(K_vecs) == 0:
        return V_vecs[0] if len(V_vecs) > 0 else jnp.zeros_like(Q_vec)

    K_mat = jnp.stack(K_vecs)  # shape (n_past, proj_dim)
    # dot-product softmax (fully differentiable)
    scores = jnp.dot(K_mat, Q_vec)  # shape (n_past,)
    if scale:
        scores = scores / jnp.sqrt(Q_vec.shape[0])
    # Causal mask: only past tokens (K_vecs are already past tokens)
    weights = jax.nn.softmax(scores)
    # Weighted sum over classical V
    output = jnp.sum(weights[:, None] * jnp.stack(V_vecs), axis=0)
    return output

def batched_classical_attention(Q_batch, K_batch_list, V_batch_list, scale=True):
    """
    Q_batch: (batch_size, proj_dim)
    K_batch_list: list of lists of past K vectors per batch element
    V_batch_list: list of lists of past V vectors per batch element
    """
    def attention_single(Q_vec, K_vecs, V_vecs):
        if len(K_vecs) == 0:
            return x_i @ W_V if len(V_vecs) > 0 else jnp.zeros_like(Q_vec)
        K_mat = jnp.stack(K_vecs)  # shape (n_past, proj_dim)
        scores = jnp.dot(K_mat, Q_vec)
        if scale:
            scores = scores / jnp.sqrt(Q_vec.shape[0])
        weights = jax.nn.softmax(scores)
        return jnp.sum(weights[:, None] * jnp.stack(V_vecs), axis=0)

    return jax.vmap(attention_single)(Q_batch, K_batch_list, V_batch_list)


In [13]:
# Device and qubit setup
# BITS_PER_TOKEN number of qubits needed to encode each token
n_prop_qubits = 3  # number of qubits needed to encode properties (logP, QED, MW)
n_ancillas = 3  # number of ancilla qubits that represent the environment
n_total_qubits = n_prop_qubits + BITS_PER_TOKEN + n_ancillas

N_LAYERS = 6  # number of variational layers
H_LOCAL = 3 # h_local sets the maximum number of qubits that can interact in each Z-string term of Σ


# Name them explicitly
prop_wires = [f"prop_{i}" for i in range(n_prop_qubits)]
token_wires = [f"token_{i}" for i in range(BITS_PER_TOKEN)]
ancilla_wires = [f"ancilla_{i}" for i in range(n_ancillas)]
all_wires = prop_wires + token_wires + ancilla_wires

dev = qml.device("default.qubit", wires=all_wires)


def molecular_property_encoder(props):
    """Encode continuous props on property qubits via RY rotations"""
    for wire, val in zip(prop_wires, props):
        qml.RY(val, wires=wire)


def token_encoder(token_bits):
    """Basis-encode token bits on token qubits"""
    qml.BasisState(token_bits, wires=token_wires)


def operator_layer(theta_params, theta_prop, wires):
    """
    Variational layer where:
      - theta_params[...] are rotations for token + ancilla qubits
      - theta_prop encodes property→token entanglement
    """
    token_ancilla_ws = token_wires + ancilla_wires

    # Property → token entanglement 
    for p, prop_wire in enumerate(prop_wires):
        for t, t_a_wire in enumerate(token_ancilla_ws):
            qml.CRX(theta_prop[p, t, 0], wires=[prop_wire, t_a_wire])
            qml.CRY(theta_prop[p, t, 1], wires=[prop_wire, t_a_wire])

 
    qml.StronglyEntanglingLayers(
        weights=theta_params[None,:,:],  # shape: (n_token_ancilla, 3)
        wires=token_ancilla_ws
    )

def Sigma_layer_vec(gamma_vec, token_ancilla_ws, time=1.0, combos=None):
    """
    Diagonal multi-Z unitary Σ = exp(i * sum_s gamma_s * Z^{⊗|s|} * t)
    using a flat parameter vector 'gamma_vec' aligned with 'combos'.
    """
    #token_ancilla_ws = list(wires)  # pass token+ancilla here
    if combos is None:
        combos = zstring_combos(token_ancilla_ws)

    # Safety: ensure the vector length matches the number of combos
    assert gamma_vec.shape[0] == len(combos), \
        f"gamma_vec has length {gamma_vec.shape[0]} but expected {len(combos)}"

    # MultiRZ(phi) = exp(-i * phi/2 * Z^{⊗k}); choose phi = -2 * gamma * time
    for gamma, combo in zip(gamma_vec, combos):
        qml.MultiRZ(-2.0 * gamma * time, wires=list(combo))


# QNode combining encoding and variational layers
@qml.qnode(dev, interface="jax")
def single_qnode(token_bits, props, theta_params, theta_prop, sigma_params, output_i):
    molecular_property_encoder(props)
    token_encoder(token_bits)

    for i, val in enumerate(output_i):
        qml.RY(val, wires=token_wires[i])

    token_ancilla_ws = token_wires + ancilla_wires
    combos = zstring_combos(token_ancilla_ws)

    for l in range(N_LAYERS):
        operator_layer(theta_params[l], theta_prop[l], wires=all_wires)
        Sigma_layer_vec(sigma_params[l], token_ancilla_ws, time=1.0, combos=combos)
        qml.adjoint(operator_layer)(theta_params[l], theta_prop[l], wires=all_wires)

    return qml.probs(wires=token_wires)

# Vectorize over batch
batched_qnode = jax.vmap(
    single_qnode,
    in_axes=(0, 0, None, None, None, 0)  # batch token_bits, props, output_i; params are shared
)

In [14]:
def bitstr_to_array(bitstr):
    """Convert a string of bits (e.g., '010101') to a numpy float32 array."""
    return np.array([int(b) for b in bitstr], dtype=np.float32)

def build_training_data(df):
    """
    Build dataset tuples of (input_token_bits, molecular_properties, target_token_bits)
    from a DataFrame.

    Args:
        df (pandas.DataFrame): DataFrame containing molecular properties and token bit strings.
        n_token_cols (int): Number of token columns in the DataFrame.

    Returns:
        list of tuples: Each tuple contains (x_token: np.array, x_props: np.array, y_target: np.array)
    """
    dataset = []

    for _, row in df.iterrows():
        # Extract molecular properties as a numpy float32 array
        props = [row['logP'], row['qed'], row['mw']]
        x_props = np.array(props, dtype=np.float32)

        tokens = row[3:]  # token columns after properties

        # Iterate over token sequence to create input-target pairs
        for i in range(len(tokens) - 1):
            current_token = tokens.iloc[i]
            next_token = tokens.iloc[i + 1]

            # Skip missing or NaN tokens
            if current_token is None or (isinstance(current_token, float) and math.isnan(current_token)):
                continue
            if next_token is None or (isinstance(next_token, float) and math.isnan(next_token)):
                continue

            # Convert token strings (e.g., '01011') to bit arrays
            x_token = bitstr_to_array(current_token)
            y_target = bitstr_to_array(next_token)

            dataset.append((x_token, x_props, y_target))

    return dataset


# Load dataset
token_cols = [f"token_{i}" for i in range(n_tokens)]
df = pd.read_csv("structured_data_selfies.csv", dtype={col: str for col in token_cols})
dataset = build_training_data(df)  # Should return list/array of (x_token, x_props, y_target)

In [42]:
def bits_to_index(bits):
    powers = 2 ** jnp.arange(BITS_PER_TOKEN - 1, -1, -1)
    return jnp.dot(bits, powers).astype(jnp.int32)

def categorical_crossentropy(pred_probs, target_index):
    epsilon = 1e-10
    return -jnp.log(pred_probs[target_index] + epsilon)

def label_smoothing_crossentropy_normalized(pred_probs, target_index, epsilon=0.1):
    """Cross-entropy loss with label smoothing, normalized to [0,1]."""
    num_classes = pred_probs.shape[0] # Number of classes (tokens)
    
    # Build smoothed target
    smooth_target = jnp.full_like(pred_probs, epsilon / (num_classes - 1))
    smooth_target = smooth_target.at[target_index].set(1.0 - epsilon)
    
    # Compute cross-entropy
    loss = -jnp.sum(smooth_target * jnp.log(pred_probs + 1e-10))
    
    # Normalize to [0,1]
    max_loss = jnp.log(num_classes)  # worst-case (uniform distribution)
    norm_loss = loss / max_loss
    
    return norm_loss


def compute_accuracy(pred_probs_batch, y_targets):
    """
    pred_probs_batch: (batch_size, vocab_size)
    y_targets: (batch_size,) integer indices
    """
    pred_indices = jnp.argmax(pred_probs_batch, axis=1)  # predicted token index
    correct = pred_indices == y_targets
    return jnp.mean(correct)  # fraction correct

In [43]:
# Token embedding
key = jax.random.PRNGKey(42)
EMBEDDING_SIZE = 8      # size of embeddings
key, k_emb = jax.random.split(key)
embedding_table = jax.random.normal(k_emb, (VOCABULARY_SIZE, EMBEDDING_SIZE)) * 0.1

# Projection matrices
key = jax.random.PRNGKey(42)
proj_dim = BITS_PER_TOKEN  # number of qubits for quantum attention
key, k_WQ, k_WK, k_WV = jax.random.split(key, 4)
W_Q = jax.random.normal(k_WQ, (EMBEDDING_SIZE, proj_dim)) * 0.1
W_K = jax.random.normal(k_WK, (EMBEDDING_SIZE, proj_dim)) * 0.1
W_V = jax.random.normal(k_WV, (EMBEDDING_SIZE, proj_dim)) * 0.1


# Effective qubit counts in variational layers
n_token_ancilla = BITS_PER_TOKEN + n_ancillas

# Initialize theta and sigma params
key = jax.random.PRNGKey(42)
key, k_theta, k_theta_prop, k_sigma = jax.random.split(key, 4)

# Precompute Z-string combos once
token_ancilla_ws = token_wires + ancilla_wires
combos = zstring_combos(token_ancilla_ws)
n_strings = len(combos)

# Combine all trainable parameters into a single dictionary
combined_params = {
    'theta': jax.random.normal(k_theta, (N_LAYERS, n_token_ancilla, 3)) * 0.1,
    'theta_prop': jax.random.normal(k_theta_prop, (N_LAYERS, n_prop_qubits, n_token_ancilla, 4)) * 0.1,
    'sigma': jax.random.normal(k_sigma, (N_LAYERS, n_strings)) * 0.1,
    'embedding_table': embedding_table,
    'W_Q': W_Q,
    'W_K': W_K,
    'W_V': W_V
}
# Training hyperparams
learning_rate = 0.001
n_epochs = 25

# Optimizer
optimizer = optax.adam(learning_rate)
opt_state = optimizer.init(combined_params)

MAX_PAST = 10  # maximum number of past tokens to keep

# --- Prepare past tokens as fixed-size array ---
def prepare_past_tokens(past_token_indices):
    """Convert dynamic list of past token indices to fixed-size array for JAX"""
    arr = jnp.zeros(MAX_PAST, dtype=jnp.int32)
    n_past = min(len(past_token_indices), MAX_PAST)
    if n_past > 0:
        arr = arr.at[:n_past].set(jnp.array(past_token_indices[-n_past:], dtype=jnp.int32))
    return arr, n_past



In [44]:
import jax
import jax.numpy as jnp
from jax import grad, jit, vmap, value_and_grad

def get_valid_embeddings(indices, table, W):
    """
    Gather embeddings while ignoring zero-padding indices.
    indices: (max_history,) int array
    table: embedding table (VOCAB_SIZE, EMBED_DIM)
    W: projection matrix (EMBED_DIM, proj_dim)
    """
    # indices shape: (max_history,)
    embeddings = table[indices] @ W         # shape: (max_history, proj_dim)
    mask = (indices != 0).astype(jnp.float32)[:, None]  # (max_history, 1)
    embeddings = embeddings * mask          # zero out padding
    return jnp.sum(embeddings, axis=0)     # sum valid embeddings -> (proj_dim,)

def compute_output_i(q, x_i, past_idx, embedding_table, W_K, W_V):
    """
    Compute output_i for one element.
    q: (proj_dim,)
    x_i: (embed_dim,)
    past_idx: (max_history,)
    """
    def no_past(_):
        return x_i @ W_V

    def has_past(past_idx_inner):
        K_vecs_sum = get_valid_embeddings(past_idx_inner, embedding_table, W_K)
        V_vecs_sum = get_valid_embeddings(past_idx_inner, embedding_table, W_V)
        return classical_attention(q, K_vecs_sum, V_vecs_sum)

    return jax.lax.cond(jnp.all(past_idx == 0), no_past, has_past, operand=past_idx)

def batched_training_step(params, opt_state, x_tokens, x_props, y_targets, past_tokens_batch):
    """
    Batched training step.
    x_tokens: (batch_size, BITS_PER_TOKEN)
    x_props:  (batch_size, n_props)
    y_targets:(batch_size,)
    past_tokens_batch: (batch_size, max_history)
    """
    theta_params = params['theta']
    theta_prop   = params['theta_prop']
    sigma_params = params['sigma']
    embedding_table = params['embedding_table']
    W_Q = params['W_Q']
    W_K = params['W_K']
    W_V = params['W_V']

    batch_size = x_tokens.shape[0]

    def loss_fn(params):
        theta_params = params['theta']
        theta_prop   = params['theta_prop']
        sigma_params = params['sigma']

        # --- Embeddings ---
        token_indices = vmap(bits_to_index)(x_tokens)       # (batch,)
        x_i = embedding_table[token_indices]               # (batch, embed_dim)

        # --- Q/K/V projections ---
        Q = x_i @ W_Q                                      # (batch, proj_dim)

        # --- Compute output_i for each batch element ---
        output_i = vmap(
            compute_output_i,
            in_axes=(0, 0, 0, None, None, None)
        )(Q, x_i, past_tokens_batch, embedding_table, W_K, W_V)  # (batch, proj_dim)

        # --- QNode for each batch element ---
        def qnode_per_example(x_token, x_prop, out_i):
            return autoregressive_model(x_token, x_prop, theta_params, theta_prop, sigma_params, out_i)

        pred_probs = vmap(qnode_per_example)(x_tokens, x_props, output_i)  # (batch, num_classes)

        # --- Loss & accuracy ---
        target_indices = vmap(bits_to_index)(y_targets)
        losses = vmap(label_smoothing_crossentropy_normalized)(pred_probs, target_indices)
        avg_loss = jnp.mean(losses)

        acc = jnp.mean(jnp.argmax(pred_probs, axis=1) == target_indices)
        return avg_loss, (acc, pred_probs)

    (avg_loss, (avg_acc, pred_probs)), grads = value_and_grad(loss_fn, has_aux=True)(params)

    updates, opt_state = optimizer.update(grads, opt_state, params)
    params = optax.apply_updates(params, updates)

    return params, avg_loss, opt_state, grads, avg_acc


In [45]:
for epoch in range(n_epochs):
    total_loss = total_acc = 0.0
    past_token_indices = []

    for x_token, x_props, y_target in dataset:  # each batch
        past_tokens_arr = prepare_past_tokens(past_token_indices)
        combined_params, loss, opt_state, grads, acc = batched_training_step(
            combined_params,
            opt_state,
            x_token,
            x_props,
            y_target,
            past_tokens_arr
        )
        total_loss += loss * x_token.shape[0]
        total_acc += acc * x_token.shape[0]

        # Update past_token_indices (append new batch)
        past_token_indices = update_past_tokens(past_token_indices, x_token)

    avg_loss = total_loss / dataset_size
    avg_acc  = total_acc / dataset_size
    print(f"Epoch {epoch+1} | Loss = {avg_loss:.4f} | Accuracy = {avg_acc:.4f}")


ValueError: vmap was requested to map its argument along axis 0, which implies that its rank should be at least 1, but is only 0 (its shape is ())