In [19]:
from typing import Tuple, List, Dict, Any

from rdkit import Chem
from rdkit.Chem import AllChem
import selfies as sf

import numpy as np
from math import ceil, log2
import re
import pandas as pd
import optax
import csv
import flax.linen as nn
import math

import pennylane as qml
from pennylane import qchem
from pennylane.templates import StronglyEntanglingLayers

import jax
import jax.numpy as jnp
from jax.nn.initializers import normal

import haiku as hk

In [20]:
from chembl_webresource_client.new_client import new_client

# Using the ChEMBL API to get the molecules dataset

molecule = new_client.molecule

# Filter for drug-like small molecules interesting for human use
druglike_molecules = molecule.filter(
    molecule_properties__heavy_atoms__lte=15,           # Heavy atoms less than 20
    molecule_properties__alogp__lte=5,                  # LogP less than 5 (Lipophilicity and membrane permeability)
    molecule_properties__mw_freebase__lte=300,          # Molecular weight less than 300 g/mol
    molecule_properties__qed_weighted__gte=0.5,         # QED weighted greater than 0.5 (Drug-likeness)
    molecule_properties__num_ro5_violations__lte=1,     # At most 1 Rule of 5 violation (Drug-likeness filter)

)

print("Training molecules set: ", len(druglike_molecules))  # Check how many molecules match the filter criteria

Training molecules set:  65778


## 1. Get the alphabet used for the SMILEs representation

Build the alphabet considering the structure of some special tokens

In [21]:
# Define the subset of molecules we are going to train the model with
molecules_subset = druglike_molecules[:500]

max_len = 0
alphabet = set()
# We use a subset of molecules to build the alphabet
for mol in molecules_subset:  # Limiting to 1000 molecules for performance
    smiles = mol.get('molecule_structures', {}).get('canonical_smiles')
    selfies = sf.encoder(smiles)
    if selfies:
        # Skip if contains '.'
        if "." in selfies:
            continue
        tokens = list(sf.split_selfies(selfies))
        if max_len < len(tokens):
            max_len = len(tokens)
        alphabet.update(tokens)

alphabet = sorted(alphabet)
alphabet = ['<SOS>'] + alphabet  + ['<EOS>'] # Add Start-of-Secuence, End-of-secuence

In [22]:
print("Alphabet of SELFIES characters:", alphabet)

VOCABULARY_SIZE = len(alphabet)
BITS_PER_TOKEN = ceil(log2(VOCABULARY_SIZE))  # n bits por token

print("Total unique characters in SELFIES:", VOCABULARY_SIZE)
print("Maximum length of SELFIES in dataset:", max_len)
print("Bits per token:", BITS_PER_TOKEN)

Alphabet of SMILES characters: ['<SOS>', '[#Branch1]', '[#Branch2]', '[#C]', '[/C]', '[/N]', '[/S]', '[=Branch1]', '[=Branch2]', '[=C]', '[=N+1]', '[=N]', '[=O]', '[=P]', '[=Ring1]', '[=S]', '[Br]', '[Branch1]', '[Branch2]', '[C@@H1]', '[C@@]', '[C@H1]', '[C@]', '[C]', '[Cl]', '[F]', '[I]', '[N+1]', '[NH1]', '[N]', '[O-1]', '[O]', '[PH1]', '[P]', '[Ring1]', '[Ring2]', '[S+1]', '[S]', '[\\C]', '[\\Cl]', '[\\N]', '<EOS>']
Total unique characters in SMILES: 42
Maximum length of SMILES in dataset: 29
Bits per token: 6


•	Símbolos de enlaces y paréntesis: #, (, ), /, \, =

•	Dígitos simples para cierres de anillos: '1', '2', '3', '4', '5'

•	Átomos orgánicos y halógenos comunes, tanto mayúsculas (alifáticos) como minúsculas (aromáticos)

•	Tokens entre corchetes para isótopos, estados de carga, quiralidad, etc.

•	Un token especial '<PAD>' para padding en modelos ML

In [23]:
# Diccionario token → índice
token_to_index = {tok: i for i, tok in enumerate(alphabet)}

def print_token_bits(tokens, token_to_index):
    for tok in tokens:
        idx = token_to_index.get(tok, None)
        if idx is None:
            print(f"Token '{tok}' no está en el diccionario.")
            continue
        binary = format(idx, f'0{BITS_PER_TOKEN}b')
        print(f"'{tok}' → index {idx} → {binary}")

print_token_bits(alphabet, token_to_index)

'<SOS>' → index 0 → 000000
'[#Branch1]' → index 1 → 000001
'[#Branch2]' → index 2 → 000010
'[#C]' → index 3 → 000011
'[/C]' → index 4 → 000100
'[/N]' → index 5 → 000101
'[/S]' → index 6 → 000110
'[=Branch1]' → index 7 → 000111
'[=Branch2]' → index 8 → 001000
'[=C]' → index 9 → 001001
'[=N+1]' → index 10 → 001010
'[=N]' → index 11 → 001011
'[=O]' → index 12 → 001100
'[=P]' → index 13 → 001101
'[=Ring1]' → index 14 → 001110
'[=S]' → index 15 → 001111
'[Br]' → index 16 → 010000
'[Branch1]' → index 17 → 010001
'[Branch2]' → index 18 → 010010
'[C@@H1]' → index 19 → 010011
'[C@@]' → index 20 → 010100
'[C@H1]' → index 21 → 010101
'[C@]' → index 22 → 010110
'[C]' → index 23 → 010111
'[Cl]' → index 24 → 011000
'[F]' → index 25 → 011001
'[I]' → index 26 → 011010
'[N+1]' → index 27 → 011011
'[NH1]' → index 28 → 011100
'[N]' → index 29 → 011101
'[O-1]' → index 30 → 011110
'[O]' → index 31 → 011111
'[PH1]' → index 32 → 100000
'[P]' → index 33 → 100001
'[Ring1]' → index 34 → 100010
'[Ring2]' → index

