In [2]:
cd /home/grad/Desktop/pietro/denovo/ab/ab2

/home/grad/Desktop/pietro/denovo/ab/ab2


In [3]:
import tensorflow as tf

print(tf.__version__)
print("Num GPUs Available: ", len(tf.config.list_physical_devices('GPU')))  # Should print 0
print("Num CPUs Available: ", len(tf.config.list_physical_devices('CPU')))  # Should print > 0

with tf.device('/cpu:0'):  # Explicitly use CPU (optiona§§l but good practice)
    a = tf.constant([1.0, 2.0, 3.0])
    b = tf.constant([4.0, 5.0, 6.0])
    c = a + b
print(c)

2025-06-24 10:01:20.930370: I tensorflow/core/util/port.cc:153] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2025-06-24 10:01:20.944915: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:467] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1750752080.960858 3063298 cuda_dnn.cc:8579] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1750752080.965642 3063298 cuda_blas.cc:1407] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
W0000 00:00:1750752080.977701 3063298 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking 

2.19.0
Num GPUs Available:  1
Num CPUs Available:  1
tf.Tensor([5. 7. 9.], shape=(3,), dtype=float32)


I0000 00:00:1750752083.425083 3063298 gpu_device.cc:2019] Created device /job:localhost/replica:0/task:0/device:GPU:0 with 1584 MB memory:  -> device: 0, name: NVIDIA RTX A2000 12GB, pci bus id: 0000:65:00.0, compute capability: 8.6


In [4]:
# --- Configurazione Avanzata ---
class AdvancedConfig:
    SMILES_FILE = '/home/grad/Desktop/pietro/denovo/new/CHEMBL25.smi'  # file con circa 26k SMILES
    BATCH_SIZE = 64
    EPOCHS = 20
    EMBED_DIM = 94
    TRANSFORMER_HEADS = 2
    TRANSFORMER_LAYERS = 2
    FF_DIM = 200
    VALID_RATIO = 0.1
    TEMPERATURE = 1.0
    TEMPERATURE_DECAY = 0.97
    GEN_NUM = 1
    WARMUP_EPOCHS = 10  # Per il curriculum learning iniziale
    MAX_RANDOMIZATIONS = 3
    MAX_LENGTH = 100
    PRINT_EVERY = 100
    DROPOUT_RATE = 0.15
    GRADIENT_CLIP = 1.0
    L2_REG = 1e-5
    LOSS_STABILITY_THRESHOLD = 0.01  # variazione relativa sotto il 1%
    CURRICULUM_START_COMPLEXITY = 5   # complessità iniziale (0 = molecole molto semplici)
    CURRICULUM_COMPLEXITY_STEP = 1    # incremento della soglia di complessità
    AUGMENT_PROB = 0.3

config = AdvancedConfig()

In [5]:
import numpy as np
import matplotlib.pyplot as plt
import random
import time
from rdkit.Chem import MolFromSmiles, MolToSmiles
from sklearn.model_selection import train_test_split
from tensorflow.keras import backend as K
from tensorflow.keras.models import Model
from tensorflow.keras.layers import (
    Input, Embedding, GRU, Dense, Bidirectional, 
    LayerNormalization, Dropout, Attention, Concatenate,
    Layer, Multiply, Masking, RepeatVector,  GlobalAveragePooling1D
)
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.callbacks import (
    EarlyStopping, ModelCheckpoint, 
    LearningRateScheduler, Callback
)
import re
import logging
from typing import List, Optional, Tuple
from rdkit import Chem
from rdkit.Chem import MolFromSmiles, MolToSmiles, SanitizeMol, SanitizeFlags
from rdkit.Chem.rdmolops import AssignStereochemistry
import tensorflow as tf
from tensorflow.keras.layers import Layer, Embedding, Input, LayerNormalization, MultiHeadAttention, Dropout, Dense
from tensorflow.keras.models import Model
from tensorflow.keras.callbacks import Callback, ReduceLROnPlateau
from sklearn.model_selection import train_test_split
from threading import Lock
from sklearn.decomposition import PCA
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D  # Per grafici 3D

In [6]:
# Configurazione del logger
logger = logging.getLogger(__name__)
logging.basicConfig(level=logging.INFO, format="%(asctime)s [%(levelname)s] %(message)s")



# --- Funzione per calcolare la complessità di una SMILES ---
def compute_complexity_from_tokens(tokens: List[str]) -> int:
    """
    Calcola la complessità di una SMILES come la somma del numero di anelli (ottenuti con GetSSSR)
    e del numero di ramificazioni (conteggio delle parentesi aperte).
    """
    smiles = ''.join(tokens)
    try:
        mol = Chem.MolFromSmiles(smiles)
        if not mol:
            return len(tokens)  # fallback: usa la lunghezza come complessità
        num_rings = len(Chem.GetSSSR(mol))  # Correzione: ottiene la lunghezza dell'iterabile
        num_branches = smiles.count('(')
        return num_rings + num_branches
    except Exception:
        return len(tokens)  # fallback: usa la lunghezza come complessità


# --- Funzioni di Preprocessing ---
def randomize_smiles(smiles: str, num_versions: int = 3) -> List[str]:
    mol = Chem.MolFromSmiles(smiles)
    if not mol:
        return []
    randomized = []
    for _ in range(num_versions):
        try:
            new_smiles = Chem.MolToSmiles(mol, doRandom=True, canonical=False)
            if new_smiles:
                randomized.append(new_smiles)
        except Exception as e:
            logger.debug(f"Randomization error: {e}")
            continue
    return randomized

