# S2F Performance Analysis

Use this notebook to explore the prediction outputs produced by different S2F runs and compare their behaviour. Each run writes a `prediction.df` file inside `<installation_directory>/output/<alias>/`. Update the configuration cells below with the aliases you want to analyse.

## Imports and configuration

The snippet below reads `s2f.conf` to locate the shared output directory. Adjust `CONFIG_PATH` if you are running the notebook from a different location.

In [None]:
from pathlib import Path
import configparser

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

pd.set_option('display.max_columns', None)
pd.set_option('display.max_rows', 10)
pd.set_option('display.width', 120)

sns.set_theme(style='whitegrid')

CONFIG_PATH = Path('s2f.conf')
config = configparser.ConfigParser()
if not CONFIG_PATH.exists():
    raise FileNotFoundError(f'Configuration file not found: {CONFIG_PATH.resolve()}. Update CONFIG_PATH to match your setup.')
config.read(CONFIG_PATH)

BASE_OUTPUT = Path(config['directories']['installation_directory']).expanduser() / 'output'
if not BASE_OUTPUT.exists():
    raise FileNotFoundError(f'Output directory not found: {BASE_OUTPUT}. Make sure the S2F runs have been executed.')

FILTERED_GOA_PATH = Path(config.get('databases', 'filtered_goa', fallback='filtered_goa')).expanduser()
if not FILTERED_GOA_PATH.exists():
    raise FileNotFoundError(f'Filtered GOA file not found: {FILTERED_GOA_PATH}. Update FILTERED_GOA_PATH below if your data lives elsewhere.')

DEFAULT_OBO_PATH = Path('go.obo').resolve()
if not DEFAULT_OBO_PATH.exists():
    raise FileNotFoundError(f'GO ontology file not found: {DEFAULT_OBO_PATH}. Update DEFAULT_OBO_PATH to point at your go.obo file.')

BASE_OUTPUT

## Discover available runs

This cell lists the aliases that currently have diffusion outputs. If the list is long, slice or filter it as needed.

In [None]:
AVAILABLE_RUNS = sorted(p.name for p in BASE_OUTPUT.iterdir() if p.is_dir())
print(f"{len(AVAILABLE_RUNS)} run(s) found under {BASE_OUTPUT}")
pd.DataFrame({'alias': AVAILABLE_RUNS})

## Select runs to compare

Edit `RUNS_TO_COMPARE` to focus on the runs you are interested in. By default the cell keeps only aliases that actually exist in `AVAILABLE_RUNS`.

In [None]:
RUNS_TO_COMPARE = [
    'test_223283',
    'test_223283_new',
    # 'test_223283_new_2',
]

RUNS_TO_COMPARE = [alias for alias in RUNS_TO_COMPARE if alias in AVAILABLE_RUNS]
if not RUNS_TO_COMPARE:
    raise ValueError('Update RUNS_TO_COMPARE with at least one valid alias.')

RUNS_TO_COMPARE

## Load prediction tables

Helper functions to read the prediction scores and the index files written by each run.

In [None]:
def load_predictions(alias: str, base_path: Path = BASE_OUTPUT) -> pd.DataFrame:
    """Load the diffusion output for a single run as a tidy DataFrame."""
    path = base_path / alias / 'prediction.df'
    if not path.exists():
        raise FileNotFoundError(f'Prediction file not found for {alias}: {path}')
    df = pd.read_csv(path, sep='	', header=None, names=['protein_id', 'term_id', 'score'])
    df['alias'] = alias
    return df


def load_terms(alias: str, base_path: Path = BASE_OUTPUT) -> pd.DataFrame:
    """Grab the GO term lookup table if you need term names or namespaces."""
    path = base_path / alias / 'terms.df'
    if not path.exists():
        raise FileNotFoundError(f'GO term index not found for {alias}: {path}')
    return pd.read_pickle(path)


def load_proteins(alias: str, base_path: Path = BASE_OUTPUT) -> pd.DataFrame:
    """Fetch the protein index for convenience (e.g. to map back to FASTA identifiers)."""
    path = base_path / alias / 'proteins.df'
    if not path.exists():
        raise FileNotFoundError(f'Protein index not found for {alias}: {path}')
    return pd.read_pickle(path)


In [None]:
predictions_df = pd.concat(
    [load_predictions(alias) for alias in RUNS_TO_COMPARE],
    ignore_index=True
)

print(f"Loaded {predictions_df.shape[0]:,} scored annotations spanning {predictions_df['alias'].nunique()} run(s).")
predictions_df.head()