The input SMILEs must be the same size, so we need to use the padding to make it uniform

In [24]:
basis_encoded_dataset = []
token_to_index = {tok: i for i, tok in enumerate(alphabet)}

def smiles_to_bits(tokens: list) -> np.ndarray:
    """Convert tokens to a 2D array"""
    padded_tokens = ['<SOS>'] + tokens + ['<EOS>']
    bit_matrix = []
    for tok in padded_tokens:
        idx = token_to_index[tok]
        bits = list(f"{idx:0{BITS_PER_TOKEN}b}")  # length of the binary string depends on the number of bits required to represent the alphabet
        bit_matrix.append([int(b) for b in bits])
    return np.array(bit_matrix)

## 2. Obtain the molecular properties of interest

logP (o cx_logp) -> Coeficiente de partición octanol/agua (lipofilia)

QED (quantitative estimate of drug-likeness) -> Escala combinada que evalúa qué tan “drug-like” es una molécula

SAS (Synthetic Accessibility Score) -> Qué tan difícil sería sintetizar la molécula en laboratorio * Needs to be calculated separately!!

MW (peso molecular) -> Masa total, típicamente ≤ 500 Da para buenos fármacos orales


First we get the max and min range values in order to later normalize the properties in the range (0, pi)

In [25]:
min_logp = float('inf')
max_logp = float('-inf')
min_qed = float('inf')
max_qed = float('-inf')
min_mw = float('inf')
max_mw = float('-inf')


# Iterate through the subset of molecules to find min/max properties to normalize them
for mol in molecules_subset:
    logP = mol.get('molecule_properties', {}).get('alogp')
    qed = mol.get('molecule_properties', {}).get('qed_weighted')
    mw = mol.get('molecule_properties', {}).get('mw_freebase')

    if logP is None or qed is None or mw is None:
        continue  # Skip if any property is missing

    logP = float(logP)
    qed = float(qed)
    mw = float(mw)

    if logP < min_logp:
        min_logp = logP
    if logP > max_logp:
        max_logp = logP

    if qed < min_qed:
        min_qed = qed
    if qed > max_qed:
        max_qed = qed

    if mw < min_mw:
        min_mw = mw
    if mw > max_mw:
        max_mw = mw

print(f"LogP range: {min_logp} to {max_logp}")
print(f"QED range: {min_qed} to {max_qed}")
print(f"MW range: {min_mw} to {max_mw}")

LogP range: -1.5 to 3.89
QED range: 0.5 to 0.92
MW range: 111.14 to 298.11


In [26]:
def normalize(value, min_val, max_val, target_max=np.pi):
    ''' Normalize a value to a range [0, [0, pi] to later encode them as rotation angles'''
    norm = (value - min_val) / (max_val - min_val) * target_max
    return float(f"{norm:.3f}")

In [27]:
print("Maximum length of sequences in the subset:", max_len)

Maximum length of sequences in the subset: 29


In [28]:
# Write the structured data to a CSV file
DATA_PATH = "../data/structured_data_selfies.csv"
with open(DATA_PATH, mode="w", newline="") as file:
    writer = csv.writer(file)

    n_tokens = max_len + 2  # +2 for <SOS> and <EOS>
    header = ["logP", "qed", "mw"] + [f"token_{i}" for i in range(n_tokens)]
    writer.writerow(header)


    for mol in molecules_subset:
        smiles = mol.get('molecule_structures', {}).get('canonical_smiles')
        selfies = sf.encoder(smiles)
        props = mol.get('molecule_properties', {})
        if not selfies:
            continue
        if "." in selfies:
            continue
        try:
            logP = float(props.get('alogp'))
            qed = float(props.get('qed_weighted'))
            mw = float(props.get('mw_freebase'))
        except (TypeError, ValueError):
            continue

        norm_logp = normalize(logP, min_logp, max_logp)
        norm_qed = normalize(qed, min_qed, max_qed)
        norm_mw = normalize(mw, min_mw, max_mw)

        tokens = list(sf.split_selfies(selfies))
        if not all(tok in token_to_index for tok in tokens):
            continue

        bit_matrix = smiles_to_bits(tokens)  # shape (n_tokens, 6)
        token_bits_as_strings = ["".join(map(str, row)) for row in bit_matrix]

        row = [norm_logp, norm_qed, norm_mw] + token_bits_as_strings
        writer.writerow(row)

## 3. Quantum Generative Model

In [29]:
import itertools
def zstring_combos(wires):
    """
    Return an ordered list of wire-tuples for all Z-strings up to order H_LOCAL.
    Order: all 1-local, then all 2-local, ..., up to H_LOCAL.
    """
    L = []
    for k in range(1, H_LOCAL + 1):
        L.extend(itertools.combinations(wires, k))
    return [tuple(c) for c in L]

def num_zstrings(n_wires):
    """Count how many Z-strings up to order H_LOCAL."""
    from math import comb
    return sum(comb(n_wires, k) for k in range(1, H_LOCAL + 1))

In [30]:
# Quantum Attention mechanism using SWAP test

# --- Device for attention ---
n_past = 5
attn_dev = qml.device("default.qubit", wires=(BITS_PER_TOKEN+1)*n_past)

@qml.qnode(attn_dev, interface="jax")
def quantum_attention_qnode(Q_vec, K_vecs):
    """
    Q_vec: projected query vector of current token
    K_vecs: list of projected key vectors for past tokens
    Returns: attention scores ⟨q_i | k_j⟩ for each j
    """
    n_tokens = len(K_vecs)
    q_wires = list(range(BITS_PER_TOKEN))

    # Encode Q and all K in parallel (different wire registers)
    def encode_token(angles, wires):
        # Use AngleEmbedding for compactness, then entangle
        qml.templates.AngleEmbedding(angles, wires=wires, rotation="Y")
        for i in range(len(wires)-1):
            qml.CNOT(wires=[wires[i], wires[i+1]])

    # Encode Q
    encode_token(Q_vec, wires=q_wires)

    # Encode K_j
    # Collect expectation values (one per K_j)
    measurements = []
    for j, K_j in enumerate(K_vecs):
        start = (BITS_PER_TOKEN+1) + j*(BITS_PER_TOKEN+1)
        k_wires = list(range(start, start+BITS_PER_TOKEN))
        encode_token(K_j, k_wires)

        # SWAP test
        ancilla = start + BITS_PER_TOKEN
        qml.Hadamard(wires=ancilla)
        for qw, kw in zip(q_wires, k_wires):
            qml.CSWAP(wires=[ancilla, qw, kw])
        qml.Hadamard(wires=ancilla)

        measurements.append(qml.expval(qml.PauliZ(ancilla)))

    # **Return as tuple** so PennyLane converts to JAX array
    return tuple(measurements)


