In [1]:
from typing import Tuple, List, Dict, Any

from rdkit import Chem
from rdkit.Chem import AllChem

import numpy as np
from math import ceil, log2
import re
import pandas as pd
import optax
import csv
import json
import flax.linen as nn
import math
import itertools

import pennylane as qml
from pennylane import qchem
from pennylane.templates import StronglyEntanglingLayers

import jax
import jax.numpy as jnp
from jax.nn.initializers import normal

import haiku as hk

## Data Preparation

In [3]:
# Load metadata from JSON
N_MOLECS = 5000
META_DATA_PATH = f"../data/metadata_selfies_{N_MOLECS}.json"

with open(META_DATA_PATH, "r") as f:
    metadata = json.load(f)

VOCABULARY_SIZE = metadata['vocabulary_size']
BITS_PER_TOKEN = metadata['bits_per_token']
MAX_LEN = metadata['max_sequence_length']
ALPHABET = metadata['alphabet']
min_logp = metadata['min_logP']
max_logp = metadata['max_logP']
min_qed = metadata['min_qed']
max_qed = metadata['max_qed']
min_mw = metadata['min_mw']
max_mw = metadata['max_mw']

print("Vocabulary Size:", VOCABULARY_SIZE)
print("Bits per Token:", BITS_PER_TOKEN)
print("Max Sequence Length:", MAX_LEN)
print("Alphabet:", ALPHABET)


Vocabulary Size: 32
Bits per Token: 5
Max Sequence Length: 31
Alphabet: ['<SOS>', '[#Branch1]', '[#Branch2]', '[#C]', '[#N]', '[=Branch1]', '[=Branch2]', '[=C]', '[=N]', '[=O]', '[=PH1]', '[=P]', '[=Ring1]', '[=S]', '[Br]', '[Branch1]', '[Branch2]', '[C]', '[Cl]', '[F]', '[H]', '[I]', '[NH1]', '[N]', '[O]', '[PH1]', '[P]', '[Ring1]', '[Ring2]', '[S]', '<EOS>', '<PAD>']


In [4]:
# --- Auxiliary functions ---

def normalize(value, min_val, max_val, target_max=np.pi):
    ''' Normalize a value to a range [0, [0, pi] to later encode them as rotation angles'''
    norm = (value - min_val) / (max_val - min_val) * target_max
    return float(f"{norm:.3f}")

def token_to_index(token):
    ''' Map a SELFIES token to its corresponding index in the ALPHABET'''
    if token in ALPHABET:
        return ALPHABET.index(token)
    else:
        return None

def bits_to_index(bits):
    powers = 2 ** jnp.arange(len(bits) - 1, -1, -1)
    return jnp.dot(bits, powers).astype(jnp.int32)

In [5]:
PAD_index = token_to_index("<PAD>")
SOS_index = token_to_index("<SOS>")
EOS_index = token_to_index("<EOS>")

In [7]:
def load_dataset_bits_batch(csv_path, n_bits=BITS_PER_TOKEN):
    """
    Read the dataset from a CSV file and convert it into bit representations for quantum processing.
    Args:
        csv_path (str): Path to the CSV file containing the dataset.
        n_bits (int): Number of bits used to represent each token.
    Returns:
        Tuple containing:
            - X_bits (jnp.ndarray): Input bit representations of shape (N_Molecules, MAX_LEN-1, n_bits).
            - Props (jnp.ndarray): Molecular properties of shape (N_Molecules, 3).
            - Y_indices (jnp.ndarray): Target token indices of shape (N_Molecules, MAX_LEN-1).
    """
    
    df = pd.read_csv(csv_path, dtype=str)
    print("Dataset loaded with shape:", df.shape)
    
    # Properties: logP, qed, mw
    # Shape: (N_Moleculas, 3)
    # Types: float
    props = df.iloc[:, :3].astype(float).values
    
    # Tokens
    # Shape: (N_Moleculas, MAX_LEN)
    # Types: string (e.g., "00101")
    token_cols = df.iloc[:, 3:].values
    
    # Auxiliary function to convert a string of '0's and '1's to a list
    def str_to_bit_list(s):
        return [int(c) for c in s]

    # 3D array to hold all bits
    all_bits = np.array([
        [str_to_bit_list(token) for token in row] 
        for row in token_cols
    ])
    
    # X_bits: Input for the model.
    # Take all tokens except the last one.
    X_bits = all_bits[:, :-1, :] 
    
    # Y_ind: Target indices for the model.
    # Take all tokens except the first one.
    # Convert from bit strings to integer indices (for loss calculation).
    Y_indices = np.array([[int(t, 2) for t in row[1:]] for row in token_cols])

    return jnp.array(X_bits), jnp.array(props), jnp.array(Y_indices)


In [8]:
MOLECS_DATA_PATH = f"../data/structured_data_selfies_{N_MOLECS}.csv"
X_bits, Props, Y_ind = load_dataset_bits_batch(MOLECS_DATA_PATH)

print("Loaded data:")
print("X_bits shape:", X_bits.shape) # (N, L, 5)
print("Props shape:", Props.shape)   # (N, 3)
print("Y_ind shape:", Y_ind.shape)   # (N, L)

Dataset loaded with shape: (4777, 34)
Loaded data:
X_bits shape: (4777, 30, 5)
Props shape: (4777, 3)
Y_ind shape: (4777, 30)


## Quantum Circuit

In [12]:
import itertools
def zstring_combos(wires):
    """
    Return an ordered list of wire-tuples for all Z-strings up to order H_LOCAL.
    Order: all 1-local, then all 2-local, ..., up to H_LOCAL.
    """
    L = []
    for k in range(1, H_LOCAL + 1):
        L.extend(itertools.combinations(wires, k))
    return [tuple(c) for c in L]

def num_zstrings(n_wires):
    """Count how many Z-strings up to order H_LOCAL."""
    from math import comb
    return sum(comb(n_wires, k) for k in range(1, H_LOCAL + 1))

In [13]:
# Device and qubit setup
# BITS_PER_TOKEN number of qubits needed to encode each token
n_prop_qubits = 3  # number of qubits needed to encode properties (logP, QED, MW)
n_ancillas = 3  # number of ancilla qubits that represent the environment
n_total_qubits = n_prop_qubits + BITS_PER_TOKEN + n_ancillas

