# MICRONS Dataset - Exploratory Data Analysis

This notebook provides visualization and analysis of the MICRONS dataset for large-scale cortical connectomics.

**Dataset Info:**
- Source: MICrONS Consortium (2021)
- Resolution: 4×4×40 nm (typical)
- Content: Mouse visual cortex
- Labels: neurons, synapses, mitochondria

In [None]:
import sys
sys.path.insert(0, '..')

import numpy as np
import matplotlib.pyplot as plt
from matplotlib import colors
import h5py
import tifffile
from pathlib import Path

# Set up plotting
plt.rcParams['figure.figsize'] = (12, 8)
plt.rcParams['figure.dpi'] = 100

# Dataset path - UPDATE THIS
DATA_ROOT = Path('../data/microns')

## 1. Load Dataset Information

In [None]:
from neurocircuitry.datasets import MICRONSDataset

# Print dataset metadata
print("=" * 60)
print("MICRONS Dataset Metadata")
print("=" * 60)
print(f"\nPaper: {MICRONSDataset._paper}")
print(f"\nResolution: {MICRONSDataset._resolution}")
print(f"Labels (base): {MICRONSDataset._labels_base}")
print(f"Labels (extended): {MICRONSDataset._labels_extended}")

## 2. Load Raw Data Files

In [None]:
def load_microns_data(data_root):
    """
    Load MICRONS volumes and labels.
    Supports HDF5, TIFF, and NRRD formats.
    """
    data_root = Path(data_root)
    
    volumes = {}
    files_to_check = [
        ('volume', 'volume'),
        ('segmentation', 'segmentation'),
        ('synapses', 'synapses'),
        ('mitochondria', 'mitochondria'),
    ]
    
    for name, base in files_to_check:
        for ext in ['.h5', '.hdf5', '.tiff', '.tif', '.nrrd']:
            path = data_root / f"{base}{ext}"
            if path.exists():
                print(f"Found: {path}")
                try:
                    if ext in ['.h5', '.hdf5']:
                        with h5py.File(path, 'r') as f:
                            key = list(f.keys())[0]
                            volumes[name] = f[key][:]
                    elif ext in ['.tiff', '.tif']:
                        volumes[name] = tifffile.imread(str(path))
                    elif ext == '.nrrd':
                        import nrrd
                        volumes[name], _ = nrrd.read(str(path))
                except Exception as e:
                    print(f"Error loading {path}: {e}")
                break
    
    return volumes

# Load data
if DATA_ROOT.exists():
    volumes = load_microns_data(DATA_ROOT)
    print(f"\nLoaded {len(volumes)} volumes")
    for name, vol in volumes.items():
        print(f"  {name}: shape={vol.shape}, dtype={vol.dtype}")
else:
    print(f"Data root not found: {DATA_ROOT}")
    print("Please update DATA_ROOT to point to your MICRONS data directory")
    print("\nNote: MICRONS data is typically accessed via CAVEclient API.")
    print("This notebook expects locally downloaded subvolumes.")
    volumes = {}

## 3. Volume Statistics

In [None]:
def print_volume_stats(volumes):
    """Print detailed statistics for each volume."""
    print("\n" + "=" * 60)
    print("MICRONS Volume Statistics")
    print("=" * 60)
    
    for name, vol in volumes.items():
        print(f"\n{name}:")
        print(f"  Shape: {vol.shape}")
        print(f"  Dtype: {vol.dtype}")
        print(f"  Min: {vol.min()}, Max: {vol.max()}")
        print(f"  Mean: {vol.mean():.2f}, Std: {vol.std():.2f}")
        print(f"  Memory: {vol.nbytes / 1024**2:.2f} MB")
        
        # Physical dimensions
        res = MICRONSDataset._resolution
        if len(vol.shape) == 3:
            phys_size = (
                vol.shape[0] * res['z'] / 1000,  # um
                vol.shape[1] * res['y'] / 1000,
                vol.shape[2] * res['x'] / 1000
            )
            print(f"  Physical size: {phys_size[0]:.1f} x {phys_size[1]:.1f} x {phys_size[2]:.1f} μm")
        
        # Label-specific stats
        if name in ['segmentation', 'synapses', 'mitochondria']:
            unique = np.unique(vol)
            n_instances = len(unique) - (1 if 0 in unique else 0)
            print(f"  Unique labels: {len(unique)}")
            print(f"  Instances (excl. background): {n_instances}")
            
            fg_ratio = (vol > 0).sum() / vol.size * 100
            print(f"  Foreground coverage: {fg_ratio:.2f}%")