def quantum_attention(Q_vec, K_vecs, V_vecs):
    # Make non-traced (concrete) copies for the QNode
    Q_safe = jax.lax.stop_gradient(Q_vec)
    K_safe = [jax.lax.stop_gradient(k) for k in K_vecs]

    raw_expvals = quantum_attention_qnode(Q_safe, K_safe)
    raw_expvals = jnp.asarray(raw_expvals)
    # Convert from expectation values (in [-1,1]) to probabilities [0,1]
    overlaps = (1.0 - raw_expvals) / 2.0
    return overlaps


def classical_attention(Q_vec, K_vecs, V_vecs, scale=True):
    # Q_vec: (proj_dim,) ; K_vecs: list of (proj_dim,) ; V_vecs: list of (proj_dim,)
    if len(K_vecs) == 0:
        return V_vecs[0] if len(V_vecs) > 0 else jnp.zeros_like(Q_vec)

    K_mat = jnp.stack(K_vecs)  # shape (n_past, proj_dim)
    # dot-product softmax (fully differentiable)
    scores = jnp.dot(K_mat, Q_vec)  # shape (n_past,)
    if scale:
        scores = scores / jnp.sqrt(Q_vec.shape[0])
    # Causal mask: only past tokens (K_vecs are already past tokens)
    weights = jax.nn.softmax(scores)
    # Weighted sum over classical V
    output = jnp.sum(weights[:, None] * jnp.stack(V_vecs), axis=0)
    return output

In [31]:
# Device and qubit setup
# BITS_PER_TOKEN number of qubits needed to encode each token
n_prop_qubits = 3  # number of qubits needed to encode properties (logP, QED, MW)
n_ancillas = 3  # number of ancilla qubits that represent the environment
n_total_qubits = n_prop_qubits + BITS_PER_TOKEN + n_ancillas

N_LAYERS = 6  # number of variational layers
H_LOCAL = 3 # h_local sets the maximum number of qubits that can interact in each Z-string term of Σ


# Name them explicitly
prop_wires = [f"prop_{i}" for i in range(n_prop_qubits)]
token_wires = [f"token_{i}" for i in range(BITS_PER_TOKEN)]
ancilla_wires = [f"ancilla_{i}" for i in range(n_ancillas)]
all_wires = prop_wires + token_wires + ancilla_wires

dev = qml.device("default.qubit", wires=all_wires)


def molecular_property_encoder(props):
    """Encode continuous props on property qubits via RY rotations"""
    for wire, val in zip(prop_wires, props):
        qml.RY(val, wires=wire)

def token_encoder(token_bits):
    """Basis-encode token bits on token qubits"""
    qml.BasisState(token_bits, wires=token_wires)


def operator_layer(theta_params, theta_prop, wires):
    """
    Variational layer where:
      - theta_params[...] are rotations for token + ancilla qubits
      - theta_prop encodes property→token entanglement
    """
    token_ancilla_ws = token_wires + ancilla_wires

    # Property → token entanglement 
    for p, prop_wire in enumerate(prop_wires):
        for t, t_a_wire in enumerate(token_ancilla_ws):
            qml.CRX(theta_prop[p, t, 0], wires=[prop_wire, t_a_wire])
            qml.CRY(theta_prop[p, t, 1], wires=[prop_wire, t_a_wire])

 
    qml.StronglyEntanglingLayers(
        weights=theta_params[None,:,:],  # shape: (n_token_ancilla, 3)
        wires=token_ancilla_ws
    )

def Sigma_layer_vec(gamma_vec, token_ancilla_ws, time=1.0, combos=None):
    """
    Diagonal multi-Z unitary Σ = exp(i * sum_s gamma_s * Z^{⊗|s|} * t)
    using a flat parameter vector 'gamma_vec' aligned with 'combos'.
    """
    #token_ancilla_ws = list(wires)  # pass token+ancilla here
    if combos is None:
        combos = zstring_combos(token_ancilla_ws)

    # Safety: ensure the vector length matches the number of combos
    assert gamma_vec.shape[0] == len(combos), \
        f"gamma_vec has length {gamma_vec.shape[0]} but expected {len(combos)}"

    # MultiRZ(phi) = exp(-i * phi/2 * Z^{⊗k}); choose phi = -2 * gamma * time
    for gamma, combo in zip(gamma_vec, combos):
        qml.MultiRZ(-2.0 * gamma * time, wires=list(combo))


# QNode combining encoding and variational layers
@qml.qnode(dev, interface="jax")
def autoregressive_model(token_bits, props, theta_params, theta_prop, sigma_params, output_i):
    molecular_property_encoder(props)      # Encode MW, logP, QED
    token_encoder(token_bits)              # Basis-encode token bits

    # --- Encode output_i embedding safely ---
    for i, val in enumerate(output_i):
        qml.RY(val, wires=token_wires[i])

    token_ancilla_ws = token_wires + ancilla_wires
    combos = zstring_combos(token_ancilla_ws)

    for l in range(N_LAYERS):
        # Forward V(θ)
        operator_layer(theta_params[l], theta_prop[l], wires=all_wires)

        # Diagonal Σ(γ,t): vector API
        Sigma_layer_vec(sigma_params[l], token_ancilla_ws, time=1.0, combos=combos)

        # Backward V(θ)†
        qml.adjoint(operator_layer)(theta_params[l], theta_prop[l], wires=all_wires)

    return qml.probs(wires=token_wires), [qml.expval(qml.PauliZ(w)) for w in prop_wires]


