<a href="https://colab.research.google.com/github/yonabimanyu/cancer-boolean-network/blob/main/StepMiner_algorithm.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
from google.colab import drive
import os
import math
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt

In [None]:
drive.mount('/content/drive')

In [None]:
folder = "data/"
files = os.listdir(folder)
for f in files:
    print(f)

In [None]:
# 'exp_upbrca1' is only an example
# The '-expr' suffix is mandatory

df = pd.read_table("data/exp_upbrca1-expr.txt")
df.head()

In [None]:
prefix = "results/exp_upbrca1"

In [None]:
def stepminer(prefix):
    """
    Main StepMiner function that processes gene expression data to find optimal thresholds.

    Args:
        prefix (str): File prefix for input files (e.g., 'exp_upbrca1' for 'exp_upbrca1-expr.txt')

    Returns:
        list: List of threshold dictionaries for each gene
    """
    # Load input files with proper error handling
    expr_df = _load_expression_file(prefix)
    if expr_df is None:
        return None

    # Load additional files (optional)
    _load_optional_files(prefix)

    # Find thresholds for all genes
    print(f"Processing {len(expr_df)} genes...")
    thresholds = find_thresholds(expr_df)

    # Write output files
    write_thr_txt(thresholds, prefix)
    write_info_txt(thresholds, prefix)
    write_bv_txt(thresholds, prefix)

    print(f"StepMiner analysis completed. Output files:")
    print(f"  - {prefix}-thr.txt")
    print(f"  - {prefix}-info.txt")
    print(f"  - {prefix}-bv.txt")

    return thresholds

In [None]:
def _load_expression_file(prefix):
    """Load and validate expression file."""
    try:
        expr_df = pd.read_table(f"{prefix}-expr.txt")

        # Validate required columns
        required_cols = ['ensembl_id', 'gene_symbol']
        missing_cols = [col for col in required_cols if col not in expr_df.columns]
        if missing_cols:
            raise ValueError(f"Missing required columns: {missing_cols}")

        # Check if we have sample data
        sample_columns = list(expr_df.columns)[2:]
        if len(sample_columns) == 0:
            raise ValueError("No sample columns found in expression file")

        print(f"Loaded expression file with {len(expr_df)} genes and {len(sample_columns)} samples")
        return expr_df

    except FileNotFoundError:
        print(f"ERROR: Expression file '{prefix}-expr.txt' not found")
        return None
    except Exception as e:
        print(f"ERROR loading expression file: {e}")
        return None

In [None]:
def _load_optional_files(prefix):
    """Load optional additional files."""
    optional_files = [
        ("-idx.txt", "index"),
        ("-ih.txt", "ih"),
        ("-survival.txt", "survival")
    ]

    for suffix, file_type in optional_files:
        try:
            pd.read_table(f"{prefix}{suffix}")
            print(f"Loaded {file_type} file")
        except FileNotFoundError:
            print(f"Optional {file_type} file not found (skipping)")
        except Exception as e:
            print(f"Warning: Error loading {file_type} file: {e}")

In [None]:
def find_thresholds(expr_df):
    """
    Find optimal thresholds for all genes using StepMiner algorithm.

    Args:
        expr_df (DataFrame): Expression data with ensembl_id, gene_symbol, and sample columns

    Returns:
        list: List of dictionaries containing threshold information for each gene
    """
    sample_columns = list(expr_df.columns)[2:]  # Skip ensembl_id and gene_symbol
    thresholds = []

    for index, row in expr_df.iterrows():
        ensembl_id = row["ensembl_id"]
        gene_symbol = row["gene_symbol"]
        values = row[sample_columns].tolist()

        # Validate expression values
        if not _validate_expression_values(values, ensembl_id):
            continue

        # Apply StepMiner algorithm (CORE MATHEMATICAL COMPUTATION - UNCHANGED)
        thr_vals = fit_step_miner(values)

        # Add gene identifiers
        thr_vals["ensembl_id"] = ensembl_id
        thr_vals["gene_symbol"] = gene_symbol

        # Calculate F-statistic (MATHEMATICAL COMPUTATION - UNCHANGED)
        stat = f_statistic(thr_vals["sse"], thr_vals["sstot"], thr_vals["n"])
        thr_vals["stat"] = round(stat, 3)

        # Calculate threshold bounds
        thr_vals["thr-0.5"] = round(thr_vals["thr"] - 0.5, 3)
        thr_vals["thr+0.5"] = round(thr_vals["thr"] + 0.5, 3)

        thresholds.append(thr_vals)

    print(f"Successfully processed {len(thresholds)} genes")
    return thresholds

