In [1]:
from typing import Tuple, List, Dict, Any

from rdkit import Chem
from rdkit.Chem import AllChem
import selfies as sf

import numpy as np
from math import ceil, log2
import re
import pandas as pd
import optax
import csv
import json
import flax.linen as nn
import math
import itertools

import pennylane as qml
from pennylane import qchem
from pennylane.templates import StronglyEntanglingLayers

import jax
import jax.numpy as jnp
from jax.nn.initializers import normal

import haiku as hk

## 1. Get the alphabet used for the SMILEs representation

Build the alphabet considering the structure of some special tokens

•	Símbolos de enlaces y paréntesis: #, (, ), /, \, =

•	Dígitos simples para cierres de anillos: '1', '2', '3', '4', '5'

•	Átomos orgánicos y halógenos comunes, tanto mayúsculas (alifáticos) como minúsculas (aromáticos)

•	Tokens entre corchetes para isótopos, estados de carga, quiralidad, etc.

•	Un token especial '<PAD>' para padding en modelos ML

The input SMILEs must be the same size, so we need to use the padding to make it uniform

## 2. Obtain the molecular properties of interest

logP (o cx_logp) -> Coeficiente de partición octanol/agua (lipofilia)

QED (quantitative estimate of drug-likeness) -> Escala combinada que evalúa qué tan “drug-like” es una molécula

SAS (Synthetic Accessibility Score) -> Qué tan difícil sería sintetizar la molécula en laboratorio * Needs to be calculated separately!!

MW (peso molecular) -> Masa total, típicamente ≤ 500 Da para buenos fármacos orales


First we get the max and min range values in order to later normalize the properties in the range (0, pi)

In [2]:
# Load metadata from JSON
N_MOLECS = 500
META_DATA_PATH = f"../data/metadata_selfies_{N_MOLECS}.json"

with open(META_DATA_PATH, "r") as f:
    metadata = json.load(f)

VOCABULARY_SIZE = metadata['vocabulary_size']
BITS_PER_TOKEN = metadata['bits_per_token']
MAX_LEN = metadata['max_sequence_length']
ALPHABET = metadata['alphabet']

print("Vocabulary Size:", VOCABULARY_SIZE)
print("Bits per Token:", BITS_PER_TOKEN)
print("Max Sequence Length:", MAX_LEN)


Vocabulary Size: 29
Bits per Token: 5
Max Sequence Length: 31


In [3]:
# Load training data from CSV
DATA_PATH = f"../data/training_data_selfies_{N_MOLECS}.pickle"
dataset = pd.read_pickle(DATA_PATH)

In [4]:
# Auxiliary functions

def normalize(value, min_val, max_val, target_max=np.pi):
    ''' Normalize a value to a range [0, [0, pi] to later encode them as rotation angles'''
    norm = (value - min_val) / (max_val - min_val) * target_max
    return float(f"{norm:.3f}")

def token_to_index(token):
    ''' Map a SELFIES token to its corresponding index in the ALPHABET'''
    if token in ALPHABET:
        return ALPHABET.index(token)
    else:
        return None

def bits_to_index(bits):
    powers = 2 ** jnp.arange(len(bits) - 1, -1, -1)
    return jnp.dot(bits, powers).astype(jnp.int32)

## 3. Quantum Generative Model

In [5]:
def zstring_combos(wires):
    """
    Return an ordered list of wire-tuples for all Z-strings up to order H_LOCAL.
    Order: all 1-local, then all 2-local, ..., up to H_LOCAL.
    """
    L = []
    for k in range(1, H_LOCAL + 1):
        L.extend(itertools.combinations(wires, k))
    return [tuple(c) for c in L]

def num_zstrings(n_wires):
    """Count how many Z-strings up to order H_LOCAL."""
    from math import comb
    return sum(comb(n_wires, k) for k in range(1, H_LOCAL + 1))

In [6]:
# Quantum Attention mechanism using SWAP test

# --- Device for attention ---
n_past = 5
attn_dev = qml.device("default.qubit", wires=(BITS_PER_TOKEN+1)*n_past)

@qml.qnode(attn_dev, interface="jax")
def quantum_attention_qnode(Q_vec, K_vecs):
    """
    Q_vec: projected query vector of current token
    K_vecs: list of projected key vectors for past tokens
    Returns: attention scores ⟨q_i | k_j⟩ for each j
    """
    n_tokens = len(K_vecs)
    q_wires = list(range(BITS_PER_TOKEN))

    # Encode Q and all K in parallel (different wire registers)
    def encode_token(angles, wires):
        # Use AngleEmbedding for compactness, then entangle
        qml.templates.AngleEmbedding(angles, wires=wires, rotation="Y")
        for i in range(len(wires)-1):
            qml.CNOT(wires=[wires[i], wires[i+1]])

    # Encode Q
    encode_token(Q_vec, wires=q_wires)

    # Encode K_j
    # Collect expectation values (one per K_j)
    measurements = []
    for j, K_j in enumerate(K_vecs):
        start = (BITS_PER_TOKEN+1) + j*(BITS_PER_TOKEN+1)
        k_wires = list(range(start, start+BITS_PER_TOKEN))
        encode_token(K_j, k_wires)

        # SWAP test
        ancilla = start + BITS_PER_TOKEN
        qml.Hadamard(wires=ancilla)
        for qw, kw in zip(q_wires, k_wires):
            qml.CSWAP(wires=[ancilla, qw, kw])
        qml.Hadamard(wires=ancilla)

        measurements.append(qml.expval(qml.PauliZ(ancilla)))

    # **Return as tuple** so PennyLane converts to JAX array
    return tuple(measurements)


