In [1]:
from typing import Tuple, List, Dict, Any

from rdkit import Chem
from rdkit.Chem import AllChem

import numpy as np
from math import ceil, log2
import re
import pandas as pd
import optax
import csv
import flax.linen as nn
import math

import pennylane as qml
from pennylane import qchem
from pennylane.templates import StronglyEntanglingLayers

import jax
import jax.numpy as jnp
from jax.nn.initializers import normal

import haiku as hk

In [2]:
from chembl_webresource_client.new_client import new_client

# Using the ChEMBL API to get the molecules dataset

molecule = new_client.molecule

# Filter for drug-like small molecules interesting for human use
druglike_molecules = molecule.filter(
    molecule_properties__heavy_atoms__lte=15,           # Heavy atoms less than 20
    molecule_properties__alogp__lte=5,                  # LogP less than 5 (Lipophilicity and membrane permeability)
    molecule_properties__mw_freebase__lte=300,          # Molecular weight less than 300 g/mol
    molecule_properties__qed_weighted__gte=0.5,         # QED weighted greater than 0.5 (Drug-likeness)
    molecule_properties__num_ro5_violations__lte=1,     # At most 1 Rule of 5 violation (Drug-likeness filter)

)

print("Training molecules set: ", len(druglike_molecules))  # Check how many molecules match the filter criteria

Training molecules set:  65778


## 1. Get the alphabet used for the SMILEs representation

Build the alphabet considering the structure of some special tokens

In [3]:
# Regex sin grupos capturantes y con reconocimiento de átomos minúsculos
TOKEN_PATTERN = re.compile(
    r'\[[^\[\]]+\]'     # atoms in brakets (ej. [C@@H])
    r'|Br|Cl'           # 2 letter atoms (ej. Br, Cl)
    r'|[A-Z](?![a-z])'  # capital letters (ej. C, N, O)
    r'|[a-z]'           # lowercase letters (ej. s, p, d)
    r'|\d'              # ring numbers (ej. 1, 2, 3)
    r'|=|#|\/|\\|\(|\)' # birdges and parentheses (ej. =, #, /, \, (, ))
)

def tokenize_smiles(smiles: str):
    mol = Chem.MolFromSmiles(smiles)
    if mol is None:
        raise ValueError(f"SMILES inválido: {smiles}")
    can = Chem.MolToSmiles(mol, canonical=False)
    return TOKEN_PATTERN.findall(can)


In [4]:
# Define the subset of molecules we are going to train the model with
molecules_subset = druglike_molecules[:600]

max_len = 0
alphabet = set()

# We use a subset of molecules to build the alphabet
for mol in molecules_subset:  # Limiting to 1000 molecules for performance
    smiles = mol.get('molecule_structures', {}).get('canonical_smiles')
    if smiles:
        tokens = tokenize_smiles(smiles)
        if max_len < len(tokens):
            max_len = len(tokens)
        alphabet.update(tokens)

alphabet = sorted(alphabet)
alphabet = ['<SOS>'] + alphabet  + ['<EOS>'] # Add Start-of-Secuence, End-of-secuence


In [5]:
print("Alphabet of SMILES characters:", alphabet)

VOCABULARY_SIZE = len(alphabet)
BITS_PER_TOKEN = ceil(log2(VOCABULARY_SIZE))  # n bits por token

print("Total unique characters in SMILES:", VOCABULARY_SIZE)
print("Maximum length of SMILES in dataset:", max_len)
print("Bits per token:", BITS_PER_TOKEN)

Alphabet of SMILES characters: ['<SOS>', '#', '(', ')', '/', '1', '2', '3', '=', 'Br', 'C', 'Cl', 'F', 'I', 'N', 'O', 'P', 'S', '[C@@H]', '[C@@]', '[C@H]', '[C@]', '[Cl-]', '[Li+]', '[N+]', '[Na+]', '[O-]', '[PH]', '[S+]', '[nH]', '\\', 'c', 'n', 'o', 's', '<EOS>']
Total unique characters in SMILES: 36
Maximum length of SMILES in dataset: 33
Bits per token: 6


•	Símbolos de enlaces y paréntesis: #, (, ), /, \, =

•	Dígitos simples para cierres de anillos: '1', '2', '3', '4', '5'

•	Átomos orgánicos y halógenos comunes, tanto mayúsculas (alifáticos) como minúsculas (aromáticos)

•	Tokens entre corchetes para isótopos, estados de carga, quiralidad, etc.

•	Un token especial '<PAD>' para padding en modelos ML

In [6]:
# Diccionario token → índice
token_to_index = {tok: i for i, tok in enumerate(alphabet)}

def print_token_bits(tokens, token_to_index):
    for tok in tokens:
        idx = token_to_index.get(tok, None)
        if idx is None:
            print(f"Token '{tok}' no está en el diccionario.")
            continue
        binary = format(idx, f'0{BITS_PER_TOKEN}b')
        print(f"'{tok}' → index {idx} → {binary}")

