In [52]:
from typing import Tuple, List, Dict, Any
from rdkit import Chem
from rdkit.Chem import AllChem
from rdkit.Chem import SaltRemover, GetFormalCharge
import selfies as sf
from math import ceil, log2
import csv
import json
import pickle
import numpy as np
import pandas as pd
import math
from chembl_webresource_client.new_client import new_client

In [2]:
# Using the ChEMBL API to get the molecules dataset
molecule = new_client.molecule

# Filter for drug-like small molecules interesting for human use
druglike_molecules = molecule.filter(
    molecule_properties__heavy_atoms__lte=15,           # Heavy atoms less than 15
    molecule_properties__alogp__lte=5,                  # LogP less than 5 (Lipophilicity and membrane permeability)
    molecule_properties__mw_freebase__lte=300,          # Molecular weight less than 300 g/mol
    molecule_properties__qed_weighted__gte=0.5,         # QED weighted greater than 0.5 (Drug-likeness)
    molecule_properties__num_ro5_violations__lte=1,     # At most 1 Rule of 5 violation (Drug-likeness filter)

)

print("Training molecules set: ", len(druglike_molecules))  # Check how many molecules match the filter criteria

Training molecules set:  65778


In [55]:
# --- Set up filters ---
remover = SaltRemover.SaltRemover()

# --- Use a larger, cleaner subset ---
molecules_subset = druglike_molecules[:10000]

MAX_LEN = 0
alphabet = set()
valid_molecules_for_training = [] 
total_processed = 0
charged_skipped = 0
selfies_error_skipped = 0
mixture_skipped = 0

print(f"Starting with {len(molecules_subset)} molecules...")

for mol_data in molecules_subset:
    total_processed += 1
    smiles = mol_data.get('molecule_structures', {}).get('canonical_smiles')
    if not smiles:
        continue
        
    rdkit_mol = Chem.MolFromSmiles(smiles)
    if rdkit_mol is None:
        continue

    # Remove salts (counter-ions)
    neutral_mol = remover.StripMol(rdkit_mol)
    
    # Check every atom for formal charge (filters zwitterions)
    has_charge = False
    for atom in neutral_mol.GetAtoms():
        if atom.GetFormalCharge() != 0:
            has_charge = True
            break # Found a charged atom
    if has_charge:
        charged_skipped += 1
        continue # Skip this zwitterion/charged molecule

    # Remove ALL stereochemistry
    cleaned_smiles = Chem.MolToSmiles(neutral_mol, isomericSmiles=False)
    
    if not cleaned_smiles:
        continue
        
    try:
        selfies = sf.encoder(cleaned_smiles)
    except sf.EncoderError:
        # Skip molecules with exotic valency (like hypervalent Iodine)
        selfies_error_skipped += 1
        continue 
    
    if selfies:
        # Final check for mixtures
        if "." in selfies:
            mixture_skipped += 1
            continue # Skip any remaining molecules with '.'
            
        tokens = list(sf.split_selfies(selfies))
        if MAX_LEN < len(tokens):
            MAX_LEN = len(tokens)
        alphabet.update(tokens)
        
        # This one is good! Store it.
        valid_molecules_for_training.append((mol_data, cleaned_smiles, selfies, tokens))

# Build Final Alphabet ---
alphabet = sorted(list(alphabet))
alphabet = ['<SOS>'] + alphabet + ['<EOS>'] + ['<PAD>']

VOCABULARY_SIZE = len(alphabet)
BITS_PER_TOKEN = ceil(log2(VOCABULARY_SIZE))
MAX_LEN += 2  # For <SOS> and <EOS>

print(f"\n--- Filtering Stats ---")
print(f"Total molecules processed: {total_processed}")
print(f"Skipped (charged/zwitterion): {charged_skipped}")
print(f"Skipped (selfies valency error): {selfies_error_skipped}")
print(f"Skipped (mixture/'.'): {mixture_skipped}")
print(f"Kept for training: {len(valid_molecules_for_training)}")

print(f"\n--- Final Results ---")
print(f"Final Alphabet of SELFIES characters: {alphabet}")
print(f"Total unique characters in SELFIES: {VOCABULARY_SIZE}")
print(f"Maximum length of SELFIES in dataset: {MAX_LEN}")
print(f"Bits per token: {BITS_PER_TOKEN}")

# Create token to index mapping
token_to_index = {tok: i for i, tok in enumerate(alphabet)}

Starting with 10000 molecules...