## Summary statistics per run

In [None]:
summary = (
    predictions_df
    .groupby('alias')
    .agg(
        proteins=('protein_id', 'nunique'),
        go_terms=('term_id', 'nunique'),
        annotations=('term_id', 'size'),
        min_score=('score', 'min'),
        median_score=('score', 'median'),
        mean_score=('score', 'mean'),
        max_score=('score', 'max')
    )
    .sort_index()
)
summary

## Score distribution (log scale)

Scores tend to be very small, so plotting their log<sub>10</sub> values highlights differences between runs.

In [None]:
predictions_df['score_log10'] = np.log10(predictions_df['score'].clip(lower=1e-15))

plt.figure(figsize=(9, 4))
sns.histplot(
    data=predictions_df,
    x='score_log10',
    hue='alias',
    bins=60,
    element='step',
    stat='density',
    common_norm=False,
    alpha=0.35
)
plt.xlabel('log10(score)')
plt.ylabel('density')
plt.title('Score distribution per run (log scale)')
plt.tight_layout()
plt.show()

## Focus on top-N predictions per protein

Restrict to the highest-scoring annotations per protein to study agreement between runs.

In [None]:
TOP_N = 5  # change to inspect more or fewer annotations per protein

_top_sorted = predictions_df.sort_values(
    ['alias', 'protein_id', 'score'],
    ascending=[True, True, False]
)

top_predictions = (
    _top_sorted
    .groupby(['alias', 'protein_id'], as_index=False)
    .head(TOP_N)
    .reset_index(drop=True)
)

top_predictions.head()

## Pairwise comparison (first two runs)

The cell below contrasts the first two aliases in `RUNS_TO_COMPARE`. Edit `BASELINE` and `VARIANT` if you prefer a different pair.

In [None]:
if len(RUNS_TO_COMPARE) >= 2:
    BASELINE, VARIANT = RUNS_TO_COMPARE[0], RUNS_TO_COMPARE[1]

    baseline_top = (
        top_predictions[top_predictions['alias'] == BASELINE]
        [['protein_id', 'term_id', 'score']]
        .set_index(['protein_id', 'term_id'])
    )
    variant_top = (
        top_predictions[top_predictions['alias'] == VARIANT]
        [['protein_id', 'term_id', 'score']]
        .set_index(['protein_id', 'term_id'])
    )

    pairwise = baseline_top.join(
        variant_top,
        how='outer',
        lsuffix=f'_{BASELINE}',
        rsuffix=f'_{VARIANT}'
    )
    pairwise[f'score_{BASELINE}'] = pairwise.get(f'score_{BASELINE}', 0).fillna(0)
    pairwise[f'score_{VARIANT}'] = pairwise.get(f'score_{VARIANT}', 0).fillna(0)
    pairwise['score_delta'] = pairwise[f'score_{VARIANT}'] - pairwise[f'score_{BASELINE}']
    pairwise = pairwise.sort_values('score_delta', ascending=False)

    print(f'Comparing top-{TOP_N} predictions between {BASELINE} and {VARIANT}.')
    pairwise.head(20)
else:
    print('Add at least two aliases to RUNS_TO_COMPARE to compute pairwise deltas.')

## Largest improvements and regressions

Filter the `pairwise` table to highlight the strongest positive or negative score shifts.

In [None]:
if len(RUNS_TO_COMPARE) >= 2 and 'pairwise' in globals():
    gains = pairwise[pairwise['score_delta'] > 0].head(20)
    losses = pairwise[pairwise['score_delta'] < 0].tail(20)

    print('Top gains:')
    display(gains)

    print('Top losses:')
    display(losses)
else:
    print('Pairwise deltas are not available — add at least two aliases and run the previous cell first.')

## Per-protein overlap statistics

Compute how much the top-N annotations overlap per protein between the first two runs.

In [None]:
if len(RUNS_TO_COMPARE) >= 2:
    BASELINE, VARIANT = RUNS_TO_COMPARE[0], RUNS_TO_COMPARE[1]
    overlaps = []

    for protein_id, subset in top_predictions.groupby('protein_id'):
        alias_groups = {alias: group for alias, group in subset.groupby('alias')}
        if BASELINE not in alias_groups or VARIANT not in alias_groups:
            continue
        baseline_terms = set(alias_groups[BASELINE]['term_id'])
        variant_terms = set(alias_groups[VARIANT]['term_id'])
        union = baseline_terms | variant_terms
        if not union:
            continue
        shared = baseline_terms & variant_terms
        overlaps.append({
            'protein_id': protein_id,
            'shared': len(shared),
            'union': len(union),
            'jaccard': len(shared) / len(union)
        })

    overlap_df = pd.DataFrame(overlaps)
    print(f'Computed overlaps for {len(overlap_df)} proteins present in both runs.')
    overlap_df.sort_values('jaccard', ascending=False).head(20)
