# TimeFlies: Comprehensive Single-Cell RNA-seq Analysis

This notebook provides a comprehensive analysis of single-cell RNA sequencing data from Drosophila tissue samples using the TimeFlies framework. It demonstrates:

1. **Data Loading and Configuration**: Using the modern YAML-based configuration system
2. **Exploratory Data Analysis**: Quality control, cell type distribution, and basic statistics
3. **Gene Expression Analysis**: Temporal patterns, sex-specific differences, and biomarker identification
4. **Dimensionality Reduction**: PCA and UMAP visualization
5. **Model Predictions**: Loading trained models and analyzing predictions
6. **Feature Importance**: SHAP-based interpretation of model decisions

## Setup and Configuration

In [None]:
# Import necessary libraries
import os
import sys
import yaml
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import scanpy as sc
from scipy.sparse import issparse
from sklearn.decomposition import PCA
from sklearn.metrics import classification_report, confusion_matrix
import warnings
warnings.filterwarnings('ignore')

# Add src to path for imports
sys.path.append('../src')

from timeflies.core.config_manager import ConfigManager
from timeflies.data.loaders import DataLoader
from timeflies.models.model import ModelLoader
from timeflies.evaluation.interpreter import Interpreter
from timeflies.utils.path_manager import PathManager

# Configure matplotlib and seaborn
plt.style.use('default')
sns.set_palette("husl")
%matplotlib inline

# Configure scanpy
sc.settings.verbosity = 1  # Reduce verbosity
sc.settings.set_figure_params(dpi=80, facecolor='white')

In [None]:
# Load configuration from YAML file
config_path = '../configs/head_cnn_config.yaml'

try:
    config = ConfigManager.from_yaml(config_path)
    print(f"✅ Configuration loaded successfully from {config_path}")
    print(f"Analysis target: {config.data.tissue} tissue, {config.data.encoding_variable} prediction")
    print(f"Model type: {config.data.model_type}")
except FileNotFoundError:
    print(f"⚠️  Config file not found at {config_path}")
    print("Creating a minimal configuration for demonstration...")
    
    # Create a minimal config for demonstration
    config_dict = {
        'data': {
            'tissue': 'head',
            'model_type': 'CNN',
            'encoding_variable': 'age',
            'sex_type': 'all',
            'cell_type': 'all',
            'batch_correction': {'enabled': False}
        },
        'file_locations': {
            'training_file': 'fly_train.h5ad',
            'evaluation_file': 'fly_eval.h5ad',
            'original_file': 'fly_original.h5ad'
        },
        'feature_importance': {
            'run_interpreter': True,
            'run_visualization': True
        }
    }
    config = ConfigManager.from_dict(config_dict)
    print("✅ Minimal configuration created")

## Data Loading and Initial Exploration

In [None]:
# Load data using the DataLoader
try:
    data_loader = DataLoader(config)
    adata, adata_eval, adata_original = data_loader.load_data()
    print(f"✅ Data loaded successfully")
    print(f"Training data: {adata.n_obs} cells, {adata.n_vars} genes")
    print(f"Evaluation data: {adata_eval.n_obs} cells, {adata_eval.n_vars} genes")
    print(f"Original data: {adata_original.n_obs} cells, {adata_original.n_vars} genes")
except Exception as e:
    print(f"⚠️  Error loading data: {e}")
    print("This notebook requires the TimeFlies dataset to be properly set up.")
    print("Please ensure your data is in the correct directory structure.")
    # For demo purposes, create synthetic data
    print("Creating synthetic data for demonstration...")
    
    n_obs, n_vars = 10000, 2000
    X = np.random.negative_binomial(5, 0.3, size=(n_obs, n_vars))
    
    obs = pd.DataFrame({
        'age': np.random.choice([1, 5, 10, 20], n_obs),
        'sex': np.random.choice(['male', 'female'], n_obs),
        'tissue': 'head',
        'afca_annotation_broad': np.random.choice(
            ['CNS neuron', 'muscle cell', 'epithelial cell', 'glial cell'], n_obs
        )
    })
    
    var = pd.DataFrame({
        'gene_name': [f'gene_{i}' for i in range(n_vars)],
        'highly_variable': np.random.choice([True, False], n_vars, p=[0.3, 0.7])
    }, index=[f'gene_{i}' for i in range(n_vars)])
    
    from anndata import AnnData
    adata = AnnData(X=X, obs=obs, var=var)
    adata_eval = adata[::2].copy()  # Half for evaluation
    adata_original = adata.copy()
    
    print(f"✅ Synthetic data created: {adata.n_obs} cells, {adata.n_vars} genes")

