In [1]:
import pandas as pd
from PyBioMed.PyProtein import CTD
from tqdm import tqdm
import numpy as np
import json
import torch
import math
import re
from model import GPT, GPTConfig
from utils import sample

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
pro_seq = pd.read_csv("../datasets/transport_pro_seq.txt",sep='\t')
pro_seq.set_index(["uniprot_id"], inplace=True) 

In [121]:
pro_cond=list(CTD.CalculateCTD(pro_seq['seq']["dt_P08183"]).values())

In [122]:
max_len = 147

In [123]:
stoi = json.load(open(f'../datasets/drug_selfies_stoi.json', 'r'))
itos = { i:ch for ch,i in stoi.items() }

In [124]:
stoi

{'<unk>': 0,
 '<blank>': 1,
 '[C]': 2,
 '[Branch1]': 3,
 '[=C]': 4,
 '[O]': 5,
 '[Ring1]': 6,
 '[=Branch1]': 7,
 '[=O]': 8,
 '[N]': 9,
 '[Ring2]': 10,
 '[C@H1]': 11,
 '[Branch2]': 12,
 '[C@@H1]': 13,
 '[#Branch1]': 14,
 '[C@@]': 15,
 '[=N]': 16,
 '[C@]': 17,
 '[S]': 18,
 '[#Branch2]': 19,
 '[=Branch2]': 20,
 '[P]': 21,
 '[#C]': 22,
 '[/C]': 23,
 '[F]': 24,
 '.': 25,
 '[Cl]': 26,
 '[NH1]': 27,
 '[\\C]': 28,
 '[O-1]': 29,
 '[=Ring1]': 30,
 '[I]': 31,
 '[N+1]': 32,
 '[Na+1]': 33,
 '[=N+1]': 34,
 '[/N]': 35,
 '[=S]': 36,
 '[/C@H1]': 37,
 '[#N]': 38,
 '[Br]': 39,
 '[=N-1]': 40,
 '[/O]': 41,
 '[/C@@H1]': 42,
 '[=Ring2]': 43,
 '[\\O]': 44,
 '[Si]': 45,
 '[\\C@@H1]': 46,
 '[=P]': 47,
 '[Ca+2]': 48,
 '[Cl-1]': 49,
 '[Br-1]': 50,
 '[B]': 51,
 '[S+1]': 52,
 '[-/Ring2]': 53,
 '[K+1]': 54,
 '[\\C@H1]': 55,
 '[Mg+2]': 56,
 '[N-1]': 57,
 '[I-1]': 58,
 '[P@@]': 59,
 '[P@]': 60,
 '[2H]': 61,
 '[Branch3]': 62,
 '[Se]': 63,
 '[\\S]': 64,
 '[\\N]': 65,
 '[Al+3]': 66,
 '[11CH3]': 67,
 '[14CH2]': 68,
 '[/S]

In [125]:
pro = pro_cond
pro_len = 147
vocab_size = 128
block_size = 359
n_layer = 8
n_head = 8
n_embd = 256
scaffold = False
lstm = False
lstm_layers = 0
gen_size = 5000
batch_size = 8

In [126]:
mconf = GPTConfig(vocab_size, block_size, pro_len = pro_len,
                       n_layer=n_layer, n_head=n_head, n_embd=n_embd,
                       lstm = lstm, lstm_layers = lstm_layers)
model = GPT(mconf)

In [127]:
model.load_state_dict(torch.load('../result/models/Transport_seq.pt'))
model.to('cuda')
print('Model loaded')

gen_iter = math.ceil(gen_size / batch_size)

Model loaded


In [128]:
pattern = "(\[[^\]]+]|<|Br?|Cl?|N|O|S|P|F|I|b|c|n|o|s|p|\(|\)|\.|=|#|-|\+|\\\\|\/|:|~|@|\?|>|\*|\$|\%[0-9]{2}|[0-9])"
regex = re.compile(pattern)

In [129]:
context = '[C]'

In [130]:
#pro_based
count = 0
molecules = []
count += 1
for i in tqdm(range(gen_iter)):
    x = torch.tensor(np.random.randint(128), dtype=torch.long)[None,...].repeat(batch_size, 1).to('cuda')
    #x = torch.tensor([stoi[s] for s in regex.findall(context)], dtype=torch.long)[None,...].repeat(batch_size, 1).to('cuda')
    p = torch.tensor(pro_cond).repeat(batch_size, 1).unsqueeze(1).to('cuda')
    y = sample(model, x, block_size, temperature=0.7, sample=True, top_k=None, pro = p)
    for gen_mol in y:
        completion = ''.join([itos[int(i)] for i in gen_mol])
        completion = completion.replace('<blank>', '')
        molecules.append(completion)

100%|██████████| 625/625 [2:10:51<00:00, 12.56s/it]  


In [131]:
len(set(molecules))

3357

In [132]:
len(set(molecules))/len(molecules)

0.6714

In [133]:
mol_df = pd.DataFrame(molecules)
mol_df.to_csv("../result/molecules/P08183_seq_random_mol.csv",index=None)