In [None]:
def _validate_expression_values(values, gene_id):
    """Validate expression values for a gene."""
    # Check for missing values
    if any(pd.isna(val) for val in values):
        print(f"Warning: Gene {gene_id} has missing values, skipping")
        return False

    # Check for non-numeric values
    try:
        numeric_values = [float(val) for val in values]
    except (ValueError, TypeError):
        print(f"Warning: Gene {gene_id} has non-numeric values, skipping")
        return False

    # Check for sufficient data points
    if len(numeric_values) < 4:
        print(f"Warning: Gene {gene_id} has insufficient data points ({len(numeric_values)} < 4), skipping")
        return False

    return True

In [None]:
def fit_step_miner(values):
    """
    CORE STEPMINER ALGORITHM - MATHEMATICAL COMPUTATION UNCHANGED

    Fit step function to find optimal threshold that minimizes SSE.
    This algorithm finds the best split point that divides data into two groups.

    Args:
        values (list): Expression values for a gene

    Returns:
        dict: Dictionary containing threshold statistics and derived values
    """
    l = sorted(values)  # Sort the list if not sorted

    # Initialize variables for optimization
    n = len(l)
    n1 = 0
    n2 = n
    min_m1 = m1 = 0
    m2 = sum(l) / n
    min_m2 = m = m2
    min_i = -1
    sse = sum([(x - m2)**2 for x in l])
    sstot = sse
    min_sse = sse
    ssr = 0
    thr = -float('inf')

    # CORE OPTIMIZATION LOOP - UNCHANGED
    for i, x in enumerate(l):
        m1 = ((m1 * n1) + x) / (n1 + 1)
        n1 += 1
        if n1 == n:
            m2 = 0
            n2 = 0
        else:
            m2 = ((m2 * n2) - x) / (n2 - 1)
            n2 = n2 - 1
        ssr = (n1 * ((m1 - m)**2)) + (n2 * ((m2 - m)**2))
        sse = (sstot - ssr)
        if min_sse > sse:
            min_sse = sse
            min_m1 = m1
            min_m2 = m2
            thr = x
            min_i = i

    # Calculate basic statistics
    ret_vals = {}
    l_mean = m  # Overall mean
    l_thr = (min_m1 + min_m2) / 2  # Optimal threshold

    ret_vals["mean"] = round(l_mean, 3)
    ret_vals["min"] = round(min(l), 3)
    ret_vals["max"] = round(max(l), 3)
    ret_vals["mean-thr"] = round(l_mean - l_thr, 3)
    ret_vals["sd"] = round(np.std(l), 3)
    ret_vals["thr"] = round(l_thr, 3)
    ret_vals["sse"] = round(min_sse, 3)
    ret_vals["sstot"] = round(sstot, 3)
    ret_vals["n"] = n

    # Generate BitVector based on threshold Â± 0.5
    upperlim = l_thr + 0.5
    lowerlim = l_thr - 0.5
    BVlist = []
    for val in values:
        if val > upperlim:
            BVlist.append('2')  # High expression
        elif val < lowerlim:
            BVlist.append('0')  # Low expression
        else:
            BVlist.append('1')  # Intermediate expression

    BVstring = ''.join(BVlist)
    ret_vals["BV"] = BVstring

    # Count expression categories
    hi = sum([1 for i in BVlist if i == '2'])
    low = sum([1 for i in BVlist if i == '0'])
    inter = len(BVlist) - hi - low
    ret_vals["hi"] = hi
    ret_vals["lo"] = low
    ret_vals["int"] = inter

    # Calculate percentage metric
    if l_thr > l_mean:
        numb_betw = sum([1 for i in l if l_mean <= i <= l_thr])
        numb_below = sum([1 for i in l if i <= l_thr])
        perc = -1 * numb_betw / numb_below if numb_below > 0 else 0
    else:
        numb_betw = sum([1 for i in l if l_thr <= i <= l_mean])
        numb_above = sum([1 for i in l if i >= l_thr])
        perc = numb_betw / numb_above if numb_above > 0 else 0
    ret_vals["perc"] = round(perc, 3)

    # Calculate threshold number
    thrnum = sum([1 for i in l if i <= l_thr])
    ret_vals["thrNum"] = thrnum

    return ret_vals

