In [None]:
import scanpy as sc
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import squidpy as sq
from scipy import stats
from scipy.cluster.hierarchy import dendrogram, linkage
from sklearn.preprocessing import StandardScaler
from sklearn.decomposition import PCA
import warnings
warnings.filterwarnings('ignore')

In [None]:
## load nichePCA results
adata = sc.read_h5ad('xenium_kidney_with_domains.h5ad')

print(f"Data shape: {adata.shape}")
print(f"Samples: {adata.obs['sample'].unique()}")
print(f"nichePCA domains per sample:")
for sample in adata.obs['sample'].unique():
    sample_counts = adata.obs[adata.obs['sample'] == sample]['nichepca_domains'].value_counts().sort_index()
    print(f"{sample}: {dict(sample_counts)}")

In [None]:
## domain-specific gene expression

def calculate_domain_markers(adata, domain_key, sample=None, min_pct=0.1, logfc_threshold=0.25):
    if sample:
        adata_subset = adata[adata.obs['sample'] == sample].copy()
    else:
        adata_subset = adata.copy()
    
    domain_markers = {}
    
    for domain in adata_subset.obs[domain_key].unique():
        domain_mask = adata_subset.obs[domain_key] == domain
        other_mask = ~domain_mask
        
        domain_cells = adata_subset.X[domain_mask]
        other_cells = adata_subset.X[other_mask]
        
        if hasattr(domain_cells, 'toarray'):
            domain_cells = domain_cells.toarray()
            other_cells = other_cells.toarray()
        
        domain_mean = domain_cells.mean(axis=0)
        other_mean = other_cells.mean(axis=0)
        
        domain_pct = (domain_cells > 0).mean(axis=0)
        other_pct = (other_cells > 0).mean(axis=0)
        
        log_fc = np.log2((domain_mean + 1e-8) / (other_mean + 1e-8))
        
        significant_genes = (
            (log_fc > logfc_threshold) & 
            (domain_pct > min_pct) & 
            (domain_pct > other_pct)
        )
        
        if significant_genes.sum() > 0:
            gene_indices = np.where(significant_genes)[0]
            gene_names = adata_subset.var.index[gene_indices].tolist()
            gene_logfc = log_fc[gene_indices]
            gene_pct_in = domain_pct[gene_indices]
            gene_pct_out = other_pct[gene_indices]
            
            sorted_indices = np.argsort(gene_logfc)[::-1]
            
            domain_markers[domain] = pd.DataFrame({
                'gene': [gene_names[i] for i in sorted_indices],
                'log_fc': gene_logfc[sorted_indices],
                'pct_in_domain': gene_pct_in[sorted_indices],
                'pct_out_domain': gene_pct_out[sorted_indices]
            })
        else:
            domain_markers[domain] = pd.DataFrame()
    
    return domain_markers

control_sample = adata.obs['sample'].unique()[0]
disease_sample = adata.obs['sample'].unique()[1]

control_markers = calculate_domain_markers(adata, 'nichepca_domains', sample=control_sample)
disease_markers = calculate_domain_markers(adata, 'nichepca_domains', sample=disease_sample)

print(f"Domain markers calculated for {control_sample} and {disease_sample}")

In [None]:
def display_top_markers(domain_markers, sample_name, top_n=5):
    print(f"Top {top_n} marker genes per domain - {sample_name}\n")
    
    for domain, markers_df in domain_markers.items():
        if len(markers_df) > 0:
            top_markers = markers_df.head(top_n)
            print(f"\nDomain {domain}:")
            for _, gene_info in top_markers.iterrows():
                print(f"  {gene_info['gene']}: log_FC={gene_info['log_fc']:.2f}, "
                      f"pct_in={gene_info['pct_in_domain']:.2f}, "
                      f"pct_out={gene_info['pct_out_domain']:.2f}")
        else:
            print(f"\nDomain {domain}: No significant markers found")

display_top_markers(control_markers, control_sample)
display_top_markers(disease_markers, disease_sample)

