# CREMI3D Dataset - Exploratory Data Analysis

This notebook provides visualization and analysis of the CREMI3D dataset for neuron and synapse segmentation.

**Dataset Info:**
- Source: MICCAI 2016 CREMI Challenge
- Resolution: 4x4x40 nm (anisotropic)
- Samples: A, B, C (each 1250x1250x125)
- Labels: neuron instances, synaptic clefts

In [None]:
import sys
sys.path.insert(0, '..')

import numpy as np
import matplotlib.pyplot as plt
from matplotlib import colors
import h5py
from pathlib import Path

# Set up plotting
plt.rcParams['figure.figsize'] = (12, 8)
plt.rcParams['figure.dpi'] = 100

# Dataset path - UPDATE THIS
DATA_ROOT = Path('../data/cremi')

# Colormap for segmentation
SEG_CMAP = 'PiYG'
SYN_CMAP = 'hot'

## 1. Load Dataset Information

In [None]:
from neurocircuitry.datasets import CREMI3DDataset

# Print dataset metadata
print("=" * 60)
print("CREMI3D Dataset Metadata")
print("=" * 60)
print(f"\nPaper: {CREMI3DDataset._paper}")
print(f"\nResolution: {CREMI3DDataset._resolution}")
print(f"Labels (base): {CREMI3DDataset._labels_base}")
print(f"Labels (with synapse): {CREMI3DDataset._labels_with_synapse}")

## 2. Load Raw Data Files

In [None]:
def load_cremi_sample(data_root, sample='A'):
    """
    Load a CREMI sample file.
    """
    data_root = Path(data_root)
    
    # Try different file naming conventions
    patterns = [
        f"sample_{sample}_20160501.hdf",
        f"sample_{sample}_20160501.h5",
        f"sample_{sample}.hdf",
        f"sample_{sample}.h5",
    ]
    
    for pattern in patterns:
        path = data_root / pattern
        if path.exists():
            print(f"Found: {path}")
            data = {}
            with h5py.File(path, 'r') as f:
                # Print file structure
                print(f"\nFile structure for sample {sample}:")
                def print_structure(name, obj):
                    if isinstance(obj, h5py.Dataset):
                        print(f"  {name}: shape={obj.shape}, dtype={obj.dtype}")
                f.visititems(print_structure)
                
                # Load data
                if 'volumes/raw' in f:
                    data['raw'] = f['volumes/raw'][:]
                if 'volumes/labels/neuron_ids' in f:
                    data['neuron_ids'] = f['volumes/labels/neuron_ids'][:]
                if 'volumes/labels/clefts' in f:
                    data['clefts'] = f['volumes/labels/clefts'][:]
            return data
    
    print(f"Sample {sample} not found in {data_root}")
    return None

# Load all samples
samples_data = {}
if DATA_ROOT.exists():
    for sample in ['A', 'B', 'C']:
        print(f"\n{'='*40}")
        print(f"Loading Sample {sample}")
        print(f"{'='*40}")
        data = load_cremi_sample(DATA_ROOT, sample)
        if data:
            samples_data[sample] = data
else:
    print(f"Data root not found: {DATA_ROOT}")
    print("Please update DATA_ROOT to point to your CREMI data directory")

## 3. Volume Statistics

In [None]:
def print_cremi_stats(samples_data):
    """Print statistics for CREMI samples."""
    print("\n" + "=" * 60)
    print("CREMI Volume Statistics")
    print("=" * 60)
    
    for sample, data in samples_data.items():
        print(f"\n--- Sample {sample} ---")
        for name, vol in data.items():
            print(f"\n  {name}:")
            print(f"    Shape: {vol.shape}")
            print(f"    Dtype: {vol.dtype}")
            print(f"    Min: {vol.min()}, Max: {vol.max()}")
            
            if name in ['neuron_ids', 'clefts']:
                unique = np.unique(vol)
                print(f"    Unique labels: {len(unique)}")
                if name == 'clefts':
                    fg_ratio = (vol > 0).sum() / vol.size * 100
                    print(f"    Synapse coverage: {fg_ratio:.2f}%")

if samples_data:
    print_cremi_stats(samples_data)

## 4. Visualize Image-Segmentation Pairs Per Sample

Show each sample with image, neuron segmentation, and synaptic clefts side by side.