N_LAYERS = 6  # number of variational layers
H_LOCAL = 3 # h_local sets the maximum number of qubits that can interact in each Z-string term of Σ


# Name them explicitly
prop_wires = [f"prop_{i}" for i in range(n_prop_qubits)]
token_wires = [f"token_{i}" for i in range(BITS_PER_TOKEN)]
ancilla_wires = [f"ancilla_{i}" for i in range(n_ancillas)]
all_wires = prop_wires + token_wires + ancilla_wires

dev = qml.device("default.qubit", wires=all_wires)


def molecular_property_encoder(props):
    """Encode continuous props on property qubits via RY rotations"""
    for wire, val in zip(prop_wires, props):
        qml.RY(val, wires=wire)


def token_encoder(token_bits):
    qml.BasisState(token_bits, wires=token_wires)


def operator_layer(theta_params, theta_prop, wires):
    """
    Variational layer where:
      - theta_params[...] are rotations for token + ancilla qubits
      - theta_prop encodes property→token entanglement
    """
    prop_ws = prop_wires
    token_ancilla_ws = token_wires + ancilla_wires

    # Property → token entanglement
    for p, prop_wire in enumerate(prop_ws):
        for t, t_a_wire in enumerate(token_ancilla_ws):
            qml.CRX(theta_prop[p, t, 0], wires=[prop_wire, t_a_wire])
            qml.CRY(theta_prop[p, t, 1], wires=[prop_wire, t_a_wire])

    # Single-qubit rotations
        ''' for i, wire in enumerate(token_ancilla_ws):
        qml.RX(theta_params[i, 0], wires=wire)
        qml.RY(theta_params[i, 1], wires=wire)
        qml.RZ(theta_params[i, 2], wires=wire)

    # Entangle token + ancilla chain
    for i in range(len(token_ancilla_ws) - 1):
        qml.CNOT(wires=[token_ancilla_ws[i], token_ancilla_ws[i + 1]])
    qml.CNOT(wires=[token_ancilla_ws[-1], token_ancilla_ws[0]])'''
    
    qml.StronglyEntanglingLayers(
        weights=theta_params[None,:,:],  # shape: (n_token_ancilla, 3)
        wires=token_ancilla_ws
    )

def Sigma_layer_vec(gamma_vec, wires, time=1.0, combos=None):
    """
    Diagonal multi-Z unitary Σ = exp(i * sum_s gamma_s * Z^{⊗|s|} * t)
    using a flat parameter vector 'gamma_vec' aligned with 'combos'.
    """
    token_ancilla_ws = list(wires)  # pass token+ancilla here
    if combos is None:
        combos = zstring_combos(token_ancilla_ws)

    # Safety: ensure the vector length matches the number of combos
    assert gamma_vec.shape[0] == len(combos), \
        f"gamma_vec has length {gamma_vec.shape[0]} but expected {len(combos)}"

    # MultiRZ(phi) = exp(-i * phi/2 * Z^{⊗k}); choose phi = -2 * gamma * time
    for gamma, combo in zip(gamma_vec, combos):
        qml.MultiRZ(-2.0 * gamma * time, wires=list(combo))


# QNode combining encoding and variational layers
@qml.qnode(dev, interface="jax")
def autoregressive_model(token_bits, props, theta_params, theta_prop, sigma_params):
    molecular_property_encoder(props)      # Encode MW, logP, QED
    token_encoder(token_bits)              # Basis-encode token bits

    token_ancilla_ws = token_wires + ancilla_wires
    combos = zstring_combos(token_ancilla_ws)


    for l in range(N_LAYERS):
        # Forward V(θ)
        operator_layer(theta_params[l], theta_prop[l], wires=all_wires)

        # Diagonal Σ(γ,t): vector API
        Sigma_layer_vec(sigma_params[l], token_ancilla_ws, time=1.0, combos=combos)

        # Backward V(θ)†
        qml.adjoint(operator_layer)(theta_params[l], theta_prop[l], wires=all_wires)

    return qml.probs(wires=token_wires), [qml.expval(qml.PauliZ(w)) for w in prop_wires]

In [14]:
'''def bitstr_to_array(bitstr):
    """Convert a string of bits (e.g., '010101') to a numpy float32 array."""
    return np.array([int(b) for b in bitstr], dtype=np.float32)

def build_training_data(df):
    """
    Build dataset tuples of (input_token_bits, molecular_properties, target_token_bits)
    from a DataFrame.

    Args:
        df (pandas.DataFrame): DataFrame containing molecular properties and token bit strings.
        n_token_cols (int): Number of token columns in the DataFrame.

    Returns:
        list of tuples: Each tuple contains (x_token: np.array, x_props: np.array, y_target: np.array)
    """
    dataset = []

    for _, row in df.iterrows():
        # Extract molecular properties as a numpy float32 array
        props = [row['logP'], row['qed'], row['mw']]
        x_props = np.array(props, dtype=np.float32)

        tokens = row[3:]  # token columns after properties

        # Iterate over token sequence to create input-target pairs
        for i in range(len(tokens) - 1):
            current_token = tokens.iloc[i]
            next_token = tokens.iloc[i + 1]

            # Skip missing or NaN tokens
            if current_token is None or (isinstance(current_token, float) and math.isnan(current_token)):
                continue
            if next_token is None or (isinstance(next_token, float) and math.isnan(next_token)):
                continue

            # Convert token strings (e.g., '01011') to bit arrays
            x_token = bitstr_to_array(current_token)
            y_target = bitstr_to_array(next_token)

            dataset.append((x_token, x_props, y_target))

    return dataset


# Load dataset
token_cols = [f"token_{i}" for i in range(n_tokens)]
df = pd.read_csv(DATA_PATH, dtype={col: str for col in token_cols})
dataset = build_training_data(df)  # Should return list/array of (x_token, x_props, y_target)'''

In [15]:
# 1. Define the embedding network function
def embedding_network_fn(x):
    embedding_size = N_LAYERS * (BITS_PER_TOKEN + n_ancillas) * 3 
    mlp = hk.Sequential([
        hk.Linear(32), jax.nn.relu,
        hk.Linear(16), jax.nn.relu,
        hk.Linear(embedding_size),  # Match θ shape
    ])
    return mlp(x)

# 2. Transform the function to make it usable in JAX/Haiku
embedding_network = hk.transform(embedding_network_fn)


