In [1]:
%load_ext autoreload
%autoreload 2

import numpy as np
import pandas as pd
import json
import rdkit
from rdkit import Chem
from rdkit.Chem import Draw, BRICS
from krxns.cheminfo import expand_unpaired_cofactors, mcs, draw_molecule, draw_reaction
from krxns.config import filepaths
from krxns.net_construction import SimilarityConnector, extract_compounds
from krxns.utils import str2int
import ipywidgets as widgets
from ipywidgets import interact
import matplotlib.pyplot as plt
from IPython.display import display, SVG
from itertools import product
from collections import defaultdict

In [2]:
# Load known reaction data
with open(filepaths['data'] / 'sprhea_240310_v3_mapped.json', 'r') as f:
    known_reactions = json.load(f)

known_reactions = {int(k): v for k,v in known_reactions.items()}

# Remove reverses
rids = set()
for k, v in known_reactions.items():
    rids.add(tuple(sorted([k, v['reverse']])))

keepers = [elt[0] for elt in rids]
known_reactions = {k: known_reactions[k] for k in keepers}

known_compounds, smi2id = extract_compounds(known_reactions)

# Load op connected reactions
with open(filepaths['connected_reactions'] / 'sprhea_240310_v3_mapped_operator.json', 'r') as f:
    op_cxn = str2int(json.load(f))

# Load sim connected reactions
with open(filepaths['connected_reactions'] / 'sprhea_240310_v3_mapped_similarity.json', 'r') as f:
    sim_cxn = str2int(json.load(f))

with open(filepaths['connected_reactions'] / 'sprhea_240310_v3_mapped_side_counts.json', 'r') as f:
    side_counts = str2int(json.load(f))

Spot check operator connected reactions

In [None]:
ocd_options = sorted(op_cxn.keys())
op_cxn_dropdown = widgets.Dropdown(options=ocd_options, value=ocd_options[0])

def show_op_cxn(rid):
    for rule in op_cxn[rid]:
        print(f"Rule: {rule}")
        print(f"Rct inlinks: {op_cxn[rid][rule]['rct_inlinks']}")
        print(f"Pdt inlinks: {op_cxn[rid][rule]['pdt_inlinks']}")
    display(SVG(draw_reaction(known_reactions[rid]['smarts'], sub_img_size=(300, 200))))

_ = interact(show_op_cxn, rid=op_cxn_dropdown)

Spot check similarity connected reactions

In [None]:
scd_options = list(sim_cxn.keys())
sim_cxn_dropdown = widgets.Dropdown(options=scd_options, value=scd_options[0])

def show_sim_cxn(rid):
    translate = lambda x : {known_compounds[outer]['name']: {known_compounds[inner]['name']: inner_v  for inner, inner_v in outer_v.items()} for outer, outer_v in x.items()}
    rct_inlinks = translate(sim_cxn[rid]['rct_inlinks'])
    pdt_inlinks = translate(sim_cxn[rid]['pdt_inlinks'])
    print("Rct inlinks")
    for k, v in rct_inlinks.items():
        print(k, v)
    print("\nPdt inlinks")
    for k, v in pdt_inlinks.items():
        print(k, v)
    display(SVG(draw_reaction(known_reactions[rid]['smarts'], sub_img_size=(300, 200))))

_ = interact(show_sim_cxn, rid=sim_cxn_dropdown)

Look at multiply mapped reactions

In [None]:
multiple_imt = {k: v for k, v in op_cxn.items() if len(v) > 1}
print(len(multiple_imt))
mimt_opts = sorted(multiple_imt.keys())
mimt_dropdown = widgets.Dropdown(options=mimt_opts, value=mimt_opts[0])

def show_both(rid):
    print("SIM CONNECTED")
    translate = lambda x : {known_compounds[outer]['name']: {known_compounds[inner]['name']: inner_v  for inner, inner_v in outer_v.items()} for outer, outer_v in x.items()}
    rct_inlinks = translate(sim_cxn[rid]['rct_inlinks'])
    pdt_inlinks = translate(sim_cxn[rid]['pdt_inlinks'])
    print("Rct inlinks")
    for k, v in rct_inlinks.items():
        print(k, v)
    print("\nPdt inlinks")
    for k, v in pdt_inlinks.items():
        print(k, v)

    print("\nOP CONNECTED")
    for rule in op_cxn[rid]:
        print(f"Rule: {rule}")
        print(f"Rct inlinks: {op_cxn[rid][rule]['rct_inlinks']}")
        print(f"Pdt inlinks: {op_cxn[rid][rule]['pdt_inlinks']}")

    display(SVG(draw_reaction(known_reactions[rid]['smarts'], sub_img_size=(300, 200))))

