# Which allele frequency (AF) values should we choose for our simulated data?
- look at the distribution of AF values from the germline and somatic mutation matrices (samples x gene-level; binary genotype matrix)
- choose values near the extremes (can always fill in values later)

In paper, will reference the "control frequency" instead of allele frequency, which we define as the proportion of samples in class 0 (primary cancer) that have an event (mutation).


In [None]:
import logging
import os

import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns

from pnet.data_processing import prostate_data_loaders

# Setup Logging and Configuration
logging.basicConfig(
    format="%(asctime)s %(levelname)-8s [%(name)s] %(message)s",
    level=logging.INFO,
    datefmt="%Y-%m-%d %H:%M:%S",
    force=True,
)
logger = logging.getLogger(__name__)

## Loading genetic data

In [None]:
# loading data (paired input data, germline and somatic)
# data_dir = "/mnt/disks/gmiller_data1/pnet_germline/data/prostate/paired_germline_somatic/"
# data_dir = "/mnt/disks/gmiller_data1/pnet_germline/processed/wandb-group-data_prep_germline_tier12_and_somatic/converted-IDs-to-somatic_imputed-germline_True_imputed-somatic_False_paired-samples-True/wandb-run-id-q151d0zw"
# somatic_f = os.path.join(data_dir, "somatic_mut.csv")
# germline_f = os.path.join(data_dir, "germline_rare_common_lof_missense.csv")

somatic_datadir = "/mnt/disks/gmiller_data1/pnet_germline/data/pnet_database/prostate/processed"
germline_datadir = "/mnt/disks/gmiller_data1/pnet_germline/data"

somatic_df = prostate_data_loaders.get_somatic_mutation(
            os.path.join(somatic_datadir, "P1000_final_analysis_set_cross_important_only.csv"))

# Unfiltered germline data, accidentally used in my previous analysis
# germline_df = prostate_data_loaders.get_germline_mutation(
#             os.path.join(
#                 germline_datadir,
#                 "prostate/prostate_germline_vcf_subset_to_germline_tier_12_and_somatic_passed-universal-filters_rare_common_high-impact_moderate-impact.txt",
#             ))


germline_df = prostate_data_loaders.get_germline_mutation(
            os.path.join(
                germline_datadir,
                "prostate/prostate_germline_vcf_subset_to_germline_tier_12_and_somatic_passed-universal-filters_patho-vars-only_rare_common_high-impact_moderate-impact.txt",
            ))

y = pd.read_csv(os.path.join(somatic_datadir, "response_paper.csv"))
y.set_index("id", inplace=True)

In [None]:
harmonized_data_dir = "/mnt/disks/gmiller_data1/pnet_germline/processed/wandb-group-data_prep_germline_tier12_and_somatic/converted-IDs-to-somatic_imputed-germline_True_imputed-somatic_False_paired-samples-True/wandb-run-id-u5yt90p1"
somatic_df = pd.read_csv(os.path.join(harmonized_data_dir, "somatic_mut.csv"), index_col=0)
somatic_df = somatic_df = somatic_df.loc[:, (somatic_df != 0).any(axis=0)]

germline_df = pd.read_csv(os.path.join(harmonized_data_dir, "germline_rare_common_lof_missense.csv"), index_col=0)
germline_df = germline_df.loc[:, (germline_df != 0).any(axis=0)]
y = pd.read_csv(os.path.join(harmonized_data_dir, "y.csv"), index_col=0)
y.rename(columns={"is_met": "class"}, inplace=True)

print("somatic_df.shape", somatic_df.shape)
print("germline_df.shape", germline_df.shape)
print("y.shape", y.shape)

In [None]:
# show basic info about the data
print("Somatic data shape:", somatic_df.shape)
print("Germline data shape:", germline_df.shape)
print("First 5 columns of somatic data:", somatic_df.columns[:5].tolist())
print("First 5 columns of germline data:", germline_df.columns[:5].tolist())
print("First 5 rows of somatic data:", somatic_df.index[:5].tolist())
print("First 5 rows of germline data:", germline_df.index[:5].tolist())

