In [2]:
import pickle
import pandas as pd
from datasets import Dataset, DatasetDict
import multiprocessing
import copy
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
import torch,transformers
from reference.llm_loader import reload_model_and_tokenizer
from peft import LoraConfig, prepare_model_for_kbit_training
from transformers import TrainingArguments
from trl import SFTTrainer
from reference.utils import get_tokenizer
from tqdm import tqdm
import logging
import os
import statistics
import random
random.seed(42)  # 设置随机种子以确保结果可复现

reload_path = "./models/second-stage-llama3.2-3b/checkpoints/long_texts_match/checkpoint-1011"


🌟 构造smiles_to_struct_code的映射

🌟 构造训练/验证/测试数据

In [2]:
map1 = {
    "p_np": (
        "Does this molecule have blood-brain barrier permeability (BBB penetration)? True for BBB permeable and False for not BBB permeable.",
        "True",
        "False"
    )
}
map2 = {
    "NR-AR": (
        "Does this molecule activate the androgen receptor (NR-AR)? True for active and False for inactive.",
        "True",
        "False"
    ),
    "NR-AR-LBD": (
        "Does this molecule activate the ligand-binding domain of the androgen receptor (NR-AR-LBD)? True for active and False for inactive.",
        "True",
        "False"
    ),
    "NR-AhR": (
        "Does this molecule activate the aryl hydrocarbon receptor (NR-AhR)? True for active and False for inactive.",
        "True",
        "False"
    ),
    "NR-Aromatase": (
        "Does this molecule inhibit the aromatase enzyme (NR-Aromatase)? True for active and False for inactive.",
        "True",
        "False"
    ),
    "NR-ER": (
        "Does this molecule activate the estrogen receptor (NR-ER)? True for active and False for inactive.",
        "True",
        "False"
    ),
    "NR-ER-LBD": (
        "Does this molecule activate the ligand-binding domain of the estrogen receptor (NR-ER-LBD)? True for active and False for inactive.",
        "True",
        "False"
    ),
    "NR-PPAR-gamma": (
        "Does this molecule activate the peroxisome proliferator-activated receptor gamma (NR-PPAR-gamma)? True for active and False for inactive.",
        "True",
        "False"
    ),
    "SR-ARE": (
        "Does this molecule activate the antioxidant response element (SR-ARE)? True for active and False for inactive.",
        "True",
        "False"
    ),
    "SR-ATAD5": (
        "Does this molecule activate ATAD5 signaling (SR-ATAD5)? True for active and False for inactive.",
        "True",
        "False"
    ),
    "SR-HSE": (
        "Does this molecule activate the heat shock element (SR-HSE)? True for active and False for inactive.",
        "True",
        "False"
    ),
    "SR-MMP": (
        "Does this molecule activate the mitochondrial membrane potential stress response (SR-MMP)? True for active and False for inactive.",
        "True",
        "False"
    ),
    "SR-p53": (
        "Does this molecule activate the p53 stress response pathway (SR-p53)? True for active and False for inactive.",
        "True",
        "False"
    )
}
map3 = {
    "FDA_APPROVED": (
        "Has this molecule been approved by the FDA? True for FDA approved and false for non-approved.",
        "True",
        "False"
    ),
    "CT_TOX": (
        "Is this molecule associated with clinical toxicity? True for clinically toxic and false for non-toxic.",
        "True",
        "False"
    )
}
map4 = {
    "HIV_active": (
        "Does this molecule inhibit HIV replication? True for active molecules which can inhibit HIV and false for inactive ones.",
        "True",
        "False"
    )
}
map5 = {
    "Class": (
        "Is the binding result of the molecular on beta-secretase 1 true or false?",
        "True",
        "False"
    )
}
# for dname,mapping in zip(['BBBP', 'Tox21', 'ClinTox', 'HIV', 'BACE'],[map1,map2,map3,map4,map5]):
#     print(dname,len(mapping))

# 将多个mapping合为一个
all_mapping = {**map1, **map2, **map3, **map4, **map5}
# print(len(all_mapping))

save_path = "./phase1/datasets/all/codes_map.pt"
codes_map = torch.load(save_path,weights_only=True)
label_desc_map = all_mapping

datasets_to_load = ['BBBP', 'Tox21', 'ClinTox', 'HIV', 'BACE']
prompt_template = """
You are a chemistry expert. Classify the given molecule into the correct category based on its molecular structure (token) and smiles expression. Each structure token represents a unique graph pattern (e.g., a kind of similar molecular graphs).
[Molecule] {text_attribute}
[Structure Token] <struct>{code}</struct>
[Task] {task_desc} Output the complete correct answer from the following two options:
1. {category_1}
2. {category_2}
[Answer]"""

tool,assistant_name = get_tokenizer(tokenizer_path=reload_path)
def process(example):
    # smiles example['text']
    text_attribute = example['smiles']
    task = example['task']
    code_number = example['code']
    code = f'<gstruct_{code_number}>'
    task_details = label_desc_map[task]
    example['instruction'] = prompt_template.format(
        text_attribute = text_attribute,
        code = code,
        task_desc = task_details[0],
        category_1 = task_details[1],        
        category_2 = task_details[2]
    )
    example['response'] = [task_details[1] if example['label'] == 1 else task_details[2]][0]
    messages = [
        {"role": "user", "content": example['instruction']},
        {"role": assistant_name, "content": example['response']},
    ]
    example["text"] = tool.apply_chat_template(messages, tokenize=False)+"<|end_of_text|>"
   
    return example

def seal_datasets(split='train'):
    datasets = {}
    start = 0
    for dataset_name in datasets_to_load:
        dataset_path = f"./phase1/datasets/{dataset_name}/graph_attributes.pkl"
        with open(dataset_path, "rb") as f:
            graphs = pickle.load(f)
        tar_graphs = graphs[split]

        tar_len = len(tar_graphs)
        tar_codes = list(codes_map[split].values())[start: start + tar_len]
        start += tar_len
        assert len(tar_codes) == tar_len,'Length mismatch between tar_graphs and tar_codes'
        for g,code in zip(tar_graphs,tar_codes):
            g['code'] = code
        
        tasks = g['label'].keys()
        for task in tasks:
            # print('Processing '+task+' for dataset: ',dataset_name)
            new_graphs = copy.deepcopy(tar_graphs)
            for g in new_graphs:
                g['label'] = g['label'][task]
                g['task'] = task

            if split == 'train' and dataset_name != 'BACE':
                print('Balancing training split for dataset: ',dataset_name, 'task', task)
                pos_samples = [g for g in new_graphs if g['label'] == 1]
                neg_samples = [g for g in new_graphs if g['label'] == 0]
                num_pos, num_neg = len(pos_samples), len(neg_samples)
                num_total = num_pos + num_neg
                print(f"原始: Total={num_total}, Pos={num_pos} , Neg={num_neg}")

                if num_pos < num_neg:
                    # 复制正样本
                    target_len = int(num_neg / 5)
                    repeat_factor = target_len // num_pos
                    remainder = target_len % num_pos
                    pos_samples_extended = pos_samples * repeat_factor + random.sample(pos_samples, remainder)
                    print(f"复制正样本: {num_pos} -> {len(pos_samples_extended)}")
                    balanced_graphs = pos_samples_extended + neg_samples
                    random.shuffle(balanced_graphs)
                elif num_neg < num_pos:
                    # 复制负样本
                    target_len = int(num_pos / 5)
                    repeat_factor = target_len // num_neg 
                    remainder = target_len % num_neg 
                    if dataset_name == 'BBBP':
                        repeat_factor = num_pos // num_neg
                        remainder = num_pos % num_neg  
                    neg_samples_extended = neg_samples * repeat_factor + random.sample(neg_samples, remainder)
                    print(f"复制负样本: {num_neg} -> {len(neg_samples_extended)}")
                    balanced_graphs = pos_samples + neg_samples_extended
                    random.shuffle(balanced_graphs)
                else:
                    balanced_graphs = new_graphs
            else:
                balanced_graphs = new_graphs

            raw_dataset = Dataset.from_list(balanced_graphs)
            raw_dataset = raw_dataset.add_column("smiles", raw_dataset["text"])
            raw_dataset = raw_dataset.remove_columns("text")

            print(f"Processing {dataset_name}_{task} split: {split}")

            datasets[dataset_name+'_'+task] = raw_dataset.map(
                process,
                num_proc= multiprocessing.cpu_count(),
                load_from_cache_file=False,
                remove_columns=raw_dataset.column_names,
            )
            # === 统计正负样本数量和比例 ===
            if split == 'train':
                labels = raw_dataset["label"]
                pos = sum(1 for l in labels if l == 1)
                neg = sum(1 for l in labels if l == 0)
                total = len(labels)
                pos_ratio = pos / total if total > 0 else 0
                neg_ratio = neg / total if total > 0 else 0
                print(f"{dataset_name}_{task}(TRAIN): Total={total}, Pos={pos} ({pos_ratio:.2%}), Neg={neg} ({neg_ratio:.2%})")

    return datasets

