In [1]:
%load_ext autoreload
%autoreload 2

from src.rxn_ctr_mcs import *
from src.utils import load_json, rxn_entry_to_smarts, rm_atom_map_num
from src.pathway_utils import get_reverse_paths_to_starting, create_graph_from_pickaxe
from src.post_processing import *

from minedatabase.pickaxe import Pickaxe
from minedatabase.utils import get_compound_hash

from rdkit.Chem import AllChem

from collections import defaultdict
import pandas as pd
import csv
import pickle
import subprocess

  from .autonotebook import tqdm as notebook_tqdm
Some weights of the model checkpoint at /home/stef/miniconda3/envs/mine/lib/python3.7/site-packages/rxnmapper/models/transformers/albert_heads_8_uspto_all_1310k were not used when initializing AlbertModel: ['predictions.dense.weight', 'predictions.decoder.weight', 'predictions.LayerNorm.bias', 'predictions.LayerNorm.weight', 'predictions.decoder.bias', 'predictions.bias', 'predictions.dense.bias']
- This IS expected if you are initializing AlbertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing AlbertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [65]:
# Load processed expansion
starters = 'ccm_v0'
targets = 'mvacid'
generations = 4

expansion_dir = '../data/processed_expansions/'
thermo_dir = '../data/thermo/'
fn = f"{starters}_to_{targets}_gen_{generations}_tan_sample_1_n_samples_1000.pk" # Expansion file name
rxns_path = expansion_dir + 'predicted_reactions_' + fn
paths_path = expansion_dir + 'paths_' + fn
pruned_dir = '../data/pruned_expansions/'

# Load reactions and paths
with open(rxns_path, 'rb') as f:
    pred_rxns = pickle.load(f)

with open(paths_path, 'rb') as f:
    paths = pickle.load(f)

# Load raw expansion object
pk = Pickaxe()
path = pruned_dir + fn
pk.load_pickled_pickaxe(path)

----------------------------------------
Intializing pickaxe object

Done intializing pickaxe object
----------------------------------------

Loading ../data/pruned_expansions/ccm_v0_to_mvacid_gen_4_tan_sample_1_n_samples_1000.pk pickled data.
Loaded 458 compounds
Loaded 478 reactions
Loaded 3604 operators
Loaded 1 targets
Took 3.4294538497924805


In [66]:
for k,v in paths.items():
    print(k, len(v))

('fumarate', 'mvacid') 557
('succinate', 'mvacid') 17
('acetate', 'mvacid') 2
('pyruvate', 'mvacid') 3


In [67]:
# Load in IMT rule mapping

# Load rules
rules_path = '../data/rules/JN3604IMT_rules.tsv'
rule_df = pd.read_csv(rules_path, delimiter='\t')
rule_df.set_index('Name', inplace=True)

# Load mapping
rxn2rule = {}
db_names = ['_mc_v21', '_brenda', '_kegg']
suffix = '_imt_rules_enforce_cof.csv'
for name in db_names:
    mapping_path = '../data/mapping/mapping' + name + suffix
    with open(mapping_path, 'r') as f:
        reader = csv.reader(f)
        for row in reader:
            if len(row) == 1:
                rxn2rule[row[0]] = []
            else:
                rxn2rule[row[0]] = row[1:]

# Make rule2rxn
rule2rxn = {}
for k,v in rxn2rule.items():
    for elt in v:
        if elt not in rule2rxn:
            rule2rxn[elt] = [k]
        else:
            rule2rxn[elt].append(k)

# Load all known reaction json entries into dict
known_rxns = {}
pref = '../data/mapping/'
suffs = ['mc_v21_as_is.json', 'brenda_as_is.json', 'kegg_as_is.json']
for elt in suffs:
    known_rxns.update(load_json(pref + elt))

In [68]:
# Populate reaction objects in rxn dict w/ known reactions