## Allele frequencies by class and dataset

In [None]:
def compute_afs(binary_df):
    """Compute allele frequencies for each gene (mean across samples)."""
    return binary_df.mean(axis=0)

def compute_deciles(afs):
    """Compute deciles (10th to 90th percentile) from allele frequencies."""
    deciles = afs.quantile([i / 10 for i in range(0, 11)]).reset_index()
    deciles.columns = ['Decile', 'AF']
    return deciles


def plot_af_histogram(afs, title="Allele Frequency Distribution"):
    """Plot histogram of allele frequencies."""
    plt.figure()
    sns.histplot(afs, bins=30, kde=True)
    plt.title(title)
    plt.xlabel("Allele Frequency")
    plt.ylabel("Number of Genes")
    plt.show()


def plot_af_boxplot(afs, title="Allele Frequency Boxplot"):
    """Plot boxplot of allele frequencies."""
    plt.figure()
    plt.boxplot(afs, vert=False)
    plt.title(title)
    plt.xlabel("Allele Frequency")
    plt.show()


def analyze_af_distribution(binary_df, label="Dataset"):
    afs = compute_afs(binary_df)
    deciles = compute_deciles(afs)
    deciles.rename(columns={'AF': f'{label} AF'}, inplace=True)
    plot_af_histogram(afs, title=f"{label} - Histogram")
    plot_af_boxplot(afs, title=f"{label} - Boxplot")
    return afs, deciles


In [None]:
# call analyze_af_distribution on somatic and germline data
afs_somatic, deciles_somatic = analyze_af_distribution(somatic_df, label="Somatic")
afs_germline, deciles_germline = analyze_af_distribution(germline_df, label="Germline")
display(deciles_somatic)
display(deciles_germline)

merged_df = pd.merge(deciles_somatic, deciles_germline, on='Decile')
merged_df

In [None]:
def calculate_proportions(somatic_mut, y, gene_list):
    logger.info("Group by class and compute mean proportions")
    # Align y to somatic_mut by index
    y_aligned = y.loc[somatic_mut.index]
    
    # Subset to genes of interest
    gene_df = somatic_mut[gene_list]
    
    # Combine with response variable
    combined = gene_df.copy()
    combined['class'] = y_aligned
    
    # Group by class and compute proportions (mean of binary values)
    proportions = combined.groupby('class').mean().T
    
    return proportions

def calculate_proportions_with_rr(somatic_mut, y, gene_list, metric=):
    proportions = calculate_proportions(somatic_mut, y, gene_list)
    
    logger.warn("Compute OR (class 1 proportion / class 0 proportion) <-- warning this is relative risk not OR!!")
    proportions['OR'] = proportions[1] / proportions[0]
    
    return proportions



In [None]:
# AR, TP53, PTEN, APC, GNAS, RAC1, OBSCN, MAML3, MUC4
# check AF / OR of genes P-NET finds important for somatic data
af_by_class = calculate_proportions_with_or(somatic_df, y, ['AR', 'TP53', 'PTEN', 'APC', 'GNAS', 'RAC1', 'OBSCN', 'MAML3', 'MUC4'])
af_by_class.sort_values(by='OR', ascending=False).round(4)


In [None]:
print("Top 20 germline by AF")
afs_germline.nlargest(20)

In [None]:
# top 5 genes with highest allele frequency
top_5_somatic = afs_somatic.nlargest(10)
top_5_germline = afs_germline.nlargest(10)
print("Top 10 somatic genes with highest allele frequency:")
print(top_5_somatic)
print("Top 10 germline genes with highest allele frequency:")
print(top_5_germline)

In [None]:
# germline genes with AF = 0 (genes which have no mutations at all. Maybe we should exclude these genes from the analysis? But we'd zero-mpute them anyway when combining with somatic)
afs_germline[afs_germline == 0.]

## Explore gene-gene correlation matrix by class and dataset
- class: met vs primary (1 vs 0)
- dataset: somatic vs germline