splits = ['train','valid','test']


ds_dict = {}
for split in splits:
    ds_dict[split] = seal_datasets(split)

# 存储ds_dict
with open("./corpus/ds_dict_balanced.pkl", "wb") as f:
    pickle.dump(ds_dict, f)


Balancing training split for dataset:  BBBP task p_np
原始: Total=1631, Pos=1341 , Neg=290
复制负样本: 290 -> 1341
Processing BBBP_p_np split: train


Map (num_proc=48): 100%|██████████| 2682/2682 [00:02<00:00, 1113.90 examples/s]


BBBP_p_np(TRAIN): Total=2682, Pos=1341 (50.00%), Neg=1341 (50.00%)
Balancing training split for dataset:  Tox21 task NR-AR
原始: Total=6258, Pos=250 , Neg=6008
复制正样本: 250 -> 1201
Processing Tox21_NR-AR split: train


Map (num_proc=48): 100%|██████████| 7209/7209 [00:03<00:00, 2256.43 examples/s]


Tox21_NR-AR(TRAIN): Total=7209, Pos=1201 (16.66%), Neg=6008 (83.34%)
Balancing training split for dataset:  Tox21 task NR-AR-LBD
原始: Total=6258, Pos=193 , Neg=6065
复制正样本: 193 -> 1213
Processing Tox21_NR-AR-LBD split: train


Map (num_proc=48): 100%|██████████| 7278/7278 [00:03<00:00, 2256.36 examples/s]


Tox21_NR-AR-LBD(TRAIN): Total=7278, Pos=1213 (16.67%), Neg=6065 (83.33%)
Balancing training split for dataset:  Tox21 task NR-AhR
原始: Total=6258, Pos=589 , Neg=5669
复制正样本: 589 -> 1133
Processing Tox21_NR-AhR split: train


Map (num_proc=48): 100%|██████████| 6802/6802 [00:03<00:00, 2157.46 examples/s]


Tox21_NR-AhR(TRAIN): Total=6802, Pos=1133 (16.66%), Neg=5669 (83.34%)
Balancing training split for dataset:  Tox21 task NR-Aromatase
原始: Total=6258, Pos=208 , Neg=6050
复制正样本: 208 -> 1210
Processing Tox21_NR-Aromatase split: train


Map (num_proc=48): 100%|██████████| 7260/7260 [00:03<00:00, 2354.34 examples/s]


Tox21_NR-Aromatase(TRAIN): Total=7260, Pos=1210 (16.67%), Neg=6050 (83.33%)
Balancing training split for dataset:  Tox21 task NR-ER
原始: Total=6258, Pos=646 , Neg=5612
复制正样本: 646 -> 1122
Processing Tox21_NR-ER split: train


Map (num_proc=48): 100%|██████████| 6734/6734 [00:03<00:00, 2212.11 examples/s]


Tox21_NR-ER(TRAIN): Total=6734, Pos=1122 (16.66%), Neg=5612 (83.34%)
Balancing training split for dataset:  Tox21 task NR-ER-LBD
原始: Total=6258, Pos=299 , Neg=5959
复制正样本: 299 -> 1191
Processing Tox21_NR-ER-LBD split: train


Map (num_proc=48): 100%|██████████| 7150/7150 [00:03<00:00, 2312.12 examples/s]


Tox21_NR-ER-LBD(TRAIN): Total=7150, Pos=1191 (16.66%), Neg=5959 (83.34%)
Balancing training split for dataset:  Tox21 task NR-PPAR-gamma
原始: Total=6258, Pos=132 , Neg=6126
复制正样本: 132 -> 1225
Processing Tox21_NR-PPAR-gamma split: train


Map (num_proc=48): 100%|██████████| 7351/7351 [00:03<00:00, 2326.98 examples/s]


Tox21_NR-PPAR-gamma(TRAIN): Total=7351, Pos=1225 (16.66%), Neg=6126 (83.34%)
Balancing training split for dataset:  Tox21 task SR-ARE
原始: Total=6258, Pos=718 , Neg=5540
复制正样本: 718 -> 1108
Processing Tox21_SR-ARE split: train


Map (num_proc=48): 100%|██████████| 6648/6648 [00:03<00:00, 2170.57 examples/s]


Tox21_SR-ARE(TRAIN): Total=6648, Pos=1108 (16.67%), Neg=5540 (83.33%)
Balancing training split for dataset:  Tox21 task SR-ATAD5
原始: Total=6258, Pos=196 , Neg=6062
复制正样本: 196 -> 1212
Processing Tox21_SR-ATAD5 split: train


Map (num_proc=48): 100%|██████████| 7274/7274 [00:03<00:00, 2022.84 examples/s]


Tox21_SR-ATAD5(TRAIN): Total=7274, Pos=1212 (16.66%), Neg=6062 (83.34%)
Balancing training split for dataset:  Tox21 task SR-HSE
原始: Total=6258, Pos=281 , Neg=5977
复制正样本: 281 -> 1195
Processing Tox21_SR-HSE split: train


Map (num_proc=48): 100%|██████████| 7172/7172 [00:03<00:00, 2329.39 examples/s]


Tox21_SR-HSE(TRAIN): Total=7172, Pos=1195 (16.66%), Neg=5977 (83.34%)
Balancing training split for dataset:  Tox21 task SR-MMP
原始: Total=6258, Pos=711 , Neg=5547
复制正样本: 711 -> 1109
Processing Tox21_SR-MMP split: train


Map (num_proc=48): 100%|██████████| 6656/6656 [00:03<00:00, 2183.11 examples/s]


Tox21_SR-MMP(TRAIN): Total=6656, Pos=1109 (16.66%), Neg=5547 (83.34%)
Balancing training split for dataset:  Tox21 task SR-p53
原始: Total=6258, Pos=276 , Neg=5982
复制正样本: 276 -> 1196
Processing Tox21_SR-p53 split: train