def quantum_attention(Q_vec, K_vecs, V_vecs):
    # Make non-traced (concrete) copies for the QNode
    Q_safe = jax.lax.stop_gradient(Q_vec)
    K_safe = [jax.lax.stop_gradient(k) for k in K_vecs]

    raw_expvals = quantum_attention_qnode(Q_safe, K_safe)
    raw_expvals = jnp.asarray(raw_expvals)
    # Convert from expectation values (in [-1,1]) to probabilities [0,1]
    overlaps = (1.0 - raw_expvals) / 2.0
    return overlaps


def classical_attention(Q_vec, K_vecs, V_vecs, mask, scale=True):
    """
    Calculates classical dot-product attention with padding mask.

    Args:
        Q_vec: Query vector, shape (proj_dim,)
        K_vecs: Key matrix, shape (n_past, proj_dim)
        V_vecs: Value matrix, shape (n_past, proj_dim)
        mask: Padding mask, shape (n_past,). 
              Contains 0.0 for valid tokens and -jnp.inf for padding.
        scale: Whether to scale scores by sqrt(proj_dim).
        
    Returns:
        Output context vector, shape (proj_dim,)
    """
    if len(K_vecs) == 0:
        return V_vecs[0] if len(V_vecs) > 0 else jnp.zeros_like(Q_vec)

    K_mat = jnp.stack(K_vecs)  # shape (n_past, proj_dim)
    V_mat = jnp.stack(V_vecs)  # shape (n_past, proj_dim)
    # dot-product softmax (fully differentiable)
    scores = jnp.dot(K_mat, Q_vec)  # shape (n_past,)
    if scale:
        scores = scores / jnp.sqrt(Q_vec.shape[0])
    # Causal mask: only past tokens (K_vecs are already past tokens)
    # Adding -inf makes the softmax probability of pad tokens zero.
    scores = scores + mask
    weights = jax.nn.softmax(scores)

    # Weighted sum over classical V
    output = jnp.dot(weights, V_mat)
    return output

In [7]:
# Device and qubit setup
# BITS_PER_TOKEN number of qubits needed to encode each token
n_prop_qubits = 3  # number of qubits needed to encode properties (logP, QED, MW)
n_ancillas = 1  # number of ancilla qubits that represent the environment
n_total_qubits = n_prop_qubits + BITS_PER_TOKEN + n_ancillas

N_LAYERS = 2  # number of variational layers
H_LOCAL = 2 # h_local sets the maximum number of qubits that can interact in each Z-string term of Σ


# Name them explicitly
prop_wires = [f"prop_{i}" for i in range(n_prop_qubits)]
token_wires = [f"token_{i}" for i in range(BITS_PER_TOKEN)]
ancilla_wires = [f"ancilla_{i}" for i in range(n_ancillas)]
all_wires = prop_wires + token_wires + ancilla_wires

#dev = qml.device("default.qubit", wires=all_wires)
dev = qml.device("lightning.qubit", wires=all_wires)

def molecular_property_encoder(props):
    """Encode continuous props on property qubits via RY rotations"""
    for wire, val in zip(prop_wires, props):
        qml.RY(val, wires=wire)

def token_encoder(token_bits):
    """Basis-encode token bits on token qubits"""
    qml.BasisState(token_bits, wires=token_wires)


def operator_layer(theta_params, theta_prop, wires):
    """
    Variational layer where:
      - theta_params[...] are rotations for token + ancilla qubits
      - theta_prop encodes property→token entanglement
    """
    token_ancilla_ws = token_wires + ancilla_wires

    # Property → token entanglement 
    for p, prop_wire in enumerate(prop_wires):
        for t, t_a_wire in enumerate(token_ancilla_ws):
            qml.CRX(theta_prop[p, t, 0], wires=[prop_wire, t_a_wire])
            qml.CRY(theta_prop[p, t, 1], wires=[prop_wire, t_a_wire])

 
    qml.StronglyEntanglingLayers(
        weights=theta_params[None,:,:],  # shape: (n_token_ancilla, 3)
        wires=token_ancilla_ws
    )

