In [1]:
model_path = "/home/grad/Desktop/pietro/denovo/new/attachments_3/2/diff_scaffold.h5"
char2idx_path = "/home/grad/Desktop/pietro/denovo/new/attachments_3/2/char2idx.pkl"
idx2char_path = "/home/grad/Desktop/pietro/denovo/new/attachments_3/2/idx2char.pkl"
vocab_path    = "/home/grad/Desktop/pietro/denovo/new/attachments_3/2/vocab.json"
model_out_path = "/home/grad/Desktop/pietro/denovo/new/risultati/fine/gba/model_final_fxar_pp.h5"

In [2]:
# --- CONFIGURAZIONE ---
class AdvancedConfig:
    SMILES_FILE = '/home/grad/Desktop/pietro/denovo/s4-for-de-novo-drug-design/datasets/fxar/fine.txt'   # File per il fine-tuning
    BATCH_SIZE = 10
    EPOCHS = 50                  # Numero di epoche per il fine-tuning
    EMBED_DIM = 94             # Deve corrispondere a quello del modello pre-addestrato
    TRANSFORMER_HEADS = 4
    TRANSFORMER_LAYERS = 4
    FF_DIM = 300
    VALID_RATIO = 0.1            # Percentuale di validation split
    TEMPERATURE = 1
    TEMPERATURE_DECAY = 0.97
    GEN_NUM = 5
    WARMUP_EPOCHS = 10
    MAX_RANDOMIZATIONS = 10
    MAX_LENGTH = 94              # Lunghezza massima (potrà essere aggiornata in seguito)
    PRINT_EVERY = 5              # Non usato più per la generazione di SMILES
    DROPOUT_RATE = 0.15
    GRADIENT_CLIP = 1.0
    L2_REG = 1e-5
    LOSS_STABILITY_THRESHOLD = 0.01
    CURRICULUM_START_COMPLEXITY = 0
    CURRICULUM_COMPLEXITY_STEP = 1
    AUGMENT_PROB = 0.1

In [3]:
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Script di fine-tuning aggiornato:
  - Suddivide il dataset in train e validation.
  - Valuta la loss sul validation set ad ogni epoca.
  - Non genera SMILES ogni 5 epoche.