Map (num_proc=48): 100%|██████████| 7178/7178 [00:03<00:00, 2379.05 examples/s]


Tox21_SR-p53(TRAIN): Total=7178, Pos=1196 (16.66%), Neg=5982 (83.34%)
Balancing training split for dataset:  ClinTox task FDA_APPROVED
原始: Total=1184, Pos=1105 , Neg=79
复制负样本: 79 -> 221
Processing ClinTox_FDA_APPROVED split: train


Map (num_proc=48): 100%|██████████| 1326/1326 [00:01<00:00, 716.84 examples/s]


ClinTox_FDA_APPROVED(TRAIN): Total=1326, Pos=1105 (83.33%), Neg=221 (16.67%)
Balancing training split for dataset:  ClinTox task CT_TOX
原始: Total=1184, Pos=95 , Neg=1089
复制正样本: 95 -> 217
Processing ClinTox_CT_TOX split: train


Map (num_proc=48): 100%|██████████| 1306/1306 [00:01<00:00, 660.82 examples/s]


ClinTox_CT_TOX(TRAIN): Total=1306, Pos=217 (16.62%), Neg=1089 (83.38%)
Balancing training split for dataset:  HIV task HIV_active
原始: Total=32896, Pos=1232 , Neg=31664
复制正样本: 1232 -> 6332
Processing HIV_HIV_active split: train


Map (num_proc=48): 100%|██████████| 37996/37996 [00:09<00:00, 3981.86 examples/s]


HIV_HIV_active(TRAIN): Total=37996, Pos=6332 (16.66%), Neg=31664 (83.34%)
Processing BACE_Class split: train


Map (num_proc=48): 100%|██████████| 1210/1210 [00:01<00:00, 660.64 examples/s]


BACE_Class(TRAIN): Total=1210, Pos=515 (42.56%), Neg=695 (57.44%)
Processing BBBP_p_np split: valid


Map (num_proc=48): 100%|██████████| 204/204 [00:01<00:00, 127.00 examples/s]


Processing Tox21_NR-AR split: valid


Map (num_proc=48): 100%|██████████| 782/782 [00:01<00:00, 434.71 examples/s]


Processing Tox21_NR-AR-LBD split: valid


Map (num_proc=48): 100%|██████████| 782/782 [00:01<00:00, 427.28 examples/s]


Processing Tox21_NR-AhR split: valid


Map (num_proc=48): 100%|██████████| 782/782 [00:01<00:00, 445.59 examples/s]


Processing Tox21_NR-Aromatase split: valid


Map (num_proc=48): 100%|██████████| 782/782 [00:01<00:00, 439.32 examples/s]


Processing Tox21_NR-ER split: valid


Map (num_proc=48): 100%|██████████| 782/782 [00:01<00:00, 435.06 examples/s]


Processing Tox21_NR-ER-LBD split: valid


Map (num_proc=48): 100%|██████████| 782/782 [00:01<00:00, 441.05 examples/s]


Processing Tox21_NR-PPAR-gamma split: valid


Map (num_proc=48): 100%|██████████| 782/782 [00:01<00:00, 425.19 examples/s]


Processing Tox21_SR-ARE split: valid


Map (num_proc=48): 100%|██████████| 782/782 [00:01<00:00, 440.59 examples/s]


Processing Tox21_SR-ATAD5 split: valid


Map (num_proc=48): 100%|██████████| 782/782 [00:01<00:00, 449.71 examples/s]


Processing Tox21_SR-HSE split: valid


Map (num_proc=48): 100%|██████████| 782/782 [00:01<00:00, 441.45 examples/s]


Processing Tox21_SR-MMP split: valid


Map (num_proc=48): 100%|██████████| 782/782 [00:01<00:00, 446.55 examples/s]


Processing Tox21_SR-p53 split: valid


Map (num_proc=48): 100%|██████████| 782/782 [00:01<00:00, 398.48 examples/s]


Processing ClinTox_FDA_APPROVED split: valid


Map (num_proc=48): 100%|██████████| 148/148 [00:01<00:00, 88.26 examples/s] 


Processing ClinTox_CT_TOX split: valid


Map (num_proc=48): 100%|██████████| 148/148 [00:01<00:00, 86.02 examples/s] 


Processing HIV_HIV_active split: valid


Map (num_proc=48): 100%|██████████| 4112/4112 [00:02<00:00, 1517.13 examples/s]


Processing BACE_Class split: valid


Map (num_proc=48): 100%|██████████| 151/151 [00:01<00:00, 88.27 examples/s] 


Processing BBBP_p_np split: test


Map (num_proc=48): 100%|██████████| 204/204 [00:01<00:00, 116.85 examples/s]


Processing Tox21_NR-AR split: test


Map (num_proc=48): 100%|██████████| 783/783 [00:01<00:00, 431.65 examples/s]


Processing Tox21_NR-AR-LBD split: test


Map (num_proc=48): 100%|██████████| 783/783 [00:01<00:00, 441.28 examples/s]


Processing Tox21_NR-AhR split: test


Map (num_proc=48): 100%|██████████| 783/783 [00:02<00:00, 339.74 examples/s]


Processing Tox21_NR-Aromatase split: test


Map (num_proc=48): 100%|██████████| 783/783 [00:01<00:00, 412.84 examples/s]


Processing Tox21_NR-ER split: test


Map (num_proc=48): 100%|██████████| 783/783 [00:01<00:00, 416.41 examples/s]


Processing Tox21_NR-ER-LBD split: test


Map (num_proc=48): 100%|██████████| 783/783 [00:01<00:00, 410.66 examples/s]


Processing Tox21_NR-PPAR-gamma split: test


Map (num_proc=48): 100%|██████████| 783/783 [00:01<00:00, 439.77 examples/s]


Processing Tox21_SR-ARE split: test


Map (num_proc=48): 100%|██████████| 783/783 [00:01<00:00, 438.78 examples/s]


Processing Tox21_SR-ATAD5 split: test


Map (num_proc=48): 100%|██████████| 783/783 [00:01<00:00, 393.45 examples/s]


Processing Tox21_SR-HSE split: test


Map (num_proc=48): 100%|██████████| 783/783 [00:02<00:00, 379.78 examples/s]


Processing Tox21_SR-MMP split: test


Map (num_proc=48): 100%|██████████| 783/783 [00:01<00:00, 422.58 examples/s]


Processing Tox21_SR-p53 split: test


Map (num_proc=48): 100%|██████████| 783/783 [00:01<00:00, 434.96 examples/s]


Processing ClinTox_FDA_APPROVED split: test


Map (num_proc=48): 100%|██████████| 148/148 [00:01<00:00, 86.62 examples/s] 


Processing ClinTox_CT_TOX split: test


Map (num_proc=48): 100%|██████████| 148/148 [00:01<00:00, 83.23 examples/s] 


Processing HIV_HIV_active split: test


Map (num_proc=48): 100%|██████████| 4112/4112 [00:02<00:00, 1649.87 examples/s]


Processing BACE_Class split: test


Map (num_proc=48): 100%|██████████| 152/152 [00:01<00:00, 94.11 examples/s] 