else:
    print('Add at least two aliases to RUNS_TO_COMPARE to evaluate overlap.')

## Optional: attach GO term metadata

If you need GO term names or namespaces, use `load_terms` and merge on `term_id`. The snippet below shows how to expand the pairwise comparison with GO labels.

In [None]:
# Example: enrich pairwise table with GO term metadata (adjust `ALIAS_FOR_METADATA` as needed)
if RUNS_TO_COMPARE:
    ALIAS_FOR_METADATA = RUNS_TO_COMPARE[0]
    try:
        terms_lookup = load_terms(ALIAS_FOR_METADATA)
        terms_lookup = terms_lookup.reset_index().rename(columns={'term id': 'term_id'})
    except FileNotFoundError as exc:
        print(exc)
    else:
        if 'pairwise' in globals():
            decorated = pairwise.reset_index().merge(terms_lookup, on='term_id', how='left')
            decorated.head()
        else:
            print('Run the pairwise comparison cell first to create the `pairwise` table.')
else:
    print('RUNS_TO_COMPARE is empty; nothing to decorate.')

## Ground truth configuration

Use the helpers below to discover metadata for the available runs and set up GOA-based evaluation parameters.

In [None]:
import re

def _parse_taxon_from_fasta(fasta_path):
    """Infer a taxon identifier from the FASTA file name."""
    stem = Path(fasta_path).stem
    match = re.search(r'(\d+)$', stem)
    if match:
        return match.group(1)
    digits = re.findall(r'\d+', stem)
    return digits[-1] if digits else None


def discover_run_metadata(run_config_paths):
    """Parse run configuration files to map aliases to useful metadata."""
    metadata = {}
    for conf_path in run_config_paths:
        parser = configparser.ConfigParser()
        parser.read(conf_path)
        if not parser.has_section('configuration'):
            continue
        alias = parser.get('configuration', 'alias', fallback='').strip()
        if not alias:
            continue
        entry = metadata.setdefault(alias, {})
        entry['run_config'] = conf_path
        if parser.has_option('configuration', 'fasta'):
            fasta_path = parser.get('configuration', 'fasta')
            entry['fasta'] = fasta_path
            taxon = _parse_taxon_from_fasta(fasta_path)
            if taxon:
                entry['taxon_id'] = taxon
    return metadata


RUN_CONFIG_PATHS = sorted(Path('.').glob('run*.conf'))
RUN_METADATA = discover_run_metadata(RUN_CONFIG_PATHS)

metadata_table = pd.DataFrame(
    [
        {
            'alias': alias,
            'run_config': str(meta.get('run_config')),
            'fasta': meta.get('fasta'),
            'taxon_id': meta.get('taxon_id'),
        }
        for alias, meta in RUN_METADATA.items()
    ]
).sort_values('alias').reset_index(drop=True) if RUN_METADATA else pd.DataFrame(columns=['alias', 'run_config', 'fasta', 'taxon_id'])

metadata_table

In [None]:
DEFAULT_EVIDENCE_CODES = [
    code.strip()
    for code in config.get('options', 'evidence_codes', fallback='EXP,IDA,IPI,IMP,IGI,IEP,TAS,IC').split(',')
    if code.strip()
]

GROUND_TRUTH_CONFIG = {
    alias: {
        'taxon_id': meta.get('taxon_id'),
        'goa_path': FILTERED_GOA_PATH,
        'obo_path': DEFAULT_OBO_PATH,
        'evidence_codes': DEFAULT_EVIDENCE_CODES,
        'min_annotations_per_protein': 3,
        'term_frequency_range': (1, 1_000_000),
    }
    for alias, meta in RUN_METADATA.items()
    if meta.get('taxon_id')
}

# Override or add entries here as needed, for example:
# GROUND_TRUTH_CONFIG['custom_alias'] = {
#     'taxon_id': '123456',
#     'goa_path': Path('/path/to/filtered_goa_subset.gaf'),
#     'obo_path': Path('/path/to/go.obo'),
#     'evidence_codes': ['EXP', 'IDA'],
#     'min_annotations_per_protein': 3,
#     'term_frequency_range': (1, 1_000_000),
# }

GROUND_TRUTH_CONFIG

