In [1]:
import json
import pandas as pd
from typing import List
from rxntools import reaction, utils
from collections import Counter

In [2]:
reported_rxns_df = pd.read_parquet("../data/interim/enzymemap_MetaCyc_JN_mapped.parquet")
JN_rules_df = pd.read_csv('../data/raw/JN1224MIN_rules.tsv', delimiter='\t')

with open('../data/raw/cofactors.json') as f:
    cofactors_dict = json.load(f)

all_cofactor_codes: List[str] = list(cofactors_dict.keys())
cofactors_list: List[str] = [cofactors_dict[key] for key in cofactors_dict.keys()]
cofactors_df = pd.read_csv('../data/raw/all_cofactors.csv')


In [3]:
def get_cofactor_SMARTS_from_JN_rule(cofactor_code, reactant_codes: str, product_codes: str, rxn_SMARTS: str, rxn_side: str):
    lhs_side_SMARTS, rhs_side_SMARTS = rxn_SMARTS.split('>>')
    lhs_side_SMARTS_list = lhs_side_SMARTS.split('.')
    rhs_side_SMARTS_list = rhs_side_SMARTS.split('.')
    
    reactant_codes_list = reactant_codes.split(';')
    product_codes_list = product_codes.split(';')

    if rxn_side == 'lhs':
        cofactor_idx = reactant_codes_list.index(cofactor_code)
        return lhs_side_SMARTS_list[cofactor_idx]
    
    elif rxn_side == 'rhs':
        cofactor_idx = product_codes_list.index(cofactor_code)
        return rhs_side_SMARTS_list[cofactor_idx]

In [None]:
rxn_side = 'lhs'
get_cofactor_SMARTS_from_JN_rule('NAD_CoF', reactant_codes, product_codes, rxn_SMARTS, rxn_side)

'[#6:3]1:[#6:4]:[#6:5]:[#6:6]:[#7+:7]:[#6:8]:1'

In [10]:
query_rule = 'rule0002'
radius = 4
include_stereo = True

# initialize a counter to keep track of the number of reactions from which rules were not extracted
rxns_skipped_count = 0

# extract the reaction SMARTS for this reaction under Joseph's rule as well as reactant and product codes
reactant_codes = JN_rules_df[JN_rules_df['Name']==query_rule]['Reactants'].to_list()[0]
product_codes = JN_rules_df[JN_rules_df['Name']==query_rule]['Products'].to_list()[0]
rxn_SMARTS = JN_rules_df[JN_rules_df['Name']==query_rule]['SMARTS'].to_list()[0]

query_df = reported_rxns_df[reported_rxns_df['top_mapped_operator']==query_rule]
atom_mapped_rxns_list: List[str] = query_df['mapped'].to_list()
all_rxn_templates: List[str] = [] # initialize list to store all reaction templates extracted

