In [1]:
import asyncio
import time
import signal
import threading
from typing import List, Dict, Any, Optional, Set, Tuple

from starfish import data_factory
from starfish.common.env_loader import load_env_file
from starfish.data_factory.utils.mock import mock_llm_call

# Load environment variables
load_env_file()


In [2]:

# Apply nest_asyncio for use in Jupyter notebooks
try:
    import nest_asyncio
    nest_asyncio.apply()
except ImportError:
    print("nest_asyncio not found, skipping. This may cause issues if run in a notebook.")


In [3]:

# Define the data factory function at module level
@data_factory(max_concurrency=10)
async def mock_llm_processor(city_name: str, num_records_per_city: int):
    """Mock LLM processor that simulates processing with a delay"""
    # Added sleep to make the process take longer for demonstration
    await asyncio.sleep(0.5)
    return await mock_llm_call(city_name=city_name, num_records_per_city=num_records_per_city, fail_rate=0)

class TestRunner:
    def __init__(self, total_time_limit=15, checkpoint_interval=3, max_checkpoints=10):
        """
        Initialize the test runner
        
        Args:
            total_time_limit: Maximum time allowed for the whole test in seconds
            checkpoint_interval: Time between checkpoints (stop/resume) in seconds
            max_checkpoints: Maximum number of checkpoints before forced termination
        """
        self.total_time_limit = total_time_limit
        self.checkpoint_interval = checkpoint_interval
        self.max_checkpoints = max_checkpoints
        self.errors = []  # Each item will be a tuple of (step, error_message)
        self.results = None
        self.stop_events = []
        self.job = None
        self.timeout_triggered = False
        self.all_checkpoint_errors = {}  # Dictionary to track errors per checkpoint
        
    def add_error(self, step: str, error_message: str):
        """Add an error with the associated step information"""
        self.errors.append((step, error_message))
        
        # Also track errors by checkpoint
        if step not in self.all_checkpoint_errors:
            self.all_checkpoint_errors[step] = []
        self.all_checkpoint_errors[step].append(error_message)
        
    def validate_completion_indices(self, indices, step_name="Validation", is_final=False):
        """
        Validate that the completion indices are correct
        
        Args:
            indices: The indices to validate
            step_name: The name of the step for error reporting
            is_final: Whether this is the final validation (expecting all indices to be complete)
            
        Returns:
            List of errors found
        """
        errors = []
        
        # Safety check for None
        if indices is None:
            error = "Indices are None"
            self.add_error(step_name, error)
            return [error]
        
        # Get the completed count
        completed_values = [idx for idx in indices if idx is not None]
        completed_count = len(completed_values)
        
        # For final validation, check if all indices are completed
        if is_final:
            # Check length
            if len(indices) != 100:
                error = f"Expected 100 indices total, but found {len(indices)}"
                self.add_error(step_name, error)
                errors.append(error)
                
            # Check that all are completed (no None values)
            if completed_count != 100:
                error = f"Expected 100 completed indices, but found {completed_count}"
                self.add_error(step_name, error)
                errors.append(error)
            
        # Check for uniqueness among completed indices (always important)
        unique_indices = set(completed_values)
        if len(unique_indices) != len(completed_values):
            duplicates = [idx for idx in unique_indices if indices.count(idx) > 1]
            error = f"Found duplicate values: {duplicates}"
            self.add_error(step_name, error)
            errors.append(error)
        
        # Check range of indices (0-99)
        expected_range = set(range(100))
        extra = unique_indices - expected_range
        
        if extra:
            error = f"Unexpected indices: {sorted(extra)}"
            self.add_error(step_name, error)
            errors.append(error)
        
        # For final validation, check if any indices are missing
        if is_final:
            missing = expected_range - unique_indices
            if missing:
                error = f"Missing indices: {sorted(missing)}"
                self.add_error(step_name, error)
                errors.append(error)
        
        return errors

    def interrupt_execution(self):
        """Schedule an interruption after the checkpoint interval"""
        print(f"⏱️ Scheduling interruption in {self.checkpoint_interval} seconds")
        timer = threading.Timer(self.checkpoint_interval, self.raise_interrupt)
        self.stop_events.append(timer)
        timer.start()

    def raise_interrupt(self):
        """Raise a KeyboardInterrupt to stop the execution"""
        print("🛑 Raising interruption signal")
        signal.raise_signal(signal.SIGINT)

    def setup_timeout(self):
        """Set up the overall timeout for the test"""
        print(f"⏱️ Setting up timeout limit of {self.total_time_limit} seconds")
        timeout_timer = threading.Timer(self.total_time_limit, self.handle_timeout)
        self.stop_events.append(timeout_timer)
        timeout_timer.start()

    def handle_timeout(self):
        """Handle the timeout by setting a flag instead of forcefully exiting"""
        print("⏰ Timeout reached! Stopping the job gracefully.")
        self.add_error("Timeout", f"Test exceeded maximum time limit of {self.total_time_limit} seconds")
        # Set a flag instead of hard exiting - this is more Jupyter-friendly
        self.timeout_triggered = True
        # Signal the main thread to stop
        signal.raise_signal(signal.SIGINT)

    def cleanup_timers(self):
        """Clean up all running timers"""
        for timer in self.stop_events:
            if timer.is_alive():
                timer.cancel()
        self.stop_events = []
        
    def check_progress_and_validate(self, checkpoint_name):
        """
        Check the current progress and validate indices for the current checkpoint
        
        Returns:
            Tuple of (progress_info, completed)
        """
        progress_info = "Unknown"
        completed = False
        
        try:
            # Safely get job status - avoid calling methods directly on potentially None objects
            if hasattr(self.job, 'get_index_completed') and callable(getattr(self.job, 'get_index_completed')):
                indices = self.job.get_index_completed()
                
                # Safety check
                if indices is not None:
                    # Determine if this is final validation based on completion status
                    completed_count = len([i for i in indices if i is not None])
                    is_final = completed_count == 100
                    
                    # Perform validation for this checkpoint
                    validation_errors = self.validate_completion_indices(
                        indices, 
                        checkpoint_name + " Validation",
                        is_final=is_final
                    )
                    if validation_errors:
                        print(f"❌ {checkpoint_name} validation failed:")
                        for err in validation_errors:
                            print(f"  - {err}")
                    elif is_final:
                        print(f"✅ {checkpoint_name} validation passed: All indices are correct")
                    else:
                        print(f"✅ {checkpoint_name} partial validation passed: {completed_count} indices processed")
                    
                    progress_info = f"{completed_count}/100"
                    
                    # Check if all tasks are completed
                    if completed_count == 100:
                        completed = True
                else:
                    self.add_error(checkpoint_name, "Failed to get indices: indices is None")
                    print(f"⚠️ {checkpoint_name}: Failed to get indices: indices is None")
            else:
                self.add_error(checkpoint_name, "Job does not have get_index_completed method")
                print(f"⚠️ {checkpoint_name}: Job does not have get_index_completed method")
                
        except Exception as e:
            self.add_error(checkpoint_name, f"Error getting indices: {str(e)}")
            print(f"❌ {checkpoint_name}: Error getting indices: {str(e)}")
        
        return progress_info, completed

    def _finish_test(self, start_time):
        """Finish the test by cleaning up and returning results"""
        # Clean up timers
        self.cleanup_timers()
        
        # Final validation if we have a job
        if self.job and hasattr(self.job, 'get_index_completed'):
            try:
                final_indices = self.job.get_index_completed()
                # Always perform full validation in the final step
                validation_errors = self.validate_completion_indices(final_indices, "Final Validation", is_final=True)
                if validation_errors:
                    print("❌ Final validation failed:")
                    for err in validation_errors:
                        print(f"  - {err}")
                else:
                    print("✅ Final validation passed: All indices are correct")
            except Exception as e:
                self.add_error("Final Validation", f"Error getting final indices: {str(e)}")
                print(f"❌ Error in final validation: {str(e)}")

    def run_test(self):
        """Run the complete test with interruptions and resumptions"""
        # Create input data
        cities = ["New York", "London", "Tokyo", "Paris", "Sydney"] * 20  # 100 cities
        
        print("=== Starting Initial Run ===")
        start_time = time.time()
        
        try:
            # Setup timers
            self.setup_timeout()
            self.interrupt_execution()
            
            # Start initial run - use the module level decorated function
            self.job = mock_llm_processor  # Use the module-level function
            
            try:
                self.results = self.job.run(city_name=cities, num_records_per_city=1)
                print("✅ Initial run completed without interruption")
                
                # Check progress and validate after initial run
                progress_info, completed = self.check_progress_and_validate("Initial Run")
                if completed:
                    print("✅ All tasks completed in initial run")
                    return self._finish_test(start_time)
                    
            except Exception as e:
                self.add_error("Initial Run", f"Error: {str(e)}")
                print(f"❌ Error in Initial Run: {str(e)}")
                # Don't return here, continue with checkpoint attempts
        except KeyboardInterrupt:
            print("⚠️ Initial run interrupted")
            
            # Check progress and validate after interruption
            progress_info, completed = self.check_progress_and_validate("Initial Run (Interrupted)")
            if completed:
                print("✅ All tasks completed after initial interruption")
                return self._finish_test(start_time)
                
        except Exception as e:
            self.add_error("Initial Run Setup", f"Error: {str(e)}")
            print(f"❌ Error in Initial Run setup: {str(e)}")
            # Don't return here, continue with checkpoint attempts
        
        # Resume until complete
        checkpoint_count = 1
        
        # Add a safety counter to prevent infinite loops
        while checkpoint_count <= self.max_checkpoints:
            checkpoint_name = f"Checkpoint {checkpoint_count}"
            
            # Check if timeout was triggered
            if self.timeout_triggered:
                print("⏰ Test timed out - stopping testing loop")
                break
                
            # Check if we have reached the total time limit
            if time.time() - start_time >= self.total_time_limit:
                self.add_error(checkpoint_name, f"Test exceeded maximum time limit of {self.total_time_limit} seconds")
                print(f"⏰ Test timed out after {self.total_time_limit} seconds")
                break
                
            # Check if we've hit the max checkpoint count
            if checkpoint_count == self.max_checkpoints:
                self.add_error(checkpoint_name, f"Test reached maximum checkpoint count of {self.max_checkpoints}")
                print(f"⚠️ Test reached maximum checkpoint count of {self.max_checkpoints}")
                break
                
            # Check if we have a job to resume
            if self.job is None:
                self.add_error(checkpoint_name, "Cannot continue: job is None")
                print("❌ Cannot continue: job is None")
                break
            
            # Check progress before resuming
            progress_info, completed = self.check_progress_and_validate(f"Before {checkpoint_name}")
            if completed:
                print(f"✅ All tasks completed before {checkpoint_name}")
                break
                
            print(f"=== Starting {checkpoint_name} ({progress_info}) ===")
            
            # Resume the job
            try:
                # Setup interruption for the next checkpoint
                self.interrupt_execution()
                
                # Try to resume if the method exists
                if hasattr(self.job, 'resume') and callable(getattr(self.job, 'resume')):
                    try:
                        self.results = self.job.resume()
                        print(f"✅ {checkpoint_name} completed without interruption")
                        
                        # Check progress after resumption
                        progress_info, completed = self.check_progress_and_validate(f"After {checkpoint_name}")
                        if completed:
                            print(f"✅ All tasks completed after {checkpoint_name}")
                            break
                            
                    except Exception as e:
                        self.add_error(checkpoint_name, f"Error: {str(e)}")
                        print(f"❌ Error in {checkpoint_name}: {str(e)}")
                        # Continue to the next checkpoint
                else:
                    self.add_error(checkpoint_name, "Job does not have resume method")
                    print("⚠️ Job does not have resume method")
                    break  # Can't continue without resume method
                    
            except KeyboardInterrupt:
                print(f"⚠️ {checkpoint_name} interrupted")
                
                # Check progress after interruption
                progress_info, completed = self.check_progress_and_validate(f"After {checkpoint_name} (Interrupted)")
                if completed:
                    print(f"✅ All tasks completed after {checkpoint_name} interruption")
                    break
                    
            checkpoint_count += 1

        # Finish the test
        return self._finish_test(start_time)
        
    def _finish_test(self, start_time):
        """Finish the test by cleaning up and returning results"""
        # Clean up timers
        self.cleanup_timers()
        
        # Final validation if we have a job
        if self.job and hasattr(self.job, 'get_index_completed'):
            try:
                final_indices = self.job.get_index_completed()
                validation_errors = self.validate_completion_indices(final_indices, "Final Validation")
                if validation_errors:
                    print("❌ Final validation failed:")
                    for err in validation_errors:
                        print(f"  - {err}")
                else:
                    print("✅ Final validation passed: All indices are correct")
            except Exception as e:
                self.add_error("Final Validation", f"Error getting final indices: {str(e)}")
                print(f"❌ Error in final validation: {str(e)}")
        
        # Report final status
        total_time = time.time() - start_time
        print(f"\n=== Test Summary ===")
        print(f"Total execution time: {total_time:.2f} seconds")
        
        # Group errors by phase type for summary
        validation_phases = [p for p in self.all_checkpoint_errors.keys() if "Validation" in p]
        checkpoint_phases = [p for p in self.all_checkpoint_errors.keys() if "Checkpoint" in p and "Validation" not in p]
        timeout_phases = [p for p in self.all_checkpoint_errors.keys() if "Timeout" in p]
        other_phases = [p for p in self.all_checkpoint_errors.keys() 
                        if p not in validation_phases and p not in checkpoint_phases and p not in timeout_phases]
        
        # Report errors by category
        print("\n=== Errors by Phase ===")
        
        # Show timeout errors first
        if timeout_phases:
            print("\nTimeout Errors:")
            for phase in timeout_phases:
                for err in self.all_checkpoint_errors[phase]:
                    print(f"  - {err}")
        
        # Show checkpoint execution errors
        if checkpoint_phases:
            print("\nCheckpoint Execution Errors:")
            for phase in sorted(checkpoint_phases):
                if phase in self.all_checkpoint_errors and self.all_checkpoint_errors[phase]:
                    print(f"  {phase}:")
                    for err in self.all_checkpoint_errors[phase]:
                        print(f"    - {err}")
                        
        # Show validation errors for each checkpoint
        if validation_phases:
            print("\nValidation Errors:")
            for phase in sorted(validation_phases):
                if phase in self.all_checkpoint_errors and self.all_checkpoint_errors[phase]:
                    print(f"  {phase}:")
                    for err in self.all_checkpoint_errors[phase]:
                        print(f"    - {err}")
                        
        # Show other errors
        if other_phases:
            print("\nOther Errors:")
            for phase in sorted(other_phases):
                if phase in self.all_checkpoint_errors and self.all_checkpoint_errors[phase]:
                    print(f"  {phase}:")
                    for err in self.all_checkpoint_errors[phase]:
                        print(f"    - {err}")
        
        if not self.errors:
            print("\n✅ Test completed successfully with no errors")
        else:
            validation_error_count = sum(len(self.all_checkpoint_errors[p]) for p in validation_phases if p in self.all_checkpoint_errors)
            checkpoint_error_count = sum(len(self.all_checkpoint_errors[p]) for p in checkpoint_phases if p in self.all_checkpoint_errors)
            timeout_error_count = sum(len(self.all_checkpoint_errors[p]) for p in timeout_phases if p in self.all_checkpoint_errors)
            other_error_count = sum(len(self.all_checkpoint_errors[p]) for p in other_phases if p in self.all_checkpoint_errors)
            
            print(f"\n❌ Test completed with {len(self.errors)} errors:")
            print(f"   - {timeout_error_count} timeout errors")
            print(f"   - {checkpoint_error_count} checkpoint execution errors")
            print(f"   - {validation_error_count} validation errors")
            print(f"   - {other_error_count} other errors")
        
        return {
            "success": len(self.errors) == 0,
            "errors": self.errors,
            "errors_by_checkpoint": self.all_checkpoint_errors,
            "total_time": total_time,
            "results": self.results
        }

