In [None]:
from flash_ansr import FlashANSRDataset, get_path, SkeletonPool
import matplotlib.pyplot as plt
import numpy as np
from tqdm import tqdm
from collections import defaultdict

In [2]:
# MODEL = 'v22.4-60M'
MODEL = 'v23.0-120M'

In [3]:
dataset = FlashANSRDataset.from_config(get_path('configs', MODEL, 'dataset_train.yaml'))

Compiling Skeletons: 100%|██████████| 107/107 [00:00<00:00, 20366.24it/s]
Compiling Skeletons: 100%|██████████| 1000/1000 [00:00<00:00, 15661.73it/s]
Compiling Skeletons: 100%|██████████| 107/107 [00:00<00:00, 19027.84it/s]


In [4]:
N_SAMPLES = 2**10
# N_SAMPLES = 2**10
print(f'{N_SAMPLES:,}')

1,024


In [None]:
from types import MethodType
from flash_ansr.expressions.skeleton_pool import NoValidSampleFoundError

def measure_holdout_rejection_rate(pool: SkeletonPool, n_trials: int = 2000):
    counts = defaultdict(int)
    original_is_held_out = pool.is_held_out

    def counting_is_held_out(self, skeleton, constants, code=None):
        result = original_is_held_out(skeleton, constants, code)
        counts["holdout_checks"] += 1
        counts["held_out"] += int(result)
        return result

    pool.is_held_out = MethodType(counting_is_held_out, pool)
    try:
        for _ in tqdm(range(n_trials)):
            try:
                pool.sample_skeleton(new=True, decontaminate=True)
                counts["accepted"] += 1
            except NoValidSampleFoundError:
                counts["failures"] += 1
    finally:
        pool.is_held_out = original_is_held_out

    total_checked = counts["holdout_checks"]
    rejection_rate = counts["held_out"] / total_checked if total_checked else float("nan")
    return counts, rejection_rate

In [7]:
N_TRIALS = 10_000
counts, rejection_rate = measure_holdout_rejection_rate(dataset.skeleton_pool, n_trials=N_TRIALS)

print(f"Holdout rejections: {counts['held_out']}/{counts['holdout_checks']} ({rejection_rate:.2%})")
print(f"Accepted skeletons: {counts['accepted']}")
print(f"Failed attempts (max retries hit): {counts['failures']}")

100%|██████████| 10000/10000 [00:22<00:00, 452.56it/s]

Holdout rejections: 305/10169 (3.00%)
Accepted skeletons: 9864
Failed attempts (max retries hit): 136