# 3. Use the embedding to generate θ parameters
def get_context_embedding(prev_token_bits, embed_params):
    """
    prev_token_bits: jnp.array of shape (12,) — concatenated 2×6 bits
    Returns: reshaped embedding output to match θ shape
    """
    embedding_output = embedding_network.apply(embed_params, None, prev_token_bits)
    return embedding_output.reshape((N_LAYERS, BITS_PER_TOKEN + n_ancillas, 3))  # Shape for operator_layer



In [16]:
def bits_to_index(bits):
    powers = 2 ** jnp.arange(len(bits) - 1, -1, -1)
    return jnp.dot(bits, powers).astype(jnp.int32)
'''
def categorical_crossentropy(pred_probs, target_index):
    epsilon = 1e-10
    return -jnp.log(pred_probs[target_index] + epsilon)

def label_smoothing_crossentropy_normalized(pred_probs, target_index, epsilon=0.1, alpha=0.1):
    """Cross-entropy loss with label smoothing, normalized to [0,1]."""
    num_classes = pred_probs.shape[0] # Number of classes (tokens)
    
    # Build smoothed target
    smooth_target = jnp.full_like(pred_probs, epsilon / (num_classes - 1))
    smooth_target = smooth_target.at[target_index].set(1.0 - epsilon)
    
    # Compute cross-entropy loss
    ce_loss = -jnp.sum(smooth_target * jnp.log(pred_probs + 1e-10)) # in [0, log(num_classes)]
    max_loss = jnp.log(num_classes)
    total_loss = ce_loss / max_loss  # Normalize to [0,1]

    return total_loss'''
def total_loss_fn(pred_probs, prop_expvals, target_index, props, alpha=0.5, epsilon=0.1):
    """Single normalized loss in [0,1]."""
    num_classes = pred_probs.shape[0]

    # --- Cross-entropy with label smoothing (raw, not normalized yet) ---
    smooth_target = jnp.full_like(pred_probs, epsilon / (num_classes - 1))
    smooth_target = smooth_target.at[target_index].set(1.0 - epsilon)
    ce_loss = -jnp.sum(smooth_target * jnp.log(pred_probs + 1e-10))  # in [0, log(num_classes)]

    # --- Property preservation loss (MSE) ---
    prop_expvals = jnp.array(prop_expvals)  # convert list -> JAX array
    prop_loss = jnp.mean((prop_expvals - jnp.cos(props)) ** 2)  # in [0,4]

    # --- Combine ---
    combined_loss = ce_loss + alpha * prop_loss

    # --- Normalize only once ---
    max_loss = jnp.log(num_classes) + alpha * 4.0
    final_loss = combined_loss / max_loss

    return final_loss

def compute_accuracy(pred_probs, target_index):
    predicted_index = jnp.argmax(pred_probs)
    return jnp.array(predicted_index == target_index, dtype=jnp.float32)


In [17]:
# Initialize embedding params
rng = jax.random.PRNGKey(0)
dummy_context = jnp.zeros((3 * BITS_PER_TOKEN,), dtype=jnp.float32)  # 3 prev. tokens
embedding_params = embedding_network.init(rng, dummy_context)

# Effective qubit counts in variational layers
n_token_ancilla = BITS_PER_TOKEN + n_ancillas

# Initialize theta and sigma params
key = jax.random.PRNGKey(42)
key, k_theta, k_theta_prop, k_sigma = jax.random.split(key, 4)

# Precompute Z-string combos once
token_ancilla_ws = token_wires + ancilla_wires
combos = zstring_combos(token_ancilla_ws)
n_strings = len(combos)


combined_params = {
    'theta': jax.random.normal(k_theta, (N_LAYERS, n_token_ancilla, 3)) * 0.1,
    'theta_prop': jax.random.normal(k_theta_prop, (N_LAYERS, n_prop_qubits, n_token_ancilla, 4)) * 0.1,
    'sigma': jax.random.normal(k_sigma, (N_LAYERS, n_strings)) * 0.1,
    'embedding': embedding_params,
}
# Training hyperparams
learning_rate = 0.001
n_epochs = 100

# Optimizer
optimizer = optax.adam(learning_rate)
opt_state = optimizer.init(combined_params)

@jax.jit
def training_step(params, opt_state, x_token, x_props, y_target, context_vector):
    def loss_fn(params):
        theta_params = params['theta']
        theta_prop = params['theta_prop']
        sigma_params = params['sigma']
        embedding_params = params['embedding']

        # Embedding → adjustment for theta
        theta_from_embedding = get_context_embedding(context_vector, embedding_params)
        theta_effective = theta_params + theta_from_embedding
        # theta_params || theta_from_embedding

        # Predict
        pred_probs, expval_props = autoregressive_model(x_token, x_props, theta_effective, theta_prop, sigma_params)
        index = bits_to_index(y_target)
        
        # Return scalar loss for gradient computation
        # return label_smoothing_crossentropy_nortotal_loss_fnmalized(pred_probs, index), pred_probs
        return total_loss_fn(pred_probs, expval_props, index, x_props), pred_probs

    # value_and_grad computes both loss and grads in one pass
    (loss, pred_probs), grads = jax.value_and_grad(loss_fn, has_aux=True)(params)

    # Update parameters
    updates, opt_state = optimizer.update(grads, opt_state, params)
    new_params = optax.apply_updates(params, updates)

    # Accuracy from pred_probs (already computed!)
    index = bits_to_index(y_target)
    acc = compute_accuracy(pred_probs, index)

    return new_params, loss, opt_state, grads, acc

# Context initialization outside epoch loop
SOS_token = jnp.zeros((BITS_PER_TOKEN,), dtype=jnp.int32)
eos_index = len(alphabet)-1
EOS_token = jnp.array([int(b) for b in format(eos_index, f'0{BITS_PER_TOKEN}b')], dtype=jnp.int32)

prev_token1 = SOS_token
prev_token2 = SOS_token
prev_token3 = SOS_token