In [None]:
from Measures.measures import HX_py
from GOTool.GeneOntology import GeneOntology

_GOA_CACHE = {}


def load_goa_annotations_for_taxon(taxon_id, goa_path, evidence_codes, chunk_size=200_000):
    """Load GOA annotations for a single taxon from a potentially large GOA file."""
    goa_path = Path(goa_path)
    if not goa_path.exists():
        raise FileNotFoundError(f'GOA file not found: {goa_path}')
    cache_key = (goa_path.resolve(), str(taxon_id), tuple(sorted(evidence_codes)) if evidence_codes else ())
    if cache_key in _GOA_CACHE:
        return _GOA_CACHE[cache_key].copy()

    column_names = [
        'DB', 'DB Object ID', 'DB Object Symbol', 'Qualifier', 'GO ID',
        'DB Reference', 'Evidence Code', 'With', 'Aspect', 'DB Object Name',
        'Synonym', 'DB Object Type', 'Taxon', 'Date', 'Assigned By',
        'Annotation Extension', 'Gene Product Form ID'
    ]
    dtype = {name: str for name in column_names}
    pattern = fr'(?:^|\|)taxon:{taxon_id}(?:$|\|)'
    annotations = []

    for chunk in pd.read_csv(
        goa_path,
        sep='	',
        comment='!',
        header=None,
        names=column_names,
        dtype=dtype,
        chunksize=chunk_size,
        low_memory=False,
    ):
        mask = chunk['Taxon'].fillna('').str.contains(pattern, regex=True)
        if not mask.any():
            continue
        subset = chunk.loc[mask]
        if evidence_codes:
            subset = subset[subset['Evidence Code'].isin(evidence_codes)]
            if subset.empty:
                continue
        trimmed = subset[['DB Object ID', 'GO ID']].drop_duplicates()
        trimmed = trimmed.rename(columns={'DB Object ID': 'Protein'})
        trimmed['Score'] = 1.0
        annotations.append(trimmed)

    if annotations:
        result = pd.concat(annotations, ignore_index=True)
    else:
        result = pd.DataFrame(columns=['Protein', 'GO ID', 'Score'])

    _GOA_CACHE[cache_key] = result
    return result.copy()


def build_prediction_and_gold_matrices(predictions, annotations):
    if predictions.empty:
        raise ValueError('Prediction table is empty for the requested alias.')
    if annotations.empty:
        raise ValueError('Ground-truth annotations are empty for the requested taxon.')

    proteins = sorted(set(predictions['protein_id']) | set(annotations['Protein']))
    terms = sorted(set(predictions['term_id']) | set(annotations['GO ID']))

    protein_to_idx = {protein: idx for idx, protein in enumerate(proteins)}
    term_to_idx = {term: idx for idx, term in enumerate(terms)}

    prediction_matrix = np.zeros((len(proteins), len(terms)), dtype=np.float32)
    rows = predictions['protein_id'].map(protein_to_idx).to_numpy()
    cols = predictions['term_id'].map(term_to_idx).to_numpy()
    prediction_matrix[rows, cols] = predictions['score'].to_numpy(dtype=float)

    gold_matrix = np.zeros((len(proteins), len(terms)), dtype=np.float32)
    gt_rows = annotations['Protein'].map(protein_to_idx).to_numpy()
    gt_cols = annotations['GO ID'].map(term_to_idx).to_numpy()
    gold_matrix[gt_rows, gt_cols] = 1.0

    return prediction_matrix, gold_matrix, protein_to_idx, term_to_idx


def compute_information_content(term_to_idx, ontology, organism_name):
    ic = np.zeros(len(term_to_idx), dtype=np.float32)
    for term, idx in term_to_idx.items():
        try:
            ic[idx] = ontology.find_term(term).information_content(organism_name)
        except KeyError:
            ic[idx] = 0.0
    return ic


def filter_matrices(gold_matrix, prediction_matrix, ic_vector, min_annotations_per_protein=3, term_frequency_range=(1, 1_000_000)):
    lower, upper = term_frequency_range
    upper = float('inf') if upper is None else upper

    sumrow = gold_matrix.sum(axis=1)
    sumcol = gold_matrix.sum(axis=0)

    row_mask = sumrow >= min_annotations_per_protein
    col_mask = (sumcol >= lower) & (sumcol <= upper)

    filtered_pred = prediction_matrix[row_mask][:, col_mask]
    filtered_gold = gold_matrix[row_mask][:, col_mask]
    filtered_ic = ic_vector[col_mask]

    return filtered_pred, filtered_gold, filtered_ic, row_mask, col_mask


