# Dataset size summaries for frequent-hitters

This notebook computes dataset sizes and filtering statistics used in the paper based on the current assay-cleaning + dataset pipeline outputs. Paths point at the raw HTS inputs, cleaned parquet files, and the four processed model-ready datasets, and we include a quick split-consistency check.

In [10]:
from pathlib import Path
from typing import Optional, Tuple, Dict
import json

import polars as pl
from rdkit import Chem, RDLogger
from rdkit.Chem import Descriptors
from rdkit.Chem.MolStandardize import rdMolStandardize

# Silence RDKit warnings
RDLogger.logger().setLevel(RDLogger.CRITICAL)

# Show Polars version for reproducibility
print("Polars version:", pl.__version__)

Polars version: 1.35.2


In [17]:

# Adjust these paths if you run the notebook from a different working directory.
project_root = Path.cwd().parent

assay_tables_dir = project_root / "assay-etl" / "data" / "assay_tables"
assay_metadata_path = project_root / "assay-etl" / "outputs" / "assay_metadata.csv"
assay_rscores_path = project_root / "assay-etl" / "outputs" / "assay_rscores.parquet"
assay_cleaning_stats_path = project_root / "data" / "clean" / "assay_cleaning_stats.json"

# Paths to assay-cleaning outputs for the full dataset.
clean_biochemical_path = project_root / "data" / "clean" / "biochemical.parquet"
clean_cellular_path = project_root / "data" / "clean" / "cellular.parquet"

# Final processed dataset locations (one parquet per assay format and model type).
processed_dir = project_root / "data" / "processed50"
processed_paths = {
    "biochemical_regression": processed_dir / "biochemical" / "biochemical_regression.parquet",
    "biochemical_multilabel": processed_dir / "biochemical" / "biochemical_multilabel.parquet",
    "cellular_regression": processed_dir / "cellular" / "cellular_regression.parquet",
    "cellular_multilabel": processed_dir / "cellular" / "cellular_multilabel.parquet",
}
scaffold_assignments_path = processed_dir / "scaffold_assignments.parquet"

print("Assay tables dir:", assay_tables_dir)
print("Assay metadata:", assay_metadata_path)
print("Assay rscores:", assay_rscores_path)
print("Assay-cleaning stats:", assay_cleaning_stats_path)
print("Clean biochemical path:", clean_biochemical_path)
print("Clean cellular path:", clean_cellular_path)
print("Processed outputs:")
for name, path in processed_paths.items():
    print(f"  {name}: {path}")


Assay tables dir: /Users/snappi/frequent-hitters/assay-etl/data/assay_tables
Assay metadata: /Users/snappi/frequent-hitters/assay-etl/outputs/assay_metadata.csv
Assay rscores: /Users/snappi/frequent-hitters/assay-etl/outputs/assay_rscores.parquet
Assay-cleaning stats: /Users/snappi/frequent-hitters/data/clean/assay_cleaning_stats.json
Clean biochemical path: /Users/snappi/frequent-hitters/data/clean/biochemical.parquet
Clean cellular path: /Users/snappi/frequent-hitters/data/clean/cellular.parquet
Processed outputs:
  biochemical_regression: /Users/snappi/frequent-hitters/data/processed50/biochemical/biochemical_regression.parquet
  biochemical_multilabel: /Users/snappi/frequent-hitters/data/processed50/biochemical/biochemical_multilabel.parquet
  cellular_regression: /Users/snappi/frequent-hitters/data/processed50/cellular/cellular_regression.parquet
  cellular_multilabel: /Users/snappi/frequent-hitters/data/processed50/cellular/cellular_multilabel.parquet


## 1. Assays with ≥10k substances and unique compounds

Counts based on the pre-filtered assay tables in `assay-etl/data/assay_tables`.

In [18]:

# List assay table parquet files (each file corresponds to one assay with ≥10k substances).
assay_table_files = sorted(assay_tables_dir.glob("aid_*.parquet"))

assay_table_df = pl.DataFrame(
    {
        "assay_id": [int(p.stem.split("_")[1]) for p in assay_table_files],
        "path": [str(p) for p in assay_table_files],
    }
)

num_assays_ge_10k = assay_table_df.height
print(f"Number of assays with ≥10k substances (assay tables): {num_assays_ge_10k}")

assay_level_stats = pl.DataFrame()