if volumes:
    print_volume_stats(volumes)

## 4. Visualize Multiple Slices

In [None]:
def visualize_microns_slices(volume, title, num_slices=9, cmap='gray', is_label=False):
    """
    Visualize multiple slices from a MICRONS volume.
    """
    if volume.ndim != 3:
        print(f"Expected 3D volume, got shape {volume.shape}")
        return
    
    n_total = volume.shape[0]
    indices = np.linspace(0, n_total - 1, num_slices, dtype=int)
    
    # Set up colormap for labels
    if is_label:
        n_labels = min(int(volume.max()) + 1, 10000)
        np.random.seed(42)
        random_colors = np.random.rand(n_labels, 3)
        random_colors[0] = [0, 0, 0]
        cmap = colors.ListedColormap(random_colors)
    
    cols = 3
    rows = (num_slices + cols - 1) // cols
    
    fig, axes = plt.subplots(rows, cols, figsize=(15, 5 * rows))
    axes = axes.flatten()
    
    for i, idx in enumerate(indices):
        ax = axes[i]
        im = ax.imshow(volume[idx], cmap=cmap)
        ax.set_title(f'Slice {idx}')
        ax.axis('off')
        if not is_label:
            plt.colorbar(im, ax=ax, fraction=0.046, pad=0.04)
    
    for i in range(num_slices, len(axes)):
        axes[i].axis('off')
    
    plt.suptitle(title, fontsize=16, y=1.02)
    plt.tight_layout()
    plt.show()

# Visualize EM volume
if 'volume' in volumes:
    visualize_microns_slices(volumes['volume'], 'MICRONS EM Volume')

In [None]:
# Visualize segmentation
if 'segmentation' in volumes:
    visualize_microns_slices(volumes['segmentation'], 'MICRONS Neuron Segmentation', is_label=True)

In [None]:
# Visualize synapses
if 'synapses' in volumes:
    visualize_microns_slices(volumes['synapses'], 'MICRONS Synapses', cmap='hot', is_label=True)

In [None]:
# Visualize mitochondria
if 'mitochondria' in volumes:
    visualize_microns_slices(volumes['mitochondria'], 'MICRONS Mitochondria', is_label=True)

## 5. Multi-Channel Overlay

In [None]:
def visualize_multichannel_overlay(volumes, slice_idx=None):
    """
    Visualize EM image with all annotation layers overlaid.
    """
    if 'volume' not in volumes:
        print("No volume data available")
        return
    
    vol = volumes['volume']
    if slice_idx is None:
        slice_idx = vol.shape[0] // 2
    
    # Create figure
    n_panels = 2 + sum(1 for k in ['segmentation', 'synapses', 'mitochondria'] if k in volumes)
    fig, axes = plt.subplots(1, n_panels, figsize=(5 * n_panels, 5))
    
    panel_idx = 0
    
    # Raw EM
    axes[panel_idx].imshow(vol[slice_idx], cmap='gray')
    axes[panel_idx].set_title('EM Image')
    axes[panel_idx].axis('off')
    panel_idx += 1
    
    # Segmentation
    if 'segmentation' in volumes:
        seg = volumes['segmentation']
        n_labels = min(int(seg.max()) + 1, 10000)
        np.random.seed(42)
        random_colors = np.random.rand(n_labels, 3)
        random_colors[0] = [0, 0, 0]
        cmap_seg = colors.ListedColormap(random_colors)
        
        axes[panel_idx].imshow(seg[slice_idx], cmap=cmap_seg)
        axes[panel_idx].set_title('Neuron Segmentation')
        axes[panel_idx].axis('off')
        panel_idx += 1
    
    # Synapses
    if 'synapses' in volumes:
        axes[panel_idx].imshow(volumes['synapses'][slice_idx], cmap='hot')
        axes[panel_idx].set_title('Synapses')
        axes[panel_idx].axis('off')
        panel_idx += 1
    
    # Mitochondria
    if 'mitochondria' in volumes:
        axes[panel_idx].imshow(volumes['mitochondria'][slice_idx], cmap='viridis')
        axes[panel_idx].set_title('Mitochondria')
        axes[panel_idx].axis('off')
        panel_idx += 1
    
    # Combined overlay
    ax_overlay = axes[panel_idx]
    ax_overlay.imshow(vol[slice_idx], cmap='gray')
    
    if 'segmentation' in volumes:
        ax_overlay.imshow(seg[slice_idx], cmap=cmap_seg, alpha=0.3)
    
    if 'synapses' in volumes:
        syn_mask = volumes['synapses'][slice_idx] > 0
        overlay_syn = np.zeros((*syn_mask.shape, 4))
        overlay_syn[syn_mask] = [1, 0, 0, 0.8]  # Red
        ax_overlay.imshow(overlay_syn)
    
    if 'mitochondria' in volumes:
        mito_mask = volumes['mitochondria'][slice_idx] > 0
        overlay_mito = np.zeros((*mito_mask.shape, 4))
        overlay_mito[mito_mask] = [0, 1, 0, 0.6]  # Green
        ax_overlay.imshow(overlay_mito)
    
    ax_overlay.set_title('Combined Overlay')
    ax_overlay.axis('off')
    
    plt.suptitle(f'MICRONS Multi-Channel View - Slice {slice_idx}', fontsize=14, y=1.02)
    plt.tight_layout()
    plt.show()