def validate_and_fix_smiles(smiles: str) -> str:
    try:
        mol = Chem.MolFromSmiles(smiles, sanitize=True)
        if mol is None:
            return None
        try:
            Chem.Kekulize(mol, clearAromaticFlags=True)
        except Exception as e:
            logger.debug(f"Kekulization error in SMILES {smiles}: {e}")
            return None
        return Chem.MolToSmiles(mol, canonical=True, isomericSmiles=False)
    except Exception as e:
        logger.debug(f"Parsing/sanitization error in SMILES {smiles}: {e}")
        return None

def robust_tokenize(smiles: str) -> list:
    pattern = (
        r"(\[[^\[\]]{1,6}\]|"                 # atomi in parentesi quadre
        r"Br|Cl|Si|Na|Mg|Mn|Ca|Fe|Zn|Se|Li|K|Al|B|"  # elementi multi-char
        r"R[0-9]|r[0-9]|a[0-9]|"             # ring labels
        r"[A-Za-z0-9@+\-\\\/\(\)=#\$\.\%,])"  # singoli caratteri, incluso '%'
    )
    tokens = re.findall(pattern, smiles)
    tokens = re.findall(pattern, smiles)
    stack = []
    for t in tokens:
        if t.startswith('['):
            stack.append(t)
        if t.endswith(']') and not stack:
            return []
        if t.endswith(']'):
            stack.pop()
    try:
        if not stack and Chem.MolFromSmiles(''.join(tokens)):
            return tokens
    except Exception as e:
        logger.debug(f"Tokenization error: {e}")
        return []
    return tokens


def process_dataset(data: List[str]) -> Tuple[List[List[str]], List[str], int]:
    processed = []
    all_tokens = set()
    for s in data:
        fixed = validate_and_fix_smiles(s)
        if not fixed:
            continue
        tokens = robust_tokenize(fixed)
        if tokens and 3 <= len(tokens) <= config.MAX_LENGTH - 2:
            processed.append(tokens)
            all_tokens.update(tokens)
    vocab = ['<PAD>', '<START>', '<END>'] + sorted(all_tokens)
    lengths = [len(t) for t in processed]
    max_len = min(int(np.percentile(lengths, 99)) + 2, config.MAX_LENGTH) if processed else config.MAX_LENGTH
    logger.info(f"Processed SMILES: {len(processed)}/{len(data)}")
    logger.info(f"Unique tokens: {len(all_tokens)}")
    logger.info(f"Max length: {max_len}")
    return processed, vocab, max_len