In [3]:
map1 = {
    "p_np": (
        "Does this molecule have blood-brain barrier permeability (BBB penetration)? True for BBB permeable and False for not BBB permeable.",
        "True",
        "False"
    )
}
map2 = {
    "NR-AR": (
        "Does this molecule activate the androgen receptor (NR-AR)? True for active and False for inactive.",
        "True",
        "False"
    ),
    "NR-AR-LBD": (
        "Does this molecule activate the ligand-binding domain of the androgen receptor (NR-AR-LBD)? True for active and False for inactive.",
        "True",
        "False"
    ),
    "NR-AhR": (
        "Does this molecule activate the aryl hydrocarbon receptor (NR-AhR)? True for active and False for inactive.",
        "True",
        "False"
    ),
    "NR-Aromatase": (
        "Does this molecule inhibit the aromatase enzyme (NR-Aromatase)? True for active and False for inactive.",
        "True",
        "False"
    ),
    "NR-ER": (
        "Does this molecule activate the estrogen receptor (NR-ER)? True for active and False for inactive.",
        "True",
        "False"
    ),
    "NR-ER-LBD": (
        "Does this molecule activate the ligand-binding domain of the estrogen receptor (NR-ER-LBD)? True for active and False for inactive.",
        "True",
        "False"
    ),
    "NR-PPAR-gamma": (
        "Does this molecule activate the peroxisome proliferator-activated receptor gamma (NR-PPAR-gamma)? True for active and False for inactive.",
        "True",
        "False"
    ),
    "SR-ARE": (
        "Does this molecule activate the antioxidant response element (SR-ARE)? True for active and False for inactive.",
        "True",
        "False"
    ),
    "SR-ATAD5": (
        "Does this molecule activate ATAD5 signaling (SR-ATAD5)? True for active and False for inactive.",
        "True",
        "False"
    ),
    "SR-HSE": (
        "Does this molecule activate the heat shock element (SR-HSE)? True for active and False for inactive.",
        "True",
        "False"
    ),
    "SR-MMP": (
        "Does this molecule activate the mitochondrial membrane potential stress response (SR-MMP)? True for active and False for inactive.",
        "True",
        "False"
    ),
    "SR-p53": (
        "Does this molecule activate the p53 stress response pathway (SR-p53)? True for active and False for inactive.",
        "True",
        "False"
    )
}
map3 = {
    "FDA_APPROVED": (
        "Has this molecule been approved by the FDA? True for FDA approved and false for non-approved.",
        "True",
        "False"
    ),
    "CT_TOX": (
        "Is this molecule associated with clinical toxicity? True for clinically toxic and false for non-toxic.",
        "True",
        "False"
    )
}
map4 = {
    "HIV_active": (
        "Does this molecule inhibit HIV replication? True for active molecules which can inhibit HIV and false for inactive ones.",
        "True",
        "False"
    )
}
map5 = {
    "Class": (
        "Is the binding result of the molecular on beta-secretase 1 true or false?",
        "True",
        "False"
    )
}
# for dname,mapping in zip(['BBBP', 'Tox21', 'ClinTox', 'HIV', 'BACE'],[map1,map2,map3,map4,map5]):
#     print(dname,len(mapping))

# 将多个mapping合为一个
all_mapping = {**map1, **map2, **map3, **map4, **map5}
# print(len(all_mapping))

save_path = "./phase1/datasets/all/codes_map.pt"
codes_map = torch.load(save_path,weights_only=True)
label_desc_map = all_mapping

datasets_to_load = ['BBBP', 'Tox21', 'ClinTox', 'HIV', 'BACE']
prompt_template = """
You are a chemistry expert. Classify the given molecule into the correct category based on its molecular structure (token) and smiles expression. Each structure token represents a unique graph pattern (e.g., a kind of similar molecular graphs).
[Molecule] {text_attribute}
[Structure Token] <struct>{code}</struct>
[Task] {task_desc} Output the complete correct answer from the following two options:
1. {category_1}
2. {category_2}
[Answer]"""

tool,assistant_name = get_tokenizer(tokenizer_path=reload_path)
def process(example):
    # smiles example['text']
    text_attribute = example['smiles']
    task = example['task']
    code_number = example['code']
    code = f'<gstruct_{code_number}>'
    task_details = label_desc_map[task]
    example['instruction'] = prompt_template.format(
        text_attribute = text_attribute,
        code = code,
        task_desc = task_details[0],
        category_1 = task_details[1],        
        category_2 = task_details[2]
    )
    example['response'] = [task_details[1] if example['label'] == 1 else task_details[2]][0]
    messages = [
        {"role": "user", "content": example['instruction']},
        {"role": assistant_name, "content": example['response']},
    ]
    example["text"] = tool.apply_chat_template(messages, tokenize=False)+"<|end_of_text|>"
   
    return example

def seal_datasets(split='train'):
    datasets = {}
    start = 0
    for dataset_name in datasets_to_load:
        dataset_path = f"./phase1/datasets/{dataset_name}/graph_attributes.pkl"
        with open(dataset_path, "rb") as f:
            graphs = pickle.load(f)
        tar_graphs = graphs[split]

        tar_len = len(tar_graphs)
        tar_codes = list(codes_map[split].values())[start: start + tar_len]
        start += tar_len
        assert len(tar_codes) == tar_len,'Length mismatch between tar_graphs and tar_codes'
        for g,code in zip(tar_graphs,tar_codes):
            g['code'] = code
        
        tasks = g['label'].keys()
        for task in tasks:
            # print('Processing '+task+' for dataset: ',dataset_name)
            new_graphs = copy.deepcopy(tar_graphs)
            for g in new_graphs:
                g['label'] = g['label'][task]
                g['task'] = task

            balanced_graphs = new_graphs

            raw_dataset = Dataset.from_list(balanced_graphs)
            raw_dataset = raw_dataset.add_column("smiles", raw_dataset["text"])
            raw_dataset = raw_dataset.remove_columns("text")

            datasets[dataset_name+'_'+task] = raw_dataset.map(
                process,
                num_proc= multiprocessing.cpu_count(),
                load_from_cache_file=False,
                remove_columns=raw_dataset.column_names,
            )
            # === 统计正负样本数量和比例 ===
            # labels = raw_dataset["label"]
            # pos = sum(1 for l in labels if l == 1)
            # neg = sum(1 for l in labels if l == 0)
            # total = len(labels)
            # pos_ratio = pos / total if total > 0 else 0
            # neg_ratio = neg / total if total > 0 else 0
            # logger.info(f"{dataset_name}_{task}: Total={total}, Pos={pos} ({pos_ratio:.2%}), Neg={neg} ({neg_ratio:.2%})")
            # print(f"{dataset_name}_{task}: Total={total}, Pos={pos} ({pos_ratio:.2%}), Neg={neg} ({neg_ratio:.2%})")

    return datasets

splits = ['train','valid','test']


ds_dict = {}
for split in splits:
    ds_dict[split] = seal_datasets(split)

# 存储ds_dict
with open("./corpus/ds_dict_ori.pkl", "wb") as f:
    pickle.dump(ds_dict, f)


