In [1]:
# ============================================================================
# CARICAMENTO MODELLO E VOCABOLARIO
# ============================================================================
# Paths (MODIFICA QUESTI PERCORSI!)
MODEL_PATH = "/home/grad/Desktop/pietro/denovo/2/final_hybrid_model.keras"
CHAR2IDX_PATH = "/home/grad/Desktop/pietro/denovo/2/char2idx.pkl"
IDX2CHAR_PATH = "/home/grad/Desktop/pietro/denovo/2/idx2char.pkl"
VOCAB_PATH = "/home/grad/Desktop/pietro/denovo/2/vocab.json"
MAX_LENGTH = 140
TRAINING_FILE = "/home/grad/Desktop/pietro/denovo/s4-for-de-novo-drug-design/s4_loro/gen_mio/eval_out_fxr/train.smi"  # Per novelty

In [2]:
##### ============================================================================
# GENERAZIONE CONDIZIONATA: SMARTS e/o SCAFFOLD ‚Üí SMILES
# ============================================================================
import os
import sys
import pickle
import json
import numpy as np
import tensorflow as tf
from rdkit import Chem
from rdkit.Chem import QED
import csv
import time
import re
from typing import List, Optional

# Reproducibilit√†
SEED = 42
np.random.seed(SEED)
tf.random.set_seed(SEED)

# Configurazione GPU
gpus = tf.config.list_physical_devices('GPU')
if gpus:
    try:
        tf.config.set_visible_devices(gpus[0], 'GPU')
        print(f"‚úì GPU attiva: {gpus[0]}")
    except RuntimeError as e:
        print(f"‚ö† Errore GPU: {e}")
else:
    print("‚Üí CPU mode attivo")

# ============================================================================
# CONFIGURAZIONE
# ============================================================================
class Config:
    EMBED_DIM = 1024
    TRANSFORMER_LAYERS = 6
    TRANSFORMER_HEADS = 6
    FF_DIM = 512
    DROPOUT_RATE = 0.1
    L2_REG = 1e-4
    MAX_LENGTH = 140  # Usa lo stesso valore del training
    GRADIENT_CLIP = 1.0

config = Config()

# ============================================================================
# CUSTOM OBJECTS (stesse definizioni del training)
# ============================================================================
from tensorflow.keras.layers import Layer, Dense, Dropout, Embedding, LayerNormalization, MultiHeadAttention, Input
from tensorflow.keras.models import load_model, Model

def smoothed_loss(y_true, y_pred):
    y_true_int = tf.cast(y_true, tf.int32)
    mask = tf.cast(tf.math.not_equal(y_true, 0), tf.float32)
    loss = tf.nn.sparse_softmax_cross_entropy_with_logits(labels=y_true_int, logits=y_pred)
    return tf.reduce_sum(loss * mask) / (tf.reduce_sum(mask) + 1e-9)