# --- Componenti del Modello ---
class DynamicPositionalEncoding(Layer):
    def __init__(self, embed_dim, **kwargs):
        super().__init__(**kwargs)
        self.embed_dim = embed_dim
    def build(self, input_shape):
        max_seq_len = input_shape[1]
        pos = np.arange(max_seq_len)[:, np.newaxis]
        i = np.arange(self.embed_dim)[np.newaxis, :]
        angle_rates = 1 / np.power(10000, (2 * (i // 2)) / np.float32(self.embed_dim))
        angle_rads = pos * angle_rates
        angle_rads[:, 0::2] = tf.math.sin(angle_rads[:, 0::2])
        angle_rads[:, 1::2] = tf.math.cos(angle_rads[:, 1::2])
        self.pos_encoding = tf.cast(angle_rads[np.newaxis, ...], dtype=tf.float32)
    def call(self, inputs):
        return inputs + self.pos_encoding[:, :tf.shape(inputs)[1], :]
    def get_config(self):
        base_config = super().get_config()
        base_config.update({"embed_dim": self.embed_dim})
        return base_config

class ImprovedTransformerBlock(Layer):
    def __init__(self, embed_dim, num_heads, ffn_dim, rate=0.1, **kwargs):
        super().__init__(**kwargs)
        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.ffn_dim = ffn_dim
        self.rate = rate
        self.mha = MultiHeadAttention(
            num_heads=num_heads,
            key_dim=embed_dim,
            kernel_regularizer=tf.keras.regularizers.l2(config.L2_REG),
            dropout=rate
        )
        self.ffn = tf.keras.Sequential([
            Dense(ffn_dim, activation="gelu", kernel_regularizer=tf.keras.regularizers.l2(config.L2_REG)),
            Dense(embed_dim, kernel_regularizer=tf.keras.regularizers.l2(config.L2_REG))
        ])
        self.layernorm1 = LayerNormalization(epsilon=1e-6)
        self.layernorm2 = LayerNormalization(epsilon=1e-6)
        self.dropout1 = Dropout(rate)
        self.dropout2 = Dropout(rate)
    
    def call(self, inputs, training=False):
        seq_len = tf.shape(inputs)[1]
        causal_mask = tf.linalg.band_part(tf.ones((seq_len, seq_len)), -1, 0)
        attn_output = self.mha(inputs, inputs, attention_mask=causal_mask)
        attn_output = self.dropout1(attn_output, training=training)
        out1 = self.layernorm1(inputs + attn_output)
        ffn_output = self.ffn(out1)
        ffn_output = self.dropout2(ffn_output, training=training)
        return self.layernorm2(out1 + ffn_output)
    
    def get_config(self):
        base_config = super().get_config()
        base_config.update({
            "embed_dim": self.embed_dim,
            "num_heads": self.num_heads,
            "ffn_dim": self.ffn_dim,
            "rate": self.rate
        })
        return base_config

class CustomSchedule(tf.keras.optimizers.schedules.LearningRateSchedule):
    def __init__(self, embed_dim, warmup_steps=10000):
        super().__init__()
        self.embed_dim = tf.cast(embed_dim, tf.float32)
        self.warmup_steps = tf.cast(warmup_steps, tf.float32)
    def __call__(self, step):
        step = tf.cast(step, tf.float32) + 1e-9
        arg1 = tf.math.rsqrt(step)
        arg2 = step * (self.warmup_steps ** -1.5)
        return tf.math.rsqrt(self.embed_dim) * tf.math.minimum(arg1, arg2)
    def get_config(self):
        return {"embed_dim": self.embed_dim.numpy(), "warmup_steps": self.warmup_steps.numpy()}

def build_improved_model(vocab_size: int) -> Model:
    inputs = Input(shape=(config.MAX_LENGTH,))
    x = Embedding(vocab_size, config.EMBED_DIM, mask_zero=True)(inputs)
    x = DynamicPositionalEncoding(config.EMBED_DIM)(x)
    x = Dropout(config.DROPOUT_RATE)(x)
    for _ in range(config.TRANSFORMER_LAYERS):
        x = ImprovedTransformerBlock(config.EMBED_DIM, config.TRANSFORMER_HEADS, config.FF_DIM, rate=config.DROPOUT_RATE)(x)
    outputs = Dense(vocab_size)(x)
    def smoothed_loss(y_true, y_pred):
        # cast a interi perché tf.nn.sparse_* vuole etichette int32 o int64
        y_true_int = tf.cast(y_true, tf.int32)
        mask = tf.cast(tf.math.not_equal(y_true, 0), tf.float32)
        loss = tf.nn.sparse_softmax_cross_entropy_with_logits(labels=y_true_int, logits=y_pred)
        return tf.reduce_sum(loss * mask) / (tf.reduce_sum(mask) + 1e-9)
    optimizer = tf.keras.optimizers.Adam(
        learning_rate=CustomSchedule(config.EMBED_DIM),
        clipnorm=config.GRADIENT_CLIP
    )
    model = Model(inputs, outputs)
    model.compile(optimizer=optimizer, loss=smoothed_loss)
    return model

# --- Generatori Dati e Utilities ---
class ThreadSafeIterator:
    def __init__(self, iterator):
        self.iterator = iterator
        self.lock = Lock()
    def __iter__(self):
        return self
    def __next__(self):
        with self.lock:
            return next(self.iterator)

def threadsafe_generator(func):
    def wrapper(*args, **kwargs):
        return ThreadSafeIterator(func(*args, **kwargs))
    return wrapper

class CurriculumSmilesGenerator:
    """
    Versione semplificata e robusta del generatore con curriculum learning
    """
    def __init__(self, tokenized_smiles: List[List[str]], vocab: List[str]):
        self.char2idx = {c: i for i, c in enumerate(vocab)}
        self.idx2char = {i: c for c, i in self.char2idx.items()}
        
        # Calcola la complessità di ogni SMILES
        logger.info("Calcolo complessità delle molecole...")
        self.data_with_complexity = []
        complexities = []
        
        for tokens in tokenized_smiles:
            complexity = compute_complexity_from_tokens(tokens)
            self.data_with_complexity.append((tokens, complexity))
            complexities.append(complexity)
        
        # Calcola statistiche
        self.min_complexity = min(complexities) if complexities else 0
        self.max_complexity = max(complexities) if complexities else 0
        
        # Inizializza al minimo
        self.current_complexity = config.CURRICULUM_START_COMPLEXITY
        self.train_smiles = {''.join(tokens) for tokens, _ in self.data_with_complexity}
        
        # Log iniziali
        logger.info(f"Curriculum inizializzato: min={self.min_complexity}, max={self.max_complexity}")
        logger.info(f"Complessità iniziale: {self.current_complexity}")
        
        # Crea indici per accesso rapido
        self._prepare_complexity_indices()
    
    def _prepare_complexity_indices(self):
        """Prepara gli indici per accesso rapido per complessità"""
        self.complexity_indices = {}
        for i, (_, comp) in enumerate(self.data_with_complexity):
            if comp not in self.complexity_indices:
                self.complexity_indices[comp] = []
            self.complexity_indices[comp].append(i)
        
        # Prepara gli indici disponibili
        self.available_indices = []
        for comp in range(self.min_complexity, self.current_complexity + 1):
            if comp in self.complexity_indices:
                self.available_indices.extend(self.complexity_indices[comp])
    
    def update_complexity(self, epoch, loss_diff=None):
        """Aggiorna la complessità in base all'epoca o alla convergenza"""
        old_complexity = self.current_complexity
        
        if epoch < config.WARMUP_EPOCHS:
            # Durante il warmup, aumenta linearmente
            progress = (epoch + 1) / config.WARMUP_EPOCHS
            target = self.min_complexity + progress * (self.max_complexity - self.min_complexity)
            new_complexity = round(target)
            self.current_complexity = max(min(new_complexity, self.max_complexity), self.current_complexity)
            reason = f"warmup ({epoch+1}/{config.WARMUP_EPOCHS})"
        elif loss_diff is not None and loss_diff < config.LOSS_STABILITY_THRESHOLD:
            # Dopo il warmup, aumenta se la loss è stabile
            new_complexity = min(self.current_complexity + config.CURRICULUM_COMPLEXITY_STEP, self.max_complexity)
            if new_complexity > self.current_complexity:
                self.current_complexity = new_complexity
                reason = f"convergenza (diff={loss_diff:.5f})"
            else:
                reason = "già alla massima complessità"
        else:
            reason = "nessun cambiamento"
        
        # Aggiorna gli indici disponibili
        self.available_indices = []
        for comp in range(self.min_complexity, self.current_complexity + 1):
            if comp in self.complexity_indices:
                self.available_indices.extend(self.complexity_indices[comp])
        
        # Log
        if old_complexity != self.current_complexity:
            available_percent = len(self.available_indices) / len(self.data_with_complexity)
            logger.info(f"Curriculum: complessità {old_complexity} -> {self.current_complexity} ({reason})")
            logger.info(f"Esempi disponibili: {len(self.available_indices)}/{len(self.data_with_complexity)} ({available_percent:.1%})")
        elif epoch % 5 == 0:
            available_percent = len(self.available_indices) / len(self.data_with_complexity)
            logger.info(f"[Epoca {epoch}] Curriculum status: complessità={self.current_complexity}/{self.max_complexity}, "
                       f"esempi={len(self.available_indices)}/{len(self.data_with_complexity)} ({available_percent:.1%})")
    
    def __call__(self):
        """Generatore di batch per l'addestramento"""
        while True:
            inputs = np.full((config.BATCH_SIZE, config.MAX_LENGTH), self.char2idx['<PAD>'], dtype=np.int32)
            targets = np.full_like(inputs, self.char2idx['<PAD>'])
            
            # Verifica che ci siano indici disponibili
            if not self.available_indices:
                logger.warning(f"Nessun esempio disponibile alla complessità {self.current_complexity}!")
                indices = list(range(len(self.data_with_complexity)))
            else:
                indices = self.available_indices
            
            for i in range(config.BATCH_SIZE):
                # Seleziona un esempio casuale
                idx = random.choice(indices)
                tokens = self.data_with_complexity[idx][0]
                
                # Data augmentation (randomizzazione)
                if random.random() < config.AUGMENT_PROB:
                    try:
                        smiles = ''.join(tokens)
                        mol = Chem.MolFromSmiles(smiles)
                        if mol:
                            new_smiles = Chem.MolToSmiles(mol, doRandom=True, canonical=False)
                            if new_smiles:
                                new_tokens = robust_tokenize(new_smiles)
                                if new_tokens:
                                    tokens = new_tokens
                    except Exception:
                        pass  # In caso di errore, usa i token originali
                
                # Prepara la sequenza
                seq = ['<START>'] + tokens + ['<END>']
                padded = (seq + ['<PAD>'] * config.MAX_LENGTH)[:config.MAX_LENGTH]
                
                # Converti in indici
                inputs[i] = [self.char2idx.get(t, self.char2idx['<PAD>']) for t in padded]
                targets[i, :-1] = inputs[i][1:]
                targets[i, -1] = self.char2idx['<PAD>']
            
            yield inputs, targets
    
    def get_dataset(self):
        return tf.data.Dataset.from_generator(
            self.__call__,
            output_signature=(
                tf.TensorSpec(shape=(config.BATCH_SIZE, config.MAX_LENGTH), dtype=tf.int32),
                tf.TensorSpec(shape=(config.BATCH_SIZE, config.MAX_LENGTH), dtype=tf.int32)
            )
        ).prefetch(tf.data.AUTOTUNE)



# --- Callback per il monitoraggio ---
class CustomTensorBoard(tf.keras.callbacks.TensorBoard):
    def on_epoch_end(self, epoch, logs=None):
        logs = logs or {}
        lr = self.model.optimizer.learning_rate
        if isinstance(lr, tf.keras.optimizers.schedules.LearningRateSchedule):
            logs['lr'] = lr(epoch).numpy()
        else:
            logs['lr'] = lr.numpy()
        super().on_epoch_end(epoch, logs)

class EnhancedTrainingMonitor(Callback):
    def __init__(self, val_gen: CurriculumSmilesGenerator):
        super().__init__()
        self.val_gen = val_gen
        self.best_val_loss = np.inf
        self.prev_val_loss = None
    def generate_smiles(self, num: int) -> Tuple[List[str], List[str]]:
        generated, valid = [], []
        if self.model is None:
            raise ValueError("Il modello non è stato assegnato al callback!")
        input_seq = np.full((1, config.MAX_LENGTH), self.val_gen.char2idx['<PAD>'], dtype=np.int32)
        input_seq[0, 0] = self.val_gen.char2idx['<START>']
        for _ in range(num):
            for t in range(1, config.MAX_LENGTH):
                logits = self.model(input_seq, training=False)[0, t-1]
                probs = tf.nn.softmax(logits / config.TEMPERATURE).numpy()
                valid_indices = [i for i, tok in self.val_gen.idx2char.items() 
                                 if re.match(r'^([A-Za-z0-9@#\[\]()+\-\\/%=:.,]|<\w+>)$', tok)]
                probs[[i for i in range(len(probs)) if i not in valid_indices]] = 0
                if np.sum(probs) == 0:
                    break
                probs /= np.sum(probs)
                sampled = np.random.choice(len(probs), p=probs)
                input_seq[0, t] = sampled
                if sampled == self.val_gen.char2idx['<END>']:
                    break
            raw = ''.join([self.val_gen.idx2char[i] for i in input_seq[0]
                           if i not in {self.val_gen.char2idx.get('<PAD>'), self.val_gen.char2idx['<END>']}][1:])
            final = validate_and_fix_smiles(raw) or raw
            generated.append(final)
            if Chem.MolFromSmiles(final):
                valid.append(final)
            input_seq = np.full((1, config.MAX_LENGTH), self.val_gen.char2idx['<PAD>'], dtype=np.int32)
            input_seq[0, 0] = self.val_gen.char2idx['<START>']
        return generated, valid
    def on_epoch_end(self, epoch, logs=None):
        if (epoch + 1) % config.PRINT_EVERY == 0:
            generated, valid = self.generate_smiles(config.GEN_NUM)
            validity = len(valid) / config.GEN_NUM
            unique = len(set(valid))
            novel = len([s for s in valid if s not in self.val_gen.train_smiles])
            logger.info(f"\n🧪 Epoca {epoch+1}:")
            logger.info(f"Loss Training: {logs.get('loss', 'N/A'):.4f} - Loss Validation: {logs.get('val_loss', 'N/A'):.4f}")
            logger.info(f"Validità: {validity:.1%}")
            logger.info(f"Unicità: {unique}/{len(valid)}")
            logger.info(f"Novità: {novel}/{len(valid)}")
            if valid:
                logger.info("Esempi:")
                for s in valid[:3]:
                    logger.info(f"- {s}")
            with open("valid_generated_smiles.txt", "w") as f:
                for s in valid:
                    f.write(s + "\n")

# --- Main ---
if __name__ == "__main__":
    logger.info("🚀 Avvio Training Transformer per SMILES con tokenizzazione BPE, data augmentation e curriculum learning basato sulla complessità")
    with open(config.SMILES_FILE) as f:
        raw_smiles = [line.strip() for line in f if line.strip()]
    logger.info("🔍 Validazione SMILES...")
    valid_smiles = []
    for idx, s in enumerate(raw_smiles):
        if idx % 5000 == 0:
            logger.info(f"Processati {idx}/{len(raw_smiles)}")
        fixed = validate_and_fix_smiles(s)
        if fixed and 3 <= len(fixed) <= config.MAX_LENGTH:
            valid_smiles.append(fixed)
    processed, vocab, max_len = process_dataset(valid_smiles)
    config.MAX_LENGTH = max_len
    logger.info("Vocabolario creato:")
    logger.info(vocab)
    
    import pickle
    with open('char2idx.pkl', 'wb') as f:
       pickle.dump({char: idx for idx, char in enumerate(vocab)}, f)
    with open('idx2char.pkl', 'wb') as f:
       pickle.dump({idx: char for idx, char in enumerate(vocab)}, f)
    import json
    with open("vocab.json", "w") as f:
        json.dump(vocab, f)
    logger.info("Mappature e vocabolario salvati su disco.")
    
    stratify_labels = [min(len(t), 20) for t in processed]
    try:
        counts = np.bincount(stratify_labels)
        if np.min(counts[np.nonzero(counts)]) < 2:
            logger.warning("Stratificazione disabilitata: alcune classi hanno meno di 2 esempi.")
            stratify_param = None
        else:
            stratify_param = stratify_labels
    except Exception as e:
        logger.warning(f"Errore nella stratificazione: {e}. Disabilito stratificazione.")
        stratify_param = None

    train_data, val_data = train_test_split(
        processed,
        test_size=config.VALID_RATIO,
        stratify=stratify_param,
        random_state=42
    )

2025-06-24 10:01:26,094 [INFO] 🚀 Avvio Training Transformer per SMILES con tokenizzazione BPE, data augmentation e curriculum learning basato sulla complessità
2025-06-24 10:01:26,367 [INFO] 🔍 Validazione SMILES...
2025-06-24 10:01:26,367 [INFO] Processati 0/1213651
2025-06-24 10:01:27,706 [INFO] Processati 5000/1213651
2025-06-24 10:01:29,524 [INFO] Processati 10000/1213651
2025-06-24 10:01:31,066 [INFO] Processati 15000/1213651
2025-06-24 10:01:32,510 [INFO] Processati 20000/1213651
2025-06-24 10:01:33,842 [INFO] Processati 25000/1213651
2025-06-24 10:01:35,147 [INFO] Processati 30000/1213651
2025-06-24 10:01:36,516 [INFO] Processati 35000/1213651
2025-06-24 10:01:37,975 [INFO] Processati 40000/1213651
2025-06-24 10:01:39,260 [INFO] Processati 45000/1213651
2025-06-24 10:01:40,628 [INFO] Processati 50000/1213651
2025-06-24 10:01:42,031 [INFO] Processati 55000/1213651
2025-06-24 10:01:43,455 [INFO] Processati 60000/1213651
2025-06-24 10:01:44,873 [INFO] Processati 65000/1213651
2025-0

In [7]:
import matplotlib.pyplot as plt
import numpy as np
import os

# Crea una cartella per i grafici se non esiste
os.makedirs('training_plots', exist_ok=True)

# Inizializza liste per memorizzare le metriche
train_losses = []
val_losses = []
epochs_list = []
learning_rates = []
complexity_values = []
coverage_percentages = []

model = build_improved_model(len(vocab))
train_gen = CurriculumSmilesGenerator(train_data, vocab)
val_gen = CurriculumSmilesGenerator(val_data, vocab)

train_gen.update_complexity(0)
val_gen.update_complexity(0)

monitor_callback = EnhancedTrainingMonitor(val_gen)

callbacks = [
    monitor_callback,
    ReduceLROnPlateau(monitor='val_loss', factor=0.5, patience=5, min_lr=1e-6),
    CustomTensorBoard(log_dir='./logs'),
    tf.keras.callbacks.ModelCheckpoint('diff_scaffold_2.h5', save_best_only=True, monitor='val_loss')
]

steps_per_epoch = len(train_data) // config.BATCH_SIZE
val_steps = len(val_data) // config.BATCH_SIZE
total_epochs = config.EPOCHS
start_training = time.time()

def plot_training_progress(save_path='training_plots/current_progress.png'):
    """Genera e salva un grafico con l'andamento del training"""
    fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(12, 10))
    
    # Plot delle loss
    ax1.plot(epochs_list, train_losses, 'b-', label='Training Loss')
    ax1.plot(epochs_list, val_losses, 'r-', label='Validation Loss')
    ax1.set_title('Model Loss During Training')
    ax1.set_ylabel('Loss')
    ax1.set_xlabel('Epoch')
    ax1.legend()
    ax1.grid(True)
    
    # Plot secondario: Learning rate e complessità curriculum
    ax2.set_title('Learning Rate and Curriculum Complexity')
    ax2.set_xlabel('Epoch')
    
    # Learning rate (asse y sinistro)
    color = 'tab:green'
    ax2.set_ylabel('Learning Rate', color=color)
    ax2.plot(epochs_list, learning_rates, color=color)
    ax2.tick_params(axis='y', labelcolor=color)
    ax2.set_yscale('log')  # Scala logaritmica per LR
    
    # Complessità curriculum (asse y destro)
    if complexity_values:
        ax3 = ax2.twinx()
        color = 'tab:orange'
        ax3.set_ylabel('Complexity / Coverage %', color=color)
        ax3.plot(epochs_list, complexity_values, 'o-', color=color, label='Complexity')
        if coverage_percentages:
            ax3.plot(epochs_list, coverage_percentages, 's-', color='tab:purple', label='Coverage %')
        ax3.tick_params(axis='y', labelcolor=color)
        ax3.legend(loc='upper right')
    
    plt.tight_layout()
    plt.savefig(save_path, dpi=200)
    plt.close()

for epoch in range(1, total_epochs + 1):
    epoch_start = time.time()
    current_history = model.fit(
        train_gen.get_dataset(),
        steps_per_epoch=steps_per_epoch,
        callbacks=callbacks,
        epochs=1,
        validation_data=val_gen.get_dataset(),
        validation_steps=val_steps,
        verbose=2,
    )
    epoch_duration = time.time() - epoch_start
    
    # Salva le metriche di questa epoca
    current_train_loss = current_history.history['loss'][0]
    current_val_loss = current_history.history['val_loss'][0]
    
    # Ottieni il learning rate corrente
    if hasattr(model.optimizer, 'learning_rate'):
        if isinstance(model.optimizer.learning_rate, tf.keras.optimizers.schedules.LearningRateSchedule):
            current_lr = model.optimizer.learning_rate(epoch).numpy()
        else:
            current_lr = model.optimizer.learning_rate.numpy()
    else:
        current_lr = 0.001  # Default fallback
    
    # Aggiorna le liste per il plotting
    train_losses.append(current_train_loss)
    val_losses.append(current_val_loss)
    learning_rates.append(current_lr)
    epochs_list.append(epoch)
    
    # Aggiungi informazioni sul curriculum learning
    if hasattr(train_gen, 'current_complexity') and hasattr(train_gen, 'max_complexity'):
        complexity_values.append(train_gen.current_complexity)
        if hasattr(train_gen, 'available_indices') and hasattr(train_gen, 'data_with_complexity'):
            coverage = len(train_gen.available_indices) / len(train_gen.data_with_complexity) * 100
            coverage_percentages.append(coverage)
    
    # Calcola la differenza di loss per il curriculum learning
    if epoch > 1:
        loss_diff = abs(current_val_loss - prev_val_loss) / prev_val_loss
    else:
        loss_diff = None
    prev_val_loss = current_val_loss

    # Aggiorna la complessità
    train_gen.update_complexity(epoch, loss_diff=loss_diff)
    val_gen.update_complexity(epoch, loss_diff=loss_diff)
    
    # Genera e salva il grafico ogni 5 epoche o all'ultima epoca
    if epoch % 5 == 0 or epoch == total_epochs:
        # Salva un grafico con il progresso corrente
        plot_training_progress(f'training_plots/progress_epoch_{epoch}.png')
        
        # Salva anche un grafico sempre con lo stesso nome per facile accesso
        plot_training_progress()
        
        print(f"Saved training plots at epoch {epoch}")

# Fine training - salva il grafico finale
plot_training_progress('training_plots/final_training_progress.png')

total_duration = time.time() - start_training
print(f"Training completed in {total_duration/3600:.2f} hours")
print(f"Best validation loss: {min(val_losses):.4f} at epoch {val_losses.index(min(val_losses))+1}")


I0000 00:00:1750753042.022811 3063298 cuda_executor.cc:479] failed to allocate 1.55GiB (1661075456 bytes) from device: RESOURCE_EXHAUSTED: : CUDA_ERROR_OUT_OF_MEMORY: out of memory
2025-06-24 10:17:23,405 [INFO] Calcolo complessità delle molecole...
2025-06-24 10:20:45,233 [INFO] Curriculum inizializzato: min=0, max=25
2025-06-24 10:20:45,234 [INFO] Complessità iniziale: 5
2025-06-24 10:20:45,524 [INFO] Calcolo complessità delle molecole...
2025-06-24 10:21:08,483 [INFO] Curriculum inizializzato: min=0, max=22
2025-06-24 10:21:08,484 [INFO] Complessità iniziale: 5
2025-06-24 10:21:08,522 [INFO] [Epoca 0] Curriculum status: complessità=5/25, esempi=153839/1081448 (14.2%)
2025-06-24 10:21:08,523 [INFO] [Epoca 0] Curriculum status: complessità=5/22, esempi=17029/120161 (14.2%)
I0000 00:00:1750753277.327925 3063666 service.cc:152] XLA service 0x7c5a34002f90 initialized for platform CUDA (this does not guarantee that XLA will be used). Devices:
I0000 00:00:1750753277.327944 3063666 service.