In [None]:
def hcluster_corr_matrix(corr_matrix):
    """
    Perform hierarchical clustering on a correlation matrix and plot the dendrogram.
    """
    from scipy.cluster.hierarchy import dendrogram, linkage
    from scipy.spatial.distance import squareform

    # Compute the distance matrix
    dist_matrix = squareform(1 - corr_matrix)

    # Perform hierarchical clustering
    Z = linkage(dist_matrix, method='average')

    # Plot the dendrogram
    plt.figure(figsize=(10, 7))
    dendrogram(Z, labels=corr_matrix.columns, leaf_rotation=90)
    plt.title("Hierarchical Clustering Dendrogram")
    plt.xlabel("Samples")
    plt.ylabel("Distance")
    plt.show()

### Calculate correlation matrices

In [None]:
def get_lowrank_matrix(X, top=20):
    eigvals, eigvecs = np.linalg.eigh(X)
    L = eigvecs[:, -top:] @ np.diag(np.sqrt(eigvals[-top:]))
    Sigma_lowrank = L @ L.T
    return Sigma_lowrank

def get_scree_plot(X, max_num_eigenvalues=20, title='Scree Plot'):
    """
    Plot the scree plot of the eigenvalues of the covariance matrix.
    """
    eigvals, _ = np.linalg.eigh(X)
    # Sort in descending order
    eigvals_sorted = eigvals[::-1]

    # Proportion of variance explained
    explained_variance = eigvals_sorted / eigvals_sorted.sum()
    explained_variance = explained_variance[:max_num_eigenvalues]

    plt.figure(figsize=(8, 5))
    plt.plot(np.arange(1, len(explained_variance)+1), explained_variance, marker='o')
    plt.xlabel('Eigenvalue Rank')
    plt.ylabel('Proportion of Variance Explained')
    plt.title(title)
    plt.grid(True)
    plt.tight_layout()
    plt.show()
    return

def get_cumulative_variance_plot(X, max_num_eigenvalues=20, title='Cumulative Variance Plot'):
    """
    Plot the cumulative variance explained by the eigenvalues of the covariance matrix.
    """
    eigvals, _ = np.linalg.eigh(X)
    # Sort in descending order
    eigvals_sorted = eigvals[::-1]

    # Cumulative variance
    explained_variance = eigvals_sorted / eigvals_sorted.sum()
    cumulative_variance = np.cumsum(explained_variance)
    # cumulative_variance = np.cumsum(eigvals_sorted) / np.sum(eigvals_sorted)

    cumulative_variance = cumulative_variance[:max_num_eigenvalues]

    plt.figure(figsize=(8, 5))
    plt.plot(np.arange(1, len(cumulative_variance)+1), cumulative_variance, marker='o')
    plt.xlabel('Eigenvalue Rank')
    plt.ylabel('Cumulative Variance Explained')
    plt.title(title)
    plt.grid(True)
    plt.tight_layout()
    plt.show()

def num_components_for_variance(eigvals, threshold=0.8):
    """
    Returns the minimum number of eigenvalues needed to reach a given cumulative variance threshold.
    
    Parameters:
        eigvals (array-like): Eigenvalues (not necessarily sorted).
        threshold (float): Cumulative variance threshold (between 0 and 1).
    
    Returns:
        int: Number of components required to reach the threshold.
    """
    eigvals = np.array(eigvals)
    eigvals_sorted = np.sort(eigvals)[::-1]  # descending order
    cumulative = np.cumsum(eigvals_sorted) / eigvals_sorted.sum()
    num_components = np.searchsorted(cumulative, threshold) + 1
    return num_components

In [None]:
X = somatic_df.copy()
denoise=False # Set to True to denoise the correlation matrices by low-rank approximation


X_class0 = X.loc[y[y["class"] == 0].index]
X_class1 = X.loc[y[y["class"] == 1].index]

# Drop genes with all 0s in their respective classes
X_class0 = X_class0.loc[:, X_class0.sum(axis=0) > 0]
X_class1 = X_class1.loc[:, X_class1.sum(axis=0) > 0]