class DynamicPositionalEncoding(Layer):
    def __init__(self, embed_dim, **kwargs):
        super().__init__(**kwargs)
        self.embed_dim = embed_dim
        
    def build(self, input_shape):
        max_len = config.MAX_LENGTH
        pos = np.arange(max_len)[:, np.newaxis]
        i = np.arange(self.embed_dim)[np.newaxis, :]
        angle_rates = 1 / np.power(10000, (2 * (i // 2)) / np.float32(self.embed_dim))
        angle_rads = pos * angle_rates
        angle_rads[:, 0::2] = tf.math.sin(angle_rads[:, 0::2])
        angle_rads[:, 1::2] = tf.math.cos(angle_rads[:, 1::2])
        self.pos_encoding = tf.cast(angle_rads[np.newaxis, ...], dtype=tf.float32)
        
    def call(self, inputs):
        seq_len = tf.shape(inputs)[1]
        return inputs + self.pos_encoding[:, :seq_len, :]
        
    def get_config(self):
        return {**super().get_config(), "embed_dim": self.embed_dim}

class ImprovedTransformerBlock(Layer):
    def __init__(self, embed_dim, num_heads, ffn_dim, rate=0.1, **kwargs):
        super().__init__(**kwargs)
        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.ffn_dim = ffn_dim
        self.rate = rate
        
        self.mha = MultiHeadAttention(
            num_heads=num_heads, 
            key_dim=embed_dim, 
            kernel_regularizer=tf.keras.regularizers.l2(config.L2_REG)
        )
        self.ffn = tf.keras.Sequential([
            Dense(ffn_dim, activation="gelu", kernel_regularizer=tf.keras.regularizers.l2(config.L2_REG)),
            Dense(embed_dim, kernel_regularizer=tf.keras.regularizers.l2(config.L2_REG))
        ])
        self.ln1 = LayerNormalization(epsilon=1e-6)
        self.ln2 = LayerNormalization(epsilon=1e-6)
        self.d1 = Dropout(rate)
        self.d2 = Dropout(rate)

    def call(self, inputs, training=False, mask=None):
        seq_len = tf.shape(inputs)[1]
        causal_mask = tf.linalg.band_part(tf.ones((seq_len, seq_len)), -1, 0)
        attn_output = self.mha(inputs, inputs, attention_mask=causal_mask)
        out1 = self.ln1(inputs + self.d1(attn_output, training=training))
        ffn_output = self.ffn(out1)
        return self.ln2(out1 + self.d2(ffn_output, training=training))
        
    def get_config(self):
        return {**super().get_config(), "embed_dim": self.embed_dim, "num_heads": self.num_heads, "ffn_dim": self.ffn_dim, "rate": self.rate}

class CustomSchedule(tf.keras.optimizers.schedules.LearningRateSchedule):
    def __init__(self, embed_dim, warmup_steps=4000):
        super().__init__()
        self.embed_dim = tf.cast(embed_dim, tf.float32)
        self.warmup_steps = tf.cast(warmup_steps, tf.float32)

    def __call__(self, step):
        step = tf.cast(step, tf.float32) + 1.0
        arg1 = tf.math.rsqrt(step)
        arg2 = step * (self.warmup_steps ** -1.5)
        return tf.math.rsqrt(self.embed_dim) * tf.math.minimum(arg1, arg2)

    def get_config(self):
        return {
            "embed_dim": float(self.embed_dim.numpy()),
            "warmup_steps": float(self.warmup_steps.numpy())
        }

def build_improved_model(vocab_size: int) -> Model:
    inputs = Input(shape=(config.MAX_LENGTH,))
    x = Embedding(vocab_size, config.EMBED_DIM, mask_zero=False)(inputs)
    x = DynamicPositionalEncoding(config.EMBED_DIM)(x)
    x = Dropout(config.DROPOUT_RATE)(x)
    
    for _ in range(config.TRANSFORMER_LAYERS):
        x = ImprovedTransformerBlock(
            config.EMBED_DIM, 
            config.TRANSFORMER_HEADS, 
            config.FF_DIM, 
            rate=config.DROPOUT_RATE
        )(x)
        
    outputs = Dense(vocab_size)(x)
    
    opt = tf.keras.optimizers.Adam(
        learning_rate=CustomSchedule(config.EMBED_DIM), 
        clipnorm=config.GRADIENT_CLIP
    )
    
    model = Model(inputs, outputs)
    model.compile(optimizer=opt, loss=smoothed_loss)
    return model

# ============================================================================
# FUNZIONI DI GENERAZIONE
# ============================================================================
def robust_tokenize(smiles: str) -> list:
    pattern = (
        r"(\[[^\[\]]{1,6}\]|"
        r"Br|Cl|Si|Na|Mg|Mn|Ca|Fe|Zn|Se|Li|K|Al|B|R[0-9]|r[0-9]|a[0-9]|"
        r"[A-Za-z0-9@+\-\\\/\(\)=#$.])"
    )
    tokens = re.findall(pattern, smiles)
    return tokens

from rdkit.Chem.Scaffolds import MurckoScaffold
from rdkit import Chem

def user_scaffold_to_training_murcko_smiles(scaffold_smiles: str):
    """
    Replica la pipeline del training:
    1) parse + sanitize
    2) Kekulize(clearAromaticFlags=True) sul MOL INTERO
    3) MolToSmiles(canonical=True, isomericSmiles=False)
    4) rilegge quel SMILES (come nel training)
    5) MurckoScaffold.GetScaffoldForMol
    6) MolToSmiles(canonical=True, isomericSmiles=False)
    """
    try:
        mol = Chem.MolFromSmiles(scaffold_smiles, sanitize=True)
        if not mol:
            return None

        # 1) identico a validate_and_fix_smiles
        try:
            Chem.Kekulize(mol, clearAromaticFlags=True)
        except:
            pass

        fixed = Chem.MolToSmiles(mol, canonical=True, isomericSmiles=False)
        if not fixed:
            return None

        # 2) come training: ricarico da fixed
        mol2 = Chem.MolFromSmiles(fixed, sanitize=True)
        if not mol2:
            return None

        scaf = MurckoScaffold.GetScaffoldForMol(mol2)
        if not scaf:
            return None

        return Chem.MolToSmiles(scaf, canonical=True, isomericSmiles=False)
    except:
        return None


def encode_prompt(
    char2idx: dict,
    smart_names: Optional[List[str]] = None,
    scaffold_smiles: Optional[str] = None
) -> List[int]:
    """
    Codifica SMARTS e/o scaffold in token indices per il prompt.
    
    Args:
        char2idx: dizionario token‚Üíindice
        smart_names: lista di nomi SMARTS (es: ["Alcohol", "Ketone"])
        scaffold_smiles: stringa SMILES del scaffold
        
    Returns:
        Lista di indici token per il prompt
    """
    tokens = ['<START>']
    
    if smart_names:
        tokens.extend(smart_names)
    else:
        # Se non ci sono SMARTS, aggiungiamo comunque SEP per formato consistente
        pass
    
    tokens.append('<SEP>')

    if scaffold_smiles:
        scaffold_smiles = user_scaffold_to_training_murcko_smiles(scaffold_smiles)

        if scaffold_smiles:
            scaffold_tokens = robust_tokenize(scaffold_smiles)
            tokens.extend(scaffold_tokens)
            tokens.append('<SEP>')


    
    # Converti in indici
    UNK_IDX = char2idx.get('<UNK>', char2idx['<PAD>'])
    return [char2idx.get(t, UNK_IDX) for t in tokens]

def generate_smiles_batch_conditioned(model, char2idx, idx2char, max_length, batch_size=64, temperature=1.0,
                                      prompt_indices: Optional[List[int]] = None):
    PAD_IDX = char2idx['<PAD>']
    END_IDX = char2idx['<END>']
    
    input_seqs = np.full((batch_size, max_length), PAD_IDX, dtype=np.int32)
    start_step = 0
    
    if prompt_indices is None:
        input_seqs[:, 0] = char2idx['<START>']
        start_step = 1
    else:
        prompt_len = len(prompt_indices)
        if prompt_len >= max_length:
             raise ValueError(f"Prompt troppo lungo ({prompt_len}) rispetto a max_length ({max_length})")
        for i in range(batch_size):
            input_seqs[i, :prompt_len] = prompt_indices
        start_step = prompt_len

    finished = np.zeros(batch_size, dtype=bool)

    for t in range(start_step, max_length):
        logits = model(input_seqs, training=False)[:, t-1, :]
        step_probs = tf.nn.softmax(logits / temperature).numpy()

        for i in range(batch_size):
            if not finished[i]:
                if np.sum(step_probs[i]) < 1e-6:
                    finished[i] = True
                    continue
                sampled = np.random.choice(len(step_probs[i]), p=step_probs[i])
                input_seqs[i, t] = sampled
                if sampled == END_IDX:
                    finished[i] = True
                    
        if finished.all():
            break

    raw_smiles_list = []
    for seq in input_seqs:
        tokens = [idx2char.get(idx, '') for idx in seq if idx != PAD_IDX]
        sep_indices = [j for j, x in enumerate(tokens) if x == '<SEP>']
        
        if len(sep_indices) >= 1:
            target_tokens = tokens[sep_indices[-1] + 1:]
        else:
            target_tokens = [t for t in tokens if t != '<START>']

        raw_smiles = "".join(target_tokens).split('<END>')[0]
        raw_smiles_list.append(raw_smiles)

    return raw_smiles_list

def generate_from_prompt(
    model,
    char2idx: dict,
    idx2char: dict,
    max_length: int,
    smart_names: Optional[List[str]] = None,
    scaffold_smiles: Optional[str] = None,
    batch_size: int = 1,
    temperature: float = 1.0
) -> List[str]:
    """
    Genera SMILES condizionati da SMARTS e/o scaffold.
    
    Args:
        model: modello addestrato
        char2idx: dizionario token‚Üíindice
        idx2char: dizionario indice‚Üítoken
        max_length: lunghezza massima sequenza
        smart_names: lista nomi SMARTS (opzionale)
        scaffold_smiles: stringa scaffold (opzionale)
        batch_size: numero di molecole da generare
        temperature: temperatura per sampling
        
    Returns:
        Lista di SMILES generati
    """
    if not smart_names and not scaffold_smiles:
        # Generazione incondizionata
        prompt_indices = None
    else:
        # Costruisci prompt
        prompt_indices = encode_prompt(char2idx, smart_names, scaffold_smiles)
    
    return generate_smiles_batch_conditioned(
        model, char2idx, idx2char, max_length,
        batch_size, temperature, prompt_indices
    )



# Carica vocabolario e dizionari
try:
    with open(CHAR2IDX_PATH, "rb") as f:
        char2idx = pickle.load(f)
    with open(IDX2CHAR_PATH, "rb") as f:
        idx2char = pickle.load(f)
    with open(VOCAB_PATH, "r") as f:
        vocab = json.load(f)
    print(f"‚úì Vocabolario caricato: {len(vocab)} token")
except Exception as e:
    print(f"‚ùå Errore caricamento dizionari: {e}")
    sys.exit(1)

# Carica training set per novelty
training_smiles_set = set()
if os.path.exists(TRAINING_FILE):
    print(f"‚Üí Caricamento training set per novelty: {TRAINING_FILE}")
    with open(TRAINING_FILE, "r") as f:
        for line in f:
            smi = line.strip()
            if smi:
                try:
                    mol = Chem.MolFromSmiles(smi)
                    if mol:
                        canon_smi = Chem.MolToSmiles(mol, canonical=True, isomericSmiles=False)
                        training_smiles_set.add(canon_smi)
                except: pass
    print(f"‚úì {len(training_smiles_set)} SMILES di training caricati")
else:
    print(f"‚ö† File di training non trovato, novelty disabilitata")
    training_smiles_set = set()

# Carica modello
custom_objects = {
    "DynamicPositionalEncoding": DynamicPositionalEncoding,
    "ImprovedTransformerBlock": ImprovedTransformerBlock,
    "CustomSchedule": CustomSchedule,
    "smoothed_loss": smoothed_loss,
}

try:
    print(f"‚Üí Caricamento modello da: {MODEL_PATH}")
    model = load_model(MODEL_PATH, custom_objects=custom_objects)
    model.compile(optimizer='adam', loss=smoothed_loss)
    print("‚úì Modello caricato e pronto")
except Exception as e:
    print(f"‚ùå Errore caricamento modello: {e}")
    sys.exit(1)

# ============================================================================
# FUNZIONE PER BATCH GENERATION CON REPORT
# ============================================================================
# ============================================================================
# FUNZIONE PER BATCH GENERATION CON REPORT (MODIFICATA)
# ============================================================================
def evaluate_and_save_batches_conditional(
    model, char2idx, idx2char, max_length, training_smiles_set,
    smart_names: Optional[List[str]] = None,
    scaffold_smiles: Optional[str] = None,
    out_csv_path: str = "generated_conditional.csv",
    num_batches: int = 10,
    batch_size: int = 64,
    temperature: float = 1.0,
    print_progress: bool = True
):
    """
    Genera molecole condizionate, FILTRA per scaffold match, calcola metriche e salva risultati.
    """
    all_raw_generated = []
    start_time = time.time()
    
    # Progress bar
    try:
        from tqdm import tqdm
        iterator = tqdm(range(num_batches), desc="Generazione")
    except ImportError:
        iterator = range(num_batches)
        print("Installa tqdm per vedere la progress bar: pip install tqdm")
    
    # Genera batch
    for b in iterator:
        generated_raw = generate_from_prompt(
            model, char2idx, idx2char, max_length,
            smart_names=smart_names,
            scaffold_smiles=scaffold_smiles,
            batch_size=batch_size,
            temperature=temperature
        )
        all_raw_generated.extend(generated_raw)
        if not hasattr(iterator, '__len__'):
            print(f"Batch {b+1}/{num_batches}: {len(generated_raw)} SMILES")
    
    end_time = time.time()
    total_inference_time = end_time - start_time

    # Post-processing
    if print_progress:
        print("\n‚Üí Post-processing (validazione, canonicalization)...")
    
    valid_smiles = []
    for smi in all_raw_generated:
        if not smi: continue
        try:
            mol = Chem.MolFromSmiles(smi)
            if mol:
                canon = Chem.MolToSmiles(mol, canonical=True, isomericSmiles=False)
                valid_smiles.append(canon)
        except: pass
    
    # Lista di partenza: molecole valide e uniche
    unique_smiles = list(set(valid_smiles))
    
    # --- NUOVA LOGICA: SCAFFOLD FILTERING ---
    scaffold_match_rate = 0.0
    filtered_smiles = unique_smiles # Default: se non c'√® scaffold, prendiamo tutto
    
    if scaffold_smiles:
        if print_progress:
            print("‚Üí Verifica corrispondenza Scaffold...")
            
        # 1. Otteniamo la forma canonica "Murcko" dello scaffold target
        target_scaffold_canon = user_scaffold_to_training_murcko_smiles(scaffold_smiles)
        
        if target_scaffold_canon:
            matching_smiles = []
            for smi in unique_smiles:
                # 2. Estraiamo lo scaffold dalla molecola generata con la stessa funzione
                gen_scaffold = user_scaffold_to_training_murcko_smiles(smi)
                
                # 3. Controllo identit√† stringa
                if gen_scaffold == target_scaffold_canon:
                    matching_smiles.append(smi)
            
            # Calcolo statistiche
            if len(unique_smiles) > 0:
                scaffold_match_rate = len(matching_smiles) / len(unique_smiles)
            
            # Sovrascriviamo la lista da analizzare e salvare
            filtered_smiles = matching_smiles
            
            if print_progress:
                print(f"‚úì Target Scaffold: {target_scaffold_canon}")
                print(f"‚úì Match Rate: {len(filtered_smiles)}/{len(unique_smiles)} ({scaffold_match_rate*100:.2f}%)")
        else:
            if print_progress:
                print("‚ö† Attenzione: Impossibile calcolare Murcko Scaffold dall'input fornito. Salto il filtro.")

    # ----------------------------------------

    # Calcolo propriet√† (SUI FILTRATI)
    qed_list, sa_list = [], []
    try:
        from rdkit.Chem import RDConfig
        sys.path.append(os.path.join(RDConfig.RDContribDir, 'SA_Score'))
        import sascorer
        HAS_SA = True
    except:
        HAS_SA = False
        if print_progress:
            print("‚ö† SA_Score non disponibile")

    for smi in filtered_smiles:
        mol = Chem.MolFromSmiles(smi)
        if mol:
            try: 
                qed_list.append(QED.qed(mol))
            except: pass
            if HAS_SA:
                try: 
                    sa_list.append(sascorer.calculateScore(mol))
                except: pass

    # Salva risultati (SOLO NOVEL E FILTRATI)
    only_novel_smiles = [smi for smi in filtered_smiles if smi not in training_smiles_set]

    os.makedirs(os.path.dirname(out_csv_path) if os.path.dirname(out_csv_path) else '.', exist_ok=True)
    with open(out_csv_path, 'w', newline='', encoding='utf-8') as csvfile:
        writer = csv.writer(csvfile)
        writer.writerow(["SMILES"]) # Header opzionale, rimuovilo se preferisci file raw
        for smi in only_novel_smiles:
            writer.writerow([smi])


    # Metriche Finali
    total_generated_count = len(all_raw_generated)
    validity = len(valid_smiles) / total_generated_count if total_generated_count else 0
    uniqueness = len(unique_smiles) / len(valid_smiles) if valid_smiles else 0
    
    # Novelty calcolata sui filtrati
    novelty_rate = len(only_novel_smiles) / len(filtered_smiles) if filtered_smiles else 0
    
    avg_qed = np.mean(qed_list) if qed_list else 0
    avg_sa = np.mean(sa_list) if sa_list else 0

    time_per_batch = total_inference_time / num_batches if num_batches > 0 else 0
    time_per_molecule = total_inference_time / total_generated_count if total_generated_count else 0
    
    # Report condizionato
    condition_desc = []
    if smart_names:
        condition_desc.append(f"SMARTS={smart_names}")
    if scaffold_smiles:
        condition_desc.append(f"SCAFFOLD={scaffold_smiles}")
    condition_str = " | ".join(condition_desc) if condition_desc else "Incondizionata"
    
    scaffold_info_str = ""
    if scaffold_smiles:
        scaffold_info_str = f"  Scaffold Match Rate:      {scaffold_match_rate*100:.2f}% (su {len(unique_smiles)} unici)"

    if print_progress:
        print(f"""
{'='*70}
üìä GENERATION REPORT - {condition_str}
{'='*70}
  Total generated (raw):    {total_generated_count:,}
  Valid SMILES:             {len(valid_smiles):,} ({validity*100:.2f}%)
  Unique SMILES:            {len(unique_smiles):,} (uniqueness: {uniqueness*100:.2f}%)
{scaffold_info_str}
  Saved (Filtered & Novel): {len(only_novel_smiles):,}
  Novelty (on filtered):    {novelty_rate*100:.2f}%
  Average QED (filtered):   {avg_qed:.4f}
  Average SA (filtered):    {avg_sa:.4f}
  
‚è±Ô∏è  INFERENCE TIME
  Total time:               {total_inference_time:.2f} seconds
  Time per batch:           {time_per_batch:.3f} seconds
  Time per molecule:        {time_per_molecule*1000:.2f} ms
{'='*70}
""")
    
    return {
        "total_input": total_generated_count,
        "valid_count": len(valid_smiles),
        "validity_rate": validity,
        "unique_count": len(unique_smiles),
        "uniqueness_rate": uniqueness,
        "scaffold_match_rate": scaffold_match_rate,
        "novelty_rate": novelty_rate,
        "avg_qed": avg_qed,
        "avg_sa": avg_sa,
        "output_file": out_csv_path
    }



2026-02-16 12:06:13.218438: I tensorflow/core/util/port.cc:153] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2026-02-16 12:06:13.232665: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:467] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1771239973.249867 1255999 cuda_dnn.cc:8579] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1771239973.255097 1255999 cuda_blas.cc:1407] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
W0000 00:00:1771239973.268343 1255999 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking 