print_token_bits(alphabet, token_to_index)

'<SOS>' → index 0 → 000000
'#' → index 1 → 000001
'(' → index 2 → 000010
')' → index 3 → 000011
'/' → index 4 → 000100
'1' → index 5 → 000101
'2' → index 6 → 000110
'3' → index 7 → 000111
'=' → index 8 → 001000
'Br' → index 9 → 001001
'C' → index 10 → 001010
'Cl' → index 11 → 001011
'F' → index 12 → 001100
'I' → index 13 → 001101
'N' → index 14 → 001110
'O' → index 15 → 001111
'P' → index 16 → 010000
'S' → index 17 → 010001
'[C@@H]' → index 18 → 010010
'[C@@]' → index 19 → 010011
'[C@H]' → index 20 → 010100
'[C@]' → index 21 → 010101
'[Cl-]' → index 22 → 010110
'[Li+]' → index 23 → 010111
'[N+]' → index 24 → 011000
'[Na+]' → index 25 → 011001
'[O-]' → index 26 → 011010
'[PH]' → index 27 → 011011
'[S+]' → index 28 → 011100
'[nH]' → index 29 → 011101
'\' → index 30 → 011110
'c' → index 31 → 011111
'n' → index 32 → 100000
'o' → index 33 → 100001
's' → index 34 → 100010
'<EOS>' → index 35 → 100011


The input SMILEs must be the same size, so we need to use the padding to make it uniform

In [7]:
basis_encoded_dataset = []
token_to_index = {tok: i for i, tok in enumerate(alphabet)}

def smiles_to_bits(tokens: list, max_len: int) -> np.ndarray:
    """Convert tokens to a 2D array of shape (max_len, n_bits)"""
    padded_tokens = ['<SOS>'] + tokens + ['<EOS>']
    bit_matrix = []
    for tok in padded_tokens:
        idx = token_to_index[tok]
        bits = list(f"{idx:0{BITS_PER_TOKEN}b}")  # length of the binary string depends on the number of bits required to represent the alphabet
        bit_matrix.append([int(b) for b in bits])
    return np.array(bit_matrix)  # shape (max_len, n_bits)

for mol in molecules_subset:
    smiles = mol.get('molecule_structures', {}).get('canonical_smiles')
    if smiles:
        tokens = tokenize_smiles(smiles)
        print(f"Processing SMILES: {tokens}")
        if all(tok in token_to_index for tok in tokens):
            bit_matrix = smiles_to_bits(tokens, max_len)
            basis_encoded_dataset.append(bit_matrix)
            print(f"SMILES: {smiles}, shape: {bit_matrix}")



Processing SMILES: ['C', 'N', '1', 'C', 'C', 'C', '[C@H]', '1', 'c', '1', 'c', 'c', 'c', 'n', 'c', '1']
SMILES: CN1CCC[C@H]1c1cccnc1, shape: [[0 0 0 0 0 0]
 [0 0 1 0 1 0]
 [0 0 1 1 1 0]
 [0 0 0 1 0 1]
 [0 0 1 0 1 0]
 [0 0 1 0 1 0]
 [0 0 1 0 1 0]
 [0 1 0 1 0 0]
 [0 0 0 1 0 1]
 [0 1 1 1 1 1]
 [0 0 0 1 0 1]
 [0 1 1 1 1 1]
 [0 1 1 1 1 1]
 [0 1 1 1 1 1]
 [1 0 0 0 0 0]
 [0 1 1 1 1 1]
 [0 0 0 1 0 1]
 [1 0 0 0 1 1]]
Processing SMILES: ['c', '1', 'c', 'n', 'c', 'c', '(', '[C@@H]', '2', 'C', 'C', 'C', 'N', '2', ')', 'c', '1']
SMILES: c1cncc([C@@H]2CCCN2)c1, shape: [[0 0 0 0 0 0]
 [0 1 1 1 1 1]
 [0 0 0 1 0 1]
 [0 1 1 1 1 1]
 [1 0 0 0 0 0]
 [0 1 1 1 1 1]
 [0 1 1 1 1 1]
 [0 0 0 0 1 0]
 [0 1 0 0 1 0]
 [0 0 0 1 1 0]
 [0 0 1 0 1 0]
 [0 0 1 0 1 0]
 [0 0 1 0 1 0]
 [0 0 1 1 1 0]
 [0 0 0 1 1 0]
 [0 0 0 0 1 1]
 [0 1 1 1 1 1]
 [0 0 0 1 0 1]
 [1 0 0 0 1 1]]
