# Focused Learning: Test Case Generation Pipeline & Sandboxed Execution

## Learning Objectives
1. Master the art of automatic test case generation for code problems
2. Understand sandboxed execution environments for safe code evaluation
3. Learn to handle special data structures (trees, linked lists) in testing
4. Implement robust test generation with edge case detection

## Concept Source
- **Paper Section**: Section 2.1 (Data Collection) - Test Case Generation subsection
- **Key Figures**: Figure 4 (Input Generation Prompt), Figure 5 (Complex Input Generation)
- **Critical Quote**: "By applying both approaches multiple times, we construct an average of over 100 inputs per problem, including many complex cases, significantly reducing the risk of false positives." (Page 3)

## 1. The Challenge of Test Case Generation

### Why is this complex?

Generating test cases for competitive programming problems involves:
1. **Understanding constraints** from natural language descriptions
2. **Generating valid inputs** that respect problem constraints
3. **Computing correct outputs** using canonical solutions
4. **Handling special data structures** (trees, graphs, linked lists)
5. **Ensuring coverage** of edge cases and corner cases

The paper's approach uses LLMs for intelligent test generation, combining simple and complex cases.

In [None]:
import ast
import json
import subprocess
import tempfile
import os
import sys
from typing import List, Dict, Any, Optional, Tuple
from dataclasses import dataclass
import numpy as np
from collections import deque
import re
import traceback

# For visualization
import matplotlib.pyplot as plt
import networkx as nx
from matplotlib.patches import Rectangle

## 2. Core Components of Test Case Generation

### 2.1 Entry Point Identification

First, we need to identify the function to test from the starter code:

In [None]:
class EntryPointExtractor:
    """Extract function entry points from starter code"""
    
    @staticmethod
    def extract_entry_point(starter_code: str) -> Dict[str, Any]:
        """Extract function name and parameter information"""
        try:
            # Parse the code
            tree = ast.parse(starter_code)
            
            # Find function definitions
            for node in ast.walk(tree):
                if isinstance(node, ast.FunctionDef):
                    # Extract function name
                    func_name = node.name
                    
                    # Extract parameters
                    params = []
                    for arg in node.args.args:
                        param_name = arg.arg
                        
                        # Try to extract type annotation
                        param_type = None
                        if arg.annotation:
                            param_type = ast.unparse(arg.annotation)
                        
                        params.append({
                            'name': param_name,
                            'type': param_type
                        })
                    
                    # Extract return type
                    return_type = None
                    if node.returns:
                        return_type = ast.unparse(node.returns)
                    
                    return {
                        'function_name': func_name,
                        'parameters': params,
                        'return_type': return_type,
                        'is_class_method': len(params) > 0 and params[0]['name'] == 'self'
                    }
            
            return None
            
        except Exception as e:
            print(f"Error parsing code: {e}")
            return None

# Test the extractor
test_starter_code = """
class Solution:
    def missingNumber(self, arr: List[int]) -> int:
        pass
"""

entry_point = EntryPointExtractor.extract_entry_point(test_starter_code)
print("Extracted entry point:")
print(json.dumps(entry_point, indent=2))

### 2.2 Constraint Extraction from Problem Description

Extract constraints to guide test case generation:

In [None]:
@dataclass
class Constraint:
    """Represents a problem constraint"""
    variable: str
    min_value: Optional[float]
    max_value: Optional[float]
    constraint_type: str  # 'range', 'length', 'value'
    
class ConstraintExtractor:
    """Extract constraints from problem descriptions"""
    
    @staticmethod
    def extract_constraints(problem_description: str) -> List[Constraint]:
        """Extract numerical constraints using regex patterns"""
        constraints = []
        
        # Common constraint patterns
        patterns = [
            # Pattern: "1 <= arr.length <= 1000"
            r'(\d+)\s*<=\s*(\w+\.?\w*)\s*<=\s*(\d+)',
            # Pattern: "0 <= arr[i] <= 10^5"
            r'(\d+)\s*<=\s*(\w+)\[i\]\s*<=\s*(\d+\^?\d*)',
            # Pattern: "n = arr.length"
            r'(\w+)\s*=\s*(\w+)\.length',
            # Pattern: "The array has at least 3 elements"
            r'at least (\d+) elements',
        ]
        
        for pattern in patterns:
            matches = re.findall(pattern, problem_description)
            for match in matches:
                if len(match) == 3 and match[0].isdigit():
                    # Range constraint
                    min_val = int(match[0])
                    var_name = match[1]
                    max_val_str = match[2]
                    
                    # Handle exponential notation
                    if '^' in max_val_str:
                        base, exp = max_val_str.split('^')
                        max_val = int(base) ** int(exp)
                    else:
                        max_val = int(max_val_str)
                    
                    constraint_type = 'length' if '.length' in var_name else 'value'
                    constraints.append(Constraint(
                        variable=var_name,
                        min_value=min_val,
                        max_value=max_val,
                        constraint_type=constraint_type
                    ))
        
        return constraints

