In [1]:
import csv
import numpy as np
from pysmiles import read_smiles
import networkx as nx
from matplotlib import pyplot as plt
import warnings
from tqdm import tqdm
from torch_geometric.utils.convert import from_networkx
from rdkit import Chem
from pandas import read_csv
from rdkit import Chem
import torch
import torch_geometric
from torch_geometric.data import Data
from torch_geometric.loader import DataLoader
warnings.filterwarnings("ignore")

In [2]:
train_data = read_csv('data_good/train.csv')
test_data = read_csv('data_good/test.csv')

In [3]:
train_data['Smiles'][0]

'COc1ccc2[nH]cc(CCN)c2c1'

In [4]:
molecules = [Chem.MolFromSmiles(smile) for smile in tqdm(train_data['Smiles'])]
train_data['molecules'] = molecules

molecules = [Chem.MolFromSmiles(smile) for smile in tqdm(test_data['Smiles'])]
test_data['molecules'] = molecules

100%|██████████| 5557/5557 [00:00<00:00, 11603.61it/s]
100%|██████████| 1614/1614 [00:00<00:00, 11367.33it/s]


In [5]:
def get_single_raw_node_features(atom):
    now = []
    now.append(atom.GetAtomicNum())
    now.append(atom.GetChiralTag())
    now.append(atom.GetTotalDegree())
    now.append(atom.GetFormalCharge())
    now.append(atom.GetTotalNumHs())
    now.append(atom.GetNumRadicalElectrons())
    now.append(atom.GetHybridization())
    now.append(atom.GetIsAromatic())
    now.append(atom.IsInRing())
    return now

def get_raw_node_features(mol):
    features = []
    for atom in mol.GetAtoms():
        features.append(get_single_raw_node_features(atom))
    return features

def get_single_raw_bond_features(bond):
    now = []
    now.append(str(bond.GetBondType()))
    now.append(str(bond.GetStereo()))
    now.append(bond.GetIsConjugated())
    return now

def get_raw_bond_features(mol):
    features = []
    for bond in mol.GetBonds():
        features.append(get_single_raw_bond_features(bond))
    return features

In [6]:
class Encoder():
    def __init__(self, features, print_unique = True):
        reshaped = [[el] for el in features[0]]
        for i in range(1, len(features)):
            for j in range(len(reshaped)):
                reshaped[j].append(features[i][j])
        self.unique = [list(np.unique(values)) for values in reshaped]
        if (print_unique):
            for i in range(len(self.unique)):
                print(i, self.unique[i])
                
                
    def transform(self, features, show_progress = True):
        result = []
        for vector in tqdm(features, disable = not show_progress):
            now = []
            for i in range(len(vector)):
                current = np.zeros(len(self.unique[i]))
                current[self.unique[i].index(vector[i])] = 1.0
                now.append(current)
            now = np.concatenate(now, axis = 0)
            result.append(now)
        return np.array(result)

In [7]:
train_raw_bond_features = []
for mol in train_data['molecules']:
    train_raw_bond_features = train_raw_bond_features + get_raw_bond_features(mol)

bond_encoder = Encoder(train_raw_bond_features)

train_bond_features = bond_encoder.transform(train_raw_bond_features)
print(train_bond_features.shape)

0 ['AROMATIC', 'DOUBLE', 'SINGLE', 'TRIPLE']
1 ['STEREOE', 'STEREONONE', 'STEREOZ']
2 [False, True]


100%|██████████| 155186/155186 [00:00<00:00, 182737.19it/s]

(155186, 9)





In [8]:
train_raw_node_features = []
for mol in train_data['molecules']:
    train_raw_node_features = train_raw_node_features + get_raw_node_features(mol)
    
node_encoder = Encoder(train_raw_node_features)

train_node_features = node_encoder.transform(train_raw_node_features)
print(train_node_features.shape)




