A GRN dataset holds the following objects:
- Gene Expression Matrix $\in \mathbb{R}^{M \times N} $
- Reference Network, 3-tuple: $\text{(Regulating Gene, Target Gene, Importance Score)}^{R},R \in \mathbb{N} $
- Transcription factors $\subseteq \mathcal{K}$ (optional)

In [None]:
from pathlib import Path
from typing import List, Tuple

import numpy as np
import pandas as pd
from numpy.typing import NDArray


def load_beeline(
    root: Path = Path("../data/raw/BEELINE"),
) -> Tuple[pd.DataFrame, pd.Series, pd.DataFrame]:
    assert root.exists(), f"Path {root} does not exist"
    data_dir = root / "BEELINE-data"
    network_dir = root / "BEELINE-Networks"

    # Name columns in accordance to GENIE3
    ref_network.columns = ["regulator_gene", "target_gene", "ground_truth"]

    return {
        "gene_expressions": gene_expressions,
        "transcription_factors": transcription_factors,
        "ref_network": ref_network,
    }


def _get_transcription_factor_indices(
    transcription_factors: pd.Series,
) -> List[int]:
    assert len(set(transcription_factors)) == len(
        transcription_factors
    ), "Transcription factors are not unique"
    transcription_factor_indices = list(range(len(transcription_factors)))
    return transcription_factor_indices


gene_expressions, transcription_factors, ref_network, gene_ids = load_dream5(
    network_id=3
).values()
inputs: NDArray = gene_expressions.values
transcription_factor_indices = _get_transcription_factor_indices(
    transcription_factors
)

In [None]:
from fedgenie3.modeling import GENIE3

tree_method = "RF"
tree_init_kwargs = {
    "n_estimators": 100,
    "max_features": "sqrt",
    "random_state": 42,
    "n_jobs": -1,
}
genie3 = GENIE3(tree_method=tree_method, tree_init_kwargs=tree_init_kwargs)

In [None]:
importance_matrix = genie3.compute_importances(
    inputs, transcription_factor_indices
)

In [None]:
gene_ranking = genie3.get_gene_ranking(
    importance_matrix, transcription_factor_indices
)
gene_ranking

In [None]:
def fn(x, gene_expressions):
    gene_names = gene_expressions.columns
    x = gene_names[x]
    return x


gene_cols = ["regulator_gene", "target_gene"]
gene_ranking[gene_cols] = gene_ranking[gene_cols].apply(
    lambda x: fn(x, gene_expressions), axis=0
)
gene_ranking

In [None]:
from fedgenie3.eval import evaluate

evaluate(gene_ranking, ref_network)