def Sigma_layer_vec(gamma_vec, token_ancilla_ws, time=1.0, combos=None):
    """
    Diagonal multi-Z unitary Σ = exp(i * sum_s gamma_s * Z^{⊗|s|} * t)
    using a flat parameter vector 'gamma_vec' aligned with 'combos'.
    """
    #token_ancilla_ws = list(wires)  # pass token+ancilla here
    if combos is None:
        combos = zstring_combos(token_ancilla_ws)

    # Safety: ensure the vector length matches the number of combos
    assert gamma_vec.shape[0] == len(combos), \
        f"gamma_vec has length {gamma_vec.shape[0]} but expected {len(combos)}"

    # MultiRZ(phi) = exp(-i * phi/2 * Z^{⊗k}); choose phi = -2 * gamma * time
    for gamma, combo in zip(gamma_vec, combos):
        qml.MultiRZ(-2.0 * gamma * time, wires=list(combo))


# QNode combining encoding and variational layers
@qml.qnode(dev, interface="jax")
def autoregressive_model(token_bits, props, theta_params, theta_prop, sigma_params, output_i):
    molecular_property_encoder(props)      # Encode MW, logP, QED
    token_encoder(token_bits)              # Basis-encode token bits

    # --- Encode output_i embedding safely ---
    for i, val in enumerate(output_i):
        qml.RY(val, wires=token_wires[i])

    token_ancilla_ws = token_wires + ancilla_wires
    combos = zstring_combos(token_ancilla_ws)

    for l in range(N_LAYERS):
        # Forward V(θ)
        operator_layer(theta_params[l], theta_prop[l], wires=all_wires)

        # Diagonal Σ(γ,t): vector API
        Sigma_layer_vec(sigma_params[l], token_ancilla_ws, time=1.0, combos=combos)

        # Backward V(θ)†
        # qml.adjoint(operator_layer)(theta_params[l], theta_prop[l], wires=all_wires)

    return qml.probs(wires=token_wires), [qml.expval(qml.PauliZ(w)) for w in prop_wires]


In [8]:
def categorical_crossentropy(pred_probs, target_index):
    epsilon = 1e-10
    return -jnp.log(pred_probs[target_index] + epsilon)

def total_loss_fn(pred_probs, prop_expvals, target_index, props, alpha=0.5, epsilon=0.1):
    """Cross-entropy loss with label smoothing, normalized to [0,1]."""
    num_classes = pred_probs.shape[0] # Number of classes (tokens)
    
    # Build smoothed target
    smooth_target = jnp.full_like(pred_probs, epsilon / (num_classes - 1))
    smooth_target = smooth_target.at[target_index].set(1.0 - epsilon)
    
    # --- Compute cross-entropy loss
    ce_loss = -jnp.sum(smooth_target * jnp.log(pred_probs + 1e-10))
    
    # --- Property preservation loss (MSE) ---
    prop_expvals = jnp.array(prop_expvals)  # convert list -> JAX array
    prop_loss = jnp.mean((prop_expvals - jnp.cos(props)) ** 2)  # in [0,4]

    # --- Combine ---
    combined_loss = ce_loss + alpha * prop_loss

    # --- Normalize only once ---
    max_loss = jnp.log(num_classes) + alpha * 4.0
    final_loss = combined_loss / max_loss
    
    return final_loss


def compute_accuracy(pred_probs, target_index):
    predicted_index = jnp.argmax(pred_probs)
    return jnp.array(predicted_index == target_index, dtype=jnp.float32)


In [9]:
# Token embedding
key = jax.random.PRNGKey(42)
EMBEDDING_SIZE = BITS_PER_TOKEN + n_ancillas      # size of embeddings
key, k_emb = jax.random.split(key)
embedding_table = jax.random.normal(k_emb, (VOCABULARY_SIZE, EMBEDDING_SIZE)) * 0.1

# Projection matrices
key = jax.random.PRNGKey(42)
proj_dim = BITS_PER_TOKEN  # number of qubits for quantum attention
key, k_WQ, k_WK, k_WV = jax.random.split(key, 4)
W_Q = jax.random.normal(k_WQ, (EMBEDDING_SIZE, proj_dim)) * 0.1
W_K = jax.random.normal(k_WK, (EMBEDDING_SIZE, proj_dim)) * 0.1
W_V = jax.random.normal(k_WV, (EMBEDDING_SIZE, proj_dim)) * 0.1


# Effective qubit counts in variational layers
n_token_ancilla = BITS_PER_TOKEN + n_ancillas

# Initialize theta and sigma params
key = jax.random.PRNGKey(42)
key, k_theta, k_theta_prop, k_sigma = jax.random.split(key, 4)

# Precompute Z-string combos once
token_ancilla_ws = token_wires + ancilla_wires
combos = zstring_combos(token_ancilla_ws)
n_strings = len(combos)

# Combine all trainable parameters into a single dictionary
combined_params = {
    'theta': jax.random.normal(k_theta, (N_LAYERS, n_token_ancilla, 3)) * 0.1,
    'theta_prop': jax.random.normal(k_theta_prop, (N_LAYERS, n_prop_qubits, n_token_ancilla, 4)) * 0.1,
    'sigma': jax.random.normal(k_sigma, (N_LAYERS, n_strings)) * 0.1,
    'embedding_table': embedding_table,
    'W_Q': W_Q,
    'W_K': W_K,
    'W_V': W_V
}
# Training hyperparams
learning_rate = 0.001
n_epochs = 100

