# SLEAP-Roots Processing with sleap-vizmo

This notebook demonstrates how to:
1. Load SLEAP files using sleap-io
2. Split multi-video labels into individual files
3. Save files with correct naming for sleap-roots Series
4. Process with MultipleDicotPipeline to get traits for multiple plants
5. Generate a CSV with all plant associations and traits

In [6]:
# Import required libraries
import sleap_io as sio
import sleap_roots as sr
from sleap_roots.trait_pipelines import MultipleDicotPipeline
from sleap_vizmo.roots_utils import (
    split_labels_by_video,
    save_individual_video_labels,
    validate_series_compatibility,
    create_series_name_from_video
)
from pathlib import Path
from datetime import datetime
import pandas as pd
import json

## 1. Load Test SLEAP Files

In [7]:
# Define paths to test data
test_data_dir = Path("tests/data")
lateral_file = test_data_dir / "lateral_root_MK22_Day14_labels.v002.slp"
primary_file = test_data_dir / "primary_root_MK22_Day14_labels.v003.slp"

# Load the SLEAP files
print("Loading SLEAP files...")
lateral_labels = sio.load_slp(lateral_file)
primary_labels = sio.load_slp(primary_file)

print(f"Lateral labels: {len(lateral_labels)} frames, {len(lateral_labels.videos)} videos")
print(f"Primary labels: {len(primary_labels)} frames, {len(primary_labels.videos)} videos")

Loading SLEAP files...
Lateral labels: 23 frames, 23 videos
Primary labels: 23 frames, 23 videos


## 2. Validate Series Compatibility

In [8]:
# Check if labels are compatible with Series requirements
lateral_compat = validate_series_compatibility(lateral_labels)
primary_compat = validate_series_compatibility(primary_labels)

print("Lateral labels compatibility:")
print(f"  Compatible: {lateral_compat['is_compatible']}")
if lateral_compat['warnings']:
    print(f"  Warnings: {lateral_compat['warnings']}")
if lateral_compat['errors']:
    print(f"  Errors: {lateral_compat['errors']}")

print("\nPrimary labels compatibility:")
print(f"  Compatible: {primary_compat['is_compatible']}")
if primary_compat['warnings']:
    print(f"  Warnings: {primary_compat['warnings']}")
if primary_compat['errors']:
    print(f"  Errors: {primary_compat['errors']}")

Lateral labels compatibility:
  Compatible: True

Primary labels compatibility:
  Compatible: True


## 3. Split Labels by Video and Save with Proper Naming

In [9]:
# Create timestamped output directory
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S_%f")
output_dir = Path("output") / f"sleap_roots_processing_{timestamp}"
output_dir.mkdir(parents=True, exist_ok=True)

print(f"Output directory: {output_dir}")

# Split labels by video if needed
lateral_split = split_labels_by_video(lateral_labels)
primary_split = split_labels_by_video(primary_labels)

print(f"\nLateral labels split into {len(lateral_split)} video(s)")
print(f"Primary labels split into {len(primary_split)} video(s)")

Output directory: output/sleap_roots_processing_20250804_075350_702843

Lateral labels split into 23 video(s)
Primary labels split into 23 video(s)


In [10]:
# Save individual video labels with proper naming for Series.load
# The naming convention should make it clear which are lateral vs primary

series_data = {}  # Will store series names and their file paths

# Process lateral roots
print("\nSaving lateral root files...")
for video_name, labels in lateral_split.items():
    series_name = create_series_name_from_video(video_name)
    if series_name not in series_data:
        series_data[series_name] = {}
    
    # Save with .lateral suffix to identify root type
    output_path = output_dir / f"{series_name}.lateral.slp"
    labels.save(str(output_path))
    series_data[series_name]['lateral_path'] = str(output_path)
    print(f"  Saved: {output_path.name}")

# Process primary roots
print("\nSaving primary root files...")
for video_name, labels in primary_split.items():
    series_name = create_series_name_from_video(video_name)
    if series_name not in series_data:
        series_data[series_name] = {}
    
    # Save with .primary suffix to identify root type
    output_path = output_dir / f"{series_name}.primary.slp"
    labels.save(str(output_path))
    series_data[series_name]['primary_path'] = str(output_path)
    print(f"  Saved: {output_path.name}")

