In [None]:
# correlate_trs_with_genes.py
# Purpose:
# 1. Load TRS scores and original gene expression data
# 2. Compute Pearson correlation between each functional gene and TRS
# 3. Identify and visualize top genes correlated with TRS
# 4. Save correlation results for downstream pathway analysis

import os
import pandas as pd
from scipy.stats import pearsonr
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np

print("--- Script start: correlate TRS with gene expression ---")

In [None]:
# --- Step 0: Paths and parameters ---
print("\n--- Step 0: Configure paths and parameters ---")
BASE_DIR = r"D:\结直肠癌肝转移Biomarker 诊断\新的策略\Autoencoder"
# Note: input directory is the output folder from previous scripts
INPUT_DIR = os.path.join(BASE_DIR, "跨平台模型", "multi_omics_results")
OUTPUT_DIR = INPUT_DIR

# Input files
TRS_FILE = os.path.join(INPUT_DIR, "transcriptomic_risk_signatures.csv")
EXPRESSION_FILE = os.path.join(BASE_DIR, "expression_data_combat_corrected.csv")
FUNC_GENES_FILE = os.path.join(BASE_DIR, "functional_genes_620.txt")

# Output files
CORRELATION_OUTPUT_FILE = os.path.join(OUTPUT_DIR, "trs_gene_correlations.csv")
CORRELATION_PLOT_FILE = os.path.join(OUTPUT_DIR, "trs_gene_correlations_plot.png")

print(f"TRS file: {TRS_FILE}")
print(f"Expression data file: {EXPRESSION_FILE}")
print(f"Output correlation file: {CORRELATION_OUTPUT_FILE}")

In [None]:
# --- Step 1: Load data ---
print("\n--- Step 1: Load data ---")
try:
    trs_df = pd.read_csv(TRS_FILE, index_col=0)
    expression_data = pd.read_csv(EXPRESSION_FILE, index_col=0)
    with open(FUNC_GENES_FILE, 'r') as f:
        functional_genes = [line.strip() for line in f.readlines() if line.strip()]
    print("All input files loaded successfully.")
except FileNotFoundError as e:
    print(f"ERROR: File not found - {e}")
    print("Please check input file paths.")
    raise

In [None]:
# --- Step 2: Align and prepare data ---
print("\n--- Step 2: Align TRS and expression data ---")
TRS_CORRECTED_FILE = os.path.join(INPUT_DIR, "trs_prediction_results_corrected.csv")
if os.path.exists(TRS_CORRECTED_FILE):
    trs_results = pd.read_csv(TRS_CORRECTED_FILE)
    # extract TRS score and set Sample_ID as index
    trs_df = trs_results.set_index('Sample_ID')[['TRS_Score']]
    trs_df.columns = ['TRS_1']  # uniform column name
    print(f"Loaded corrected TRS file: {TRS_CORRECTED_FILE}")
else:
    # fallback to original TRS file
    trs_df = pd.read_csv(TRS_FILE, index_col=0)
    print(f"Warning: corrected TRS file not found; loaded original TRS: {TRS_FILE}")

In [None]:
# Ensure sample intersection
common_samples = trs_df.index.intersection(expression_data.index)
trs_aligned = trs_df.loc[common_samples]
expression_aligned = expression_data.loc[common_samples]

# Filter functional genes present in expression data
available_genes = [gene for gene in functional_genes if gene in expression_aligned.columns]

# Use corrected TRS (higher means higher metastasis risk)
trs_scores = trs_aligned['TRS_1']
expression_subset = expression_aligned[available_genes]

print(f"Data aligned: {len(common_samples)} samples will be analyzed.")
print(f"Analyzing {len(available_genes)} functional genes.")
print(f"TRS score range: {trs_scores.min():.4f} ~ {trs_scores.max():.4f}")
print("TRS semantic: higher values indicate higher metastasis risk.")

In [None]:
# --- Step 3: Compute correlations ---
print("\n--- Step 3: Compute Pearson correlation between each gene and TRS ---")
correlation_results = []

for gene in available_genes:
    gene_expression = expression_subset[gene].astype(float)
    try:
        corr, p_value = pearsonr(gene_expression, trs_scores)
    except Exception:
        corr, p_value = (np.nan, np.nan)
    correlation_results.append({
        'Gene': gene,
        'Correlation': corr,
        'P_value': p_value
    })

correlation_df = pd.DataFrame(correlation_results)
correlation_df['Abs_Correlation'] = correlation_df['Correlation'].abs()
correlation_df = correlation_df.sort_values(by='Abs_Correlation', ascending=False)

