In [None]:
"""
PCA Feature Analysis for Morphological Features
Analyzes feature distributions, separability, and informativeness
"""

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.decomposition import PCA
from sklearn.preprocessing import StandardScaler
from sklearn.discriminant_analysis import LinearDiscriminantAnalysis
import warnings
warnings.filterwarnings('ignore')

def analyze_features(X, y, feature_names=None, n_components=10):
    """
    Comprehensive PCA analysis of features
    
    Parameters:
    -----------
    X : array-like, shape (n_samples, n_features)
        Feature matrix
    y : array-like, shape (n_samples,)
        Target labels
    feature_names : list, optional
        Names of features
    n_components : int
        Number of PCA components to compute
    """
    
    # Convert to numpy if needed
    if isinstance(X, pd.DataFrame):
        feature_names = X.columns.tolist() if feature_names is None else feature_names
        X = X.values
    if isinstance(y, pd.Series):
        y = y.values
    
    if feature_names is None:
        feature_names = [f"Feature_{i}" for i in range(X.shape[1])]
    
    print("="*80)
    print("FEATURE ANALYSIS REPORT")
    print("="*80)
    print(f"\nDataset shape: {X.shape[0]} samples, {X.shape[1]} features")
    print(f"Classes: {np.unique(y)} (counts: {np.bincount(y.astype(int))})")
    
    # 1. Basic feature statistics
    print("\n" + "="*80)
    print("1. FEATURE STATISTICS")
    print("="*80)
    feature_stats = pd.DataFrame({
        'feature': feature_names,
        'mean': X.mean(axis=0),
        'std': X.std(axis=0),
        'min': X.min(axis=0),
        'max': X.max(axis=0),
        'range': X.max(axis=0) - X.min(axis=0),
        'zeros_%': (X == 0).mean(axis=0) * 100
    })
    
    # Check for problematic features
    low_variance = feature_stats[feature_stats['std'] < 1e-6]
    if len(low_variance) > 0:
        print(f"\n‚ö†Ô∏è  WARNING: {len(low_variance)} features have near-zero variance!")
        print(low_variance[['feature', 'std']])
    
    constant_features = feature_stats[feature_stats['range'] < 1e-6]
    if len(constant_features) > 0:
        print(f"\n‚ö†Ô∏è  WARNING: {len(constant_features)} features are constant!")
    
    high_zeros = feature_stats[feature_stats['zeros_%'] > 90]
    if len(high_zeros) > 0:
        print(f"\n‚ö†Ô∏è  WARNING: {len(high_zeros)} features are >90% zeros!")
        print(high_zeros[['feature', 'zeros_%']])
    
    # 2. Standardize features
    print("\n" + "="*80)
    print("2. STANDARDIZING FEATURES")
    print("="*80)
    scaler = StandardScaler()
    X_scaled = scaler.fit_transform(X)
    print("‚úì Features standardized (mean=0, std=1)")
    
    # 3. PCA Analysis
    print("\n" + "="*80)
    print("3. PCA ANALYSIS")
    print("="*80)
    n_components = min(n_components, X.shape[1], X.shape[0])
    pca = PCA(n_components=n_components)
    X_pca = pca.fit_transform(X_scaled)
    
    print(f"\nVariance explained by first {n_components} components:")
    for i, var in enumerate(pca.explained_variance_ratio_):
        cum_var = pca.explained_variance_ratio_[:i+1].sum()
        print(f"  PC{i+1}: {var*100:.2f}% (cumulative: {cum_var*100:.2f}%)")
    
    cum_var_80 = np.where(np.cumsum(pca.explained_variance_ratio_) >= 0.80)[0]
    if len(cum_var_80) > 0:
        print(f"\nüí° {cum_var_80[0]+1} components explain 80% of variance")
    
    cum_var_95 = np.where(np.cumsum(pca.explained_variance_ratio_) >= 0.95)[0]
    if len(cum_var_95) > 0:
        print(f"üí° {cum_var_95[0]+1} components explain 95% of variance")
    
    # 4. LDA Analysis (class separability)
    print("\n" + "="*80)
    print("4. LINEAR DISCRIMINANT ANALYSIS (Class Separability)")
    print("="*80)
    
    n_classes = len(np.unique(y))
    lda_components = min(n_classes - 1, X.shape[1])
    
    if lda_components > 0:
        try:
            lda = LinearDiscriminantAnalysis(n_components=lda_components)
            X_lda = lda.fit_transform(X_scaled, y)
            
            print(f"\nLDA variance explained (max {lda_components} discriminant axes):")
            for i, var in enumerate(lda.explained_variance_ratio_):
                print(f"  LD{i+1}: {var*100:.2f}%")
            
            # Calculate separation between class means in LDA space
            if lda_components >= 1:
                class_means = np.array([X_lda[y == c].mean(axis=0) for c in np.unique(y)])
                distances = []
                for i in range(len(class_means)):
                    for j in range(i+1, len(class_means)):
                        dist = np.linalg.norm(class_means[i] - class_means[j])
                        distances.append(dist)
                print(f"\nüí° Average distance between class means in LDA space: {np.mean(distances):.3f}")
                print(f"   (Higher is better for separability)")
        except Exception as e:
            print(f"‚ö†Ô∏è  Could not perform LDA: {e}")
            X_lda = None
    else:
        print("‚ö†Ô∏è  Not enough classes for LDA")
        X_lda = None
    
    # 5. Feature correlation with target
    print("\n" + "="*80)
    print("5. FEATURE-TARGET RELATIONSHIPS")
    print("="*80)
    
    # For binary classification, compute point-biserial correlation
    if len(np.unique(y)) == 2:
        correlations = []
        for i in range(X_scaled.shape[1]):
            corr = np.corrcoef(X_scaled[:, i], y)[0, 1]
            correlations.append(abs(corr))
        
        correlations = np.array(correlations)
        top_indices = np.argsort(correlations)[-10:][::-1]
        
        print("\nTop 10 features by correlation with target:")
        for idx in top_indices:
            print(f"  {feature_names[idx]}: {correlations[idx]:.4f}")
        
        if correlations.max() < 0.1:
            print("\n‚ö†Ô∏è  WARNING: No features have correlation >0.1 with target!")
            print("   This suggests features may not be informative for this classification.")
    
    # 6. Create visualizations
    print("\n" + "="*80)
    print("6. GENERATING VISUALIZATIONS")
    print("="*80)
    
    fig = plt.figure(figsize=(20, 12))
    
    # Plot 1: Scree plot
    ax1 = plt.subplot(2, 3, 1)
    plt.plot(range(1, len(pca.explained_variance_ratio_) + 1), 
             pca.explained_variance_ratio_, 'bo-')
    plt.xlabel('Principal Component', fontsize=12)
    plt.ylabel('Variance Explained', fontsize=12)
    plt.title('Scree Plot - Variance per Component', fontsize=14, fontweight='bold')
    plt.grid(True, alpha=0.3)
    
    # Plot 2: Cumulative variance
    ax2 = plt.subplot(2, 3, 2)
    plt.plot(range(1, len(pca.explained_variance_ratio_) + 1), 
             np.cumsum(pca.explained_variance_ratio_), 'ro-')
    plt.axhline(y=0.8, color='g', linestyle='--', label='80% threshold')
    plt.axhline(y=0.95, color='b', linestyle='--', label='95% threshold')
    plt.xlabel('Number of Components', fontsize=12)
    plt.ylabel('Cumulative Variance Explained', fontsize=12)
    plt.title('Cumulative Variance Explained', fontsize=14, fontweight='bold')
    plt.legend()
    plt.grid(True, alpha=0.3)
    
    # Plot 3: PCA scatter (PC1 vs PC2)
    ax3 = plt.subplot(2, 3, 3)
    scatter = plt.scatter(X_pca[:, 0], X_pca[:, 1], c=y, cmap='viridis', 
                         alpha=0.6, edgecolors='k', linewidth=0.5)
    plt.xlabel(f'PC1 ({pca.explained_variance_ratio_[0]*100:.1f}%)', fontsize=12)
    plt.ylabel(f'PC2 ({pca.explained_variance_ratio_[1]*100:.1f}%)', fontsize=12)
    plt.title('PCA: First Two Components', fontsize=14, fontweight='bold')
    plt.colorbar(scatter, label='Class')
    plt.grid(True, alpha=0.3)
    
    # Plot 4: Feature importance in PC1
    ax4 = plt.subplot(2, 3, 4)
    pc1_importance = np.abs(pca.components_[0])
    top_10_idx = np.argsort(pc1_importance)[-10:]
    plt.barh(range(10), pc1_importance[top_10_idx])
    plt.yticks(range(10), [feature_names[i] for i in top_10_idx], fontsize=9)
    plt.xlabel('Absolute Loading', fontsize=12)
    plt.title('Top 10 Features in PC1', fontsize=14, fontweight='bold')
    plt.grid(True, alpha=0.3, axis='x')
    
    # Plot 5: LDA plot if available
    ax5 = plt.subplot(2, 3, 5)
    if X_lda is not None and X_lda.shape[1] >= 2:
        scatter = plt.scatter(X_lda[:, 0], X_lda[:, 1], c=y, cmap='viridis', 
                             alpha=0.6, edgecolors='k', linewidth=0.5)
        plt.xlabel(f'LD1 ({lda.explained_variance_ratio_[0]*100:.1f}%)', fontsize=12)
        if X_lda.shape[1] >= 2:
            plt.ylabel(f'LD2 ({lda.explained_variance_ratio_[1]*100:.1f}%)', fontsize=12)
        plt.title('LDA: Class Separability', fontsize=14, fontweight='bold')
        plt.colorbar(scatter, label='Class')
        plt.grid(True, alpha=0.3)
    elif X_lda is not None and X_lda.shape[1] == 1:
        for class_val in np.unique(y):
            plt.hist(X_lda[y == class_val, 0], alpha=0.5, label=f'Class {class_val}', bins=30)
        plt.xlabel('LD1', fontsize=12)
        plt.ylabel('Frequency', fontsize=12)
        plt.title('LDA: 1D Projection', fontsize=14, fontweight='bold')
        plt.legend()
        plt.grid(True, alpha=0.3)
    else:
        plt.text(0.5, 0.5, 'LDA not available', ha='center', va='center', fontsize=14)
        plt.axis('off')
    
    # Plot 6: Feature variance distribution
    ax6 = plt.subplot(2, 3, 6)
    feature_stds = X_scaled.std(axis=0)
    try:
        plt.hist(feature_stds, bins='auto', edgecolor='black')
    except:
        # If auto fails, just use bar plot
        plt.bar(range(len(feature_stds)), sorted(feature_stds, reverse=True))
        plt.xlabel('Feature Index (sorted)', fontsize=12)
    plt.xlabel('Standard Deviation (after scaling)', fontsize=12)
    plt.ylabel('Number of Features', fontsize=12)
    plt.title('Distribution of Feature Variances', fontsize=14, fontweight='bold')
    plt.grid(True, alpha=0.3, axis='y')
    
    plt.tight_layout()
    plt.savefig('/mnt/user-data/outputs/pca_feature_analysis.png', dpi=150, bbox_inches='tight')
    print("‚úì Saved: pca_feature_analysis.png")
    
    # Additional plot: Pairplot of top PCA components
    if n_components >= 3:
        fig2, axes = plt.subplots(3, 3, figsize=(15, 15))
        
        for i in range(3):
            for j in range(3):
                ax = axes[i, j]
                if i == j:
                    # Histogram on diagonal
                    for class_val in np.unique(y):
                        ax.hist(X_pca[y == class_val, i], alpha=0.5, 
                               label=f'Class {class_val}', bins=30)
                    ax.set_ylabel('Frequency')
                    if i == 0:
                        ax.legend()
                else:
                    # Scatter plot
                    scatter = ax.scatter(X_pca[:, j], X_pca[:, i], c=y, 
                                       cmap='viridis', alpha=0.5, s=10)
                
                if i == 2:
                    ax.set_xlabel(f'PC{j+1} ({pca.explained_variance_ratio_[j]*100:.1f}%)')
                if j == 0:
                    ax.set_ylabel(f'PC{i+1} ({pca.explained_variance_ratio_[i]*100:.1f}%)')
                
                ax.grid(True, alpha=0.3)
        
        plt.suptitle('PCA Components Pairplot', fontsize=16, fontweight='bold', y=1.00)
        plt.tight_layout()
        plt.savefig('/mnt/user-data/outputs/pca_pairplot.png', dpi=150, bbox_inches='tight')
        print("‚úì Saved: pca_pairplot.png")
    
    # 7. Summary and recommendations
    print("\n" + "="*80)
    print("7. SUMMARY & RECOMMENDATIONS")
    print("="*80)
    
    # Check for issues
    issues = []
    if pca.explained_variance_ratio_[0] > 0.9:
        issues.append("‚ö†Ô∏è  First PC explains >90% variance - features may be highly redundant")
    
    if len(low_variance) > 0:
        issues.append(f"‚ö†Ô∏è  {len(low_variance)} features have near-zero variance")
    
    if len(np.unique(y)) == 2:
        if correlations.max() < 0.1:
            issues.append("‚ö†Ô∏è  No strong feature-target correlations (<0.1)")
        elif correlations.max() < 0.2:
            issues.append("‚ö†Ô∏è  Weak feature-target correlations (<0.2)")
    
    cum_var_3 = np.sum(pca.explained_variance_ratio_[:3])
    if cum_var_3 < 0.5:
        issues.append(f"‚ö†Ô∏è  First 3 PCs only explain {cum_var_3*100:.1f}% variance - high dimensionality")
    
    if len(issues) > 0:
        print("\nüî¥ Issues Detected:")
        for issue in issues:
            print(f"  {issue}")
    else:
        print("\n‚úÖ No major issues detected!")
    
    print("\nüìã Recommendations:")
    print("  1. Check the PCA scatter plot - do classes visually separate?")
    print("  2. Check the LDA plot - better separation = more informative features")
    print("  3. If separation is poor, consider:")
    print("     - Engineering new features (ratios, differences, interactions)")
    print("     - Checking if features are computed correctly from images")
    print("     - Trying different morphological features")
    print("     - Examining misclassified samples visually")
    
    return {
        'pca': pca,
        'lda': lda if X_lda is not None else None,
        'X_pca': X_pca,
        'X_lda': X_lda,
        'scaler': scaler,
        'feature_stats': feature_stats
    }


print("\n" + "="*80)
print("EXAMPLE: How to use this script")
print("="*80)
# Load your data
import pandas as pd
df = pd.read_csv("/mnt/towbin.data/shared/spsalmon/towbinlab_classification_database/datasets/10x_pharynx_qc/pharynx/features.csv")
target = pd.read_csv("/mnt/towbin.data/shared/spsalmon/towbinlab_classification_database/datasets/10x_pharynx_qc/pharynx/processed_annotations.csv")

# Separate features and target
X = df.drop('target', axis=1)  # or specify feature columns
y = df['target']

# Run analysis
from pca_feature_analysis import analyze_features
results = analyze_features(X, y, feature_names=X.columns, n_components=10)

# Access results
# results['pca'] - fitted PCA object
# results['X_pca'] - transformed data
# results['feature_stats'] - feature statistics
