In [48]:
import json
import pandas as pd
from rdkit import Chem
from rdkit.Chem import AllChem
from src.utils import construct_sparse_adj_mat
from pathlib import Path
from collections import Counter, defaultdict
from itertools import chain
import numpy as np

In [18]:
rng = np.random.default_rng(42)
n = 1000
ps = 0.8
y = np.concatenate([np.zeros(n), np.ones(n)])
get_y_pred = lambda n, ps: np.random.rand(n) < ps
y_pred = np.concatenate([get_y_pred(n, ps), ~get_y_pred(n, ps)]).astype(int)

In [19]:
n_bootstraps = 20
get_acc = lambda y, y_pred: (y == y_pred).mean()
mean_acc = get_acc(y, y_pred)
bootstrap_accs = []
for _ in range(n_bootstraps):
    indices = np.random.choice(len(y), len(y), replace=True)
    bootstrap_accs.append(get_acc(y[indices], y_pred[indices]))

ci_lo = np.percentile(bootstrap_accs, 2.5)
ci_hi = np.percentile(bootstrap_accs, 97.5)
sttdev = np.std(bootstrap_accs)
print(f"Accuracy: {mean_acc:.3f} ({ci_lo:.3f}, {ci_hi:.3f})")
print(f"Plus or minus {sttdev:.3f}")

Accuracy: 0.200 (0.183, 0.218)
Plus or minus 0.009


In [20]:
def foo(generator, n):
    print(generator.random(n))

In [31]:
foo(rng, 10)
foo(rng, 10)

[0.10774095 0.91601185 0.23021399 0.03741256 0.55485247 0.37092228
 0.82978974 0.80825147 0.31713889 0.9528994 ]
[0.29091784 0.51505713 0.25596509 0.93604357 0.16460782 0.04491062
 0.43509706 0.99237556 0.89167727 0.74860802]


In [34]:
dir = "/home/stef/quest_data/hiec/data/sprhea"


with open(Path(dir) / "v3_folded_pt_ns.json", 'r') as f:
    krs = json.load(f)

with open(Path(dir) / "v3_folded_pt_ns_arc_unobserved_reactions.json", 'r') as f:
    unobs_rxns = json.load(f)

rxns = {**krs, **unobs_rxns}
clusters = defaultdict(set)
for k, v in rxns.items():
    rules = tuple(sorted(v['min_rules']))
    clusters[rules].add(k)

get_n_combos = lambda n: n * (n - 1) // 2
n_combos = sum(get_n_combos(len(v)) for v in clusters.values())
print(f"Number of reaction pairs from same rule cluster: {n_combos:,}")


Number of reaction pairs from same rule cluster: 556,339,628


In [36]:
dir = Path("/home/stef/quest_data/hiec/scratch/sprhea_v3_folded_pt_ns/rcmcs/3fold")

train_data = pd.concat(
    pd.read_parquet(dir / f"train_val_{i}.parquet") for i in range(3)
)

train_data.head()

Unnamed: 0,protein_idx,reaction_idx,pid,rid,protein_embedding,am_smarts,reaction_center,y
0,222,187,Q9KJ20,7694,"[-0.012849729, 0.16863914, -0.07288795, -0.055...",[S+:1]([CH2:2][CH2:4][CH:6]([NH2:9])[C:10](=[O...,"[[[0, 1], [1]], [[14], [0, 1]]]",1
1,223,187,Q83WC3,7694,"[-0.020623509, 0.14122067, -0.12404032, -0.027...",[S+:1]([CH2:2][CH2:4][CH:6]([NH2:9])[C:10](=[O...,"[[[0, 1], [1]], [[14], [0, 1]]]",1
2,224,187,Q9KJ21,7694,"[-0.01765835, 0.14142956, -0.0953328, -0.06397...",[S+:1]([CH2:2][CH2:4][CH:6]([NH2:9])[C:10](=[O...,"[[[0, 1], [1]], [[14], [0, 1]]]",1
3,13363,187,Q9KJ22,7694,"[-0.044674527, 0.17813939, -0.057687543, -0.07...",[S+:1]([CH2:2][CH2:4][CH:6]([NH2:9])[C:10](=[O...,"[[[0, 1], [1]], [[14], [0, 1]]]",1
4,13364,187,Q7U4Z8,7694,"[-0.029929712, 0.16648583, -0.016743593, -0.03...",[S+:1]([CH2:2][CH2:4][CH:6]([NH2:9])[C:10](=[O...,"[[[0, 1], [1]], [[14], [0, 1]]]",1


In [37]:
test_data = pd.read_parquet(dir / "test.parquet")
test_data.head()

Unnamed: 0,protein_idx,reaction_idx,pid,rid,protein_embedding,am_smarts,reaction_center,y
0,16113,5226,Q09KQ6,16449,"[0.038376495, 0.18099314, 0.10217904, 0.086344...",[C:1](=[CH:3][CH:5]=[CH:8][C:9](=[O:10])[OH:11...,"[[[6, 7], [0]], [[1, 0], [0]]]",1
1,21408,5226,Q9CZU6,16449,"[0.046930913, 0.19558673, 0.035590347, -0.0470...",[C:1](=[CH:3][CH:5]=[CH:8][C:9](=[O:10])[OH:11...,"[[[6, 7], [0]], [[1, 0], [0]]]",0
2,7098,5226,Q5XI78,16449,"[-0.012881524, 0.28901568, 0.061466977, -0.024...",[C:1](=[CH:3][CH:5]=[CH:8][C:9](=[O:10])[OH:11...,"[[[6, 7], [0]], [[1, 0], [0]]]",0
3,11061,5226,Q8RXV3,16449,"[-0.19607657, 0.31087485, -0.117286794, -0.082...",[C:1](=[CH:3][CH:5]=[CH:8][C:9](=[O:10])[OH:11...,"[[[6, 7], [0]], [[1, 0], [0]]]",0
4,24151,5226,O30418,16449,"[0.0678145, 0.28062484, -0.051406916, 0.019343...",[C:1](=[CH:3][CH:5]=[CH:8][C:9](=[O:10])[OH:11...,"[[[6, 7], [0]], [[1, 0], [0]]]",0


In [38]:
train_rxns = set(train_data['rid'])
test_rxns = set(test_data['rid'])

In [None]:
train_rules = set([tuple(sorted(v['min_rules'])) for k, v in krs.items() if k in train_rxns])
test_rules = set([tuple(sorted(v['min_rules'])) for k, v in krs.items() if k in test_rxns])


In [51]:
train_only_rules = train_rules - test_rules
test_only_rules = test_rules - train_rules
print(len(train_only_rules), len(test_only_rules))
print(len(train_rules & test_rules))

216 43
132


In [47]:
unobs_clusters = defaultdict(set)
for k in unobs_rxns.keys():
    rules = tuple(sorted(unobs_rxns[k]['min_rules']))
    unobs_clusters[rules].add(k)

In [49]:
train_only_unobs = set(chain(*(unobs_clusters[r] for r in train_only_rules if r in unobs_clusters)))
test_only_unobs = set(chain(*(unobs_clusters[r] for r in test_only_rules if r in unobs_clusters)))
print(len(train_only_unobs), len(test_only_unobs))

4867 400