"""

import numpy as np
import random
import re
import logging
import pickle
import json
import tensorflow as tf
from tensorflow.keras.layers import Layer, Embedding, Input, LayerNormalization, MultiHeadAttention, Dropout, Dense
from tensorflow.keras.models import Model
from tensorflow.keras.callbacks import Callback
from rdkit import Chem
from rdkit.Chem import MolFromSmiles, MolToSmiles
from threading import Lock
from sklearn.model_selection import train_test_split

# Configurazione del logger
logger = logging.getLogger(__name__)
logging.basicConfig(level=logging.INFO, format="%(asctime)s [%(levelname)s] %(message)s")


config = AdvancedConfig()

# --- FUNZIONI DI PREPROCESSING ---
def validate_and_fix_smiles(smiles: str) -> str:
    try:
        mol = Chem.MolFromSmiles(smiles, sanitize=True)
        if mol is None:
            return None
        try:
            Chem.Kekulize(mol, clearAromaticFlags=True)
        except Exception as e:
            logger.debug(f"Kekulization error in SMILES {smiles}: {e}")
            return None
        return Chem.MolToSmiles(mol, canonical=True, isomericSmiles=False)
    except Exception as e:
        logger.debug(f"Parsing/sanitization error in SMILES {smiles}: {e}")
        return None

def robust_tokenize(smiles: str) -> list:
    pattern = (
        r"(\[[^\[\]]{1,6}\]|"                 # atomi in parentesi quadre
        r"Br|Cl|Si|Na|Mg|Mn|Ca|Fe|Zn|Se|Li|K|Al|B|"  # elementi multi-char
        r"R[0-9]|r[0-9]|a[0-9]|"             # ring labels
        r"[A-Za-z0-9@+\-\\\/\(\)=#\$\.\%,])"  # singoli caratteri, incluso '%'
    )
    tokens = re.findall(pattern, smiles)
    stack = []
    for t in tokens:
        if t.startswith('['):
            stack.append(t)
        if t.endswith(']') and not stack:
            return []
        if t.endswith(']'):
            stack.pop()
    try:
        if not stack and Chem.MolFromSmiles(''.join(tokens)):
            return tokens
    except Exception as e:
        logger.debug(f"Tokenization error: {e}")
        return []
    return tokens

def process_dataset(data: list) -> tuple:
    processed = []
    all_tokens = set()
    for s in data:
        fixed = validate_and_fix_smiles(s)
        if not fixed:
            continue
        tokens = robust_tokenize(fixed)
        if tokens and 3 <= len(tokens) <= config.MAX_LENGTH - 2:
            processed.append(tokens)
            all_tokens.update(tokens)
    vocab = ['<PAD>', '<START>', '<END>'] + sorted(all_tokens)
    lengths = [len(t) for t in processed]
    max_len = min(int(np.percentile(lengths, 99)) + 2, config.MAX_LENGTH) if processed else config.MAX_LENGTH
    logger.info(f"Processed SMILES: {len(processed)}/{len(data)}")
    logger.info(f"Unique tokens: {len(all_tokens)}")
    logger.info(f"Max length: {max_len}")
    return processed, vocab, max_len

def randomize_smiles(smiles: str, num_versions: int = 3) -> list:
    mol = Chem.MolFromSmiles(smiles)
    if not mol:
        return []
    randomized = []
    for _ in range(num_versions):
        try:
            new_smiles = Chem.MolToSmiles(mol, doRandom=True, canonical=False)
            if new_smiles:
                randomized.append(new_smiles)
        except Exception as e:
            logger.debug(f"Randomization error: {e}")
            continue
    return randomized

def compute_complexity_from_tokens(tokens: list) -> int:
    smiles = ''.join(tokens)
    try:
        mol = Chem.MolFromSmiles(smiles)
        if not mol:
            return float('inf')
        num_rings = Chem.GetSSSR(mol)
        num_branches = smiles.count('(')
        return num_rings + num_branches
    except Exception:
        return float('inf')

# --- Componenti del Modello ---
class DynamicPositionalEncoding(Layer):
    def __init__(self, embed_dim, **kwargs):
        super().__init__(**kwargs)
        self.embed_dim = embed_dim
    def build(self, input_shape):
        max_seq_len = input_shape[1]
        pos = np.arange(max_seq_len)[:, np.newaxis]
        i = np.arange(self.embed_dim)[np.newaxis, :]
        angle_rates = 1 / np.power(10000, (2 * (i // 2)) / np.float32(self.embed_dim))
        angle_rads = pos * angle_rates
        angle_rads[:, 0::2] = tf.math.sin(angle_rads[:, 0::2])
        angle_rads[:, 1::2] = tf.math.cos(angle_rads[:, 1::2])
        self.pos_encoding = tf.cast(angle_rads[np.newaxis, ...], dtype=tf.float32)
    def call(self, inputs):
        return inputs + self.pos_encoding[:, :tf.shape(inputs)[1], :]
    def get_config(self):
        base_config = super().get_config()
        base_config.update({"embed_dim": self.embed_dim})
        return base_config

class ImprovedTransformerBlock(Layer):
    def __init__(self, embed_dim, num_heads, ffn_dim, rate=0.1, **kwargs):
        super().__init__(**kwargs)
        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.ffn_dim = ffn_dim
        self.rate = rate
        self.mha = MultiHeadAttention(
            num_heads=num_heads,
            key_dim=embed_dim,
            kernel_regularizer=tf.keras.regularizers.l2(config.L2_REG),
            dropout=rate
        )
        self.ffn = tf.keras.Sequential([
            Dense(ffn_dim, activation="gelu", kernel_regularizer=tf.keras.regularizers.l2(config.L2_REG)),
            Dense(embed_dim, kernel_regularizer=tf.keras.regularizers.l2(config.L2_REG))
        ])
        self.layernorm1 = LayerNormalization(epsilon=1e-6)
        self.layernorm2 = LayerNormalization(epsilon=1e-6)
        self.dropout1 = Dropout(rate)
        self.dropout2 = Dropout(rate)
    
    def call(self, inputs, training=False):
        seq_len = tf.shape(inputs)[1]
        causal_mask = tf.linalg.band_part(tf.ones((seq_len, seq_len)), -1, 0)
        attn_output = self.mha(inputs, inputs, attention_mask=causal_mask)
        attn_output = self.dropout1(attn_output, training=training)
        out1 = self.layernorm1(inputs + attn_output)
        ffn_output = self.ffn(out1)
        ffn_output = self.dropout2(ffn_output, training=training)
        return self.layernorm2(out1 + ffn_output)
    
    def get_config(self):
        base_config = super().get_config()
        base_config.update({
            "embed_dim": self.embed_dim,
            "num_heads": self.num_heads,
            "ffn_dim": self.ffn_dim,
            "rate": self.rate
        })
        return base_config

class CustomSchedule(tf.keras.optimizers.schedules.LearningRateSchedule):
    def __init__(self, embed_dim, warmup_steps=10000):
        super().__init__()
        self.embed_dim = tf.cast(embed_dim, tf.float32)
        self.warmup_steps = tf.cast(warmup_steps, tf.float32)
    def __call__(self, step):
        step = tf.cast(step, tf.float32) + 1e-9
        arg1 = tf.math.rsqrt(step)
        arg2 = step * (self.warmup_steps ** -1.5)
        return tf.math.rsqrt(self.embed_dim) * tf.math.minimum(arg1, arg2)
    def get_config(self):
        return {"embed_dim": self.embed_dim.numpy(), "warmup_steps": self.warmup_steps.numpy()}

def build_improved_model(vocab_size: int) -> Model:
    inputs = Input(shape=(config.MAX_LENGTH,))
    x = Embedding(vocab_size, config.EMBED_DIM, mask_zero=True)(inputs)
    x = DynamicPositionalEncoding(config.EMBED_DIM)(x)
    x = Dropout(config.DROPOUT_RATE)(x)
    for _ in range(config.TRANSFORMER_LAYERS):
        x = ImprovedTransformerBlock(config.EMBED_DIM, config.TRANSFORMER_HEADS, config.FF_DIM, rate=config.DROPOUT_RATE)(x)
    outputs = Dense(vocab_size)(x)
    def smoothed_loss(y_true, y_pred):
        # cast a interi perché tf.nn.sparse_* vuole etichette int32 o int64
        y_true_int = tf.cast(y_true, tf.int32)
        mask = tf.cast(tf.math.not_equal(y_true, 0), tf.float32)
        loss = tf.nn.sparse_softmax_cross_entropy_with_logits(labels=y_true_int, logits=y_pred)
        return tf.reduce_sum(loss * mask) / (tf.reduce_sum(mask) + 1e-9)
    optimizer = tf.keras.optimizers.Adam(
        learning_rate=CustomSchedule(config.EMBED_DIM),
        clipnorm=config.GRADIENT_CLIP
    )
    model = Model(inputs, outputs)
    model.compile(optimizer=optimizer, loss=smoothed_loss)
    return model

# --- GENERATORI ---
class ThreadSafeIterator:
    def __init__(self, iterator):
        self.iterator = iterator
        self.lock = Lock()
    def __iter__(self):
        return self
    def __next__(self):
        with self.lock:
            return next(self.iterator)

def threadsafe_generator(func):
    def wrapper(*args, **kwargs):
        return ThreadSafeIterator(func(*args, **kwargs))
    return wrapper

class CurriculumSmilesGenerator:
    """
    Generatore per il training che utilizza curriculum learning e data augmentation.
    """
    def __init__(self, tokenized_smiles: list, vocab: list):
        self.char2idx = {c: i for i, c in enumerate(vocab)}
        self.idx2char = {i: c for c, i in self.char2idx.items()}
        self.original_data = []
        for tokens in tokenized_smiles:
            comp = compute_complexity_from_tokens(tokens)
            fixed = validate_and_fix_smiles(''.join(tokens))
            if fixed is None:
                continue
            self.original_data.append((tokens, comp))
        if self.original_data:
            valid_comps = [comp for _, comp in self.original_data if comp != float('inf')]
            self.max_complexity = max(valid_comps) if valid_comps else 0
        else:
            self.max_complexity = 0
        self.current_complexity = config.CURRICULUM_START_COMPLEXITY
        self.available_data = self._filter_data()
        self.train_smiles = {''.join(tokens) for tokens, _ in self.original_data}
        self.lock = Lock()
    
    def _filter_data(self):
        filtered = [tokens for tokens, comp in self.original_data if comp <= self.current_complexity]
        return filtered or [tokens for tokens, comp in self.original_data]
    
    def update_complexity(self, epoch: int, loss_diff: float = None):
        with self.lock:
            if loss_diff is not None and loss_diff < config.LOSS_STABILITY_THRESHOLD:
                self.current_complexity = min(self.current_complexity + config.CURRICULUM_COMPLEXITY_STEP, self.max_complexity)
            else:
                if epoch <= config.WARMUP_EPOCHS:
                    increment = int((self.max_complexity - config.CURRICULUM_START_COMPLEXITY) * (epoch / config.WARMUP_EPOCHS))
                    self.current_complexity = config.CURRICULUM_START_COMPLEXITY + increment
                else:
                    self.current_complexity = self.max_complexity
            self.available_data = self._filter_data()
            if not self.available_data:
                self.available_data = [tokens for tokens, comp in self.original_data]
                logger.warning("Reset available_data to original")
    
    @threadsafe_generator
    def __call__(self):
        while True:
            inputs = np.full((config.BATCH_SIZE, config.MAX_LENGTH), self.char2idx['<PAD>'], dtype=np.int32)
            targets = np.full_like(inputs, self.char2idx['<PAD>'])
            for i in range(config.BATCH_SIZE):
                with self.lock:
                    try:
                        tokens = random.choice(self.available_data)
                    except IndexError:
                        self.available_data = [tokens for tokens, comp in self.original_data]
                        tokens = random.choice(self.available_data)
                if random.random() < config.AUGMENT_PROB:
                    try:
                        augmented = randomize_smiles(''.join(tokens))
                        if augmented:
                            new_tokens = robust_tokenize(random.choice(augmented)) or tokens
                            tokens = new_tokens
                    except Exception as e:
                        logger.debug(f"Augmentation error: {e}")
                seq = ['<START>'] + tokens + ['<END>']
                padded = (seq + ['<PAD>'] * config.MAX_LENGTH)[:config.MAX_LENGTH]
                inputs[i] = [self.char2idx.get(t, self.char2idx['<PAD>']) for t in padded]
                targets[i, :-1] = inputs[i][1:]
                targets[i, -1] = self.char2idx['<PAD>']
            yield inputs, targets

    def get_dataset(self):
        return tf.data.Dataset.from_generator(
            self.__call__,
            output_signature=(
                tf.TensorSpec(shape=(config.BATCH_SIZE, config.MAX_LENGTH), dtype=tf.int32),
                tf.TensorSpec(shape=(config.BATCH_SIZE, config.MAX_LENGTH), dtype=tf.int32)
            )
        ).prefetch(tf.data.AUTOTUNE)

# Per la validazione creiamo un generatore semplice che non applichi data augmentation
def simple_smiles_generator(tokenized_smiles: list, vocab: list):
    char2idx = {c: i for i, c in enumerate(vocab)}
    while True:
        inputs = np.full((config.BATCH_SIZE, config.MAX_LENGTH), char2idx['<PAD>'], dtype=np.int32)
        targets = np.full_like(inputs, char2idx['<PAD>'])
        # Se il validation set è più piccolo, si cicla ripetutamente
        chosen = random.choices(tokenized_smiles, k=config.BATCH_SIZE)
        for i, tokens in enumerate(chosen):
            seq = ['<START>'] + tokens + ['<END>']
            padded = (seq + ['<PAD>'] * config.MAX_LENGTH)[:config.MAX_LENGTH]
            inputs[i] = [char2idx.get(t, char2idx['<PAD>']) for t in padded]
            targets[i, :-1] = inputs[i][1:]
            targets[i, -1] = char2idx['<PAD>']
        yield inputs, targets

def get_simple_dataset(tokenized_smiles: list, vocab: list):
    return tf.data.Dataset.from_generator(
        lambda: simple_smiles_generator(tokenized_smiles, vocab),
        output_signature=(
            tf.TensorSpec(shape=(config.BATCH_SIZE, config.MAX_LENGTH), dtype=tf.int32),
            tf.TensorSpec(shape=(config.BATCH_SIZE, config.MAX_LENGTH), dtype=tf.int32)
        )
    ).prefetch(tf.data.AUTOTUNE)

# --- Callback per il monitoraggio ---
class CustomTensorBoard(tf.keras.callbacks.TensorBoard):
    def on_epoch_end(self, epoch, logs=None):
        logs = logs or {}
        lr = self.model.optimizer.learning_rate
        if isinstance(lr, tf.keras.optimizers.schedules.LearningRateSchedule):
            logs['lr'] = lr(epoch).numpy()
        else:
            logs['lr'] = lr.numpy()
        super().on_epoch_end(epoch, logs)

class EnhancedTrainingMonitor(Callback):
    """
    Callback di monitoraggio che ora si limita a loggare i valori di loss
    (non genera SMILES a ogni PRINT_EVERY epoche).
    """
    def on_epoch_end(self, epoch, logs=None):
        logs = logs or {}
        logger.info(f"Epoca {epoch+1}: Loss Training = {logs.get('loss', 'N/A'):.4f}, Loss Validation = {logs.get('val_loss', 'N/A'):.4f}")



2025-06-11 09:48:17.757039: I tensorflow/core/util/port.cc:153] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2025-06-11 09:48:17.771081: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:467] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1749628097.788029 3718496 cuda_dnn.cc:8579] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1749628097.793231 3718496 cuda_blas.cc:1407] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
W0000 00:00:1749628097.805974 3718496 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking 

In [None]:
# --- FINE-TUNING SCRIPT ---
if __name__ == "__main__":
    # Carica il vocabolario e le mappature salvate
    logger.info("Loading vocabulary and saved mappings...")
    with open(char2idx_path, 'rb') as f:
        char2idx = pickle.load(f)
    with open(idx2char_path, 'rb') as f:
        idx2char = pickle.load(f)
    with open(vocab_path, "r") as f:
        vocab = json.load(f)
    logger.info(f"Vocabulary loaded: {len(vocab)} tokens")

    # Carica il dataset per il fine-tuning
    logger.info("Loading new dataset for fine-tuning...")
    with open(config.SMILES_FILE, "r") as f:
        raw_smiles = [line.strip() for line in f if line.strip()]

    # Preprocessa il dataset
    processed, new_vocab, max_len = process_dataset(raw_smiles)
    config.MAX_LENGTH = max_len  # Aggiorna la lunghezza massima se necessario
    logger.info("Dataset processed for fine-tuning.")

    # Suddividi il dataset in training set e validation set
    stratify_labels = [min(len(t), 20) for t in processed]  # Etichette semplificate per stratificazione
    try:
        counts = np.bincount(stratify_labels)
        if np.min(counts[np.nonzero(counts)]) < 2:
            logger.warning("Layering disabled: Some classes have less than 2 examples.")
            stratify_param = None
        else:
            stratify_param = stratify_labels
    except Exception as e:
        logger.warning(f"Errore nella stratificazione: {e}. Disabilito stratificazione.")
        stratify_param = None

    train_data, val_data = train_test_split(
        processed,
        test_size=config.VALID_RATIO,
        stratify=stratify_param,
        random_state=42
    )
    logger.info(f"Training set: {len(train_data)} SMILES, Validation set: {len(val_data)} SMILES")

    # Crea il generatore per il training (curriculum learning)
    curriculum_generator = CurriculumSmilesGenerator(tokenized_smiles=train_data, vocab=vocab)
    train_dataset = curriculum_generator.get_dataset()

    # Crea il generatore per il validation (generazione semplice, senza curriculum)
    val_dataset = get_simple_dataset(val_data, vocab)

    # Costruisci il modello e carica i pesi pre-addestrati per il fine-tuning
    logger.info("Model construction and loading of pre-trained weights...")
    model = build_improved_model(len(vocab))
    try:
        model.load_weights(model_path)
        logger.info("Pre-trained weights loaded correctly.")
    except Exception as e:
        logger.warning(f"Unable to load pre-trained weights: {e}. Proceeding with random weights.")

    # Definisci i callback per il training
    from tensorflow.keras.callbacks import ModelCheckpoint  # Import aggiuntivo per checkpointing
    tensorboard_cb = CustomTensorBoard(log_dir="logs_finetune")
    monitor_cb = EnhancedTrainingMonitor()
    checkpoint_cb = ModelCheckpoint(
        filepath="best_model.h5",  # Salva il modello con loss di validazione minimo
        monitor="val_loss",
        save_best_only=True,
        verbose=1
    )

    # Avvio del fine-tuning (validation viene valutato ad ogni epoca)
    logger.info("Starting fine-tuning...")
    steps_per_epoch = max(1, len(train_data) // config.BATCH_SIZE)
    val_steps = max(1, len(val_data) // config.BATCH_SIZE)

    model.fit(
        train_dataset,
        epochs=config.EPOCHS,
        steps_per_epoch=steps_per_epoch,
        validation_data=val_dataset,
        validation_steps=val_steps,
        callbacks=[tensorboard_cb, monitor_cb, checkpoint_cb]
    )

    # (Opzionale) Salva il modello finale dopo l'ultima epoca
    final_model_path = model_out_path
    model.save(final_model_path)
    logger.info(f"Fine-tuned model saved in {final_model_path}")
    logger.info("Best model based on validation loss saved in best_model.h5")