In [32]:
def bitstr_to_array(bitstr):
    """Convert a string of bits (e.g., '010101') to a numpy float32 array."""
    return np.array([int(b) for b in bitstr], dtype=np.float32)

def build_training_data(df):
    """
    Build dataset tuples of (input_token_bits, molecular_properties, target_token_bits)
    from a DataFrame.

    Args:
        df (pandas.DataFrame): DataFrame containing molecular properties and token bit strings.
        n_token_cols (int): Number of token columns in the DataFrame.

    Returns:
        list of tuples: Each tuple contains (x_token: np.array, x_props: np.array, y_target: np.array)
    """
    dataset = []

    for _, row in df.iterrows():
        # Extract molecular properties as a numpy float32 array
        props = [row['logP'], row['qed'], row['mw']]
        x_props = np.array(props, dtype=np.float32)

        tokens = row[3:]  # token columns after properties

        # Iterate over token sequence to create input-target pairs
        for i in range(len(tokens) - 1):
            current_token = tokens.iloc[i]
            next_token = tokens.iloc[i + 1]

            # Skip missing or NaN tokens
            if current_token is None or (isinstance(current_token, float) and math.isnan(current_token)):
                continue
            if next_token is None or (isinstance(next_token, float) and math.isnan(next_token)):
                continue

            # Convert token strings (e.g., '01011') to bit arrays
            x_token = bitstr_to_array(current_token)
            y_target = bitstr_to_array(next_token)

            dataset.append((x_token, x_props, y_target))

    return dataset


# Load dataset
token_cols = [f"token_{i}" for i in range(n_tokens)]
df = pd.read_csv(DATA_PATH, dtype={col: str for col in token_cols})
dataset = build_training_data(df)  # Should return list/array of (x_token, x_props, y_target)

In [44]:
def bits_to_index(bits):
    powers = 2 ** jnp.arange(len(bits) - 1, -1, -1)
    return jnp.dot(bits, powers).astype(jnp.int32)

def categorical_crossentropy(pred_probs, target_index):
    epsilon = 1e-10
    return -jnp.log(pred_probs[target_index] + epsilon)

def total_loss_fn(pred_probs, prop_expvals, target_index, props, alpha=0.1, epsilon=0.1):
    """Cross-entropy loss with label smoothing, normalized to [0,1]."""
    num_classes = pred_probs.shape[0] # Number of classes (tokens)
    
    # Build smoothed target
    smooth_target = jnp.full_like(pred_probs, epsilon / (num_classes - 1))
    smooth_target = smooth_target.at[target_index].set(1.0 - epsilon)
    
    # --- Compute cross-entropy loss
    ce_loss = -jnp.sum(smooth_target * jnp.log(pred_probs + 1e-10))
    
    # --- Property preservation loss (MSE) ---
    prop_expvals = jnp.array(prop_expvals)  # convert list -> JAX array
    prop_loss = jnp.mean((prop_expvals - jnp.cos(props)) ** 2)  # in [0,4]

    # --- Combine ---
    combined_loss = ce_loss + alpha * prop_loss

    # --- Normalize only once ---
    max_loss = jnp.log(num_classes) + alpha * 4.0
    final_loss = combined_loss / max_loss
    
    return final_loss


def compute_accuracy(pred_probs, target_index):
    predicted_index = jnp.argmax(pred_probs)
    return jnp.array(predicted_index == target_index, dtype=jnp.float32)


In [47]:
# Token embedding
key = jax.random.PRNGKey(42)
EMBEDDING_SIZE = BITS_PER_TOKEN + n_ancillas      # size of embeddings
key, k_emb = jax.random.split(key)
embedding_table = jax.random.normal(k_emb, (VOCABULARY_SIZE, EMBEDDING_SIZE)) * 0.1

# Projection matrices
key = jax.random.PRNGKey(42)
proj_dim = BITS_PER_TOKEN  # number of qubits for quantum attention
key, k_WQ, k_WK, k_WV = jax.random.split(key, 4)
W_Q = jax.random.normal(k_WQ, (EMBEDDING_SIZE, proj_dim)) * 0.1
W_K = jax.random.normal(k_WK, (EMBEDDING_SIZE, proj_dim)) * 0.1
W_V = jax.random.normal(k_WV, (EMBEDDING_SIZE, proj_dim)) * 0.1


# Effective qubit counts in variational layers
n_token_ancilla = BITS_PER_TOKEN + n_ancillas

# Initialize theta and sigma params
key = jax.random.PRNGKey(42)
key, k_theta, k_theta_prop, k_sigma = jax.random.split(key, 4)

# Precompute Z-string combos once
token_ancilla_ws = token_wires + ancilla_wires
combos = zstring_combos(token_ancilla_ws)
n_strings = len(combos)

# Combine all trainable parameters into a single dictionary
combined_params = {
    'theta': jax.random.normal(k_theta, (N_LAYERS, n_token_ancilla, 3)) * 0.1,
    'theta_prop': jax.random.normal(k_theta_prop, (N_LAYERS, n_prop_qubits, n_token_ancilla, 4)) * 0.1,
    'sigma': jax.random.normal(k_sigma, (N_LAYERS, n_strings)) * 0.1,
    'embedding_table': embedding_table,
    'W_Q': W_Q,
    'W_K': W_K,
    'W_V': W_V
}
# Training hyperparams
learning_rate = 0.001
n_epochs = 100

# Optimizer
optimizer = optax.adam(learning_rate)
opt_state = optimizer.init(combined_params)