if assay_table_files:
    # Inspect the first file to infer column names.
    first_schema_cols = pl.read_parquet(assay_table_files[0], n_rows=0).columns
    print("Example columns (first assay table):", first_schema_cols)

    cid_candidates = [c for c in first_schema_cols if c.lower().endswith("cid") or c.lower() == "pubchem_cid"]
    sid_candidates = [c for c in first_schema_cols if c.lower().endswith("sid") or c.lower() == "pubchem_sid"]

    cid_col = cid_candidates[0] if cid_candidates else None
    sid_col = sid_candidates[0] if sid_candidates else None

    per_file_lfs = []
    has_sid = False
    for path in assay_table_files:
        file_cols = pl.read_parquet(path, n_rows=0).columns
        lf = pl.scan_parquet(path)
        projections = []
        if cid_col and cid_col in file_cols:
            projections.append(pl.col(cid_col).cast(pl.Utf8).alias("_cid"))
        if sid_col and sid_col in file_cols:
            projections.append(pl.col(sid_col).cast(pl.Utf8).alias("_sid"))
            has_sid = True
        if not projections:
            continue
        per_file_lfs.append(lf.select(projections))

    if per_file_lfs:
        assay_tables_lf = pl.concat(per_file_lfs)

        select_exprs = [pl.len().alias("num_rows")]
        if has_sid:
            select_exprs.append(pl.col("_sid").n_unique().alias("num_unique_substances"))
        if cid_col:
            select_exprs.append(pl.col("_cid").n_unique().alias("num_unique_compounds"))

        assay_level_stats = assay_tables_lf.select(select_exprs).collect(engine="streaming")

assay_level_stats


Number of assays with ≥10k substances (assay tables): 1323
Example columns (first assay table): ['PUBCHEM_CID', 'PUBCHEM_RESULT_TAG', 'PUBCHEM_ACTIVITY_SCORE', 'LogGI50_M', 'LogGI50_u', 'LogGI50_V', 'IndnGI50', 'StddevGI50', 'LogTGI_M', 'LogTGI_u', 'LogTGI_V', 'IndnTGI', 'StddevTGI', 'PUBCHEM_EXT_DATASOURCE_SMILES', 'PUBCHEM_ACTIVITY_OUTCOME', '_LogGI50_mean', '_LogTGI_mean']


num_rows,num_unique_compounds
u32,u32
243047406,2955548


## 2. Assay and compound counts after column selection

Uses `assay_metadata.csv` to quantify how many assays/compounds are retained vs ineligible after selecting a single readout column.

`compounds_screened` counts the number of compound–assay screening pairs (not unique compounds).


In [19]:
assay_meta = pl.read_csv(assay_metadata_path, infer_schema_length=1000)

total_assays = assay_meta.height
num_ineligible_assays = assay_meta.filter(pl.col("selected_column") == "__INELIGIBLE__").height
num_eligible_assays = total_assays - num_ineligible_assays

assay_selection_summary = pl.DataFrame(
    {
        "total_assays": [total_assays],
        "eligible_assays": [num_eligible_assays],
        "ineligible_assays": [num_ineligible_assays],
    }
)

print("Assay counts after column selection:")
assay_selection_summary

compound_selection_summary = assay_meta.select(
    pl.len().alias("num_assays"),
    pl.col("compounds_screened").sum().alias("total_compounds_screened"),
    pl.when(pl.col("selected_column") != "__INELIGIBLE__")
    .then(pl.col("compounds_screened"))
    .otherwise(0)
    .sum()
    .alias("compounds_screened_eligible"),
    pl.when(pl.col("selected_column") == "__INELIGIBLE__")
    .then(pl.col("compounds_screened"))
    .otherwise(0)
    .sum()
    .alias("compounds_screened_ineligible"),
)

print("Compound (screening) counts based on assay_metadata:")
compound_selection_summary

Assay counts after column selection:
Compound (screening) counts based on assay_metadata:


num_assays,total_compounds_screened,compounds_screened_eligible,compounds_screened_ineligible
u32,i64,i64,i64
1323,242970045,229726022,13244023


## 3. Assay format breakdown (biochemical vs cellular vs other)

Counts of assays by `assay_format` from `assay_metadata.csv`, before and after applying the column-selection eligibility filter.


In [20]:
assay_format_total = (
    assay_meta
    .group_by("assay_format")
    .agg(pl.len().alias("num_assays"))
    .sort("assay_format")
)