0 [1, 3, 5, 6, 7, 8, 9, 11, 12, 13, 14, 15, 16, 17, 19, 20, 30, 33, 34, 35, 38, 47, 53]
1 [0, 1, 2]
2 [0, 1, 2, 3, 4, 6]
3 [-1, 0, 1, 2, 3]
4 [0, 1, 2, 3, 4]
5 [0, 1]
6 [1, 2, 3, 4, 5, 6]
7 [False, True]
8 [False, True]


100%|██████████| 144802/144802 [00:02<00:00, 68139.04it/s]

(144802, 54)





In [9]:
def mol_2_pytorch_geometric(mol, node_encoder, bond_encoder):
    node_features = node_encoder.transform(get_raw_node_features(mol), show_progress = False)
    #print("node features: ", node_features.shape)
    
    edge_indices, edge_attrs = [], []
    
    num_bond_features = None
    for bond in mol.GetBonds():
        i = bond.GetBeginAtomIdx()
        j = bond.GetEndAtomIdx()

        bond_features = bond_encoder.transform([get_single_raw_bond_features(bond)], show_progress = False)
        num_bond_features = len(bond_features[0])
        #print("bond features: ", bond_features[0].shape)
        edge_indices += [[i, j], [j, i]]
        edge_attrs += [bond_features[0], bond_features[0]]
    
    if len(edge_attrs) > 0:
        
   
        edge_index = torch.tensor(edge_indices)
        edge_index = edge_index.t().to(torch.long).view(2, -1)
        edge_attrs = torch.tensor(edge_attrs, dtype=torch.long).view(-1, num_bond_features)
        ''' print(edge_attrs.shape)
        print(edge_index.shape)
        print(node_features.shape)
        print(edge_index)'''
        # Sort indices.
  
        perm = (edge_index[0] * node_features.shape[0] + edge_index[1]).argsort()
        #print("perm: ", perm.shape)
        #print(edge_index[0] * node_features.shape[0] + edge_index[1])
        edge_index, edge_attrs = edge_index[:, perm], edge_attrs[perm]
        data = Data(x=torch.FloatTensor(node_features), edge_index=edge_index, edge_attr=edge_attrs, empty = False)
    else:
        data = Data(x = torch.FloatTensor(node_features), empty = True)
    #print(edge_index.shape)
    #print(edge_attrs.shape)
    
    return data

In [10]:
first = mol_2_pytorch_geometric(train_data['molecules'][0], node_encoder, bond_encoder)

In [11]:
train_data['graphs'] = [mol_2_pytorch_geometric(molecule, node_encoder, bond_encoder) 
                        for molecule in tqdm(train_data['molecules'])]

test_data['graphs'] = [mol_2_pytorch_geometric(molecule, node_encoder, bond_encoder)
                       for molecule in tqdm(test_data['molecules'])]

100%|██████████| 5557/5557 [00:08<00:00, 630.07it/s]
100%|██████████| 1614/1614 [00:02<00:00, 624.11it/s]


In [12]:
positive_indices = np.arange(len(train_data))[np.array(train_data['Active'])]
negative_indices = np.arange(len(train_data))[np.array(train_data['Active']) == False]

np.random.seed(0)
np.random.shuffle(positive_indices)
np.random.shuffle(negative_indices)
print(positive_indices.shape)
print(negative_indices.shape)

(206,)
(5351,)


In [13]:
TRAIN_PERCENTAGE = 0.66
num_train_positive = int(positive_indices.shape[0] * TRAIN_PERCENTAGE)
print(num_train_positive)
train_positive = positive_indices[:num_train_positive]
val_positive = positive_indices[num_train_positive:]

num_train_negative = int(negative_indices.shape[0] * TRAIN_PERCENTAGE)
print(num_train_negative)
train_negative = negative_indices[:num_train_negative]
val_negative = negative_indices[num_train_negative:]

135
3531


In [14]:
train_indices = np.concatenate([train_positive, train_negative], axis = 0)
val_indices = np.concatenate([val_positive, val_negative], axis = 0)