Map (num_proc=128): 100%|██████████| 1631/1631 [00:02<00:00, 580.43 examples/s]
Map (num_proc=128): 100%|██████████| 6258/6258 [00:02<00:00, 2265.03 examples/s]
Map (num_proc=128): 100%|██████████| 6258/6258 [00:02<00:00, 2224.12 examples/s]
Map (num_proc=128): 100%|██████████| 6258/6258 [00:01<00:00, 5100.24 examples/s]
Map (num_proc=128): 100%|██████████| 6258/6258 [00:01<00:00, 5387.31 examples/s]
Map (num_proc=128): 100%|██████████| 6258/6258 [00:01<00:00, 5414.65 examples/s]
Map (num_proc=128): 100%|██████████| 6258/6258 [00:01<00:00, 5428.28 examples/s]
Map (num_proc=128): 100%|██████████| 6258/6258 [00:01<00:00, 4725.94 examples/s]
Map (num_proc=128): 100%|██████████| 6258/6258 [00:01<00:00, 4941.34 examples/s]
Map (num_proc=128): 100%|██████████| 6258/6258 [00:02<00:00, 2504.56 examples/s]
Map (num_proc=128): 100%|██████████| 6258/6258 [00:02<00:00, 2476.35 examples/s]
Map (num_proc=128): 100%|██████████| 6258/6258 [00:01<00:00, 5214.93 examples/s]
Map (num_proc=128): 100%|████

In [None]:
map1 = {
    "p_np": (
        "Does this molecule have blood-brain barrier permeability (BBB penetration)? True for BBB permeable and False for not BBB permeable.",
        "True",
        "False"
    )
}
map2 = {
    "NR-AR": (
        "Does this molecule activate the androgen receptor (NR-AR)? True for active and False for inactive.",
        "True",
        "False"
    ),
    "NR-AR-LBD": (
        "Does this molecule activate the ligand-binding domain of the androgen receptor (NR-AR-LBD)? True for active and False for inactive.",
        "True",
        "False"
    ),
    "NR-AhR": (
        "Does this molecule activate the aryl hydrocarbon receptor (NR-AhR)? True for active and False for inactive.",
        "True",
        "False"
    ),
    "NR-Aromatase": (
        "Does this molecule inhibit the aromatase enzyme (NR-Aromatase)? True for active and False for inactive.",
        "True",
        "False"
    ),
    "NR-ER": (
        "Does this molecule activate the estrogen receptor (NR-ER)? True for active and False for inactive.",
        "True",
        "False"
    ),
    "NR-ER-LBD": (
        "Does this molecule activate the ligand-binding domain of the estrogen receptor (NR-ER-LBD)? True for active and False for inactive.",
        "True",
        "False"
    ),
    "NR-PPAR-gamma": (
        "Does this molecule activate the peroxisome proliferator-activated receptor gamma (NR-PPAR-gamma)? True for active and False for inactive.",
        "True",
        "False"
    ),
    "SR-ARE": (
        "Does this molecule activate the antioxidant response element (SR-ARE)? True for active and False for inactive.",
        "True",
        "False"
    ),
    "SR-ATAD5": (
        "Does this molecule activate ATAD5 signaling (SR-ATAD5)? True for active and False for inactive.",
        "True",
        "False"
    ),
    "SR-HSE": (
        "Does this molecule activate the heat shock element (SR-HSE)? True for active and False for inactive.",
        "True",
        "False"
    ),
    "SR-MMP": (
        "Does this molecule activate the mitochondrial membrane potential stress response (SR-MMP)? True for active and False for inactive.",
        "True",
        "False"
    ),
    "SR-p53": (
        "Does this molecule activate the p53 stress response pathway (SR-p53)? True for active and False for inactive.",
        "True",
        "False"
    )
}
map3 = {
    "FDA_APPROVED": (
        "Has this molecule been approved by the FDA? True for FDA approved and false for non-approved.",
        "True",
        "False"
    ),
    "CT_TOX": (
        "Is this molecule associated with clinical toxicity? True for clinically toxic and false for non-toxic.",
        "True",
        "False"
    )
}
map4 = {
    "HIV_active": (
        "Does this molecule inhibit HIV replication? True for active molecules which can inhibit HIV and false for inactive ones.",
        "True",
        "False"
    )
}
map5 = {
    "Class": (
        "Is the binding result of the molecular on beta-secretase 1 true or false?",
        "True",
        "False"
    )
}
# for dname,mapping in zip(['BBBP', 'Tox21', 'ClinTox', 'HIV', 'BACE'],[map1,map2,map3,map4,map5]):
#     print(dname,len(mapping))

# 将多个mapping合为一个
all_mapping = {**map1, **map2, **map3, **map4, **map5}
# print(len(all_mapping))

save_path = "./phase1/datasets/all/codes_map.pt"
codes_map = torch.load(save_path,weights_only=True)
label_desc_map = all_mapping

datasets_to_load = ['BBBP', 'Tox21', 'ClinTox', 'HIV', 'BACE']
prompt_template = """
You are a chemistry expert. Classify the given molecule into the correct category based on its molecular structure (token) and smiles expression. Each structure token represents a unique graph pattern (e.g., a kind of similar molecular graphs).
[Molecule] {text_attribute}
[Structure Token] <struct>{code}</struct>
[Task] {task_desc} Output the complete correct answer from the following two options:
1. {category_1}
2. {category_2}
[Answer]"""

tool,assistant_name = get_tokenizer(tokenizer_path=reload_path)
def process(example):
    # smiles example['text']
    text_attribute = example['smiles']
    task = example['task']
    code_number = example['code']
    code = f'<gstruct_{code_number}>'
    task_details = label_desc_map[task]
    example['instruction'] = prompt_template.format(
        text_attribute = text_attribute,
        code = code,
        task_desc = task_details[0],
        category_1 = task_details[1],        
        category_2 = task_details[2]
    )
    example['response'] = [task_details[1] if example['label'] == 1 else task_details[2]][0]
    messages = [
        {"role": "user", "content": example['instruction']},
        {"role": assistant_name, "content": example['response']},
    ]
    example["text"] = tool.apply_chat_template(messages, tokenize=False)+"<|end_of_text|>"
   
    return example

def seal_datasets(split='train'):
    datasets = {}
    start = 0
    for dataset_name in datasets_to_load:
        dataset_path = f"./phase1/datasets/{dataset_name}/graph_attributes.pkl"
        with open(dataset_path, "rb") as f:
            graphs = pickle.load(f)
        tar_graphs = graphs[split]

        tar_len = len(tar_graphs)
        tar_codes = list(codes_map[split].values())[start: start + tar_len]
        start += tar_len
        assert len(tar_codes) == tar_len,'Length mismatch between tar_graphs and tar_codes'
        for g,code in zip(tar_graphs,tar_codes):
            g['code'] = code
        
        tasks = g['label'].keys()
        for task in tasks:
            # print('Processing '+task+' for dataset: ',dataset_name)
            new_graphs = copy.deepcopy(tar_graphs)
            for g in new_graphs:
                g['label'] = g['label'][task]
                g['task'] = task

            if split == 'train':
                pos_samples = [g for g in new_graphs if g['label'] == 1]
                neg_samples = [g for g in new_graphs if g['label'] == 0]
                num_pos, num_neg = len(pos_samples), len(neg_samples)

                if num_pos < num_neg:
                    # 复制正样本
                    print(f"复制正样本: {num_pos} -> {num_neg}")
                    repeat_factor = num_neg // num_pos
                    remainder = num_neg % num_pos
                    pos_samples_extended = pos_samples * repeat_factor + random.sample(pos_samples, remainder)
                    balanced_graphs = pos_samples_extended + neg_samples
                    random.shuffle(balanced_graphs)
                elif num_neg < num_pos:
                    # 复制负样本
                    print(f"复制负样本: {num_neg} -> {num_pos}")
                    repeat_factor = num_pos // num_neg
                    remainder = num_pos % num_neg
                    neg_samples_extended = neg_samples * repeat_factor + random.sample(neg_samples, remainder)
                    balanced_graphs = pos_samples + neg_samples_extended
                    random.shuffle(balanced_graphs)
                else:
                    balanced_graphs = new_graphs
            else:
                balanced_graphs = new_graphs

            raw_dataset = Dataset.from_list(balanced_graphs)
            raw_dataset = raw_dataset.add_column("smiles", raw_dataset["text"])
            raw_dataset = raw_dataset.remove_columns("text")

            datasets[dataset_name+'_'+task] = raw_dataset.map(
                process,
                num_proc= multiprocessing.cpu_count(),
                load_from_cache_file=False,
                remove_columns=raw_dataset.column_names,
            )
            # === 统计正负样本数量和比例 ===
            # labels = raw_dataset["label"]
            # pos = sum(1 for l in labels if l == 1)
            # neg = sum(1 for l in labels if l == 0)
            # total = len(labels)
            # pos_ratio = pos / total if total > 0 else 0
            # neg_ratio = neg / total if total > 0 else 0
            # logger.info(f"{dataset_name}_{task}: Total={total}, Pos={pos} ({pos_ratio:.2%}), Neg={neg} ({neg_ratio:.2%})")
            # print(f"{dataset_name}_{task}: Total={total}, Pos={pos} ({pos_ratio:.2%}), Neg={neg} ({neg_ratio:.2%})")

    return datasets

