<a href="https://colab.research.google.com/github/quanghuydsai/Project-III/blob/main/LDMOL_ENCODING.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install torch transformers pathlib rdkit



In [None]:
from google.colab import drive
from google.colab import files

drive.mount("/content/drive")

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [None]:
import torch
import re
class regexTokenizer():
    def __init__(self,vocab_path='/content/drive/MyDrive/LDMOL/vocab_bpe_300_sc.txt',max_len=127):
        with open(vocab_path,'r') as f:
            x = f.readlines()
            x = [xx.replace('##', '') for xx in x]
            x2 = x.copy()
        x2.sort(key=len, reverse=True)
        pattern = "("+"|".join(re.escape(token).strip()[:-1] for token in x2)+")"
        self.rg = re.compile(pattern)

        self.idtotok  = { cnt:i.strip() for cnt,i in enumerate(x)}
        self.vocab_size = len(self.idtotok) #SOS, EOS, pad
        self.toktoid = { v:k for k,v in self.idtotok.items()}
        self.max_len = max_len
        self.cls_token_id = self.toktoid['[CLS]']
        self.sep_token_id = self.toktoid['[SEP]']
        self.pad_token_id = self.toktoid['[PAD]']

    def decode_one(self, iter):
        if self.sep_token_id in iter:   iter = iter[:(iter == self.sep_token_id).nonzero(as_tuple=True)[0][0].item()]
        # return "".join([self.ind2Letter(i) for i in iter]).replace('[SOS]','').replace('[EOS]','').replace('[PAD]','')
        return "".join([self.idtotok[i.item()] for i in iter[1:]])

    def decode(self,ids:torch.tensor):
        if len(ids.shape)==1:
            return [self.decode_one(ids)]
        else:
            smiles  = []
            for i in ids:
                smiles.append(self.decode_one(i))
            return smiles
    def __len__(self):
        return self.vocab_size

    def __call__(self,smis:list, truncation='max_len'):
        tensors = []
        lengths = []
        if type(smis) is str:
            smis = [smis]
        for i in smis:
            length, tensor = self.encode_one(i)
            tensors.append(tensor)
            lengths.append(length)
        output = torch.concat(tensors,dim=0)
        if truncation == 'max_len':
            return output
        elif truncation == 'longest':
            return output[:, :max(lengths)]
        else:
            raise ValueError('truncation should be either max_len or longest')

    def encode_one(self, smi):
        smi = '[CLS]' + smi + '[SEP]'
        res = [self.toktoid[i] for i in self.rg.findall(smi)]
        token_length = len(res)
        if token_length < self.max_len:
            res += [self.pad_token_id]*(self.max_len-len(res))
        else:
            res = res[:self.max_len]
            # res[-1] = self.sep_token_id
        return token_length, torch.LongTensor([res])

In [None]:
from rdkit import RDLogger
RDLogger.DisableLog('rdApp.*')

In [None]:
from abc import abstractmethod
from random import shuffle
from typing import Any
from typing import Iterable
from typing import List
from typing import Union

import numpy as np
from rdkit import Chem

In [None]:
Mol = Chem.Mol
class Augmenter:
    """An abstract base class for molecular augmenters.

    The class has one method, `augment`, which is overriden by child classes.
    It is possible to call the class with either a list of molecules or a single
    molecules. This input will then be passed to `augment` and the augmented
    molecule(s) will be returned.
    The Boolean ".active" property can be set to toggle augmentation.

    :param active: Whether the augmentation should be active or not, defaults to True.
    :param augment_prob: if lower than 1, it is used to randomly turn-off augmentation on an individual basis
    """

    def __init__(self, active: bool = True, augment_prob: float = 1.0) -> None:
        self.active = active
        self.augment_prob = augment_prob

    def __call__(self, data: Union[Iterable[Any], Any]) -> List[Any]:
        """Augments either a list of Anys or a single molecule by making sure
        the input is put into a `List` and then passed to the `augment` function.

        :param data: Either a list of molecules or a single molecules to be augmented.

        :return: A list of augmented molecules.
        """
        # Str is Iterable but must be encapsulated (e.g. single SMILES string)
        if not isinstance(data, Iterable) or isinstance(data, str):
            data = [data]

        return self.augment(data)

    @abstractmethod
    def _augment(self, data: Iterable[Any]) -> List[Any]:
        raise NotImplementedError()

    def augment(self, data: Iterable[Any]) -> List[Any]:
        """
        Augment a given list

        :param data: a list of molecules to be augmented.
        :return: A list of augmented molecules.
        """
        if self.active:
            return self._augment(data)
        return list(data)
