In [1]:
import pandas as pd
import argparse
from utils import set_seed
import numpy as np
import wandb
import math
import re

import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torch.nn import functional as F
from torch.cuda.amp import GradScaler

from model import GPT, GPTConfig
from trainer import Trainer, TrainerConfig

from seq_kg_embedd import SmilesDataset
import selfies as sf
from PyBioMed.PyProtein import CTD

import json

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
set_seed(42)

run_name = "Transport_seq_KG_embedding"

In [3]:
wandb.init(project="DTproject", name=run_name)

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mzoey_chen[0m (use `wandb login --relogin` to force relogin)


In [4]:
data = pd.read_csv('../datasets/chemb_drug_selfies.csv')
data = data.dropna(axis=0).reset_index(drop=True)
data.columns = data.columns.str.lower()

In [5]:
data = data.dropna(axis=0).reset_index(drop=True)
data.head(3)  

Unnamed: 0,dt,selfies,split
0,dt_Q96A29,"['[C]', '[C@@H1]', '[O]', '[C@H1]', '[Branch2]...",train
1,dt_P08183,"['[C]', '[C]', '[Branch1]', '[C]', '[C]', '[O]...",train
2,dt_Q9Y6L6,"['[C]', '[O]', '[C@@H1]', '[C]', '[C@H1]', '[B...",train


In [6]:
pro_seq = pd.read_csv("../datasets/transport_pro_seq.txt",sep='\t')
pro_seq = pro_seq.dropna(axis=0).reset_index(drop=True)
pro_seq = pro_seq.rename(columns={"uniprot_id":"dt"})

In [7]:
merge_data = pd.merge(data,pro_seq,how="right",on="dt")
merge_data = merge_data.dropna(axis=0).reset_index(drop=True)

In [8]:
entity_rescal_em = pd.read_csv("../datasets/kg_embedding/RESCAL_entity_embedding.csv")

merge_data2 = pd.merge(merge_data,entity_rescal_em,how="right",left_on="dt",right_on="ent_name")
merge_data2 = merge_data2.dropna(axis=0).reset_index(drop=True)
merge_data2.drop(columns=['ent_name'],inplace=True)

In [9]:
merge_data2.head()

Unnamed: 0,dt,selfies,split,seq,ent_embedding
0,dt_Q13183,"['[O]', '.', '[O]', '.', '[O]', '.', '[O]', '....",train,MATCWQALWAYRSYLIVFFVPILLLPLPILVPSKEAYCAYAIILMA...,"[-0.1436040699481964, -0.08548112958669662, -0..."
1,dt_Q13183,"['[O]', '[=C]', '[Branch1]', '[C]', '[O]', '[C...",train,MATCWQALWAYRSYLIVFFVPILLLPLPILVPSKEAYCAYAIILMA...,"[-0.1436040699481964, -0.08548112958669662, -0..."
2,dt_Q13183,"['[O]', '[=S]', '[=Branch1]', '[C]', '[=O]', '...",train,MATCWQALWAYRSYLIVFFVPILLLPLPILVPSKEAYCAYAIILMA...,"[-0.1436040699481964, -0.08548112958669662, -0..."
3,dt_Q13183,"['[O]', '[=S]', '[=Branch1]', '[C]', '[=O]', '...",train,MATCWQALWAYRSYLIVFFVPILLLPLPILVPSKEAYCAYAIILMA...,"[-0.1436040699481964, -0.08548112958669662, -0..."
4,dt_Q13183,"['[O]', '[=S]', '[=Branch1]', '[C]', '[=O]', '...",train,MATCWQALWAYRSYLIVFFVPILLLPLPILVPSKEAYCAYAIILMA...,"[-0.1436040699481964, -0.08548112958669662, -0..."


In [10]:
#Get selfies train and validation datasets

train_data = merge_data2[merge_data2['split'] == 'train'].reset_index(drop=True)
val_data = merge_data2[merge_data2['split'] == 'test'].reset_index(drop=True)

selfies_list = list(train_data['selfies'])
vselfies_list = list(val_data['selfies'])

print(len(selfies_list))
print(len(vselfies_list))

52883
13221


In [11]:
#Get All charsets from datasets

from torchtext.legacy import data as d
from torchtext.vocab import Vectors


all_selfies = data['selfies'].to_list()
BLANK_WORD = "<blank>"
tokenizer = lambda x: x.split()
TGT = d.Field(tokenize=tokenizer,pad_token=BLANK_WORD)
src = []
src_len = []
for i in all_selfies:
    i = i[2:-2].replace("\\\\","\\")
    src.append(i.split("', '"))
    src_len.append(len(i.split("', '")))

TGT.build_vocab(src)


whole_string = []
for k in TGT.vocab.stoi.keys():
    whole_string.append(k)
print(len(whole_string))


152


In [12]:
#Get charsets
stoi = json.load(open(f'../datasets/drug_selfies_stoi.json', 'r'))
itos = dict(zip(stoi.values(), stoi.keys()))

In [13]:
#Treat selfies as inputs of equal length to guarantee that the input model does not have dimensional problems

max_len = max(src_len)

In [14]:
selfies = [] 
BLANK_WORD = '<blank>'
for s in selfies_list:
    s = eval(s)
    while len(s) < max_len+1:   #以防末尾信息丢失
        s.append(BLANK_WORD)
    
    selfies.append(s)
    
vselfies = [] 
BLANK_WORD = '<blank>'
for vs in vselfies_list:
    vs = eval(vs)
    while len(vs) < max_len+1:  #以防末尾信息丢失
        vs.append(BLANK_WORD)
    
    vselfies.append(vs)

In [15]:
#Obtain protein sequence conditions

pro = train_data["seq"]
vpro = val_data["seq"]

embedding = train_data["ent_embedding"]
vembedding = val_data["ent_embedding"]

In [16]:
train_dataset = SmilesDataset(selfies,whole_string,stoi,itos,embedding,max_len,aug_prob=0,pro=pro)
valid_dataset = SmilesDataset(vselfies,whole_string,stoi,itos,vembedding,max_len,aug_prob=0,pro=vpro)

data has 52883 smiles, 152 unique characters.
data has 13221 smiles, 152 unique characters.


In [17]:
pro_len = 947

In [18]:
n_layer = 8
n_head = 8
n_embd = 256
lstm_layers = 0

In [19]:
mconf = GPTConfig(train_dataset.vocab_size, train_dataset.max_len, pro_len=pro_len,  # args.num_props,
                        n_layer=n_layer, n_head=n_head, n_embd=n_embd,
                        lstm=False, lstm_layers=lstm_layers)

In [20]:
model = GPT(mconf)

In [24]:
max_epochs = 10
batch_size = 16
learning_rate = 6e-4

In [25]:
tconf = TrainerConfig(max_epochs=max_epochs, batch_size=batch_size, 
                      learning_rate=learning_rate,
                      lr_decay=True, warmup_tokens=0.1*len(train_data)*max_len, 
                      final_tokens= max_epochs*len(train_data)*max_len,
                      num_workers=0, 
                      ckpt_path=f'../result/models/{run_name}.pt', 
                      block_size=train_dataset.max_len, generate=False)

In [26]:
trainer = Trainer(model, train_dataset, valid_dataset,
                  tconf, train_dataset.stoi, train_dataset.itos)

In [27]:
df = trainer.train(wandb)

epoch 1 iter 3305: train loss 0.04914. lr 5.878964e-04: 100%|██████████| 3306/3306 [12:15<00:00,  4.49it/s]


Saving at epoch 1


epoch 2 iter 3305: train loss 0.05534. lr 5.472984e-04: 100%|██████████| 3306/3306 [12:33<00:00,  4.39it/s]


Saving at epoch 2


epoch 3 iter 3305: train loss 0.04312. lr 4.820944e-04: 100%|██████████| 3306/3306 [12:23<00:00,  4.45it/s]


Saving at epoch 3


epoch 4 iter 3305: train loss 0.02616. lr 3.987721e-04: 100%|██████████| 3306/3306 [12:13<00:00,  4.51it/s]


Saving at epoch 4


epoch 5 iter 3305: train loss 0.03184. lr 3.056219e-04: 100%|██████████| 3306/3306 [12:13<00:00,  4.50it/s]


Saving at epoch 5


epoch 6 iter 3305: train loss 0.02703. lr 2.119125e-04: 100%|██████████| 3306/3306 [13:35<00:00,  4.05it/s]


Saving at epoch 6


epoch 7 iter 3305: train loss 0.04490. lr 1.269677e-04: 100%|██████████| 3306/3306 [13:09<00:00,  4.19it/s]


Saving at epoch 7


epoch 8 iter 3305: train loss 0.02254. lr 6.000000e-05: 100%|██████████| 3306/3306 [12:27<00:00,  4.42it/s]


Saving at epoch 8


epoch 9 iter 3305: train loss 0.02978. lr 6.000000e-05: 100%|██████████| 3306/3306 [12:17<00:00,  4.48it/s]


Saving at epoch 9


epoch 10 iter 3305: train loss 0.01465. lr 6.000000e-05: 100%|██████████| 3306/3306 [12:07<00:00,  4.54it/s]


Saving at epoch 10