# Test constraint extraction
test_problem = """
Given an array arr, return the missing number.

Constraints:
3 <= arr.length <= 1000
0 <= arr[i] <= 10^5
The array has at least 3 elements.
"""

constraints = ConstraintExtractor.extract_constraints(test_problem)
print("Extracted constraints:")
for c in constraints:
    print(f"  {c.variable}: [{c.min_value}, {c.max_value}] (type: {c.constraint_type})")

## 3. Intelligent Test Input Generation

### 3.1 Simple Input Generation

Following the paper's approach with one-shot prompting:

In [None]:
class TestInputGenerator:
    """Generate test inputs using constraint-aware strategies"""
    
    def __init__(self, constraints: List[Constraint]):
        self.constraints = constraints
        
    def generate_simple_inputs(self, param_info: Dict, num_cases: int = 10) -> List[Dict]:
        """Generate simple test inputs based on constraints"""
        inputs = []
        
        # Extract parameter type
        param_name = param_info['name']
        param_type = param_info['type']
        
        if 'List[int]' in str(param_type):
            # Generate array inputs
            length_constraint = next((c for c in self.constraints 
                                    if 'length' in c.constraint_type), None)
            value_constraint = next((c for c in self.constraints 
                                   if 'value' in c.constraint_type), None)
            
            min_len = length_constraint.min_value if length_constraint else 1
            max_len = min(length_constraint.max_value if length_constraint else 100, 20)
            min_val = value_constraint.min_value if value_constraint else 0
            max_val = min(value_constraint.max_value if value_constraint else 1000, 1000)
            
            # Generate diverse cases
            for i in range(num_cases):
                if i == 0:
                    # Minimum size
                    length = min_len
                elif i == 1:
                    # Maximum size (capped)
                    length = min(max_len, 20)
                else:
                    # Random sizes
                    length = np.random.randint(min_len, min(max_len, 20) + 1)
                
                # Generate array values
                if i < 3:
                    # Simple patterns
                    arr = list(range(min_val, min_val + length))
                else:
                    # Random values
                    arr = [np.random.randint(min_val, min(max_val, 1000)) 
                          for _ in range(length)]
                
                inputs.append({param_name: arr})
        
        elif 'int' in str(param_type):
            # Generate integer inputs
            value_constraint = next((c for c in self.constraints 
                                   if param_name in c.variable), None)
            
            min_val = value_constraint.min_value if value_constraint else -1000
            max_val = value_constraint.max_value if value_constraint else 1000
            
            # Edge cases + random
            edge_cases = [min_val, max_val, 0, 1, -1]
            for val in edge_cases[:num_cases//2]:
                if min_val <= val <= max_val:
                    inputs.append({param_name: val})
            
            # Random values
            for _ in range(num_cases - len(inputs)):
                val = np.random.randint(min_val, max_val + 1)
                inputs.append({param_name: val})
        
        return inputs
    
    def generate_complex_inputs(self, param_info: Dict, 
                              simple_inputs: List[Dict], 
                              num_cases: int = 10) -> List[Dict]:
        """Generate complex inputs based on simple examples"""
        complex_inputs = []
        param_name = param_info['name']
        
        if 'List[int]' in str(param_info['type']):
            # Analyze simple inputs to understand patterns
            for _ in range(num_cases):
                strategy = np.random.choice([
                    'large_values',
                    'negative_values', 
                    'duplicates',
                    'sorted_desc',
                    'all_same',
                    'alternating'
                ])
                
                length = np.random.randint(10, 50)
                
                if strategy == 'large_values':
                    arr = [np.random.randint(10000, 100000) for _ in range(length)]
                elif strategy == 'negative_values':
                    arr = [np.random.randint(-1000, 0) for _ in range(length)]
                elif strategy == 'duplicates':
                    unique_vals = np.random.randint(1, length//2)
                    vals = [np.random.randint(0, 100) for _ in range(unique_vals)]
                    arr = np.random.choice(vals, size=length).tolist()
                elif strategy == 'sorted_desc':
                    arr = sorted([np.random.randint(0, 1000) for _ in range(length)], 
                               reverse=True)
                elif strategy == 'all_same':
                    val = np.random.randint(0, 100)
                    arr = [val] * length
                else:  # alternating
                    arr = [i if i % 2 == 0 else -i for i in range(length)]
                
                complex_inputs.append({param_name: arr})
        
        return complex_inputs

# Test input generation
generator = TestInputGenerator(constraints)
param_info = {'name': 'arr', 'type': 'List[int]'}

simple_inputs = generator.generate_simple_inputs(param_info, num_cases=5)
print("Simple inputs:")
for i, inp in enumerate(simple_inputs):
    print(f"  {i+1}: {inp}")

complex_inputs = generator.generate_complex_inputs(param_info, simple_inputs, num_cases=5)
print("\nComplex inputs:")
for i, inp in enumerate(complex_inputs):
    print(f"  {i+1}: {inp['arr'][:10]}{'...' if len(inp['arr']) > 10 else ''}")

## 4. Sandboxed Code Execution

### 4.1 Basic Sandbox Implementation

The paper emphasizes safe execution of potentially untrusted code:

In [None]:
class CodeSandbox:
    """Sandboxed environment for safe code execution"""
    
    def __init__(self, timeout: int = 5, memory_limit_mb: int = 256):
        self.timeout = timeout
        self.memory_limit_mb = memory_limit_mb
        
    def execute_code(self, code: str, test_input: Dict, 
                    entry_point: str) -> Tuple[bool, Any, str]:
        """Execute code with input and return (success, output, error)"""
        
        # Create execution template
        execution_template = f"""
import sys
import resource
from typing import List, Optional
import json

# Set memory limit
resource.setrlimit(resource.RLIMIT_AS, ({self.memory_limit_mb} * 1024 * 1024, -1))

# User code
{code}

# Test execution
try:
    solution = Solution()
    test_input = {test_input}
    result = solution.{entry_point}(**test_input)
    print(json.dumps({{'success': True, 'output': result}}))
except Exception as e:
    print(json.dumps({{'success': False, 'error': str(e)}}))
"""
        
        try:
            # Write to temporary file
            with tempfile.NamedTemporaryFile(mode='w', suffix='.py', delete=False) as f:
                f.write(execution_template)
                f.flush()
                temp_file = f.name
            
            # Execute with timeout
            result = subprocess.run(
                [sys.executable, temp_file],
                capture_output=True,
                text=True,
                timeout=self.timeout
            )
            
            # Parse output
            if result.returncode == 0:
                output_data = json.loads(result.stdout.strip())
                if output_data['success']:
                    return True, output_data['output'], None
                else:
                    return False, None, output_data['error']
            else:
                return False, None, result.stderr
                
        except subprocess.TimeoutExpired:
            return False, None, "Execution timeout"
        except json.JSONDecodeError:
            return False, None, f"Invalid output: {result.stdout}"
        except Exception as e:
            return False, None, str(e)
        finally:
            # Cleanup
            if 'temp_file' in locals():
                os.unlink(temp_file)
    
    def batch_execute(self, code: str, test_inputs: List[Dict], 
                     entry_point: str) -> List[Dict]:
        """Execute code on multiple test inputs"""
        results = []
        
        for i, test_input in enumerate(test_inputs):
            success, output, error = self.execute_code(code, test_input, entry_point)
            results.append({
                'test_id': i,
                'input': test_input,
                'success': success,
                'output': output,
                'error': error
            })
        
        return results

# Test the sandbox
test_solution = """
class Solution:
    def missingNumber(self, arr: List[int]) -> int:
        # Calculate expected sum of arithmetic progression
        n = len(arr)
        expected_sum = (n + 1) * (arr[0] + arr[-1]) // 2
        actual_sum = sum(arr)
        return expected_sum - actual_sum
"""

sandbox = CodeSandbox(timeout=2)
test_inputs = [
    {'arr': [5, 7, 11, 13]},  # Missing 9
    {'arr': [15, 13, 12]},    # Missing 14  
]

results = sandbox.batch_execute(test_solution, test_inputs, 'missingNumber')
for result in results:
    print(f"Input: {result['input']['arr']}")
    if result['success']:
        print(f"Output: {result['output']}")
    else:
        print(f"Error: {result['error']}")
    print()

## 5. Handling Special Data Structures

### 5.1 Binary Trees

The paper provides specific handling for tree structures (Figure 7):

In [None]:
class TreeNode:
    def __init__(self, val=0, left=None, right=None):
        self.val = val
        self.left = left
        self.right = right

class TreeHandler:
    """Handle binary tree serialization/deserialization for testing"""
    
    @staticmethod
    def tree_node(values: List[Any]) -> Optional[TreeNode]:
        """Convert list representation to TreeNode (from paper)"""
        if not values:
            return None
            
        root = TreeNode(values[0])
        i = 1
        queue = deque()
        queue.append(root)
        
        while queue:
            node = queue.popleft()
            if i < len(values) and values[i] is not None:
                node.left = TreeNode(values[i])
                queue.append(node.left)
            i += 1
            
            if i < len(values) and values[i] is not None:
                node.right = TreeNode(values[i])
                queue.append(node.right)
            i += 1
            
        return root
    
    @staticmethod
    def tree_node_to_list(root: Optional[TreeNode]) -> List[Any]:
        """Convert TreeNode to list representation (from paper)"""
        if not root:
            return []
            
        result = []
        queue = deque()
        queue.append(root)
        
        while queue:
            node = queue.popleft()
            if node:
                result.append(node.val)
                queue.append(node.left)
                queue.append(node.right)
            else:
                result.append(None)
        
        # Remove trailing None values
        while result and result[-1] is None:
            result.pop()
            
        return result
    
    @staticmethod
    def visualize_tree(root: Optional[TreeNode]):
        """Visualize binary tree structure"""
        if not root:
            print("Empty tree")
            return
            
        # Create graph
        G = nx.DiGraph()
        pos = {}
        labels = {}
        
        def add_nodes(node, x=0, y=0, layer=1):
            if node:
                node_id = id(node)
                G.add_node(node_id)
                pos[node_id] = (x, y)
                labels[node_id] = str(node.val)
                
                # Add children
                spacing = 2 ** (4 - layer)
                if node.left:
                    left_id = id(node.left)
                    G.add_edge(node_id, left_id)
                    add_nodes(node.left, x - spacing, y - 1, layer + 1)
                if node.right:
                    right_id = id(node.right)
                    G.add_edge(node_id, right_id)
                    add_nodes(node.right, x + spacing, y - 1, layer + 1)
        
        add_nodes(root)
        
        plt.figure(figsize=(10, 8))
        nx.draw(G, pos, labels=labels, with_labels=True, 
               node_color='lightblue', node_size=1500,
               font_size=16, font_weight='bold',
               arrows=True, arrowsize=20)
        plt.title("Binary Tree Visualization")
        plt.axis('off')
        plt.show()

# Test tree handling
tree_values = [1, 2, 3, None, 4, 5, None, None, None, 6]
tree = TreeHandler.tree_node(tree_values)
reconstructed = TreeHandler.tree_node_to_list(tree)

print(f"Original: {tree_values}")
print(f"Reconstructed: {reconstructed}")

# Visualize
TreeHandler.visualize_tree(tree)

### 5.2 Linked Lists

Similarly for linked lists (Figure 6):

In [None]:
class ListNode:
    def __init__(self, val=0, next=None):
        self.val = val
        self.next = next

class LinkedListHandler:
    """Handle linked list operations for testing"""
    
    @staticmethod
    def list_node(values: List[int]) -> Optional[ListNode]:
        """Convert list to linked list (from paper)"""
        if not values:
            return None
            
        head = ListNode(values[0])
        p = head
        for val in values[1:]:
            node = ListNode(val)
            p.next = node
            p = node
        return head
    
    @staticmethod
    def linked_list_to_list(head: Optional[ListNode]) -> List[int]:
        """Convert linked list to list (from paper)"""
        result = []
        current = head
        while current:
            result.append(current.val)
            current = current.next
        return result
    
    @staticmethod
    def visualize_linked_list(head: Optional[ListNode]):
        """Visualize linked list structure"""
        if not head:
            print("Empty list")
            return
            
        fig, ax = plt.subplots(figsize=(12, 3))
        
        # Count nodes
        count = 0
        current = head
        while current:
            count += 1
            current = current.next
        
        # Draw nodes
        node_width = 1.5
        node_height = 1
        spacing = 0.5
        
        current = head
        x = 1
        y = 1
        
        while current:
            # Draw node box
            rect = Rectangle((x, y), node_width, node_height, 
                           facecolor='lightblue', edgecolor='black', linewidth=2)
            ax.add_patch(rect)
            
            # Add value text
            ax.text(x + node_width/2, y + node_height/2, str(current.val),
                   ha='center', va='center', fontsize=14, fontweight='bold')
            
            # Draw arrow to next
            if current.next:
                ax.arrow(x + node_width, y + node_height/2, 
                        spacing - 0.1, 0,
                        head_width=0.2, head_length=0.1, 
                        fc='black', ec='black')
            
            x += node_width + spacing
            current = current.next
        
        # Add NULL at the end
        ax.text(x - spacing/2, y + node_height/2, 'NULL',
               ha='center', va='center', fontsize=12, style='italic')
        
        ax.set_xlim(0, x + 1)
        ax.set_ylim(0, 3)
        ax.axis('off')
        ax.set_title('Linked List Visualization', fontsize=16)
        plt.show()

# Test linked list handling
list_values = [1, 2, 3, 4, 5]
linked_list = LinkedListHandler.list_node(list_values)
reconstructed = LinkedListHandler.linked_list_to_list(linked_list)

print(f"Original: {list_values}")
print(f"Reconstructed: {reconstructed}")

# Visualize
LinkedListHandler.visualize_linked_list(linked_list)

## 6. Complete Test Case Generation Pipeline

Now let's integrate everything into a complete pipeline:

In [None]:
class TestCaseGenerationPipeline:
    """Complete pipeline for generating test cases"""
    
    def __init__(self):
        self.sandbox = CodeSandbox(timeout=5)
        self.special_handlers = {
            'TreeNode': TreeHandler,
            'ListNode': LinkedListHandler
        }
    
    def generate_test_cases(self, problem: Dict, 
                          canonical_solution: str,
                          num_simple: int = 50,
                          num_complex: int = 50) -> List[Dict]:
        """Generate complete test cases for a problem"""
        
        # Step 1: Extract entry point
        entry_point_info = EntryPointExtractor.extract_entry_point(problem['starter_code'])
        if not entry_point_info:
            raise ValueError("Could not extract entry point")
        
        function_name = entry_point_info['function_name']
        
        # Step 2: Extract constraints
        constraints = ConstraintExtractor.extract_constraints(problem['description'])
        
        # Step 3: Generate inputs
        generator = TestInputGenerator(constraints)
        test_inputs = []
        
        for param in entry_point_info['parameters'][1:]:  # Skip 'self'
            # Generate simple inputs
            simple_inputs = generator.generate_simple_inputs(param, num_simple)
            
            # Generate complex inputs
            complex_inputs = generator.generate_complex_inputs(
                param, simple_inputs, num_complex
            )
            
            test_inputs.extend(simple_inputs + complex_inputs)
        
        # Step 4: Execute canonical solution to get outputs
        print(f"Generating outputs for {len(test_inputs)} test inputs...")
        
        # Add necessary imports to canonical solution
        enhanced_solution = self._enhance_solution_with_imports(canonical_solution)
        
        # Execute and collect results
        results = self.sandbox.batch_execute(
            enhanced_solution, test_inputs, function_name
        )
        
        # Step 5: Create test cases
        test_cases = []
        for result in results:
            if result['success']:
                test_cases.append({
                    'input': result['input'],
                    'output': result['output']
                })
        
        print(f"Successfully generated {len(test_cases)} test cases")
        
        # Step 6: Validate test case quality
        quality_report = self._validate_test_quality(test_cases)
        
        return test_cases, quality_report
    
    def _enhance_solution_with_imports(self, solution: str) -> str:
        """Add necessary imports based on code analysis"""
        imports = [
            "from typing import List, Optional, Dict, Tuple, Set",
            "import math",
            "import heapq",
            "from collections import defaultdict, deque, Counter"
        ]
        
        # Check if special data structures are used
        if 'TreeNode' in solution:
            imports.append(self._get_tree_imports())
        if 'ListNode' in solution:
            imports.append(self._get_linked_list_imports())
        
        return "\n".join(imports) + "\n\n" + solution
    
    def _get_tree_imports(self) -> str:
        """Get tree-related imports (from paper Figure 7)"""
        return """
class TreeNode:
    def __init__(self, val=0, left=None, right=None):
        self.val = val
        self.left = left
        self.right = right
"""
    
    def _get_linked_list_imports(self) -> str:
        """Get linked list imports (from paper Figure 6)"""
        return """
class ListNode:
    def __init__(self, val=0, next=None):
        self.val = val
        self.next = next
"""
    
    def _validate_test_quality(self, test_cases: List[Dict]) -> Dict:
        """Validate the quality of generated test cases"""
        report = {
            'total_cases': len(test_cases),
            'input_diversity': 0,
            'output_diversity': 0,
            'edge_cases_covered': [],
            'warnings': []
        }
        
        if not test_cases:
            report['warnings'].append("No test cases generated")
            return report
        
        # Analyze input diversity
        input_hashes = set()
        for tc in test_cases:
            input_str = json.dumps(tc['input'], sort_keys=True)
            input_hashes.add(hash(input_str))
        
        report['input_diversity'] = len(input_hashes) / len(test_cases)
        
        # Analyze output diversity
        unique_outputs = set(str(tc['output']) for tc in test_cases)
        report['output_diversity'] = len(unique_outputs) / len(test_cases)
        
        # Check for edge cases
        for tc in test_cases:
            for key, value in tc['input'].items():
                if isinstance(value, list):
                    if len(value) == 0:
                        report['edge_cases_covered'].append('empty_array')
                    elif len(value) == 1:
                        report['edge_cases_covered'].append('single_element')
                    elif all(v == value[0] for v in value):
                        report['edge_cases_covered'].append('all_same_elements')
                elif isinstance(value, int):
                    if value == 0:
                        report['edge_cases_covered'].append('zero_value')
                    elif value < 0:
                        report['edge_cases_covered'].append('negative_value')
        
        report['edge_cases_covered'] = list(set(report['edge_cases_covered']))
        
        # Warnings
        if report['input_diversity'] < 0.8:
            report['warnings'].append("Low input diversity - consider more varied inputs")
        if report['output_diversity'] < 0.3:
            report['warnings'].append("Low output diversity - may need more complex test cases")
        
        return report

# Test the complete pipeline
test_problem = {
    'starter_code': """
class Solution:
    def missingNumber(self, arr: List[int]) -> int:
        pass
""",
    'description': """
Given an arithmetic progression array with one missing element.
Constraints:
3 <= arr.length <= 1000
0 <= arr[i] <= 10^5
"""
}

canonical_solution = """
class Solution:
    def missingNumber(self, arr: List[int]) -> int:
        n = len(arr)
        total_diff = arr[-1] - arr[0]
        common_diff = total_diff // n
        
        for i in range(n - 1):
            expected_next = arr[i] + common_diff
            if arr[i + 1] != expected_next:
                return expected_next
        
        return arr[0] + common_diff
"""

pipeline = TestCaseGenerationPipeline()
test_cases, quality_report = pipeline.generate_test_cases(
    test_problem, 
    canonical_solution,
    num_simple=10,
    num_complex=10
)

print("\nQuality Report:")
print(json.dumps(quality_report, indent=2))

print("\nSample test cases:")
for i, tc in enumerate(test_cases[:5]):
    print(f"Test {i+1}: Input={tc['input']['arr']}, Output={tc['output']}")

## 7. Advanced Topics: False Positive Prevention

The paper emphasizes generating 100+ test cases to prevent false positives:

In [None]:
class FalsePositiveDetector:
    """Detect and prevent false positives in test evaluation"""
    
    @staticmethod
    def analyze_solution_coverage(solution_code: str, 
                                test_cases: List[Dict]) -> Dict:
        """Analyze how well test cases cover the solution space"""
        
        # Common incorrect patterns to check
        incorrect_patterns = [
            {
                'name': 'always_return_constant',
                'code': 'return {constant}',
                'description': 'Solution always returns same value'
            },
            {
                'name': 'return_first_element',
                'code': 'return arr[0]',
                'description': 'Solution returns first element'
            },
            {
                'name': 'return_input_length',
                'code': 'return len(arr)',
                'description': 'Solution returns array length'
            }
        ]
        
        vulnerabilities = []
        
        # Check each incorrect pattern
        for pattern in incorrect_patterns:
            # For constant returns, check if any output appears too frequently
            if pattern['name'] == 'always_return_constant':
                outputs = [tc['output'] for tc in test_cases]
                from collections import Counter
                output_counts = Counter(outputs)
                most_common = output_counts.most_common(1)[0]
                
                if most_common[1] / len(outputs) > 0.3:
                    vulnerabilities.append({
                        'pattern': pattern['name'],
                        'risk': 'high',
                        'details': f"Output {most_common[0]} appears in {most_common[1]}/{len(outputs)} cases"
                    })
        
        # Calculate coverage metrics
        coverage_report = {
            'total_test_cases': len(test_cases),
            'unique_outputs': len(set(tc['output'] for tc in test_cases)),
            'vulnerabilities': vulnerabilities,
            'recommendations': []
        }
        
        # Recommendations
        if coverage_report['unique_outputs'] < len(test_cases) * 0.5:
            coverage_report['recommendations'].append(
                "Add more diverse test cases to increase output variety"
            )
        
        if len(test_cases) < 100:
            coverage_report['recommendations'].append(
                f"Paper recommends 100+ test cases. Current: {len(test_cases)}"
            )
        
        return coverage_report
    
    @staticmethod
    def generate_adversarial_cases(problem_type: str, 
                                 existing_cases: List[Dict]) -> List[Dict]:
        """Generate adversarial test cases to catch common mistakes"""
        adversarial_cases = []
        
        if problem_type == 'array_manipulation':
            # Edge cases that often break naive solutions
            adversarial_inputs = [
                {'arr': [1]},  # Single element
                {'arr': [1, 1, 1, 1, 1]},  # All same
                {'arr': [-1000000, 1000000]},  # Extreme values
                {'arr': list(range(1000, 0, -1))},  # Reverse sorted
                {'arr': [0] * 100},  # All zeros
            ]
            
            # Add cases that would pass incorrect solutions
            for inp in adversarial_inputs:
                adversarial_cases.append({
                    'input': inp,
                    'is_adversarial': True,
                    'targets': ['constant_return', 'first_element_return']
                })
        
        return adversarial_cases

# Analyze our generated test cases
detector = FalsePositiveDetector()
coverage_report = detector.analyze_solution_coverage(canonical_solution, test_cases)

print("Coverage Analysis:")
print(json.dumps(coverage_report, indent=2))

# Generate adversarial cases
adversarial = detector.generate_adversarial_cases('array_manipulation', test_cases)
print(f"\nGenerated {len(adversarial)} adversarial test cases")
for i, case in enumerate(adversarial[:3]):
    print(f"  Adversarial {i+1}: {case['input']} targets {case['targets']}")

## 8. Performance Optimization for Large-Scale Generation

When generating test cases for 2,869 problems with 100+ cases each:

In [None]:
import concurrent.futures
import time
from functools import partial

class OptimizedTestGenerator:
    """Optimized test generation for large-scale processing"""
    
    def __init__(self, max_workers: int = 4):
        self.max_workers = max_workers
        self.cache = {}  # Cache for similar problems
        
    def batch_generate_test_cases(self, problems: List[Dict], 
                                canonical_solutions: Dict[str, str]) -> Dict:
        """Generate test cases for multiple problems in parallel"""
        
        start_time = time.time()
        results = {}
        
        # Group problems by similarity for cache efficiency
        problem_groups = self._group_similar_problems(problems)
        
        with concurrent.futures.ProcessPoolExecutor(max_workers=self.max_workers) as executor:
            # Submit tasks
            future_to_problem = {}
            
            for group_id, group_problems in problem_groups.items():
                # Use cached patterns for similar problems
                base_pattern = self._get_base_pattern(group_id)
                
                for problem in group_problems:
                    if problem['id'] in canonical_solutions:
                        future = executor.submit(
                            self._generate_single_problem_tests,
                            problem,
                            canonical_solutions[problem['id']],
                            base_pattern
                        )
                        future_to_problem[future] = problem['id']
            
            # Collect results
            completed = 0
            for future in concurrent.futures.as_completed(future_to_problem):
                problem_id = future_to_problem[future]
                try:
                    test_cases = future.result()
                    results[problem_id] = test_cases
                    completed += 1
                    
                    if completed % 100 == 0:
                        elapsed = time.time() - start_time
                        rate = completed / elapsed
                        print(f"Progress: {completed}/{len(problems)} "
                              f"({rate:.1f} problems/sec)")
                        
                except Exception as e:
                    print(f"Failed to generate tests for {problem_id}: {e}")
                    results[problem_id] = []
        
        total_time = time.time() - start_time
        print(f"\nCompleted {len(results)} problems in {total_time:.1f} seconds")
        print(f"Average: {total_time/len(results):.2f} seconds per problem")
        
        return results
    
    def _group_similar_problems(self, problems: List[Dict]) -> Dict[str, List[Dict]]:
        """Group problems by type for efficient processing"""
        groups = {}
        
        for problem in problems:
            # Simple grouping by tags
            tags = problem.get('tags', [])
            group_key = '-'.join(sorted(tags[:2]))  # Use first 2 tags
            
            if group_key not in groups:
                groups[group_key] = []
            groups[group_key].append(problem)
        
        return groups
    
    def _get_base_pattern(self, group_id: str) -> Dict:
        """Get cached test patterns for problem group"""
        if group_id in self.cache:
            return self.cache[group_id]
        
        # Create base pattern for group
        base_pattern = {
            'input_strategies': ['edge_cases', 'random', 'patterns'],
            'size_distribution': [0.2, 0.6, 0.2]  # small, medium, large
        }
        
        self.cache[group_id] = base_pattern
        return base_pattern
    
    def _generate_single_problem_tests(self, problem: Dict, 
                                     solution: str, 
                                     base_pattern: Dict) -> List[Dict]:
        """Generate tests for a single problem (runs in separate process)"""
        # This would use the full pipeline we built earlier
        # Simplified for demonstration
        return [
            {'input': {'arr': [1, 2, 3]}, 'output': 0},
            {'input': {'arr': [5, 7, 11]}, 'output': 9}
        ]
    
    def estimate_processing_time(self, num_problems: int) -> Dict:
        """Estimate time required for processing"""
        # Based on paper's scale: 2,869 problems, 100+ tests each
        avg_time_per_problem = 2.5  # seconds
        total_serial_time = num_problems * avg_time_per_problem
        
        parallel_time = total_serial_time / self.max_workers
        overhead = parallel_time * 0.1  # 10% overhead
        
        return {
            'problems': num_problems,
            'serial_time_hours': total_serial_time / 3600,
            'parallel_time_hours': (parallel_time + overhead) / 3600,
            'speedup': total_serial_time / (parallel_time + overhead),
            'recommendation': self._get_recommendation(num_problems)
        }
    
    def _get_recommendation(self, num_problems: int) -> str:
        if num_problems < 100:
            return "Use single-threaded processing"
        elif num_problems < 1000:
            return "Use 4-8 workers for optimal performance"
        else:
            return "Consider distributed processing or cloud compute"

# Demonstrate optimization
optimizer = OptimizedTestGenerator(max_workers=4)

# Estimate for paper's scale
estimate = optimizer.estimate_processing_time(2869)
print("Processing Time Estimate for LeetCodeDataset Scale:")
print(json.dumps(estimate, indent=2))

## 9. Key Takeaways and Best Practices

### Critical Insights from the Paper:

1. **100+ Test Cases**: Essential for preventing false positives
2. **Two-Stage Generation**: Simple cases + complex cases for comprehensive coverage
3. **Special Data Structures**: Require dedicated serialization/deserialization
4. **Sandboxed Execution**: Critical for safe evaluation of untrusted code

### Best Practices:

1. **Constraint-Aware Generation**: Extract and respect problem constraints
2. **Diversity Metrics**: Monitor input/output diversity to ensure coverage
3. **Adversarial Testing**: Include cases that target common mistakes
4. **Parallel Processing**: Essential for large-scale generation
5. **Caching Strategies**: Reuse patterns for similar problems

### Implementation Checklist:

- [ ] Entry point extraction with AST parsing
- [ ] Constraint extraction from natural language
- [ ] Multi-stage input generation (simple + complex)
- [ ] Sandboxed execution with timeout and memory limits
- [ ] Special handling for trees, linked lists, graphs
- [ ] Quality validation and false positive detection
- [ ] Performance optimization for scale

### Future Improvements:

1. **LLM-Guided Generation**: Use GPT-4/Claude for smarter test generation
2. **Mutation Testing**: Generate variants to test solution robustness
3. **Coverage Analysis**: Ensure all code paths are tested
4. **Automated Debugging**: Help identify why solutions fail specific tests