_ = interact(show_both, rid=mimt_dropdown)

The two cells below show:
1. Most of the multiply mapped reactions are exactly or directionally the same so can resolve easily by picking one
2. Those that are coming up as different seem to have to do with water, ambiguity around pentacovalent intermediates / situations with phosphates, and so would go away if I first converted to compound IDs, removing unpaired cofactors along the way

In [None]:
bad_rids = []

for rid in mimt_opts:
    exactly_same = True
    directionally_same = True

    full_side_count = set()
    for rule in op_cxn[rid]:
        nrcts = len(op_cxn[rid][rule]['rct_inlinks'])
        npdts = len(op_cxn[rid][rule]['pdt_inlinks'])
        full_side_count.add((nrcts, npdts))

    if len(full_side_count) > 1:
        raise ValueError("Number of substrates not consistent")

    rct_inlinks_check = defaultdict(set)
    pdt_inlinks_check = defaultdict(set)
    for rule in op_cxn[rid]:
        for side in op_cxn[rid][rule]:
            for i in op_cxn[rid][rule][side]:
                for j in op_cxn[rid][rule][side][i]:
                    if side == 'rct_inlinks':
                        rct_inlinks_check[(i, j)].add(op_cxn[rid][rule][side][i][j])
                    elif side == 'pdt_inlinks':
                        pdt_inlinks_check[(i, j)].add(op_cxn[rid][rule][side][i][j])

    if exactly_same:
        for v in rct_inlinks_check.values():
            if len(v) > 1:
                exactly_same = False
                break

    if exactly_same:
        for v in pdt_inlinks_check.values():
            if len(v) > 1:
                exactly_same = False
                break

    print(f"exactly_same: {exactly_same}")

    if not exactly_same:
        rct_inlinks_check = defaultdict(set)
        pdt_inlinks_check = defaultdict(set)
        for rule in op_cxn[rid]:
            for side in op_cxn[rid][rule]:
                for i in op_cxn[rid][rule][side]:
                    distro = list(op_cxn[rid][rule][side][i].values())
                    if side == 'rct_inlinks':
                        rct_inlinks_check[i].add(np.argmax(distro))
                    elif side == 'pdt_inlinks':
                        pdt_inlinks_check[i].add(np.argmax(distro))

        if directionally_same:
            for v in rct_inlinks_check.values():
                if len(v) > 1:
                    directionally_same = False
                    break

        if directionally_same:
            for v in pdt_inlinks_check.values():
                if len(v) > 1:
                    directionally_same = False
                    break

        print(f"Directionally same: {directionally_same}")

    if not exactly_same and not directionally_same:
        bad_rids.append(rid)

In [None]:
bad_opts = sorted(bad_rids)
bad_dropdown = widgets.Dropdown(options=bad_opts, value=bad_opts[0])
_ = interact(show_both, rid=bad_dropdown)

Spot check reactions with only similarity based connection

In [None]:
sim_only = {k: v for k, v in sim_cxn.items() if tuple(side_counts[k]) == (2, 2) and k not in op_cxn}
print(len(sim_only))
sim_only_opts = sorted(sim_only.keys())
sim_only_dropdown = widgets.Dropdown(options=sim_only_opts, value=sim_only_opts[0])

_ = interact(show_sim_cxn, rid=sim_only_dropdown)

Histogram of number of reactants / product for reactions I could not map

In [9]:

def filter_cofactors(known_reactions: dict[str, dict], cofactors: dict[str, str], smi2id: dict[str, int], paired_cofactors: dict[tuple, float] = {}):
    '''
    Filters cofactors out of known reaction dict
    '''
    paired_cofactors = {tuple(sorted(k)): v for k, v in paired_cofactors.items()}

    filtered_krs = {}
    for rid, rxn in known_reactions.items():
        lhs, rhs = [set(side.split(".")) for side in rxn['smarts'].split(">>")] # Set out stoichiometric degeneracy
        lhs = [smi2id[elt] for elt in lhs if elt not in cofactors]
        rhs = [smi2id[elt] for elt in rhs if elt not in cofactors]

        if not lhs or not rhs:
            continue

        to_remove = tuple()
        best_jaccard = 0
        for pair in product(lhs, rhs):
            srt_pair = tuple(sorted(pair))

            if srt_pair in paired_cofactors and paired_cofactors[srt_pair] > best_jaccard:
                to_remove = pair # Note NOT srt pair
                best_jaccard = paired_cofactors[srt_pair]

        if to_remove:
            lhs.remove(to_remove[0])
            rhs.remove(to_remove[1])

        if len(lhs) > len(rhs):
            tmp = lhs
            lhs = rhs
            rhs = tmp

        filtered_krs[rid] = (lhs, rhs)
    
    return filtered_krs

