<a href="https://colab.research.google.com/github/shirgalor/caiman_local/blob/main/CNMF_Colab_Runner.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# CNMF Analysis in Google Colab
## Your Custom CaImAn CNMF Project

This notebook runs your specific CNMF project with all custom modifications:
- Custom debug stage saving
- Error handling for spatial/temporal updates  
- YrA error filtering
- Manual step-by-step CNMF execution

**Runtime Requirements:** Use GPU runtime for faster processing (though CNMF is CPU-based)

## 📦 Setup: Install Dependencies

In [None]:
# Install CaImAn and dependencies
!pip install caiman[complete] -q
!pip install tifffile matplotlib -q

# Install additional packages for visualization
!pip install seaborn plotly -q

print("✅ All dependencies installed!")

## 💾 Setup: Mount Google Drive & Create Directories

In [None]:
from google.colab import drive, files
import os
import shutil

# Mount Google Drive
drive.mount('/content/drive')

# Create output directories
output_dir = '/content/drive/MyDrive/cnmf_colab_output'
local_output = '/content/cnmf_local_output'
os.makedirs(output_dir, exist_ok=True)
os.makedirs(local_output, exist_ok=True)
os.makedirs('/content/data', exist_ok=True)

print(f"📁 Output directories created:")
print(f"  - Google Drive: {output_dir}")
print(f"  - Local (faster): {local_output}")
print(f"  - Data directory: /content/data")

## 📹 Upload Your Video File

**Choose one option:**
1. **Upload directly** (run the cell below)
2. **Use from Google Drive** (modify the path in the next cell)

In [None]:
# Option 1: Upload video file directly to Colab
print("📹 Please upload your TIF video file:")
uploaded = files.upload()

# Move uploaded file to data directory
video_filename = None
for filename in uploaded.keys():
    if filename.lower().endswith(('.tif', '.tiff')):
        video_path = f'/content/data/{filename}'
        shutil.move(filename, video_path)
        video_filename = filename
        print(f"✅ Video uploaded: {video_path}")
        print(f"📊 File size: {os.path.getsize(video_path) / (1024**3):.2f} GB")
        break

if video_filename is None:
    print("❌ No TIF file found in upload. Please upload a TIF/TIFF file.")

In [None]:
# Option 2: Use video file from Google Drive
# Uncomment and modify the path below if you have the video in Google Drive

# drive_video_path = '/content/drive/MyDrive/your_video_file.tif'  # MODIFY THIS PATH
# 
# if os.path.exists(drive_video_path):
#     video_path = '/content/data/video_file.tif'
#     shutil.copy(drive_video_path, video_path)
#     print(f"✅ Video copied from Drive: {video_path}")
#     print(f"📊 File size: {os.path.getsize(video_path) / (1024**3):.2f} GB")
# else:
#     print(f"❌ Video file not found at: {drive_video_path}")
#     print("📁 Available files in your Drive:")
#     !find '/content/drive/MyDrive' -name '*.tif' -o -name '*.TIF' | head -10

## 🔍 System Check: Memory & Video Info

In [None]:
import psutil
import numpy as np
from tifffile import imread

def check_system_resources():
    """Check available system resources"""
    memory = psutil.virtual_memory()
    cpu_count = psutil.cpu_count()
    
    print("🖥️ System Resources:")
    print(f"  💾 Total RAM: {memory.total / (1024**3):.1f} GB")
    print(f"  💾 Available RAM: {memory.available / (1024**3):.1f} GB")
    print(f"  🔧 CPU Cores: {cpu_count}")
    
    return memory.available / (1024**3)