print("Correlation calculation completed.")

In [None]:
# --- Step 4: Save results ---
print("\n--- Step 4: Save correlation results ---")
correlation_df.to_csv(CORRELATION_OUTPUT_FILE, index=False)
print(f"Correlation results saved to: {CORRELATION_OUTPUT_FILE}")

# --- Step 5: Visualize top correlated genes ---
print("\n--- Step 5: Visualize top correlated genes ---")
# Select top genes for visualization (combination of positive and negative)
top_pos_corr = correlation_df.sort_values('Correlation', ascending=False).head(30)
top_neg_corr = correlation_df.sort_values('Correlation', ascending=True).head(25)
top_genes_vis = pd.concat([top_pos_corr, top_neg_corr]).sort_values('Correlation')

plt.style.use('seaborn-v0_8-whitegrid')
fig, ax = plt.subplots(figsize=(12, 8))

In [None]:
# Red = risk-promoting, Blue = protective
colors = ['#c23616' if c > 0 else '#192a56' for c in top_genes_vis['Correlation']]
ax.barh(top_genes_vis['Gene'], top_genes_vis['Correlation'], color=colors)

ax.set_xlabel('Pearson Correlation with TRS Score (Metastasis Risk)', fontsize=12, fontweight='bold')
ax.set_ylabel('Gene', fontsize=12, fontweight='bold')
ax.set_title('Top Genes Most Correlated with TRS (Corrected)\nRed = Risk-promoting genes; Blue = Protective genes\nHigher TRS = Higher Metastasis Risk', fontsize=15, fontweight='bold', pad=20)
ax.tick_params(axis='y', labelsize=11)
ax.grid(axis='x', linestyle='--', alpha=0.6)
ax.spines['top'].set_visible(False)
ax.spines['right'].set_visible(False)

In [None]:
# Add numeric labels
for i, (value, name) in enumerate(zip(top_genes_vis['Correlation'], top_genes_vis['Gene'])):
    ax.text(value + 0.01 if value > 0 else value - 0.01, i, f'{value:.3f}',
            ha='left' if value > 0 else 'right',
            va='center',
            fontweight='medium',
            fontsize=10)

plt.tight_layout()
plt.savefig(CORRELATION_PLOT_FILE, dpi=300, bbox_inches='tight')
print(f"Correlation plot saved to: {CORRELATION_PLOT_FILE}")
plt.show()

In [None]:
# --- Step 6: Query specific gene information (example: ACMSD) ---
print("\n--- Step 6: Query specific gene info (example: 'ACMSD') ---")
gene_to_find = "ACMSD"

if 'correlation_df' in locals() and isinstance(correlation_df, pd.DataFrame):
    correlation_df_ranked = correlation_df.sort_values(by='Abs_Correlation', ascending=False).reset_index(drop=True)
    correlation_df_ranked['Rank'] = correlation_df_ranked.index + 1

    gene_info = correlation_df_ranked[correlation_df_ranked['Gene'] == gene_to_find]

    if not gene_info.empty:
        gene_stats = gene_info.iloc[0]
        print(f"\nAnalysis result for gene '{gene_to_find}' (based on corrected TRS):")
        print("-" * 40)
        print(f"  - Correlation with TRS: {gene_stats['Correlation']:.6f}")
        print(f"  - Importance rank: {int(gene_stats['Rank'])} / {len(correlation_df_ranked)}")
        print(f"  - P-value: {gene_stats['P_value']:.4e}")
        if gene_stats['Correlation'] > 0:
            print(f"  - Interpretation: '{gene_to_find}' is associated with increased metastasis risk (risk-promoting).")
            print("    Higher expression -> higher TRS -> increased liver metastasis risk.")
        else:
            print(f"  - Interpretation: '{gene_to_find}' is associated with protection against metastasis.")
            print("    Higher expression -> lower TRS -> decreased liver metastasis risk.")

        # Significance level
        pv = gene_stats['P_value']
        if pd.isna(pv):
            sig_level = "n.s."
        elif pv < 1e-3:
            sig_level = "p < 0.001 (*** )"
        elif pv < 1e-2:
            sig_level = "p < 0.01 (**) "
        elif pv < 0.05:
            sig_level = "p < 0.05 (*)"
        else:
            sig_level = "n.s."
        print(f"  - Statistical significance: {sig_level}")
    else:
        print(f"Gene '{gene_to_find}' was not found in the functional gene list. Please check the gene symbol.")
else:
    print("Correlation data not found. Please run the analysis steps above first.")            