# Dataset size summaries for frequent-hitters

This notebook computes dataset sizes and filtering statistics used in the paper, based on the PubChem HTS data and the downstream cleaning pipeline.


In [None]:
from pathlib import Path
from typing import Optional, Tuple, Dict

import polars as pl
from rdkit import Chem, RDLogger
from rdkit.Chem import Descriptors
from rdkit.Chem.MolStandardize import rdMolStandardize

# Silence RDKit warnings
RDLogger.logger().setLevel(RDLogger.CRITICAL)

# Show Polars version for reproducibility
print("Polars version:", pl.__version__)


In [None]:
# Adjust these paths if you run the notebook from a different working directory.
project_root = Path.cwd().parent

assay_tables_dir = project_root / "pubchem-bioassay" / "data" / "assay_tables"
assay_metadata_path = project_root / "pubchem-bioassay" / "outputs" / "assay_metadata.csv"
assay_rscores_path = project_root / "pubchem-bioassay" / "outputs" / "assay_rscores.parquet"

# Paths to clean-split outputs for the full dataset (produced on the cluster).
# Update these to point at the actual cleaned files.
clean_biochemical_path = project_root / "cleaned" / "biochemical_hits.parquet"
clean_cellular_path = project_root / "cleaned" / "cellular_hits.parquet"

# For local testing you can instead point to the integration subset:
# integration_raw = project_root / "integration_artifacts" / "raw"
# clean_biochemical_path = integration_raw / "biochemical_hits_subset.parquet"
# clean_cellular_path = integration_raw / "cellular_hits_subset.parquet"

print("Assay tables dir:", assay_tables_dir)
print("Assay metadata:", assay_metadata_path)
print("Assay rscores:", assay_rscores_path)
print("Clean biochemical path:", clean_biochemical_path)
print("Clean cellular path:", clean_cellular_path)


## 1. Assays with ≥10k substances and unique compounds

Counts based on the pre-filtered assay tables in `pubchem-bioassay/data/assay_tables`.


In [None]:
# List assay table parquet files (each file corresponds to one assay with ≥10k substances).
assay_table_files = sorted(assay_tables_dir.glob("aid_*.parquet"))

assay_table_df = pl.DataFrame(
    {
        "assay_id": [int(p.stem.split("_")[1]) for p in assay_table_files],
        "path": [str(p) for p in assay_table_files],
    }
)

num_assays_ge_10k = assay_table_df.height
print(f"Number of assays with ≥10k substances (assay tables): {num_assays_ge_10k}")

# Compute total numbers of substances (SIDs) and compounds (CIDs) across all assays.
assay_tables_lf = pl.scan_parquet(str(assay_tables_dir / "aid_*.parquet"))

assay_level_stats = assay_tables_lf.select(
    pl.len().alias("num_rows"),
    pl.col("PUBCHEM_SID").n_unique().alias("num_unique_substances"),
    pl.col("PUBCHEM_CID").n_unique().alias("num_unique_compounds"),
).collect()

assay_level_stats


## 2. Assay and compound counts after column selection

Uses `assay_metadata.csv` to quantify how many assays/compounds are retained vs ineligible after selecting a single readout column.

`compounds_screened` counts the number of compound–assay screening pairs (not unique compounds).


In [None]:
assay_meta = pl.read_csv(assay_metadata_path, infer_schema_length=1000)

total_assays = assay_meta.height
num_ineligible_assays = assay_meta.filter(pl.col("selected_column") == "__INELIGIBLE__").height
num_eligible_assays = total_assays - num_ineligible_assays

assay_selection_summary = pl.DataFrame(
    {
        "total_assays": [total_assays],
        "eligible_assays": [num_eligible_assays],
        "ineligible_assays": [num_ineligible_assays],
    }
)

print("Assay counts after column selection:")
assay_selection_summary

compound_selection_summary = assay_meta.select(
    pl.len().alias("num_assays"),
    pl.col("compounds_screened").sum().alias("total_compounds_screened"),
    pl.when(pl.col("selected_column") != "__INELIGIBLE__")
    .then(pl.col("compounds_screened"))
    .otherwise(0)
    .sum()
    .alias("compounds_screened_eligible"),
    pl.when(pl.col("selected_column") == "__INELIGIBLE__")
    .then(pl.col("compounds_screened"))
    .otherwise(0)
    .sum()
    .alias("compounds_screened_ineligible"),
)

