Leave one out ec classification.

Get:
- n_rxns x n_rxns similarity matrix
- rxn idx to ec look up table

leave one out, get top k most similar reactions by metric
compare ecs, score

In [13]:
from hydra import initialize, compose
import numpy as np
import polars as pl
from pathlib import Path
import json
from src.utils import construct_sparse_adj_mat

In [8]:
with initialize(config_path="../configs/filepaths", version_base=None):
    fps = compose(config_name='base')

In [20]:
toc = pl.read_csv(
    Path(fps.data) / "sprhea" / "v3_folded_pt_ns.csv",
    separator='\t'
)

adj, idx_sample, idx_feature = construct_sparse_adj_mat(Path(fps.data) / "sprhea" / "v3_folded_pt_ns.csv")
rid2idx = {int(v) : k for k, v in idx_feature.items()}

with open(Path(fps.data) / "sprhea" / "v3_folded_pt_ns.json", 'r') as f:
    tmp = json.load(f)
    rxn_details = {int(k): v for k, v in tmp.items()}
    del tmp

# Map all stable rhea ids at 250915 to single unique idx for convernience
rhea_directions = pl.read_csv(Path(fps.data) / 'rhea-directions.tsv', separator='\t')
any_rhea_to_working_idx = {}
for i, row in enumerate(rhea_directions.iter_rows(named=True)):
    any_rhea_to_working_idx[row['RHEA_ID_MASTER']] = i
    any_rhea_to_working_idx[row['RHEA_ID_LR']] = i
    any_rhea_to_working_idx[row['RHEA_ID_RL']] = i
    any_rhea_to_working_idx[row['RHEA_ID_BI']] = i

rhea2ec = pl.read_csv(Path(fps.data) / 'rhea2ec.tsv', separator='\t')
rhea2ec = {any_rhea_to_working_idx[row['RHEA_ID']]: row['ID'] for row in rhea2ec.iter_rows(named=True) if row['RHEA_ID'] in any_rhea_to_working_idx}

Constructing v3_folded_pt_ns sparse adjacency matrix


In [49]:
idx2ec = {}
_issues = []
for rid, elt in rxn_details.items():
    for rhea_id in elt['rhea_ids']:
        working_idx = any_rhea_to_working_idx.get(rhea_id)
        if working_idx is None:
            continue

        ec = rhea2ec.get(working_idx)
        
        if ec is None:
            continue
        else:
            break
    
    if ec is None:
        _issues.append(rid)
    else:
        idx2ec[rid2idx[rid]] = {tuple(ec.split('.'))}

print(len(idx2ec))

3190


In [54]:
for rid in _issues:
    ecs = set()
    for enz in rxn_details[rid]['enzymes']:
        if enz['ec'] is None:
            continue
        for ec in enz['ec'].split(';'):
            ecs.add(tuple(ec.strip().split('.')))
    
    idx2ec[rid2idx[rid]] = ecs

print(len(idx2ec))
print(len(rxn_details))

6460
6460


In [55]:
def get_top_knn(sim_mat: np.ndarray, top_k: int) -> dict[int, list[int]]:
    knn_dict = {}
    for i in range(sim_mat.shape[0]):
        knn_idx = np.argsort(-sim_mat[i, :])[:top_k + 1]
        knn_idx = knn_idx[knn_idx != i]
        knn_dict[i] = knn_idx.tolist()
    return knn_dict

S_rcmcs = np.load(Path(fps.results) / 'similarity_matrices' / 'sprhea_v3_folded_pt_ns_rcmcs.npy')
S_drfp = np.load(Path(fps.results) / 'similarity_matrices' / 'sprhea_v3_folded_pt_ns_drfp.npy')

In [56]:
k = 10
knn_rcmcs = get_top_knn(S_rcmcs, top_k=k)
knn_drfp = get_top_knn(S_drfp, top_k=k)

In [59]:
for idx, neighbors in knn_rcmcs.items():
    assert idx not in neighbors