# Optimizer
optimizer = optax.adam(learning_rate)
opt_state = optimizer.init(combined_params)


PAD_index = token_to_index("<PAD>")
SOS_index = token_to_index("<SOS>")
EOS_index = token_to_index("<EOS>")

In [10]:
@jax.jit
def training_step(params, opt_state, x_token, x_props, y_target, past_token_indices=None):
    """
    Performs a single, JIT-compiled training step.
    
    Args:
        params: Dictionary of all model parameters.
        opt_state: Current state of the optimizer.
        x_token: JAX array for the current token, shape (BITS_PER_TOKEN,)
        x_props: JAX array for molecular properties, shape (3,)
        y_target: JAX array for the target token, shape (BITS_PER_TOKEN,)
        past_token_indices: JAX array of *fixed size*, shape (N_PAST,)
    """
    
    def loss_fn(params):
        # --- Unpack parameters ---
        theta_params = params['theta']
        theta_prop = params['theta_prop']
        sigma_params = params['sigma']
        embedding_table = params['embedding_table']
        W_Q = params['W_Q']
        W_K = params['W_K']
        W_V = params['W_V']

        # --- Embedding lookup ---
        #start = time.time()
        token_index = bits_to_index(x_token)
        x_i = embedding_table[token_index]
        #print("Embedding lookup:", time.time() - start, "seconds")
        
        # Positional encoding (sin/cos)
        # We calculate position by counting non-pad tokens
        position = jnp.sum(past_token_indices != PAD_index)
        dim_indices = jnp.arange(EMBEDDING_SIZE)
        pos_enc = jnp.where(
            dim_indices % 2 == 0,
            jnp.sin(position / (10000 ** (dim_indices / EMBEDDING_SIZE))),
            jnp.cos(position / (10000 ** ((dim_indices-1) / EMBEDDING_SIZE)))
        )

        # Combine token embedding + positional encoding
        x_i_pos = x_i + pos_enc

        # --- Q projection for current token ---
        #start = time.time()
        Q_i = x_i_pos @ W_Q
        
        # --- Parallel K, V projections for past tokens (JIT-friendly) ---
        
        # Use jax.vmap to apply the embedding lookup and projection
        # in parallel across the (N_PAST,) array. This is extremely fast.
        past_embeddings = jax.vmap(lambda idx: embedding_table[idx])(past_token_indices)
        
        K_vecs = jax.vmap(lambda x: x @ W_K)(past_embeddings) # shape (N_PAST, proj_dim)
        V_vecs = jax.vmap(lambda x: x @ W_V)(past_embeddings) # shape (N_PAST, proj_dim)

        # --- Create padding mask ---
        # Create a mask of 0.0 for real tokens and -inf for <PAD> tokens
        mask = jnp.where(past_token_indices == PAD_index, -jnp.inf, 0.0)

        # --- Classical Attention (JIT-friendly) ---
        # We pass the mask to the attention function
        output_i = classical_attention(Q_i, K_vecs, V_vecs, mask)

        # --- Variational model ---
        #start = time.time()
        pred_probs, expval_props = autoregressive_model(
            x_token, x_props, theta_params, theta_prop, sigma_params, output_i
        )
        #print("Variational model:", time.time() - start, "seconds")

        # --- Loss computation ---
        target_index = bits_to_index(y_target)
        # Return scalar loss for gradient computation
        return total_loss_fn(pred_probs, expval_props, target_index, x_props), pred_probs

    # value_and_grad computes both loss and grads in one pass
    (loss, pred_probs), grads = jax.value_and_grad(loss_fn, has_aux=True)(params)

    # --- Update parameters ---
    #start = time.time()
    updates, opt_state = optimizer.update(grads, opt_state, params)
    new_params = optax.apply_updates(params, updates)
    #print("Update params:", time.time() - start, "seconds")
    
    # --- Compute accuracy ---
    target_index = bits_to_index(y_target)
    acc = compute_accuracy(pred_probs, target_index)

    return new_params, loss, opt_state, grads, acc


In [11]:
import jax.lax
import time

N_PAST = 10  # Define a fixed context window size

