In [26]:
%load_ext autoreload
%autoreload 2

from src.rxn_ctr_mcs import *
from src.utils import load_json, rxn_entry_to_smarts, rm_atom_map_num
from src.pathway_utils import get_reverse_paths_to_starting, create_graph_from_pickaxe
from src.post_processing import *

from minedatabase.pickaxe import Pickaxe
from minedatabase.utils import get_compound_hash

from rdkit.Chem import AllChem

from collections import defaultdict
import pandas as pd
import csv
import pickle

Some weights of the model checkpoint at /home/stef/miniconda3/envs/mine/lib/python3.7/site-packages/rxnmapper/models/transformers/albert_heads_8_uspto_all_1310k were not used when initializing AlbertModel: ['predictions.LayerNorm.bias', 'predictions.decoder.bias', 'predictions.decoder.weight', 'predictions.LayerNorm.weight', 'predictions.dense.weight', 'predictions.dense.bias', 'predictions.bias']
- This IS expected if you are initializing AlbertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing AlbertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [3]:
# Set params

expansion_dir = '../data/raw_expansions/'
starters = 'succinate'
targets = 'mvacid'
generations = 4
fn = f"{starters}_to_{targets}_gen_{generations}_tan_sample_1_n_samples_1000.pk" # Expansion file name

# Load raw expansion object
pk = Pickaxe()
path = expansion_dir + fn
pk.load_pickled_pickaxe(path)

----------------------------------------
Intializing pickaxe object

Done intializing pickaxe object
----------------------------------------

Loading ../data/raw_expansions/succinate_to_mvacid_gen_4_tan_sample_1_n_samples_1000.pk pickled data.
Loaded 89758 compounds
Loaded 106853 reactions
Loaded 3604 operators
Loaded 1 targets
Took 6.473435878753662


In [4]:
# Create the initial graph

DG, rxn, edge = create_graph_from_pickaxe(pk, "Biology")
starting_nodes = []
bad_nodes = []
for n in DG.nodes():
    try:
        if DG.nodes()[n]["Type"] == "Starting Compound":
            starting_nodes.append(n)
    except:
        bad_nodes.append(n)

RDKit ERROR: [12:35:21] Can't kekulize mol.  Unkekulized atoms: 19 23 26 27 28 29 30 31 33 34
[12:35:21] Can't kekulize mol.  Unkekulized atoms: 19 23 26 27 28 29 30 31 33 34

RDKit ERROR: 
RDKit ERROR: [12:35:22] Can't kekulize mol.  Unkekulized atoms: 19 23 26 27 28 29 30 31 33 34
RDKit ERROR: 
[12:35:22] Can't kekulize mol.  Unkekulized atoms: 19 23 26 27 28 29 30 31 33 34

RDKit ERROR: [12:35:30] Can't kekulize mol.  Unkekulized atoms: 19 23 26 27 28 29 30 31 33 34
[12:35:30] Can't kekulize mol.  Unkekulized atoms: 19 23 26 27 28 29 30 31 33 34

RDKit ERROR: 
RDKit ERROR: [12:35:31] Can't kekulize mol.  Unkekulized atoms: 19 23 26 27 28 29 30 31 33 34
RDKit ERROR: 
[12:35:31] Can't kekulize mol.  Unkekulized atoms: 19 23 26 27 28 29 30 31 33 34

RDKit ERROR: [12:35:31] Can't kekulize mol.  Unkekulized atoms: 19 23 26 27 28 29 30 31 33 34
RDKit ERROR: 
[12:35:31] Can't kekulize mol.  Unkekulized atoms: 19 23 26 27 28 29 30 31 33 34

RDKit ERROR: [12:35:31] Can't kekulize mol.  Unkek

In [5]:
# Get pathways
max_depth = generations * 2
paths = defaultdict(list)

# Specify Targets / Starting Cpds
# target_cids = [get_compound_hash(smi)[0] for smi in pk.target_smiles]
# target_names = [target_smi_name.loc[smi, "id"] for smi in pk.target_smiles]
target_cids, target_names = [], []
for k,v in pk.targets.items():
    target_cids.append(get_compound_hash(v['SMILES'])[0])
    target_names.append(v['ID'])

starting_cpds = [get_compound_hash(val["SMILES"])[0] for val in pk.compounds.values() if val["Type"].startswith("Start")]