# Ensure both classes have the same set of genes
common_genes = X_class0.columns.intersection(X_class1.columns)
X_class0 = X_class0[common_genes]
X_class1 = X_class1[common_genes]


corr_class0 = pd.DataFrame(np.corrcoef(X_class0, rowvar=False), columns=X_class0.columns, index=X_class0.columns)
corr_class1 = pd.DataFrame(np.corrcoef(X_class1, rowvar=False), columns=X_class1.columns, index=X_class1.columns)

if denoise:
    corr_class0.values = get_lowrank_matrix(corr_class0.values, top=10)
    corr_class1.values = get_lowrank_matrix(corr_class1.values, top=10)
    
corr_diff = corr_class1 - corr_class0

In [None]:
# X = somatic_df.copy()
# denoise=False # Set to True to denoise the correlation matrices by low-rank approximation


# X_class0 = X.loc[y[y["class"] == 0].index]
# X_class1 = X.loc[y[y["class"] == 1].index]

# print('MDM4' in X_class0.columns.tolist())

# # get the mean of MDM4 column in class 0
# mean_MDM4_class0 = X_class0.mean(axis=0)['MDM4']
# # get the mean of MDM4 column in class 1
# mean_MDM4_class1 = X_class1.mean(axis=0)['MDM4']
# print(f"Mean MDM4 in class 0: {mean_MDM4_class0}")
# print(f"Mean MDM4 in class 1: {mean_MDM4_class1}")

### Eigenvalue structure of the class-specific gene-gene correlation matrices

In [None]:
eigvals = np.linalg.eigvalsh(corr_class0)
k = num_components_for_variance(eigvals, threshold=0.6)
print(f"Number of components to explain 60% of variance in class 0 correlation matrix: {k}")


eigvals = np.linalg.eigvalsh(corr_class1)
k = num_components_for_variance(eigvals, threshold=0.6)
print(f"Number of components to explain 60% of variance in class 1 correlation matrix: {k}")




In [None]:
get_scree_plot(corr_class0, title='Scree Plot of Class 0')
get_scree_plot(corr_class1, title='Scree Plot of Class 1')

In [None]:
get_cumulative_variance_plot(corr_class0, title='Cumulative Variance Plot of Class 0')
get_cumulative_variance_plot(corr_class1, title='Cumulative Variance Plot of Class 1')

### Correlation structure of genes that P-NET assigns high importance score?

In [None]:
PROJECT_NAME = "millergw/prostate_met_status"
GROUP_NAME = "pnet_somatic_and_germline_exp_004"
SWEEP_ID = "rv4lm363" # germline filtered for pathogenicity
# Define directories for saving results
# FIGDIR = f"../figures/{GROUP_NAME}/"
# RESULTS_DIR = f"../results/{GROUP_NAME}/"

In [None]:
# Fetch the runs
def fetch_wandb_runs(project_name, sweep_id):
    api = wandb.Api()
    runs = api.runs(project_name, filters={"sweep": sweep_id, "state": "finished"})
    return runs

def fetch_feature_importance_paths(runs, group_name, who="validation"):
    # You can use this DataFrame to access the paths to the gene x modality importances files

    # Create an empty list to store DataFrames
    dfs = []
    for run in runs:
        model_type = run.config.get("model_type", "unknown")
        datasets = run.config.get("datasets", "unknown")
        subdir_name = f"{model_type}_eval_set_{who}"

        # Paths to feature and gene importances files
        feature_importances_path = f'../results/{group_name}/{subdir_name}/wandbID_{run.id}/{who}_gene_feature_importances.csv'
        gene_importances_path = f'../results/{group_name}/{subdir_name}/wandbID_{run.id}/{who}_gene_importances.csv'

        # Create a DataFrame for the current run
        df = pd.DataFrame({
            "run_id": [run.id],
            "model_type": [model_type],
            "datasets": [datasets],
            "feature_importances_path": [feature_importances_path],
            "gene_importances_path": [gene_importances_path],
        })

        # Append the DataFrame to the list
        dfs.append(df)

    # Concatenate all DataFrames in the list into a single DataFrame
    df = pd.concat(dfs, ignore_index=True)
    df.set_index("run_id", inplace=True)

    return df