for epoch in range(n_epochs):
    total_loss = total_acc = 0.0
    
    for x_token, x_props, y_target in dataset:
        x_token = jnp.array(x_token, dtype=jnp.int32)
        x_props = jnp.array(x_props, dtype=jnp.float32)
        y_target = jnp.array(y_target, dtype=jnp.int32)
      
        context_vector = jnp.concatenate([prev_token1, prev_token2, prev_token3])

        combined_params, loss, opt_state, grads, acc = training_step(combined_params, opt_state, x_token, x_props, y_target, context_vector)
        
        # Update previous tokens
        if jnp.array_equal(x_token, EOS_token):
            prev_token1 = SOS_token
            prev_token2 = SOS_token
            prev_token3 = SOS_token
        else:
            # Shift prev tokens, add current token as the newest previous token
            prev_token3 = prev_token2
            prev_token2 = prev_token1
            prev_token1 = x_token

        total_loss += loss
        total_acc  += acc

    avg_loss = total_loss / len(dataset)
    avg_acc  = total_acc / len(dataset)
    
    print(f"Epoch {epoch+1} | Loss = {avg_loss:.4f} | Accuracy = {avg_acc:.4f}")

  return lax_numpy.astype(self, dtype, copy=copy, device=device)
  return lax_numpy.astype(self, dtype, copy=copy, device=device)


Epoch 1 | Loss = 0.5476 | Accuracy = 0.4380
Epoch 2 | Loss = 0.5016 | Accuracy = 0.4766
Epoch 3 | Loss = 0.4801 | Accuracy = 0.5086
Epoch 4 | Loss = 0.4711 | Accuracy = 0.5244
Epoch 5 | Loss = 0.4642 | Accuracy = 0.5353
Epoch 6 | Loss = 0.4627 | Accuracy = 0.5375
Epoch 7 | Loss = 0.4607 | Accuracy = 0.5381
Epoch 8 | Loss = 0.4559 | Accuracy = 0.5449
Epoch 9 | Loss = 0.4504 | Accuracy = 0.5526
Epoch 10 | Loss = 0.4500 | Accuracy = 0.5562
Epoch 11 | Loss = 0.4484 | Accuracy = 0.5641
Epoch 12 | Loss = 0.4450 | Accuracy = 0.5685
Epoch 13 | Loss = 0.4430 | Accuracy = 0.5686
Epoch 14 | Loss = 0.4428 | Accuracy = 0.5658
Epoch 15 | Loss = 0.4396 | Accuracy = 0.5706
Epoch 16 | Loss = 0.4393 | Accuracy = 0.5747
Epoch 17 | Loss = 0.4414 | Accuracy = 0.5708
Epoch 18 | Loss = 0.4397 | Accuracy = 0.5716
Epoch 19 | Loss = 0.4398 | Accuracy = 0.5763
Epoch 20 | Loss = 0.4399 | Accuracy = 0.5767
Epoch 21 | Loss = 0.4415 | Accuracy = 0.5756
Epoch 22 | Loss = 0.4367 | Accuracy = 0.5796
Epoch 23 | Loss = 0

KeyboardInterrupt: 

In [None]:

@jax.jit
def training_step(params, opt_state, x_token, x_props, y_target, context_vector):
    def loss_fn(params):
        theta_params = params['theta']
        theta_prop = params['theta_prop']
        sigma_params = params['sigma']
        embedding_params = params['embedding']

        # Embedding → adjustment for theta
        theta_from_embedding = get_context_embedding(context_vector, embedding_params)
        theta_effective = theta_params + theta_from_embedding
        # theta_params || theta_from_embedding

        # Predict
        pred_probs, expval_props = autoregressive_model(x_token, x_props, theta_effective, theta_prop, sigma_params)
        index = bits_to_index(y_target)
        
        # Return scalar loss for gradient computation
        # return label_smoothing_crossentropy_nortotal_loss_fnmalized(pred_probs, index), pred_probs
        return total_loss_fn(pred_probs, expval_props, index, x_props), pred_probs

    # value_and_grad computes both loss and grads in one pass
    (loss, pred_probs), grads = jax.value_and_grad(loss_fn, has_aux=True)(params)

    # Update parameters
    updates, opt_state = optimizer.update(grads, opt_state, params)
    new_params = optax.apply_updates(params, updates)

    # Accuracy from pred_probs (already computed!)
    index = bits_to_index(y_target)
    acc = compute_accuracy(pred_probs, index)

    return new_params, loss, opt_state, grads, acc

# Context initialization outside epoch loop
SOS_token = jnp.zeros((BITS_PER_TOKEN,), dtype=jnp.int32)
eos_index = len(alphabet)-1
EOS_token = jnp.array([int(b) for b in format(eos_index, f'0{BITS_PER_TOKEN}b')], dtype=jnp.int32)

prev_token1 = SOS_token
prev_token2 = SOS_token
prev_token3 = SOS_token

for epoch in range(n_epochs):
    total_loss = total_acc = 0.0
    
    for x_token, x_props, y_target in dataset:
        x_token = jnp.array(x_token, dtype=jnp.int32)
        x_props = jnp.array(x_props, dtype=jnp.float32)
        y_target = jnp.array(y_target, dtype=jnp.int32)
      
        context_vector = jnp.concatenate([prev_token1, prev_token2, prev_token3])

        combined_params, loss, opt_state, grads, acc = training_step(combined_params, opt_state, x_token, x_props, y_target, context_vector)
        
        # Update previous tokens
        if jnp.array_equal(x_token, EOS_token):
            prev_token1 = SOS_token
            prev_token2 = SOS_token
            prev_token3 = SOS_token
        else:
            # Shift prev tokens, add current token as the newest previous token
            prev_token3 = prev_token2
            prev_token2 = prev_token1
            prev_token1 = x_token

        total_loss += loss
        total_acc  += acc

    avg_loss = total_loss / len(dataset)
    avg_acc  = total_acc / len(dataset)
    
    print(f"Epoch {epoch+1} | Loss = {avg_loss:.4f} | Accuracy = {avg_acc:.4f}")

In [23]:
import os

# Define the directory path (relative to your current location, assuming you are in /content/QGen-Mol/code)
target_dir = '../data/params/'

# Create the directory if it doesn't exist
os.makedirs(target_dir, exist_ok=True)

# Now, run your pickle code:
with open(os.path.join(target_dir, 'embedding_theta_params.pkl'), "wb") as f:
    pickle.dump(combined_params, f)

In [35]:
# Generate new molecules
from jax import random