16897/16897 - 520s - 31ms/step - loss: 1.0121 - val_loss: 0.7626 - learning_rate: 7.9347e-04 - lr: 7.9347e-04




16897/16897 - 473s - 28ms/step - loss: 0.7792 - val_loss: 0.7296 - learning_rate: 5.6107e-04 - lr: 5.6107e-04


2025-06-24 10:37:41,378 [INFO] Curriculum: complessità 5 -> 8 (warmup (3/10))
2025-06-24 10:37:41,379 [INFO] Esempi disponibili: 603447/1081448 (55.8%)
2025-06-24 10:37:41,380 [INFO] Curriculum: complessità 5 -> 7 (warmup (3/10))
2025-06-24 10:37:41,381 [INFO] Esempi disponibili: 49364/120161 (41.1%)


16897/16897 - 498s - 29ms/step - loss: 0.7561 - val_loss: 0.7141 - learning_rate: 4.5811e-04 - lr: 4.5811e-04


2025-06-24 10:45:59,360 [INFO] Curriculum: complessità 8 -> 10 (warmup (4/10))
2025-06-24 10:45:59,361 [INFO] Esempi disponibili: 849896/1081448 (78.6%)
2025-06-24 10:45:59,364 [INFO] Curriculum: complessità 7 -> 9 (warmup (4/10))
2025-06-24 10:45:59,365 [INFO] Esempi disponibili: 82042/120161 (68.3%)