def check_video_info(video_path):
    """Check video file information and memory requirements"""
    try:
        # Read just the header to get dimensions
        with open(video_path, 'rb') as f:
            # Try to read video info without loading full file
            pass
        
        # Load a small sample to get info
        print("📹 Analyzing video file...")
        Y_sample = imread(video_path, key=slice(0, 10))  # Read first 10 frames
        sample_frames, d1, d2 = Y_sample.shape
        
        # Estimate total frames (this is approximate)
        file_size = os.path.getsize(video_path)
        bytes_per_frame = Y_sample.nbytes / sample_frames
        estimated_total_frames = int(file_size / bytes_per_frame)
        
        print(f"📊 Video Information:")
        print(f"  📐 Dimensions: {d1} x {d2} pixels")
        print(f"  🎬 Estimated frames: ~{estimated_total_frames}")
        print(f"  💾 Estimated RAM needed: ~{(estimated_total_frames * d1 * d2 * 4) / (1024**3):.1f} GB")
        
        return estimated_total_frames, d1, d2
        
    except Exception as e:
        print(f"❌ Error reading video: {e}")
        return None, None, None

# Check system and video
available_ram = check_system_resources()

if 'video_path' in locals() and os.path.exists(video_path):
    est_frames, d1, d2 = check_video_info(video_path)
    
    if est_frames and est_frames > 800:
        print(f"\n⚠️ Large dataset detected ({est_frames} frames)")
        print(f"💡 Consider limiting to first 600-800 frames for Colab")
        print(f"💡 You can modify this in the next cell")
else:
    print("❌ Video file not found. Please upload or set the correct path.")

## 🧠 Your Custom CNMF Code

This cell contains your exact CNMF implementation with all modifications:

In [None]:
import os
import numpy as np
from tifffile import imread
import warnings
import sys
from contextlib import redirect_stderr
from io import StringIO
import gc

# Simple error suppression for YrA plotting issues
def suppress_YrA_errors():
    """Suppress common YrA visualization errors"""
    warnings.filterwarnings("ignore", message=".*cannot reshape array.*")
    warnings.filterwarnings("ignore", message=".*Failed to plot YrA.*")

suppress_YrA_errors()

from caiman.source_extraction.cnmf import cnmf, params
from caiman.mmapping import load_memmap, save_memmap

# Create a custom output filter for YrA messages only
class YrAErrorFilter:
    def __init__(self, original_stream):
        self.original_stream = original_stream
        
    def write(self, text):
        # Only filter out specific YrA error messages, let everything else through
        if ("Failed to plot YrA" in text and "cannot reshape array" in text):
            # Replace YrA error with a cleaner message
            self.original_stream.write("⚠️ YrA visualization skipped (shape mismatch)\n")
        else:
            self.original_stream.write(text)
    
    def flush(self):
        self.original_stream.flush()
        
    def __getattr__(self, name):
        return getattr(self.original_stream, name)

# Manual debug saver that works without visualization
class ManualDebugSaver:
    def __init__(self, output_dir):
        self.output_dir = output_dir
        os.makedirs(self.output_dir, exist_ok=True)
    
    def save_stage(self, cnmf_obj, stage_name):
        """Save CNMF outputs at each stage without problematic visualization"""
        print(f"📸 Saving stage: {stage_name}")
        
        def save_array(name, arr, stage):
            if arr is not None:
                try:
                    if hasattr(arr, 'toarray'):
                        arr = arr.toarray()
                    filename = os.path.join(self.output_dir, f"{name}_{stage}.npy")
                    np.save(filename, arr)
                    print(f"  ✅ Saved {name} shape {arr.shape} to {filename}")
                except Exception as e:
                    print(f"  ❌ Failed to save {name}: {e}")
        
        # Save all matrices
        if hasattr(cnmf_obj, 'estimates'):
            save_array("A", cnmf_obj.estimates.A, stage_name)
            save_array("C", cnmf_obj.estimates.C, stage_name)
            save_array("b", cnmf_obj.estimates.b, stage_name)
            save_array("f", cnmf_obj.estimates.f, stage_name)
            
            # Save YrA properly as temporal data (don't try to visualize as image)
            if hasattr(cnmf_obj.estimates, "YrA") and cnmf_obj.estimates.YrA is not None:
                save_array("YrA", cnmf_obj.estimates.YrA, stage_name)
                print(f"  📊 YrA shape: {cnmf_obj.estimates.YrA.shape} (K components × T timepoints)")
        
        # Save metadata
        metadata_path = os.path.join(self.output_dir, f"metadata_{stage_name}.txt")
        with open(metadata_path, "w") as f:
            f.write(f"Stage: {stage_name}\n")
            if hasattr(cnmf_obj, 'estimates'):
                f.write(f"A shape: {getattr(cnmf_obj.estimates.A, 'shape', 'None')}\n")
                f.write(f"C shape: {getattr(cnmf_obj.estimates.C, 'shape', 'None')}\n")
                f.write(f"YrA shape: {getattr(cnmf_obj.estimates.YrA, 'shape', 'None')}\n")
            f.write(f"dims: {getattr(cnmf_obj, 'dims', 'None')}\n")
        print(f"  📝 Saved metadata to {metadata_path}")