In [None]:
# Basic data exploration
print("=== Dataset Overview ===")
print(f"Observations (cells): {adata.n_obs:,}")
print(f"Variables (genes): {adata.n_vars:,}")
print(f"Data matrix density: {(adata.X > 0).sum() / (adata.n_obs * adata.n_vars):.2%}")

print("\n=== Cell Metadata ===")
print(adata.obs.columns.tolist())

print("\n=== Age Distribution ===")
age_counts = adata.obs['age'].value_counts().sort_index()
print(age_counts)

print("\n=== Sex Distribution ===")
sex_counts = adata.obs['sex'].value_counts()
print(sex_counts)

if 'afca_annotation_broad' in adata.obs.columns:
    print("\n=== Cell Type Distribution ===")
    cell_type_counts = adata.obs['afca_annotation_broad'].value_counts()
    print(f"Number of cell types: {len(cell_type_counts)}")
    print("\nTop 10 cell types:")
    print(cell_type_counts.head(10))

## Quality Control and Visualization

In [None]:
# Create comprehensive overview plots
fig, axes = plt.subplots(2, 3, figsize=(18, 12))

# Age distribution
adata.obs['age'].hist(bins=20, ax=axes[0,0], alpha=0.7)
axes[0,0].set_title('Age Distribution')
axes[0,0].set_xlabel('Age')
axes[0,0].set_ylabel('Number of Cells')

# Sex distribution
sex_counts.plot(kind='bar', ax=axes[0,1], alpha=0.7)
axes[0,1].set_title('Sex Distribution')
axes[0,1].set_xlabel('Sex')
axes[0,1].set_ylabel('Number of Cells')
axes[0,1].tick_params(axis='x', rotation=0)

# Age by Sex
age_sex_crosstab = pd.crosstab(adata.obs['age'], adata.obs['sex'])
age_sex_crosstab.plot(kind='bar', ax=axes[0,2], alpha=0.7)
axes[0,2].set_title('Age Distribution by Sex')
axes[0,2].set_xlabel('Age')
axes[0,2].set_ylabel('Number of Cells')
axes[0,2].legend(title='Sex')

# Gene expression statistics
gene_means = np.array(adata.X.mean(axis=0)).flatten()
axes[1,0].hist(gene_means, bins=50, alpha=0.7)
axes[1,0].set_title('Gene Expression Means')
axes[1,0].set_xlabel('Mean Expression')
axes[1,0].set_ylabel('Number of Genes')
axes[1,0].set_yscale('log')

# Cell total counts
cell_totals = np.array(adata.X.sum(axis=1)).flatten()
axes[1,1].hist(cell_totals, bins=50, alpha=0.7)
axes[1,1].set_title('Total Counts per Cell')
axes[1,1].set_xlabel('Total Counts')
axes[1,1].set_ylabel('Number of Cells')

# Cell type distribution (if available)
if 'afca_annotation_broad' in adata.obs.columns:
    top_cell_types = adata.obs['afca_annotation_broad'].value_counts().head(10)
    top_cell_types.plot(kind='barh', ax=axes[1,2], alpha=0.7)
    axes[1,2].set_title('Top 10 Cell Types')
    axes[1,2].set_xlabel('Number of Cells')
else:
    axes[1,2].text(0.5, 0.5, 'Cell type\ninformation\nnot available', 
                   ha='center', va='center', transform=axes[1,2].transAxes)
    axes[1,2].set_title('Cell Type Distribution')

plt.tight_layout()
plt.show()

print(f"\n=== Expression Statistics ===")
print(f"Mean expression per gene: {gene_means.mean():.2f} ± {gene_means.std():.2f}")
print(f"Mean total counts per cell: {cell_totals.mean():.0f} ± {cell_totals.std():.0f}")
print(f"Genes with zero expression: {(gene_means == 0).sum():,} ({(gene_means == 0).mean():.1%})")

## Dimensionality Reduction and Visualization

In [None]:
# Prepare data for dimensionality reduction
adata_vis = adata.copy()

# Basic preprocessing
print("Performing PCA...")
sc.pp.pca(adata_vis, n_comps=50)

# Plot PCA variance explained
fig, axes = plt.subplots(1, 2, figsize=(12, 5))