@jax.jit
def training_step(params, opt_state, x_token, x_props, y_target, past_token_indices=None):
    if past_token_indices is None:
        past_token_indices = []

    def loss_fn(params):
        theta_params = params['theta']
        theta_prop = params['theta_prop']
        sigma_params = params['sigma']
        embedding_table = params['embedding_table']
        W_Q = params['W_Q']
        W_K = params['W_K']
        W_V = params['W_V']

        # --- 1. Embedding lookup ---
        #start = time.time()
        token_index = bits_to_index(x_token)
        x_i = embedding_table[token_index]
        #print("Embedding lookup:", time.time() - start, "seconds")
        
        # Positional encoding (sin/cos)
        position = len(past_token_indices)   # current position in the sequence
        dim_indices = jnp.arange(EMBEDDING_SIZE)
        pos_enc = jnp.where(
            dim_indices % 2 == 0,
            jnp.sin(position / (10000 ** (dim_indices / EMBEDDING_SIZE))),
            jnp.cos(position / (10000 ** ((dim_indices-1) / EMBEDDING_SIZE)))
        )

        # Combine token embedding + positional encoding
        x_i_pos = x_i + pos_enc
        # --- 2. Q/K/V projections ---
        #start = time.time()
        Q_i = x_i_pos @ W_Q
        
        if len(past_token_indices) == 0:
            # No past tokens: use V projection of current token as output
            output_i = x_i_pos @ W_V
        else:
            # Normal attention over past tokens
            past_embeddings = [embedding_table[idx] for idx in past_token_indices]
            K_vecs = [x @ W_K for x in past_embeddings]
            V_vecs = [x @ W_V for x in past_embeddings]
            # --- 3. Quantum attention ---
            # Use classical attention during training (differentiable)
            output_i = classical_attention(Q_i, K_vecs, V_vecs)
        #print("Attention:", time.time() - start, "seconds")

        # --- 4. Variational model ---
        # Before calling autoregressive_model, make output non-traced
        #start = time.time()
        pred_probs, expval_props = autoregressive_model(x_token, x_props, theta_params, theta_prop, sigma_params, output_i)
        #print("Variational model:", time.time() - start, "seconds")

        target_index = bits_to_index(y_target)
        # Return scalar loss for gradient computation
        return total_loss_fn(pred_probs, expval_props, target_index, x_props), pred_probs

    # value_and_grad computes both loss and grads in one pass
    (loss, pred_probs), grads = jax.value_and_grad(loss_fn, has_aux=True)(params)

    # Update parameters
    #start = time.time()
    updates, opt_state = optimizer.update(grads, opt_state, params)
    new_params = optax.apply_updates(params, updates)
    #print("Update params:", time.time() - start, "seconds")
    
    # Accuracy from pred_probs (already computed!)
    target_index = bits_to_index(y_target)
    acc = compute_accuracy(pred_probs, target_index)

    return new_params, loss, opt_state, grads, acc


In [48]:
for epoch in range(n_epochs):
    total_loss = total_acc = 0.0
    past_token_indices = [] # Reset past tokens at the start of each epoch

    for x_token, x_props, y_target in dataset:
        x_token = jnp.array(x_token, dtype=jnp.int32)
        x_props = jnp.array(x_props, dtype=jnp.float32)
        y_target = jnp.array(y_target, dtype=jnp.int32)

        if jnp.all(x_token == 0):
            # If current token is <SOS>, reset past tokens
            past_token_indices = []

        combined_params, loss, opt_state, grads, acc = training_step(combined_params, opt_state, x_token, x_props, y_target, past_token_indices)

        total_loss += loss
        total_acc  += acc

        past_token_indices.append(bits_to_index(x_token))

    avg_loss = total_loss / len(dataset)
    avg_acc  = total_acc / len(dataset)
    
    print(f"Epoch {epoch+1} | Loss = {avg_loss:.4f} | Accuracy = {avg_acc:.4f}")

  return lax_numpy.astype(self, dtype, copy=copy, device=device)
  return lax_numpy.astype(self, dtype, copy=copy, device=device)


Epoch 1 | Loss = 0.6477 | Accuracy = 0.3118
Epoch 2 | Loss = 0.5822 | Accuracy = 0.3786
Epoch 3 | Loss = 0.5650 | Accuracy = 0.3998
Epoch 4 | Loss = 0.5578 | Accuracy = 0.4121
Epoch 5 | Loss = 0.5515 | Accuracy = 0.4201
Epoch 6 | Loss = 0.5460 | Accuracy = 0.4265
Epoch 7 | Loss = 0.5449 | Accuracy = 0.4255
Epoch 8 | Loss = 0.5456 | Accuracy = 0.4230
Epoch 9 | Loss = 0.5442 | Accuracy = 0.4215
Epoch 10 | Loss = 0.5412 | Accuracy = 0.4266
Epoch 11 | Loss = 0.5383 | Accuracy = 0.4258
Epoch 12 | Loss = 0.5382 | Accuracy = 0.4179
Epoch 13 | Loss = 0.5359 | Accuracy = 0.4187
Epoch 14 | Loss = 0.5346 | Accuracy = 0.4224
Epoch 15 | Loss = 0.5306 | Accuracy = 0.4264
Epoch 16 | Loss = 0.5274 | Accuracy = 0.4250
Epoch 17 | Loss = 0.5232 | Accuracy = 0.4365
Epoch 18 | Loss = 0.5178 | Accuracy = 0.4450
Epoch 19 | Loss = 0.5158 | Accuracy = 0.4438
Epoch 20 | Loss = 0.5136 | Accuracy = 0.4475
Epoch 21 | Loss = 0.5134 | Accuracy = 0.4498
Epoch 22 | Loss = 0.5122 | Accuracy = 0.4468
Epoch 23 | Loss = 0

In [51]:
import os
import pickle

# Define the directory path (relative to your current location, assuming you are in /content/QGen-Mol/code)
target_dir = '../data/params/'

# Create the directory if it doesn't exist
os.makedirs(target_dir, exist_ok=True)

# Now, run your pickle code:
with open(os.path.join(target_dir, 'selfies_params.pkl'), "wb") as f:
    pickle.dump(combined_params, f)