In [None]:
## domain signatures
def create_domain_signature_matrix(adata_sample, domain_key, domain_markers, top_genes=10):
    domains = sorted(adata_sample.obs[domain_key].unique())
    
    all_marker_genes = set()
    for domain, markers_df in domain_markers.items():
        if len(markers_df) > 0:
            top_genes_domain = markers_df.head(top_genes)['gene'].tolist()
            all_marker_genes.update(top_genes_domain)
    
    if not all_marker_genes:
        print("No marker genes found")
        return None, None
    
    all_marker_genes = list(all_marker_genes)
    gene_mask = adata_sample.var.index.isin(all_marker_genes)
    
    domain_expression = np.zeros((len(domains), len(all_marker_genes)))
    
    for i, domain in enumerate(domains):
        domain_mask = adata_sample.obs[domain_key] == domain
        domain_cells = adata_sample[domain_mask, gene_mask]
        
        if domain_cells.n_obs > 0:
            if hasattr(domain_cells.X, 'toarray'):
                domain_mean = domain_cells.X.toarray().mean(axis=0)
            else:
                domain_mean = domain_cells.X.mean(axis=0)
            
            domain_expression[i, :] = domain_mean
    
    return domain_expression, all_marker_genes, domains

fig, axes = plt.subplots(1, 2, figsize=(20, 8))

for i, (sample, markers) in enumerate([(control_sample, control_markers), 
                                      (disease_sample, disease_markers)]):
    sample_mask = adata.obs['sample'] == sample
    sample_data = adata[sample_mask]
    
    expr_matrix, marker_genes, domains = create_domain_signature_matrix(
        sample_data, 'nichepca_domains', markers, top_genes=8
    )
    
    if expr_matrix is not None:
        sns.heatmap(expr_matrix, 
                   xticklabels=marker_genes, 
                   yticklabels=[f'Domain {d}' for d in domains],
                   cmap='RdYlBu_r', center=0, ax=axes[i])
        axes[i].set_title(f'Domain Gene Signatures - {sample}')
        axes[i].tick_params(axis='x', rotation=45)

plt.tight_layout()
plt.show()

In [None]:
## domain comparisons between conditions

def compare_domain_profiles(adata, domain_key):
    samples = adata.obs['sample'].unique()
    control_data = adata[adata.obs['sample'] == samples[0]]
    disease_data = adata[adata.obs['sample'] == samples[1]]
    
    comparison_results = {}
    
    for control_domain in control_data.obs[domain_key].unique():
        control_mask = control_data.obs[domain_key] == control_domain
        control_profile = control_data.X[control_mask].mean(axis=0)
        
        if hasattr(control_profile, 'A1'):
            control_profile = control_profile.A1
        
        best_correlation = -1
        best_match = None
        correlations = {}
        
        for disease_domain in disease_data.obs[domain_key].unique():
            disease_mask = disease_data.obs[domain_key] == disease_domain
            disease_profile = disease_data.X[disease_mask].mean(axis=0)
            
            if hasattr(disease_profile, 'A1'):
                disease_profile = disease_profile.A1
            
            correlation = np.corrcoef(control_profile, disease_profile)[0, 1]
            correlations[disease_domain] = correlation
            
            if correlation > best_correlation:
                best_correlation = correlation
                best_match = disease_domain
        
        comparison_results[control_domain] = {
            'best_match': best_match,
            'best_correlation': best_correlation,
            'all_correlations': correlations
        }
    
    return comparison_results

domain_comparisons = compare_domain_profiles(adata, 'nichepca_domains')

print("Domain matching between control and disease:")
for control_domain, match_info in domain_comparisons.items():
    print(f"Control Domain {control_domain} <-> Disease Domain {match_info['best_match']} "
          f"(correlation: {match_info['best_correlation']:.3f})")

In [None]:
## classical DE
def find_disease_vs_control_de(adata, domain_key, control_domain, disease_domain, 
                              control_sample, disease_sample):
    
    control_mask = (adata.obs['sample'] == control_sample) & (adata.obs[domain_key] == control_domain)
    disease_mask = (adata.obs['sample'] == disease_sample) & (adata.obs[domain_key] == disease_domain)
    
    control_cells = adata.X[control_mask]
    disease_cells = adata.X[disease_mask]
    
    if hasattr(control_cells, 'toarray'):
        control_cells = control_cells.toarray()
        disease_cells = disease_cells.toarray()
    
    control_mean = control_cells.mean(axis=0)
    disease_mean = disease_cells.mean(axis=0)
    
    log_fc = np.log2((disease_mean + 1e-8) / (control_mean + 1e-8))
    
    p_values = []
    for i in range(adata.n_vars):
        try:
            _, p_val = stats.ttest_ind(disease_cells[:, i], control_cells[:, i])
            p_values.append(p_val)
        except:
            p_values.append(1.0)
    
    de_results = pd.DataFrame({
        'gene': adata.var.index,
        'log_fc': log_fc,
        'p_value': p_values,
        'control_mean': control_mean,
        'disease_mean': disease_mean
    })
    
    de_results['abs_log_fc'] = np.abs(de_results['log_fc'])
    de_results = de_results.sort_values('abs_log_fc', ascending=False)
    
    return de_results