print(f"\nTotal series to process: {len(series_data)}")


Saving lateral root files...
  Saved: F_Ac_set1_day14_20250527_102755_001.lateral.slp
  Saved: F_Cp_set1_day14_20250527_102755_002.lateral.slp
  Saved: F_De_set1_day14_20250527_102755_003.lateral.slp
  Saved: F_DhA_set1_day14_20250527_102755_004.lateral.slp
  Saved: F_DhD_set1_day14_20250527_102755_005.lateral.slp
  Saved: F_Fo_set1_day14_20250527_102755_006.lateral.slp
  Saved: F_Gr_set1_day14_20250527_102955_007.lateral.slp
  Saved: no_peptide1_set1_day14_20250527_102955_010.lateral.slp
  Saved: no_peptide2_set1_day14_20250527_102955_011.lateral.slp
  Saved: OG_Ac_set2_day14_20250527_103422_014.lateral.slp
  Saved: OG_Cp_set2_day14_20250527_103422_015.lateral.slp
  Saved: OG_De_set2_day14_20250527_103422_016.lateral.slp
  Saved: OG_DhA_set2_day14_20250527_103422_017.lateral.slp
  Saved: OG_DhB_set2_day14_20250527_103422_018.lateral.slp
  Saved: OG_DhD_set2_day14_20250527_103618_019.lateral.slp
  Saved: OG_Fo_set1_day14_20250527_102955_012.lateral.slp
  Saved: OG_Gr_set2_day14_202505

## 4. Load Series and Process with MultipleDicotPipeline

In [None]:
# Find all slp files in the folder
all_slps = sr.find_all_slp_paths(output_dir)

# Load the cylinder series using slp paths
all_series = sr.load_series_from_slps(slp_paths=all_slps, h5s=False)
print(f"Loaded {len(all_series)} series")
all_series

## Create Expected Count CSV for MultipleDicotPipeline

The MultipleDicotPipeline expects a CSV with the number of plants per cylinder. We'll count the number of instances in each primary root file to create this.

In [None]:
# Create expected count CSV by counting instances in primary root files
expected_counts = []

for series in all_series:
    # Get the series name which will be our plant_qr_code
    plant_qr_code = series.series_name
    
    # Count the number of instances (plants) in the primary root file
    # The primary labels contain the instances we need to count
    if hasattr(series, 'primary_labels') and series.primary_labels is not None:
        # Count unique instances across all frames
        all_instances = set()
        for lf in series.primary_labels:
            for instance in lf.instances:
                all_instances.add(instance)
        
        num_plants = len(all_instances)
    else:
        # If no primary labels, default to 0
        num_plants = 0
    
    # Get the paths for documentation
    primary_path = ""
    lateral_path = ""
    
    # Find the corresponding paths from our series_data
    if plant_qr_code in series_data:
        primary_path = series_data[plant_qr_code].get('primary_path', '')
        lateral_path = series_data[plant_qr_code].get('lateral_path', '')
    
    # Extract genotype and replicate from the series name
    # Parse patterns like "F_Ac_set1_day14_20250527_102755_001"
    parts = plant_qr_code.split('_')
    genotype = "_".join(parts[:2]) if len(parts) > 1 else plant_qr_code
    
    # Try to extract replicate number from "set" part
    replicate = 1  # default
    for part in parts:
        if part.startswith('set'):
            try:
                replicate = int(part.replace('set', ''))
            except:
                pass
    
    # Create row for expected count CSV
    row = {
        'plant_qr_code': plant_qr_code,
        'genotype': genotype,
        'replicate': replicate,
        'path': primary_path,  # Using primary path as the main path
        'qc_cylinder': 0,  # Default value
        'qc_code': None,  # Will be NaN in CSV
        'number_of_plants_cylinder': num_plants,
        'primary_root_proofread': primary_path,
        'lateral_root_proofread': lateral_path if lateral_path else None,
    }
    
    expected_counts.append(row)
    print(f"{plant_qr_code}: {num_plants} plants detected")