16897/16897 - 524s - 31ms/step - loss: 0.7441 - val_loss: 0.7002 - learning_rate: 3.9674e-04 - lr: 3.9674e-04


2025-06-24 10:54:43,715 [INFO] Curriculum: complessità 10 -> 12 (warmup (5/10))
2025-06-24 10:54:43,716 [INFO] Esempi disponibili: 986766/1081448 (91.2%)
2025-06-24 10:54:43,718 [INFO] Curriculum: complessità 9 -> 11 (warmup (5/10))
2025-06-24 10:54:43,719 [INFO] Esempi disponibili: 103254/120161 (85.9%)


16897/16897 - 559s - 33ms/step - loss: 0.7400 - val_loss: 0.6920 - learning_rate: 3.5485e-04 - lr: 3.5485e-04


2025-06-24 11:04:02,996 [INFO] Curriculum: complessità 12 -> 15 (warmup (6/10))
2025-06-24 11:04:02,997 [INFO] Esempi disponibili: 1062679/1081448 (98.3%)
2025-06-24 11:04:02,999 [INFO] Curriculum: complessità 11 -> 13 (warmup (6/10))
2025-06-24 11:04:03,000 [INFO] Esempi disponibili: 113705/120161 (94.6%)


Saved training plots at epoch 5