‚úì GPU attiva: PhysicalDevice(name='/physical_device:GPU:0', device_type='GPU')
‚úì Vocabolario caricato: 370 token
‚Üí Caricamento training set per novelty: /home/grad/Desktop/pietro/denovo/s4-for-de-novo-drug-design/s4_loro/gen_mio/eval_out_fxr/train.smi
‚úì 702 SMILES di training caricati
‚Üí Caricamento modello da: /home/grad/Desktop/pietro/denovo/2/final_hybrid_model.keras


I0000 00:00:1771239975.945674 1255999 gpu_device.cc:2019] Created device /job:localhost/replica:0/task:0/device:GPU:0 with 9025 MB memory:  -> device: 0, name: NVIDIA RTX A2000 12GB, pci bus id: 0000:65:00.0, compute capability: 8.6


‚úì Modello caricato e pronto


In [None]:
# ============================================================================
# INTERACTIVE USAGE EXAMPLES
# ============================================================================
if __name__ == "__main__":
    print("\n" + "="*70)
    print("CONDITIONAL GENERATION: SMARTS and/or SCAFFOLD ‚Üí SMILES")
    print("="*70)
    from rdkit import RDLogger
    RDLogger.DisableLog('rdApp.*')
    while True:
        print("\n--- Menu ---")
        print("1. Generate from SMARTS")
        print("2. Generate from SCAFFOLD")
        print("3. Generate from SMARTS + SCAFFOLD")
        print("4. Unconditional generation")
        print("5. BATCH generation with full report")
        print("0. Exit")
        
        choice = input("\nChoose an option (0-5): ").strip()
        
        if choice == "0":
            print("Exiting...")
            break
        
        # Option 5: Batch generation with report
        if choice == "5":
            print("\n--- Multi-Batch Generation with Report ---")
            batch_size = int(input("Molecules per batch (default 64): ").strip() or "64")
            num_batches = int(input("Number of batches (default 10): ").strip() or "10")
            temperature = float(input("Temperature (default 1.0): ").strip() or "1.0")
            
            smart_names = None
            scaffold_smiles = None
            
            # Ask whether to use conditioning
            use_condition = input("Use conditioning? (y/n, default n): ").strip().lower()
            if use_condition == 's':
                cond_type = input("Type (smarts/scaffold/both): ").strip().lower()
                if cond_type in ['smarts', 'both']:
                    smart_input = input("Enter SMARTS names separated by comma: ").strip()
                    if smart_input:
                        smart_names = [s.strip() for s in smart_input.split(',')]
                if cond_type in ['scaffold', 'both']:
                    scaffold_smiles = input("Enter scaffold SMILES: ").strip()
            
            output_file = input("Output file path (default ./generated_batch.csv): ").strip() or "generated_batch.csv"
            
            # Run batch generation
            metrics = evaluate_and_save_batches_conditional(
                model, char2idx, idx2char, MAX_LENGTH, training_smiles_set,
                smart_names=smart_names,
                scaffold_smiles=scaffold_smiles,
                out_csv_path=output_file,
                num_batches=num_batches,
                batch_size=batch_size,
                temperature=temperature
            )
            
            print(f"\n‚úì Generation completed!")
            continue
        
        # Options 1-4: Single/interactive generation
        batch_size = int(input("Number of molecules to generate (1-20): ").strip())
        temperature = float(input("Temperature (0.1-2.0, default 1.0): ").strip() or "1.0")
        
        smart_names = None
        scaffold_smiles = None
        
        if choice == "1":
            smart_input = input("Enter SMARTS names separated by comma (e.g.: Alcohol,Ketone): ").strip()
            if smart_input:
                smart_names = [s.strip() for s in smart_input.split(',')]
        
        elif choice == "2":
            scaffold_smiles = input("Enter scaffold SMILES (e.g.: c1ccccc1): ").strip()
        
        elif choice == "3":
            smart_input = input("Enter SMARTS names separated by comma (e.g.: Alcohol,Ketone): ").strip()
            if smart_input:
                smart_names = [s.strip() for s in smart_input.split(',')]
            scaffold_smiles = input("Enter scaffold SMILES (e.g.: c1ccccc1): ").strip()
        
        elif choice == "4":
            pass
        
        else:
            print("Invalid option!")
            continue
        
        # Generate
        print("\n‚Üí Generating...")
        generated_smiles = generate_from_prompt(
            model, char2idx, idx2char, MAX_LENGTH,
            smart_names=smart_names,
            scaffold_smiles=scaffold_smiles,
            batch_size=batch_size,
            temperature=temperature
        )
        
        # Print results
        print("\n‚úì Results:")
        for i, smi in enumerate(generated_smiles, 1):
            mol = Chem.MolFromSmiles(smi)
            if mol:
                qed_val = QED.qed(mol) if mol else 0
                print(f"  {i}. {smi} (QED: {qed_val:.3f})")
            else:
                print(f"  {i}. {smi} (INVALID)")



CONDITIONAL GENERATION: SMARTS and/or SCAFFOLD ‚Üí SMILES

--- Menu ---
1. Generate from SMARTS
2. Generate from SCAFFOLD
3. Generate from SMARTS + SCAFFOLD
4. Unconditional generation
5. BATCH generation with full report
0. Exit



Choose an option (0-5):  5



--- Multi-Batch Generation with Report ---


Molecules per batch (default 64):  10
Number of batches (default 10):  10
Temperature (default 1.0):  1
Use conditioning? (y/n, default n):  n
Output file path (default ./generated_batch.csv):  ok.csv


Generazione:  20%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñè                        | 2/10 [00:20<01:20, 10.09s/it]