In [1]:
import torch
from utils.train_utils import seed_all
import os
import argparse
from dataset import SmileDataset, SmileCollator
from torch.utils.data import DataLoader
from tokenizer import SmilesTokenizer
from model import GPTConfig, GPT
import time
import datasets
from rdkit import Chem
from utils.train_utils import get_mol
from utils.chem_utils import reconstruct
from tqdm import tqdm

def Test(model, tokenizer, max_seq_len, temperature, top_k, stream, rp, kv_cache, is_simulation, device, output_file_path, seed):
    complete_answer_list = []
    valid_answer_list = []
    model.eval()
    for x in tqdm(range(1000)):
        # place data on the correct device
        x = torch.tensor([1], dtype=torch.int64).unsqueeze(0)
        x = x.to(device)
        # pbar.set_description(f"iter {it}")
        with torch.no_grad():
            res_y = model.generate(x, tokenizer, max_new_tokens=max_seq_len,
                                   temperature=temperature, top_k=top_k, stream=stream, rp=rp, kv_cache=kv_cache, is_simulation=is_simulation)
            try:
                y = next(res_y)
            except StopIteration:
                print("No answer")
                continue

            history_idx = 0
            complete_answer = f"{tokenizer.decode(x[0])}"  # 用于保存整个生成的句子

            while y != None:
                answer = tokenizer.decode(y[0].tolist())
                if answer and answer[-1] == '�':
                    try:
                        y = next(res_y)
                    except:
                        break
                    continue
          
                if not len(answer):
                    try:
                        y = next(res_y)
                    except:
                        break
                    continue

                # 保存生成的片段到完整回答中
                complete_answer += answer[history_idx:]

                # print(answer[history_idx:], end='', flush=True)
                try:
                    y = next(res_y)
                except:
                    break
                history_idx = len(answer)
                if not stream:
                    break

            complete_answer = complete_answer.replace(" ", "").replace("[BOS]", "").replace("[EOS]", "")
            frag_list = complete_answer.replace(" ", "").split('[SEP]')
            try:
                frag_mol = [Chem.MolFromSmiles(s) for s in frag_list]
                mol = reconstruct(frag_mol)[0]
                if mol:
                    generate_smiles = Chem.MolToSmiles(mol)
                    valid_answer_list.append(generate_smiles)
                    answer = frag_list
                else:
                    answer = frag_list
            except:
                answer = frag_list
            complete_answer_list.append(answer)

    print(f"valid ratio:{len(valid_answer_list)}/{len(complete_answer_list)}={len(valid_answer_list) / len(complete_answer_list)}")
    if not os.path.exists(output_file_path):
        os.mkdir(output_file_path)
    with open(os.path.join(output_file_path, f'complete_answer_{seed}'), "w") as w:
        for j in complete_answer_list:
            if not isinstance(j, str):
                j = str(j)
            w.write(j)
            w.write("\n")
    w.close()
    with open(os.path.join(output_file_path, f'valid_answer_{seed}'), "w") as w:
        for j in valid_answer_list:
            w.write(j)
            w.write("\n")
    w.close()


In [11]:
def main_test(args):
    #设置随机种子的值
    seed_all(args.seed)
    os.environ["CUDA_VISIBLE_DEVICES"] = args.device
    batch_size = 1
    device = torch.device(f'cuda:{0}')
    
    test_names = "test"

    tokenizer = SmilesTokenizer('./vocabs/vocab.txt')
    tokenizer.bos_token = "[BOS]"
    tokenizer.bos_token_id = tokenizer.convert_tokens_to_ids("[BOS]")
    tokenizer.eos_token = "[EOS]"
    tokenizer.eos_token_id = tokenizer.convert_tokens_to_ids("[EOS]")
    

    collator = SmileCollator(tokenizer)

    mconf = GPTConfig(vocab_size=tokenizer.vocab_size, n_layer=12, n_head=12, n_embd=768)
    model = GPT(mconf).to(device)
    checkpoint = torch.load(f'./weights/fragpt.pt', weights_only=True)
    # checkpoint = torch.load(f'/data1/yzf/molecule_generation/a/LinkerGPT/weights/{args.run_name}.pt', weights_only=True)
    model.load_state_dict(checkpoint)
    start_time = time.time()
    Test(model, tokenizer, max_seq_len=1024, temperature=1.0, top_k=None, stream=False, rp=1., kv_cache=True, is_simulation=True, device=device, output_file_path="./output",seed=args.seed)
    end_time = time.time()
    elapsed_time = end_time - start_time

    print(f"运行时间: {elapsed_time:.4f} 秒")