splits = ['train','valid','test']
ds_dict = {}
for split in splits:
    ds_dict[split] = seal_datasets(split)

# 存储ds_dict
with open("./corpus/ds_dict.pkl", "wb") as f:
    pickle.dump(ds_dict, f)

Map (num_proc=128): 100%|██████████| 1631/1631 [00:01<00:00, 1145.76 examples/s]


BBBP_p_np: Total=1631, Pos=1341 (82.22%), Neg=290 (17.78%)


Map (num_proc=128): 100%|██████████| 12016/12016 [00:02<00:00, 4882.88 examples/s]


Tox21_NR-AR: Total=12016, Pos=6008 (50.00%), Neg=6008 (50.00%)


Map (num_proc=128): 100%|██████████| 12130/12130 [00:02<00:00, 4126.86 examples/s]


Tox21_NR-AR-LBD: Total=12130, Pos=6065 (50.00%), Neg=6065 (50.00%)


Map (num_proc=128): 100%|██████████| 11338/11338 [00:01<00:00, 5674.02 examples/s]


Tox21_NR-AhR: Total=11338, Pos=5669 (50.00%), Neg=5669 (50.00%)


Map (num_proc=128): 100%|██████████| 12100/12100 [00:02<00:00, 4865.07 examples/s]


Tox21_NR-Aromatase: Total=12100, Pos=6050 (50.00%), Neg=6050 (50.00%)


Map (num_proc=128): 100%|██████████| 11224/11224 [00:01<00:00, 6946.97 examples/s]


Tox21_NR-ER: Total=11224, Pos=5612 (50.00%), Neg=5612 (50.00%)


Map (num_proc=128): 100%|██████████| 11918/11918 [00:01<00:00, 6605.21 examples/s]


Tox21_NR-ER-LBD: Total=11918, Pos=5959 (50.00%), Neg=5959 (50.00%)


Map (num_proc=128): 100%|██████████| 12252/12252 [00:01<00:00, 7028.15 examples/s]


Tox21_NR-PPAR-gamma: Total=12252, Pos=6126 (50.00%), Neg=6126 (50.00%)


Map (num_proc=128): 100%|██████████| 11080/11080 [00:01<00:00, 5798.56 examples/s]


Tox21_SR-ARE: Total=11080, Pos=5540 (50.00%), Neg=5540 (50.00%)


Map (num_proc=128): 100%|██████████| 12124/12124 [00:01<00:00, 6919.88 examples/s]


Tox21_SR-ATAD5: Total=12124, Pos=6062 (50.00%), Neg=6062 (50.00%)


Map (num_proc=128): 100%|██████████| 11954/11954 [00:01<00:00, 6615.87 examples/s]


Tox21_SR-HSE: Total=11954, Pos=5977 (50.00%), Neg=5977 (50.00%)


Map (num_proc=128): 100%|██████████| 11094/11094 [00:01<00:00, 6274.20 examples/s]


Tox21_SR-MMP: Total=11094, Pos=5547 (50.00%), Neg=5547 (50.00%)


Map (num_proc=128): 100%|██████████| 11964/11964 [00:01<00:00, 6320.36 examples/s]


Tox21_SR-p53: Total=11964, Pos=5982 (50.00%), Neg=5982 (50.00%)


Map (num_proc=128): 100%|██████████| 1184/1184 [00:01<00:00, 813.71 examples/s] 


ClinTox_FDA_APPROVED: Total=1184, Pos=1105 (93.33%), Neg=79 (6.67%)


Map (num_proc=128): 100%|██████████| 2178/2178 [00:01<00:00, 1515.96 examples/s]


ClinTox_CT_TOX: Total=2178, Pos=1089 (50.00%), Neg=1089 (50.00%)


Map (num_proc=128): 100%|██████████| 63328/63328 [00:08<00:00, 7092.14 examples/s]


HIV_HIV_active: Total=63328, Pos=31664 (50.00%), Neg=31664 (50.00%)


Map (num_proc=128): 100%|██████████| 1390/1390 [00:01<00:00, 1008.63 examples/s]


BACE_Class: Total=1390, Pos=695 (50.00%), Neg=695 (50.00%)


Map (num_proc=128): 100%|██████████| 204/204 [00:01<00:00, 157.22 examples/s]


BBBP_p_np: Total=204, Pos=112 (54.90%), Neg=92 (45.10%)


Map (num_proc=128): 100%|██████████| 782/782 [00:01<00:00, 607.46 examples/s] 


Tox21_NR-AR: Total=782, Pos=31 (3.96%), Neg=751 (96.04%)


Map (num_proc=128): 100%|██████████| 782/782 [00:01<00:00, 588.83 examples/s] 


Tox21_NR-AR-LBD: Total=782, Pos=25 (3.20%), Neg=757 (96.80%)


Map (num_proc=128): 100%|██████████| 782/782 [00:01<00:00, 593.62 examples/s]


Tox21_NR-AhR: Total=782, Pos=87 (11.13%), Neg=695 (88.87%)


Map (num_proc=128): 100%|██████████| 782/782 [00:01<00:00, 554.95 examples/s] 


Tox21_NR-Aromatase: Total=782, Pos=45 (5.75%), Neg=737 (94.25%)


Map (num_proc=128): 100%|██████████| 782/782 [00:01<00:00, 531.48 examples/s]


Tox21_NR-ER: Total=782, Pos=75 (9.59%), Neg=707 (90.41%)


Map (num_proc=128): 100%|██████████| 782/782 [00:01<00:00, 582.86 examples/s] 


Tox21_NR-ER-LBD: Total=782, Pos=29 (3.71%), Neg=753 (96.29%)


Map (num_proc=128): 100%|██████████| 782/782 [00:01<00:00, 489.00 examples/s]


Tox21_NR-PPAR-gamma: Total=782, Pos=32 (4.09%), Neg=750 (95.91%)


Map (num_proc=128): 100%|██████████| 782/782 [00:01<00:00, 605.68 examples/s] 