16897/16897 - 590s - 35ms/step - loss: 0.7405 - val_loss: 0.6917 - learning_rate: 3.2393e-04 - lr: 3.2393e-04


2025-06-24 11:13:54,915 [INFO] Curriculum: complessità 15 -> 18 (warmup (7/10))
2025-06-24 11:13:54,916 [INFO] Esempi disponibili: 1079121/1081448 (99.8%)
2025-06-24 11:13:54,918 [INFO] Curriculum: complessità 13 -> 15 (warmup (7/10))
2025-06-24 11:13:54,919 [INFO] Esempi disponibili: 118073/120161 (98.3%)


16897/16897 - 498s - 29ms/step - loss: 0.7384 - val_loss: 0.6910 - learning_rate: 2.9990e-04 - lr: 2.9990e-04


2025-06-24 11:22:13,104 [INFO] Curriculum: complessità 18 -> 20 (warmup (8/10))
2025-06-24 11:22:13,105 [INFO] Esempi disponibili: 1081133/1081448 (100.0%)
2025-06-24 11:22:13,108 [INFO] Curriculum: complessità 15 -> 18 (warmup (8/10))
2025-06-24 11:22:13,109 [INFO] Esempi disponibili: 119913/120161 (99.8%)


16897/16897 - 529s - 31ms/step - loss: 0.7344 - val_loss: 0.6910 - learning_rate: 2.8053e-04 - lr: 2.8053e-04


