In [1]:
import tensorflow as tf

print(tf.__version__)
print("Num GPUs Available: ", len(tf.config.list_physical_devices('GPU')))  # Should print 0
print("Num CPUs Available: ", len(tf.config.list_physical_devices('CPU')))  # Should print > 0

with tf.device('/cpu:0'):  # Explicitly use CPU (optiona§§l but good practice)
    a = tf.constant([1.0, 2.0, 3.0])
    b = tf.constant([4.0, 5.0, 6.0])
    c = a + b
print(c)

2025-06-11 09:40:32.265587: I tensorflow/core/util/port.cc:153] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2025-06-11 09:40:32.280845: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:467] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1749627632.297694 3709216 cuda_dnn.cc:8579] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1749627632.302877 3709216 cuda_blas.cc:1407] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
W0000 00:00:1749627632.316726 3709216 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking 

2.19.0
Num GPUs Available:  1
Num CPUs Available:  1
tf.Tensor([5. 7. 9.], shape=(3,), dtype=float32)


I0000 00:00:1749627634.885485 3709216 gpu_device.cc:2019] Created device /job:localhost/replica:0/task:0/device:GPU:0 with 1426 MB memory:  -> device: 0, name: NVIDIA RTX A2000 12GB, pci bus id: 0000:65:00.0, compute capability: 8.6


In [2]:
# --- Configurazione Avanzata ---
class AdvancedConfig:
    SMILES_FILE = '/home/grad/Desktop/pietro/denovo/new/CHEMBL25.tar.xz'  # file con circa 26k SMILES
    BATCH_SIZE = 64
    EPOCHS = 200
    EMBED_DIM = 94
    TRANSFORMER_HEADS = 4
    TRANSFORMER_LAYERS = 4
    FF_DIM = 300
    VALID_RATIO = 0.1
    TEMPERATURE = 1.0
    TEMPERATURE_DECAY = 0.97
    GEN_NUM = 1
    WARMUP_EPOCHS = 35  # Per il curriculum learning iniziale
    MAX_RANDOMIZATIONS = 3
    MAX_LENGTH = 100
    PRINT_EVERY = 100
    DROPOUT_RATE = 0.15
    GRADIENT_CLIP = 1.0
    L2_REG = 1e-5
    LOSS_STABILITY_THRESHOLD = 0.01  # variazione relativa sotto il 1%
    CURRICULUM_START_COMPLEXITY = 0   # complessità iniziale (0 = molecole molto semplici)
    CURRICULUM_COMPLEXITY_STEP = 1    # incremento della soglia di complessità
    AUGMENT_PROB = 0.1

config = AdvancedConfig()

In [3]:
import numpy as np
import matplotlib.pyplot as plt
import random
import time
from rdkit.Chem import MolFromSmiles, MolToSmiles
from sklearn.model_selection import train_test_split
from tensorflow.keras import backend as K
from tensorflow.keras.models import Model
from tensorflow.keras.layers import (
    Input, Embedding, GRU, Dense, Bidirectional, 
    LayerNormalization, Dropout, Attention, Concatenate,
    Layer, Multiply, Masking, RepeatVector,  GlobalAveragePooling1D
)
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.callbacks import (
    EarlyStopping, ModelCheckpoint, 
    LearningRateScheduler, Callback
)
import re
import logging
from typing import List, Optional, Tuple
from rdkit import Chem
from rdkit.Chem import MolFromSmiles, MolToSmiles, SanitizeMol, SanitizeFlags
from rdkit.Chem.rdmolops import AssignStereochemistry
import tensorflow as tf
from tensorflow.keras.layers import Layer, Embedding, Input, LayerNormalization, MultiHeadAttention, Dropout, Dense
from tensorflow.keras.models import Model
from tensorflow.keras.callbacks import Callback, ReduceLROnPlateau
from sklearn.model_selection import train_test_split
from threading import Lock
from sklearn.decomposition import PCA
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D  # Per grafici 3D

In [4]:


# Configurazione del logger
logger = logging.getLogger(__name__)
logging.basicConfig(level=logging.INFO, format="%(asctime)s [%(levelname)s] %(message)s")