de_analyses = {}

for control_domain, match_info in domain_comparisons.items():
    disease_domain = match_info['best_match']
    
    de_results = find_disease_vs_control_de(
        adata, 'nichepca_domains', control_domain, disease_domain,
        control_sample, disease_sample
    )
    
    de_analyses[f"Domain_{control_domain}"] = de_results

print("Differential expression analysis completed for matched domains")

for domain_pair, de_df in de_analyses.items():
    print(f"\n{domain_pair} - Top upregulated genes in disease:")
    upregulated = de_df[de_df['log_fc'] > 0.5].head(5)
    for _, gene_info in upregulated.iterrows():
        print(f"  {gene_info['gene']}: log_FC={gene_info['log_fc']:.2f}, p={gene_info['p_value']:.3f}")
    
    print(f"\n{domain_pair} - Top downregulated genes in disease:")
    downregulated = de_df[de_df['log_fc'] < -0.5].head(5)
    for _, gene_info in downregulated.iterrows():
        print(f"  {gene_info['gene']}: log_FC={gene_info['log_fc']:.2f}, p={gene_info['p_value']:.3f}")

In [None]:
## volcanos

fig, axes = plt.subplots(2, 3, figsize=(18, 12))
axes = axes.flatten()

for i, (domain_pair, de_df) in enumerate(list(de_analyses.items())[:6]):
    if i >= 6:
        break
    
    x = de_df['log_fc']
    y = -np.log10(de_df['p_value'] + 1e-10)
    
    axes[i].scatter(x, y, s=8, alpha=0.6, color='gray')
    
    significant_up = (de_df['log_fc'] > 0.5) & (de_df['p_value'] < 0.05)
    significant_down = (de_df['log_fc'] < -0.5) & (de_df['p_value'] < 0.05)
    
    axes[i].scatter(x[significant_up], y[significant_up], s=8, alpha=0.8, color='red', label='Up')
    axes[i].scatter(x[significant_down], y[significant_down], s=8, alpha=0.8, color='blue', label='Down')
    
    axes[i].axhline(-np.log10(0.05), color='black', linestyle='--', alpha=0.5)
    axes[i].axvline(0.5, color='black', linestyle='--', alpha=0.5)
    axes[i].axvline(-0.5, color='black', linestyle='--', alpha=0.5)
    
    axes[i].set_xlabel('log2(Fold Change)')
    axes[i].set_ylabel('-log10(p-value)')
    axes[i].set_title(f'{domain_pair}')
    axes[i].legend()

if len(de_analyses) < 6:
    for j in range(len(de_analyses), 6):
        axes[j].set_visible(False)

plt.tight_layout()
plt.show()

In [None]:
## Domain spatial organizations

def analyze_domain_spatial_organization(adata_sample, domain_key):
    coords = adata_sample.obsm['spatial']
    domains = adata_sample.obs[domain_key]
    
    domain_centroids = {}
    domain_boundaries = {}
    
    for domain in domains.unique():
        domain_mask = domains == domain
        domain_coords = coords[domain_mask]
        
        centroid = domain_coords.mean(axis=0)
        domain_centroids[domain] = centroid
        
        x_min, x_max = domain_coords[:, 0].min(), domain_coords[:, 0].max()
        y_min, y_max = domain_coords[:, 1].min(), domain_coords[:, 1].max()
        domain_boundaries[domain] = {
            'x_range': x_max - x_min,
            'y_range': y_max - y_min,
            'area': (x_max - x_min) * (y_max - y_min)
        }
    
    inter_domain_distances = {}
    domain_list = list(domains.unique())
    
    for i, domain1 in enumerate(domain_list):
        for j, domain2 in enumerate(domain_list[i+1:], i+1):
            dist = np.sqrt(np.sum((domain_centroids[domain1] - domain_centroids[domain2])**2))
            inter_domain_distances[f"{domain1}-{domain2}"] = dist
    
    return domain_centroids, domain_boundaries, inter_domain_distances