assay_format_eligible = (
    assay_meta
    .filter(pl.col("selected_column") != "__INELIGIBLE__")
    .group_by("assay_format")
    .agg(pl.len().alias("num_assays_eligible"))
    .sort("assay_format")
)

assay_format_ineligible = (
    assay_meta
    .filter(pl.col("selected_column") == "__INELIGIBLE__")
    .group_by("assay_format")
    .agg(pl.len().alias("num_assays_ineligible"))
    .sort("assay_format")
)

print("Assay counts by format (all assays):")
print(assay_format_total)

print("\nAssay counts by format (eligible only):")
print(assay_format_eligible)

print("\nAssay counts by format (ineligible only):")
print(assay_format_ineligible)

Assay counts by format (all assays):
shape: (4, 2)
┌──────────────┬────────────┐
│ assay_format ┆ num_assays │
│ ---          ┆ ---        │
│ str          ┆ u32        │
╞══════════════╪════════════╡
│ null         ┆ 94         │
│ biochemical  ┆ 471        │
│ cellular     ┆ 389        │
│ other        ┆ 369        │
└──────────────┴────────────┘

Assay counts by format (eligible only):
shape: (3, 2)
┌──────────────┬─────────────────────┐
│ assay_format ┆ num_assays_eligible │
│ ---          ┆ ---                 │
│ str          ┆ u32                 │
╞══════════════╪═════════════════════╡
│ biochemical  ┆ 455                 │
│ cellular     ┆ 361                 │
│ other        ┆ 289                 │
└──────────────┴─────────────────────┘

Assay counts by format (ineligible only):
shape: (4, 2)
┌──────────────┬───────────────────────┐
│ assay_format ┆ num_assays_ineligible │
│ ---          ┆ ---                   │
│ str          ┆ u32                   │
╞══════════════╪══════

## 4. Overall impact of assay-cleaning on HTS data

Summaries before/after cleaning using the full HTS hits table (`assay_rscores.parquet`) and the cleaned biochemical / cellular outputs in `data/clean`. With `assay_cleaning_stats.json` present, we load those stats instead of re-scanning the large parquet files.

In [21]:

# Prefer precomputed stats from the cleaning run to avoid re-scanning 200M-row tables locally.
if assay_cleaning_stats_path.is_file():
    stats = json.loads(assay_cleaning_stats_path.read_text())

    input_summary = pl.DataFrame([
        {"stage": "raw_input", **stats["input"]}
    ])

    clean_summary = pl.DataFrame(
        [{"stage": f"clean_{k}", **v} for k, v in stats["output"].items()]
    )

    drop_reason_rows = [
        {"reason": reason, "count": count}
        for reason, count in stats["structure_cleaning"]["drop_reasons"].items()
    ]
    drop_reason_summary = pl.DataFrame(drop_reason_rows).sort("count", descending=True)

    print("Assay-cleaning stats loaded from", assay_cleaning_stats_path)
    print("Input summary:")
    input_summary
    print("Cleaned output summary:")
    clean_summary
    print("Structure cleaning summary:")
    pl.DataFrame(
        [
            {
                "unique_smiles": stats["structure_cleaning"]["unique_smiles"],
                "valid_canonical_smiles": stats["structure_cleaning"]["valid_canonical_smiles"],
                "dropped": stats["structure_cleaning"]["dropped"],
                "skip_tautomers": stats["structure_cleaning"]["skip_tautomers"],
            }
        ]
    )
    print("Drop reasons (unique SMILES):")
    drop_reason_summary