for epoch in range(n_epochs):
    total_loss = total_acc = 0.0
    # Initialize a fixed-size array for past indices
    past_token_indices = jnp.full((N_PAST,), PAD_index, dtype=jnp.int32)
    epoch_start_time = time.time()

    for i, (x_token, x_props, y_target) in enumerate(dataset):
        # Convert to JAX arrays
        x_token = jnp.array(x_token, dtype=jnp.int32)
        x_props = jnp.array(x_props, dtype=jnp.float32)
        y_target = jnp.array(y_target, dtype=jnp.int32)
        '''
        if jnp.all(x_token == 0):
            # If current token is <SOS>, reset past tokens
            past_token_indices = jnp.full((N_PAST,), PAD_index, dtype=jnp.int32)
        '''
        # Check if current token is <SOS>
        is_sos = (bits_to_index(x_token) == 0)
        past_token_indices = jax.lax.cond(
            is_sos,
            lambda: jnp.full((N_PAST,), PAD_index, dtype=jnp.int32), # true_fun (reset)
            lambda: past_token_indices                               # false_fun (keep)
        )

        # Perform a training step
        combined_params, loss, opt_state, grads, acc = training_step(
            combined_params, 
            opt_state, 
            x_token, 
            x_props, 
            y_target, 
            past_token_indices
        )

        # Update loss and accuracy
        total_loss += loss
        total_acc  += acc

        # Update the history: roll and add the new token index
        epoch_time = time.time() - epoch_start_time
        current_token_index = bits_to_index(x_token)
        past_token_indices = jnp.roll(past_token_indices, shift=-1).at[-1].set(current_token_index)

    avg_loss = total_loss / len(dataset)
    avg_acc  = total_acc / len(dataset)
    
    print(f"Epoch {epoch+1} | Loss = {avg_loss:.4f} | Accuracy = {avg_acc:.4f} | Time = {epoch_time:.2f}s")

  return lax_numpy.astype(self, dtype, copy=copy, device=device)


KeyboardInterrupt: 

Epoch 1 | Loss = 0.5735 | Accuracy = 0.4376
Epoch 2 | Loss = 0.5484 | Accuracy = 0.4612
Epoch 3 | Loss = 0.5449 | Accuracy = 0.4657
Epoch 4 | Loss = 0.5455 | Accuracy = 0.4658
Epoch 5 | Loss = 0.5433 | Accuracy = 0.4705
Epoch 6 | Loss = 0.5404 | Accuracy = 0.4729
Epoch 7 | Loss = 0.5363 | Accuracy = 0.4798
Epoch 8 | Loss = 0.5343 | Accuracy = 0.4819
Epoch 9 | Loss = 0.5319 | Accuracy = 0.4833
Epoch 10 | Loss = 0.5321 | Accuracy = 0.4859

/Users/ter/Apps/anaconda3/envs/tfm/lib/python3.10/site-packages/jax/_src/numpy/array_methods.py:122: UserWarning: Explicitly requested dtype <class 'jax.numpy.int64'> requested in astype is not available, and will be truncated to dtype int32. To enable more dtypes, set the jax_enable_x64 configuration option or the JAX_ENABLE_X64 shell environment variable. See https://github.com/jax-ml/jax#current-gotchas for more.
  return lax_numpy.astype(self, dtype, copy=copy, device=device)
/Users/ter/Apps/anaconda3/envs/tfm/lib/python3.10/site-packages/jax/_src/numpy/array_methods.py:122: UserWarning: Explicitly requested dtype <class 'jax.numpy.complex128'> requested in astype is not available, and will be truncated to dtype complex64. To enable more dtypes, set the jax_enable_x64 configuration option or the JAX_ENABLE_X64 shell environment variable. See https://github.com/jax-ml/jax#current-gotchas for more.
  return lax_numpy.astype(self, dtype, copy=copy, device=device)