class MolAugmenter(Augmenter):
    """
    Augmenter that works on RDKit Mol objects
    """

    def randomize_mols_restricted(self, mols: Iterable[Mol]) -> List[Mol]:
        """Randomizes the atom ordering of a list of RDKit molecules (`rdkit.Chem.Mol`:s).

        :param mols: List of RDKit molecules to be augmented.
        :return:  List of augmented RDKit molecules.
        """
        return list(map(self.randomize_mol_restricted, mols))

    def randomize_mol_restricted(self, mol: Mol) -> Mol:
        """Randomize the atom ordering of a RDKit molecule (`rdkit.Chem.Mol`).

        :param mol:  RDKit molecule to get a randomized atom order.
        :return: RDKit molecule object with a randomized atom-order.
        """
        # Standard shuffle surprisingly leads to 35% slower code.
        if self.augment_prob < np.random.rand():
            return mol
        atom_order: List[int] = list(range(mol.GetNumAtoms()))
        np.random.shuffle(atom_order)
        return Chem.RenumberAtoms(mol, atom_order)

    def _augment(self, data: Iterable[Mol]) -> List[Mol]:
        """Randomizes `RDKit molecules by shuffling the atom order.

        :param data: List of RDKit molecules to be randomized.
        :return:  A list of randomized molecules.
        """
        return self.randomize_mols_restricted(data)

In [None]:
from torch.utils.data import Dataset
import random
from rdkit import Chem, RDLogger
from rdkit.Chem.EnumerateStereoisomers import EnumerateStereoisomers

In [None]:
class SMILESDataset_pretrain(Dataset):
    def __init__(self, data_path, data_length=None, shuffle=False, is_train=True):
        if data_length is not None:
            with open(data_path, 'r') as f:
                for _ in range(data_length[0]):
                    f.readline()
                lines = []
                for _ in range(data_length[1] - data_length[0]):
                    lines.append(f.readline())
        else:
            with open(data_path, 'r') as f:
                lines = f.readlines()

        self.data = [l.strip() for l in lines]

        if shuffle:
            random.shuffle(self.data)

        self.train = is_train

    def __len__(self):
        return len(self.data)

    def __getitem__(self, index):
        smiles = self.data[index].split('\t')[0]
        smiles2 = smiles
        if random.random() > 0.:
            try:
                mol = Chem.MolFromSmiles(smiles)
                sc_list = list(EnumerateStereoisomers(mol))
                if self.train and len(sc_list) > 1:
                    mol, mol2 = random.sample(sc_list, k=2)
                    smiles = Chem.MolToSmiles(mol, canonical=True, isomericSmiles=True)
                    smiles2 = Chem.MolToSmiles(mol2, canonical=True, isomericSmiles=True)
                else:
                    mol = random.choice(sc_list)
                    smiles = Chem.MolToSmiles(mol, canonical=True, isomericSmiles=True)
            except:
                pass
        if self.train and smiles2 != smiles:
            return '[CLS]'+smiles+'Q[CLS]'+smiles2
        return '[CLS]' + smiles

In [None]:
from transformers.models.bert.configuration_bert import BertConfig

In [None]:
!pip install tqdm torch_geometric

