# LBMD Visualization and Interpretation Demo

This notebook demonstrates advanced visualization techniques and interpretation methods for LBMD analysis results.

## 🎨 What You'll Learn

- **Advanced Visualizations**: Interactive manifold exploration, boundary overlays, statistical plots
- **Result Interpretation**: Understanding boundary patterns, cluster analysis, neuron importance
- **Comparative Analysis**: Comparing results across different images and layers
- **Export Options**: Saving results for publications and further analysis

In [None]:
# Import required libraries
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import plotly.graph_objects as go
import plotly.express as px
from plotly.subplots import make_subplots
import pandas as pd
from pathlib import Path

# Set style
plt.style.use('seaborn-v0_8')
sns.set_palette("husl")

print("✅ Visualization libraries loaded successfully!")

## 📊 Step 1: Comprehensive Result Visualization

Let's create comprehensive visualizations of our LBMD analysis results.

In [None]:
def create_comprehensive_visualization(sample_data, lbmd_results, save_path=None):
    """Create comprehensive visualization of LBMD results."""
    
    n_images = len(sample_data)
    fig = plt.figure(figsize=(20, 6 * n_images))
    
    for i, ((image, mask), results) in enumerate(zip(sample_data, lbmd_results)):
        # Create subplot grid for this image
        base_idx = i * 4
        
        # 1. Original image with boundary overlay
        ax1 = plt.subplot(n_images, 4, base_idx + 1)
        ax1.imshow(image)
        
        # Create boundary overlay
        boundary_overlay = np.ma.masked_where(~results.boundary_mask, results.boundary_mask)
        ax1.imshow(boundary_overlay, alpha=0.6, cmap='Reds')
        ax1.set_title(f'Image {i+1}: Detected Boundaries', fontsize=12, fontweight='bold')
        ax1.axis('off')
        
        # Add boundary statistics
        boundary_coverage = np.mean(results.boundary_mask) * 100
        ax1.text(10, 30, f'Coverage: {boundary_coverage:.1f}%', 
                bbox=dict(boxstyle='round', facecolor='white', alpha=0.8),
                fontsize=10, fontweight='bold')
        
        # 2. Boundary manifold visualization
        ax2 = plt.subplot(n_images, 4, base_idx + 2)
        
        # Color points by cluster and boundary status
        scatter = ax2.scatter(results.manifold_coords[:, 0], 
                             results.manifold_coords[:, 1],
                             c=results.clusters, 
                             cmap='tab10',
                             s=30,
                             alpha=0.7,
                             edgecolors='black',
                             linewidth=0.5)
        
        # Highlight boundary points
        boundary_points = results.manifold_coords[results.is_boundary]
        if len(boundary_points) > 0:
            ax2.scatter(boundary_points[:, 0], boundary_points[:, 1],
                       s=50, c='red', marker='x', linewidth=2, alpha=0.8,
                       label='Boundary Points')
        
        ax2.set_title(f'Boundary Manifold (t-SNE)', fontsize=12, fontweight='bold')
        ax2.set_xlabel('t-SNE Dimension 1')
        ax2.set_ylabel('t-SNE Dimension 2')
        ax2.grid(True, alpha=0.3)
        if len(boundary_points) > 0:
            ax2.legend()
        
        # Add cluster statistics
        n_clusters = len(np.unique(results.clusters))
        n_boundary_points = np.sum(results.is_boundary)
        ax2.text(0.02, 0.98, f'Clusters: {n_clusters}\nBoundary Points: {n_boundary_points}',
                transform=ax2.transAxes, verticalalignment='top',
                bbox=dict(boxstyle='round', facecolor='white', alpha=0.8),
                fontsize=9)
        
        # 3. Boundary responsiveness scores
        ax3 = plt.subplot(n_images, 4, base_idx + 3)
        
        sorted_scores = sorted(results.boundary_scores, reverse=True)
        bars = ax3.bar(range(len(sorted_scores)), sorted_scores, 
                      alpha=0.7, color='steelblue', edgecolor='black', linewidth=0.5)
        
        # Highlight top neurons
        top_n = min(5, len(sorted_scores))
        for j in range(top_n):
            bars[j].set_color('gold')
            bars[j].set_edgecolor('darkorange')
        
        ax3.set_title(f'Boundary Responsiveness Scores', fontsize=12, fontweight='bold')
        ax3.set_xlabel('Neuron Rank')
        ax3.set_ylabel('Responsiveness Score')
        ax3.grid(True, alpha=0.3)
        
        # Add statistics
        mean_score = np.mean(sorted_scores)
        max_score = np.max(sorted_scores)
        ax3.axhline(mean_score, color='red', linestyle='--', alpha=0.7, label=f'Mean: {mean_score:.3f}')
        ax3.legend()
        
        # 4. Transition strength heatmap
        ax4 = plt.subplot(n_images, 4, base_idx + 4)
        
        # Create transition matrix
        n_clusters = len(np.unique(results.clusters))
        transition_matrix = np.zeros((n_clusters, n_clusters))
        
        for (i_cluster, j_cluster), strength in results.transition_strengths.items():
            transition_matrix[i_cluster, j_cluster] = strength
            transition_matrix[j_cluster, i_cluster] = strength  # Symmetric
        
        # Create heatmap
        im = ax4.imshow(transition_matrix, cmap='YlOrRd', aspect='auto')
        
        # Add colorbar
        cbar = plt.colorbar(im, ax=ax4, shrink=0.8)
        cbar.set_label('Transition Strength', rotation=270, labelpad=15)
        
        # Add text annotations
        for ii in range(n_clusters):
            for jj in range(n_clusters):
                if transition_matrix[ii, jj] > 0:
                    text = ax4.text(jj, ii, f'{transition_matrix[ii, jj]:.2f}',
                                   ha="center", va="center", color="black", fontweight='bold')
        
        ax4.set_title(f'Cluster Transition Strengths', fontsize=12, fontweight='bold')
        ax4.set_xlabel('Cluster ID')
        ax4.set_ylabel('Cluster ID')
        ax4.set_xticks(range(n_clusters))
        ax4.set_yticks(range(n_clusters))
    
    plt.tight_layout()
    
    if save_path:
        plt.savefig(save_path, dpi=300, bbox_inches='tight')
        print(f"✅ Comprehensive visualization saved to {save_path}")
    
    plt.show()

# Create comprehensive visualization
create_comprehensive_visualization(sample_data, lbmd_results, 
                                 save_path='./demo_results/comprehensive_analysis.png')