In [104]:
from jax import random

def generate_molecule_selfies_stochastic(key, props, combined_params, temperature=1.0, max_length=29):
    """
    Generates a molecule stochastically using Attention weights for context,
    suitable for the SELFIES model.
    """
    # --- 1. Parameter Unpacking (Match SELFIES/Attention model) ---
    embedding_table = combined_params['embedding_table']
    W_Q = combined_params['W_Q']
    W_K = combined_params['W_K']
    W_V = combined_params['W_V']
    theta_params = combined_params['theta']
    theta_prop = combined_params['theta_prop']
    sigma_params = combined_params['sigma']
    
    generated_bits = []
    past_token_indices = []
    
    # Start with <SOS>
    SOS_index = token_to_index['<SOS>']
    current_token_index = SOS_index
    # Create the <SOS> bit array (jnp.int32 array of shape (BITS_PER_TOKEN,))
    current_token_bits = jnp.array(list(map(int, format(SOS_index, f'0{BITS_PER_TOKEN}b'))), dtype=jnp.int32)
    
    # Store the local RNG state
    local_rng = key 
    
    # --- 2. Generation Loop ---
    for t in range(max_length):
        # Split the key for inner-loop stochasticity
        local_rng, subkey = random.split(local_rng)

        # Map current token bits to index and lookup embedding
        x_token = current_token_bits 
        current_token_index = bits_to_index(x_token)
        x_i = embedding_table[current_token_index]

        # Apply Positional Encoding
        position = len(past_token_indices)
        dim_indices = jnp.arange(EMBEDDING_SIZE)
        pos_enc = jnp.where(
            dim_indices % 2 == 0,
            jnp.sin(position / (10000 ** (dim_indices / EMBEDDING_SIZE))),
            jnp.cos(position / (10000 ** ((dim_indices-1) / EMBEDDING_SIZE)))
        )
        x_i_pos = x_i + pos_enc

        # --- Attention Calculation (Replaces theta-Embedding MLP) ---
        Q_i = x_i_pos @ W_Q

        if len(past_token_indices) == 0:
            # First token (after <SOS>), no past context yet
            output_i = x_i_pos @ W_V
        else:
            # Calculate K and V vectors for all past tokens
            past_embeddings = [embedding_table[idx] for idx in past_token_indices]
            K_vecs = [x @ W_K for x in past_embeddings]
            V_vecs = [x @ W_V for x in past_embeddings]
            output_i = classical_attention(Q_i, K_vecs, V_vecs)
        # -------------------------------------------------------------
        
        # 3. Predict probabilities (theta_effective is just theta_params here)
        # Note: Your SELFIES training used constant theta_params (no dynamic embedding).
        # We pass the Attention output as the final parameter, output_i.
        pred_probs, _ = autoregressive_model(x_token, props, theta_params, theta_prop, sigma_params, output_i)
        
        # 4. Tempering and Sampling
        # Ensure we only consider probabilities for valid tokens (0 to 41)
        logits = jnp.log(pred_probs[:VOCABULARY_SIZE] + 1e-10)

        # Scale logits by Temperature
        tempered_logits = logits / temperature

        # Apply softmax to get the new, tempered probability distribution
        tempered_probs = nn.softmax(tempered_logits)
        
        # Sample the next index from the Tempered probability distribution
        token_indices = jnp.arange(VOCABULARY_SIZE)
        next_index = random.choice(
            subkey, 
            token_indices, 
            p=tempered_probs 
        )

        # Convert index to bits 
        next_bits_str = format(int(next_index), f'0{BITS_PER_TOKEN}b')
        next_bits = jnp.array([int(b) for b in next_bits_str], dtype=jnp.int32)

        # 5. Check for <EOS>
        if int(next_index) == token_to_index['<EOS>']:
            break

        # Update previous tokens for the next step
        generated_bits.append(next_bits)
        past_token_indices.append(current_token_index) # Store index of the token just consumed
        current_token_bits = next_bits # Set the new token for the next iteration

    return generated_bits

# Convert generated bits to SELFIES string (adjusted for SELFIES alphabet)
def bits_to_selfies_smiles(generated_bits):
    selfies_tokens = []
    for bits in generated_bits:
        index = int("".join(map(str, bits)), 2)
        
        # Safety clip, though the sampling process above should enforce this
        if index >= VOCABULARY_SIZE or index == 0:
            continue

        token = alphabet[int(index)]
        if token == '<EOS>':
            break
        selfies_tokens.append(token)
    
    selfies_str = ''.join(selfies_tokens)
    smiles_str = sf.decoder(selfies_str)
    return selfies_str, smiles_str

In [113]:
N_MOLECS = 50
MASTER_KEY = jr.PRNGKey(50)  # Fixed key for reproducibility

# Target properties (mid-range example normalized to [0, pi])
desired_logp = 1.2
desired_qed = 0.71
desired_mw = 205.0

norm_logp = normalize(desired_logp, min_logp, max_logp)
norm_qed = normalize(desired_qed, min_qed, max_qed)
norm_mw = normalize(desired_mw, min_mw, max_mw)
desired_props = jnp.array([norm_logp, norm_qed, norm_mw], dtype=jnp.float32)

print(f"Target Properties (Normalized to [0, pi]):")
print(f"   LogP: {norm_logp:.3f}, QED: {norm_qed:.3f}, MW: {norm_mw:.3f}\n")


def generate_molecules(props, params):
    selfies_list = []
    smiles_list = []
    keys = jr.split(MASTER_KEY, N_MOLECS)
    for i in range(N_MOLECS):
        rng_i = keys[i]
        generated_bits = generate_molecule_selfies_stochastic(rng_i, props, params)
        generated_selfies, generated_smiles = bits_to_selfies_smiles(generated_bits)
        selfies_list.append(generated_selfies)
        smiles_list.append(generated_smiles)
    return selfies_list, smiles_list

selfies_list, smiles_list = generate_molecules(desired_props, combined_params)