Epoch 1 | Loss = 0.6477 | Accuracy = 0.3118
Epoch 2 | Loss = 0.5822 | Accuracy = 0.3786
Epoch 3 | Loss = 0.5650 | Accuracy = 0.3998
Epoch 4 | Loss = 0.5578 | Accuracy = 0.4121
Epoch 5 | Loss = 0.5515 | Accuracy = 0.4201
Epoch 6 | Loss = 0.5460 | Accuracy = 0.4265
Epoch 7 | Loss = 0.5449 | Accuracy = 0.4255
Epoch 8 | Loss = 0.5456 | Accuracy = 0.4230
Epoch 9 | Loss = 0.5442 | Accuracy = 0.4215
Epoch 10 | Loss = 0.5412 | Accuracy = 0.4266
Epoch 11 | Loss = 0.5383 | Accuracy = 0.4258
Epoch 12 | Loss = 0.5382 | Accuracy = 0.4179
Epoch 13 | Loss = 0.5359 | Accuracy = 0.4187
Epoch 14 | Loss = 0.5346 | Accuracy = 0.4224
Epoch 15 | Loss = 0.5306 | Accuracy = 0.4264
Epoch 16 | Loss = 0.5274 | Accuracy = 0.4250
Epoch 17 | Loss = 0.5232 | Accuracy = 0.4365
Epoch 18 | Loss = 0.5178 | Accuracy = 0.4450
Epoch 19 | Loss = 0.5158 | Accuracy = 0.4438
Epoch 20 | Loss = 0.5136 | Accuracy = 0.4475
Epoch 21 | Loss = 0.5134 | Accuracy = 0.4498
Epoch 22 | Loss = 0.5122 | Accuracy = 0.4468
Epoch 23 | Loss = 0.5103 | Accuracy = 0.4546
Epoch 24 | Loss = 0.5092 | Accuracy = 0.4526
Epoch 25 | Loss = 0.5074 | Accuracy = 0.4580
Epoch 26 | Loss = 0.5094 | Accuracy = 0.4573
Epoch 27 | Loss = 0.5053 | Accuracy = 0.4572
Epoch 28 | Loss = 0.5036 | Accuracy = 0.4560
Epoch 29 | Loss = 0.5023 | Accuracy = 0.4598
Epoch 30 | Loss = 0.5006 | Accuracy = 0.4628
Epoch 31 | Loss = 0.5003 | Accuracy = 0.4611
Epoch 32 | Loss = 0.4995 | Accuracy = 0.4638
Epoch 33 | Loss = 0.5002 | Accuracy = 0.4606
Epoch 34 | Loss = 0.4991 | Accuracy = 0.4589
Epoch 35 | Loss = 0.4979 | Accuracy = 0.4584
Epoch 36 | Loss = 0.4965 | Accuracy = 0.4657
Epoch 37 | Loss = 0.4954 | Accuracy = 0.4635
Epoch 38 | Loss = 0.4958 | Accuracy = 0.4634
Epoch 39 | Loss = 0.4948 | Accuracy = 0.4638
Epoch 40 | Loss = 0.4946 | Accuracy = 0.4664
Epoch 41 | Loss = 0.4945 | Accuracy = 0.4644
Epoch 42 | Loss = 0.4948 | Accuracy = 0.4633
Epoch 43 | Loss = 0.4935 | Accuracy = 0.4682
Epoch 44 | Loss = 0.4941 | Accuracy = 0.4703
Epoch 45 | Loss = 0.4945 | Accuracy = 0.4681
Epoch 46 | Loss = 0.4930 | Accuracy = 0.4688
Epoch 47 | Loss = 0.4925 | Accuracy = 0.4685
Epoch 48 | Loss = 0.4925 | Accuracy = 0.4657
Epoch 49 | Loss = 0.4922 | Accuracy = 0.4670
Epoch 50 | Loss = 0.4912 | Accuracy = 0.4697
Epoch 51 | Loss = 0.4910 | Accuracy = 0.4695
Epoch 52 | Loss = 0.4908 | Accuracy = 0.4680
Epoch 53 | Loss = 0.4905 | Accuracy = 0.4717
Epoch 54 | Loss = 0.4905 | Accuracy = 0.4697
Epoch 55 | Loss = 0.4916 | Accuracy = 0.4666
Epoch 56 | Loss = 0.4899 | Accuracy = 0.4687
Epoch 57 | Loss = 0.4920 | Accuracy = 0.4661
Epoch 58 | Loss = 0.4910 | Accuracy = 0.4704
Epoch 59 | Loss = 0.4913 | Accuracy = 0.4729
Epoch 60 | Loss = 0.4913 | Accuracy = 0.4698
Epoch 61 | Loss = 0.4912 | Accuracy = 0.4732
Epoch 62 | Loss = 0.4895 | Accuracy = 0.4721
Epoch 63 | Loss = 0.4901 | Accuracy = 0.4688
Epoch 64 | Loss = 0.4898 | Accuracy = 0.4704
Epoch 65 | Loss = 0.4889 | Accuracy = 0.4740
Epoch 66 | Loss = 0.4883 | Accuracy = 0.4729
Epoch 67 | Loss = 0.4888 | Accuracy = 0.4735
Epoch 68 | Loss = 0.4886 | Accuracy = 0.4709
Epoch 69 | Loss = 0.4888 | Accuracy = 0.4742
Epoch 70 | Loss = 0.4884 | Accuracy = 0.4718
Epoch 71 | Loss = 0.4884 | Accuracy = 0.4713
Epoch 72 | Loss = 0.4880 | Accuracy = 0.4723
Epoch 73 | Loss = 0.4888 | Accuracy = 0.4721
Epoch 74 | Loss = 0.4894 | Accuracy = 0.4732
Epoch 75 | Loss = 0.4939 | Accuracy = 0.4699
Epoch 76 | Loss = 0.4911 | Accuracy = 0.4727
Epoch 77 | Loss = 0.4895 | Accuracy = 0.4768
Epoch 78 | Loss = 0.4901 | Accuracy = 0.4724
Epoch 79 | Loss = 0.4889 | Accuracy = 0.4751
Epoch 80 | Loss = 0.4892 | Accuracy = 0.4737
Epoch 81 | Loss = 0.4883 | Accuracy = 0.4749
Epoch 82 | Loss = 0.4878 | Accuracy = 0.4740
Epoch 83 | Loss = 0.4877 | Accuracy = 0.4738
Epoch 84 | Loss = 0.4878 | Accuracy = 0.4735
Epoch 85 | Loss = 0.4874 | Accuracy = 0.4744
Epoch 86 | Loss = 0.4884 | Accuracy = 0.4724
Epoch 87 | Loss = 0.4884 | Accuracy = 0.4739
Epoch 88 | Loss = 0.4875 | Accuracy = 0.4729
Epoch 89 | Loss = 0.4872 | Accuracy = 0.4760
Epoch 90 | Loss = 0.4869 | Accuracy = 0.4754
Epoch 91 | Loss = 0.4870 | Accuracy = 0.4730
Epoch 92 | Loss = 0.4879 | Accuracy = 0.4747
Epoch 93 | Loss = 0.4873 | Accuracy = 0.4735
Epoch 94 | Loss = 0.4874 | Accuracy = 0.4723
Epoch 95 | Loss = 0.4869 | Accuracy = 0.4745
Epoch 96 | Loss = 0.4864 | Accuracy = 0.4719
Epoch 97 | Loss = 0.4862 | Accuracy = 0.4717
Epoch 98 | Loss = 0.4864 | Accuracy = 0.4715
Epoch 99 | Loss = 0.4860 | Accuracy = 0.4734
Epoch 100 | Loss = 0.4864 | Accuracy = 0.4737

