In [1]:
import dgl
import torch

from tdc.single_pred import ADME
from rdkit import Chem

In [2]:
data = ADME(name="BBB_martins")

Found local copy...
Loading...
Done!


In [3]:
dir(data)

['__class__',
 '__delattr__',
 '__dict__',
 '__dir__',
 '__doc__',
 '__eq__',
 '__format__',
 '__ge__',
 '__getattribute__',
 '__gt__',
 '__hash__',
 '__init__',
 '__init_subclass__',
 '__le__',
 '__len__',
 '__lt__',
 '__module__',
 '__ne__',
 '__new__',
 '__reduce__',
 '__reduce_ex__',
 '__repr__',
 '__setattr__',
 '__sizeof__',
 '__str__',
 '__subclasshook__',
 '__weakref__',
 'balanced',
 'binarize',
 'convert_format',
 'convert_from_log',
 'convert_result',
 'convert_to_log',
 'entity1',
 'entity1_idx',
 'entity1_name',
 'file_format',
 'get_approved_set',
 'get_data',
 'get_label_meaning',
 'get_other_species',
 'get_split',
 'harmonize',
 'label_distribution',
 'label_name',
 'name',
 'path',
 'print_stats',
 'raw_format',
 'y']

In [4]:
split = data.get_split(method="random", frac=[0.7, 0.1, 0.2], seed=999)

In [5]:
split