def load_feature_importances(importances_path):
    """
    Load feature importance data from a CSV file.

    Args:
        importances_path (str): Path to the feature importance file.

    Returns:
        pd.DataFrame: DataFrame of feature importances.
    """
    if not os.path.exists(importances_path):
        raise FileNotFoundError(f"File not found: {importances_path}")
    return pd.read_csv(importances_path).set_index('Unnamed: 0')


def process_importances(imps, response_df):
    """
    Process feature importances by joining with response data and calculating differences between sample classes.
    This function computes the mean feature importances for each response class and then calculates the difference between them.
    The result is a Series of feature importance differences.

    Args:
        imps (pd.DataFrame): Feature importance DataFrame.
        response_df (pd.DataFrame): Response variable DataFrame.

    Returns:
        pd.Series: Processed feature importance differences.
    """
    logger.debug(f"head of imps.join(response_df).groupby('response').mean(): {imps.join(response_df).groupby('response').mean().head()}")
    logger.debug(f"shape of imps.join(response_df).groupby('response').mean(): {imps.join(response_df).groupby('response').mean().shape}")
    return imps.join(response_df).groupby('response').mean().diff(axis=0).iloc[1]

def process_feature_importances(df_feature_importance_paths, response_df):
    """
    Process feature importance data for multiple runs and group by dataset.

    Args:
        df_feature_importance_paths (pd.DataFrame): DataFrame with feature importance paths and dataset info.
        response_df (pd.DataFrame): Response variable DataFrame.

    Returns:
        tuple: Dictionaries of DataFrames for feature importances and ranks grouped by the input datasets used in the model.
    """
    imps_by_key = {}
    ranks_by_key = {}

    for _, row in df_feature_importance_paths.iterrows():
        try:
            imps = load_feature_importances(row['feature_importances_path'])
            processed_imps = process_importances(imps, response_df)
            ranks = processed_imps.abs().rank(ascending=False)

            key = row['datasets']
            imps_by_key.setdefault(key, []).append(processed_imps)
            ranks_by_key.setdefault(key, []).append(ranks)
        except FileNotFoundError as e:
            print(e)
            continue

    df_imps_by_key = {key: pd.DataFrame(imps_list) for key, imps_list in imps_by_key.items()}
    df_ranks_by_key = {key: pd.DataFrame(ranks_list) for key, ranks_list in ranks_by_key.items()}

    return df_imps_by_key, df_ranks_by_key


def load_response_variable(response_path='../../pnet_germline/data/pnet_database/prostate/processed/response_paper.csv'):
    # Load the response variable DataFrame
    response_df = pd.read_csv(response_path)
    response_df.rename(columns={'id': "Tumor_Sample_Barcode"}, inplace=True)
    response_df.set_index('Tumor_Sample_Barcode', inplace=True)
    return response_df


def extract_top_features_from_df(df_per_dataset, top_n=10, keep_smallest_n=True, index_label=None):
    """
    Extract the top N features by rank for each dataset.

    Args:
        df_per_dataset (dict): Dictionary containing feature-level pd DataFrames for each dataset.
        top_n (int): Number of top features to extract.
        keep_smallest_n (bool): Whether to sort in ascending order (lower rank is better).

    Returns:
        pd.DataFrame: DataFrame containing the top N features for each dataset.
    """
    top_features_df = pd.DataFrame()

    # Iterate over the dictionary to calculate top features
    for dataset, df in df_per_dataset.items():
        # Calculate the mean rank for each feature and select the top N
        top_features = df.mean(axis=0).sort_values(ascending=keep_smallest_n)[:top_n]
        # Add the top features as a column to the DataFrame
        top_features_df[dataset] = top_features.index

    # Set the index of the DataFrame to be 1 through top_n
    top_features_df.index = range(1, top_n + 1)

    if index_label is not None:
        top_features_df.index.name = index_label

    return top_features_df