In [None]:
import os
import pickle

# Define the directory path (relative to your current location, assuming you are in /content/QGen-Mol/code)
target_dir = '../data/params/'

# Create the directory if it doesn't exist
os.makedirs(target_dir, exist_ok=True)

# Now, run your pickle code:
with open(os.path.join(target_dir, 'selfies_params.pkl'), "wb") as f:
    pickle.dump(combined_params, f)

In [None]:
from jax import random
MAX_LEN = 32  # maximum length of generated SELFIES (including <SOS> and <EOS>)
def generate_molecule_selfies_stochastic(key, props, combined_params, temperature=1.0):
    """
    Generates a molecule stochastically using Attention weights for context,
    suitable for the SELFIES model.
    """
    # --- 1. Parameter Unpacking (Match SELFIES/Attention model) ---
    embedding_table = combined_params['embedding_table']
    W_Q = combined_params['W_Q']
    W_K = combined_params['W_K']
    W_V = combined_params['W_V']
    theta_params = combined_params['theta']
    theta_prop = combined_params['theta_prop']
    sigma_params = combined_params['sigma']
    
    generated_bits = []
    past_token_indices = []
    
    # Start with <SOS>
    current_token_index = SOS_index
    # Create the <SOS> bit array (jnp.int32 array of shape (BITS_PER_TOKEN,))
    current_token_bits = jnp.array(list(map(int, format(SOS_index, f'0{BITS_PER_TOKEN}b'))), dtype=jnp.int32)
    
    # Store the local RNG state
    local_rng = key 
    
    # --- 2. Generation Loop ---
    for t in range(MAX_LEN):
        # Split the key for inner-loop stochasticity
        local_rng, subkey = random.split(local_rng)

        # Map current token bits to index and lookup embedding
        x_token = current_token_bits 
        current_token_index = bits_to_index(x_token)
        x_i = embedding_table[current_token_index]

        # Apply Positional Encoding
        position = len(past_token_indices)
        dim_indices = jnp.arange(EMBEDDING_SIZE)
        pos_enc = jnp.where(
            dim_indices % 2 == 0,
            jnp.sin(position / (10000 ** (dim_indices / EMBEDDING_SIZE))),
            jnp.cos(position / (10000 ** ((dim_indices-1) / EMBEDDING_SIZE)))
        )
        x_i_pos = x_i + pos_enc

        # --- Attention Calculation (Replaces theta-Embedding MLP) ---
        Q_i = x_i_pos @ W_Q

        if len(past_token_indices) == 0:
            # First token (after <SOS>), no past context yet
            output_i = x_i_pos @ W_V
        else:
            # Calculate K and V vectors for all past tokens
            past_embeddings = [embedding_table[idx] for idx in past_token_indices]
            K_vecs = [x @ W_K for x in past_embeddings]
            V_vecs = [x @ W_V for x in past_embeddings]
            output_i = classical_attention(Q_i, K_vecs, V_vecs)
        # -------------------------------------------------------------
        
        # 3. Predict probabilities (theta_effective is just theta_params here)
        # Note: Your SELFIES training used constant theta_params (no dynamic embedding).
        # We pass the Attention output as the final parameter, output_i.
        pred_probs, _ = autoregressive_model(x_token, props, theta_params, theta_prop, sigma_params, output_i)
        
        # 4. Tempering and Sampling
        # Ensure we only consider probabilities for valid tokens (0 to 41)
        logits = jnp.log(pred_probs[:VOCABULARY_SIZE] + 1e-10)

        # Scale logits by Temperature
        tempered_logits = logits / temperature

        # Apply softmax to get the new, tempered probability distribution
        tempered_probs = nn.softmax(tempered_logits)
        
        # Sample the next index from the Tempered probability distribution
        token_indices = jnp.arange(VOCABULARY_SIZE)
        next_index = random.choice(
            subkey, 
            token_indices, 
            p=tempered_probs 
        )

        # Convert index to bits 
        next_bits_str = format(int(next_index), f'0{BITS_PER_TOKEN}b')
        next_bits = jnp.array([int(b) for b in next_bits_str], dtype=jnp.int32)

        # 5. Check for <EOS>
        if int(next_index) == EOS_index:
            break

        # Update previous tokens for the next step
        generated_bits.append(next_bits)
        past_token_indices.append(current_token_index) # Store index of the token just consumed
        current_token_bits = next_bits # Set the new token for the next iteration

    return generated_bits