In [None]:

# Run the test
runner = TestRunner(total_time_limit=20, checkpoint_interval=3, max_checkpoints=10)
result = runner.run_test()
if not result["success"]:
    # Format error message to include all errors organized by category
    error_parts = []
    
    # Categorize phases
    validation_phases = [p for p in result["errors_by_checkpoint"].keys() if "Validation" in p]
    checkpoint_phases = [p for p in result["errors_by_checkpoint"].keys() if "Checkpoint" in p and "Validation" not in p]
    timeout_phases = [p for p in result["errors_by_checkpoint"].keys() if "Timeout" in p]
    other_phases = [p for p in result["errors_by_checkpoint"].keys() 
                    if p not in validation_phases and p not in checkpoint_phases and p not in timeout_phases]
    
    # Add timeout errors first
    if timeout_phases:
        error_parts.append("\n=== TIMEOUT ERRORS ===")
        for phase in timeout_phases:
            for err in result["errors_by_checkpoint"][phase]:
                error_parts.append(f"- {err}")
    
    # Add checkpoint execution errors
    if checkpoint_phases:
        error_parts.append("\n=== CHECKPOINT EXECUTION ERRORS ===")
        for phase in sorted(checkpoint_phases):
            if phase in result["errors_by_checkpoint"] and result["errors_by_checkpoint"][phase]:
                error_parts.append(f"\n-- {phase} --")
                for err in result["errors_by_checkpoint"][phase]:
                    error_parts.append(f"- {err}")
    
    # Add validation errors for each checkpoint
    if validation_phases:
        error_parts.append("\n=== VALIDATION ERRORS ===")
        for phase in sorted(validation_phases):
            if phase in result["errors_by_checkpoint"] and result["errors_by_checkpoint"][phase]:
                error_parts.append(f"\n-- {phase} --")
                for err in result["errors_by_checkpoint"][phase]:
                    error_parts.append(f"- {err}")
    
    # Add other errors
    if other_phases:
        error_parts.append("\n=== OTHER ERRORS ===")
        for phase in sorted(other_phases):
            if phase in result["errors_by_checkpoint"] and result["errors_by_checkpoint"][phase]:
                error_parts.append(f"\n-- {phase} --")
                for err in result["errors_by_checkpoint"][phase]:
                    error_parts.append(f"- {err}")
    
    error_message = "\n".join(error_parts)
    validation_error_count = sum(len(result["errors_by_checkpoint"][p]) for p in validation_phases if p in result["errors_by_checkpoint"])
    checkpoint_error_count = sum(len(result["errors_by_checkpoint"][p]) for p in checkpoint_phases if p in result["errors_by_checkpoint"])
    timeout_error_count = sum(len(result["errors_by_checkpoint"][p]) for p in timeout_phases if p in result["errors_by_checkpoint"])
    other_error_count = sum(len(result["errors_by_checkpoint"][p]) for p in other_phases if p in result["errors_by_checkpoint"])
    
    raise RuntimeError(f"Test failed with {len(result['errors'])} total errors ({timeout_error_count} timeout, {checkpoint_error_count} execution, {validation_error_count} validation, {other_error_count} other):{error_message}")