# --- Funzione per calcolare la complessità di una SMILES ---
def compute_complexity_from_tokens(tokens: List[str]) -> int:
    """
    Calcola la complessità di una SMILES come la somma del numero di anelli (ottenuti con GetSSSR)
    e del numero di ramificazioni (conteggio delle parentesi aperte).
    """
    smiles = ''.join(tokens)
    try:
        mol = Chem.MolFromSmiles(smiles)
        if not mol:
            return float('inf')
        num_rings = Chem.GetSSSR(mol)
        num_branches = smiles.count('(')
        return num_rings + num_branches
    except Exception:
        return float('inf')

# --- Funzioni di Preprocessing ---
def randomize_smiles(smiles: str, num_versions: int = 3) -> List[str]:
    mol = Chem.MolFromSmiles(smiles)
    if not mol:
        return []
    randomized = []
    for _ in range(num_versions):
        try:
            new_smiles = Chem.MolToSmiles(mol, doRandom=True, canonical=False)
            if new_smiles:
                randomized.append(new_smiles)
        except Exception as e:
            logger.debug(f"Randomization error: {e}")
            continue
    return randomized

def validate_and_fix_smiles(smiles: str) -> str:
    try:
        mol = Chem.MolFromSmiles(smiles, sanitize=True)
        if mol is None:
            return None
        try:
            Chem.Kekulize(mol, clearAromaticFlags=True)
        except Exception as e:
            logger.debug(f"Kekulization error in SMILES {smiles}: {e}")
            return None
        return Chem.MolToSmiles(mol, canonical=True, isomericSmiles=False)
    except Exception as e:
        logger.debug(f"Parsing/sanitization error in SMILES {smiles}: {e}")
        return None

def robust_tokenize(smiles: str) -> list:
    pattern = (
        r"(\[[^\[\]]{1,6}\]|"                 # atomi in parentesi quadre
        r"Br|Cl|Si|Na|Mg|Mn|Ca|Fe|Zn|Se|Li|K|Al|B|"  # elementi multi-char
        r"R[0-9]|r[0-9]|a[0-9]|"             # ring labels
        r"[A-Za-z0-9@+\-\\\/\(\)=#\$\.\%,])"  # singoli caratteri, incluso '%'
    )
    tokens = re.findall(pattern, smiles)
    tokens = re.findall(pattern, smiles)
    stack = []
    for t in tokens:
        if t.startswith('['):
            stack.append(t)
        if t.endswith(']') and not stack:
            return []
        if t.endswith(']'):
            stack.pop()
    try:
        if not stack and Chem.MolFromSmiles(''.join(tokens)):
            return tokens
    except Exception as e:
        logger.debug(f"Tokenization error: {e}")
        return []
    return tokens


def process_dataset(data: List[str]) -> Tuple[List[List[str]], List[str], int]:
    processed = []
    all_tokens = set()
    for s in data:
        fixed = validate_and_fix_smiles(s)
        if not fixed:
            continue
        tokens = robust_tokenize(fixed)
        if tokens and 3 <= len(tokens) <= config.MAX_LENGTH - 2:
            processed.append(tokens)
            all_tokens.update(tokens)
    vocab = ['<PAD>', '<START>', '<END>'] + sorted(all_tokens)
    lengths = [len(t) for t in processed]
    max_len = min(int(np.percentile(lengths, 99)) + 2, config.MAX_LENGTH) if processed else config.MAX_LENGTH
    logger.info(f"Processed SMILES: {len(processed)}/{len(data)}")
    logger.info(f"Unique tokens: {len(all_tokens)}")
    logger.info(f"Max length: {max_len}")
    return processed, vocab, max_len