for k, v in pred_rxns.items():
    this_rules = list(pk.reactions[k]["Operators"])
    this_known_rxns = []
    for elt in this_rules:
        if elt in rule2rxn:
            this_rxn_ids = rule2rxn[elt]
            for this_id in this_rxn_ids:
                this_sma = rxn_entry_to_smarts(known_rxns[this_id])
                this_known_rxns.append((None, this_sma, this_id))
    
    v.known_rxns = [list(elt) for elt in set(this_known_rxns)]


In [69]:
print(len(pred_rxns))
for k,v in paths.items():
    print(k, len(v))

478
('fumarate', 'mvacid') 557
('succinate', 'mvacid') 17
('acetate', 'mvacid') 2
('pyruvate', 'mvacid') 3


In [70]:
pr_am_errors = [] # Track predicted rxn am errors
kr_am_errors = [] # Track known rxn am errors
alignment_issues = [] # Track substrate alignment issues
norm = 'max atoms' # Normalize MCS atom count by larger molecule

# Populate pred_rxns, known rxn prc-mcs slot
# for x in range(1):
for x in range(len(pred_rxns.keys())):
    h = list(pred_rxns.keys())[x]
    rxn_sma1 = pred_rxns[h].smarts

    # Skip pred reactions that trigger RXNMapper atom mapping errors
    try:
        am_rxn_sma1 = atom_map(rxn_sma1)
    except:
        pr_am_errors.append(h)
        continue

    a = 0 # Number known rxns analyzed
    for z, kr in enumerate(pred_rxns[h].known_rxns):
        rxn_sma2 = kr[1]

        # Catch stoichiometry mismatches stemming from pickaxe, early post-processing
        if tuple([len(elt.split('.')) for elt in rxn_sma2.split('>>')]) != tuple([len(elt.split('.')) for elt in rxn_sma1.split('>>')]):
            print(x, z, 'stoich_error')
            continue

        # Skip pred reactions that trigger RXNMapper atom mapping errors
        try:
            am_rxn_sma2 = atom_map(rxn_sma2)
        except:
            kr_am_errors.append((h, z, kr[-1]))
            continue

        # Construct reaction objects
        rxns = []
        for elt in [am_rxn_sma1, am_rxn_sma2]:
            temp = AllChem.ReactionFromSmarts(elt, useSmiles=True)
            temp.Initialize()
            rxns.append(temp)

        rc_atoms = [elt.GetReactingAtoms() for elt in rxns] # Get reaction center atom idxs

        # Construct rxn ctr mol objs
        try: # REMOVE after addressing KekulizationException in get_sub_mol
            rcs = []
            for i, t_rxn in enumerate(rxns):
                temp = []
                for j, t_mol in enumerate(t_rxn.GetReactants()):
                    temp.append(get_sub_mol(t_mol, rc_atoms[i][j]))
                rcs.append(temp)
        except:
            continue

        # Align substrates of the 2 reactions
        rc_idxs = [] # Each element: (idx for rxn 1, idx for rxn 2)
        remaining = [[i for i in range(len(elt))] for elt in rcs]
        while (len(remaining[0]) > 0) & (len(remaining[1]) > 0):
            idx_pair = align_substrates(rcs, remaining)

            if idx_pair is None:
                break
            else:
                rc_idxs.append(idx_pair)
                remaining[0].remove(idx_pair[0])
                remaining[1].remove(idx_pair[1])

        # Skip if you haven't aligned every reactant pred to known
        if len(rc_idxs) < len(rxn_sma1.split('>>')[0].split('.')):
            alignment_issues.append((h, z, kr[-1]))
            continue

        # For reaction 2 (known reaction) Re-order rcs, rc_atoms,
        # internal order of reactants in the reaction object in rxns
        # and the smarts stored in the known_reactions attribute of the
        # associated predicted reaction

        # Sort reaction 2 rc_idxs by reaction 1 rc_idxs
        rxn_1_rc_idxs, rxn_2_rc_idxs = list(zip(*rc_idxs))
        if rxn_1_rc_idxs != rxn_2_rc_idxs:
            rxn_2_rc_idxs, rxn_1_rc_idxs = sort_x_by_y(rxn_2_rc_idxs, rxn_1_rc_idxs)

            # Re-order atom-mapped smarts string, and then update known_rxns entry
            # with de-atom-mapped version of this string because atom mapper changes
            # reactant order and its this order that rcs, rcatoms, rc_idxs all come from
            am_ro_sma2 = am_rxn_sma2.split('>>')[0].split('.') # Get list of reactant strings
            am_ro_sma2 = '.'.join([am_ro_sma2[elt] for elt in rxn_2_rc_idxs]) # Re-join in new order
            am_rxn_sma2 = am_ro_sma2 + '>>' + am_rxn_sma2.split('>>')[1] # Join with products side

            # Re-construct reaction object from re-ordered, am smarts
            foo = rxns[1]
            temp = AllChem.ReactionFromSmarts(am_rxn_sma2, useSmiles=True)
            temp.Initialize()
            rxns[1] = temp
            bar = rxns[1]

            rc_atoms[1] = rxns[1].GetReactingAtoms() # Update rc_atoms
            rcs[1] = [get_sub_mol(elt, rc_atoms[1][i]) for i, elt in enumerate(rxns[1].GetReactants())] # Update rc mol obj
        
        pred_rxns[h].known_rxns[z][1] = rm_atom_map_num(am_rxn_sma2) # Update known_reaction entry w/ de-am smarts
        rxns = align_atom_map_nums(rxns, rcs, rc_atoms)

        # Compute MCS seeded by reaction center
        prc_mcs = get_prc_mcs(rxns, rcs, rc_atoms, norm=norm) 
        pred_rxns[h].known_rxns[z][0] = prc_mcs # Update pred_rxns
        
        a += 1 # Count known rxn analyzed
        pred_rxns[h].smarts = rm_atom_map_num(am_rxn_sma1) # Update pred_rxn smarts w/ de-am smarts

    print(x, ':', a / (z+1), 'of', z+1)