fig, axes = plt.subplots(1, 2, figsize=(16, 8))

for i, sample in enumerate(adata.obs['sample'].unique()):
    sample_mask = adata.obs['sample'] == sample
    sample_data = adata[sample_mask]
    
    centroids, boundaries, distances = analyze_domain_spatial_organization(sample_data, 'nichepca_domains')
    
    coords = sample_data.obsm['spatial']
    domains = sample_data.obs['nichepca_domains']
    
    unique_domains = sorted(domains.unique())
    colors = plt.cm.tab10(np.linspace(0, 1, len(unique_domains)))
    
    for j, domain in enumerate(unique_domains):
        domain_mask = domains == domain
        axes[i].scatter(coords[domain_mask, 0], coords[domain_mask, 1], 
                       c=[colors[j]], s=1, alpha=0.6, label=f'Domain {domain}')
        
        centroid = centroids[domain]
        axes[i].scatter(centroid[0], centroid[1], c='black', s=100, marker='x')
        axes[i].annotate(f'D{domain}', (centroid[0], centroid[1]), 
                        xytext=(5, 5), textcoords='offset points', fontsize=8, fontweight='bold')
    
    axes[i].set_title(f'Domain Spatial Organization - {sample}')
    axes[i].set_xlabel('X coordinate')
    axes[i].set_ylabel('Y coordinate')
    axes[i].axis('equal')

plt.tight_layout()
plt.show()

print("\nDomain spatial characteristics:")
for sample in adata.obs['sample'].unique():
    sample_mask = adata.obs['sample'] == sample
    sample_data = adata[sample_mask]
    centroids, boundaries, distances = analyze_domain_spatial_organization(sample_data, 'nichepca_domains')
    
    print(f"\n{sample}:")
    print("Domain areas:")
    for domain, boundary_info in boundaries.items():
        print(f"  Domain {domain}: {boundary_info['area']:.0f} spatial units")

In [None]:
## Disease-specific changes

def perform_simple_pathway_analysis(marker_genes, all_genes):
    kidney_pathways = {
        'Immune_Response': ['CD68', 'CD3E', 'CD3D', 'CD8A', 'CD4', 'PTPRC', 'LYZ', 'CSF1R'],
        'Epithelial_Function': ['EPCAM', 'CDH1', 'KRT8', 'KRT18', 'KRT19', 'CLDN1', 'OCLN', 'TJP1'],
        'Fibrosis': ['COL1A1', 'COL1A2', 'COL3A1', 'FN1', 'ACTA2', 'TGFB1', 'PDGFRB'],
        'Angiogenesis': ['VEGFA', 'VEGFB', 'VEGFC', 'FLT1', 'KDR', 'ANGPT1', 'ANGPT2', 'PDGFRB'],
        'Inflammation': ['IL1B', 'IL6', 'TNF', 'CXCL1', 'CXCL2', 'CCL2', 'CCL3', 'PTGS2'],
        'Complement': ['C1QA', 'C1QB', 'C1QC', 'C3', 'C4A', 'C4B', 'CFB', 'CFD']
    }
    
    enrichment_results = {}
    
    for pathway, pathway_genes in kidney_pathways.items():
        available_pathway_genes = [g for g in pathway_genes if g in all_genes]
        overlap = len(set(marker_genes) & set(available_pathway_genes))
        
        if len(available_pathway_genes) > 0:
            enrichment_score = overlap / len(available_pathway_genes)
            enrichment_results[pathway] = {
                'overlap': overlap,
                'pathway_size': len(available_pathway_genes),
                'enrichment_score': enrichment_score,
                'overlapping_genes': list(set(marker_genes) & set(available_pathway_genes))
            }
    
    return enrichment_results

all_genes = adata.var.index.tolist()
pathway_enrichments = {}