# --- Componenti del Modello ---
class DynamicPositionalEncoding(Layer):
    def __init__(self, embed_dim, **kwargs):
        super().__init__(**kwargs)
        self.embed_dim = embed_dim
    def build(self, input_shape):
        max_seq_len = input_shape[1]
        pos = np.arange(max_seq_len)[:, np.newaxis]
        i = np.arange(self.embed_dim)[np.newaxis, :]
        angle_rates = 1 / np.power(10000, (2 * (i // 2)) / np.float32(self.embed_dim))
        angle_rads = pos * angle_rates
        angle_rads[:, 0::2] = tf.math.sin(angle_rads[:, 0::2])
        angle_rads[:, 1::2] = tf.math.cos(angle_rads[:, 1::2])
        self.pos_encoding = tf.cast(angle_rads[np.newaxis, ...], dtype=tf.float32)
    def call(self, inputs):
        return inputs + self.pos_encoding[:, :tf.shape(inputs)[1], :]
    def get_config(self):
        base_config = super().get_config()
        base_config.update({"embed_dim": self.embed_dim})
        return base_config

class ImprovedTransformerBlock(Layer):
    def __init__(self, embed_dim, num_heads, ffn_dim, rate=0.1, **kwargs):
        super().__init__(**kwargs)
        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.ffn_dim = ffn_dim
        self.rate = rate
        self.mha = MultiHeadAttention(
            num_heads=num_heads,
            key_dim=embed_dim,
            kernel_regularizer=tf.keras.regularizers.l2(config.L2_REG),
            dropout=rate
        )
        self.ffn = tf.keras.Sequential([
            Dense(ffn_dim, activation="gelu", kernel_regularizer=tf.keras.regularizers.l2(config.L2_REG)),
            Dense(embed_dim, kernel_regularizer=tf.keras.regularizers.l2(config.L2_REG))
        ])
        self.layernorm1 = LayerNormalization(epsilon=1e-6)
        self.layernorm2 = LayerNormalization(epsilon=1e-6)
        self.dropout1 = Dropout(rate)
        self.dropout2 = Dropout(rate)
    
    def call(self, inputs, training=False):
        seq_len = tf.shape(inputs)[1]
        causal_mask = tf.linalg.band_part(tf.ones((seq_len, seq_len)), -1, 0)
        attn_output = self.mha(inputs, inputs, attention_mask=causal_mask)
        attn_output = self.dropout1(attn_output, training=training)
        out1 = self.layernorm1(inputs + attn_output)
        ffn_output = self.ffn(out1)
        ffn_output = self.dropout2(ffn_output, training=training)
        return self.layernorm2(out1 + ffn_output)
    
    def get_config(self):
        base_config = super().get_config()
        base_config.update({
            "embed_dim": self.embed_dim,
            "num_heads": self.num_heads,
            "ffn_dim": self.ffn_dim,
            "rate": self.rate
        })
        return base_config

class CustomSchedule(tf.keras.optimizers.schedules.LearningRateSchedule):
    def __init__(self, embed_dim, warmup_steps=10000):
        super().__init__()
        self.embed_dim = tf.cast(embed_dim, tf.float32)
        self.warmup_steps = tf.cast(warmup_steps, tf.float32)
    def __call__(self, step):
        step = tf.cast(step, tf.float32) + 1e-9
        arg1 = tf.math.rsqrt(step)
        arg2 = step * (self.warmup_steps ** -1.5)
        return tf.math.rsqrt(self.embed_dim) * tf.math.minimum(arg1, arg2)
    def get_config(self):
        return {"embed_dim": self.embed_dim.numpy(), "warmup_steps": self.warmup_steps.numpy()}

def build_improved_model(vocab_size: int) -> Model:
    inputs = Input(shape=(config.MAX_LENGTH,))
    x = Embedding(vocab_size, config.EMBED_DIM, mask_zero=True)(inputs)
    x = DynamicPositionalEncoding(config.EMBED_DIM)(x)
    x = Dropout(config.DROPOUT_RATE)(x)
    for _ in range(config.TRANSFORMER_LAYERS):
        x = ImprovedTransformerBlock(config.EMBED_DIM, config.TRANSFORMER_HEADS, config.FF_DIM, rate=config.DROPOUT_RATE)(x)
    outputs = Dense(vocab_size)(x)
    def smoothed_loss(y_true, y_pred):
        # cast a interi perché tf.nn.sparse_* vuole etichette int32 o int64
        y_true_int = tf.cast(y_true, tf.int32)
        mask = tf.cast(tf.math.not_equal(y_true, 0), tf.float32)
        loss = tf.nn.sparse_softmax_cross_entropy_with_logits(labels=y_true_int, logits=y_pred)
        return tf.reduce_sum(loss * mask) / (tf.reduce_sum(mask) + 1e-9)
    optimizer = tf.keras.optimizers.Adam(
        learning_rate=CustomSchedule(config.EMBED_DIM),
        clipnorm=config.GRADIENT_CLIP
    )
    model = Model(inputs, outputs)
    model.compile(optimizer=optimizer, loss=smoothed_loss)
    return model

# --- Generatori Dati e Utilities ---
class ThreadSafeIterator:
    def __init__(self, iterator):
        self.iterator = iterator
        self.lock = Lock()
    def __iter__(self):
        return self
    def __next__(self):
        with self.lock:
            return next(self.iterator)

def threadsafe_generator(func):
    def wrapper(*args, **kwargs):
        return ThreadSafeIterator(func(*args, **kwargs))
    return wrapper

class CurriculumSmilesGenerator:
    """
    Generatore che implementa il curriculum learning basato sulla complessità:
      - Durante le prime WARMUP_EPOCHS la soglia di complessità aumenta.
      - Se la loss si stabilizza (variazione < LOSS_STABILITY_THRESHOLD), la soglia viene incrementata.
      - Applica data augmentation con probabilità AUGMENT_PROB.
      - Le molecole non valide vengono scartate.
    """
    def __init__(self, tokenized_smiles: List[List[str]], vocab: List[str]):
        self.char2idx = {c: i for i, c in enumerate(vocab)}
        self.idx2char = {i: c for c, i in self.char2idx.items()}
        self.original_data = []
        for tokens in tokenized_smiles:
            comp = compute_complexity_from_tokens(tokens)
            fixed = validate_and_fix_smiles(''.join(tokens))
            if fixed is None:
                continue
            self.original_data.append((tokens, comp))
        if self.original_data:
            valid_comps = [comp for _, comp in self.original_data if comp != float('inf')]
            self.max_complexity = max(valid_comps) if valid_comps else 0
        else:
            self.max_complexity = 0
        self.current_complexity = config.CURRICULUM_START_COMPLEXITY
        self.available_data = self._filter_data()
        self.train_smiles = {''.join(tokens) for tokens, _ in self.original_data}
        self.lock = Lock()
    
    def _filter_data(self):
        filtered = [tokens for tokens, comp in self.original_data if comp <= self.current_complexity]
        return filtered or [tokens for tokens, comp in self.original_data]
    
    def update_complexity(self, epoch: int, loss_diff: float = None):
        with self.lock:
            if loss_diff is not None and loss_diff < config.LOSS_STABILITY_THRESHOLD:
                self.current_complexity = min(self.current_complexity + config.CURRICULUM_COMPLEXITY_STEP, self.max_complexity)
            else:
                if epoch <= config.WARMUP_EPOCHS:
                    increment = int((self.max_complexity - config.CURRICULUM_START_COMPLEXITY) * (epoch / config.WARMUP_EPOCHS))
                    self.current_complexity = config.CURRICULUM_START_COMPLEXITY + increment
                else:
                    self.current_complexity = self.max_complexity
            self.available_data = self._filter_data()
            if not self.available_data:
                self.available_data = [tokens for tokens, comp in self.original_data]
                logger.warning("Reset available_data to original")
    
    @threadsafe_generator
    def __call__(self):
        while True:
            inputs = np.full((config.BATCH_SIZE, config.MAX_LENGTH), self.char2idx['<PAD>'], dtype=np.int32)
            targets = np.full_like(inputs, self.char2idx['<PAD>'])
            for i in range(config.BATCH_SIZE):
                with self.lock:
                    try:
                        tokens = random.choice(self.available_data)
                    except IndexError:
                        self.available_data = [tokens for tokens, comp in self.original_data]
                        tokens = random.choice(self.available_data)
                if random.random() < config.AUGMENT_PROB:
                    try:
                        augmented = randomize_smiles(''.join(tokens))
                        if augmented:
                            new_tokens = robust_tokenize(random.choice(augmented)) or tokens
                            tokens = new_tokens
                    except Exception as e:
                        logger.debug(f"Augmentation error: {e}")
                seq = ['<START>'] + tokens + ['<END>']
                padded = (seq + ['<PAD>'] * config.MAX_LENGTH)[:config.MAX_LENGTH]
                inputs[i] = [self.char2idx.get(t, self.char2idx['<PAD>']) for t in padded]
                targets[i, :-1] = inputs[i][1:]
                targets[i, -1] = self.char2idx['<PAD>']
            yield inputs, targets

    def get_dataset(self):
        return tf.data.Dataset.from_generator(
            self.__call__,
            output_signature=(
                tf.TensorSpec(shape=(config.BATCH_SIZE, config.MAX_LENGTH), dtype=tf.int32),
                tf.TensorSpec(shape=(config.BATCH_SIZE, config.MAX_LENGTH), dtype=tf.int32)
            )
        ).prefetch(tf.data.AUTOTUNE)

# --- Callback per il monitoraggio ---
class CustomTensorBoard(tf.keras.callbacks.TensorBoard):
    def on_epoch_end(self, epoch, logs=None):
        logs = logs or {}
        lr = self.model.optimizer.learning_rate
        if isinstance(lr, tf.keras.optimizers.schedules.LearningRateSchedule):
            logs['lr'] = lr(epoch).numpy()
        else:
            logs['lr'] = lr.numpy()
        super().on_epoch_end(epoch, logs)

class EnhancedTrainingMonitor(Callback):
    def __init__(self, val_gen: CurriculumSmilesGenerator):
        super().__init__()
        self.val_gen = val_gen
        self.best_val_loss = np.inf
        self.prev_val_loss = None
    def generate_smiles(self, num: int) -> Tuple[List[str], List[str]]:
        generated, valid = [], []
        if self.model is None:
            raise ValueError("Il modello non è stato assegnato al callback!")
        input_seq = np.full((1, config.MAX_LENGTH), self.val_gen.char2idx['<PAD>'], dtype=np.int32)
        input_seq[0, 0] = self.val_gen.char2idx['<START>']
        for _ in range(num):
            for t in range(1, config.MAX_LENGTH):
                logits = self.model(input_seq, training=False)[0, t-1]
                probs = tf.nn.softmax(logits / config.TEMPERATURE).numpy()
                valid_indices = [i for i, tok in self.val_gen.idx2char.items() 
                                 if re.match(r'^([A-Za-z0-9@#\[\]()+\-\\/%=:.,]|<\w+>)$', tok)]
                probs[[i for i in range(len(probs)) if i not in valid_indices]] = 0
                if np.sum(probs) == 0:
                    break
                probs /= np.sum(probs)
                sampled = np.random.choice(len(probs), p=probs)
                input_seq[0, t] = sampled
                if sampled == self.val_gen.char2idx['<END>']:
                    break
            raw = ''.join([self.val_gen.idx2char[i] for i in input_seq[0]
                           if i not in {self.val_gen.char2idx.get('<PAD>'), self.val_gen.char2idx['<END>']}][1:])
            final = validate_and_fix_smiles(raw) or raw
            generated.append(final)
            if Chem.MolFromSmiles(final):
                valid.append(final)
            input_seq = np.full((1, config.MAX_LENGTH), self.val_gen.char2idx['<PAD>'], dtype=np.int32)
            input_seq[0, 0] = self.val_gen.char2idx['<START>']
        return generated, valid
    def on_epoch_end(self, epoch, logs=None):
        if (epoch + 1) % config.PRINT_EVERY == 0:
            generated, valid = self.generate_smiles(config.GEN_NUM)
            validity = len(valid) / config.GEN_NUM
            unique = len(set(valid))
            novel = len([s for s in valid if s not in self.val_gen.train_smiles])
            logger.info(f"\n🧪 Epoca {epoch+1}:")
            logger.info(f"Loss Training: {logs.get('loss', 'N/A'):.4f} - Loss Validation: {logs.get('val_loss', 'N/A'):.4f}")
            logger.info(f"Validità: {validity:.1%}")
            logger.info(f"Unicità: {unique}/{len(valid)}")
            logger.info(f"Novità: {novel}/{len(valid)}")
            if valid:
                logger.info("Esempi:")
                for s in valid[:3]:
                    logger.info(f"- {s}")
            with open("valid_generated_smiles.txt", "w") as f:
                for s in valid:
                    f.write(s + "\n")

# --- Main ---
if __name__ == "__main__":
    logger.info("🚀 Avvio Training Transformer per SMILES con tokenizzazione BPE, data augmentation e curriculum learning basato sulla complessità")
    with open(config.SMILES_FILE) as f:
        raw_smiles = [line.strip() for line in f if line.strip()]
    logger.info("🔍 Validazione SMILES...")
    valid_smiles = []
    for idx, s in enumerate(raw_smiles):
        if idx % 5000 == 0:
            logger.info(f"Processati {idx}/{len(raw_smiles)}")
        fixed = validate_and_fix_smiles(s)
        if fixed and 3 <= len(fixed) <= config.MAX_LENGTH:
            valid_smiles.append(fixed)
    processed, vocab, max_len = process_dataset(valid_smiles)
    config.MAX_LENGTH = max_len
    logger.info("Vocabolario creato:")
    logger.info(vocab)
    
    import pickle
    with open('char2idx.pkl', 'wb') as f:
       pickle.dump({char: idx for idx, char in enumerate(vocab)}, f)
    with open('idx2char.pkl', 'wb') as f:
       pickle.dump({idx: char for idx, char in enumerate(vocab)}, f)
    import json
    with open("vocab.json", "w") as f:
        json.dump(vocab, f)
    logger.info("Mappature e vocabolario salvati su disco.")
    
    stratify_labels = [min(len(t), 20) for t in processed]
    try:
        counts = np.bincount(stratify_labels)
        if np.min(counts[np.nonzero(counts)]) < 2:
            logger.warning("Stratificazione disabilitata: alcune classi hanno meno di 2 esempi.")
            stratify_param = None
        else:
            stratify_param = stratify_labels
    except Exception as e:
        logger.warning(f"Errore nella stratificazione: {e}. Disabilito stratificazione.")
        stratify_param = None

    train_data, val_data = train_test_split(
        processed,
        test_size=config.VALID_RATIO,
        stratify=stratify_param,
        random_state=42
    )

2025-06-11 09:40:35,217 [INFO] 🚀 Avvio Training Transformer per SMILES con tokenizzazione BPE, data augmentation e curriculum learning basato sulla complessità


UnicodeDecodeError: 'utf-8' codec can't decode byte 0xfd in position 0: invalid start byte

In [None]:
    model = build_improved_model(len(vocab))
    train_gen = CurriculumSmilesGenerator(train_data, vocab)
    val_gen = CurriculumSmilesGenerator(val_data, vocab)
    
    train_gen.update_complexity(0)
    val_gen.update_complexity(0)
    
    monitor_callback = EnhancedTrainingMonitor(val_gen)
    
    callbacks = [
        monitor_callback,
        ReduceLROnPlateau(monitor='val_loss', factor=0.5, patience=5, min_lr=1e-6),
        CustomTensorBoard(log_dir='./logs'),
        tf.keras.callbacks.ModelCheckpoint('diff_scaffold_2.h5', save_best_only=True, monitor='val_loss')
    ]
    
    steps_per_epoch = len(train_data) // config.BATCH_SIZE
    val_steps = len(val_data) // config.BATCH_SIZE
    total_epochs = config.EPOCHS
    start_training = time.time()
    
    for epoch in range(1, total_epochs + 1):
        epoch_start = time.time()
        current_history = model.fit(
            train_gen.get_dataset(),
            steps_per_epoch=steps_per_epoch,
            callbacks=callbacks,
            epochs=1,
            validation_data=val_gen.get_dataset(),
            validation_steps=val_steps,
            verbose=2,
        )
        epoch_duration = time.time() - epoch_start
        current_val_loss = current_history.history['val_loss'][0]
        if epoch > 1:
            loss_diff = abs(current_val_loss - prev_val_loss) / prev_val_loss
        else:
            loss_diff = None
        prev_val_loss = current_val_loss

        train_gen.update_complexity(epoch, loss_diff=loss_diff)
        val_gen.update_complexity(epoch, loss_diff=loss_diff)

In [None]:
# Dopo il training, per generare SMILES e calcolare le metriche:
generated, valid = monitor_callback.generate_smiles(10000)  # genera 1000 SMILES

# Calcola validità
validity = len(valid) / 10000

# Calcola unicità (solo tra quelle valide)
unique_valid = len(set(valid))
uniqueness = unique_valid / len(valid) if valid else 0

print(f"Validità: {validity*100:.2f}%")
print(f"Unicità: {uniqueness*100:.2f}%")

# Per la diversità, puoi per esempio calcolare le fingerprint e misurare la similarità media.
# Questo è un esempio con RDKit (assicurati di avere rdkit installato):
from rdkit import Chem
from rdkit.Chem import AllChem, DataStructs

fps = []
for s in valid:
    mol = Chem.MolFromSmiles(s)
    if mol:
        fp = AllChem.GetMorganFingerprintAsBitVect(mol, 2)
        fps.append(fp)

# Calcola la similarità media tra tutte le coppie
similarities = []
for i in range(len(fps)):
    for j in range(i+1, len(fps)):
        sim = DataStructs.TanimotoSimilarity(fps[i], fps[j])
        similarities.append(sim)
        
avg_similarity = np.mean(similarities) if similarities else 0
diversity = 1 - avg_similarity

print(f"Diversità (1 - similarità media): {diversity:.2f}")