print("Compound (screening) counts based on assay_metadata:")
compound_selection_summary


## 3. Assay format breakdown (biochemical vs cellular vs other)

Counts of assays by `assay_format` from `assay_metadata.csv`, before and after applying the column-selection eligibility filter.


In [None]:
assay_format_total = (
    assay_meta
    .group_by("assay_format")
    .agg(pl.len().alias("num_assays"))
    .sort("assay_format")
)

assay_format_eligible = (
    assay_meta
    .filter(pl.col("selected_column") != "__INELIGIBLE__")
    .group_by("assay_format")
    .agg(pl.len().alias("num_assays_eligible"))
    .sort("assay_format")
)

assay_format_ineligible = (
    assay_meta
    .filter(pl.col("selected_column") == "__INELIGIBLE__")
    .group_by("assay_format")
    .agg(pl.len().alias("num_assays_ineligible"))
    .sort("assay_format")
)

print("Assay counts by format (all assays):")
print(assay_format_total)

print("\nAssay counts by format (eligible only):")
print(assay_format_eligible)

print("\nAssay counts by format (ineligible only):")
print(assay_format_ineligible)


## 4. Overall impact of clean-split on HTS data

Summaries before/after cleaning using the HTS hits table (`assay_rscores.parquet`) and the biochemical/cellular outputs produced by the `clean-split` CLI.


In [None]:
# Summary of the raw HTS hits table used as input to clean-split.
hts_lf = pl.scan_parquet(assay_rscores_path)

hts_summary = hts_lf.select(
    pl.len().alias("num_rows"),
    pl.col("assay_id").n_unique().alias("num_assays"),
    pl.col("compound_id").n_unique().alias("num_unique_compounds"),
).collect()

print("Raw HTS hits (assay_rscores.parquet):")
hts_summary

# Summaries for the cleaned biochemical / cellular outputs (if available).
clean_lfs = []
if clean_biochemical_path.is_file():
    clean_lfs.append(pl.scan_parquet(str(clean_biochemical_path)))
if clean_cellular_path.is_file():
    clean_lfs.append(pl.scan_parquet(str(clean_cellular_path)))

if clean_lfs:
    clean_lf = clean_lfs[0]
    for lf in clean_lfs[1:]:
        clean_lf = clean_lf.union(lf)

    clean_summary = clean_lf.select(
        pl.len().alias("num_rows"),
        pl.col("assay_id").n_unique().alias("num_assays"),
        pl.col("compound_id").n_unique().alias("num_unique_compounds"),
    ).collect()

    print("Cleaned HTS hits after clean-split (combined biochemical + cellular):")
    clean_summary

    pre_rows = int(hts_summary["num_rows"][0])
    pre_compounds = int(hts_summary["num_unique_compounds"][0])
    post_rows = int(clean_summary["num_rows"][0])
    post_compounds = int(clean_summary["num_unique_compounds"][0])

    drop_summary = pl.DataFrame(
        {
            "num_rows_before": [pre_rows],
            "num_rows_after": [post_rows],
            "num_rows_dropped": [pre_rows - post_rows],
            "num_unique_compounds_before": [pre_compounds],
            "num_unique_compounds_after": [post_compounds],
            "num_unique_compounds_dropped": [pre_compounds - post_compounds],
        }
    )
    print("Overall rows / unique-compound counts before vs after clean-split:")
    drop_summary
else:
    print(
        "Clean-split outputs not found at the configured paths; "
        "skipping pre/post clean-split summary."
    )


## 5. Breakdown of compounds dropped by clean-split

This section mirrors the `ChemicalProcessor` logic from `clean_split.cli` to assign each compound a reason for being dropped (invalid SMILES, molecular weight filter, disallowed elements, etc.).

Note: this re-runs RDKit-based standardisation on the unique SMILES and is therefore expensive on the full dataset. Run this on the cluster rather than on a local laptop.