# Variance ratio
var_ratio = adata_vis.uns['pca']['variance_ratio']
axes[0].plot(range(1, len(var_ratio) + 1), var_ratio, 'o-')
axes[0].set_title('PCA Variance Explained')
axes[0].set_xlabel('Principal Component')
axes[0].set_ylabel('Variance Ratio')
axes[0].grid(True, alpha=0.3)

# Cumulative variance
cumvar = np.cumsum(var_ratio)
axes[1].plot(range(1, len(cumvar) + 1), cumvar, 'o-')
axes[1].axhline(y=0.8, color='r', linestyle='--', alpha=0.7, label='80% variance')
axes[1].set_title('Cumulative Variance Explained')
axes[1].set_xlabel('Principal Component')
axes[1].set_ylabel('Cumulative Variance Ratio')
axes[1].legend()
axes[1].grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

print(f"First 10 PCs explain {cumvar[9]:.1%} of variance")
print(f"First 50 PCs explain {cumvar[49]:.1%} of variance")

In [None]:
# Compute UMAP
print("Computing neighborhood graph and UMAP...")
sc.pp.neighbors(adata_vis, n_pcs=30)
sc.tl.umap(adata_vis)

# Create UMAP plots
fig, axes = plt.subplots(1, 3, figsize=(18, 5))

# UMAP colored by age
sc.pl.umap(adata_vis, color='age', ax=axes[0], show=False, 
           title='UMAP colored by Age', size=20)

# UMAP colored by sex
sc.pl.umap(adata_vis, color='sex', ax=axes[1], show=False, 
           title='UMAP colored by Sex', size=20)

# UMAP colored by cell type (if available)
if 'afca_annotation_broad' in adata_vis.obs.columns:
    sc.pl.umap(adata_vis, color='afca_annotation_broad', ax=axes[2], show=False, 
               title='UMAP colored by Cell Type', size=20, legend_loc='right margin')
else:
    axes[2].text(0.5, 0.5, 'Cell type\ninformation\nnot available', 
                 ha='center', va='center', transform=axes[2].transAxes)
    axes[2].set_title('UMAP colored by Cell Type')

plt.tight_layout()
plt.show()

## Gene Expression Analysis

In [None]:
# Find highly variable genes and interesting patterns
print("Identifying highly variable genes...")

# Calculate gene statistics
gene_df = pd.DataFrame({
    'mean_expr': np.array(adata.X.mean(axis=0)).flatten(),
    'var_expr': np.array(adata.X.var(axis=0)).flatten(),
    'gene_name': adata.var_names
})

# Add coefficient of variation
gene_df['cv'] = gene_df['var_expr'] / (gene_df['mean_expr'] + 1e-6)

# Filter out very low expression genes
gene_df_filtered = gene_df[gene_df['mean_expr'] > 0.1]

# Select top variable genes
top_variable_genes = gene_df_filtered.nlargest(20, 'cv')

print(f"Top 10 most variable genes:")
for i, (_, gene) in enumerate(top_variable_genes.head(10).iterrows()):
    print(f"{i+1:2d}. {gene['gene_name']} (CV: {gene['cv']:.2f}, Mean: {gene['mean_expr']:.2f})")

In [None]:
# Analyze age-related expression patterns
def analyze_age_patterns(adata, gene_list, max_genes=12):
    """Analyze expression patterns across age groups."""
    
    # Select genes that exist in the dataset
    available_genes = [g for g in gene_list if g in adata.var_names]
    if len(available_genes) == 0:
        print("⚠️  None of the specified genes found in dataset")
        available_genes = top_variable_genes['gene_name'].head(max_genes).tolist()
        print(f"Using top {len(available_genes)} variable genes instead")
    
    genes_to_plot = available_genes[:max_genes]
    
    # Calculate expression by age and sex
    expr_df = pd.DataFrame(adata[:, genes_to_plot].X.toarray() if issparse(adata.X) else adata[:, genes_to_plot].X,
                          columns=genes_to_plot)
    expr_df['age'] = adata.obs['age'].values
    expr_df['sex'] = adata.obs['sex'].values
    
    # Create subplots
    n_genes = len(genes_to_plot)
    cols = 4
    rows = (n_genes + cols - 1) // cols
    
    fig, axes = plt.subplots(rows, cols, figsize=(4*cols, 3*rows))
    if rows == 1:
        axes = axes.reshape(1, -1)
    axes = axes.flatten()
    
    colors = {'male': '#1f77b4', 'female': '#ff7f0e'}
    
    for i, gene in enumerate(genes_to_plot):
        ax = axes[i]
        
        # Group by age and sex, calculate mean
        grouped = expr_df.groupby(['age', 'sex'])[gene].agg(['mean', 'std']).reset_index()
        
        for sex in ['male', 'female']:
            sex_data = grouped[grouped['sex'] == sex]
            ax.errorbar(sex_data['age'], sex_data['mean'], 
                       yerr=sex_data['std'], label=sex, 
                       color=colors[sex], marker='o', capsize=3)
        
        ax.set_title(f'{gene}', fontsize=10)
        ax.set_xlabel('Age')
        ax.set_ylabel('Expression')
        ax.grid(True, alpha=0.3)
        
        if i == 0:
            ax.legend()
    
    # Hide unused subplots
    for i in range(n_genes, len(axes)):
        axes[i].set_visible(False)
    
    plt.suptitle('Gene Expression Patterns Across Age and Sex', fontsize=14, y=1.02)
    plt.tight_layout()
    plt.show()
    
    return expr_df