--- Filtering Stats ---
Total molecules processed: 10000
Skipped (charged/zwitterion): 529
Skipped (selfies valency error): 7
Skipped (mixture/'.'): 2
Kept for training: 9462

--- Final Results ---
Final Alphabet of SELFIES characters: ['<SOS>', '[#Branch1]', '[#Branch2]', '[#C]', '[#N]', '[=Branch1]', '[=Branch2]', '[=C]', '[=N]', '[=O]', '[=PH1]', '[=P]', '[=Ring1]', '[=S]', '[Br]', '[Branch1]', '[Branch2]', '[C]', '[Cl]', '[F]', '[H]', '[I]', '[NH1]', '[N]', '[O]', '[PH1]', '[P]', '[Ring1]', '[Ring2]', '[S]', '<EOS>', '<PAD>']
Total unique characters in SELFIES: 32
Maximum length of SELFIES in dataset: 34
Bits per token: 5


In [56]:
# Diccionario token → índice
token_to_index = {tok: i for i, tok in enumerate(alphabet)}

def print_token_bits(tokens, token_to_index):
    for tok in tokens:
        idx = token_to_index.get(tok, None)
        if idx is None:
            print(f"Token '{tok}' no está en el diccionario.")
            continue
        binary = format(idx, f'0{BITS_PER_TOKEN}b')
        print(f"'{tok}' → index {idx} → {binary}")

print_token_bits(alphabet, token_to_index)

'<SOS>' → index 0 → 00000
'[#Branch1]' → index 1 → 00001
'[#Branch2]' → index 2 → 00010
'[#C]' → index 3 → 00011
'[#N]' → index 4 → 00100
'[=Branch1]' → index 5 → 00101
'[=Branch2]' → index 6 → 00110
'[=C]' → index 7 → 00111
'[=N]' → index 8 → 01000
'[=O]' → index 9 → 01001
'[=PH1]' → index 10 → 01010
'[=P]' → index 11 → 01011
'[=Ring1]' → index 12 → 01100
'[=S]' → index 13 → 01101
'[Br]' → index 14 → 01110
'[Branch1]' → index 15 → 01111
'[Branch2]' → index 16 → 10000
'[C]' → index 17 → 10001
'[Cl]' → index 18 → 10010
'[F]' → index 19 → 10011
'[H]' → index 20 → 10100
'[I]' → index 21 → 10101
'[NH1]' → index 22 → 10110
'[N]' → index 23 → 10111
'[O]' → index 24 → 11000
'[PH1]' → index 25 → 11001
'[P]' → index 26 → 11010
'[Ring1]' → index 27 → 11011
'[Ring2]' → index 28 → 11100
'[S]' → index 29 → 11101
'<EOS>' → index 30 → 11110
'<PAD>' → index 31 → 11111


In [57]:
basis_encoded_dataset = []
token_to_index = {tok: i for i, tok in enumerate(alphabet)}

def smiles_to_bits(tokens: list) -> np.ndarray:
    """Convert tokens to a 2D array"""
    padded_tokens = ['<SOS>'] + tokens + ['<EOS>']
    bit_matrix = []
    for tok in padded_tokens:
        idx = token_to_index[tok]
        bits = list(f"{idx:0{BITS_PER_TOKEN}b}")  # length of the binary string depends on the number of bits required to represent the alphabet
        bit_matrix.append([int(b) for b in bits])
    return np.array(bit_matrix)

In [58]:
min_logp = float('inf')
max_logp = float('-inf')
min_qed = float('inf')
max_qed = float('-inf')
min_mw = float('inf')
max_mw = float('-inf')


# Iterate through the subset of molecules to find min/max properties to normalize them
for mol in molecules_subset:
    logP = mol.get('molecule_properties', {}).get('alogp')
    qed = mol.get('molecule_properties', {}).get('qed_weighted')
    mw = mol.get('molecule_properties', {}).get('mw_freebase')

    if logP is None or qed is None or mw is None:
        continue  # Skip if any property is missing

    logP = float(logP)
    qed = float(qed)
    mw = float(mw)

    if logP < min_logp:
        min_logp = logP
    if logP > max_logp:
        max_logp = logP

    if qed < min_qed:
        min_qed = qed
    if qed > max_qed:
        max_qed = qed

    if mw < min_mw:
        min_mw = mw
    if mw > max_mw:
        max_mw = mw

print(f"LogP range: {min_logp} to {max_logp}")
print(f"QED range: {min_qed} to {max_qed}")
print(f"MW range: {min_mw} to {max_mw}")