def evaluate_alias(alias, predictions_df, alias_cfg):
    alias_predictions = predictions_df[predictions_df['alias'] == alias].copy()
    if alias_predictions.empty:
        raise ValueError('No predictions found for this alias in predictions_df.')

    taxon_id = alias_cfg.get('taxon_id')
    if not taxon_id:
        raise ValueError('Taxon identifier is missing from GROUND_TRUTH_CONFIG.')

    goa_path = alias_cfg.get('goa_path', FILTERED_GOA_PATH)
    evidence_codes = alias_cfg.get('evidence_codes', DEFAULT_EVIDENCE_CODES)
    annotations = load_goa_annotations_for_taxon(taxon_id, goa_path, evidence_codes)
    if annotations.empty:
        raise ValueError('No annotations retrieved from GOA for the selected taxon and evidence codes.')

    organism_name = f'gt_{alias}'
    ontology_path = alias_cfg.get('obo_path', DEFAULT_OBO_PATH)
    ontology = GeneOntology(str(ontology_path), verbose=False)
    ontology.build_structure()
    ontology.load_annotations(annotations, organism_name)
    ontology.up_propagate_annotations(organism_name)
    propagated_annotations = ontology.get_annotations(organism_name)

    prediction_matrix, gold_matrix, protein_to_idx, term_to_idx = build_prediction_and_gold_matrices(alias_predictions, propagated_annotations)
    ic_vector = compute_information_content(term_to_idx, ontology, organism_name)

    freq_range = alias_cfg.get('term_frequency_range', (1, 1_000_000))
    min_ann = alias_cfg.get('min_annotations_per_protein', 3)
    filtered_pred, filtered_gold, filtered_ic, row_mask, col_mask = filter_matrices(
        gold_matrix,
        prediction_matrix,
        ic_vector,
        min_annotations_per_protein=min_ann,
        term_frequency_range=freq_range,
    )

    if filtered_gold.size == 0 or filtered_gold.sum() == 0:
        raise ValueError('No overlapping annotations left after filtering criteria were applied.')

    if np.unique(filtered_pred).size > 10000:
        filtered_pred = np.around(filtered_pred, decimals=4)

    measure = HX_py(filtered_pred, filtered_ic, organism_id=alias, verbose=False)
    results = {
        'overall': measure.compute_overall(filtered_gold),
        'per_gene': measure.compute_per_gene(filtered_gold),
        'per_term': measure.compute_per_term(filtered_gold),
    }
    context = {
        'proteins_considered': int(row_mask.sum()),
        'terms_considered': int(col_mask.sum()),
        'total_annotations': int(filtered_gold.sum()),
        'matrix_shape': filtered_gold.shape,
    }
    return results, context


def flatten_metrics(nested_metrics):
    flat = {}
    for block, metrics in nested_metrics.items():
        for key, value in metrics.items():
            if isinstance(value, (list, tuple, dict)):
                continue
            arr = np.asarray(value)
            if arr.ndim == 0:
                label = key if block != 'overall' else f'{block}::{key}'
                flat[label] = float(arr)
    return flat

In [None]:
EVALUATION_RESULTS = {}
EVALUATION_CONTEXT = {}
EVALUATION_SUMMARY = []
EVALUATION_ERRORS = []

for alias in RUNS_TO_COMPARE:
    cfg = GROUND_TRUTH_CONFIG.get(alias)
    if cfg is None:
        print(f'Skipping {alias}: no ground truth configuration available.')
        continue
    try:
        metrics, context = evaluate_alias(alias, predictions_df, cfg)
    except Exception as exc:
        print(f'⚠️ {alias}: {exc}')
        EVALUATION_ERRORS.append({'alias': alias, 'error': str(exc)})
        continue
    EVALUATION_RESULTS[alias] = metrics
    EVALUATION_CONTEXT[alias] = context
    flattened = flatten_metrics(metrics)
    row = {'alias': alias, **flattened}
    EVALUATION_SUMMARY.append(row)

metrics_df = pd.DataFrame(EVALUATION_SUMMARY).set_index('alias') if EVALUATION_SUMMARY else pd.DataFrame()
metrics_df

In [None]:
context_df = pd.DataFrame.from_dict(EVALUATION_CONTEXT, orient='index') if EVALUATION_CONTEXT else pd.DataFrame()
context_df

In [None]:
pd.DataFrame(EVALUATION_ERRORS) if EVALUATION_ERRORS else pd.DataFrame(columns=['alias', 'error'])