def extract_top_features_from_series(series_per_dataset, top_n=10, keep_smallest_n=True, index_label=None):
    """
    Extract the top N features by rank for each dataset.

    Args:
        series_per_dataset (dict): Dictionary containing feature-level pd Series for each dataset.
        top_n (int): Number of top features to extract.
        keep_smallest_n (bool): Whether to sort in ascending order (lower rank is better).

    Returns:
        pd.DataFrame: DataFrame containing the top N features for each dataset.
    """
    top_features_df = pd.DataFrame()

    # Iterate over the dictionary to calculate top features
    for dataset, values in series_per_dataset.items():
        # Calculate the mean rank for each feature and select the top N
        top_features = values.sort_values(ascending=keep_smallest_n)[:top_n]
        # Add the top features as a column to the DataFrame
        top_features_df[dataset] = top_features.index

    # Set the index of the DataFrame to be 1 through top_n
    top_features_df.index = range(1, top_n + 1)

    if index_label is not None:
        top_features_df.index.name = index_label

    return top_features_df

In [None]:
logger.info("Get information on the sweep runs")
MODEL_TYPE = "pnet"
runs = fetch_wandb_runs(PROJECT_NAME, SWEEP_ID)
df_feature_importance_paths = fetch_feature_importance_paths(runs, GROUP_NAME)
logger.info(f"Filter to runs where model_type is {MODEL_TYPE}")
df_feature_importance_paths = df_feature_importance_paths[df_feature_importance_paths['model_type'] == MODEL_TYPE]
df_feature_importance_paths.head()

In [None]:
df_feature_importance_paths

In [None]:
logger.info("Get information on the sweep runs")
MODEL_TYPE = "pnet"
runs = fetch_wandb_runs(PROJECT_NAME, SWEEP_ID)
df_feature_importance_paths = fetch_feature_importance_paths(runs, GROUP_NAME)
logger.info(f"Filter to runs where model_type is {MODEL_TYPE}")
df_feature_importance_paths = df_feature_importance_paths[df_feature_importance_paths['model_type'] == MODEL_TYPE]
logger.info("Filter to just somatic datasets (somatic_amp, somatic_del, and somatic_mut)")
df_feature_importance_paths = df_feature_importance_paths[df_feature_importance_paths['datasets'] == "somatic_amp somatic_del somatic_mut"]
display(df_feature_importance_paths.head())

logger.info("Get the feature importances and ranks for each dataset combination")
response_df = load_response_variable()
df_imps_by_key, df_ranks_by_key = process_feature_importances(df_feature_importance_paths, response_df)

logger.info("Example df_imps_by_key DF:")
display(df_imps_by_key['somatic_amp somatic_del somatic_mut'].head())

N = 40  # Number of top features to select
top_N_features_by_rank = extract_top_features_from_df(df_ranks_by_key, top_n=N, keep_smallest_n=True, index_label="rank")

top_genes = top_N_features_by_rank['somatic_amp somatic_del somatic_mut'][:N].values.tolist()


In [None]:

# extract just the gene names from the top_genes list
# keep the same order but remove duplicates
top_genes = list(dict.fromkeys([gene.split("_")[0] for gene in top_genes]))
print(len(top_genes), "unique genes in top_genes")
top_genes = [i for i in top_genes if i in common_genes]
print(len(top_genes), "unique genes in top_genes when filtered to common genes")


In [None]:
def plot_corr_matrix(corr_matrix, title='Correlation Matrix', figsize=(10, 8)):
    """
    Plot a correlation matrix using seaborn heatmap.
    
    Args:
        corr_matrix (pd.DataFrame): Correlation matrix to plot.
        title (str): Title of the plot.
        figsize (tuple): Size of the figure.
    """
    g = sns.clustermap(
        corr_matrix,
        method='average',
        metric='euclidean',
        cmap='vlag',
        center=0,
        figsize=figsize,  # adjust height based on number of top_genes
    )

    g.ax_row_dendrogram.set_visible(False)
    g.ax_col_dendrogram.set_visible(False)
    g.ax_heatmap.set_xticks([])
    plt.title(title)
    plt.show()
    return