else:
    print("assay_cleaning_stats.json not found; computing summaries (expensive)...")

    # Summary of the raw HTS hits table used as input to assay-cleaning.
    hts_lf = pl.scan_parquet(assay_rscores_path)

    hts_summary = hts_lf.select(
        pl.len().alias("num_rows"),
        pl.col("assay_id").n_unique().alias("num_assays"),
        pl.col("compound_id").n_unique().alias("num_unique_compounds"),
    ).collect()

    print("Raw HTS hits (assay_rscores.parquet):")
    hts_summary

    # Summaries for the cleaned biochemical / cellular outputs (if available).
    clean_lfs = []
    if clean_biochemical_path.is_file():
        clean_lfs.append(pl.scan_parquet(str(clean_biochemical_path)))
    if clean_cellular_path.is_file():
        clean_lfs.append(pl.scan_parquet(str(clean_cellular_path)))

    if clean_lfs:
        clean_lf = clean_lfs[0]
        for lf in clean_lfs[1:]:
            clean_lf = clean_lf.union(lf)

        clean_summary = clean_lf.select(
            pl.len().alias("num_rows"),
            pl.col("assay_id").n_unique().alias("num_assays"),
            pl.col("compound_id").n_unique().alias("num_unique_compounds"),
        ).collect()

        print("Cleaned HTS hits after assay-cleaning (combined biochemical + cellular):")
        clean_summary

        pre_rows = int(hts_summary["num_rows"][0])
        pre_compounds = int(hts_summary["num_unique_compounds"][0])
        post_rows = int(clean_summary["num_rows"][0])
        post_compounds = int(clean_summary["num_unique_compounds"][0])

        drop_summary = pl.DataFrame(
            {
                "num_rows_before": [pre_rows],
                "num_rows_after": [post_rows],
                "num_rows_dropped": [pre_rows - post_rows],
                "num_unique_compounds_before": [pre_compounds],
                "num_unique_compounds_after": [post_compounds],
                "num_unique_compounds_dropped": [pre_compounds - post_compounds],
            }
        )
        print("Overall rows / unique-compound counts before vs after assay-cleaning:")
        drop_summary
    else:
        print(
            "Clean-split outputs not found at the configured paths; "
            "skipping pre/post assay-cleaning summary."
        )

Assay-cleaning stats loaded from /Users/snappi/frequent-hitters/data/clean/assay_cleaning_stats.json
Input summary:
Cleaned output summary:
Structure cleaning summary:
Drop reasons (unique SMILES):


## 5. Final processed datasets and split consistency

Summaries for the four model-ready parquet files (biochemical vs cellular × regression vs multitask). Each file contains multiple split columns (`split1`, `split2`, …); the test compounds should match between regression and multitask datasets for a given seed.

In [22]:
processed_stats = []
split_rows = []

for dataset_name, path in processed_paths.items():
    lf = pl.scan_parquet(path)
    schema_names = lf.collect_schema().names()
    split_cols = sorted([c for c in schema_names if c.startswith("split")])
    assay_cols = [c for c in schema_names if c.isdigit()]

    base_counts = (
        lf.select(
            pl.len().alias("num_rows"),
            pl.col("compound_id").n_unique().alias("unique_compounds"),
            pl.col("scaffold_smiles").n_unique().alias("unique_scaffolds"),
        )
        .collect()
        .to_dicts()[0]
    )

    processed_stats.append(
        {
            "dataset": dataset_name,
            "path": str(path),
            "num_rows": base_counts["num_rows"],
            "unique_compounds": base_counts["unique_compounds"],
            "unique_scaffolds": base_counts["unique_scaffolds"],
            "num_assay_tasks": len(assay_cols),
            "split_columns": ", ".join(split_cols),
        }
    )

    for split_col in split_cols:
        counts = (
            lf.group_by(pl.col(split_col))
            .agg(
                pl.len().alias("num_rows"),
                pl.col("compound_id").n_unique().alias("unique_compounds"),
                pl.col("scaffold_smiles").n_unique().alias("unique_scaffolds"),
            )
            .collect()
        )
        for row in counts.to_dicts():
            split_rows.append(
                {
                    "dataset": dataset_name,
                    "split_column": split_col,
                    "split": row[split_col],
                    "num_rows": row["num_rows"],
                    "unique_compounds": row["unique_compounds"],
                    "unique_scaffolds": row["unique_scaffolds"],
                }
            )

processed_summary = pl.DataFrame(processed_stats)
split_value_summary = pl.DataFrame(split_rows)

print("Processed dataset overview:")
print(processed_summary)

print()
print("Split counts per dataset/seed:")
print(split_value_summary.sort(["dataset", "split_column", "split"]))

# Check that regression and multitask test sets match for each seed.
split_pairs = {
    "biochemical": ("biochemical_regression", "biochemical_multilabel"),
    "cellular": ("cellular_regression", "cellular_multilabel"),
}