{'train':                                 Drug_ID  \
 0                  Terbutylchlorambucil   
 1                                 40730   
 2                           cloxacillin   
 3                          cefoperazone   
 4                      rolitetracycline   
 ...                                 ...   
 1416                         zometapine   
 1417                         licostinel   
 1418  ademetionine(adenosyl-methionine)   
 1419                           mesocarb   
 1420                         tofisoline   
 
                                                    Drug  Y  
 0                CC(C)(C)OC(=O)CCCc1ccc(N(CCCl)CCCl)cc1  1  
 1      CC1COc2c(N3CCN(C)CC3)c(F)cc3c(=O)c(C(=O)O)cn1c23  1  
 2     Cc1onc(-c2ccccc2Cl)c1C(=O)N[C@@H]1C(=O)N2[C@@H...  1  
 3     CCN1CCN(C(=O)N[C@@H](C(=O)N[C@@H]2C(=O)N3C(C(=...  1  
 4     CN(C)[C@@H]1C(=O)/C(=C(/O)NCN2CCCC2)C(=O)[C@@]...  1  
 ...                                                 ... ..  
 1416                CC1=C2

In [6]:
train_set = split['train']
valid_set = split['valid']
test_set = split['test']

In [7]:
smi_list = list(train_set['Drug'])

In [8]:
print(smi_list[:10])

['CC(C)(C)OC(=O)CCCc1ccc(N(CCCl)CCCl)cc1', 'CC1COc2c(N3CCN(C)CC3)c(F)cc3c(=O)c(C(=O)O)cn1c23', 'Cc1onc(-c2ccccc2Cl)c1C(=O)N[C@@H]1C(=O)N2[C@@H](C(=O)O)C(C)(C)S[C@H]12', 'CCN1CCN(C(=O)N[C@@H](C(=O)N[C@@H]2C(=O)N3C(C(=O)O)=C(CSc4nnnn4C)CS[C@H]23)c2ccc(O)cc2)C(=O)C1=O', 'CN(C)[C@@H]1C(=O)/C(=C(/O)NCN2CCCC2)C(=O)[C@@]2(O)C(=O)C3=C(O)c4c(O)cccc4[C@@](C)(O)[C@H]3C[C@@H]12', 'Cc1nccn1CC1CCc2c(c3ccccc3n2C)C1=O', 'NC(N)=NC(=O)c1nc(Cl)c(N)nc1N', 'CN1Cc2c(-c3noc(C(C)(O)CO)n3)ncn2-c2cccc(Cl)c2C1=O', 'Cc1cn([C@H]2C[C@H](F)[C@@H](CO)O2)c(=O)[nH]c1=O', 'ClCCl']


In [9]:
label_list = list(train_set['Y'])

In [10]:
print(label_list[:10])

[1, 1, 1, 1, 1, 1, 1, 0, 1, 1]


In [19]:
smi = smi_list[0]
print(smi)
mol = Chem.MolFromSmiles(smi)

## 원자 feature 얻기
atom_list = mol.GetAtoms() ## 분자를 구성하는 원자 리스트
print(atom_list)
atom = atom_list[0]
print(atom.GetSymbol()) ## i번째 원자가 무엇인가
print(atom.GetDegree()) ## i번째 원자와 직접 연결된 가지
print(atom.GetTotalNumHs()) ## 수소 원자의 수
print(atom.GetImplicitValence())

bond_list = mol.GetBonds()
bond = bond_list[0]
print(bond.GetBeginAtom().GetSymbol(), bond.GetBeginAtom().GetIdx())
print(bond.GetEndAtom().GetSymbol(), bond.GetEndAtom().GetIdx())
print(bond.GetBondType())
print(bond.IsInRing())
print(bond.GetIsConjugated())

CC(C)(C)OC(=O)CCCc1ccc(N(CCCl)CCCl)cc1
<rdkit.Chem._GetAtomsIterator object at 0x7fc60e9f8700>
C
1
3
3
C 0
C 1
SINGLE
False
False


In [None]:
ATOM_VOCAB = [
    'C', 'N', 'O', 'F', 'S', 'Cl', 'Br', 'H', 'Si', 'P', 'B', 'Li', 'Na', 'K', 'Ca',
    'Fe', 'As', 'Al', 'I', 'Mg', 'Sn', 'Sb', 'Bi', 'Ge', 'Ti', 'Se', 'Zn', 'Cu', 'Au',
    'Ni', 'Cd', 'Mn', 'Cr', 'Pt', 'Hg', 'Pb'
]

def one_of_k_encoding(x, vocab):
    if x not in vocab:
        x = vocab[-1]
    return list(map(lambda s: float(x == s), vocab))


def get_atom_feature(atom):
    atom_feature = one_of_k_encoding(atom.GetSymbol(), ATOM_VOCAB)
    atom_feature += one_of_k_encoding(atom.GetDegree(), [0, 1, 2, 3, 4, 5])
    atom_feature += one_of_k_encoding(atom.GetTotalNumHs(), [0, 1, 2, 3, 4])
    atom_feature += one_of_k_encoding(atom.GetImplicitValence(), [0, 1, 2, 3, 4, 5])
    atom_feature += [atom.GetIsAromatic()]
    return atom_feature


def get_bond_feature(bond):
    bt = bond.GetBondType()
    bond_feature = [
        bt == Chem.rdchem.BondType.SINGLE,
        bt == Chem.rdchem.BondType.DOUBLE,
        bt == Chem.rdchem.BondType.TRIPLE,
        bt == Chem.rdchem.BondType.AROMATIC,
        bond.GetIsConjugated(),
        bond.IsInRing()
    ]
    return bond_feature


def get_smi_and_label(dataset):
    smi_list = list(dataset['Drug'])
    label_list = list(dataset['Y'])

    return smi_list, label_list


def get_molecular_graph(smi):
    mol = Chem.MolFromSmiles(smi)
    graph = dgl.DGLGraph()

    atom_list = mol.GetAtoms()
    num_atoms = len(atom_list)
    graph.add_nodes(num_atoms)

    atom_feature_list = [get_atom_feature(atom) for atom in atom_list]
    atom_feature_list = torch.tensor(atom_feature_list, dtype=torch.float64)
    graph.ndata['h'] = atom_feature_list

    bond_list = mol.GetBonds()
    bond_feature_list = []
    for bond in bond_list:
        bond_feature = get_bond_feature(bond)
        
        src = bond.GetBeginAtom().GetIdx()
        dst = bond.GetEndAtom().GetIdx()

        """
        DGL graph is undirectional, so we have to add edge pair of both (i, j) and (j, i)
        """
        ## i --> j
        graph.add_edges(src, dst)
        bond_feature_list.append(bond_feature)

        ## j --> i
        graph.add_edges(dst, src)
        bond_feature_list.append(bond_feature)

    bond_feature_list = torch.tensor(bond_feature_list, dtype=torch.float64)
    graph.edata['e_ij'] = bond_feature_list

    return graph


def debugging():
    data = ADME(name="BBB_Martins")
    split = data.get_split(method='random', seed=999, frac=[0.7, 0.1, 0.2])

    train_set = split['train']
    valid_set = split['valid']
    test_set = split['test']

    smi_train, label_train = get_smi_and_label(train_set)
    graph = get_molecular_graph(smi_train[0])

if __name__ == "__main__":
    debugging()