for idx, neighbors in knn_drfp.items():
    assert idx not in neighbors

In [57]:
def compare_ecs(ec_set_1: set[tuple[str, str, str, str]], ec_set_2: set[tuple[str, str, str, str]], level: int) -> bool:
    _ec_set_1 = {ec_tuple[:level] for ec_tuple in ec_set_1}
    _ec_set_2 = {ec_tuple[:level] for ec_tuple in ec_set_2}
    return len(_ec_set_1.intersection(_ec_set_2)) > 0

def top_k_accuracy(knn_dict: dict[int, list[int]], idx2ec: dict[int, set[tuple[str, str, str, str]]], level: int) -> float:
    correct = 0
    total = 0
    for idx, neighbors in knn_dict.items():
        if idx not in idx2ec:
            continue
        total += 1
        ec_set = idx2ec[idx]
        found = False
        for neighbor in neighbors:
            if neighbor not in idx2ec:
                continue
            neighbor_ec_set = idx2ec[neighbor]
            if compare_ecs(ec_set, neighbor_ec_set, level):
                found = True
                break
        if found:
            correct += 1
    return correct / total

In [61]:
ks = [1, 2, 3, 4]
levels = [1, 2, 3]

data = []
for level in levels:
    print(f"EC Level {level} Accuracy:")
    for k in ks:
        acc_rcmcs = top_k_accuracy({idx: neighbors[:k] for idx, neighbors in knn_rcmcs.items()}, idx2ec, level)
        acc_drfp = top_k_accuracy({idx: neighbors[:k] for idx, neighbors in knn_drfp.items()}, idx2ec, level)
        print(f"  Top-{k} RCMCS: {acc_rcmcs:.4f}, DRFP: {acc_drfp:.4f}")
        data.append([level, k, acc_rcmcs, acc_drfp])

EC Level 1 Accuracy:
  Top-1 RCMCS: 0.9704, DRFP: 0.8935
  Top-2 RCMCS: 0.9786, DRFP: 0.9272
  Top-3 RCMCS: 0.9796, DRFP: 0.9467
  Top-4 RCMCS: 0.9842, DRFP: 0.9560
EC Level 2 Accuracy:
  Top-1 RCMCS: 0.9489, DRFP: 0.8068
  Top-2 RCMCS: 0.9590, DRFP: 0.8506
  Top-3 RCMCS: 0.9613, DRFP: 0.8740
  Top-4 RCMCS: 0.9642, DRFP: 0.8867
EC Level 3 Accuracy:
  Top-1 RCMCS: 0.9237, DRFP: 0.7658
  Top-2 RCMCS: 0.9420, DRFP: 0.8119
  Top-3 RCMCS: 0.9457, DRFP: 0.8367
  Top-4 RCMCS: 0.9485, DRFP: 0.8515


In [68]:
schema = {
    "ec_level": pl.Int64,
    "top_k": pl.Int64,
    "rcmcs_accuracy": pl.Float64,
    "drfp_accuracy": pl.Float64,
}

results = pl.DataFrame(data, schema=schema, orient='row')
results

ec_level,top_k,rcmcs_accuracy,drfp_accuracy
i64,i64,f64,f64
1,1,0.970433,0.893498
1,2,0.978638,0.927245
1,3,0.979567,0.946749
1,4,0.984211,0.956037
2,1,0.948916,0.806811
…,…,…,…
2,4,0.964241,0.886687
3,1,0.923684,0.765789
3,2,0.94195,0.81192
3,3,0.945666,0.836687


In [69]:
results.write_csv(Path(fps.artifacts) / 'leave_one_out_ec_classification.csv')

In [72]:
with open(Path(fps.artifacts) / 'sprhea_folded_pt_ns_idx2ec.json', 'w') as f:
    json.dump({str(k): [ '.'.join(ec_tuple) for ec_tuple in v ] for k, v in idx2ec.items()}, f)