# Create DataFrame
expected_count_df = pd.DataFrame(expected_counts)

# Add empty columns to match the expected format
for col in ['Unnamed: 9', 'Unnamed: 10', 'Unnamed: 11', 'Unnamed: 12', 'Instructions']:
    expected_count_df[col] = None

# Save the expected count CSV
expected_count_path = output_dir / "expected_plant_counts.csv"
expected_count_df.to_csv(expected_count_path, index=False)

print(f"\n✅ Expected count CSV saved to: {expected_count_path}")
print(f"Total series: {len(expected_count_df)}")
print(f"Total plants across all series: {expected_count_df['number_of_plants_cylinder'].sum()}")

# Display the dataframe
display(expected_count_df[['plant_qr_code', 'genotype', 'replicate', 'number_of_plants_cylinder']].head(10))

In [None]:
# Display the expected count CSV for verification
print("Expected Count CSV Preview:")
print(f"Shape: {expected_count_df.shape}")
print("\nFirst 10 rows:")
display(expected_count_df.head(10))

# Show summary statistics
print(f"\nPlant count distribution:")
print(expected_count_df['number_of_plants_cylinder'].value_counts().sort_index())

In [6]:
# Load each series and process with MultipleDicotPipeline
all_series = []
all_traits = []

for series_name, paths in series_data.items():
    print(f"\nProcessing series: {series_name}")
    
    # Prepare kwargs for Series.load
    load_kwargs = {'series_name': series_name}
    
    if 'primary_path' in paths:
        load_kwargs['primary_path'] = paths['primary_path']
        print(f"  Primary: {Path(paths['primary_path']).name}")
    
    if 'lateral_path' in paths:
        load_kwargs['lateral_path'] = paths['lateral_path']
        print(f"  Lateral: {Path(paths['lateral_path']).name}")
    
    # Load the series
    try:
        series = sr.Series.load(**load_kwargs)
        all_series.append(series)
        print(f"  ✓ Series loaded successfully")
    except Exception as e:
        print(f"  ✗ Error loading series: {e}")
        continue


Processing series: F_Ac_set1_day14_20250527_102755_001
  Primary: F_Ac_set1_day14_20250527_102755_001.primary.slp
  Lateral: F_Ac_set1_day14_20250527_102755_001.lateral.slp
  ✓ Series loaded successfully

Processing series: F_Cp_set1_day14_20250527_102755_002
  Primary: F_Cp_set1_day14_20250527_102755_002.primary.slp
  ✓ Series loaded successfully

Processing series: F_De_set1_day14_20250527_102755_003
  Primary: F_De_set1_day14_20250527_102755_003.primary.slp
  ✓ Series loaded successfully

Processing series: F_DhA_set1_day14_20250527_102755_004
  Primary: F_DhA_set1_day14_20250527_102755_004.primary.slp
  ✓ Series loaded successfully

Processing series: F_DhD_set1_day14_20250527_102755_005
  Primary: F_DhD_set1_day14_20250527_102755_005.primary.slp
  ✓ Series loaded successfully

Processing series: F_Fo_set1_day14_20250527_102755_006
  Primary: F_Fo_set1_day14_20250527_102755_006.primary.slp
  ✓ Series loaded successfully

Processing series: F_Gr_set1_day14_20250527_102955_007
  Pri

In [None]:
# Initialize MultipleDicotPipeline
pipeline = MultipleDicotPipeline()
print(f"\nUsing pipeline: {pipeline.__class__.__name__}")

# The MultipleDicotPipeline expects an expected count CSV
# Let's use the one we just created
print(f"Using expected count CSV: {expected_count_path}")

# Process all series together with the expected count CSV
try:
    # Compute traits for multiple plants across all series
    traits = pipeline.compute_multiple_dicots_traits(
        all_series,
        write_csv=True,
        csv_suffix="_all_plants_traits.csv",
        output_dir=str(output_dir),
        expected_count_csv_path=str(expected_count_path)  # Pass the expected count CSV
    )
    
    # Handle the output
    if isinstance(traits, pd.DataFrame):
        all_traits_df = traits
        print(f"✓ Computed traits for {len(all_traits_df)} plants")
    else:
        # Handle case where traits might be a list of dictionaries
        all_traits_df = pd.DataFrame(traits)
        print(f"✓ Computed traits for {len(all_traits_df)} plants")
        