# Loop through targets and get pathways from targets to starting compounds
for i, this_target in enumerate(target_cids):
    this_paths = get_reverse_paths_to_starting(DG, begin_node=this_target, end_nodes=starting_cpds, max_depth=max_depth)
    # If we find paths then reverse those paths and assign to a dictionary
    if this_paths:
        this_paths = list(set([tuple(path[1::2]) for path in [[*reversed(ind_path)] for ind_path in this_paths]]))
        for elt in this_paths:
            for r in pk.reactions[elt[0]]["Reactants"]:
                if r[-1] in starting_cpds:
                    s_name = pk.compounds[r[-1]]["ID"]
                    t_name = target_names[i]
                    paths[(s_name, t_name)].append(pathway(rhashes=elt, starter_hash=r[-1], target_hash=this_target)) 

In [6]:
# Make predicted reaction dict

pred_rxns = {}
degen_rhashes = defaultdict(lambda : 1)
for st_pair in paths:
    for elt in paths[st_pair]:
        for this_rhash in elt.rhashes:
            if this_rhash not in pred_rxns:
                rxn_sma = rxn_hash_2_rxn_sma(this_rhash, pk)
                pred_rxns[this_rhash] = reaction(this_rhash, rxn_sma)
            else:
                degen_rhashes[this_rhash] += 1


In [7]:
for k,v in paths.items():
    print(k, len(v))

('succinate', 'mvacid') 41


In [8]:
# Load in IMT rule mapping

# Load rules
rules_path = '../data/rules/JN3604IMT_rules.tsv'
rule_df = pd.read_csv(rules_path, delimiter='\t')
rule_df.set_index('Name', inplace=True)

# Load mapping
rxn2rule = {}
db_names = ['_mc_v21', '_brenda', '_kegg']
suffix = '_imt_rules_enforce_cof.csv'
for name in db_names:
    mapping_path = '../data/mapping/mapping' + name + suffix
    with open(mapping_path, 'r') as f:
        reader = csv.reader(f)
        for row in reader:
            if len(row) == 1:
                rxn2rule[row[0]] = []
            else:
                rxn2rule[row[0]] = row[1:]

# Make rule2rxn
rule2rxn = {}
for k,v in rxn2rule.items():
    for elt in v:
        if elt not in rule2rxn:
            rule2rxn[elt] = [k]
        else:
            rule2rxn[elt].append(k)

# Load all known reaction json entries into dict
known_rxns = {}
pref = '../data/mapping/'
suffs = ['mc_v21_as_is.json', 'brenda_as_is.json', 'kegg_as_is.json']
for elt in suffs:
    known_rxns.update(load_json(pref + elt))

In [9]:
# Populate reaction objects in rxn dict w/ known reactions

for k, v in pred_rxns.items():
    this_rules = list(pk.reactions[k]["Operators"])
    this_known_rxns = []
    for elt in this_rules:
        if elt in rule2rxn:
            this_rxn_ids = rule2rxn[elt]
            for this_id in this_rxn_ids:
                this_sma = rxn_entry_to_smarts(known_rxns[this_id])
                this_known_rxns.append((None, this_sma, this_id))
    
    v.known_rxns = [list(elt) for elt in set(this_known_rxns)]


In [10]:
print(len(pred_rxns))
for k,v in paths.items():
    print(k, len(v))

33
('succinate', 'mvacid') 41


In [27]:
pr_am_errors = [] # Track predicted rxn am errors
kr_am_errors = [] # Track known rxn am errors
alignment_issues = [] # Track substrate alignment issues