LogP range: -2.51 to 4.89
QED range: 0.5 to 0.94
MW range: 73.14 to 299.09


In [59]:
def normalize(value, min_val, max_val, target_max=np.pi):
    ''' Normalize a value to a range [0, [0, pi] to later encode them as rotation angles'''
    norm = (value - min_val) / (max_val - min_val) * target_max
    return float(f"{norm:.3f}")

In [60]:
# Write the structured data to a CSV file
DATA_PATH = "../data/structured_data_selfies.csv"
with open(DATA_PATH, mode="w", newline="") as file:
    writer = csv.writer(file)

    header = ["logP", "qed", "mw"] + [f"token_{i}" for i in range(MAX_LEN)]
    writer.writerow(header)

    for mol in valid_molecules_for_training:
        smiles = mol[1]
        selfies = mol[-1]
        props = mol[0].get('molecule_properties', {})
        if not selfies:
            continue
        if "." in selfies:
            continue
        try:
            logP = float(props.get('alogp'))
            qed = float(props.get('qed_weighted'))
            mw = float(props.get('mw_freebase'))
        except (TypeError, ValueError):
            continue

        norm_logp = normalize(logP, min_logp, max_logp)
        norm_qed = normalize(qed, min_qed, max_qed)
        norm_mw = normalize(mw, min_mw, max_mw)

        if not all(tok in token_to_index for tok in selfies):
            continue

        bit_matrix = smiles_to_bits(selfies)  # shape (n_tokens, 6)
        token_bits_as_strings = ["".join(map(str, row)) for row in bit_matrix]
        row = [norm_logp, norm_qed, norm_mw] + token_bits_as_strings

        writer.writerow(row)

print("Maximum length of sequences in the subset:", MAX_LEN)

Maximum length of sequences in the subset: 34


In [61]:
def bitstr_to_array(bitstr):
    """Convert a string of bits (e.g., '010101') to a numpy float32 array."""
    return np.array([int(b) for b in bitstr], dtype=np.float32)

def build_training_data(df):
    """
    Build dataset tuples of (input_token_bits, molecular_properties, target_token_bits)
    from a DataFrame.

    Args:
        df (pandas.DataFrame): DataFrame containing molecular properties and token bit strings.
        n_token_cols (int): Number of token columns in the DataFrame.

    Returns:
        list of tuples: Each tuple contains (x_token: np.array, x_props: np.array, y_target: np.array)
    """
    dataset = []

    for _, row in df.iterrows():
        # Extract molecular properties as a numpy float32 array
        props = [row['logP'], row['qed'], row['mw']]
        x_props = np.array(props, dtype=np.float32)

        tokens = row[3:]  # token columns after properties

        # Iterate over token sequence to create input-target pairs
        for i in range(len(tokens) - 1):
            current_token = tokens.iloc[i]
            next_token = tokens.iloc[i + 1]

            # Skip missing or NaN tokens
            if current_token is None or (isinstance(current_token, float) and math.isnan(current_token)):
                continue
            if next_token is None or (isinstance(next_token, float) and math.isnan(next_token)):
                continue

            x_token = bitstr_to_array(current_token)
            y_target = bitstr_to_array(next_token)

            dataset.append((x_token, x_props, y_target))

    return dataset


In [62]:
# --- Load dataset
token_cols = [f"token_{i}" for i in range(MAX_LEN)]
df = pd.read_csv(DATA_PATH, dtype={col: str for col in token_cols})
dataset = build_training_data(df)  # Should return list/array of (x_token, x_props, y_target)

# --- Save dataset as CSV
DATASET_PATH = "../data/training_data_selfies.pickle"
with open(DATASET_PATH, "wb") as f:
    pickle.dump(dataset, f)
print(f"Training dataset saved to {DATASET_PATH} with {len(dataset)} samples.")

# --- Save metadata to a JSON file
metadata = {
    "vocabulary_size": VOCABULARY_SIZE,
    "bits_per_token": BITS_PER_TOKEN,
    "alphabet": alphabet,
    "max_sequence_length": MAX_LEN
}
METADATA_PATH = "../data/metadata_selfies.json"
with open(METADATA_PATH, 'w') as f:
    json.dump(metadata, f, indent=4)
print(f"Metadata saved to {METADATA_PATH}.")

Training dataset saved to ../data/training_data_selfies.pickle with 198650 samples.
Metadata saved to ../data/metadata_selfies.json.