if __name__ == '__main__':
    """
        world_size: 所有的进程数量
        rank: 全局的进程id
    """
    parser = argparse.ArgumentParser(description='simple distributed training job')
    parser.add_argument('--device', default='1', help='device id (i.e. 0 or 0,1 or cpu)')
    parser.add_argument('--seed', default='42', help='seed')

    opt = parser.parse_args()

    main_test(opt)

* C ( = O ) C ( = C ) F [SEP] * N 1 C [C@H] 2 C C [C@@H] ( C 1 ) [NH+] 2 * [SEP] * C * [SEP] * c 1 c c c c ( = O ) [nH] 1 [EOS]* C C [SEP] * O * [SEP] * c 1 c c c ( O C C N 2 C C O C C 2 ) c ( * ) c 1 [SEP] * C C N [EOS]* O C [SEP] * c 1 c c c ( * ) c ( O C ) c 1 [SEP] * C [NH2+] * [SEP] * C * [SEP] * c 1 c c ( Cl ) c c ( Cl ) c 1 * [SEP] * O C [EOS]* C C ( C ) ( C ) [C@@H] ( O ) C ( C ) C [SEP] * N * [SEP] * C ( * ) = O [SEP] * c 1 c c c ( Br ) c ( Cl ) c 1 [EOS]* N 1 C C C C ( C ) ( C ) C C 1 [SEP] * c 1 n n c ( * ) n 1 C [C@@H] 1 C C O C 1 [SEP] * C O * [SEP] * C C ( F ) ( F ) F [EOS]* N 1 C C C ( C ) ( O ) C C 1 [SEP] * C ( * ) = O [SEP] * [C@H] 1 C C [C@@H] ( * ) C C 1 [SEP] * N * [SEP] * C ( * ) = O [SEP] * O * [SEP] * C ( C ) ( C ) C [EOS]* c 1 c c c n c 1 [SEP] * O * [SEP] * c 1 c c c ( * ) c c 1 [SEP] * c 1 c n n ( * ) c 1 [SEP] * C C * [SEP] * O * [SEP] * [C@@H] 1 C C C C O 1 [EOS]* c 1 n c 2 c c c c c 2 c ( = O ) n 1 C c 1 c c c n c 1 [SEP] * S * [SEP] * C * [SEP] * c 1 n c 

* C C [SEP] * N ( * ) c 1 c c n c ( N = N C 2 C C c 3 c ( F ) c c c c 3 2 ) n 1 [SEP] * C C [EOS]* C C C C C # C [SEP] * N * [SEP] * C ( * ) = O [SEP] * N 1 C C C [C@H] ( S ( C ) ( = O ) = O ) C C 1 [EOS]* c 1 c c c ( C ) c c 1 [SEP] * c 1 c n c ( * ) o 1 [SEP] * C C C ( * ) = O [SEP] * N * [SEP] * [C@@H] 1 C [C@H] 2 C [C@H] 1 [C@@H] 1 C C C [C@@H] 2 1 [EOS]* c 1 c n ( C ) c c c 1 = O [SEP] * C ( * ) = O [SEP] * N * [SEP] * N 1 C [C@@H] 2 C C C [C@] 2 ( * ) C 1 [SEP] * C ( = O ) C 1 = C C C C 1 [EOS]* N ( C ) C ( = O ) C C C C ( F ) ( F ) F [SEP] * C C * [SEP] * N * [SEP] * C ( * ) = O [SEP] * c 1 c o c ( * ) c 1 [SEP] * C ( N ) = O [EOS]* = C C C C C C C C C [SEP] * = C C C = * [SEP] * = C C C C C N = C ( [O-] ) C C C C C C C C = * [SEP] * = C C C = * [SEP] * = C C C C C C [EOS]* C C [SEP] * O * [SEP] * C ( = O ) [C@@H] ( * ) C 1 C C 1 [SEP] * N * [SEP] * C C [C@@H] ( * ) C [SEP] * N * [SEP] * C ( = O ) C C ( = * ) C [SEP] * = C ( C ) C [EOS]* C C [SEP] * c 1 c c c ( * ) c ( C C ) c 1