if volumes:
    visualize_multichannel_overlay(volumes)

## 6. Label Distribution Analysis

In [None]:
def analyze_microns_labels(volumes):
    """
    Analyze label distributions for all annotation types.
    """
    print("\n" + "=" * 60)
    print("MICRONS Label Analysis")
    print("=" * 60)
    
    label_types = ['segmentation', 'synapses', 'mitochondria']
    
    for label_type in label_types:
        if label_type not in volumes:
            continue
        
        labels = volumes[label_type]
        unique, counts = np.unique(labels, return_counts=True)
        
        print(f"\n--- {label_type.capitalize()} ---")
        print(f"Total unique labels: {len(unique)}")
        
        n_instances = len(unique) - (1 if 0 in unique else 0)
        print(f"Number of instances: {n_instances}")
        
        # Background vs foreground
        bg_idx = np.where(unique == 0)[0]
        if len(bg_idx) > 0:
            bg_pixels = counts[bg_idx[0]]
        else:
            bg_pixels = 0
        fg_pixels = labels.size - bg_pixels
        
        print(f"Background: {bg_pixels:,} voxels ({100*bg_pixels/labels.size:.1f}%)")
        print(f"Foreground: {fg_pixels:,} voxels ({100*fg_pixels/labels.size:.1f}%)")
        
        # Instance size statistics
        instance_sizes = counts[unique > 0]
        if len(instance_sizes) > 0:
            print(f"\nInstance size statistics:")
            print(f"  Min: {instance_sizes.min():,}")
            print(f"  Max: {instance_sizes.max():,}")
            print(f"  Mean: {instance_sizes.mean():,.0f}")
            print(f"  Median: {np.median(instance_sizes):,.0f}")

if volumes:
    analyze_microns_labels(volumes)

In [None]:
def plot_size_distributions(volumes):
    """
    Plot size distributions for all label types.
    """
    label_types = [k for k in ['segmentation', 'synapses', 'mitochondria'] if k in volumes]
    
    if not label_types:
        print("No label volumes to analyze")
        return
    
    fig, axes = plt.subplots(1, len(label_types), figsize=(6 * len(label_types), 5))
    if len(label_types) == 1:
        axes = [axes]
    
    for ax, label_type in zip(axes, label_types):
        labels = volumes[label_type]
        unique, counts = np.unique(labels, return_counts=True)
        instance_sizes = counts[unique > 0]
        
        if len(instance_sizes) == 0:
            ax.text(0.5, 0.5, 'No instances', ha='center', va='center')
            continue
        
        ax.hist(instance_sizes, bins=50, edgecolor='black', alpha=0.7)
        ax.set_xlabel('Instance Size (voxels)')
        ax.set_ylabel('Count')
        ax.set_title(f'{label_type.capitalize()}\n({len(instance_sizes)} instances)')
        ax.axvline(np.mean(instance_sizes), color='r', linestyle='--',
                   label=f'Mean: {np.mean(instance_sizes):.0f}')
        ax.legend()
    
    plt.suptitle('MICRONS - Instance Size Distributions', fontsize=14, y=1.02)
    plt.tight_layout()
    plt.show()

