In [63]:
import json
from pathlib import Path
from src.cross_validation import parse_cd_hit_clusters
from rdkit import Chem
from tqdm import tqdm
from collections import Counter
import numpy as np

In [2]:
dir = "/home/stef/hiec/artifacts/clustering"
dataset = 'sprhea'
toc = 'v3_folded_pt_ns'
strats = ['rcmcs', 'esm', 'gsi', 'blosum']
cutoffs = [
    (80, 60, 40),
    (99, 98, 97, 95),
    (35, 30, 25),
    (80, 60, 40)
]

In [3]:
for strat, cuts in zip(strats, cutoffs):
    for c in cuts:
        path = Path(dir) / f"{dataset}_{toc}_{strat}_{c}.json"
        with open(path) as f:
            data = json.load(f)

        print(f"{strat} {c}: {len(set(data.values()))}")

rcmcs 80: 2528
rcmcs 60: 1351
rcmcs 40: 758
esm 99: 9382
esm 98: 5768
esm 97: 3310
esm 95: 691
gsi 35: 8332
gsi 30: 2656
gsi 25: 334
blosum 80: 716
blosum 60: 72
blosum 40: 69


In [78]:
def score(x : np.ndarray) -> float:
    return sum(x) / len(x)

def bootstrap_score(metric: callable) -> callable:
    def f(x : np.ndarray, *args) -> tuple[float, float]:
        scores = []
        for _ in range(5):
            idx = np.random.choice(len(x), len(x), replace=True)
            scores.append(metric(x[idx], *args))

        return metric(x), np.std(scores)
    return f

bs_score = bootstrap_score(score)

x = np.array([1, 0, 1, 1, 0, 1, 0, 1, 1, 0])
bs_score(x)
            

(np.float64(0.6), np.float64(0.10198039027185568))