# Plot the correlation matrix for the difference
# Subset: rows = top_genes, columns = all genes
subset_corr_diff = corr_diff.loc[top_genes, :]
plot_corr_matrix(subset_corr_diff, title='Top Genes vs All Genes (Class 1 - Class 0)', figsize=(12, len(top_genes) * 0.5))

# subset_corr_class1 = corr_class1.loc[top_genes, :]
# plot_corr_matrix(subset_corr_class1, title='Top Genes vs All Genes (Class 1)', figsize=(12, len(top_genes) * 0.5))

# subset_corr_class0 = corr_class0.loc[top_genes, :]
# plot_corr_matrix(subset_corr_class0, title='Top Genes vs All Genes (Class 0)', figsize=(12, len(top_genes) * 0.5))

In [None]:
# get deciles of subset_corr_diff
deciles = np.percentile(subset_corr_diff.values.flatten(), np.arange(0, 101, 10))
# display the deciles as a table
deciles_df = pd.DataFrame(deciles, columns=["Decile"])
deciles_df

### Difference in "AF" between classes for genes assigned high importance by P-NET?
For same `top_genes` that we just examined the gene-gene correlation matrices, how do the rates of events differ between classes (met vs primary)?
- Are there any genes that don't have a difference in rate between classes? If so, do these show different gene-gene correlation between classes?


In [None]:
import pandas as pd
from scipy.stats import fisher_exact, false_discovery_control

# Step 1: Compute frequencies (proportion of 1s) for each gene
freq_class0 = X_class0.mean(axis=0)
freq_class1 = X_class1.mean(axis=0)

# Step 2: Compute frequency difference
freq_diff = freq_class1 - freq_class0

# Step 3: Perform statistical test (Fisher's exact test) for each gene
p_values = []
o_ratios = []
r_risks = []
for gene in common_genes:
    # Build contingency table:
    # [[#1s in class1, #0s in class1],
    #  [#1s in class0, #0s in class0]]
    a = X_class1[gene].sum()
    b = len(X_class1) - a
    c = X_class0[gene].sum()
    d = len(X_class0) - c
    contingency = [[a, b], [c, d]]
    
    oratio, p = fisher_exact(contingency, alternative='two-sided')  # or 'greater', 'less'
    o_ratios.append(oratio)
    p_values.append(p)

    # Relative Risk calculation
    risk_class1 = a / (a + b) if (a + b) > 0 else float('nan')
    risk_class0 = c / (c + d) if (c + d) > 0 else float('nan')
    rr = (risk_class1 / risk_class0) if risk_class0 > 0 else float('inf')
    r_risks.append(rr)

# Step 4: Assemble into DataFrame
result_df = pd.DataFrame({
    "gene": common_genes,
    "freq_class0": freq_class0.values,
    "freq_class1": freq_class1.values,
    "freq_diff": freq_diff.values,
    "odds_ratio": o_ratios,
    "relative_risk": r_risks,
    "p_value": p_values
})

# Optional: adjust p-values (e.g., Benjamini-Hochberg FDR)
adjusted_pvals = false_discovery_control(p_values, method='bh')
result_df["adj_p_value"] = adjusted_pvals

# Sort by largest absolute difference or significance
result_df = result_df.sort_values(by="freq_diff", key=abs, ascending=False).round(3)
result_df


In [None]:
# filter result_df to only include genes in top_genes
result_df_top_genes = result_df[result_df['gene'].isin(top_genes)].sort_values(by="odds_ratio", key=abs, ascending=True).round(3)
print("Top genes with frequency differences and p-values:")
display(result_df_top_genes)

### Corr matrix plots

In [None]:
# Assume corr_class0 is a cleaned DataFrame (no NaNs)
g = sns.clustermap(
    corr_class0,
    method='average',     # linkage method: 'average', 'complete', etc.
    metric='euclidean',   # or 'correlation', 'cityblock', etc.
    cmap='vlag',
    center=0,
    figsize=(10, 10),
    
)
g.ax_heatmap.set_xticks([])
g.ax_heatmap.set_yticks([])
plt.title("Clustered Gene-Gene Correlation (Class 0)")
plt.show()