Tox21_SR-ARE: Total=782, Pos=106 (13.55%), Neg=676 (86.45%)


Map (num_proc=128): 100%|██████████| 782/782 [00:01<00:00, 603.89 examples/s] 


Tox21_SR-ATAD5: Total=782, Pos=35 (4.48%), Neg=747 (95.52%)


Map (num_proc=128): 100%|██████████| 782/782 [00:01<00:00, 559.67 examples/s] 


Tox21_SR-HSE: Total=782, Pos=44 (5.63%), Neg=738 (94.37%)


Map (num_proc=128): 100%|██████████| 782/782 [00:01<00:00, 588.83 examples/s] 


Tox21_SR-MMP: Total=782, Pos=111 (14.19%), Neg=671 (85.81%)


Map (num_proc=128): 100%|██████████| 782/782 [00:01<00:00, 579.29 examples/s] 


Tox21_SR-p53: Total=782, Pos=75 (9.59%), Neg=707 (90.41%)


Map (num_proc=128): 100%|██████████| 148/148 [00:01<00:00, 101.74 examples/s]


ClinTox_FDA_APPROVED: Total=148, Pos=142 (95.95%), Neg=6 (4.05%)


Map (num_proc=128): 100%|██████████| 148/148 [00:01<00:00, 102.35 examples/s]


ClinTox_CT_TOX: Total=148, Pos=7 (4.73%), Neg=141 (95.27%)


Map (num_proc=128): 100%|██████████| 4112/4112 [00:02<00:00, 1793.53 examples/s]


HIV_HIV_active: Total=4112, Pos=81 (1.97%), Neg=4031 (98.03%)


Map (num_proc=128): 100%|██████████| 151/151 [00:01<00:00, 95.30 examples/s] 


BACE_Class: Total=151, Pos=84 (55.63%), Neg=67 (44.37%)


Map (num_proc=128): 100%|██████████| 204/204 [00:01<00:00, 128.67 examples/s]


BBBP_p_np: Total=204, Pos=107 (52.45%), Neg=97 (47.55%)


Map (num_proc=128): 100%|██████████| 783/783 [00:01<00:00, 507.45 examples/s]


Tox21_NR-AR: Total=783, Pos=27 (3.45%), Neg=756 (96.55%)


Map (num_proc=128): 100%|██████████| 783/783 [00:01<00:00, 530.28 examples/s] 


Tox21_NR-AR-LBD: Total=783, Pos=19 (2.43%), Neg=764 (97.57%)


Map (num_proc=128): 100%|██████████| 783/783 [00:01<00:00, 498.76 examples/s]


Tox21_NR-AhR: Total=783, Pos=92 (11.75%), Neg=691 (88.25%)


Map (num_proc=128): 100%|██████████| 783/783 [00:01<00:00, 503.50 examples/s]


Tox21_NR-Aromatase: Total=783, Pos=47 (6.00%), Neg=736 (94.00%)


Map (num_proc=128): 100%|██████████| 783/783 [00:01<00:00, 499.75 examples/s]


Tox21_NR-ER: Total=783, Pos=70 (8.94%), Neg=713 (91.06%)


Map (num_proc=128): 100%|██████████| 783/783 [00:01<00:00, 518.53 examples/s]


Tox21_NR-ER-LBD: Total=783, Pos=21 (2.68%), Neg=762 (97.32%)


Map (num_proc=128): 100%|██████████| 783/783 [00:01<00:00, 581.24 examples/s] 


Tox21_NR-PPAR-gamma: Total=783, Pos=22 (2.81%), Neg=761 (97.19%)


Map (num_proc=128): 100%|██████████| 783/783 [00:01<00:00, 559.86 examples/s] 


Tox21_SR-ARE: Total=783, Pos=118 (15.07%), Neg=665 (84.93%)


Map (num_proc=128): 100%|██████████| 783/783 [00:01<00:00, 521.21 examples/s] 


Tox21_SR-ATAD5: Total=783, Pos=33 (4.21%), Neg=750 (95.79%)


Map (num_proc=128): 100%|██████████| 783/783 [00:01<00:00, 543.27 examples/s] 


Tox21_SR-HSE: Total=783, Pos=47 (6.00%), Neg=736 (94.00%)


Map (num_proc=128): 100%|██████████| 783/783 [00:01<00:00, 441.35 examples/s]


Tox21_SR-MMP: Total=783, Pos=96 (12.26%), Neg=687 (87.74%)


Map (num_proc=128): 100%|██████████| 783/783 [00:01<00:00, 623.24 examples/s] 


Tox21_SR-p53: Total=783, Pos=72 (9.20%), Neg=711 (90.80%)


Map (num_proc=128): 100%|██████████| 148/148 [00:01<00:00, 107.18 examples/s]


ClinTox_FDA_APPROVED: Total=148, Pos=139 (93.92%), Neg=9 (6.08%)


Map (num_proc=128): 100%|██████████| 148/148 [00:01<00:00, 97.62 examples/s] 


ClinTox_CT_TOX: Total=148, Pos=10 (6.76%), Neg=138 (93.24%)


Map (num_proc=128): 100%|██████████| 4112/4112 [00:01<00:00, 2272.41 examples/s]


HIV_HIV_active: Total=4112, Pos=130 (3.16%), Neg=3982 (96.84%)


Map (num_proc=128): 100%|██████████| 152/152 [00:01<00:00, 106.31 examples/s]


BACE_Class: Total=152, Pos=92 (60.53%), Neg=60 (39.47%)


In [None]:
# ds_dict_ori：原始数据比例
# ds_dict_balanced: 1:5平衡后的数据比例
# ds_dict: 1:1平衡后的数据比例


In [None]:
from collections import defaultdict
import torch
import pickle

save_path = "./phase1/datasets/all/codes_map.pt"
codes_map = torch.load(save_path,weights_only=True)
datasets_to_load = ['BBBP', 'HIV', 'BACE', 'ClinTox','Tox21']

# 生成 SMILES -> code 的映射字典
# Returns:
#     dict: {smiles_string: code}

smiles_to_code = defaultdict(dict)
for split in ['train', 'valid', 'test']:
    tar_len = 0
    start = 0
    for dataset_name in datasets_to_load:
        dataset_path = f"./phase1/datasets/{dataset_name}/graph_attributes.pkl"
        with open(dataset_path, "rb") as f:
            graphs = pickle.load(f)  # graphs 是一个 dict，里面按 split 存数据
        
        tar_graphs = graphs[split]
        tar_len = len(tar_graphs)
        print(f"for {split} of {dataset_name}: length -- {tar_len}; start -- {start}.")

        tar_codes = list(codes_map[split].values())[start: start + tar_len]
        start += tar_len

        assert len(tar_codes) == tar_len, \
            f"Length mismatch: graphs={tar_len}, codes={len(tar_codes)} for dataset {dataset_name}"

        for g, code in zip(tar_graphs, tar_codes):
            # 假设 g 里有 'smiles' 字段，如果是别的字段名需要改
            smiles_to_code[dataset_name][g['text']] = code



for train of BBBP: length -- 1631; start -- 0.
for train of Tox21: length -- 6258; start -- 1631.
for train of ClinTox: length -- 1184; start -- 7889.
for train of HIV: length -- 32896; start -- 9073.
for train of BACE: length -- 1210; start -- 41969.
for valid of BBBP: length -- 204; start -- 0.
for valid of Tox21: length -- 782; start -- 204.
for valid of ClinTox: length -- 148; start -- 986.
for valid of HIV: length -- 4112; start -- 1134.
for valid of BACE: length -- 151; start -- 5246.
for test of BBBP: length -- 204; start -- 0.
for test of Tox21: length -- 783; start -- 204.
for test of ClinTox: length -- 148; start -- 987.
for test of HIV: length -- 4112; start -- 1135.
for test of BACE: length -- 152; start -- 5247.