for sample in [control_sample, disease_sample]:
    sample_enrichments = {}
    markers = control_markers if sample == control_sample else disease_markers
    
    for domain, markers_df in markers.items():
        if len(markers_df) > 0:
            top_markers = markers_df.head(20)['gene'].tolist()
            enrichment = perform_simple_pathway_analysis(top_markers, all_genes)
            sample_enrichments[domain] = enrichment
    
    pathway_enrichments[sample] = sample_enrichments

print("Pathway enrichment analysis:")
for sample, sample_results in pathway_enrichments.items():
    print(f"\n=== {sample} ===")
    for domain, domain_enrichments in sample_results.items():
        print(f"\nDomain {domain}:")
        for pathway, enrichment_info in domain_enrichments.items():
            if enrichment_info['overlap'] > 0:
                print(f"  {pathway}: {enrichment_info['overlap']}/{enrichment_info['pathway_size']} genes "
                      f"(score: {enrichment_info['enrichment_score']:.2f})")
                if enrichment_info['overlapping_genes']:
                    print(f"    Genes: {', '.join(enrichment_info['overlapping_genes'])}")

In [None]:
def compare_pathway_enrichments(control_enrichments, disease_enrichments, domain_comparisons):
    pathway_changes = {}
    
    for control_domain, match_info in domain_comparisons.items():
        disease_domain = match_info['best_match']
        
        if (control_domain in control_enrichments and 
            disease_domain in disease_enrichments):
            
            control_pathways = control_enrichments[control_domain]
            disease_pathways = disease_enrichments[disease_domain]
            
            domain_changes = {}
            all_pathways = set(control_pathways.keys()) | set(disease_pathways.keys())
            
            for pathway in all_pathways:
                control_score = control_pathways.get(pathway, {}).get('enrichment_score', 0)
                disease_score = disease_pathways.get(pathway, {}).get('enrichment_score', 0)
                
                change = disease_score - control_score
                domain_changes[pathway] = {
                    'control_score': control_score,
                    'disease_score': disease_score,
                    'change': change
                }
            
            pathway_changes[f"Domain_{control_domain}"] = domain_changes
    
    return pathway_changes

pathway_changes = compare_pathway_enrichments(
    pathway_enrichments[control_sample],
    pathway_enrichments[disease_sample],
    domain_comparisons
)

print("Pathway enrichment changes in disease:")
for domain_pair, changes in pathway_changes.items():
    print(f"\n=== {domain_pair} ===")
    
    significant_changes = [(pathway, info) for pathway, info in changes.items() 
                          if abs(info['change']) > 0.1]
    
    if significant_changes:
        significant_changes.sort(key=lambda x: abs(x[1]['change']), reverse=True)
        
        for pathway, change_info in significant_changes:
            direction = "↑" if change_info['change'] > 0 else "↓"
            print(f"  {pathway}: {change_info['control_score']:.2f} → {change_info['disease_score']:.2f} "
                  f"({direction}{abs(change_info['change']):.2f})")
    else:
        print("  No significant pathway changes detected")

In [None]:
## Summary

def generate_summary(adata, domain_key, pathway_changes, de_analyses):
    summary = {}
    
    samples = adata.obs['sample'].unique()
    control = samples[0]
    disease = samples[1]
    
    summary['sample_info'] = {
        'control': control,
        'disease': disease,
        'control_domains': len(adata[adata.obs['sample'] == control].obs[domain_key].unique()),
        'disease_domains': len(adata[adata.obs['sample'] == disease].obs[domain_key].unique())
    }
    
    summary['key_findings'] = []
    
    for domain_pair, changes in pathway_changes.items():
        significant_pathways = [(p, info) for p, info in changes.items() 
                               if abs(info['change']) > 0.15]
        
        if significant_pathways:
            summary['key_findings'].append({
                'domain': domain_pair,
                'significant_pathways': significant_pathways
            })
    
    summary['de_genes_summary'] = {}
    for domain_pair, de_df in de_analyses.items():
        upregulated = len(de_df[(de_df['log_fc'] > 0.5) & (de_df['p_value'] < 0.05)])
        downregulated = len(de_df[(de_df['log_fc'] < -0.5) & (de_df['p_value'] < 0.05)])
        
        summary['de_genes_summary'][domain_pair] = {
            'upregulated': upregulated,
            'downregulated': downregulated
        }
    
    return summary

summary = generate_interpretation_summary(
    adata, 'nichepca_domains', pathway_changes, de_analyses
)