Collecting torch_geometric
  Downloading torch_geometric-2.7.0-py3-none-any.whl.metadata (63 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m63.7/63.7 kB[0m [31m3.8 MB/s[0m eta [36m0:00:00[0m
Downloading torch_geometric-2.7.0-py3-none-any.whl (1.3 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.3/1.3 MB[0m [31m34.1 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: torch_geometric
Successfully installed torch_geometric-2.7.0


In [None]:
from torch_geometric.data import InMemoryDataset
class PubChemDataset(InMemoryDataset):
    def __init__(self, path):
        super(PubChemDataset, self).__init__()
        self.data, self.slices = torch.load(path, weights_only=False)

    def __getitem__(self, idx):
        return self.get(idx)

In [None]:
raw_dataset = PubChemDataset('/content/drive/MyDrive/LDMOL/pretrain.pt')

In [None]:
ldmol_dataset = []
for data in raw_dataset:
    ldmol_dataset.append({
            'smiles': data.smiles,
            'text': data.text
        })

In [None]:
print(ldmol_dataset[0:2])

[{'smiles': 'CC(=O)OC(CC(=O)[O-])C[N+](C)(C)C', 'text': 'The molecule is an O-acylcarnitine having acetyl as the acyl substituent. It has a role as a human metabolite. It is functionally related to an acetic acid. It is a conjugate base of an O-acetylcarnitinium.\nThe molecule is a natural product found in Pseudo-nitzschia multistriata, Euglena gracilis, and other organisms with data available.\nThe molecule is a metabolite found in or produced by Saccharomyces cerevisiae.\nAn acetic acid ester of CARNITINE that facilitates movement of ACETYL COA into the matrices of mammalian MITOCHONDRIA during the oxidation of FATTY ACIDS.'}, {'smiles': 'C(CCl)C(F)(F)F', 'text': '3-chloro-1,1,1-trifluoropropane appears as a colorless odorless nonflammable liquid. Poisonous by inhalation. Emits toxic fumes of chlorine and fluorine when heated to decomposition.'}]


In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
from torch.cuda.amp import autocast, GradScaler
from transformers import BertModel
from rdkit import Chem
import tqdm

In [None]:
BERT_CONFIG_DICT = {
    "attention_probs_dropout_prob": 0.1,
    "hidden_act": "gelu",
    "hidden_dropout_prob": 0.1,
    "hidden_size": 1024,
    "initializer_range": 0.02,
    "intermediate_size": 3072,
    "layer_norm_eps": 1e-12,
    "max_position_embeddings": 512,
    "model_type": "bert",
    "num_attention_heads": 16,
    "num_hidden_layers": 12,
    "pad_token_id": 0,
    "type_vocab_size": 2,
    "vocab_size": 300
}

In [None]:
class ListDataset(Dataset):
    def __init__(self, data_list):
        self.data = data_list

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return self.data[idx]

In [None]:
TRAIN_CONFIG = {
    'embed_dim': 256,
    'batch_size': 32, # Giảm xuống nếu bị tràn bộ nhớ GPU trên Colab
    'temp': 0.07,
    'queue_size': 16384,
    'momentum': 0.995,
    'alpha': 0.4,
    'lr': 1e-4,
    'weight_decay': 0.02,
    'epochs': 5,
    'device': 'cuda' if torch.cuda.is_available() else 'cpu'
}

In [None]:
class LDMolEncoder(nn.Module):
    def __init__(self, config, bert_config_dict):
        super().__init__()
        self.config = config

        # Khởi tạo BertConfig từ dict
        bert_config = BertConfig(**bert_config_dict)
        self.text_encoder = BertModel(config=bert_config)

        text_width = self.text_encoder.config.hidden_size
        self.text_proj = nn.Linear(text_width, config['embed_dim'])
        self.aug = MolAugmenter()

        # Momentum Models
        self.text_encoder_m = BertModel(config=bert_config)
        self.text_proj_m = nn.Linear(text_width, config['embed_dim'])

        # Đóng băng gradient cho momentum models
        for p in self.text_encoder_m.parameters(): p.requires_grad = False
        for p in self.text_proj_m.parameters(): p.requires_grad = False

        self.copy_params()

        # Contrastive Queue
        self.temp = nn.Parameter(torch.ones([]) * config['temp'])
        self.register_buffer("text1_queue", torch.randn(config['embed_dim'], config['queue_size']))
        self.register_buffer("text2_queue", torch.randn(config['embed_dim'], config['queue_size']))
        self.register_buffer("queue_ptr", torch.zeros(1, dtype=torch.long))

        self.text1_queue = F.normalize(self.text1_queue, dim=0)
        self.text2_queue = F.normalize(self.text2_queue, dim=0)

    @torch.no_grad()
    def copy_params(self):
        for p, p_m in zip(self.text_encoder.parameters(), self.text_encoder_m.parameters()):
            p_m.data.copy_(p.data)
        for p, p_m in zip(self.text_proj.parameters(), self.text_proj_m.parameters()):
            p_m.data.copy_(p.data)

    @torch.no_grad()
    def _momentum_update(self):
        m = self.config['momentum']
        for p, p_m in zip(self.text_encoder.parameters(), self.text_encoder_m.parameters()):
            p_m.data = p_m.data * m + p.data * (1. - m)
        for p, p_m in zip(self.text_proj.parameters(), self.text_proj_m.parameters()):
            p_m.data = p_m.data * m + p.data * (1. - m)

    def forward(self, ids1, mask1, ids2, mask2, alpha=0):
        # Cập nhật momentum model
        self._momentum_update()

        # Trích xuất đặc trưng hiện tại
        out1 = self.text_encoder(ids1, attention_mask=mask1).last_hidden_state[:, 0, :]
        feat1 = F.normalize(self.text_proj(out1), dim=-1)

        out2 = self.text_encoder(ids2, attention_mask=mask2).last_hidden_state[:, 0, :]
        feat2 = F.normalize(self.text_proj(out2), dim=-1)

        with torch.no_grad():
            self.temp.clamp_(0.07, 0.5)

            # Momentum features
            m_out1 = self.text_encoder_m(ids1, attention_mask=mask1).last_hidden_state[:, 0, :]
            feat1_m = F.normalize(self.text_proj_m(m_out1), dim=-1)
            feat1_all = torch.cat([feat1_m.t(), self.text1_queue.clone().detach()], dim=1)

            m_out2 = self.text_encoder_m(ids2, attention_mask=mask2).last_hidden_state[:, 0, :]
            feat2_m = F.normalize(self.text_proj_m(m_out2), dim=-1)
            feat2_all = torch.cat([feat2_m.t(), self.text2_queue.clone().detach()], dim=1)

            # Targets
            batch_size = feat1.size(0)
            # Tạo ma trận mục tiêu: [batch_size, batch_size + queue_size]
            sim_targets = torch.zeros(batch_size, feat1_all.size(1)).to(feat1.device)
            # Fill diagonal phần batch hiện tại với 1
            sim_targets.fill_diagonal_(1)

            sim_21_m = feat2_m @ feat1_all / self.temp
            sim_21_targets = alpha * F.softmax(sim_21_m, dim=1) + (1 - alpha) * sim_targets

            sim_12_m = feat1_m @ feat2_all / self.temp
            sim_12_targets = alpha * F.softmax(sim_12_m, dim=1) + (1 - alpha) * sim_targets

        # Tính Loss
        sim_21 = feat2 @ feat1_all / self.temp
        sim_12 = feat1 @ feat2_all / self.temp

        loss_21 = -torch.sum(F.log_softmax(sim_21, dim=1) * sim_21_targets, dim=1).mean()
        loss_12 = -torch.sum(F.log_softmax(sim_12, dim=1) * sim_12_targets, dim=1).mean()

        loss_ita = (loss_21 + loss_12) / 2

        # Cập nhật Queue
        self._dequeue_and_enqueue(feat1_m, feat2_m)

        return loss_ita

    @torch.no_grad()
    def _dequeue_and_enqueue(self, feat1, feat2):
        batch_size = feat1.shape[0]
        ptr = int(self.queue_ptr)

        # Nếu batch cuối cùng nhỏ hơn queue_size còn lại, chỉ lấy phần vừa đủ
        space = self.config['queue_size'] - ptr
        actual_batch = min(batch_size, space)

        self.text1_queue[:, ptr:ptr + actual_batch] = feat1[:actual_batch].T
        self.text2_queue[:, ptr:ptr + actual_batch] = feat2[:actual_batch].T

        ptr = (ptr + actual_batch) % self.config['queue_size']
        self.queue_ptr[0] = ptr

In [None]:
import math
import matplotlib.pyplot as plt
from torch.optim.lr_scheduler import CosineAnnealingLR, LinearLR, SequentialLR

def plot_metrics(steps, losses, lrs, epoch, save_path=None):
    """Vẽ và lưu biểu đồ sau mỗi epoch"""
    plt.figure(figsize=(15, 5))

    # Biểu đồ Loss
    plt.subplot(1, 2, 1)
    plt.plot(steps, losses, color='#1f77b4', label='Train Loss', alpha=0.8)
    # Thêm đường trung bình trượt để dễ quan sát xu hướng
    if len(losses) > 50:
        window = 50
        means = [sum(losses[i:i+window])/window for i in range(len(losses)-window)]
        plt.plot(steps[window:], means, color='orange', label='Trend (Moving Avg)')

    plt.title(f'Loss Progression (Epoch {epoch})')
    plt.xlabel('Global Step')
    plt.ylabel('Loss')
    plt.grid(True, linestyle='--', alpha=0.6)
    plt.legend()

    # Biểu đồ Learning Rate
    plt.subplot(1, 2, 2)
    plt.plot(steps, lrs, color='#d62728', label='Learning Rate')
    plt.title(f'LR Schedule (Epoch {epoch})')
    plt.xlabel('Global Step')
    plt.ylabel('LR')
    plt.grid(True, linestyle='--', alpha=0.6)
    plt.legend()

    plt.tight_layout()
    if save_path:
        plt.savefig(save_path)
    plt.show()

In [None]:
def train_with_best_checkpoint(ldmol_dataset, tokenizer_path='/content/drive/MyDrive/LDMOL/vocab_bpe_300_sc.txt'):
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    save_path = '/content/drive/MyDrive/LDMOL/best_encoder.pt'
    tokenizer = regexTokenizer(vocab_path=tokenizer_path, max_len=127)
    train_ds = ListDataset(ldmol_dataset)
    loader = DataLoader(train_ds, batch_size=TRAIN_CONFIG['batch_size'], shuffle=True, drop_last=True)

    model = LDMolEncoder(TRAIN_CONFIG, BERT_CONFIG_DICT).to(device)
    optimizer = torch.optim.AdamW(model.parameters(), lr=TRAIN_CONFIG['lr'], weight_decay=TRAIN_CONFIG['weight_decay'])

    # Scheduler: Warmup (1 ep) + Cosine (4 ep)
    warmup_epochs = 1
    main_epochs = 4
    scheduler_warmup = LinearLR(optimizer, start_factor=5e-5/1e-4, end_factor=1.0, total_iters=warmup_epochs)
    scheduler_cosine = CosineAnnealingLR(optimizer, T_max=TRAIN_CONFIG['epochs'], eta_min=1e-5)
    scheduler = SequentialLR(optimizer, schedulers=[scheduler_warmup, scheduler_cosine], milestones=[warmup_epochs])

    scaler = GradScaler()

    # Tracking variables
    history_steps, history_losses, history_lrs = [], [], []
    global_step = 0
    best_loss = float('inf')  # Khởi tạo loss tốt nhất là vô cùng lớn

    model.train()
    total_epochs = warmup_epochs + main_epochs

    for epoch in range(total_epochs):
        epoch_total_loss = 0
        num_batches = 0
        pbar = tqdm.tqdm(loader, desc=f"Epoch {epoch+1}/{total_epochs}")

        for batch_idx, batch in enumerate(pbar):
            smiles_list = batch['smiles']
            text1_input, text2_input = [], []
            for sm in smiles_list:
                try:
                    mol = Chem.MolFromSmiles(sm)
                    if mol:
                        text1_input.append('[CLS]' + Chem.MolToSmiles(mol, canonical=True))
                        text2_input.append('[CLS]' + Chem.MolToSmiles(model.aug([mol])[0], canonical=False))
                except: continue
            if not text1_input: continue

            # Training step
            ids1 = tokenizer(text1_input, truncation='longest').to(device)
            mask1 = torch.where(ids1 == 0, 0, 1)
            ids2 = tokenizer(text2_input, truncation='longest').to(device)
            mask2 = torch.where(ids2 == 0, 0, 1)

            alpha = TRAIN_CONFIG['alpha'] * min(1.0, batch_idx / len(loader)) if epoch == 0 else TRAIN_CONFIG['alpha']

            optimizer.zero_grad()
            with autocast():
                loss = model(ids1, mask1, ids2, mask2, alpha=alpha)

            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()

            # Logging
            global_step += 1
            current_lr = optimizer.param_groups[0]['lr']
            history_steps.append(global_step)
            history_losses.append(loss.item())
            history_lrs.append(current_lr)

            epoch_total_loss += loss.item()
            num_batches += 1

            if global_step % 100 == 0:
                print(f" > Step {global_step} | Loss: {loss.item():.4f} | LR: {current_lr:.2e}")

            pbar.set_postfix({'loss': f"{loss.item():.4f}", 'lr': f"{current_lr:.2e}"})

        # Kết thúc epoch
        avg_epoch_loss = epoch_total_loss / num_batches
        scheduler.step()

        # Kiểm tra và lưu model tốt nhất
        if avg_epoch_loss < best_loss:
            best_loss = avg_epoch_loss
            torch.save({
                'epoch': epoch + 1,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'loss': best_loss,
            }, save_path)
            print(f"--- [MỚI] Đã lưu model tốt nhất tại Epoch {epoch+1} với Avg Loss: {best_loss:.4f} ---")

        # Vẽ biểu đồ
        plot_metrics(history_steps, history_losses, history_lrs, epoch + 1, save_path=f"metrics_epoch_{epoch+1}.png")

    print(f"Huấn luyện hoàn tất! Loss tốt nhất đạt được: {best_loss:.4f}")

In [None]:
train_with_best_checkpoint(ldmol_dataset, tokenizer_path='/content/drive/MyDrive/LDMOL/vocab_bpe_300_sc.txt')

  scaler = GradScaler()
  with autocast():
Epoch 1/25:   1%|          | 95/9315 [01:00<1:37:47,  1.57it/s, loss=5.3512, lr=5.00e-05]


KeyboardInterrupt: 

In [None]:
raw_testset = PubChemDataset('/content/drive/MyDrive/LDMOL/test.pt')

In [None]:
ldmol_testset = []
for data in raw_testset:
    ldmol_testset.append({
            'smiles': data.smiles,
            'text': data.text
        })

In [None]:
import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader
import tqdm
import numpy as np

In [None]:
@torch.no_grad()
def evaluate_encoding(ldmol_testset, model_path="'/content/drive/MyDrive/LDMOL/best_encoder.pt'", tokenizer_path='vocab_bpe_300_sc.txt'):
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    # 1. Tải Tokenizer
    tokenizer = regexTokenizer(vocab_path=tokenizer_path, max_len=127)

    # 2. Khởi tạo và tải mô hình
    # Đảm bảo TRAIN_CONFIG và BERT_CONFIG_DICT đã được định nghĩa như các bước trước
    model = LDMolEncoder(TRAIN_CONFIG, BERT_CONFIG_DICT).to(device)

    try:
        checkpoint = torch.load(model_path, map_location=device)
        # Hỗ trợ cả việc load dict lưu ở bước trước hoặc chỉ state_dict thuần
        if isinstance(checkpoint, dict) and 'model_state_dict' in checkpoint:
            model.load_state_dict(checkpoint['model_state_dict'])
        else:
            model.load_state_dict(checkpoint)
        print(f"--- Đã tải mô hình thành công từ {model_path} ---")
    except Exception as e:
        print(f"--- Lỗi khi tải mô hình: {e} ---")
        return

    model.eval()

    test_ds = ListDataset(ldmol_testset)
    # Batch size lớn hơn một chút để đánh giá nhanh hơn
    loader = DataLoader(test_ds, batch_size=32, shuffle=False)

    similarities = []

    print("Bắt đầu đánh giá trên bộ test...")
    for batch in tqdm.tqdm(loader):
        smiles_list = batch['smiles']

        text1_input, text2_input = [], []
        for sm in smiles_list:
            try:
                mol = Chem.MolFromSmiles(sm)
                if mol:
                    # Tạo cặp tương đồng
                    text1_input.append('[CLS]' + Chem.MolToSmiles(mol, canonical=True))
                    text2_input.append('[CLS]' + Chem.MolToSmiles(model.aug([mol])[0], canonical=False))
            except: continue

        if not text1_input: continue

        # Tokenize
        ids1 = tokenizer(text1_input, truncation='longest').to(device)
        mask1 = torch.where(ids1 == 0, 0, 1)
        ids2 = tokenizer(text2_input, truncation='longest').to(device)
        mask2 = torch.where(ids2 == 0, 0, 1)

        # Trích xuất đặc trưng (Features) qua Encoder hiện tại (không dùng momentum model)
        # Lấy vector [CLS] (index 0) và chiếu qua Projection Head
        feat1_out = model.text_encoder(ids1, attention_mask=mask1).last_hidden_state[:, 0, :]
        feat1 = F.normalize(model.text_proj(feat1_out), dim=-1)

        feat2_out = model.text_encoder(ids2, attention_mask=mask2).last_hidden_state[:, 0, :]
        feat2 = F.normalize(model.text_proj(feat2_out), dim=-1)

        # Tính Cosine Similarity cho từng cặp trong batch (phần tử tương ứng)
        # Vì feat đã normalize nên dot product chính là cosine similarity
        cosine_sim = torch.sum(feat1 * feat2, dim=-1)
        similarities.extend(cosine_sim.cpu().numpy())

    # 3. Tổng kết kết quả
    avg_sim = np.mean(similarities)
    std_sim = np.std(similarities)

    print("\n" + "="*30)
    print(f"KẾT QUẢ ĐÁNH GIÁ (ENCODING EVAL):")
    print(f"Số lượng mẫu kiểm tra: {len(similarities)}")
    print(f"Độ tương đồng Cosine trung bình: {avg_sim:.4f}")
    print(f"Độ lệch chuẩn: {std_sim:.4f}")
    print("="*30)

    return similarities

In [None]:
result = evaluate_encoding(ldmol_testset)