In [1]:
from cgr.filepaths import filepaths
from cgr.draw import draw_molecule, draw_reaction
from cgr.cheminfo import MorganFingerprinter, extract_subgraph
import json
from IPython.display import SVG
from rdkit import Chem
import numpy as np
import pandas as pd
import ipywidgets as widgets
from ipywidgets import interact
from scipy.stats import entropy
import matplotlib.pyplot as plt
from collections import defaultdict
from collections import Counter
from itertools import chain

In [2]:
def pick_leaves(embeds, leaf, translator=None):
    asked, answered = zip(*[elt for elt in chain(*leaf)])

    if translator:
        asked = translator[asked]

    reaction_mask = []
    for q, a in zip(asked, answered):
        reaction_mask.append(embeds[:, q] == a)
    
    reaction_idxs = np.argwhere(np.prod(np.vstack(reaction_mask), axis=0).astype(int))

    return reaction_idxs.ravel()

In [None]:
krs = filepaths.data / "raw" / "sprhea_240310_v3_mapped_no_subunits.json"
with open(krs, 'r') as f:
    krs = json.load(f)

decarb = {k: v for k,v  in krs.items() if v['min_rule'] == 'rule0024'}
print(len(decarb))

In [4]:
max_hops = 3
vec_len = 2**12
mfper = MorganFingerprinter(radius=max_hops, length=vec_len, allocate_ao=True)
rc_dist_ub = None
n_samples = len(decarb)

full_embeds = []
subgraph_cts = defaultdict(lambda : defaultdict(int)) # {bit_idx: {(rid, central_aidx, radius): count}}
bit_examples = defaultdict(list) # {bit_idx: dict(mol, rid, central_aidx, radius)}
row2rid = list(decarb.keys())
for rid, rxn in decarb.items():
        rc = rxn['reaction_center'][0]
        smiles = rxn['smarts'].split('>>')[0]
        mol = Chem.MolFromSmiles(smiles)

        full_embeds.append(mfper.fingerprint(mol, reaction_center=rc, rc_dist_ub=rc_dist_ub))
        bim = mfper.bit_info_map


        for bit_idx, examples in bim.items():
            for (central_aidx, radius) in examples:
                bit_examples[bit_idx].append(
                    {
                        'mol': mol,
                        'rid': rid,
                        'central_aidx': central_aidx,
                        'radius': radius
                    }
                )

                sub_idxs, sub_mol, sub_smi = extract_subgraph(mol, central_aidx, radius)

                subgraph_cts[bit_idx][(sub_smi, radius)] += 1

full_embeds = np.vstack(full_embeds)

r2bits = defaultdict(list) # {radius: [bit idxs]}
for bit_idx, examples in subgraph_cts.items():
    r_max = sorted(examples.items(), key= lambda x : x[1], reverse=True)[0][0][1] # Sort by frequency over rxns
    r2bits[r_max].append(bit_idx)

# Add in only the most common. This is wrong?
embed_stack = np.zeros(shape=(n_samples, vec_len, max_hops + 1))
for r, bits in r2bits.items():
    embed_stack[:, bits, r] = full_embeds[:, bits]

p1 = embed_stack.sum(axis=0) / n_samples
p1 = p1[np.newaxis, :]
probas = np.vstack((p1, (1 - p1)))
H = entropy(pk=probas, axis=0, base=2)
p1 = p1.squeeze()

# Filter out non-majority examples
tmp = {}
for r, idxs in r2bits.items():
     for idx in idxs:
          tmp[idx] = [elt for elt in bit_examples[idx] if elt['radius'] == r]

bit_examples = tmp

resolved_embeds = embed_stack.sum(axis=-1)

Scale-separated loc ecfp embeddings

In [None]:
# rng = np.random.default_rng(seed=1234) # TODO seed=1234, resampling same integer for first hash / substruct
n_egs = 1
topk = 100
khop = 1 # Scale desired - how many hops from central atom
sort_by = p1