2025-06-24 11:31:01,814 [INFO] Curriculum: complessità 20 -> 22 (warmup (9/10))
2025-06-24 11:31:01,814 [INFO] Esempi disponibili: 1081440/1081448 (100.0%)
2025-06-24 11:31:01,817 [INFO] Curriculum: complessità 18 -> 20 (warmup (9/10))
2025-06-24 11:31:01,818 [INFO] Esempi disponibili: 120124/120161 (100.0%)


16897/16897 - 488s - 29ms/step - loss: 0.7314 - val_loss: 0.6880 - learning_rate: 2.6449e-04 - lr: 2.6449e-04


2025-06-24 11:39:09,970 [INFO] Curriculum: complessità 22 -> 25 (warmup (10/10))
2025-06-24 11:39:09,971 [INFO] Esempi disponibili: 1081448/1081448 (100.0%)
2025-06-24 11:39:09,974 [INFO] Curriculum: complessità 20 -> 22 (warmup (10/10))
2025-06-24 11:39:09,974 [INFO] Esempi disponibili: 120161/120161 (100.0%)


16897/16897 - 512s - 30ms/step - loss: 0.7281 - val_loss: 0.6857 - learning_rate: 2.5092e-04 - lr: 2.5092e-04


2025-06-24 11:47:42,317 [INFO] [Epoca 10] Curriculum status: complessità=25/25, esempi=1081448/1081448 (100.0%)
2025-06-24 11:47:42,320 [INFO] [Epoca 10] Curriculum status: complessità=22/22, esempi=120161/120161 (100.0%)