Target Properties (Normalized to [0, pi]):
   LogP: 1.574, QED: 1.571, MW: 1.577



  return lax_numpy.astype(self, dtype, copy=copy, device=device)


In [115]:
from rdkit import Chem
from rdkit.Chem import Descriptors
from rdkit.Chem import QED

def analyze_molecule_properties(selfies_list, smiles_list, target_logp, target_qed, target_mw):
    """Calculates and prints the physicochemical properties for generated molecules,
       comparing them directly against the denormalized targets."""
    
    results = []
    
    # --- Print Denormalized Target Properties ---
    print(f"\n--- Target Properties ---")
    print(f"LogP: {target_logp:.2f}")
    print(f"QED: {target_qed:.2f}")
    print(f"MW: {target_mw:.2f} g/mol")
    print("-" * 35)

    for i, (smiles, selfies) in enumerate(zip(smiles_list, selfies_list)):
        # Handle empty/invalid SMILES from generation failure
        if not smiles:
            results.append({"Molecule": i+1, "SMILES": "N/A", "LogP": np.nan, "QED": np.nan, "MW": np.nan})
            continue

        mol = Chem.MolFromSmiles(smiles)
        
        if mol:
            try:
                logp = Descriptors.MolLogP(mol)
                qed_score = QED.qed(mol)
                mw = Descriptors.ExactMolWt(mol)

                results.append({
                    "Molecule": i+1,
                    "SELFIES": selfies,
                    "SMILES": smiles,
                    "LogP": logp,
                    "QED": qed_score,
                    "MW": mw
                })
            except Exception:
                 results.append({"Molecule": i+1, "SELFIES": selfies, "SMILES": smiles, "LogP": np.nan, "QED": np.nan, "MW": np.nan})

    # Create and display the DataFrame
    df = pd.DataFrame(results)
    
    # Add a row for the target properties for easy comparison
    target_row = pd.Series({
        "Molecule": "TARGET", 
        "SELFIES": "TARGET",
        "SMILES": "TARGET", 
        "LogP": target_logp, 
        "QED": target_qed, 
        "MW": target_mw
    }, name="TARGET").to_frame().T
    
    # Concatenate the target row and the results for visual comparison
    df_styled = pd.concat([target_row.set_index('Molecule'), df.set_index('Molecule')])
    
    # Format numerical columns for presentation
    df_styled = df_styled.apply(pd.to_numeric, errors='ignore').round(2)
    
    return df_styled

df_styled = analyze_molecule_properties(selfies_list, smiles_list, desired_logp, desired_qed, desired_mw)


--- Target Properties ---
LogP: 1.20
QED: 0.71
MW: 205.00 g/mol
-----------------------------------


  df_styled = df_styled.apply(pd.to_numeric, errors='ignore').round(2)


-------------------------------------------------------------

Epoch 1 | Loss = 1.2955 | Accuracy = 0.0841
Epoch 2 | Loss = 0.9395 | Accuracy = 0.1449
Epoch 3 | Loss = 0.8566 | Accuracy = 0.2056
Epoch 4 | Loss = 0.8215 | Accuracy = 0.2150
Epoch 5 | Loss = 0.8026 | Accuracy = 0.2196
Epoch 6 | Loss = 0.7900 | Accuracy = 0.2430
Epoch 7 | Loss = 0.7800 | Accuracy = 0.2430
Epoch 8 | Loss = 0.7710 | Accuracy = 0.2383
Epoch 9 | Loss = 0.7621 | Accuracy = 0.2570
Epoch 10 | Loss = 0.7531 | Accuracy = 0.2664
Epoch 11 | Loss = 0.7447 | Accuracy = 0.2804
Epoch 12 | Loss = 0.7367 | Accuracy = 0.2944
Epoch 13 | Loss = 0.7290 | Accuracy = 0.3318
Epoch 14 | Loss = 0.7221 | Accuracy = 0.3458
Epoch 15 | Loss = 0.7160 | Accuracy = 0.3738
Epoch 16 | Loss = 0.7108 | Accuracy = 0.3925
Epoch 17 | Loss = 0.7063 | Accuracy = 0.3785
Epoch 18 | Loss = 0.7022 | Accuracy = 0.3879
Epoch 19 | Loss = 0.6984 | Accuracy = 0.3972
Epoch 20 | Loss = 0.6950 | Accuracy = 0.4065
Epoch 21 | Loss = 0.6917 | Accuracy = 0.4112
Epoch 22 | Loss = 0.6885 | Accuracy = 0.4065
Epoch 23 | Loss = 0.6854 | Accuracy = 0.4065
Epoch 24 | Loss = 0.6824 | Accuracy = 0.4112
Epoch 25 | Loss = 0.6795 | Accuracy = 0.4112
...
Epoch 97 | Loss = 0.5662 | Accuracy = 0.6308
Epoch 98 | Loss = 0.5655 | Accuracy = 0.6308
Epoch 99 | Loss = 0.5647 | Accuracy = 0.6355
Epoch 100 | Loss = 0.5640 | Accuracy = 0.6355

In [None]:
'''# Without JAX JIT for circuit representation and debugging
def training_step(params, opt_state, x_token, x_props, y_target):
    def loss_fn(params):
        theta_params = params['theta']
        theta_prop = params['theta_prop']
        sigma_params = params['sigma']

        # Predict using theta_effective and sigma_params
        pred_probs = autoregressive_model(x_token, x_props, theta_effective, theta_prop, sigma_params)
        index = bits_to_index(y_target)
        return categorical_crossentropy(pred_probs, index)

    grads = jax.grad(loss_fn)(params)
    updates, opt_state = optimizer.update(grads, opt_state, params)
    new_params = optax.apply_updates(params, updates)
    print("\nQuantum Circuit:")
    print(qml.draw(autoregressive_model)(
        x_token, x_props, 
        theta_effective, 
        new_params['theta_prop'], 
        new_params['sigma']
    ))    #print("Target index:", int(jax.device_get(bits_to_index(y_target))))
    loss = loss_fn(new_params)
    return new_params, loss, opt_state, grads'''