if volumes:
    plot_size_distributions(volumes)

## 7. Per-Slice Analysis

In [None]:
def analyze_per_slice_microns(volumes):
    """
    Analyze content per slice for MICRONS data.
    """
    label_types = [k for k in ['segmentation', 'synapses', 'mitochondria'] if k in volumes]
    
    if not label_types:
        print("No label volumes to analyze")
        return
    
    ref_vol = volumes[label_types[0]]
    n_slices = ref_vol.shape[0]
    
    fig, axes = plt.subplots(len(label_types), 1, figsize=(12, 4 * len(label_types)))
    if len(label_types) == 1:
        axes = [axes]
    
    for ax, label_type in zip(axes, label_types):
        labels = volumes[label_type]
        
        instances_per_slice = []
        for i in range(n_slices):
            unique = np.unique(labels[i])
            n_instances = len(unique[unique > 0])
            instances_per_slice.append(n_instances)
        
        instances_per_slice = np.array(instances_per_slice)
        
        ax.plot(instances_per_slice, marker='o', markersize=3)
        ax.set_xlabel('Slice Index')
        ax.set_ylabel('Number of Instances')
        ax.set_title(f'{label_type.capitalize()} per Slice')
        ax.axhline(instances_per_slice.mean(), color='r', linestyle='--',
                   label=f'Mean: {instances_per_slice.mean():.1f}')
        ax.legend()
        ax.grid(True, alpha=0.3)
    
    plt.tight_layout()
    plt.show()

if volumes:
    analyze_per_slice_microns(volumes)

## 8. Using NeuroCircuitry Dataset Class

In [None]:
# Example of using the NeuroCircuitry MICRONS dataset class
if DATA_ROOT.exists() and volumes:
    try:
        from neurocircuitry.datasets import MICRONSDataset
        
        dataset = MICRONSDataset(
            root_dir=str(DATA_ROOT),
            split='train',
            include_synapses='synapses' in volumes,
            include_mitochondria='mitochondria' in volumes,
            cache_rate=0.0,
        )
        
        print(f"\nDataset loaded successfully!")
        print(dataset)
        
        # Get a sample
        sample = dataset[0]
        print(f"\nSample keys: {sample.keys()}")
        print(f"Image shape: {sample['image'].shape}")
        if 'label' in sample:
            print(f"Label shape: {sample['label'].shape}")
        if 'synapses' in sample:
            print(f"Synapses shape: {sample['synapses'].shape}")
        if 'mitochondria' in sample:
            print(f"Mitochondria shape: {sample['mitochondria'].shape}")
            
    except Exception as e:
        print(f"Could not load dataset: {e}")
else:
    print("Update DATA_ROOT and ensure data files exist to test the NeuroCircuitry dataset class")

## 9. Dataset Summary

In [None]:
def print_dataset_summary(volumes):
    """
    Print a comprehensive summary of the MICRONS dataset.
    """
    print("\n" + "=" * 60)
    print("MICRONS Dataset Summary")
    print("=" * 60)
    
    if 'volume' in volumes:
        vol = volumes['volume']
        res = MICRONSDataset._resolution
        
        print(f"\nVolume Dimensions: {vol.shape}")
        print(f"Resolution: {res['x']}×{res['y']}×{res['z']} nm")
        
        phys_size = (
            vol.shape[0] * res['z'] / 1000,
            vol.shape[1] * res['y'] / 1000,
            vol.shape[2] * res['x'] / 1000
        )
        print(f"Physical Size: {phys_size[0]:.1f} × {phys_size[1]:.1f} × {phys_size[2]:.1f} μm")
    
    print("\nAvailable Annotations:")
    for label_type in ['segmentation', 'synapses', 'mitochondria']:
        if label_type in volumes:
            labels = volumes[label_type]
            unique = np.unique(labels)
            n_instances = len(unique) - (1 if 0 in unique else 0)
            coverage = 100 * (labels > 0).sum() / labels.size
            print(f"  - {label_type.capitalize()}: {n_instances} instances ({coverage:.2f}% coverage)")
        else:
            print(f"  - {label_type.capitalize()}: Not available")
    
    # Memory usage
    total_memory = sum(v.nbytes for v in volumes.values()) / 1024**2
    print(f"\nTotal Memory: {total_memory:.1f} MB")

if volumes:
    print_dataset_summary(volumes)