In [None]:
save_path = './corpus/smiles_to_code.pkl'
with open(save_path, "wb") as f:
    pickle.dump(smiles_to_code, f)
print(f"SMILES->code 映射已保存到 {save_path}")

with open(save_path, "rb") as f:
    smiles_to_code = pickle.load(f)
print(f"SMILES->code 映射已加载自 {save_path}")
len(smiles_to_code['BACE'])


SMILES->code 映射已保存到 ./corpus/smiles_to_code.pkl
SMILES->code 映射已加载自 ./corpus/smiles_to_code.pkl


1513

🌟 构造ablation study的3种语料：去掉 strcut_code；struct_code 固定替换；struct_code 随机替换

In [None]:
import re
import pickle

data_path = "corpus/ds_dict_ori.pkl"
with open(data_path, "rb") as f:
    ds_dict = pickle.load(f)
test_datasets = ds_dict['test']

pattern = r'(\[Structure Token\] <struct><gstruct_)\d+(></struct>\n?)'

def replace_with_fixed_number(match):
    return f"{match.group(1)}999{match.group(2)}"

def clean(example):
    example['instruction'] = re.sub(pattern, replace_with_fixed_number, example['instruction'])
    example['text'] = re.sub(pattern, replace_with_fixed_number, example['text'])
    return example

for k, td in test_datasets.items():
    test_datasets[k] = td.map(clean)
for k, td in test_datasets.items():
    continue
with open("./corpus/test_datasets_999.pkl", "wb") as f:
    pickle.dump(test_datasets, f)

In [None]:
import re
import pickle
import random
random.seed(42)

data_path = "corpus/ds_dict_ori.pkl"
with open(data_path, "rb") as f:
    ds_dict = pickle.load(f)
test_datasets = ds_dict['test']

pattern = r'(\[Structure Token\] <struct><gstruct_)\d+(></struct>\n?)'

def replace_with_random(match):
    rand_num = random.randint(1, 999999)
    return f"{match.group(1)}{rand_num}{match.group(2)}"

def clean(example):
    example['instruction'] = re.sub(pattern, replace_with_random, example['instruction'])
    example['text'] = re.sub(pattern, replace_with_random, example['text'])
    return example

for k, td in test_datasets.items():
    test_datasets[k] = td.map(clean)
for k, td in test_datasets.items():
    continue
with open("./corpus/test_datasets_random.pkl", "wb") as f:
    pickle.dump(test_datasets, f)

Map:   0%|          | 0/204 [00:00<?, ? examples/s]

Map: 100%|██████████| 204/204 [00:00<00:00, 2646.27 examples/s]
Map: 100%|██████████| 783/783 [00:00<00:00, 5328.77 examples/s]
Map: 100%|██████████| 783/783 [00:00<00:00, 5867.87 examples/s]
Map: 100%|██████████| 783/783 [00:00<00:00, 5883.17 examples/s]
Map: 100%|██████████| 783/783 [00:00<00:00, 5858.95 examples/s]
Map: 100%|██████████| 783/783 [00:00<00:00, 5734.29 examples/s]
Map: 100%|██████████| 783/783 [00:00<00:00, 5820.95 examples/s]
Map: 100%|██████████| 783/783 [00:00<00:00, 5853.15 examples/s]
Map: 100%|██████████| 783/783 [00:00<00:00, 5868.97 examples/s]
Map: 100%|██████████| 783/783 [00:00<00:00, 5887.07 examples/s]
Map: 100%|██████████| 783/783 [00:00<00:00, 5968.87 examples/s]
Map: 100%|██████████| 783/783 [00:00<00:00, 5941.56 examples/s]
Map: 100%|██████████| 783/783 [00:00<00:00, 5951.45 examples/s]
Map: 100%|██████████| 148/148 [00:00<00:00, 4881.47 examples/s]
Map: 100%|██████████| 148/148 [00:00<00:00, 4873.15 examples/s]
Map: 100%|██████████| 4112/4112 [00:00<0

{'instruction': '\nYou are a chemistry expert. Classify the given molecule into the correct category based on its molecular structure (token) and smiles expression. Each structure token represents a unique graph pattern (e.g., a kind of similar molecular graphs).\n[Molecule] O1CCC(CC1)CNC(=O)C(Cc1cc2cc(ccc2nc1N)-c1ccccc1C)C\n[Structure Token] <struct><gstruct_63661></struct>\n[Task] Is the binding result of the molecular on beta-secretase 1 true or false? Output the complete correct answer from the following two options:\n1. True\n2. False\n[Answer]',
 'response': 'True',
 'text': '<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n\nYou are a helpful and reliable assistant.<|eot_id|><|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\nYou are a chemistry expert. Classify the given molecule into the correct category based on its molecular structure (token) and smiles expression. Each structure token represents a unique graph pattern (e.g., a kind of similar molecular 

In [None]:
import pickle
import re

pattern = r'\[Structure Token\] <struct><gstruct_\d+></struct>\n?'
def clean(example):
    example['instruction'] = re.sub(pattern, '', example['instruction'])
    example['text'] = re.sub(pattern, '', example['text'])
    return example

data_path = "corpus/ds_dict_ori.pkl"
with open(data_path, "rb") as f:
    ds_dict = pickle.load(f)
ds_dict_puretext = dict()
for split, datasets in ds_dict.items():
    for k, td in datasets.items():
        datasets[k] = td.map(clean)
    ds_dict_puretext[split] = datasets
with open("./corpus/ds_dict_puretext.pkl", "wb") as f:
    pickle.dump(ds_dict_puretext, f)

Map: 100%|██████████| 1631/1631 [00:00<00:00, 6622.64 examples/s]
Map: 100%|██████████| 6258/6258 [00:00<00:00, 6787.17 examples/s]
Map: 100%|██████████| 6258/6258 [00:00<00:00, 6787.66 examples/s]
Map: 100%|██████████| 6258/6258 [00:00<00:00, 6804.58 examples/s]
Map: 100%|██████████| 6258/6258 [00:00<00:00, 6781.66 examples/s]
Map: 100%|██████████| 6258/6258 [00:00<00:00, 6750.49 examples/s]
Map: 100%|██████████| 6258/6258 [00:00<00:00, 6739.13 examples/s]
Map: 100%|██████████| 6258/6258 [00:00<00:00, 6710.72 examples/s]
Map: 100%|██████████| 6258/6258 [00:00<00:00, 6830.97 examples/s]
Map: 100%|██████████| 6258/6258 [00:00<00:00, 6770.06 examples/s]
Map: 100%|██████████| 6258/6258 [00:00<00:00, 6719.14 examples/s]
Map: 100%|██████████| 6258/6258 [00:00<00:00, 6683.74 examples/s]
Map: 100%|██████████| 6258/6258 [00:00<00:00, 6770.34 examples/s]
Map: 100%|██████████| 1184/1184 [00:00<00:00, 6568.05 examples/s]
Map: 100%|██████████| 1184/1184 [00:00<00:00, 6481.98 examples/s]
Map: 100%|