Saved training plots at epoch 10




16897/16897 - 516s - 31ms/step - loss: 0.7259 - val_loss: 0.6840 - learning_rate: 2.3924e-04 - lr: 2.3924e-04




16897/16897 - 507s - 30ms/step - loss: 0.7237 - val_loss: 0.6818 - learning_rate: 2.2906e-04 - lr: 2.2906e-04




16897/16897 - 453s - 27ms/step - loss: 0.7218 - val_loss: 0.6806 - learning_rate: 2.2007e-04 - lr: 2.2007e-04




16897/16897 - 453s - 27ms/step - loss: 0.7203 - val_loss: 0.6790 - learning_rate: 2.1206e-04 - lr: 2.1206e-04




16897/16897 - 452s - 27ms/step - loss: 0.7190 - val_loss: 0.6776 - learning_rate: 2.0487e-04 - lr: 2.0487e-04


2025-06-24 12:27:24,953 [INFO] [Epoca 15] Curriculum status: complessità=25/25, esempi=1081448/1081448 (100.0%)
2025-06-24 12:27:24,955 [INFO] [Epoca 15] Curriculum status: complessità=22/22, esempi=120161/120161 (100.0%)


Saved training plots at epoch 15
16897/16897 - 452s - 27ms/step - loss: 0.7178 - val_loss: 0.6779 - learning_rate: 1.9837e-04 - lr: 1.9837e-04
16897/16897 - 453s - 27ms/step - loss: 0.7161 - val_loss: 0.6778 - learning_rate: 1.9245e-04 - lr: 1.9245e-04




16897/16897 - 451s - 27ms/step - loss: 0.7157 - val_loss: 0.6753 - learning_rate: 1.8702e-04 - lr: 1.8702e-04




16897/16897 - 452s - 27ms/step - loss: 0.7144 - val_loss: 0.6744 - learning_rate: 1.8203e-04 - lr: 1.8203e-04




16897/16897 - 452s - 27ms/step - loss: 0.7138 - val_loss: 0.6731 - learning_rate: 1.7743e-04 - lr: 1.7743e-04


2025-06-24 13:05:06,222 [INFO] [Epoca 20] Curriculum status: complessità=25/25, esempi=1081448/1081448 (100.0%)
2025-06-24 13:05:06,224 [INFO] [Epoca 20] Curriculum status: complessità=22/22, esempi=120161/120161 (100.0%)


Saved training plots at epoch 20
Training completed in 2.73 hours
Best validation loss: 0.6731 at epoch 20


In [None]:
# Dopo il training, per generare SMILES e calcolare le metriche:
generated, valid = monitor_callback.generate_smiles(10000)  # genera 1000 SMILES

# Calcola validità
validity = len(valid) / 10000

# Calcola unicità (solo tra quelle valide)
unique_valid = len(set(valid))
uniqueness = unique_valid / len(valid) if valid else 0

print(f"Validità: {validity*100:.2f}%")
print(f"Unicità: {uniqueness*100:.2f}%")

# Per la diversità, puoi per esempio calcolare le fingerprint e misurare la similarità media.
# Questo è un esempio con RDKit (assicurati di avere rdkit installato):
from rdkit import Chem
from rdkit.Chem import AllChem, DataStructs

fps = []
for s in valid:
    mol = Chem.MolFromSmiles(s)
    if mol:
        fp = AllChem.GetMorganFingerprintAsBitVect(mol, 2)
        fps.append(fp)

# Calcola la similarità media tra tutte le coppie
similarities = []
for i in range(len(fps)):
    for j in range(i+1, len(fps)):
        sim = DataStructs.TanimotoSimilarity(fps[i], fps[j])
        similarities.append(sim)
        
avg_similarity = np.mean(similarities) if similarities else 0
diversity = 1 - avg_similarity

print(f"Diversità (1 - similarità media): {diversity:.2f}")

In [8]:
pwd

'/home/grad/Desktop/pietro/denovo/ab/ab2'