srt_idx = np.argsort(sort_by[:, khop])[::-1]

for idx in srt_idx[:topk]:
    egs = bit_examples[idx][:n_egs]
    print(f"Bit idx: {idx}")
    print(f"Entropy = {H[idx, khop]} bits")
    print(f"Probability: {p1[idx, khop]:.2f}")
    for eg in egs:
        mol = eg['mol']
        aidx = eg['central_aidx']
        r = eg['radius']

        sub_idxs, sub_mol, sub_smi = extract_subgraph(mol, aidx, r)

        display(SVG(draw_molecule(mol, size=(300, 300), hilite_atoms=tuple(sub_idxs))))
    
    print('-' * 50)
    

In [None]:
ft_proba_mass = resolved_embeds.sum(axis=0) / n_samples
nonzero_features = np.where((ft_proba_mass > 0) * (ft_proba_mass < 1))[0]
nonzero_embeds = resolved_embeds[:, nonzero_features]

resolved_p1 = nonzero_embeds.sum(axis=0) / n_samples
resolved_probas = np.vstack((resolved_p1, (1 - resolved_p1)))
resolved_H = entropy(pk=resolved_probas, axis=0, base=2)

fig, ax = plt.subplots()
ax.plot(np.arange(1, nonzero_embeds.shape[1] + 1), sorted(resolved_p1, reverse=True))
ax.set_ylabel("P(feature = 1)")
ax.set_xlabel("Feature rank")
ax.hlines(1 / n_samples, xmin=1, xmax=nonzero_embeds.shape[1], color='black', linestyles='--', label="1 / # samples")
ax.legend()
plt.show()

In [None]:
scl = 5
(resolved_p1 <= (scl / n_samples)).sum() / resolved_p1.shape

Cluster structural features

In [8]:
def prune_embeds(embeds: np.ndarray, asked: list, not_asked: list, answered: list):
    reaction_mask = []
    for q, a in zip(asked, answered):
        reaction_mask.append(embeds[:, q] == a)
    
    if reaction_mask:
        reaction_mask = np.prod(np.vstack(reaction_mask), axis=0).astype(bool)
        remaining_embeds = embeds[reaction_mask, :][:, not_asked]
    else:
        remaining_embeds = embeds[:, not_asked]

    return remaining_embeds

def find_feature_clusters(embeds: np.ndarray, scl_lb:int = 1, leaves: list = []):
    def bts(qna: list[tuple[tuple]] = []):
        if qna:
            asked, answered = zip(*[elt for elt in chain(*qna)])
        else:
            asked, answered = [], []

        not_asked = np.array([i for i in range(embeds.shape[1]) if i not in asked]).astype(int)

        remaining_embeds = prune_embeds(embeds, asked, not_asked, answered)

        if remaining_embeds.shape[0] < 2:
            return qna

        n_remaining_rxns = remaining_embeds.shape[0]
        remaining_p1 = remaining_embeds.sum(axis=0) / n_remaining_rxns
        next_question = np.argmax(remaining_p1)

        if remaining_p1[next_question] <= (scl_lb / n_remaining_rxns):
            return qna

        next_distribution = remaining_embeds[:, next_question] # n_samples, 
        dots = remaining_embeds.T @ next_distribution.reshape(-1, 1) # n_remaingin_fts x 1
        jaccards = (dots / (next_distribution.sum() + remaining_embeds.T.sum(axis=1).reshape(-1, 1) - dots)).reshape(-1,)
        next_question = np.where(jaccards == 1)[0] # Get completely redundant features
        next_question = [int(elt) for elt in not_asked[next_question]] # Translate to indices of full feature space

        for ans in range(2):
            next_qna = tuple(zip(next_question, [ans for _ in range(len(next_question))]))
            leaves.append(
                bts(qna=qna + [next_qna])
            )
    
    bts()
    leaves = [l for l in leaves if l is not None]

    return leaves