In [None]:
# Assume corr_class0 is a cleaned DataFrame (no NaNs)
g = sns.clustermap(
    corr_class1,
    method='average',     # linkage method: 'average', 'complete', etc.
    metric='euclidean',   # or 'correlation', 'cityblock', etc.
    cmap='vlag',
    center=0,
    figsize=(10, 10),
    
)
g.ax_heatmap.set_xticks([])
g.ax_heatmap.set_yticks([])
plt.title("Clustered Gene-Gene Correlation (Class 1)")
plt.show()


In [None]:
# Assume corr_diff is a cleaned DataFrame (no NaNs)
g = sns.clustermap(
    corr_diff,
    method='average',     # linkage method: 'average', 'complete', etc.
    metric='euclidean',   # or 'correlation', 'cityblock', etc.
    cmap='vlag',
    center=0,
    figsize=(10, 10),
    
)
g.ax_row_dendrogram.set_visible(False) #suppress row dendrogram
g.ax_col_dendrogram.set_visible(False) #suppress column dendrogram
g.ax_heatmap.set_xticks([])
g.ax_heatmap.set_yticks([])
plt.title("Clustered Gene-Gene Correlation (Class 1 - Class 0)")
plt.show()


In [None]:

# Use distance threshold instead of fixed number of clusters
# You can tune this (e.g., 0.5 → looser, 1.5 → tighter clusters)
distance_threshold = 3

cluster_labels = fcluster(g.dendrogram_row.linkage, t=distance_threshold, criterion='maxclust')

# Get reordered gene list
reordered_genes = corr_diff.index[g.dendrogram_row.reordered_ind]

# Create cluster color labels
cluster_df = pd.DataFrame({
    'Gene': corr_diff.index,
    'Cluster': cluster_labels
})
cluster_df.set_index('Gene', inplace=True)

# Assign a color to each cluster
unique_clusters = cluster_df['Cluster'].unique()
palette = sns.color_palette("hsv", len(unique_clusters))
lut = dict(zip(unique_clusters, palette))
cluster_colors = cluster_df['Cluster'].map(lut)

# Assume corr_diff is a cleaned DataFrame (no NaNs)
g = sns.clustermap(
    corr_diff,
    row_colors=cluster_colors,
    col_colors=cluster_colors,
    method='average',     # linkage method: 'average', 'complete', etc.
    metric='euclidean',   # or 'correlation', 'cityblock', etc.
    cmap='vlag',
    center=0,
    figsize=(10, 10),
    
)
g.ax_row_dendrogram.set_visible(False) #suppress row dendrogram
g.ax_col_dendrogram.set_visible(False) #suppress column dendrogram
g.ax_heatmap.set_xticks([])
g.ax_heatmap.set_yticks([])
plt.title("Clustered Differential Gene-Gene Correlation with Cluster Annotations")
plt.show()



The vast majority of gene-gene correlations are not different between classes. Only the top 20% or so seem at all relevant. Currently, around 285 genes per decile (because of how I filtered the data). And if I did some dimenionality/noise reduction on each of my class-specific correlation matrices first I might get even cleaner results here.

We see that these genes in cluster 2 are very different between class 1 and class 0.

In [None]:
cluster_df[cluster_df.Cluster == 2].index.tolist()

In [None]:
corr_diff.shape[0]*.1

In [None]:
# Use np.triu_indices to extract the upper triangle without the diagonal
def get_upper_triangle_values(corr_matrix):
    upper_tri_ix = np.triu_indices_from(corr_matrix, k=1)
    return corr_matrix.values[upper_tri_ix]

correlations = get_upper_triangle_values(corr_diff)

deciles = np.percentile(correlations, np.arange(0, 101, 10))

# Wrap in a readable format
gene_corr_decile_df = pd.DataFrame({
    'Decile': [f'{i}th' for i in range(0, 101, 10)],
    'Correlation': deciles
}).round(3)
gene_corr_decile_df