0 : 0.52 of 50
1 : 0.13636363636363635 of 88
2 : 1.0 of 9
3 : 0.6086956521739131 of 23
4 : 1.0 of 15
5 : 0.13636363636363635 of 88
6 : 1.0 of 9
7 : 0.6086956521739131 of 23
8 : 1.0 of 18
9 : 0.15 of 40
10 : 0.775 of 40
11 : 0.13636363636363635 of 88
12 : 0.0 of 38
13 : 0.13636363636363635 of 88
14 : 0.0 of 88
15 : 0.6086956521739131 of 23
16 : 0.13636363636363635 of 88
17 : 0.0 of 8
18 : 0.13636363636363635 of 88
19 : 0.15 of 40
20 : 0.3333333333333333 of 12
21 : 0.13636363636363635 of 88
22 : 0.0 of 27
23 : 0.9 of 10
24 : 0.9 of 10
25 : 0.13636363636363635 of 88
26 : 0.13636363636363635 of 88
27 : 0.13636363636363635 of 88
28 : 1.0 of 15
29 : 0.13636363636363635 of 88
30 : 1.0 of 15
31 : 0.13636363636363635 of 88
32 : 0.13636363636363635 of 88
33 : 0.0 of 18
34 : 1.0 of 8
35 : 0.723404255319149 of 47
36 : 0.13636363636363635 of 88
37 : 0.3333333333333333 of 12
38 : 0.0 of 88
39 : 1.0 of 16
40 : 0.13636363636363635 of 88
41 : 1.0 of 15
42 : 0.9 of 10
43 : 1.0 of 15
44 : 0.1363636363636