# Populate pred_rxns, known rxn prc-mcs slot
# for x in range(1):
for x in range(len(pred_rxns.keys())):
    h = list(pred_rxns.keys())[x]
    rxn_sma1 = pred_rxns[h].smarts

    # Skip pred reactions that trigger RXNMapper atom mapping errors
    try:
        am_rxn_sma1 = atom_map(rxn_sma1)
    except:
        pr_am_errors.append(h)
        continue

    a = 0 # Number known rxns analyzed
    for z, kr in enumerate(pred_rxns[h].known_rxns):
        rxn_sma2 = kr[1]

        # Catch stoichiometry mismatches stemming from pickaxe, early post-processing
        if tuple([len(elt.split('.')) for elt in rxn_sma2.split('>>')]) != tuple([len(elt.split('.')) for elt in rxn_sma1.split('>>')]):
            print(x, z, 'stoich_error')
            continue

        # Skip pred reactions that trigger RXNMapper atom mapping errors
        try:
            am_rxn_sma2 = atom_map(rxn_sma2)
        except:
            kr_am_errors.append((h, z, kr[-1]))
            continue

        # Construct reaction objects
        rxns = []
        for elt in [am_rxn_sma1, am_rxn_sma2]:
            temp = AllChem.ReactionFromSmarts(elt, useSmiles=True)
            temp.Initialize()
            rxns.append(temp)

        rc_atoms = [elt.GetReactingAtoms() for elt in rxns] # Get reaction center atom idxs

        # Construct rxn ctr mol objs
        try: # REMOVE after addressing KekulizationException in get_sub_mol
            rcs = []
            for i, t_rxn in enumerate(rxns):
                temp = []
                for j, t_mol in enumerate(t_rxn.GetReactants()):
                    temp.append(get_sub_mol(t_mol, rc_atoms[i][j]))
                rcs.append(temp)
        except:
            continue

        # Align substrates of the 2 reactions
        rc_idxs = [] # Each element: (idx for rxn 1, idx for rxn 2)
        remaining = [[i for i in range(len(elt))] for elt in rcs]
        while (len(remaining[0]) > 0) & (len(remaining[1]) > 0):
            idx_pair = align_substrates(rcs, remaining)

            if idx_pair is None:
                break
            else:
                rc_idxs.append(idx_pair)
                remaining[0].remove(idx_pair[0])
                remaining[1].remove(idx_pair[1])

        # Skip if you haven't aligned every reactant pred to known
        if len(rc_idxs) < len(rxn_sma1.split('>>')[0].split('.')):
            alignment_issues.append((h, z, kr[-1]))
            continue

        # For reaction 2 (known reaction) Re-order rcs, rc_atoms,
        # internal order of reactants in the reaction object in rxns
        # and the smarts stored in the known_reactions attribute of the
        # associated predicted reaction

        # Sort reaction 2 rc_idxs by reaction 1 rc_idxs
        rxn_1_rc_idxs, rxn_2_rc_idxs = list(zip(*rc_idxs))
        if rxn_1_rc_idxs != rxn_2_rc_idxs:
            rxn_2_rc_idxs, rxn_1_rc_idxs = sort_x_by_y(rxn_2_rc_idxs, rxn_1_rc_idxs)

            # Re-order atom-mapped smarts string, and then update known_rxns entry
            # with de-atom-mapped version of this string because atom mapper changes
            # reactant order and its this order that rcs, rcatoms, rc_idxs all come from
            am_ro_sma2 = am_rxn_sma2.split('>>')[0].split('.') # Get list of reactant strings
            am_ro_sma2 = '.'.join([am_ro_sma2[elt] for elt in rxn_2_rc_idxs]) # Re-join in new order
            am_rxn_sma2 = am_ro_sma2 + '>>' + am_rxn_sma2.split('>>')[1] # Join with products side

            # Re-construct reaction object from re-ordered, am smarts
            foo = rxns[1]
            temp = AllChem.ReactionFromSmarts(am_rxn_sma2, useSmiles=True)
            temp.Initialize()
            rxns[1] = temp
            bar = rxns[1]

            rc_atoms[1] = rxns[1].GetReactingAtoms() # Update rc_atoms
            rcs[1] = [get_sub_mol(elt, rc_atoms[1][i]) for i, elt in enumerate(rxns[1].GetReactants())] # Update rc mol obj
        
        pred_rxns[h].known_rxns[z][1] = rm_atom_map_num(am_rxn_sma2) # Update known_reaction entry w/ de-am smarts
        rxns = align_atom_map_nums(rxns, rcs, rc_atoms)

        # Compute MCS seeded by reaction center
        prc_mcs = get_prc_mcs(rxns, rcs, rc_atoms) 
        pred_rxns[h].known_rxns[z][0] = prc_mcs # Update pred_rxns
        
        a += 1 # Count known rxn analyzed
        pred_rxns[h].smarts = rm_atom_map_num(am_rxn_sma1) # Update pred_rxn smarts w/ de-am smarts

    print(x, ':', a / (z+1), 'of', z+1)

RuntimeError: Pre-condition Violation
	getExplicitValence() called without call to calcExplicitValence()
	Violation occurred on line 199 in file Code/GraphMol/Atom.cpp
	Failed Expression: d_explicitValence > -1
	RDKIT: 2020.09.1
	BOOST: 1_73


In [None]:
# Save reactions dict and paths list (ultimately will replace with expansion object)

rxns_fn = 'predicted_reactions_' + fn
paths_fn = 'paths_' + fn
save_dir = '../data/processed_expansions/'
rxns_path = save_dir + rxns_fn
paths_path = save_dir + paths_fn

with open(rxns_path, 'wb') as f:
    pickle.dump(pred_rxns, f)

with open(paths_path, 'wb') as f:
    pickle.dump(paths, f)

In [None]:
len(alignment_issues), len(pr_am_errors), len(kr_am_errors)