# ---------- Wrapper for CNMF Execution Without Manual Debug ----------

class CNMFWrapper:
    def __init__(self, cnmf_obj, mmap_path, dims, output_dir):
        self.cnmf = cnmf_obj
        self.mmap_path = mmap_path
        self.dims = dims
        self.debug_saver = ManualDebugSaver(output_dir)

    def run(self):
        # Filter YrA error messages during CNMF execution
        original_stdout = sys.stdout
        original_stderr = sys.stderr
        sys.stdout = YrAErrorFilter(original_stdout)
        sys.stderr = YrAErrorFilter(original_stderr)
        
        try:
            # Load memmap data correctly
            print(f"📖 Loading memmap from: {self.mmap_path}")
            
            # Try CaImAn's load_memmap first
            try:
                Yr, dims_loaded, T = load_memmap(self.mmap_path)
                dims = dims_loaded
                print(f"📖 Loaded with load_memmap: Yr shape {Yr.shape}, dims={dims}, T={T}")
            except:
                # Fallback: load directly as memmap
                print("📖 Using direct memmap loading...")
                T = self.dims[0] * self.dims[1]  # This will be corrected below
                
                # Load the memmap file directly
                fp_in = np.memmap(self.mmap_path, dtype=np.float32, mode='r')
                n_pixels = self.dims[0] * self.dims[1]
                T = fp_in.shape[0] // n_pixels
                
                Yr = fp_in.reshape((n_pixels, T), order='F')
                dims = self.dims
                print(f"📖 Direct load: Yr shape {Yr.shape}, dims={dims}, T={T}")
            
            self.cnmf.dims = self.dims
            
            # Ensure debug visualization is completely disabled
            self.cnmf.debug_visualize = False
            if hasattr(self.cnmf, '_debug_image'):
                delattr(self.cnmf, '_debug_image')
            
            # Check if only initialization
            only_init = self.cnmf.params.get('patch', 'only_init')
            if only_init is None:
                only_init = False
            print(f"🔧 Running CNMF with only_init = {only_init}")
            
            if only_init:
                # For only_init=True, just do initialization
                print("🚀 Starting initialization only...")
                self.cnmf.fit_file(self.mmap_path)
                self.debug_saver.save_stage(self.cnmf, "after_init_only")
            else:
                # Force manual step-by-step approach to save all debug stages
                print("🚀 Using manual step-by-step approach to save all debug stages...")
                
                # Manual approach for complete debugging
                # Load and preprocess data
                print("📖 Loading and preprocessing data...")
                Yr = self.cnmf.preprocess(Yr)
                
                # Initialization
                print("🎯 Running initialization...")
                self.cnmf.initialize(Yr.reshape((-1, T), order='F'))
                self.debug_saver.save_stage(self.cnmf, "after_initialize")
                
                # Spatial update 1
                print("🗺️ Running spatial update 1...")
                try:
                    self.cnmf.update_spatial(Yr.reshape((-1, T), order='F'))
                    self.debug_saver.save_stage(self.cnmf, "after_spatial_1")
                except (IndexError, ValueError) as e:
                    print(f"⚠️ Spatial update 1 failed: {e}")
                    print("⚠️ Skipping spatial update 1...")
                    self.debug_saver.save_stage(self.cnmf, "after_spatial_1_failed")
                
                # Temporal update 1
                print("⏰ Running temporal update 1...")
                try:
                    self.cnmf.update_temporal(Yr.reshape((-1, T), order='F'))
                    self.debug_saver.save_stage(self.cnmf, "after_temporal_1")
                except (IndexError, ValueError) as e:
                    print(f"⚠️ Temporal update 1 failed: {e}")
                    print("⚠️ Skipping temporal update 1...")
                    self.debug_saver.save_stage(self.cnmf, "after_temporal_1_failed")
                
                # Merging (if enabled)
                do_merge = self.cnmf.params.get('merging', 'do_merge')
                if do_merge is None:
                    do_merge = True
                if do_merge:
                    print("🔗 Running component merging...")
                    try:
                        self.cnmf.merge_comps(Yr.reshape((-1, T), order='F'))
                        self.debug_saver.save_stage(self.cnmf, "after_merge")
                    except (IndexError, ValueError) as e:
                        print(f"⚠️ Merging failed: {e}")
                        print("⚠️ Continuing without merging...")
                        # Save the stage anyway without merging
                        self.debug_saver.save_stage(self.cnmf, "after_merge_failed")
                
                # Spatial update 2
                print("🗺️ Running spatial update 2...")
                try:
                    self.cnmf.update_spatial(Yr.reshape((-1, T), order='F'))
                    self.debug_saver.save_stage(self.cnmf, "after_spatial_2")
                except (IndexError, ValueError) as e:
                    print(f"⚠️ Spatial update 2 failed: {e}")
                    print("⚠️ Skipping spatial update 2...")
                    self.debug_saver.save_stage(self.cnmf, "after_spatial_2_failed")
                
                # Temporal update 2
                print("⏰ Running temporal update 2...")
                try:
                    self.cnmf.update_temporal(Yr.reshape((-1, T), order='F'))
                    self.debug_saver.save_stage(self.cnmf, "after_temporal_2")
                except (IndexError, ValueError) as e:
                    print(f"⚠️ Temporal update 2 failed: {e}")
                    print("⚠️ Skipping temporal update 2...")
                    self.debug_saver.save_stage(self.cnmf, "after_temporal_2_failed")
                
                # Final residuals computation
                print("🧮 Computing final residuals...")
                try:
                    self.cnmf.compute_residuals(Yr.reshape((-1, T), order='F'))
                    self.debug_saver.save_stage(self.cnmf, "final")
                except (IndexError, ValueError) as e:
                    print(f"⚠️ Final residuals computation failed: {e}")
                    print("⚠️ Skipping residuals computation...")
                    self.debug_saver.save_stage(self.cnmf, "final_no_residuals")
            
            # Save final output
            output_path = os.path.join(self.output_dir, "cnmf_final_output.hdf5")
            self.cnmf.save(output_path)
            print(f"💾 Saved final CNMF output to {output_path}")
            
            print("A shape:", self.cnmf.estimates.A.shape)
            print("C shape:", self.cnmf.estimates.C.shape)
            print("Number of ROIs:", self.cnmf.estimates.A.shape[1])
            A_final = self.cnmf.estimates.A.toarray()
            print("Non-zero in A:", np.count_nonzero(A_final))
            
            return self.cnmf
            
        finally:
            # Restore original streams
            sys.stdout = original_stdout
            sys.stderr = original_stderr