In [15]:
val_graphs = [train_data['graphs'].to_list()[index] for index in val_indices]
val_labels = train_data['Active'].to_numpy()[val_indices]

train_graphs = [train_data['graphs'].to_list()[index] for index in train_indices]
train_labels = train_data['Active'].to_numpy()[train_indices]

print(train_labels)
for i in range(len(train_graphs)):
    train_graphs[i].y = int(train_labels[i])
for i in range(len(val_graphs)):
    val_graphs[i].y = int(val_labels[i])

    
train_graphs = [graph for graph in train_graphs if not graph.empty]
val_graphs = [graph for graph in val_graphs if not graph.empty]

[ True  True  True ... False False False]


In [16]:
print(len(val_graphs))
print(len(train_graphs))

1891
3665


In [17]:
from torch_geometric.nn import GCNConv
import torch.nn.functional as F
from torch_geometric.nn import global_max_pool
from torch import nn

class GCN(torch.nn.Module):
    def __init__(self, num_node_features):
        super().__init__()
        self.conv1 = GCNConv(num_node_features, 128)
        self.conv2 = GCNConv(128, 128)
        self.mlp_first = nn.Sequential(nn.Linear(128, 128), nn.BatchNorm1d(128), nn.ReLU(), nn.Linear(128, 128))
        self.mlp = nn.Sequential(nn.Linear(128, 128), nn.BatchNorm1d(128), nn.ReLU(), nn.Linear(128, 2))
        
    def forward(self, data):
        x, edge_index = data.x, data.edge_index

        x = self.conv1(x, edge_index)
        x = F.relu(x)
        x = F.dropout(x, p = 0.1, training=self.training)
        x = self.conv2(x, edge_index)
        x = self.mlp_first(x)
        x = global_max_pool(x, batch=data.batch)
        x = self.mlp(x)
        x = F.log_softmax(x)
        return x


In [18]:
BATCH_SIZE = 128
train_loader = DataLoader(train_graphs, batch_size=BATCH_SIZE, shuffle=True)
val_loader = DataLoader(val_graphs, batch_size = BATCH_SIZE, shuffle = False)

In [19]:
from sklearn.metrics import f1_score

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = GCN(54).to(device)
optimizer = torch.optim.Adam(model.parameters())
lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, factor = 0.5, patience = 1000)

NUM_EPOCHS = 100000
for epoch in range(NUM_EPOCHS):
    model.train(True)
    for data in train_loader:
        data.to(device)
        optimizer.zero_grad()
        outputs = model(data)
        #print(outputs)
        #print(data.y)
        loss = F.nll_loss(outputs, data.y)
        #loss = torch.sum(outputs[:, 0])
        loss.backward()
        optimizer.step()


    model.train(False)

    all_outputs = []
    all_targets = []

    for data in val_loader:
        data.to(device)
        outputs = model(data)
        all_outputs.append(outputs.data.cpu().numpy())
        all_targets.append(data.y.cpu().numpy())
    all_outputs = np.concatenate(all_outputs, axis = 0)
    all_targets = np.concatenate(all_targets, axis = 0)

    values = []
    for threshold in np.linspace(np.min(all_outputs), np.max(all_outputs), 1000):
        predictions = all_outputs[:, 1] > threshold
        now = f1_score(all_targets, predictions)
        values.append(now)
    
    lr_scheduler.step(np.max(values))
    
    if epoch % 5 == 0:
        print(epoch, np.max(values), optimizer.param_groups[0]['lr'])
        
   