In [None]:
def f_statistic(sse, sstot, n):
    """
    MATHEMATICAL COMPUTATION - UNCHANGED

    Calculate F-statistic for threshold significance testing.

    Args:
        sse (float): Sum of squared errors
        sstot (float): Total sum of squares
        n (int): Number of data points

    Returns:
        float: F-statistic value
    """
    ssr = sstot - sse
    dof_ssr = 3 if n > 4 else 2
    dof_sse = n - 4 if n > 4 else 1
    msr = ssr / dof_ssr

    # Handle edge cases
    if sse == 0:
        return 0.0

    mse = sse / dof_sse
    if mse == 0:
        return 0.0

    return msr / mse

In [None]:
def write_thr_txt(thresholds, prefix):
    """Write threshold summary file."""
    filename = f"{prefix}-thr.txt"
    try:
        with open(filename, 'w') as f:
            # Write header
            f.write("ensembl_id\tthreshold\tf_statistic\tthr_minus_0.5\tthr_plus_0.5\n")

            # Write data
            for dic in thresholds:
                f.write(f"{dic['ensembl_id']}\t{dic['thr']}\t{dic['stat']}\t"
                       f"{dic['thr-0.5']}\t{dic['thr+0.5']}\n")

        print(f"Written {len(thresholds)} thresholds to {filename}")
    except Exception as e:
        print(f"Error writing {filename}: {e}")

In [None]:
def write_info_txt(thresholds, prefix):
    """Write detailed information file."""
    filename = f"{prefix}-info.txt"
    try:
        with open(filename, 'w') as f:
            # Write header
            header_cols = ["ensembl_id", "gene_symbol", "thr", "mean", "mean-thr",
                          "perc", "min", "max", "sd", "thrNum", "hi", "int", "lo"]
            f.write("\t".join(header_cols) + "\n")

            # Write data
            for t in thresholds:
                data_vals = [str(t[col]) for col in header_cols]
                f.write("\t".join(data_vals) + "\n")

        print(f"Written detailed info for {len(thresholds)} genes to {filename}")
    except Exception as e:
        print(f"Error writing {filename}: {e}")

In [None]:
def write_bv_txt(thresholds, prefix):
    """Write BitVector file."""
    filename = f"{prefix}-bv.txt"
    try:
        with open(filename, 'w') as f:
            # Write header
            f.write("ensembl_id\tgene_symbol\tBitVector\n")

            # Write data
            for t in thresholds:
                f.write(f"{t['ensembl_id']}\t{t['gene_symbol']}\t{t['BV']}\n")

        print(f"Written BitVectors for {len(thresholds)} genes to {filename}")
    except Exception as e:
        print(f"Error writing {filename}: {e}")

In [None]:
# Legacy function names for backward compatibility
def findThresholds(expr_df):
    """Legacy wrapper for find_thresholds()."""
    return find_thresholds(expr_df)

def writeThrTxt(thresholds, prefix):
    """Legacy wrapper for write_thr_txt()."""
    return write_thr_txt(thresholds, prefix)

def writeInfoTxt(thresholds, prefix):
    """Legacy wrapper for write_info_txt()."""
    return write_info_txt(thresholds, prefix)

def writeBVTxt(thresholds, prefix):
    """Legacy wrapper for write_bv_txt()."""
    return write_bv_txt(thresholds, prefix)

In [None]:
stepminer(prefix)