# Known interesting genes (if they exist in the dataset)
interesting_genes = [
    'roX1', 'roX2', 'Sxl', 'fru', 'dsx',  # Sex-determination genes
    'Hsp70', 'Hsp83', 'InR', 'foxo',      # Aging-related genes
    'unc-13', 'Sh', 'para', 'Syt1'        # Neuronal genes
]

expr_patterns = analyze_age_patterns(adata, interesting_genes)

## Model Analysis (if models are available)

In [None]:
# Try to load trained model and analyze predictions
try:
    print("Attempting to load trained model...")
    path_manager = PathManager(config)
    model_loader = ModelLoader(config)
    
    # Load model components
    (
        label_encoder,
        scaler,
        is_scaler_fit,
        highly_variable_genes,
        num_features,
        history,
        mix_included,
        reference_data,
    ) = model_loader.load_model_components()
    
    # Load the actual model
    model = model_loader.load_model()
    
    print(f"✅ Model loaded successfully!")
    print(f"Model type: {config.data.model_type}")
    print(f"Target classes: {label_encoder.classes_}")
    print(f"Number of features: {num_features}")
    
    model_available = True
    
except Exception as e:
    print(f"⚠️  Could not load model: {e}")
    print("This is expected if no model has been trained yet.")
    model_available = False

In [None]:
if model_available:
    # Analyze model performance and predictions
    try:
        # Prepare evaluation data in the same way as training
        from timeflies.data.preprocessing.data_processor import DataPreprocessor
        
        processor = DataPreprocessor(config, adata, adata)
        eval_processed = processor.process_adata(adata_eval)
        
        # Get features and labels
        if highly_variable_genes:
            eval_data = eval_processed[:, highly_variable_genes].X
        else:
            eval_data = eval_processed.X
            
        if issparse(eval_data):
            eval_data = eval_data.toarray()
            
        # Get true labels
        true_labels = eval_processed.obs[config.data.encoding_variable].astype(str)
        true_labels_encoded = label_encoder.transform(true_labels)
        
        # Scale data if necessary
        if is_scaler_fit and scaler is not None:
            eval_data = scaler.transform(eval_data)
            
        # Make predictions
        model_type = config.data.model_type.lower()
        if model_type in ['cnn', 'mlp']:
            # Reshape for CNN if necessary
            if model_type == 'cnn' and len(eval_data.shape) == 2:
                eval_data = eval_data.reshape(eval_data.shape[0], 1, eval_data.shape[1])
            predictions_proba = model.predict(eval_data)
            predictions = np.argmax(predictions_proba, axis=1)
        else:
            predictions_proba = model.predict_proba(eval_data)
            predictions = model.predict(eval_data)
            
        # Calculate metrics
        from sklearn.metrics import accuracy_score, classification_report
        accuracy = accuracy_score(true_labels_encoded, predictions)
        
        print(f"\n=== Model Performance ===")
        print(f"Accuracy: {accuracy:.3f}")
        print(f"\nClassification Report:")
        print(classification_report(true_labels_encoded, predictions, 
                                  target_names=label_encoder.classes_))
        
        # Plot confusion matrix
        from sklearn.metrics import confusion_matrix
        cm = confusion_matrix(true_labels_encoded, predictions)
        
        plt.figure(figsize=(8, 6))
        sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', 
                   xticklabels=label_encoder.classes_,
                   yticklabels=label_encoder.classes_)
        plt.title('Confusion Matrix')
        plt.xlabel('Predicted')
        plt.ylabel('Actual')
        plt.show()
        
        # Prediction confidence analysis
        max_proba = np.max(predictions_proba, axis=1)
        
        plt.figure(figsize=(12, 4))
        
        plt.subplot(1, 2, 1)
        plt.hist(max_proba, bins=30, alpha=0.7, edgecolor='black')
        plt.title('Prediction Confidence Distribution')
        plt.xlabel('Maximum Probability')
        plt.ylabel('Number of Predictions')
        plt.axvline(max_proba.mean(), color='red', linestyle='--', 
                   label=f'Mean: {max_proba.mean():.3f}')
        plt.legend()
        
        plt.subplot(1, 2, 2)
        confidence_by_class = pd.DataFrame({
            'true_class': [label_encoder.classes_[i] for i in true_labels_encoded],
            'confidence': max_proba,
            'correct': true_labels_encoded == predictions
        })
        
        sns.boxplot(data=confidence_by_class, x='true_class', y='confidence', hue='correct')
        plt.title('Prediction Confidence by Class')
        plt.xticks(rotation=45)
        plt.legend(title='Correct Prediction')
        
        plt.tight_layout()
        plt.show()
        
        print(f"\nMean prediction confidence: {max_proba.mean():.3f} ± {max_proba.std():.3f}")
        print(f"Predictions with >90% confidence: {(max_proba > 0.9).mean():.1%}")
        print(f"Predictions with <60% confidence: {(max_proba < 0.6).mean():.1%}")
        
    except Exception as e:
        print(f"⚠️  Error in model analysis: {e}")
        print("This might be due to data preprocessing differences.")
