-
Notifications
You must be signed in to change notification settings - Fork 0
/
preprocess.py
78 lines (61 loc) · 2.4 KB
/
preprocess.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
from multiprocessing import Pool
import math, random, sys
import pickle
import argparse
from functools import partial
import torch
import rdkit
import pandas as pd
import numpy as np
from ggpm import MolGraph, common_atom_vocab, PairVocab
def to_numpy(tensors):
convert = lambda x : x.numpy() if type(x) is torch.Tensor else x
a,b,c,d,e,f = tensors
c = [convert(x) for x in c[0]], [convert(x) for x in c[1]]
return a, b, c, d, e, f
def tensorize(mol_batch, vocab):
x = MolGraph.tensorize(mol_batch, vocab, common_atom_vocab)
return to_numpy(x)
if __name__ == "__main__":
lg = rdkit.RDLogger.logger()
lg.setLevel(rdkit.RDLogger.CRITICAL)
parser = argparse.ArgumentParser()
parser.add_argument('--train', required=True)
parser.add_argument('--vocab', required=True)
parser.add_argument('--batch_size', type=int, default=20)
parser.add_argument('--ncpu', type=int, default=1)
args = parser.parse_args()
with open(args.vocab) as f:
vocab = [x.strip("\r\n ").split() for x in f]
MolGraph.load_fragments([x[0] for x in vocab if eval(x[-1])])
args.vocab = PairVocab([(x,y) for x,y,_ in vocab], cuda=False)
pool = Pool(args.ncpu)
random.seed(1)
if args.train.endswith('.csv'):
data = pd.read_csv(args.train)
# drop row w/ emtpy HOMO and LUMO
data = data.dropna().reset_index(drop=True).to_numpy()
else:
with open(args.train) as f:
data = [[x, float(h), float(l)] for line in f for x, h, l in line.strip("\r\n ").split()]
random.shuffle(data)
batches = [data[i : i + args.batch_size] for i in range(0, len(data), args.batch_size)]
if args.ncpu == 1:
all_data = []
for b in batches:
all_data.append(tensorize(b, args.vocab))
else:
func = partial(tensorize, vocab = args.vocab)
all_data = pool.map(func, batches)
# split to save into small files
if len(all_data) < 1000:
with open('tensors-%d.pkl' % 0, 'wb') as f:
pickle.dump(all_data, f, pickle.HIGHEST_PROTOCOL)
else:
num_splits = len(all_data) // 1000
le = (len(all_data) + num_splits - 1) // num_splits
for split_id in range(num_splits):
st = split_id * le
sub_data = all_data[st : st + le]
with open('tensors-%d.pkl' % split_id, 'wb') as f:
pickle.dump(sub_data, f, pickle.HIGHEST_PROTOCOL)