In [None]:
import pandas as pd
import argparse
from utils import set_seed
import numpy as np
import wandb
import math
import re

import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torch.nn import functional as F
from torch.cuda.amp import GradScaler

from model import GPT, GPTConfig
from trainer import Trainer, TrainerConfig

from seq_embedd import SmilesDataset
import selfies as sf
from PyBioMed.PyProtein import CTD

import json

In [None]:
set_seed(42)

run_name = "Transport_seq"

In [None]:
wandb.init(project="DTproject", name=run_name)

In [None]:
data = pd.read_csv('../datasets/chemb_drug_selfies.csv')
data = data.dropna(axis=0).reset_index(drop=True)
data.columns = data.columns.str.lower()

In [None]:
data = data.dropna(axis=0).reset_index(drop=True)
data.head()  

In [None]:
pro_seq = pd.read_csv("../datasets/transport_pro_seq.txt",sep='\t')
pro_seq = pro_seq.dropna(axis=0).reset_index(drop=True)
pro_seq = pro_seq.rename(columns={"uniprot_id":"dt"})

In [None]:
pro_seq.head()

In [None]:
merge_data = pd.merge(data,pro_seq,how="right",on="dt")
merge_data = merge_data.dropna(axis=0).reset_index(drop=True)

In [None]:
merge_data

In [None]:
#Get selfies train and validation datasets

train_data = merge_data[merge_data['split'] == 'train'].reset_index(drop=True)
val_data = merge_data[merge_data['split'] == 'test'].reset_index(drop=True)

selfies_list = list(train_data['selfies'])
vselfies_list = list(val_data['selfies'])

print(len(selfies_list))
print(len(vselfies_list))

In [None]:
#Get All charsets from datasets

from torchtext.legacy import data as d
from torchtext.vocab import Vectors


all_selfies = data['selfies'].to_list()
BLANK_WORD = "<blank>"
tokenizer = lambda x: x.split()
TGT = d.Field(tokenize=tokenizer,pad_token=BLANK_WORD)
src = []
src_len = []
for i in all_selfies:
    i = i[2:-2].replace("\\\\","\\")
    src.append(i.split("', '"))
    src_len.append(len(i.split("', '")))
#max_len = max(src_len) + 2

TGT.build_vocab(src)
#vocab_size = len(TGT.vocab.freqs.most_common()) + 3


whole_string = []
for k in TGT.vocab.stoi.keys():
    whole_string.append(k)
print(len(whole_string))

In [None]:
#Get All charsets index
stoi = json.load(open(f'../datasets/drug_selfies_stoi.json', 'r'))
itos = dict(zip(stoi.values(), stoi.keys()))

In [None]:
#Gets the longest string to be flattened later

max_len = max(src_len)
max_len

In [None]:
#Treat selfies as inputs of equal length to guarantee that the input model does not have dimensional problems

selfies = []
BLANK_WORD = '<blank>'
for s in selfies_list:
    s = eval(s)
    while len(s) < max_len+1:   #In case the end information is lost
        s.append(BLANK_WORD)
    
    selfies.append(s)
    
vselfies = [] 
#BOS_WORD = '<s>'
#EOS_WORD = '</s>'
BLANK_WORD = '<blank>'
for vs in vselfies_list:
    vs = eval(vs)
    while len(vs) < max_len+1:  #In case the end information is lost
        vs.append(BLANK_WORD)
    
    vselfies.append(vs)

In [None]:
#Obtain protein sequence conditions

pro = train_data["seq"]
vpro = val_data["seq"]

#Obtain protein sequence embedding length
pro_len=147

In [None]:
train_dataset = SmilesDataset(selfies,whole_string,stoi,itos,max_len,aug_prob=0,pro=pro)
valid_dataset = SmilesDataset(vselfies,whole_string,stoi,itos,max_len,aug_prob=0,pro=vpro)

In [None]:
#parameters
n_layer = 8
n_head = 8
n_embd = 256

max_epochs = 10
batch_size = 16
learning_rate = 6e-4

In [None]:
mconf = GPTConfig(train_dataset.vocab_size, train_dataset.max_len, pro_len=pro_len,  # args.num_props,
                        n_layer=n_layer, n_head=n_head, n_embd=n_embd,
                        lstm=False)

In [None]:
model = GPT(mconf)

In [None]:
tconf = TrainerConfig(max_epochs=max_epochs, batch_size=batch_size, 
                      learning_rate=learning_rate,
                      lr_decay=True, warmup_tokens=0.1*len(train_data)*max_len, 
                      final_tokens= max_epochs*len(train_data)*max_len,
                      num_workers=0, 
                      ckpt_path=f'../result/models/{run_name}.pt', 
                      block_size=train_dataset.max_len, generate=False)

In [None]:
trainer = Trainer(model, train_dataset, valid_dataset,
                  tconf, train_dataset.stoi, train_dataset.itos)

In [None]:
df = trainer.train(wandb)