In [None]:
test_embeds = np.array(
    [
        [1, 1, 1],
        [1, 0, 0],
        [0, 0, 0]
    ]
)

test_leaves = []
find_feature_clusters(test_embeds, leaves=test_leaves)

In [10]:
nonzero_leaves = []
nonzero_leaves = find_feature_clusters(nonzero_embeds, scl_lb=5, leaves=nonzero_leaves)

In [None]:
len([l for l in nonzero_leaves if l is not None])

In [None]:
n_levels = []
n_rxns = []
for leaf in nonzero_leaves:
    n_levels.append(len(leaf))
    n_rxns.append(len(pick_leaves(nonzero_embeds, leaf)))

print(n_levels)
print(n_rxns)

In [None]:
i = 5
n_egs = 1
leaves = nonzero_leaves

leaf = leaves[i]

print([elt[0][1] for elt in leaf])
for redundant_group in leaf:
    q, a = redundant_group[0]
    idx = nonzero_features[q] # Translate
    egs = bit_examples[idx][:n_egs]
    print(f"Bit idx: {idx}. Present? {a}")
    print(f"Entropy = {resolved_H[q]} bits")
    print(f"Probability: {resolved_p1[q]:.2f}")
    for eg in egs:
        mol = eg['mol']
        aidx = eg['central_aidx']
        r = eg['radius']

        sub_idxs, sub_mol, sub_smi = extract_subgraph(mol, aidx, r)

        display(SVG(draw_molecule(mol, size=(300, 300), hilite_atoms=tuple(sub_idxs))))
        
    print('-' * 50)

reaction_rows = pick_leaves(nonzero_embeds, leaf)
print(f"# cluster reactions: {len(reaction_rows)}")
for row in reaction_rows:
    rxn = decarb[row2rid[row]]
    print(rxn['imt_rules'])
    print(rxn['rhea_ids'])
    smiles = rxn['smarts'].split('>>')[0]
    rc = rxn['reaction_center'][0]
    display(SVG(draw_molecule(smiles, hilite_atoms=rc, size=(300, 300))))

In [None]:
for leaf in nonzero_leaves:
    reaction_rows = pick_leaves(nonzero_embeds, leaf)
    print(f"# levels: {len(leaf)}")
    print(f"# cluster reactions: {len(reaction_rows)}")
    for row in reaction_rows:
        rxn = decarb[row2rid[row]]
        print(rxn['imt_rules'])
        print(rxn['rhea_ids'])
        smiles = rxn['smarts'].split('>>')[0]
        rc = rxn['reaction_center'][0]
        display(SVG(draw_molecule(smiles, hilite_atoms=rc, size=(300, 300))))
    print("-" * 50)


Correlation & anti-correlation

In [15]:
embeds = full_embeds # embed_stack[:, :, 0]

ft_proba_mass = embeds.sum(axis=0) / n_samples
nonzero_features = np.where((ft_proba_mass > 0) * (ft_proba_mass < 1))[0]
nonzero_embeds = embeds[:, nonzero_features]
directed_embeds = (nonzero_embeds - 0.5) * 2
hamming_corr = np.matmul(directed_embeds.T, directed_embeds) / directed_embeds.shape[0]
triu_idxs = np.triu_indices_from(hamming_corr, k=1)
hamming_corr_upper = hamming_corr[triu_idxs]
feature_pairs = list(zip(*triu_idxs))


In [16]:
interaction_triple = np.zeros(shape=(nonzero_embeds.shape[1], nonzero_embeds.shape[1], 3))
for i in range(nonzero_embeds.shape[1] - 1):
    for j in range(i + 1, nonzero_embeds.shape[1]):
        raw_weights = Counter(nonzero_embeds[:, (i, j)].sum(axis=1))
        idxs, counts = zip(*raw_weights.items())
        counts = [ct / n_samples for ct in counts]
        interaction_triple[i, j, idxs] = counts