In [None]:
# Constants as in clean_split.cli
ALLOWED_ATOMS = {
    1,   # H
    5,   # B
    6,   # C
    7,   # N
    8,   # O
    9,   # F
    14,  # Si
    15,  # P
    16,  # S
    17,  # Cl
    34,  # Se
    35,  # Br
    53,  # I
}

MIN_MW = 180.0
MAX_MW = 900.0

UNCHARGER = rdMolStandardize.Uncharger()
FRAGMENT_CHOOSER = rdMolStandardize.LargestFragmentChooser()
TAUT_ENUMERATOR = rdMolStandardize.TautomerEnumerator()

def clean_smiles_with_reason(smi: Optional[str]) -> Tuple[Optional[str], str]:
    """Apply the same series of steps as clean_split.ChemicalProcessor,
    but return both the canonical SMILES (or None) and a reason label."""
    if smi is None:
        return None, "missing_smiles"

    try:
        mol = Chem.MolFromSmiles(smi)
    except Exception:
        return None, "invalid_smiles_parse_error"

    if mol is None:
        return None, "invalid_smiles_none"

    try:
        mol = UNCHARGER.uncharge(mol)
        mol = FRAGMENT_CHOOSER.choose(mol)
        mol = UNCHARGER.uncharge(mol)
        mol = TAUT_ENUMERATOR.Canonicalize(mol)
    except Exception:
        return None, "standardisation_error"

    if mol is None:
        return None, "standardisation_error"

    try:
        mw = Descriptors.ExactMolWt(mol)
    except Exception:
        return None, "mw_compute_error"

    if not (MIN_MW <= mw <= MAX_MW):
        return None, "molecular_weight_filter"

    for atom in mol.GetAtoms():
        if atom.GetAtomicNum() not in ALLOWED_ATOMS:
            return None, "disallowed_element"

    try:
        canonical = Chem.MolToSmiles(mol, canonical=True, isomericSmiles=False)
    except Exception:
        return None, "canonicalisation_error"

    try:
        if Chem.MolFromSmiles(canonical) is None:
            return None, "canonicalisation_invalid_smiles"
    except Exception:
        return None, "canonicalisation_invalid_smiles"

    return canonical, "retained"


In [None]:
# Build the compound_id -> SMILES mapping exactly as in clean-split.
id_to_smiles = (
    pl.scan_parquet(assay_rscores_path)
    .select("compound_id", "smiles")
    .unique()
    .collect()
)

original_unique_compounds = id_to_smiles.height
print(f"Original unique compounds (before clean-split): {original_unique_compounds:,}")

# Run the RDKit-based cleaning once per unique SMILES string.
unique_smiles = id_to_smiles["smiles"].unique().to_list()
print(f"Number of unique SMILES to process: {len(unique_smiles):,}")

canonical_map: Dict[str, Tuple[Optional[str], str]] = {}
for smi in unique_smiles:
    canonical, reason = clean_smiles_with_reason(smi)
    canonical_map[smi] = (canonical, reason)

# Attach canonical SMILES and reason to the per-compound table.
id_to_smiles_with_reasons = id_to_smiles.with_columns(
    pl.col("smiles")
    .map_elements(lambda smi: canonical_map[smi][0], return_dtype=pl.Utf8)
    .alias("canonical_smiles"),
    pl.col("smiles")
    .map_elements(lambda smi: canonical_map[smi][1], return_dtype=pl.Utf8)
    .alias("clean_reason"),
)

retained_unique_compounds = id_to_smiles_with_reasons.filter(
    pl.col("canonical_smiles").is_not_null()
).height
dropped_unique_compounds = original_unique_compounds - retained_unique_compounds

print(f"Retained unique compounds after clean-split: {retained_unique_compounds:,}")
print(f"Dropped unique compounds after clean-split: {dropped_unique_compounds:,}")

# Breakdown of dropped compounds by reason.
reason_counts = (
    id_to_smiles_with_reasons
    .group_by("clean_reason")
    .agg(
        pl.len().alias("num_compounds"),
        pl.col("canonical_smiles").is_not_null().sum().alias("num_retained"),
    )
    .sort("num_compounds", descending=True)
)

reason_counts