except Exception as e:
    print(f"✗ Error computing traits: {e}")
    import traceback
    traceback.print_exc()
    all_traits_df = pd.DataFrame()  # Empty dataframe on error

## 5. Combine All Traits into Final CSV

In [None]:
# Check if we have traits data
if 'all_traits_df' in locals() and len(all_traits_df) > 0:
    # Save the final CSV with all plants and their traits
    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S_%f")
    final_csv_path = output_dir / f"all_plants_traits_{timestamp}.csv"
    all_traits_df.to_csv(final_csv_path, index=False)
    
    print(f"\n✅ Final CSV saved: {final_csv_path}")
    print(f"Total plants processed: {len(all_traits_df)}")
    print(f"\nColumns in final CSV:")
    for col in all_traits_df.columns:
        print(f"  - {col}")
    
    # Display first few rows
    print("\nFirst 5 rows of the final DataFrame:")
    display(all_traits_df.head())
else:
    print("\n⚠️ No traits were computed successfully")
    print("Check that:")
    print("  1. The expected count CSV was created correctly")
    print("  2. The series have valid primary root instances")
    print("  3. The MultipleDicotPipeline can process the data")

## 6. Summary and Validation

In [None]:
# Create a summary of the processing
summary = {
    "timestamp": timestamp,
    "output_directory": str(output_dir),
    "input_files": {
        "lateral": str(lateral_file),
        "primary": str(primary_file)
    },
    "series_processed": len(all_series),
    "total_plants": len(all_traits_df) if 'all_traits_df' in locals() and len(all_traits_df) > 0 else 0,
    "pipeline_used": "MultipleDicotPipeline",
    "expected_count_csv": str(expected_count_path),
    "expected_total_plants": expected_count_df['number_of_plants_cylinder'].sum(),
    "trait_columns": list(all_traits_df.columns) if 'all_traits_df' in locals() and len(all_traits_df) > 0 else []
}

# Save summary as JSON
summary_path = output_dir / "processing_summary.json"
with open(summary_path, 'w') as f:
    json.dump(summary, f, indent=2)

print(f"\n📊 Processing Summary:")
print(f"  - Series processed: {summary['series_processed']}")
print(f"  - Expected plants: {summary['expected_total_plants']}")
print(f"  - Actual plants processed: {summary['total_plants']}")
print(f"  - Output directory: {summary['output_directory']}")
print(f"  - Summary saved to: {summary_path.name}")

In [None]:
# Optional: Visualize trait distributions
if 'all_traits_df' in locals() and len(all_traits_df) > 0:
    import matplotlib.pyplot as plt
    import seaborn as sns
    
    # Set up the plot style
    plt.style.use('seaborn-v0_8-whitegrid')
    
    # Select numeric columns for visualization
    numeric_cols = all_traits_df.select_dtypes(include=['float64', 'int64']).columns
    
    if len(numeric_cols) > 0:
        # Create a figure with subplots for first few traits
        n_traits = min(6, len(numeric_cols))
        fig, axes = plt.subplots(2, 3, figsize=(15, 10))
        axes = axes.flatten()
        
        for i, col in enumerate(numeric_cols[:n_traits]):
            axes[i].hist(all_traits_df[col].dropna(), bins=20, edgecolor='black', alpha=0.7)
            axes[i].set_title(col.replace('_', ' ').title())
            axes[i].set_xlabel('Value')
            axes[i].set_ylabel('Count')
        
        # Hide any unused subplots
        for i in range(n_traits, 6):
            axes[i].set_visible(False)
        
        plt.suptitle('Distribution of Plant Traits', fontsize=16)
        plt.tight_layout()
        
        # Save the figure
        fig_path = output_dir / "trait_distributions.png"
        plt.savefig(fig_path, dpi=150, bbox_inches='tight')
        plt.show()
        
        print(f"\n📈 Trait distribution plot saved to: {fig_path.name}")
else:
    print("\n📊 No traits data available for visualization")