def generate_molecule_stochastic(key, props, combined_params, temperature=1.0, max_length=max_len + 2):
    """
    Generates a molecule stochastically using a PRNG key for sampling.
    """
    theta_params = combined_params['theta']
    theta_prop = combined_params['theta_prop']
    sigma_params = combined_params['sigma']
    embedding_params = combined_params['embedding']

    generated_bits = []
    current_token = jnp.zeros((BITS_PER_TOKEN,), dtype=jnp.int32) # Start with <SOS>
    
    # Context tokens logic (reusing your implementation)
    prev_token1 = SOS_token
    prev_token2 = SOS_token
    prev_token3 = SOS_token

    for t in range(max_length):
        # *** Critical step: Split the key for inner-loop stochasticity ***
        key, subkey = random.split(key)

        context_vector = jnp.concatenate([prev_token1, prev_token2, prev_token3])
        theta_from_embedding = get_context_embedding(context_vector, embedding_params)
        theta_effective = theta_params + theta_from_embedding

        # The autoregressive model here uses the QNode to get probabilities
        pred_probs, _ = autoregressive_model(current_token, props, theta_effective, theta_prop, sigma_params)
        
        # 1. Convert probabilities to logits (log-probabilities) for scaling
        #    Add a small epsilon (1e-10) for numerical stability (avoids log(0))
        logits = jnp.log(pred_probs[:VOCABULARY_SIZE] + 1e-10)

        # 2. Scale logits by Temperature (logits / T)
        tempered_logits = logits / temperature

        # 3. Apply softmax to get the new, tempered probability distribution
        tempered_probs = nn.softmax(tempered_logits)
        
        # 4. Sample the next index from the Tempered probability distribution
        token_indices = jnp.arange(VOCABULARY_SIZE)
        next_index = random.choice(
            subkey, 
            token_indices, 
            p=tempered_probs 
        )

        next_bits_str = format(int(next_index), f'0{BITS_PER_TOKEN}b')
        next_bits = jnp.array([int(b) for b in next_bits_str], dtype=jnp.int32)

        generated_bits.append(next_bits)

        if next_index == eos_index:
            break

        # Update previous tokens for the next step
        prev_token3 = prev_token2
        prev_token2 = prev_token1
        prev_token1 = current_token
        current_token = next_bits

    return generated_bits

# Convert generated bits to SMILES
def bits_to_smiles(generated_bits):
    tokens = []
    for bits in generated_bits:
        index = int("".join(map(str, bits)), 2)
        token = alphabet[index]
        if token == '<EOS>':
            break
        tokens.append(token)
    return ''.join(tokens)

In [41]:
N_MOLECS = 100
valid_molecs = []
# Initialize a master key once. Use a fixed seed for reproducibility.
master_key = random.PRNGKey(1) 

# Normalize desired properties once
desired_logp = normalize(2.0, min_logp, max_logp)
desired_qed = normalize(0.7, min_qed, max_qed)
desired_mw = normalize(250.0, min_mw, max_mw)
props = jnp.array([desired_logp, desired_qed, desired_mw], dtype=jnp.float32)

for n in range(N_MOLECS):
    # *** Critical step: Split the master key for the current molecule's stochastic process ***
    master_key, key_for_mol = random.split(master_key)
    
    # Use the new stochastic function
    generated_bits = generate_molecule_stochastic(key_for_mol, props, combined_params)
    generated_smiles = bits_to_smiles(generated_bits) 
    
    # Now that the output is stochastic, you can proceed to validation.
    
    # ... Validation Code Here ... 
    # (See my previous suggestion for a full validation function)
    # -----------------------------
    mol = Chem.MolFromSmiles(generated_smiles)
    if mol is not None:
        validity = "Valid"
        valid_molecs.append(generated_smiles)
        # Obtain properties of valid molecules
        print(f" - Generated Molecule {n+1}: {generated_smiles} (Status: {validity})")
    else:
        validity = "Invalid"
        