test_checks = []
for assay_format, (reg_key, mt_key) in split_pairs.items():
    reg_path = processed_paths[reg_key]
    mt_path = processed_paths[mt_key]

    reg_schema = pl.scan_parquet(reg_path).collect_schema().names()
    mt_schema = pl.scan_parquet(mt_path).collect_schema().names()

    reg_split_cols = sorted([c for c in reg_schema if c.startswith("split")])
    mt_split_cols = sorted([c for c in mt_schema if c.startswith("split")])

    if reg_split_cols != mt_split_cols:
        raise ValueError(
            f"Split columns mismatch for {assay_format}: {reg_split_cols} vs {mt_split_cols}"
        )

    for split_col in reg_split_cols:
        reg_test = (
            pl.scan_parquet(reg_path)
            .filter(pl.col(split_col) == "test")
            .select("compound_id")
            .unique()
        )
        mt_test = (
            pl.scan_parquet(mt_path)
            .filter(pl.col(split_col) == "test")
            .select("compound_id")
            .unique()
        )

        reg_count = reg_test.select(pl.len()).collect().item()
        mt_count = mt_test.select(pl.len()).collect().item()
        reg_not_mt = (
            reg_test.join(mt_test, on="compound_id", how="anti")
            .select(pl.len())
            .collect()
            .item()
        )
        mt_not_reg = (
            mt_test.join(reg_test, on="compound_id", how="anti")
            .select(pl.len())
            .collect()
            .item()
        )

        test_checks.append(
            {
                "assay_format": assay_format,
                "split_column": split_col,
                "regression_count": reg_count,
                "multitask_count": mt_count,
                "regression_only": reg_not_mt,
                "multitask_only": mt_not_reg,
                "test_sets_match": reg_not_mt == 0 and mt_not_reg == 0 and reg_count == mt_count,
            }
        )

test_split_consistency = pl.DataFrame(test_checks)
print()
print("Test split consistency (regression vs multitask):")
print(test_split_consistency)

Processed dataset overview:
shape: (4, 7)
┌──────────────┬──────────────┬──────────┬──────────────┬──────────────┬─────────────┬─────────────┐
│ dataset      ┆ path         ┆ num_rows ┆ unique_compo ┆ unique_scaff ┆ num_assay_t ┆ split_colum │
│ ---          ┆ ---          ┆ ---      ┆ unds         ┆ olds         ┆ asks        ┆ ns          │
│ str          ┆ str          ┆ i64      ┆ ---          ┆ ---          ┆ ---         ┆ ---         │
│              ┆              ┆          ┆ i64          ┆ i64          ┆ i64         ┆ str         │
╞══════════════╪══════════════╪══════════╪══════════════╪══════════════╪═════════════╪═════════════╡
│ biochemical_ ┆ /Users/snapp ┆ 368371   ┆ 368352       ┆ 103100       ┆ 444         ┆ split1,     │
│ regression   ┆ i/frequent-h ┆          ┆              ┆              ┆             ┆ split2,     │
│              ┆ itters…      ┆          ┆              ┆              ┆             ┆ split3,     │
│              ┆              ┆          ┆       

## 6. Structure-cleaning summary

Use the precomputed `assay_cleaning_stats.json` emitted by the cleaning CLI instead of re-running RDKit in the notebook.

In [23]:

# Load structure-cleaning stats produced during assay-cleaning
if not assay_cleaning_stats_path.is_file():
    raise FileNotFoundError(f"Missing stats file at {assay_cleaning_stats_path}")

stats = json.loads(assay_cleaning_stats_path.read_text())
structure_stats = stats.get("structure_cleaning", {})

structure_summary = pl.DataFrame(
    [
        {
            "unique_smiles": structure_stats.get("unique_smiles"),
            "valid_canonical_smiles": structure_stats.get("valid_canonical_smiles"),
            "dropped": structure_stats.get("dropped"),
            "skip_tautomers": structure_stats.get("skip_tautomers"),
        }
    ]
)

reason_rows = [
    {"reason": reason, "count": count}
    for reason, count in structure_stats.get("drop_reasons", {}).items()
]

reason_summary = (
    pl.DataFrame(reason_rows)
    .sort("count", descending=True)
    .with_columns(
        (pl.col("count") / structure_stats.get("unique_smiles", 1) * 100).alias("fraction_pct")
    )
)

print("Structure cleaning summary (precomputed):")
structure_summary

print("Drop reasons (unique SMILES):")
reason_summary


Structure cleaning summary (precomputed):
Drop reasons (unique SMILES):


reason,count,fraction_pct
str,i64,f64
"""retained""",2549457,98.489011
"""molecular_weight_filter""",36915,1.426077
"""forbidden_atom""",2156,0.083289
"""canonical_validation_failed""",34,0.001313
"""invalid_smiles""",7,0.00027
"""uncharger_error""",1,3.9e-05