'# Without JAX JIT for circuit representation and debugging\ndef training_step(params, opt_state, x_token, x_props, y_target):\n    def loss_fn(params):\n        theta_params = params[\'theta\']\n        theta_prop = params[\'theta_prop\']\n        sigma_params = params[\'sigma\']\n\n        # Predict using theta_effective and sigma_params\n        pred_probs = autoregressive_model(x_token, x_props, theta_effective, theta_prop, sigma_params)\n        index = bits_to_index(y_target)\n        return categorical_crossentropy(pred_probs, index)\n\n    grads = jax.grad(loss_fn)(params)\n    updates, opt_state = optimizer.update(grads, opt_state, params)\n    new_params = optax.apply_updates(params, updates)\n    print("\nQuantum Circuit:")\n    print(qml.draw(autoregressive_model)(\n        x_token, x_props, \n        theta_effective, \n        new_params[\'theta_prop\'], \n        new_params[\'sigma\']\n    ))    #print("Target index:", int(jax.device_get(bits_to_index(y_target))))\n    lo

n_layers = 6
h_local = 3
prob_mask = NO MASK

Epoch 1 | Loss = 0.6208 | Accuracy = 0.4185
Epoch 2 | Loss = 0.5437 | Accuracy = 0.4909
Epoch 3 | Loss = 0.5238 | Accuracy = 0.5092
Epoch 4 | Loss = 0.5112 | Accuracy = 0.5222
Epoch 5 | Loss = 0.5034 | Accuracy = 0.5330
Epoch 6 | Loss = 0.5005 | Accuracy = 0.5372
Epoch 7 | Loss = 0.4984 | Accuracy = 0.5426
Epoch 8 | Loss = 0.4958 | Accuracy = 0.5470
Epoch 9 | Loss = 0.4954 | Accuracy = 0.5492
Epoch 10 | Loss = 0.4933 | Accuracy = 0.5497
Epoch 11 | Loss = 0.4887 | Accuracy = 0.5585
Epoch 12 | Loss = 0.4865 | Accuracy = 0.5625
Epoch 13 | Loss = 0.4855 | Accuracy = 0.5594
Epoch 14 | Loss = 0.4844 | Accuracy = 0.5634
Epoch 15 | Loss = 0.4821 | Accuracy = 0.5645
Epoch 16 | Loss = 0.4787 | Accuracy = 0.5690
Epoch 17 | Loss = 0.4776 | Accuracy = 0.5752
Epoch 18 | Loss = 0.4783 | Accuracy = 0.5689
Epoch 19 | Loss = 0.4774 | Accuracy = 0.5783
Epoch 20 | Loss = 0.4756 | Accuracy = 0.5784
Epoch 21 | Loss = 0.4753 | Accuracy = 0.5751
Epoch 22 | Loss = 0.4741 | Accuracy = 0.5789
Epoch 23 | Loss = 0.4762 | Accuracy = 0.5767
Epoch 24 | Loss = 0.4747 | Accuracy = 0.5771
Epoch 25 | Loss = 0.4732 | Accuracy = 0.5787
Epoch 26 | Loss = 0.4715 | Accuracy = 0.5852
Epoch 27 | Loss = 0.4725 | Accuracy = 0.5805
Epoch 28 | Loss = 0.4746 | Accuracy = 0.5792

h=2

Epoch 1 | Loss = 2.6945 | Accuracy = 0.4050 
Epoch 2 | Loss = 2.4173 | Accuracy = 0.4816 
Epoch 3 | Loss = 2.3229 | Accuracy = 0.5047 
Epoch 4 | Loss = 2.2768 | Accuracy = 0.5220
Epoch 5 | Loss = 2.2505 | Accuracy = 0.5267 
Epoch 6 | Loss = 2.2214 | Accuracy = 0.5399 
Epoch 7 | Loss = 2.2051 | Accuracy = 0.5448 
Epoch 8 | Loss = 2.1918 | Accuracy = 0.5523 
Epoch 9 | Loss = 2.1790 | Accuracy = 0.5544 
Epoch 10 | Loss = 2.1683 | Accuracy = 0.5577 
Epoch 11 | Loss = 2.1654 | Accuracy = 0.5582 
Epoch 12 | Loss = 2.1519 | Accuracy = 0.5617 
Epoch 13 | Loss = 2.1452 | Accuracy = 0.5647 
Epoch 14 | Loss = 2.1371 | Accuracy = 0.5670 
Epoch 15 | Loss = 2.1380 | Accuracy = 0.5630 
Epoch 16 | Loss = 2.1330 | Accuracy = 0.5654 
Epoch 17 | Loss = 2.1251 | Accuracy = 0.5661 
Epoch 18 | Loss = 2.1189 | Accuracy = 0.5703 
Epoch 19 | Loss = 2.1176 | Accuracy = 0.5718 
Epoch 20 | Loss = 2.1152 | Accuracy = 0.5698 
Epoch 21 | Loss = 2.1162 | Accuracy = 0.5682 
Epoch 22 | Loss = 2.1117 | Accuracy = 0.5704 
Epoch 23 | Loss = 2.1126 | Accuracy = 0.5729 
Epoch 24 | Loss = 2.1097 | Accuracy = 0.5735 
Epoch 25 | Loss = 2.1040 | Accuracy = 0.5756 
Epoch 26 | Loss = 2.1001 | Accuracy = 0.5751 
Epoch 27 | Loss = 2.0959 | Accuracy = 0.5794 
Epoch 28 | Loss = 2.0881 | Accuracy = 0.5808 
Epoch 29 | Loss = 2.0891 | Accuracy = 0.5806 
Epoch 30 | Loss = 2.0869 | Accuracy = 0.5808 
Epoch 31 | Loss = 2.0843 | Accuracy = 0.5855 
Epoch 32 | Loss = 2.0859 | Accuracy = 0.5839 
Epoch 33 | Loss = 2.0841 | Accuracy = 0.5830