# for each fully atom-mapped reaction
for rxn_SMARTS in atom_mapped_rxns_list:

    # create an instance of the reaction.mapped reaction class then extract atoms undergoing bond changes
    mapped_rxn = reaction.mapped_reaction(rxn_SMARTS) 
    changed_atoms, broken_bonds, formed_bonds = mapped_rxn.get_all_changed_atoms(include_cofactors=False, # set to False since we dont want changed cofactor atoms
                                                                                 consider_stereo=True,
                                                                                 cofactors_list=cofactors_list)
    
    substrates_list = mapped_rxn.get_substrates(cofactors_list=cofactors_list, consider_stereo=False)
    products_list = mapped_rxn.get_products(cofactors_list=cofactors_list, consider_stereo=False)
    lhs_cofactors_list = mapped_rxn.get_lhs_cofactors(cofactors_list=cofactors_list, consider_stereo=False)
    rhs_cofactors_list = mapped_rxn.get_rhs_cofactors(cofactors_list=cofactors_list, consider_stereo=False)

    # extract cofactor codes (leave out H+)
    lhs_cofactor_codes = [utils.get_cofactor_CoF_code(cofactor_smiles, cofactors_df) for cofactor_smiles in lhs_cofactors_list]
    rhs_cofactor_codes = [utils.get_cofactor_CoF_code(cofactor_smiles, cofactors_df) for cofactor_smiles in rhs_cofactors_list]

    lhs_cofactor_codes = [x for x in lhs_cofactor_codes if x!='H+']
    rhs_cofactor_codes = [x for x in rhs_cofactor_codes if x!='H+']

    # quick check to see if correct cofactor pairs are present according to JN1224MIN generalized reaction rules
    try:
        lhs_list = lhs_cofactor_codes
        rhs_list = [x for x in JN_rules_df[JN_rules_df['Name']==query_rule]['Reactants'].to_list()[0].split(';') if x.strip('_CoF') in all_cofactor_codes]
        assert Counter(lhs_list) == Counter(rhs_list)

        lhs_list = rhs_cofactor_codes
        rhs_list = [x for x in JN_rules_df[JN_rules_df['Name']==query_rule]['Products'].to_list()[0].split(';') if x.strip('_CoF') in all_cofactor_codes]
        assert Counter(lhs_list) == Counter(rhs_list)

        # initialize empty str to start building the reaction template for this rxn
        rxn_template = ''

        # check if reaction involves only a single substrate and single product
        if len(substrates_list)==1 and len(products_list)==1:
    
            # extract a template around the substrate
            substrate_template = mapped_rxn.get_template_around_rxn_site(atom_mapped_substrate_smarts=substrates_list[0],
                                                                         reactive_atom_indices=list(changed_atoms),
                                                                         radius=radius,
                                                                         include_stereo=include_stereo)
        
            # atoms maps for different substrates can have different atom indices depending on order so we reset
            substrate_template = utils.reset_atom_map(substrate_template)
        
            rxn_template += substrate_template

            # add cofactors to template
            if lhs_cofactor_codes != []:
                for cofactor_code in lhs_cofactor_codes:
                    cofactor_SMARTS = utils.get_cofactor_SMARTS_from_JN_rule(cofactor_code=cofactor_code, 
                                                                            reactant_codes=reactant_codes, 
                                                                            product_codes=product_codes,
                                                                            rxn_SMARTS=rxn_SMARTS,
                                                                            rxn_side='lhs')
                    
                    rxn_template += '.'

            # remove extra period added to the end
            rxn_template = rxn_template.rstrip('.')
            rxn_template += '>>' # separator between LHS and RHS of rxn
         
            # extract a template around the product
            product_template = mapped_rxn.get_template_around_rxn_site(atom_mapped_substrate_smarts=products_list[0],                                           
                                                                       reactive_atom_indices=list(changed_atoms),
                                                                       radius=radius,
                                                                       include_stereo=include_stereo)
        
            # atom maps for different products can have different atom indices depending on order so we reset
            product_template = utils.reset_atom_map(product_template)
        
            rxn_template += product_template
            all_rxn_templates.append(rxn_template) # store when finished

    except Exception as e:
        rxns_skipped_count += 1



In [11]:
all_rxn_templates

['>>',
 '[c:1][c&H1:2][c:3]([C@@&H1:4]([O&H1:5])[C&H2:6][O&H1:7])[c&H1:8][c&H1:9]>>[c:1][c&H1:2][c:3]([C@@&H1:4]([O&H1:5])[C&H1:6]=[O:7])[c&H1:8][c&H1:9]',
 '[C&H3:1][C@@&H1:2]([O&H1:3])[C&H2:4][C&H2:5][O&H1:6]>>[C&H3:1][C@@&H1:2]([O&H1:3])[C&H2:4][C&H1:5]=[O:6]',
 '>>',
 '[C&H3:1][C&H1:2]([C&H3:3])[C&H2:4][C&H2:5][O&H1:6]>>[C&H3:1][C&H1:2]([C&H3:3])[C&H2:4][C&H1:5]=[O:6]',
 '[O&H1:1][C&H2:2][C&H2:3][c:4]([c&H1:5][c&H1:6])[c&H1:7][c&H1:8]>>[O:1]=[C&H1:2][C&H2:3][c:4]([c&H1:5][c&H1:6])[c&H1:7][c&H1:8]',
 '[C&H3:1]/[C:2](=[C&H1:3]/[C&H2:4][O&H1:5])[C&H2:6][C&H2:7]>>[C&H3:1]/[C:2](=[C&H1:3]/[C&H1:4]=[O:5])[C&H2:6][C&H2:7]',
 '[O&H1:1][C&H2:2][C@@&H1:3]([O&H1:4])[C@&H1:5]([O&H1:6])[C@@&H1:7]([O&H1:8])[C&H2:9][O&H1:10]>>[O&H1:1][C&H2:2][C:3](=[O:4])[C@&H1:5]([O&H1:6])[C@@&H1:7]([O&H1:8])[C&H2:9][O&H1:10]',
 '[C&H3:1][C@@&H1:2]([O&H1:3])[C@&H1:4]([N&H3&+:5])[C:6](=[O:7])[O&-:8]>>[C&H3:1][C:2](=[O:3])[C@&H1:4]([N&H3&+:5])[C:6](=[O:7])[O&-:8]',
 '[C&H3:1][C@@&H1:2]([O&H1:3])[C@&H1:4]([N&H3&+:5

In [6]:
len(all_rxn_templates)

80

In [7]:
len(set(all_rxn_templates))

61

In [8]:
rxns_skipped_count

0