In [1]:
import pandas as pd
from PyBioMed.PyProtein import CTD
from tqdm import tqdm
import numpy as np
import json
import torch
import math
import re
from model import GPT, GPTConfig
from utils import sample

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
pro_seq = pd.read_csv("../datasets/transport_pro_seq.txt",sep='\t')
pro_seq.set_index(["uniprot_id"], inplace=True) 

In [3]:
stoi = json.load(open(f'../datasets/drug_selfies_stoi.json', 'r'))
itos = { i:ch for ch,i in stoi.items() }

In [59]:
pro = list(CTD.CalculateCTD(pro_seq['seq']["dt_Q01650"]).values())

In [60]:
#获取KG嵌入
entity_rescal_em = pd.read_csv("../datasets/kg_embedding/RESCAL_entity_embedding.csv")
entity_rescal_em.set_index(["ent_name"], inplace=True)
pro_kg = eval(entity_rescal_em["ent_embedding"]["dt_Q01650"])

In [61]:
pro_cond=pro+pro_kg

In [62]:
pro_len = len(pro_cond)
vocab_size = 128
block_size = 359
n_layer = 8
n_head = 8
n_embd = 256
scaffold = False
lstm = False
lstm_layers = 0
gen_size = 5000
batch_size = 8

In [63]:
mconf = GPTConfig(vocab_size, block_size, pro_len = pro_len,
                       n_layer=n_layer, n_head=n_head, n_embd=n_embd,
                       lstm = lstm, lstm_layers = lstm_layers)
model = GPT(mconf)

In [64]:
model.load_state_dict(torch.load('../result/models/Transport_seq_KG_embedding.pt'))
model.to('cuda')
print('Model loaded')

gen_iter = math.ceil(gen_size / batch_size)

Model loaded


In [73]:
context = '[C]'
pattern = "(\[[^\]]+]|<|Br?|Cl?|N|O|S|P|F|I|b|c|n|o|s|p|\(|\)|\.|=|#|-|\+|\\\\|\/|:|~|@|\?|>|\*|\$|\%[0-9]{2}|[0-9])"
regex = re.compile(pattern)

In [92]:
#prop_based
count = 0
molecules = []
count += 1
for i in tqdm(range(gen_iter)):
    #x = torch.tensor(np.random.randint(128), dtype=torch.long)[None,...].repeat(batch_size, 1).to('cuda')
    x = torch.tensor([stoi[s] for s in regex.findall(context)], dtype=torch.long)[None,...].repeat(batch_size, 1).to('cuda')
    p = torch.tensor(pro_cond).repeat(batch_size, 1).unsqueeze(1).to('cuda')
    y = sample(model, x, block_size, temperature=0.7, sample=True, top_k=None, pro = p)
    for gen_mol in y: 
        completion = ''.join([itos[int(i)] for i in gen_mol])
        completion = completion.replace('<blank>', '')
        molecules.append(completion)

100%|██████████| 625/625 [1:24:55<00:00,  8.15s/it]


In [93]:
len(set(molecules))

1957

In [94]:
len(set(molecules))/len(molecules)

0.3914

In [72]:
mol_df = pd.DataFrame(molecules)
mol_df.to_csv("../result/molecules/Q01650_seq_kg_fixed_mol.csv",index=None)