0 0.07999999999999999 0.001
5 0.265625 0.001
10 0.22834645669291337 0.001
15 0.28571428571428575 0.001
20 0.29906542056074764 0.001
25 0.2923076923076923 0.001
30 0.3503649635036496 0.001
35 0.33928571428571425 0.001
40 0.30769230769230765 0.001
45 0.3652173913043478 0.001
50 0.30270270270270266 0.001
55 0.3559322033898305 0.001
60 0.33576642335766427 0.001
65 0.33027522935779813 0.001
70 0.3157894736842105 0.001
75 0.348993288590604 0.001
80 0.3647798742138365 0.001
85 0.3466666666666667 0.001
90 0.3431952662721893 0.001
95 0.36129032258064514 0.001
100 0.272108843537415 0.001
105 0.3214285714285714 0.001
110 0.3163841807909604 0.001
115 0.31654676258992803 0.001
120 0.3214285714285714 0.001
125 0.3304347826086957 0.001
130 0.32758620689655177 0.001
135 0.32758620689655177 0.001
140 0.33613445378151263 0.001
145 0.31932773109243695 0.001
150 0.2981366459627329 0.001
155 0.31666666666666665 0.001
160 0.3442622950819672 0.001
165 0.3252032520325204 0.001
170 0.3278688524590164 0.001
175

1390 0.31343283582089554 0.0005
1395 0.3087248322147651 0.0005
1400 0.3181818181818182 0.0005
1405 0.3235294117647059 0.0005
1410 0.33576642335766427 0.0005
1415 0.33333333333333337 0.0005
1420 0.3130434782608696 0.0005
1425 0.3119266055045872 0.0005
1430 0.3170731707317073 0.0005
1435 0.3018867924528302 0.0005
1440 0.32061068702290074 0.0005
1445 0.3114754098360656 0.0005
1450 0.3114754098360656 0.0005
1455 0.3140495867768595 0.0005
1460 0.30645161290322576 0.0005
1465 0.3103448275862069 0.0005
1470 0.3125 0.0005
1475 0.31007751937984496 0.0005
1480 0.30645161290322576 0.0005
1485 0.3125 0.0005
1490 0.31496062992125984 0.0005
1495 0.3188405797101449 0.0005
1500 0.3285714285714286 0.0005
1505 0.32558139534883723 0.0005
1510 0.3333333333333333 0.0005
1515 0.32 0.0005
1520 0.3220338983050848 0.0005
1525 0.3252032520325204 0.0005
1530 0.32258064516129037 0.0005
1535 0.3278688524590164 0.0005
1540 0.3181818181818182 0.0005
1545 0.3418803418803419 0.0005
1550 0.3114754098360656 0.0005
1555 

2700 0.31666666666666665 0.00025
2705 0.3140495867768595 0.00025
2710 0.3148148148148148 0.00025
2715 0.31666666666666665 0.00025
2720 0.31666666666666665 0.00025
2725 0.31666666666666665 0.00025
2730 0.31932773109243695 0.00025
2735 0.32758620689655177 0.00025
2740 0.3247863247863248 0.00025
2745 0.31666666666666665 0.00025
2750 0.3247863247863248 0.00025
2755 0.31932773109243695 0.00025
2760 0.3220338983050848 0.00025
2765 0.3220338983050848 0.00025
2770 0.3247863247863248 0.00025
2775 0.3220338983050848 0.00025
2780 0.3076923076923077 0.00025
2785 0.3103448275862069 0.00025
2790 0.3185840707964602 0.00025
2795 0.3185840707964602 0.00025
2800 0.31578947368421056 0.00025
2805 0.31578947368421056 0.00025
2810 0.31578947368421056 0.00025
2815 0.31578947368421056 0.00025
2820 0.3148148148148148 0.00025
2825 0.31578947368421056 0.00025
2830 0.3214285714285714 0.00025
2835 0.3243243243243243 0.00025
2840 0.3214285714285714 0.00025
2845 0.3214285714285714 0.00025
2850 0.3214285714285714 0.0