In [None]:
def visualize_cremi_pairs(data, sample_name, num_slices=5, cmap_seg=SEG_CMAP, cmap_syn=SYN_CMAP):
    """
    Visualize CREMI image-segmentation pairs for a single sample.
    Shows: Raw | Neurons | Clefts | Overlay
    """
    if 'raw' not in data:
        print(f"No raw data for sample {sample_name}")
        return
    
    raw = data['raw']
    n_total = raw.shape[0]
    indices = np.linspace(0, n_total - 1, num_slices, dtype=int)
    
    has_neurons = 'neuron_ids' in data
    has_clefts = 'clefts' in data
    
    n_cols = 2 + int(has_neurons) + int(has_clefts)
    
    fig, axes = plt.subplots(num_slices, n_cols, figsize=(5 * n_cols, 4 * num_slices))
    
    for i, idx in enumerate(indices):
        col = 0
        
        # Raw image
        axes[i, col].imshow(raw[idx], cmap='gray')
        axes[i, col].set_title(f'Slice {idx} - EM Image')
        axes[i, col].axis('off')
        col += 1
        
        # Neuron segmentation
        if has_neurons:
            im = axes[i, col].imshow(data['neuron_ids'][idx], cmap=cmap_seg)
            axes[i, col].set_title(f'Slice {idx} - Neurons')
            axes[i, col].axis('off')
            col += 1
        
        # Synaptic clefts
        if has_clefts:
            axes[i, col].imshow(data['clefts'][idx], cmap=cmap_syn)
            axes[i, col].set_title(f'Slice {idx} - Synapses')
            axes[i, col].axis('off')
            col += 1
        
        # Overlay
        axes[i, col].imshow(raw[idx], cmap='gray')
        if has_neurons:
            axes[i, col].imshow(data['neuron_ids'][idx], cmap=cmap_seg, alpha=0.4)
        if has_clefts:
            cleft_mask = data['clefts'][idx] > 0
            overlay = np.zeros((*cleft_mask.shape, 4))
            overlay[cleft_mask] = [1, 0, 0, 0.8]  # Red for synapses
            axes[i, col].imshow(overlay)
        axes[i, col].set_title(f'Slice {idx} - Overlay')
        axes[i, col].axis('off')
    
    plt.suptitle(f'CREMI Sample {sample_name} - Image-Segmentation Pairs', fontsize=16, y=1.01)
    plt.tight_layout()
    plt.show()

# Visualize each sample
for sample_name, data in samples_data.items():
    visualize_cremi_pairs(data, sample_name)

## 5. All Samples Comparison

Compare the same slice index across all samples.

In [None]:
def compare_samples_at_slice(samples_data, slice_idx, cmap_seg=SEG_CMAP):
    """
    Compare the same slice across all CREMI samples.
    """
    n_samples = len(samples_data)
    if n_samples == 0:
        print("No samples to compare")
        return
    
    fig, axes = plt.subplots(n_samples, 4, figsize=(20, 5 * n_samples))
    
    if n_samples == 1:
        axes = axes.reshape(1, -1)
    
    for i, (sample_name, data) in enumerate(samples_data.items()):
        if 'raw' not in data:
            continue
        
        idx = min(slice_idx, data['raw'].shape[0] - 1)
        
        # Raw
        axes[i, 0].imshow(data['raw'][idx], cmap='gray')
        axes[i, 0].set_title(f'Sample {sample_name} - EM (Slice {idx})')
        axes[i, 0].axis('off')
        
        # Neurons
        if 'neuron_ids' in data:
            axes[i, 1].imshow(data['neuron_ids'][idx], cmap=cmap_seg)
            axes[i, 1].set_title(f'Sample {sample_name} - Neurons')
        axes[i, 1].axis('off')
        
        # Clefts
        if 'clefts' in data:
            axes[i, 2].imshow(data['clefts'][idx], cmap='hot')
            axes[i, 2].set_title(f'Sample {sample_name} - Synapses')
        axes[i, 2].axis('off')
        
        # Overlay
        axes[i, 3].imshow(data['raw'][idx], cmap='gray')
        if 'neuron_ids' in data:
            axes[i, 3].imshow(data['neuron_ids'][idx], cmap=cmap_seg, alpha=0.4)
        axes[i, 3].set_title(f'Sample {sample_name} - Overlay')
        axes[i, 3].axis('off')
    
    plt.suptitle(f'CREMI - All Samples at Slice {slice_idx}', fontsize=16, y=1.01)
    plt.tight_layout()
    plt.show()

# Compare at middle slice
if samples_data:
    compare_samples_at_slice(samples_data, slice_idx=62)

## 6. Label Distribution Analysis

In [None]:
def analyze_cremi_labels(samples_data):
    """
    Analyze label distribution across all CREMI samples.
    """
    print("\n" + "=" * 60)
    print("CREMI Label Analysis")
    print("=" * 60)
    
    summary = {'neurons': [], 'clefts': []}
    
    for sample, data in samples_data.items():
        print(f"\n--- Sample {sample} ---")
        
        if 'neuron_ids' in data:
            neuron_ids = data['neuron_ids']
            unique = np.unique(neuron_ids)
            n_neurons = len(unique) - 1  # Exclude background
            
            print(f"\nNeuron Instances: {n_neurons}")
            summary['neurons'].append(n_neurons)
            
            # Size distribution
            _, counts = np.unique(neuron_ids, return_counts=True)
            instance_sizes = counts[unique > 0]
            if len(instance_sizes) > 0:
                print(f"  Size range: {instance_sizes.min():,} - {instance_sizes.max():,} voxels")
                print(f"  Mean size: {instance_sizes.mean():,.0f} voxels")
        
        if 'clefts' in data:
            clefts = data['clefts']
            unique_clefts = np.unique(clefts)
            n_clefts = len(unique_clefts) - 1 if unique_clefts[0] == 0 else len(unique_clefts)
            
            print(f"\nSynaptic Clefts: {n_clefts}")
            summary['clefts'].append(n_clefts)
            
            cleft_voxels = (clefts > 0).sum()
            print(f"  Total synapse voxels: {cleft_voxels:,}")
            print(f"  Coverage: {100*cleft_voxels/clefts.size:.3f}%")
    
    # Summary
    if summary['neurons']:
        print(f"\n{'='*40}")
        print("Summary Across Samples")
        print(f"{'='*40}")
        print(f"Total neuron instances: {sum(summary['neurons'])}")
        print(f"Total synaptic clefts: {sum(summary['clefts'])}")