Processing SMILES: ['C', 'C', '1', '(', 'C', ')', '[C@H]', '(', 'C', '(', '=', 'O', ')', 'O', ')', 'N', '2', 'C', '(', '=', 'O', ')', 'C', '[C@H]', '2', 

## 2. Obtain the molecular properties of interest

logP (o cx_logp) -> Coeficiente de partición octanol/agua (lipofilia)

QED (quantitative estimate of drug-likeness) -> Escala combinada que evalúa qué tan “drug-like” es una molécula

SAS (Synthetic Accessibility Score) -> Qué tan difícil sería sintetizar la molécula en laboratorio * Needs to be calculated separately!!

MW (peso molecular) -> Masa total, típicamente ≤ 500 Da para buenos fármacos orales


First we get the max and min range values in order to later normalize the properties in the range (0, pi)

In [8]:
min_logp = float('inf')
max_logp = float('-inf')
min_qed = float('inf')
max_qed = float('-inf')
min_mw = float('inf')
max_mw = float('-inf')


# Iterate through the subset of molecules to find min/max properties to normalize them
for mol in molecules_subset:
    logP = mol.get('molecule_properties', {}).get('alogp')
    qed = mol.get('molecule_properties', {}).get('qed_weighted')
    mw = mol.get('molecule_properties', {}).get('mw_freebase')

    if logP is None or qed is None or mw is None:
        continue  # Skip if any property is missing

    logP = float(logP)
    qed = float(qed)
    mw = float(mw)

    if logP < min_logp:
        min_logp = logP
    if logP > max_logp:
        max_logp = logP

    if qed < min_qed:
        min_qed = qed
    if qed > max_qed:
        max_qed = qed

    if mw < min_mw:
        min_mw = mw
    if mw > max_mw:
        max_mw = mw

print(f"LogP range: {min_logp} to {max_logp}")
print(f"QED range: {min_qed} to {max_qed}")
print(f"MW range: {min_mw} to {max_mw}")

LogP range: -1.5 to 3.89
QED range: 0.5 to 0.92
MW range: 107.11 to 298.11


In [9]:
def normalize(value, min_val, max_val, target_max=np.pi):
    ''' Normalize a value to a range [0, [0, pi] to later encode them as rotation angles'''
    norm = (value - min_val) / (max_val - min_val) * target_max
    return float(f"{norm:.3f}")

In [10]:
print("Maximum length of sequences in the subset:", max_len)

Maximum length of sequences in the subset: 33


In [11]:
# Write the structured data to a CSV file
DATA_PATH = "../data/structured_smiles_data.csv"
with open(DATA_PATH, mode="w", newline="") as file:
    writer = csv.writer(file)

    n_tokens = max_len + 2  # +2 for <SOS> and <EOS>
    header = ["logP", "qed", "mw"] + [f"token_{i}" for i in range(n_tokens)]
    writer.writerow(header)

    for mol in molecules_subset:
        smiles = mol.get('molecule_structures', {}).get('canonical_smiles')
        props = mol.get('molecule_properties', {})
        if not smiles:
            continue

        try:
            logP = float(props.get('alogp'))
            qed = float(props.get('qed_weighted'))
            mw = float(props.get('mw_freebase'))
        except (TypeError, ValueError):
            continue

        norm_logp = normalize(logP, min_logp, max_logp)
        norm_qed = normalize(qed, min_qed, max_qed)
        norm_mw = normalize(mw, min_mw, max_mw)

        tokens = tokenize_smiles(smiles)
        if not all(tok in token_to_index for tok in tokens):
            continue

        bit_matrix = smiles_to_bits(tokens, max_len)  # shape (n_tokens, 6)
        token_bits_as_strings = ["".join(map(str, row)) for row in bit_matrix]

        row = [norm_logp, norm_qed, norm_mw] + token_bits_as_strings
        writer.writerow(row)

## 3. Quantum Generative Model

In [12]:
import itertools
def zstring_combos(wires):
    """
    Return an ordered list of wire-tuples for all Z-strings up to order H_LOCAL.
    Order: all 1-local, then all 2-local, ..., up to H_LOCAL.
    """
    L = []
    for k in range(1, H_LOCAL + 1):
        L.extend(itertools.combinations(wires, k))
    return [tuple(c) for c in L]

def num_zstrings(n_wires):
    """Count how many Z-strings up to order H_LOCAL."""
    from math import comb
    return sum(comb(n_wires, k) for k in range(1, H_LOCAL + 1))

In [13]:
# Device and qubit setup
# BITS_PER_TOKEN number of qubits needed to encode each token
n_prop_qubits = 3  # number of qubits needed to encode properties (logP, QED, MW)
n_ancillas = 3  # number of ancilla qubits that represent the environment
n_total_qubits = n_prop_qubits + BITS_PER_TOKEN + n_ancillas

N_LAYERS = 6  # number of variational layers
H_LOCAL = 3 # h_local sets the maximum number of qubits that can interact in each Z-string term of Σ


# Name them explicitly
prop_wires = [f"prop_{i}" for i in range(n_prop_qubits)]
token_wires = [f"token_{i}" for i in range(BITS_PER_TOKEN)]
ancilla_wires = [f"ancilla_{i}" for i in range(n_ancillas)]
all_wires = prop_wires + token_wires + ancilla_wires

dev = qml.device("default.qubit", wires=all_wires)


def molecular_property_encoder(props):
    """Encode continuous props on property qubits via RY rotations"""
    for wire, val in zip(prop_wires, props):
        qml.RY(val, wires=wire)


def token_encoder(token_bits):
    qml.BasisState(token_bits, wires=token_wires)


def operator_layer(theta_params, theta_prop, wires):
    """
    Variational layer where:
      - theta_params[...] are rotations for token + ancilla qubits
      - theta_prop encodes property→token entanglement
    """
    prop_ws = prop_wires
    token_ancilla_ws = token_wires + ancilla_wires

    # Property → token entanglement
    for p, prop_wire in enumerate(prop_ws):
        for t, t_a_wire in enumerate(token_ancilla_ws):
            qml.CRX(theta_prop[p, t, 0], wires=[prop_wire, t_a_wire])
            qml.CRY(theta_prop[p, t, 1], wires=[prop_wire, t_a_wire])

    # Single-qubit rotations
        ''' for i, wire in enumerate(token_ancilla_ws):
        qml.RX(theta_params[i, 0], wires=wire)
        qml.RY(theta_params[i, 1], wires=wire)
        qml.RZ(theta_params[i, 2], wires=wire)

    # Entangle token + ancilla chain
    for i in range(len(token_ancilla_ws) - 1):
        qml.CNOT(wires=[token_ancilla_ws[i], token_ancilla_ws[i + 1]])
    qml.CNOT(wires=[token_ancilla_ws[-1], token_ancilla_ws[0]])'''
    
    qml.StronglyEntanglingLayers(
        weights=theta_params[None,:,:],  # shape: (n_token_ancilla, 3)
        wires=token_ancilla_ws
    )

def Sigma_layer_vec(gamma_vec, wires, time=1.0, combos=None):
    """
    Diagonal multi-Z unitary Σ = exp(i * sum_s gamma_s * Z^{⊗|s|} * t)
    using a flat parameter vector 'gamma_vec' aligned with 'combos'.
    """
    token_ancilla_ws = list(wires)  # pass token+ancilla here
    if combos is None:
        combos = zstring_combos(token_ancilla_ws)

    # Safety: ensure the vector length matches the number of combos
    assert gamma_vec.shape[0] == len(combos), \
        f"gamma_vec has length {gamma_vec.shape[0]} but expected {len(combos)}"

    # MultiRZ(phi) = exp(-i * phi/2 * Z^{⊗k}); choose phi = -2 * gamma * time
    for gamma, combo in zip(gamma_vec, combos):
        qml.MultiRZ(-2.0 * gamma * time, wires=list(combo))


# QNode combining encoding and variational layers
@qml.qnode(dev, interface="jax")
def autoregressive_model(token_bits, props, theta_params, theta_prop, sigma_params):
    molecular_property_encoder(props)      # Encode MW, logP, QED
    token_encoder(token_bits)              # Basis-encode token bits

    token_ancilla_ws = token_wires + ancilla_wires
    combos = zstring_combos(token_ancilla_ws)


    for l in range(N_LAYERS):
        # Forward V(θ)
        operator_layer(theta_params[l], theta_prop[l], wires=all_wires)

        # Diagonal Σ(γ,t): vector API
        Sigma_layer_vec(sigma_params[l], token_ancilla_ws, time=1.0, combos=combos)

        # Backward V(θ)†
        qml.adjoint(operator_layer)(theta_params[l], theta_prop[l], wires=all_wires)

    return qml.probs(wires=token_wires), [qml.expval(qml.PauliZ(w)) for w in prop_wires]

In [14]:
def bitstr_to_array(bitstr):
    """Convert a string of bits (e.g., '010101') to a numpy float32 array."""
    return np.array([int(b) for b in bitstr], dtype=np.float32)

def build_training_data(df):
    """
    Build dataset tuples of (input_token_bits, molecular_properties, target_token_bits)
    from a DataFrame.

    Args:
        df (pandas.DataFrame): DataFrame containing molecular properties and token bit strings.
        n_token_cols (int): Number of token columns in the DataFrame.

    Returns:
        list of tuples: Each tuple contains (x_token: np.array, x_props: np.array, y_target: np.array)
    """
    dataset = []

    for _, row in df.iterrows():
        # Extract molecular properties as a numpy float32 array
        props = [row['logP'], row['qed'], row['mw']]
        x_props = np.array(props, dtype=np.float32)

        tokens = row[3:]  # token columns after properties

        # Iterate over token sequence to create input-target pairs
        for i in range(len(tokens) - 1):
            current_token = tokens.iloc[i]
            next_token = tokens.iloc[i + 1]

            # Skip missing or NaN tokens
            if current_token is None or (isinstance(current_token, float) and math.isnan(current_token)):
                continue
            if next_token is None or (isinstance(next_token, float) and math.isnan(next_token)):
                continue

            # Convert token strings (e.g., '01011') to bit arrays
            x_token = bitstr_to_array(current_token)
            y_target = bitstr_to_array(next_token)

            dataset.append((x_token, x_props, y_target))

    return dataset


# Load dataset
token_cols = [f"token_{i}" for i in range(n_tokens)]
df = pd.read_csv(DATA_PATH, dtype={col: str for col in token_cols})
dataset = build_training_data(df)  # Should return list/array of (x_token, x_props, y_target)

In [15]:
# 1. Define the embedding network function
def embedding_network_fn(x):
    embedding_size = N_LAYERS * (BITS_PER_TOKEN + n_ancillas) * 3 
    mlp = hk.Sequential([
        hk.Linear(32), jax.nn.relu,
        hk.Linear(16), jax.nn.relu,
        hk.Linear(embedding_size),  # Match θ shape
    ])
    return mlp(x)

# 2. Transform the function to make it usable in JAX/Haiku
embedding_network = hk.transform(embedding_network_fn)


# 3. Use the embedding to generate θ parameters
def get_context_embedding(prev_token_bits, embed_params):
    """
    prev_token_bits: jnp.array of shape (12,) — concatenated 2×6 bits
    Returns: reshaped embedding output to match θ shape
    """
    embedding_output = embedding_network.apply(embed_params, None, prev_token_bits)
    return embedding_output.reshape((N_LAYERS, BITS_PER_TOKEN + n_ancillas, 3))  # Shape for operator_layer



In [16]:
def bits_to_index(bits):
    powers = 2 ** jnp.arange(len(bits) - 1, -1, -1)
    return jnp.dot(bits, powers).astype(jnp.int32)

def categorical_crossentropy(pred_probs, target_index):
    epsilon = 1e-10
    return -jnp.log(pred_probs[target_index] + epsilon)

'''def label_smoothing_crossentropy_normalized(pred_probs, target_index, epsilon=0.1, alpha=0.1):
    """Cross-entropy loss with label smoothing, normalized to [0,1]."""
    num_classes = pred_probs.shape[0] # Number of classes (tokens)
    
    # Build smoothed target
    smooth_target = jnp.full_like(pred_probs, epsilon / (num_classes - 1))
    smooth_target = smooth_target.at[target_index].set(1.0 - epsilon)
    
    # Compute cross-entropy loss
    ce_loss = -jnp.sum(smooth_target * jnp.log(pred_probs + 1e-10)) # in [0, log(num_classes)]
    max_loss = jnp.log(num_classes)
    total_loss = ce_loss / max_loss  # Normalize to [0,1]

    return total_loss'''
def total_loss_fn(pred_probs, prop_expvals, target_index, props, alpha=0.1, epsilon=0.1):
    """Single normalized loss in [0,1]."""
    num_classes = pred_probs.shape[0]

    # --- Cross-entropy with label smoothing (raw, not normalized yet) ---
    smooth_target = jnp.full_like(pred_probs, epsilon / (num_classes - 1))
    smooth_target = smooth_target.at[target_index].set(1.0 - epsilon)
    ce_loss = -jnp.sum(smooth_target * jnp.log(pred_probs + 1e-10))  # in [0, log(num_classes)]

    # --- Property preservation loss (MSE) ---
    prop_expvals = jnp.array(prop_expvals)  # convert list -> JAX array
    prop_loss = jnp.mean((prop_expvals - jnp.cos(props)) ** 2)  # in [0,4]

    # --- Combine ---
    combined_loss = ce_loss + alpha * prop_loss

    # --- Normalize only once ---
    max_loss = jnp.log(num_classes) + alpha * 4.0
    final_loss = combined_loss / max_loss

    return final_loss

def compute_accuracy(pred_probs, target_index):
    predicted_index = jnp.argmax(pred_probs)
    return jnp.array(predicted_index == target_index, dtype=jnp.float32)


In [17]:
# Initialize embedding params
rng = jax.random.PRNGKey(0)
dummy_context = jnp.zeros((3 * BITS_PER_TOKEN,), dtype=jnp.float32)  # 3 prev. tokens
embedding_params = embedding_network.init(rng, dummy_context)

# Effective qubit counts in variational layers
n_token_ancilla = BITS_PER_TOKEN + n_ancillas

# Initialize theta and sigma params
key = jax.random.PRNGKey(42)
key, k_theta, k_theta_prop, k_sigma = jax.random.split(key, 4)

# Precompute Z-string combos once
token_ancilla_ws = token_wires + ancilla_wires
combos = zstring_combos(token_ancilla_ws)
n_strings = len(combos)


combined_params = {
    'theta': jax.random.normal(k_theta, (N_LAYERS, n_token_ancilla, 3)) * 0.1,
    'theta_prop': jax.random.normal(k_theta_prop, (N_LAYERS, n_prop_qubits, n_token_ancilla, 4)) * 0.1,
    'sigma': jax.random.normal(k_sigma, (N_LAYERS, n_strings)) * 0.1,
    'embedding': embedding_params,
}
# Training hyperparams
learning_rate = 0.001
n_epochs = 25

# Optimizer
optimizer = optax.adam(learning_rate)
opt_state = optimizer.init(combined_params)

@jax.jit
def training_step(params, opt_state, x_token, x_props, y_target):
    def loss_fn(params):
        theta_params = params['theta']
        theta_prop = params['theta_prop']
        sigma_params = params['sigma']
        embedding_params = params['embedding']

        # Embedding → adjustment for theta
        '''theta_from_embedding = get_context_embedding(context_vector, embedding_params)
        theta_effective = theta_params + theta_from_embedding'''
        # theta_params || theta_from_embedding

        # Predict
        pred_probs, expval_props = autoregressive_model(x_token, x_props, theta_params, theta_prop, sigma_params)
        index = bits_to_index(y_target)
        
        # Return scalar loss for gradient computation
        # return label_smoothing_crossentropy_nortotal_loss_fnmalized(pred_probs, index), pred_probs
        return total_loss_fn(pred_probs, expval_props, index, x_props), pred_probs

    # value_and_grad computes both loss and grads in one pass
    (loss, pred_probs), grads = jax.value_and_grad(loss_fn, has_aux=True)(params)

    # Update parameters
    updates, opt_state = optimizer.update(grads, opt_state, params)
    new_params = optax.apply_updates(params, updates)

    # Accuracy from pred_probs (already computed!)
    index = bits_to_index(y_target)
    acc = compute_accuracy(pred_probs, index)

    return new_params, loss, opt_state, grads, acc

for epoch in range(n_epochs):
    total_loss = total_acc = 0.0
    
    for x_token, x_props, y_target in dataset:
        x_token = jnp.array(x_token, dtype=jnp.int32)
        x_props = jnp.array(x_props, dtype=jnp.float32)
        y_target = jnp.array(y_target, dtype=jnp.int32)

        combined_params, loss, opt_state, grads, acc = training_step(combined_params, opt_state, x_token, x_props, y_target)

        total_loss += loss
        total_acc  += acc

    avg_loss = total_loss / len(dataset)
    avg_acc  = total_acc / len(dataset)
    
    print(f"Epoch {epoch+1} | Loss = {avg_loss:.4f} | Accuracy = {avg_acc:.4f}")

  return lax_numpy.astype(self, dtype, copy=copy, device=device)
  return lax_numpy.astype(self, dtype, copy=copy, device=device)


Epoch 1 | Loss = 0.6227 | Accuracy = 0.3137
Epoch 2 | Loss = 0.5602 | Accuracy = 0.3704
Epoch 3 | Loss = 0.5513 | Accuracy = 0.3800
Epoch 4 | Loss = 0.5481 | Accuracy = 0.3804
Epoch 5 | Loss = 0.5478 | Accuracy = 0.3872
Epoch 6 | Loss = 0.5428 | Accuracy = 0.3868
Epoch 7 | Loss = 0.5396 | Accuracy = 0.3832
Epoch 8 | Loss = 0.5403 | Accuracy = 0.3844
Epoch 9 | Loss = 0.5381 | Accuracy = 0.3933
Epoch 10 | Loss = 0.5384 | Accuracy = 0.3903
Epoch 11 | Loss = 0.5365 | Accuracy = 0.3832
Epoch 12 | Loss = 0.5321 | Accuracy = 0.3829
Epoch 13 | Loss = 0.5296 | Accuracy = 0.3870
Epoch 14 | Loss = 0.5272 | Accuracy = 0.3829
Epoch 15 | Loss = 0.5280 | Accuracy = 0.3841
Epoch 16 | Loss = 0.5264 | Accuracy = 0.3849
Epoch 17 | Loss = 0.5255 | Accuracy = 0.3897
Epoch 18 | Loss = 0.5243 | Accuracy = 0.3883
Epoch 19 | Loss = 0.5246 | Accuracy = 0.3891
Epoch 20 | Loss = 0.5221 | Accuracy = 0.3873
Epoch 21 | Loss = 0.5222 | Accuracy = 0.3903
Epoch 22 | Loss = 0.5151 | Accuracy = 0.3996
Epoch 23 | Loss = 0

LAST

Epoch 1 | Loss = 0.5476 | Accuracy = 0.4380
Epoch 2 | Loss = 0.5016 | Accuracy = 0.4766
Epoch 3 | Loss = 0.4801 | Accuracy = 0.5086
Epoch 4 | Loss = 0.4711 | Accuracy = 0.5244
Epoch 5 | Loss = 0.4642 | Accuracy = 0.5353
Epoch 6 | Loss = 0.4627 | Accuracy = 0.5375
Epoch 7 | Loss = 0.4607 | Accuracy = 0.5381
Epoch 8 | Loss = 0.4559 | Accuracy = 0.5449
Epoch 9 | Loss = 0.4504 | Accuracy = 0.5526
Epoch 10 | Loss = 0.4500 | Accuracy = 0.5562
Epoch 11 | Loss = 0.4484 | Accuracy = 0.5641
Epoch 12 | Loss = 0.4450 | Accuracy = 0.5685
Epoch 13 | Loss = 0.4430 | Accuracy = 0.5686
Epoch 14 | Loss = 0.4428 | Accuracy = 0.5658
Epoch 15 | Loss = 0.4396 | Accuracy = 0.5706
Epoch 16 | Loss = 0.4393 | Accuracy = 0.5747
Epoch 17 | Loss = 0.4414 | Accuracy = 0.5708
Epoch 18 | Loss = 0.4397 | Accuracy = 0.5716
Epoch 19 | Loss = 0.4398 | Accuracy = 0.5763
Epoch 20 | Loss = 0.4399 | Accuracy = 0.5767
Epoch 21 | Loss = 0.4415 | Accuracy = 0.5756
Epoch 22 | Loss = 0.4367 | Accuracy = 0.5796
Epoch 23 | Loss = 0.4352 | Accuracy = 0.5825
Epoch 24 | Loss = 0.4356 | Accuracy = 0.5808
Epoch 25 | Loss = 0.4349 | Accuracy = 0.5789

n_layers = 6
h_local = 3
prob_mask = NO MASK

Epoch 1 | Loss = 0.6208 | Accuracy = 0.4185
Epoch 2 | Loss = 0.5437 | Accuracy = 0.4909
Epoch 3 | Loss = 0.5238 | Accuracy = 0.5092
Epoch 4 | Loss = 0.5112 | Accuracy = 0.5222
Epoch 5 | Loss = 0.5034 | Accuracy = 0.5330
Epoch 6 | Loss = 0.5005 | Accuracy = 0.5372
Epoch 7 | Loss = 0.4984 | Accuracy = 0.5426
Epoch 8 | Loss = 0.4958 | Accuracy = 0.5470
Epoch 9 | Loss = 0.4954 | Accuracy = 0.5492
Epoch 10 | Loss = 0.4933 | Accuracy = 0.5497
Epoch 11 | Loss = 0.4887 | Accuracy = 0.5585
Epoch 12 | Loss = 0.4865 | Accuracy = 0.5625
Epoch 13 | Loss = 0.4855 | Accuracy = 0.5594
Epoch 14 | Loss = 0.4844 | Accuracy = 0.5634
Epoch 15 | Loss = 0.4821 | Accuracy = 0.5645
Epoch 16 | Loss = 0.4787 | Accuracy = 0.5690
Epoch 17 | Loss = 0.4776 | Accuracy = 0.5752
Epoch 18 | Loss = 0.4783 | Accuracy = 0.5689
Epoch 19 | Loss = 0.4774 | Accuracy = 0.5783
Epoch 20 | Loss = 0.4756 | Accuracy = 0.5784
Epoch 21 | Loss = 0.4753 | Accuracy = 0.5751
Epoch 22 | Loss = 0.4741 | Accuracy = 0.5789
Epoch 23 | Loss = 0.4762 | Accuracy = 0.5767
Epoch 24 | Loss = 0.4747 | Accuracy = 0.5771
Epoch 25 | Loss = 0.4732 | Accuracy = 0.5787
Epoch 26 | Loss = 0.4715 | Accuracy = 0.5852
Epoch 27 | Loss = 0.4725 | Accuracy = 0.5805
Epoch 28 | Loss = 0.4746 | Accuracy = 0.5792

/Users/ter/Apps/anaconda3/envs/tfm/lib/python3.10/site-packages/jax/_src/numpy/array_methods.py:122: UserWarning: Explicitly requested dtype <class 'jax.numpy.int64'> requested in astype is not available, and will be truncated to dtype int32. To enable more dtypes, set the jax_enable_x64 configuration option or the JAX_ENABLE_X64 shell environment variable. See https://github.com/jax-ml/jax#current-gotchas for more.
  return lax_numpy.astype(self, dtype, copy=copy, device=device)
/Users/ter/Apps/anaconda3/envs/tfm/lib/python3.10/site-packages/jax/_src/numpy/array_methods.py:122: UserWarning: Explicitly requested dtype <class 'jax.numpy.complex128'> requested in astype is not available, and will be truncated to dtype complex64. To enable more dtypes, set the jax_enable_x64 configuration option or the JAX_ENABLE_X64 shell environment variable. See https://github.com/jax-ml/jax#current-gotchas for more.
  return lax_numpy.astype(self, dtype, copy=copy, device=device)
Epoch 1 | Loss = 0.5950 | Accuracy = 0.4449
Epoch 2 | Loss = 0.5483 | Accuracy = 0.4865
Epoch 3 | Loss = 0.5334 | Accuracy = 0.5034
Epoch 4 | Loss = 0.5206 | Accuracy = 0.5211
Epoch 5 | Loss = 0.5140 | Accuracy = 0.5264
Epoch 6 | Loss = 0.5065 | Accuracy = 0.5418
Epoch 7 | Loss = 0.5021 | Accuracy = 0.5492
Epoch 8 | Loss = 0.5007 | Accuracy = 0.5510
Epoch 9 | Loss = 0.4965 | Accuracy = 0.5588
Epoch 10 | Loss = 0.4942 | Accuracy = 0.5562
Epoch 11 | Loss = 0.4921 | Accuracy = 0.5642
Epoch 12 | Loss = 0.4876 | Accuracy = 0.5738
Epoch 13 | Loss = 0.4844 | Accuracy = 0.5769
Epoch 14 | Loss = 0.4819 | Accuracy = 0.5794
Epoch 15 | Loss = 0.4814 | Accuracy = 0.5753
Epoch 16 | Loss = 0.4833 | Accuracy = 0.5763
Epoch 17 | Loss = 0.4813 | Accuracy = 0.5841
Epoch 18 | Loss = 0.4829 | Accuracy = 0.5808
Epoch 19 | Loss = 0.4831 | Accuracy = 0.5795
Epoch 20 | Loss = 0.4797 | Accuracy = 0.5832
Epoch 21 | Loss = 0.4781 | Accuracy = 0.5861
Epoch 22 | Loss = 0.4751 | Accuracy = 0.5866
Epoch 23 | Loss = 0.4750 | Accuracy = 0.5891
Epoch 24 | Loss = 0.4762 | Accuracy = 0.5872
Epoch 25 | Loss = 0.4748 | Accuracy = 0.5912

h=2

Epoch 1 | Loss = 2.6945 | Accuracy = 0.4050 
Epoch 2 | Loss = 2.4173 | Accuracy = 0.4816 
Epoch 3 | Loss = 2.3229 | Accuracy = 0.5047 
Epoch 4 | Loss = 2.2768 | Accuracy = 0.5220
Epoch 5 | Loss = 2.2505 | Accuracy = 0.5267 
Epoch 6 | Loss = 2.2214 | Accuracy = 0.5399 
Epoch 7 | Loss = 2.2051 | Accuracy = 0.5448 
Epoch 8 | Loss = 2.1918 | Accuracy = 0.5523 
Epoch 9 | Loss = 2.1790 | Accuracy = 0.5544 
Epoch 10 | Loss = 2.1683 | Accuracy = 0.5577 
Epoch 11 | Loss = 2.1654 | Accuracy = 0.5582 
Epoch 12 | Loss = 2.1519 | Accuracy = 0.5617 
Epoch 13 | Loss = 2.1452 | Accuracy = 0.5647 
Epoch 14 | Loss = 2.1371 | Accuracy = 0.5670 
Epoch 15 | Loss = 2.1380 | Accuracy = 0.5630 
Epoch 16 | Loss = 2.1330 | Accuracy = 0.5654 
Epoch 17 | Loss = 2.1251 | Accuracy = 0.5661 
Epoch 18 | Loss = 2.1189 | Accuracy = 0.5703 
Epoch 19 | Loss = 2.1176 | Accuracy = 0.5718 
Epoch 20 | Loss = 2.1152 | Accuracy = 0.5698 
Epoch 21 | Loss = 2.1162 | Accuracy = 0.5682 
Epoch 22 | Loss = 2.1117 | Accuracy = 0.5704 
Epoch 23 | Loss = 2.1126 | Accuracy = 0.5729 
Epoch 24 | Loss = 2.1097 | Accuracy = 0.5735 
Epoch 25 | Loss = 2.1040 | Accuracy = 0.5756 
Epoch 26 | Loss = 2.1001 | Accuracy = 0.5751 
Epoch 27 | Loss = 2.0959 | Accuracy = 0.5794 
Epoch 28 | Loss = 2.0881 | Accuracy = 0.5808 
Epoch 29 | Loss = 2.0891 | Accuracy = 0.5806 
Epoch 30 | Loss = 2.0869 | Accuracy = 0.5808 
Epoch 31 | Loss = 2.0843 | Accuracy = 0.5855 
Epoch 32 | Loss = 2.0859 | Accuracy = 0.5839 
Epoch 33 | Loss = 2.0841 | Accuracy = 0.5830