RDKit ERROR: [14:00:37] non-ring atom 0 marked aromatic
[14:00:37] non-ring atom 0 marked aromatic
RDKit ERROR: [14:00:37] non-ring atom 0 marked aromatic
[14:00:37] non-ring atom 0 marked aromatic
RDKit ERROR: [14:00:37] non-ring atom 0 marked aromatic
[14:00:37] non-ring atom 0 marked aromatic
RDKit ERROR: [14:00:37] non-ring atom 0 marked aromatic
[14:00:37] non-ring atom 0 marked aromatic
RDKit ERROR: [14:00:37] non-ring atom 0 marked aromatic
[14:00:37] non-ring atom 0 marked aromatic
RDKit ERROR: [14:00:37] non-ring atom 0 marked aromatic
[14:00:37] non-ring atom 0 marked aromatic
RDKit ERROR: [14:00:37] non-ring atom 0 marked aromatic
[14:00:37] non-ring atom 0 marked aromatic
RDKit ERROR: [14:00:37] non-ring atom 0 marked aromatic
[14:00:37] non-ring atom 0 marked aromatic
RDKit ERROR: [14:00:37] non-ring atom 0 marked aromatic
[14:00:37] non-ring atom 0 marked aromatic
RDKit ERROR: [14:00:37] non-ring atom 0 marked aromatic
[14:00:37] non-ring atom 0 marked aromatic
RDKit ERRO

132 : 0.0 of 88
133 : 1.0 of 3
134 : 0.0 of 88
135 : 0.13636363636363635 of 88
136 : 0.13636363636363635 of 88
137 : 0.0 of 67
138 : 0.13636363636363635 of 88
139 : 0.13636363636363635 of 88
140 : 0.13636363636363635 of 88
141 : 1.0 of 18
142 : 0.0 of 88
143 : 0.13636363636363635 of 88
144 : 1.0 of 18
145 : 1.0 of 21
146 : 1.0 of 7
147 : 0.13636363636363635 of 88
148 : 0.13636363636363635 of 88
149 : 0.13636363636363635 of 88
150 : 0.13636363636363635 of 88
151 : 1.0 of 9
152 : 0.0 of 76
153 : 0.9 of 10
154 : 0.13636363636363635 of 88
155 : 0.13636363636363635 of 88
156 : 0.13636363636363635 of 88
157 : 0.13636363636363635 of 88
158 : 0.13636363636363635 of 88
159 : 0.13636363636363635 of 88
160 : 0.13636363636363635 of 88
161 : 0.9 of 10
162 : 0.13636363636363635 of 88
163 : 0.13636363636363635 of 88
164 : 0.13636363636363635 of 88
165 : 0.13636363636363635 of 88
166 : 0.13636363636363635 of 88
167 : 0.0 of 88
168 : 0.13636363636363635 of 88
169 : 0.13636363636363635 of 88
170 : 0.136

In [71]:
# Thermo

# starters = 'succinate'
# targets = 'mvacid'
# generations = 4
# args = ['-s', f"{starters}", '-t', f"{targets}", '-g', str(generations)]
# command = f"source activate /home/stef/miniconda3/envs/thermo && python /home/stef/pickaxe_thermodynamics/path_mdf.py {' '.join(args)}"
# subprocess.run(command, shell=True)

thermo = load_json(thermo_dir + fn)
for k,v in thermo.items():
    st = tuple(k.split('>'))
    for i, elt in enumerate(thermo[k]):
        paths[st][i].mdf = elt['mdf']
        paths[st][i].dG_opt = elt['dG_opt']
        paths[st][i].dG_err = elt['dG_err']
        paths[st][i].conc_opt = elt['conc_opt']


In [72]:
# Save reactions dict and paths list (ultimately will replace with expansion object)

rxns_fn = 'predicted_reactions_' + fn
paths_fn = 'paths_' + fn
save_dir = '../data/processed_expansions/'
rxns_path = save_dir + rxns_fn
paths_path = save_dir + paths_fn

with open(rxns_path, 'wb') as f:
    pickle.dump(pred_rxns, f)

with open(paths_path, 'wb') as f:
    pickle.dump(paths, f)

In [73]:
len(alignment_issues), len(pr_am_errors), len(kr_am_errors)

(25738, 1, 2357)