3945 0.3448275862068965 0.000125
3950 0.3423423423423423 0.000125
3955 0.33928571428571425 0.000125
3960 0.3423423423423423 0.000125
3965 0.34862385321100914 0.000125
3970 0.3418803418803419 0.000125
3975 0.3423423423423423 0.000125
3980 0.33962264150943394 0.000125
3985 0.33962264150943394 0.000125
3990 0.33628318584070793 0.000125
3995 0.3454545454545454 0.000125
4000 0.3442622950819672 0.000125
4005 0.3508771929824561 6.25e-05
4010 0.36206896551724144 6.25e-05
4015 0.36206896551724144 6.25e-05
4020 0.3652173913043478 6.25e-05
4025 0.3652173913043478 6.25e-05
4030 0.3652173913043478 6.25e-05
4035 0.3652173913043478 6.25e-05
4040 0.3652173913043478 6.25e-05
4045 0.3652173913043478 6.25e-05
4050 0.3652173913043478 6.25e-05
4055 0.34710743801652894 6.25e-05
4060 0.3529411764705882 6.25e-05
4065 0.3529411764705882 6.25e-05
4070 0.3539823008849558 6.25e-05
4075 0.35185185185185186 6.25e-05
4080 0.35000000000000003 6.25e-05
4085 0.3559322033898305 6.25e-05
4090 0.35000000000000003 6.25e-05

5180 0.36036036036036034 3.125e-05
5185 0.36036036036036034 3.125e-05
5190 0.36036036036036034 3.125e-05
5195 0.36036036036036034 3.125e-05
5200 0.36036036036036034 3.125e-05
5205 0.36036036036036034 3.125e-05
5210 0.36036036036036034 3.125e-05
5215 0.36036036036036034 3.125e-05
5220 0.36036036036036034 3.125e-05
5225 0.36036036036036034 3.125e-05
5230 0.36036036036036034 3.125e-05
5235 0.36036036036036034 3.125e-05
5240 0.36036036036036034 3.125e-05
5245 0.36036036036036034 3.125e-05
5250 0.36036036036036034 3.125e-05
5255 0.36036036036036034 3.125e-05
5260 0.36036036036036034 3.125e-05
5265 0.36036036036036034 3.125e-05
5270 0.3571428571428571 3.125e-05
5275 0.36036036036036034 3.125e-05
5280 0.3571428571428571 3.125e-05
5285 0.3571428571428571 3.125e-05
5290 0.3571428571428571 3.125e-05
5295 0.3571428571428571 3.125e-05
5300 0.3571428571428571 3.125e-05
5305 0.3571428571428571 3.125e-05
5310 0.3571428571428571 3.125e-05
5315 0.36036036036036034 3.125e-05
5320 0.3571428571428571 3.12

6360 0.34782608695652173 1.5625e-05
6365 0.34782608695652173 1.5625e-05
6370 0.34862385321100914 1.5625e-05
6375 0.34862385321100914 1.5625e-05
6380 0.34782608695652173 1.5625e-05
6385 0.34862385321100914 1.5625e-05
6390 0.34862385321100914 1.5625e-05
6395 0.34782608695652173 1.5625e-05
6400 0.34862385321100914 1.5625e-05
6405 0.34862385321100914 1.5625e-05
6410 0.34782608695652173 1.5625e-05
6415 0.34862385321100914 1.5625e-05
6420 0.34782608695652173 1.5625e-05
6425 0.34862385321100914 1.5625e-05
6430 0.34862385321100914 1.5625e-05
6435 0.34862385321100914 1.5625e-05
6440 0.34782608695652173 1.5625e-05
6445 0.34862385321100914 1.5625e-05
6450 0.34862385321100914 1.5625e-05
6455 0.34782608695652173 1.5625e-05
6460 0.34862385321100914 1.5625e-05
6465 0.34782608695652173 1.5625e-05
6470 0.34862385321100914 1.5625e-05
6475 0.3454545454545454 1.5625e-05
6480 0.34782608695652173 1.5625e-05
6485 0.34782608695652173 1.5625e-05
6490 0.34782608695652173 1.5625e-05
6495 0.34862385321100914 1.56

KeyboardInterrupt: 