def plot_side_counts(side_counts: dict[tuple, set]):
    x_labels, cts = zip(*sorted([(k, len(v)) for k,v in side_counts.items()], key= lambda x : x[1], reverse=True))
    x = np.arange(len(x_labels))

    fig, ax = plt.subplots()
    ax.bar(x, height=cts)
    ax.set_xticks(x)
    ax.set_xticklabels(x_labels, rotation=45)
    ax.set_ylabel("# reactions")
    ax.set_xlabel("(# reactants, # products)")
    plt.show()

    return fig

In [None]:
# Compare with similarity connector

# Load unpaired cofs
unpaired_fp = filepaths['cofactors'] / "unpaired_cofactors_reference.tsv"
name_blacklist = [
    'acetyl-CoA',
]

unpaired_ref = pd.read_csv(
    filepath_or_buffer=unpaired_fp,
    sep='\t'
)

filtered_unpaired = unpaired_ref.loc[~unpaired_ref['Name'].isin(name_blacklist), :]
cofactors = expand_unpaired_cofactors(filtered_unpaired, k=10)

manual = {
    'N#N': 'N2',
    '[H][H]': 'H2',
    'S': 'hydrogen sufide',
    '[Cl-]': 'chloride',
    '[Na+]': 'sodium'
}

cofactors = {**cofactors, ** manual}

# Load cc sim mats
cc_sim_mats = {
    'mcs': np.load(filepaths['sim_mats'] / "mcs.npy"),
    'tanimoto': np.load(filepaths['sim_mats'] / "tanimoto.npy")
}

In [None]:
sc = SimilarityConnector(
    reactions=known_reactions,
    cc_sim_mats=cc_sim_mats,
    cofactors=cofactors,
    k_paired_cofactors=21,
)
smi2id = sc.smi2id
paired_cofactors = {pair: sc.cc_sim_mats['jaccard'][pair] for pair in sc.paired_cofactors}
missed_by_op = {k: v for k, v in known_reactions.items() if k not in op_cxn}

filtered_krs = filter_cofactors(missed_by_op, cofactors, smi2id, paired_cofactors)
side_counts = defaultdict(set)
for rid, (lhs, rhs) in filtered_krs.items():
    side_counts[tuple(sorted([len(x) for x in [lhs, rhs]]))].add(rid)

len(op_cxn)
new = plot_side_counts(side_counts)

In [None]:
trivial_cases = [k for k in side_counts if 1 in k]
sum([len(side_counts[k]) for k in trivial_cases])

Check for degeneracy when I translate operator-connected rct/pdt-idx-based adjacency matrices to compound-id-based

In [39]:
corner_cases = {}
for rid, rule_dict in op_cxn.items():
    smiles = [elt.split('.') for elt in known_reactions[rid]['smarts'].split('>>')]
    for rule, side_dict in rule_dict.items():
        for side_idx,  (side, outer_dict) in enumerate(side_dict.items()):
            

            tmp = defaultdict(set)
            for i, inner_dict in outer_dict.items():
                ismi = smiles[side_idx ^ 0][i]
                
                if ismi in cofactors:
                    continue
                iid = smi2id[ismi]
                
                for j, weight in inner_dict.items():
                    jsmi = smiles[side_idx ^ 1][j]
                    
                    if jsmi in cofactors:
                        continue

                    jid = smi2id[jsmi]
                    tmp[(iid, jid)].add(weight)

            for k, v in tmp.items():
                if len(v) > 1:
                    corner_cases[(rid, rule, side, iid, jid)] = v

In [None]:
rids, rules, sides, iids, jids = zip(*corner_cases.keys())

cc_options = sorted(set(rids))
cc_dropdown = widgets.Dropdown(options=cc_options, value=cc_options[0])

_ = interact(show_both, rid=cc_options)

In [None]:
len(set(rids))

In [None]:
{k: corner_cases[k] for k in corner_cases if 159 in k}

In [None]:
print(0 ^ 0, 0 ^ 1)
print(1 ^ 0, 1 ^ 0)

In [67]:
foo = {0: 18, 1: 11, 2: 3}
from collections import Counter
stoich_counter = Counter(foo.values())
if any([s > 1 for s in stoich_counter.values()]):
    print('yes')

In [69]:
from itertools import product

In [None]:
list(product(['A1', 'A2'], ['B1', 'B2'], ['Z']))

In [None]:
foo = {'id1':['A1', 'A2'], 'id2':['B1', 'B2'], 'id3':['Z']}
list(product(*foo.values()))