if samples_data:
    analyze_cremi_labels(samples_data)

In [None]:
def plot_cremi_size_distribution(samples_data):
    """
    Plot instance size distributions for all samples.
    """
    n_samples = len(samples_data)
    if n_samples == 0:
        return
    
    fig, axes = plt.subplots(1, n_samples, figsize=(6 * n_samples, 5))
    if n_samples == 1:
        axes = [axes]
    
    for ax, (sample, data) in zip(axes, samples_data.items()):
        if 'neuron_ids' not in data:
            continue
            
        neuron_ids = data['neuron_ids']
        unique, counts = np.unique(neuron_ids, return_counts=True)
        instance_sizes = counts[unique > 0]
        
        ax.hist(instance_sizes, bins=50, edgecolor='black', alpha=0.7, color='#8e6c8a')
        ax.set_xlabel('Instance Size (voxels)')
        ax.set_ylabel('Count')
        ax.set_title(f'Sample {sample}\n({len(instance_sizes)} instances)')
        ax.axvline(np.mean(instance_sizes), color='r', linestyle='--', 
                   label=f'Mean: {np.mean(instance_sizes):.0f}')
        ax.legend()
    
    plt.suptitle('CREMI - Neuron Instance Size Distribution', fontsize=14, y=1.02)
    plt.tight_layout()
    plt.show()

if samples_data:
    plot_cremi_size_distribution(samples_data)

## 7. Cross-Sample Comparison Table

In [None]:
def compare_samples(samples_data):
    """
    Compare statistics across CREMI samples.
    """
    if len(samples_data) < 2:
        print("Need at least 2 samples for comparison")
        return
    
    metrics = {
        'Sample': [],
        'Image Mean': [],
        'Image Std': [],
        'Neurons': [],
        'Avg Neuron Size': [],
        'Synapse Coverage (%)': []
    }
    
    for sample, data in samples_data.items():
        metrics['Sample'].append(sample)
        
        if 'raw' in data:
            metrics['Image Mean'].append(f"{data['raw'].mean():.1f}")
            metrics['Image Std'].append(f"{data['raw'].std():.1f}")
        else:
            metrics['Image Mean'].append('N/A')
            metrics['Image Std'].append('N/A')
        
        if 'neuron_ids' in data:
            unique, counts = np.unique(data['neuron_ids'], return_counts=True)
            n_neurons = len(unique) - 1
            avg_size = counts[unique > 0].mean() if n_neurons > 0 else 0
            metrics['Neurons'].append(n_neurons)
            metrics['Avg Neuron Size'].append(f"{avg_size:.0f}")
        else:
            metrics['Neurons'].append('N/A')
            metrics['Avg Neuron Size'].append('N/A')
        
        if 'clefts' in data:
            coverage = 100 * (data['clefts'] > 0).sum() / data['clefts'].size
            metrics['Synapse Coverage (%)'].append(f"{coverage:.3f}")
        else:
            metrics['Synapse Coverage (%)'].append('N/A')
    
    # Print as table
    print("\n" + "=" * 80)
    print("CREMI Cross-Sample Comparison")
    print("=" * 80)
    
    header = " | ".join([f"{k:^18}" for k in metrics.keys()])
    print(header)
    print("-" * len(header))
    
    for i in range(len(metrics['Sample'])):
        row = " | ".join([f"{metrics[k][i]:^18}" for k in metrics.keys()])
        print(row)

if samples_data:
    compare_samples(samples_data)

## 8. Using NeuroCircuitry Dataset Class

In [None]:
# Example of using the NeuroCircuitry CREMI dataset class
if DATA_ROOT.exists():
    try:
        from neurocircuitry.datasets import CREMI3DDataset
        
        dataset = CREMI3DDataset(
            root_dir=str(DATA_ROOT),
            split='train',
            samples=['A'],
            include_synapses=True,
            cache_rate=0.0,
        )
        
        print(f"\nDataset loaded successfully!")
        print(dataset)
        
        # Get a sample
        sample = dataset[0]
        print(f"\nSample keys: {sample.keys()}")
        print(f"Image shape: {sample['image'].shape}")
        if 'label' in sample:
            print(f"Label shape: {sample['label'].shape}")
        if 'clefts' in sample:
            print(f"Clefts shape: {sample['clefts'].shape}")
            
    except Exception as e:
        print(f"Could not load dataset: {e}")
else:
    print("Update DATA_ROOT to test the NeuroCircuitry dataset class")