[18:09:44] SMILES Parse Error: syntax error while parsing: c2ccc(CC([nH]22[C@H][C@@]1(=[Li+][C@@]([C@@H][C@][S+]#)[C@]O)BrPO[O-]\[nH]c
[18:09:44] SMILES Parse Error: check for mistakes around position 55:
[18:09:44] C@@]([C@@H][C@][S+]#)[C@]O)BrPO[O-]\[nH]c
[18:09:44] ~~~~~~~~~~~~~~~~~~~~^
[18:09:44] SMILES Parse Error: Failed parsing SMILES 'c2ccc(CC([nH]22[C@H][C@@]1(=[Li+][C@@]([C@@H][C@][S+]#)[C@]O)BrPO[O-]\[nH]c' for input: 'c2ccc(CC([nH]22[C@H][C@@]1(=[Li+][C@@]([C@@H][C@][S+]#)[C@]O)BrPO[O-]\[nH]c'
[18:10:08] SMILES Parse Error: syntax error while parsing: N(CC(S#)[Cl-][C@@]oCl[Cl-]\2C1
[18:10:08] SMILES Parse Error: check for mistakes around position 8:
[18:10:08] N(CC(S#)[Cl-][C@@]oCl[Cl-]\2C1
[18:10:08] ~~~~~~~^
[18:10:08] SMILES Parse Error: Failed parsing SMILES 'N(CC(S#)[Cl-][C@@]oCl[Cl-]\2C1' for input: 'N(CC(S#)[Cl-][C@@]oCl[Cl-]\2C1'
[18:10:55] SMILES Parse Error: syntax error while parsing: [nH]1n2N[C@@][nH]2[C@H][C@@]#CCNCC()c3cccc(S(N)(N)/ClBr
[18:10:55] SMILES Parse

 - Generated Molecule 4:  (Status: Valid)


[18:11:43] SMILES Parse Error: syntax error while parsing: cc12([C@@H]1(=O)=O)=[O-]#[Cl-][N+](O[C@@H][N+]oO[C@@H][O-]<SOS>S(=O)CC
[18:11:43] SMILES Parse Error: check for mistakes around position 59:
[18:11:43] @@H][N+]oO[C@@H][O-]<SOS>S(=O)CC
[18:11:43] ~~~~~~~~~~~~~~~~~~~~^
[18:11:43] SMILES Parse Error: Failed parsing SMILES 'cc12([C@@H]1(=O)=O)=[O-]#[Cl-][N+](O[C@@H][N+]oO[C@@H][O-]<SOS>S(=O)CC' for input: 'cc12([C@@H]1(=O)=O)=[O-]#[Cl-][N+](O[C@@H][N+]oO[C@@H][O-]<SOS>S(=O)CC'
[18:12:00] SMILES Parse Error: syntax error while parsing: CC1CCCI<SOS>=cIF[Li+]
[18:12:00] SMILES Parse Error: check for mistakes around position 8:
[18:12:00] CC1CCCI<SOS>=cIF[Li+]
[18:12:00] ~~~~~~~^
[18:12:00] SMILES Parse Error: Failed parsing SMILES 'CC1CCCI<SOS>=cIF[Li+]' for input: 'CC1CCCI<SOS>=cIF[Li+]'
[18:12:08] SMILES Parse Error: unclosed ring for input: 'c1[PH][PH]n'
[18:12:14] Explicit valence for atom # 2 O, 2, is greater than permitted
[18:12:31] SMILES Parse Error: extra close parentheses 

 - Generated Molecule 10:  (Status: Valid)


[18:12:49] SMILES Parse Error: syntax error while parsing: 2CN=O)c1c1(O
[18:12:49] SMILES Parse Error: check for mistakes around position 1:
[18:12:49] 2CN=O)c1c1(O
[18:12:49] ^
[18:12:49] SMILES Parse Error: Failed parsing SMILES '2CN=O)c1c1(O' for input: '2CN=O)c1c1(O'
[18:12:55] SMILES Parse Error: extra open parentheses while parsing: N(N1
[18:12:55] SMILES Parse Error: check for mistakes around position 2:
[18:12:55] N(N1
[18:12:55] ~^
[18:12:55] SMILES Parse Error: Failed parsing SMILES 'N(N1' for input: 'N(N1'
[18:13:32] SMILES Parse Error: extra close parentheses while parsing: cc12OC)Ccc1/[Na+][Na+][C@@][O-][C@H][N+]BrPP[C@@H]1(=O)
[18:13:32] SMILES Parse Error: check for mistakes around position 7:
[18:13:32] cc12OC)Ccc1/[Na+][Na+][C@@][O-][C@H][N+]B
[18:13:32] ~~~~~~^
[18:13:32] SMILES Parse Error: Failed parsing SMILES 'cc12OC)Ccc1/[Na+][Na+][C@@][O-][C@H][N+]BrPP[C@@H]1(=O)' for input: 'cc12OC)Ccc1/[Na+][Na+][C@@][O-][C@H][N+]BrPP[C@@H]1(=O)'
[18:14:18] SMILES Parse Error:

 - Generated Molecule 36:  (Status: Valid)


[18:23:41] SMILES Parse Error: extra close parentheses while parsing: C)=#1CCC(C(C()cc1n
[18:23:41] SMILES Parse Error: check for mistakes around position 2:
[18:23:41] C)=#1CCC(C(C()cc1n
[18:23:41] ~^
[18:23:41] SMILES Parse Error: Failed parsing SMILES 'C)=#1CCC(C(C()cc1n' for input: 'C)=#1CCC(C(C()cc1n'
[18:23:58] Can't kekulize mol.  Unkekulized atoms: 1 4 5
[18:24:44] SMILES Parse Error: extra close parentheses while parsing: N)=O1C3Cc1c[C@@H]\[Na+]n([Na+][nH]cc2c1CCN)Cc1nnF[C@@H][Cl-]
[18:24:44] SMILES Parse Error: check for mistakes around position 2:
[18:24:44] N)=O1C3Cc1c[C@@H]\[Na+]n([Na+][nH]cc2c1CC
[18:24:44] ~^
[18:24:44] SMILES Parse Error: Failed parsing SMILES 'N)=O1C3Cc1c[C@@H]\[Na+]n([Na+][nH]cc2c1CCN)Cc1nnF[C@@H][Cl-]' for input: 'N)=O1C3Cc1c[C@@H]\[Na+]n([Na+][nH]cc2c1CCN)Cc1nnF[C@@H][Cl-]'
[18:25:18] SMILES Parse Error: syntax error while parsing: N[Cl-][C@H][Cl-]<SOS>c1[C@]C1=CCCC1[C@H][O-][PH][O-][Cl-][C@@]1C1
[18:25:18] SMILES Parse Error: check for mistakes aro

 - Generated Molecule 56:  (Status: Valid)


[18:33:54] SMILES Parse Error: syntax error while parsing: (=C2[nH]ccc[C@@][PH][C@@]o(=C3scc2)nn1CCCCCC2)
[18:33:54] SMILES Parse Error: check for mistakes around position 1:
[18:33:54] (=C2[nH]ccc[C@@][PH][C@@]o(=C3scc2)nn1CCC
[18:33:54] ^
[18:33:54] SMILES Parse Error: Failed parsing SMILES '(=C2[nH]ccc[C@@][PH][C@@]o(=C3scc2)nn1CCCCCC2)' for input: '(=C2[nH]ccc[C@@][PH][C@@]o(=C3scc2)nn1CCCCCC2)'
[18:34:12] SMILES Parse Error: unclosed ring for input: 'c1ccc(=O)/ss'
[18:34:47] SMILES Parse Error: extra close parentheses while parsing: N1cc1F1[Li+][PH]nn=Cl)F[Cl-]P\[Na+]/CNC#Cl
[18:34:47] SMILES Parse Error: check for mistakes around position 22:
[18:34:47] 1cc1F1[Li+][PH]nn=Cl)F[Cl-]P\[Na+]/CNC#Cl
[18:34:47] ~~~~~~~~~~~~~~~~~~~~^
[18:34:47] SMILES Parse Error: Failed parsing SMILES 'N1cc1F1[Li+][PH]nn=Cl)F[Cl-]P\[Na+]/CNC#Cl' for input: 'N1cc1F1[Li+][PH]nn=Cl)F[Cl-]P\[Na+]/CNC#Cl'
[18:34:53] SMILES Parse Error: extra close parentheses while parsing: N)
[18:34:53] SMILES Parse Error:

 - Generated Molecule 62:  (Status: Valid)


[18:35:56] SMILES Parse Error: extra open parentheses while parsing: c1nc([O-][N+]1
[18:35:56] SMILES Parse Error: check for mistakes around position 5:
[18:35:56] c1nc([O-][N+]1
[18:35:56] ~~~~^
[18:35:56] SMILES Parse Error: Failed parsing SMILES 'c1nc([O-][N+]1' for input: 'c1nc([O-][N+]1'
[18:36:04] SMILES Parse Error: extra open parentheses while parsing: N(CC2
[18:36:04] SMILES Parse Error: check for mistakes around position 2:
[18:36:04] N(CC2
[18:36:04] ~^
[18:36:04] SMILES Parse Error: Failed parsing SMILES 'N(CC2' for input: 'N(CC2'
[18:36:59] SMILES Parse Error: syntax error while parsing: CC(O(FCF)c1ccc2C(=N<SOS>[O-][Na+]I[C@H][C@@][PH][PH][O-][C@H][PH]c1ccc
[18:36:59] SMILES Parse Error: check for mistakes around position 20:
[18:36:59] CC(O(FCF)c1ccc2C(=N<SOS>[O-][Na+]I[C@H][C
[18:36:59] ~~~~~~~~~~~~~~~~~~~~^
[18:36:59] SMILES Parse Error: Failed parsing SMILES 'CC(O(FCF)c1ccc2C(=N<SOS>[O-][Na+]I[C@H][C@@][PH][PH][O-][C@H][PH]c1ccc' for input: 'CC(O(FCF)c1ccc2C(=N<SOS>[O-

 - Generated Molecule 86:  (Status: Valid)


[18:48:01] SMILES Parse Error: syntax error while parsing: C1[C@@][PH]#[nH]ccc2sI/)=2[Cl-])=O)<SOS>CC)cc1[Na+]cc2ccc
[18:48:01] SMILES Parse Error: check for mistakes around position 24:
[18:48:01] C@@][PH]#[nH]ccc2sI/)=2[Cl-])=O)<SOS>CC)c
[18:48:01] ~~~~~~~~~~~~~~~~~~~~^
[18:48:01] SMILES Parse Error: Failed parsing SMILES 'C1[C@@][PH]#[nH]ccc2sI/)=2[Cl-])=O)<SOS>CC)cc1[Na+]cc2ccc' for input: 'C1[C@@][PH]#[nH]ccc2sI/)=2[Cl-])=O)<SOS>CC)cc1[Na+]cc2ccc'
[18:48:41] SMILES Parse Error: syntax error while parsing: cc1ccc(S[O-]c1ccc(S[C@H][S+]\[nH]<SOS>N\[Na+]cc\2c#
[18:48:41] SMILES Parse Error: check for mistakes around position 34:
[18:48:41] 1ccc(S[C@H][S+]\[nH]<SOS>N\[Na+]cc\2c#
[18:48:41] ~~~~~~~~~~~~~~~~~~~~^
[18:48:41] SMILES Parse Error: Failed parsing SMILES 'cc1ccc(S[O-]c1ccc(S[C@H][S+]\[nH]<SOS>N\[Na+]cc\2c#' for input: 'cc1ccc(S[O-]c1ccc(S[C@H][S+]\[nH]<SOS>N\[Na+]cc\2c#'
[18:49:11] SMILES Parse Error: extra close parentheses while parsing: c1=O)cOO)o[C@H])[Li+][C@]3[nH]cc[S+]P

 - Generated Molecule 93:  (Status: Valid)


[18:51:15] SMILES Parse Error: extra close parentheses while parsing: [C@@H]cc2)cc1
[18:51:15] SMILES Parse Error: check for mistakes around position 10:
[18:51:15] [C@@H]cc2)cc1
[18:51:15] ~~~~~~~~~^
[18:51:15] SMILES Parse Error: Failed parsing SMILES '[C@@H]cc2)cc1' for input: '[C@@H]cc2)cc1'
[18:51:24] SMILES Parse Error: syntax error while parsing: )=3(Cl
[18:51:24] SMILES Parse Error: check for mistakes around position 1:
[18:51:24] )=3(Cl
[18:51:24] ^
[18:51:24] SMILES Parse Error: Failed parsing SMILES ')=3(Cl' for input: ')=3(Cl'
[18:52:08] SMILES Parse Error: syntax error while parsing: 2CN(F[C@@H]F[C@@]#2<SOS>O=C1S[C@@]\cc1N[C@]NSc[Cl-][O-][Na+]I<SOS>c1cc
[18:52:08] SMILES Parse Error: check for mistakes around position 1:
[18:52:08] 2CN(F[C@@H]F[C@@]#2<SOS>O=C1S[C@@]\cc1N[C
[18:52:08] ^
[18:52:08] SMILES Parse Error: Failed parsing SMILES '2CN(F[C@@H]F[C@@]#2<SOS>O=C1S[C@@]\cc1N[C@]NSc[Cl-][O-][Na+]I<SOS>c1cc' for input: '2CN(F[C@@H]F[C@@]#2<SOS>O=C1S[C@@]\cc1N[C@]NSc[Cl-][


Total valid molecules generated: 7 out of 100
None
None
None
None
None
None
None


n_layers = 6
h_local = 3
prob_mask = NO MASK

Epoch 1 | Loss = 0.6208 | Accuracy = 0.4185
Epoch 2 | Loss = 0.5437 | Accuracy = 0.4909
Epoch 3 | Loss = 0.5238 | Accuracy = 0.5092
Epoch 4 | Loss = 0.5112 | Accuracy = 0.5222
Epoch 5 | Loss = 0.5034 | Accuracy = 0.5330
Epoch 6 | Loss = 0.5005 | Accuracy = 0.5372
Epoch 7 | Loss = 0.4984 | Accuracy = 0.5426
Epoch 8 | Loss = 0.4958 | Accuracy = 0.5470
Epoch 9 | Loss = 0.4954 | Accuracy = 0.5492
Epoch 10 | Loss = 0.4933 | Accuracy = 0.5497
Epoch 11 | Loss = 0.4887 | Accuracy = 0.5585
Epoch 12 | Loss = 0.4865 | Accuracy = 0.5625
Epoch 13 | Loss = 0.4855 | Accuracy = 0.5594
Epoch 14 | Loss = 0.4844 | Accuracy = 0.5634
Epoch 15 | Loss = 0.4821 | Accuracy = 0.5645
Epoch 16 | Loss = 0.4787 | Accuracy = 0.5690
Epoch 17 | Loss = 0.4776 | Accuracy = 0.5752
Epoch 18 | Loss = 0.4783 | Accuracy = 0.5689
Epoch 19 | Loss = 0.4774 | Accuracy = 0.5783
Epoch 20 | Loss = 0.4756 | Accuracy = 0.5784
Epoch 21 | Loss = 0.4753 | Accuracy = 0.5751
Epoch 22 | Loss = 0.4741 | Accuracy = 0.5789
Epoch 23 | Loss = 0.4762 | Accuracy = 0.5767
Epoch 24 | Loss = 0.4747 | Accuracy = 0.5771
Epoch 25 | Loss = 0.4732 | Accuracy = 0.5787
Epoch 26 | Loss = 0.4715 | Accuracy = 0.5852
Epoch 27 | Loss = 0.4725 | Accuracy = 0.5805
Epoch 28 | Loss = 0.4746 | Accuracy = 0.5792

/Users/ter/Apps/anaconda3/envs/tfm/lib/python3.10/site-packages/jax/_src/numpy/array_methods.py:122: UserWarning: Explicitly requested dtype <class 'jax.numpy.int64'> requested in astype is not available, and will be truncated to dtype int32. To enable more dtypes, set the jax_enable_x64 configuration option or the JAX_ENABLE_X64 shell environment variable. See https://github.com/jax-ml/jax#current-gotchas for more.
  return lax_numpy.astype(self, dtype, copy=copy, device=device)
/Users/ter/Apps/anaconda3/envs/tfm/lib/python3.10/site-packages/jax/_src/numpy/array_methods.py:122: UserWarning: Explicitly requested dtype <class 'jax.numpy.complex128'> requested in astype is not available, and will be truncated to dtype complex64. To enable more dtypes, set the jax_enable_x64 configuration option or the JAX_ENABLE_X64 shell environment variable. See https://github.com/jax-ml/jax#current-gotchas for more.
  return lax_numpy.astype(self, dtype, copy=copy, device=device)
Epoch 1 | Loss = 0.5950 | Accuracy = 0.4449
Epoch 2 | Loss = 0.5483 | Accuracy = 0.4865
Epoch 3 | Loss = 0.5334 | Accuracy = 0.5034
Epoch 4 | Loss = 0.5206 | Accuracy = 0.5211
Epoch 5 | Loss = 0.5140 | Accuracy = 0.5264
Epoch 6 | Loss = 0.5065 | Accuracy = 0.5418
Epoch 7 | Loss = 0.5021 | Accuracy = 0.5492
Epoch 8 | Loss = 0.5007 | Accuracy = 0.5510
Epoch 9 | Loss = 0.4965 | Accuracy = 0.5588
Epoch 10 | Loss = 0.4942 | Accuracy = 0.5562
Epoch 11 | Loss = 0.4921 | Accuracy = 0.5642
Epoch 12 | Loss = 0.4876 | Accuracy = 0.5738
Epoch 13 | Loss = 0.4844 | Accuracy = 0.5769
Epoch 14 | Loss = 0.4819 | Accuracy = 0.5794
Epoch 15 | Loss = 0.4814 | Accuracy = 0.5753
Epoch 16 | Loss = 0.4833 | Accuracy = 0.5763
Epoch 17 | Loss = 0.4813 | Accuracy = 0.5841
Epoch 18 | Loss = 0.4829 | Accuracy = 0.5808
Epoch 19 | Loss = 0.4831 | Accuracy = 0.5795
Epoch 20 | Loss = 0.4797 | Accuracy = 0.5832
Epoch 21 | Loss = 0.4781 | Accuracy = 0.5861
Epoch 22 | Loss = 0.4751 | Accuracy = 0.5866
Epoch 23 | Loss = 0.4750 | Accuracy = 0.5891
Epoch 24 | Loss = 0.4762 | Accuracy = 0.5872
Epoch 25 | Loss = 0.4748 | Accuracy = 0.5912

h=2

Epoch 1 | Loss = 2.6945 | Accuracy = 0.4050 
Epoch 2 | Loss = 2.4173 | Accuracy = 0.4816 
Epoch 3 | Loss = 2.3229 | Accuracy = 0.5047 
Epoch 4 | Loss = 2.2768 | Accuracy = 0.5220
Epoch 5 | Loss = 2.2505 | Accuracy = 0.5267 
Epoch 6 | Loss = 2.2214 | Accuracy = 0.5399 
Epoch 7 | Loss = 2.2051 | Accuracy = 0.5448 
Epoch 8 | Loss = 2.1918 | Accuracy = 0.5523 
Epoch 9 | Loss = 2.1790 | Accuracy = 0.5544 
Epoch 10 | Loss = 2.1683 | Accuracy = 0.5577 
Epoch 11 | Loss = 2.1654 | Accuracy = 0.5582 
Epoch 12 | Loss = 2.1519 | Accuracy = 0.5617 
Epoch 13 | Loss = 2.1452 | Accuracy = 0.5647 
Epoch 14 | Loss = 2.1371 | Accuracy = 0.5670 
Epoch 15 | Loss = 2.1380 | Accuracy = 0.5630 
Epoch 16 | Loss = 2.1330 | Accuracy = 0.5654 
Epoch 17 | Loss = 2.1251 | Accuracy = 0.5661 
Epoch 18 | Loss = 2.1189 | Accuracy = 0.5703 
Epoch 19 | Loss = 2.1176 | Accuracy = 0.5718 
Epoch 20 | Loss = 2.1152 | Accuracy = 0.5698 
Epoch 21 | Loss = 2.1162 | Accuracy = 0.5682 
Epoch 22 | Loss = 2.1117 | Accuracy = 0.5704 
Epoch 23 | Loss = 2.1126 | Accuracy = 0.5729 
Epoch 24 | Loss = 2.1097 | Accuracy = 0.5735 
Epoch 25 | Loss = 2.1040 | Accuracy = 0.5756 
Epoch 26 | Loss = 2.1001 | Accuracy = 0.5751 
Epoch 27 | Loss = 2.0959 | Accuracy = 0.5794 
Epoch 28 | Loss = 2.0881 | Accuracy = 0.5808 
Epoch 29 | Loss = 2.0891 | Accuracy = 0.5806 
Epoch 30 | Loss = 2.0869 | Accuracy = 0.5808 
Epoch 31 | Loss = 2.0843 | Accuracy = 0.5855 
Epoch 32 | Loss = 2.0859 | Accuracy = 0.5839 
Epoch 33 | Loss = 2.0841 | Accuracy = 0.5830