# Convert generated bits to SELFIES string (adjusted for SELFIES alphabet)
def bits_to_selfies_smiles(generated_bits):
    selfies_tokens = []
    for bits in generated_bits:
        index = int("".join(map(str, bits)), 2)
        
        # Safety clip, though the sampling process above should enforce this
        if index >= VOCABULARY_SIZE or index == 0:
            break


        token = alphabet[int(index)]
        if token == '<EOS>':
            break
        selfies_tokens.append(token)
    
    selfies_str = ''.join(selfies_tokens)
    smiles_str = sf.decoder(selfies_str)
    return selfies_str, smiles_str

In [None]:
N_MOLECS = 5
MASTER_KEY = jr.PRNGKey(42)  # Fixed key for reproducibility

# Target properties (mid-range example normalized to [0, pi])
desired_logp = 1.2
desired_qed = 0.71
desired_mw = 205.0

norm_logp = normalize(desired_logp, min_logp, max_logp)
norm_qed = normalize(desired_qed, min_qed, max_qed)
norm_mw = normalize(desired_mw, min_mw, max_mw)
desired_props = jnp.array([norm_logp, norm_qed, norm_mw], dtype=jnp.float32)

print(f"Target Properties (Normalized to [0, pi]):")
print(f"   LogP: {norm_logp:.3f}, QED: {norm_qed:.3f}, MW: {norm_mw:.3f}\n")


def generate_molecules(props, params):
    selfies_list = []
    smiles_list = []
    keys = jr.split(MASTER_KEY, N_MOLECS)
    for i in range(N_MOLECS):
        rng_i = keys[i]
        generated_bits = generate_molecule_selfies_stochastic(rng_i, props, params)
        generated_selfies, generated_smiles = bits_to_selfies_smiles(generated_bits)
        selfies_list.append(generated_selfies)
        smiles_list.append(generated_smiles)
    return selfies_list, smiles_list

selfies_list, smiles_list = generate_molecules(desired_props, combined_params)

Target Properties (Normalized to [0, pi]):
   LogP: 1.574, QED: 1.571, MW: 1.577



  return lax_numpy.astype(self, dtype, copy=copy, device=device)


In [None]:
from rdkit import Chem
from rdkit.Chem import Descriptors
from rdkit.Chem import QED

def analyze_molecule_properties(selfies_list, smiles_list, target_logp, target_qed, target_mw):
    """Calculates and prints the physicochemical properties for generated molecules,
       comparing them directly against the denormalized targets."""
    
    results = []
    
    # --- Print Denormalized Target Properties ---
    print(f"\n--- Target Properties ---")
    print(f"LogP: {target_logp:.2f}")
    print(f"QED: {target_qed:.2f}")
    print(f"MW: {target_mw:.2f} g/mol")
    print("-" * 35)

    for i, (smiles, selfies) in enumerate(zip(smiles_list, selfies_list)):
        # Handle empty/invalid SMILES from generation failure
        if not smiles:
            results.append({"Molecule": i+1, "SMILES": "N/A", "LogP": np.nan, "QED": np.nan, "MW": np.nan})
            continue

        mol = Chem.MolFromSmiles(smiles)

        if mol is not None:
            try:
                logp = Descriptors.MolLogP(mol)
                qed_score = QED.qed(mol)
                mw = Descriptors.ExactMolWt(mol)

                results.append({
                    "Molecule": i+1,
                    "SELFIES": selfies,
                    "SMILES": smiles,
                    "LogP": logp,
                    "QED": qed_score,
                    "MW": mw,
                    "Validity": "Valid"
                })
            except Exception:
                 results.append({"Molecule": i+1, "SELFIES": selfies, "SMILES": smiles, "LogP": np.nan, "QED": np.nan, "MW": np.nan, "Validity": "Error"})
        else:
            results.append({"Molecule": i+1, "SELFIES": selfies, "SMILES": smiles, "LogP": np.nan, "QED": np.nan, "MW": np.nan, "Validity": "Invalid"})

    # Create and display the DataFrame
    df = pd.DataFrame(results)
    
    # Add a row for the target properties for easy comparison
    target_row = pd.Series({
        "Molecule": "TARGET", 
        "SELFIES": "TARGET",
        "SMILES": "TARGET", 
        "LogP": target_logp, 
        "QED": target_qed, 
        "MW": target_mw
    }, name="TARGET").to_frame().T
    
    # Concatenate the target row and the results for visual comparison
    df_styled = pd.concat([target_row.set_index('Molecule'), df.set_index('Molecule')])
    
    # Format numerical columns for presentation
    df_styled = df_styled.apply(pd.to_numeric, errors='ignore').round(2)
    
    return df_styled

df_styled = analyze_molecule_properties(selfies_list, smiles_list, desired_logp, desired_qed, desired_mw)

# Save the results to a CSV file
output_csv_path = "../generation/generated_selfies.csv"
df_styled.to_csv(output_csv_path)


--- Target Properties ---
LogP: 1.20
QED: 0.71
MW: 205.00 g/mol
-----------------------------------


  df_styled = df_styled.apply(pd.to_numeric, errors='ignore').round(2)