print("✅ CNMF code loaded successfully!")

## ⚙️ Configure Parameters & Run CNMF

In [None]:
def run_cnmf_analysis(video_path, output_dir, limit_frames=None):
    """Run your complete CNMF analysis"""
    
    # Load video data
    print(f"📖 Loading video from: {video_path}")
    Y = imread(video_path)
    
    # Optional: limit number of frames for Colab memory management
    if limit_frames and Y.shape[0] > limit_frames:
        print(f"⚠️ Limiting frames from {Y.shape[0]} to {limit_frames} for Colab")
        Y = Y[:limit_frames]
    
    T, d1, d2 = Y.shape
    dims = (d1, d2)
    
    print(f"📊 Data info: {T} frames, {d1}x{d2} pixels")
    print(f"📊 Data range: {Y.min()} to {Y.max()}")
    print(f"📊 Data type: {Y.dtype}")
    print(f"📊 Memory usage: ~{(Y.nbytes / 1024**3):.1f} GB")

    # Create memmap file
    print("📝 Creating CaImAn-compatible memmap...")
    
    mmap_base = f"colab_d1_{d1}_d2_{d2}_frames_{T}"
    mmap_path = os.path.join(output_dir, mmap_base + '.mmap')
    
    # Convert to the right format and save
    Y_reshaped = Y.astype(np.float32)
    
    # Create the memmap file directly in the expected format
    fp_out = np.memmap(mmap_path, dtype=np.float32, mode='w+', 
                      shape=(np.prod(Y_reshaped.shape[1:]), Y_reshaped.shape[0]))
    
    # Reshape and transpose to get (pixels, time) format that CaImAn expects
    for t in range(Y_reshaped.shape[0]):
        fp_out[:, t] = Y_reshaped[t].flatten(order='F')
    
    fp_out.flush()
    del fp_out, Y, Y_reshaped
    gc.collect()  # Force garbage collection
    
    print(f"📁 Created memmap: {mmap_path}")

    # Biological parameters (your exact settings)
    fr = 1.08                               # imaging rate in frames per second
    decay_time = 20                         # length of a typical transient in seconds 
    dxy = (1.243, 1.243)                    # spatial resolution in x and y in (um per pixel)
    cell_diameter = 10                      # in microns
    d_px = int(cell_diameter // dxy[0])     # convert microns to pixels

    # CNMF parameters (optimized for Colab)
    p = 1                                   # order of the autoregressive system (1 is more stable than 2)
    gnb = 2                                 # number of global background components
    merge_thr = 0.7                         # merging threshold, max correlation allowed
    bas_nonneg = True                       # enforce nonnegativity constraint on calcium traces
    rf = None                               # No patches - analyze full image at once
    stride_cnmf = None                      # No patches
    K = 150                                 # Reduced for Colab memory (was 200)
    gSig = np.array([0.5*d_px, 0.5*d_px])  # expected half-width of neurons in pixels
    gSiz = 2*gSig + 1                       # Gaussian kernel width and height
    method_init = 'greedy_roi'              # initialization method (more stable than corr_pnr)
    ssub = 1                                # spatial subsampling during initialization
    tsub = 1                                # temporal subsampling during initialization

    # parameters for component evaluation
    min_SNR = 1.2               # signal to noise ratio for more sensitivity
    rval_thr = 0.7              # space correlation threshold for more sensitivity
    cnn_thr = 0.99              # threshold for CNN based classifier
    cnn_lowest = 0.1            # neurons with cnn probability lower than this value are rejected
    
    print(f"🔬 Biological parameters:")
    print(f"  Cell diameter: {cell_diameter} μm = {d_px} pixels")
    print(f"  No patches - analyzing full image")
    print(f"  K (total components): {K}")
    print(f"  gSig: {gSig}")

    # parameters dictionary
    parameter_dict = {
        'fnames': [mmap_path],
        'fr': fr,
        'dxy': dxy,
        'decay_time': decay_time,
        'p': p,
        'nb': gnb,
        'rf': rf,
        'K': K,
        'gSig': gSig,
        'gSiz': gSiz,
        'stride': stride_cnmf,
        'method_init': method_init,
        'rolling_sum': True,
        'only_init': False,
        'ssub': ssub,
        'tsub': tsub,
        'merge_thr': merge_thr,
        'bas_nonneg': bas_nonneg,
        'min_SNR': min_SNR,
        'rval_thr': rval_thr,
        'use_cnn': False,
        'min_cnn_thr': cnn_thr,
        'cnn_lowest': cnn_lowest,
        'dims': dims,
        'is3D': False,
        'data_format': 'mmap',
        'n_pixels_per_process': dims[0] * dims[1],
        'do_merge': True,
        'merge_thresh': merge_thr
    }

    # Create CNMFParams
    opts = params.CNMFParams(params_dict=parameter_dict)
    print(f"🚀 CNMF Parameters loaded successfully!")

    # Create CNMF object
    cnm = cnmf.CNMF(n_processes=1, params=opts)
    cnm.debug_visualize = False

    # Run your custom CNMF wrapper
    wrapper = CNMFWrapper(cnmf_obj=cnm, mmap_path=mmap_path, dims=dims, output_dir=output_dir)
    
    print("\n🚀 Starting CNMF analysis...")
    cnm_result = wrapper.run()
    
    return cnm_result, dims, T

# Run the analysis
if 'video_path' in locals() and os.path.exists(video_path):
    # You can adjust limit_frames based on your Colab memory
    # Set to None to use all frames, or a number like 600 to limit
    limit_frames = 600  # Adjust this based on your memory needs
    
    cnmf_result, dims, total_frames = run_cnmf_analysis(
        video_path=video_path,
        output_dir=local_output,  # Use local for speed, will copy to Drive later
        limit_frames=limit_frames
    )
    
    print("\n✅ CNMF Analysis completed!")
    
else:
    print("❌ Video file not found. Please upload your video file first.")

## 📊 Results & Visualization

In [None]:
import matplotlib.pyplot as plt
import seaborn as sns

def visualize_cnmf_results(cnmf_result, dims, output_dir):
    """Create comprehensive visualizations of CNMF results"""
    
    if cnmf_result is None:
        print("❌ No CNMF results to visualize")
        return
    
    # Get results
    A = cnmf_result.estimates.A.toarray().reshape(dims + (-1,), order='F')
    C = cnmf_result.estimates.C
    
    n_components = A.shape[2]
    print(f"📊 Visualizing {n_components} components")
    
    # 1. Spatial Components Overview
    fig, axes = plt.subplots(3, 4, figsize=(16, 12))
    fig.suptitle('Spatial Components (First 12)', fontsize=16, fontweight='bold')
    
    for i in range(min(12, n_components)):
        ax = axes[i//4, i%4]
        im = ax.imshow(A[:,:,i], cmap='hot', interpolation='nearest')
        ax.set_title(f'Component {i+1}', fontsize=12)
        ax.axis('off')
        plt.colorbar(im, ax=ax, shrink=0.8)
    
    # Hide empty subplots
    for i in range(min(12, n_components), 12):
        axes[i//4, i%4].axis('off')
    
    plt.tight_layout()
    plt.savefig(os.path.join(output_dir, 'spatial_components.png'), dpi=150, bbox_inches='tight')
    plt.show()
    
    # 2. Temporal Traces
    fig, ax = plt.subplots(figsize=(15, 8))
    
    # Plot first 10 traces with offset for visibility
    for i in range(min(10, n_components)):
        offset = i * 2  # Vertical offset
        ax.plot(C[i] + offset, label=f'Component {i+1}', linewidth=1.5)
    
    ax.set_title('Temporal Activity Traces (First 10 Components)', fontsize=16, fontweight='bold')
    ax.set_xlabel('Frame Number', fontsize=14)
    ax.set_ylabel('Fluorescence + Offset', fontsize=14)
    ax.legend(bbox_to_anchor=(1.05, 1), loc='upper left')
    ax.grid(True, alpha=0.3)
    
    plt.tight_layout()
    plt.savefig(os.path.join(output_dir, 'temporal_traces.png'), dpi=150, bbox_inches='tight')
    plt.show()
    
    # 3. Component Statistics
    fig, axes = plt.subplots(2, 2, figsize=(12, 10))
    
    # Activity levels
    activity_levels = np.mean(C, axis=1)
    axes[0,0].hist(activity_levels, bins=20, alpha=0.7, color='skyblue', edgecolor='black')
    axes[0,0].set_title('Distribution of Component Activity Levels')
    axes[0,0].set_xlabel('Mean Fluorescence')
    axes[0,0].set_ylabel('Count')
    
    # Spatial extent
    spatial_extent = [np.count_nonzero(A[:,:,i]) for i in range(n_components)]
    axes[0,1].hist(spatial_extent, bins=20, alpha=0.7, color='lightcoral', edgecolor='black')
    axes[0,1].set_title('Distribution of Component Spatial Extent')
    axes[0,1].set_xlabel('Number of Non-zero Pixels')
    axes[0,1].set_ylabel('Count')
    
    # Activity vs Spatial Extent
    axes[1,0].scatter(spatial_extent, activity_levels, alpha=0.6, s=50)
    axes[1,0].set_title('Activity vs Spatial Extent')
    axes[1,0].set_xlabel('Spatial Extent (pixels)')
    axes[1,0].set_ylabel('Mean Activity')
    
    # Component quality metrics
    max_activities = np.max(C, axis=1)
    std_activities = np.std(C, axis=1)
    snr_estimates = max_activities / std_activities
    
    axes[1,1].hist(snr_estimates, bins=20, alpha=0.7, color='lightgreen', edgecolor='black')
    axes[1,1].set_title('Estimated SNR Distribution')
    axes[1,1].set_xlabel('Max/Std Ratio')
    axes[1,1].set_ylabel('Count')
    
    plt.tight_layout()
    plt.savefig(os.path.join(output_dir, 'component_statistics.png'), dpi=150, bbox_inches='tight')
    plt.show()
    
    # 4. Summary Statistics
    print("\n📈 CNMF Analysis Summary:")
    print(f"  🔢 Total components found: {n_components}")
    print(f"  📐 Image dimensions: {dims[0]} x {dims[1]} pixels")
    print(f"  📊 Mean activity level: {np.mean(activity_levels):.3f} ± {np.std(activity_levels):.3f}")
    print(f"  📏 Mean spatial extent: {np.mean(spatial_extent):.1f} ± {np.std(spatial_extent):.1f} pixels")
    print(f"  📶 Mean estimated SNR: {np.mean(snr_estimates):.2f} ± {np.std(snr_estimates):.2f}")
    
    # Save summary statistics
    summary_stats = {
        'n_components': n_components,
        'image_dims': dims,
        'mean_activity': float(np.mean(activity_levels)),
        'std_activity': float(np.std(activity_levels)),
        'mean_spatial_extent': float(np.mean(spatial_extent)),
        'std_spatial_extent': float(np.std(spatial_extent)),
        'mean_snr': float(np.mean(snr_estimates)),
        'std_snr': float(np.std(snr_estimates))
    }
    
    import json
    with open(os.path.join(output_dir, 'summary_statistics.json'), 'w') as f:
        json.dump(summary_stats, f, indent=2)
    
    print(f"\n💾 Visualizations saved to: {output_dir}")

# Create visualizations
if 'cnmf_result' in locals() and cnmf_result is not None:
    visualize_cnmf_results(cnmf_result, dims, local_output)
else:
    print("❌ No CNMF results available for visualization")

## 📁 Stage Files Analysis

Check all the debug stage files that were saved:

In [None]:
# List all saved stage files
print("📁 Saved debug stage files:")
stage_files = []
for file in os.listdir(local_output):
    if file.endswith('.npy') and ('_after_' in file or '_final' in file):
        stage_files.append(file)
        file_size = os.path.getsize(os.path.join(local_output, file)) / (1024**2)
        print(f"  {file:<30} ({file_size:.1f} MB)")

# Group by stage
stages = set()
for file in stage_files:
    if '_after_' in file:
        stage = file.split('_after_')[1].split('.npy')[0]
        stages.add('after_' + stage)
    elif '_final' in file:
        stages.add('final')

print(f"\n🎯 Available stages for napari viewer: {sorted(stages)}")
print(f"📊 Total files: {len(stage_files)}")
print(f"💾 Total size: {sum(os.path.getsize(os.path.join(local_output, f)) for f in stage_files) / (1024**2):.1f} MB")

## ☁️ Copy Results to Google Drive

In [None]:
# Copy all results to Google Drive for persistence
print("☁️ Copying results to Google Drive...")

for file in os.listdir(local_output):
    src = os.path.join(local_output, file)
    dst = os.path.join(output_dir, file)
    
    if os.path.isfile(src):
        shutil.copy2(src, dst)
        file_size = os.path.getsize(src) / (1024**2)
        print(f"  ✅ Copied {file} ({file_size:.1f} MB)")

print(f"\n💾 All results saved to Google Drive: {output_dir}")
print("🔗 You can access these files from any device via Google Drive")

# List final contents
print("\n📁 Final output directory contents:")
!ls -lh "$output_dir"

## ⬇️ Download Key Results

In [None]:
# Download key result files
print("⬇️ Downloading key result files...")

# Download the final HDF5 file
hdf5_file = os.path.join(local_output, 'cnmf_final_output.hdf5')
if os.path.exists(hdf5_file):
    files.download(hdf5_file)
    print("✅ Downloaded: cnmf_final_output.hdf5")

# Download final matrices
final_files = ['A_final.npy', 'C_final.npy', 'YrA_final.npy']
for filename in final_files:
    filepath = os.path.join(local_output, filename)
    if os.path.exists(filepath):
        files.download(filepath)
        print(f"✅ Downloaded: {filename}")

# Download visualizations
viz_files = ['spatial_components.png', 'temporal_traces.png', 'component_statistics.png']
for filename in viz_files:
    filepath = os.path.join(local_output, filename)
    if os.path.exists(filepath):
        files.download(filepath)
        print(f"✅ Downloaded: {filename}")

# Download summary statistics
summary_file = os.path.join(local_output, 'summary_statistics.json')
if os.path.exists(summary_file):
    files.download(summary_file)
    print("✅ Downloaded: summary_statistics.json")

print("\n🎉 Download complete! Check your Downloads folder.")

## 🔍 Note: Napari Viewer

**Your napari viewer code won't run directly in Colab** because Colab doesn't support interactive GUI applications. However:

1. **All debug stage files are saved** - you can download them and use with your local napari viewer
2. **Stage switching will work locally** - once you download the files to your computer
3. **All intermediate stages are preserved** - `after_initialize`, `after_spatial_1`, etc.

### To use napari locally with these results:
1. Download all the `.npy` files from Google Drive
2. Update your `napari_runner.py` to point to the downloaded directory
3. Run napari locally with the stage switching functionality

### Colab Alternative:
The matplotlib visualizations above provide a good overview of your results without needing interactive napari.

## 🧹 Memory Cleanup

In [None]:
# Clean up memory
import gc

# Delete large variables
if 'cnmf_result' in locals():
    del cnmf_result
if 'A' in locals():
    del A
if 'C' in locals():
    del C

# Force garbage collection
gc.collect()

# Check final memory usage
memory = psutil.virtual_memory()
print(f"🧹 Memory after cleanup: {memory.percent}% used ({memory.used/1024**3:.1f}GB/{memory.total/1024**3:.1f}GB)")
print("✅ Cleanup complete!")