else:
    print("Skipping model analysis - no trained model available.")

## Summary and Conclusions

In [None]:
# Generate summary statistics and insights
print("=== Analysis Summary ===")
print(f"Dataset: {config.data.tissue} tissue from Drosophila")
print(f"Total cells analyzed: {adata.n_obs:,}")
print(f"Total genes: {adata.n_vars:,}")
print(f"Prediction target: {config.data.encoding_variable}")

print(f"\n=== Key Findings ===")

# Age distribution insights
age_counts = adata.obs['age'].value_counts().sort_index()
most_common_age = age_counts.idxmax()
print(f"• Most represented age group: {most_common_age} ({age_counts[most_common_age]:,} cells)")

# Sex distribution insights
sex_counts = adata.obs['sex'].value_counts()
sex_ratio = sex_counts['female'] / sex_counts['male'] if 'female' in sex_counts and 'male' in sex_counts else 1
print(f"• Sex ratio (F:M): {sex_ratio:.2f}:1")

# Cell type diversity (if available)
if 'afca_annotation_broad' in adata.obs.columns:
    n_cell_types = adata.obs['afca_annotation_broad'].nunique()
    dominant_cell_type = adata.obs['afca_annotation_broad'].value_counts().iloc[0]
    print(f"• Cell type diversity: {n_cell_types} different types identified")
    print(f"• Most abundant cell type: {adata.obs['afca_annotation_broad'].value_counts().index[0]} ({dominant_cell_type:,} cells)")

# Expression characteristics
mean_genes_per_cell = (adata.X > 0).sum(axis=1).mean()
mean_cells_per_gene = (adata.X > 0).sum(axis=0).mean()
print(f"• Average genes expressed per cell: {mean_genes_per_cell:.0f}")
print(f"• Average cells expressing each gene: {mean_cells_per_gene:.0f}")

# Model performance (if available)
if model_available and 'accuracy' in locals():
    print(f"• Model accuracy: {accuracy:.1%}")
    if 'max_proba' in locals():
        print(f"• Mean prediction confidence: {max_proba.mean():.1%}")

print(f"\n=== Recommendations for Further Analysis ===")
print("• Investigate age-related expression changes in highly variable genes")
print("• Perform differential expression analysis between age groups")
print("• Analyze cell type-specific aging patterns")
if model_available:
    print("• Use SHAP analysis to understand model decision-making")
    print("• Investigate misclassified samples for biological insights")
print("• Consider batch effects and technical variation")
print("• Validate findings with independent datasets")

print(f"\n✅ Analysis completed successfully!")
print(f"This notebook demonstrated the TimeFlies framework capabilities for")
print(f"comprehensive single-cell RNA-seq analysis and aging research.")