In [14]:
import ast
import inspect
import textwrap
from functools import wraps
from rich.progress import Progress, SpinnerColumn, BarColumn, TextColumn, TimeRemainingColumn


class ProgressAnalyzer(ast.NodeVisitor):
    """First pass: count total number of for loops"""
    
    def __init__(self, source_lines):
        self.loops = []
        self.source_lines = source_lines
    
    def visit_For(self, node):
        # Get variable name for description
        if isinstance(node.target, ast.Name):
            var_name = node.target.id
        else:
            var_name = "item"
        
        # Skip underscore variables (convention for throwaway variables)
        if var_name != '_':
            # Extract comment from the for loop line if present
            comment = None
            if node.lineno - 1 < len(self.source_lines):
                line = self.source_lines[node.lineno - 1]
                if '#' in line:
                    comment = line.split('#', 1)[1].strip()
            
            self.loops.append({
                'var_name': var_name,
                'lineno': node.lineno,
                'comment': comment
            })
        
        # Continue visiting nested loops
        self.generic_visit(node)


class ProgressTransformer(ast.NodeTransformer):
    """Second pass: transform loops to use pre-created progress tasks"""
    
    def __init__(self, num_loops):
        self.loop_counter = 0
        self.num_loops = num_loops
    
    def visit_For(self, node):
        # Get variable name
        if isinstance(node.target, ast.Name):
            var_name = node.target.id
        else:
            var_name = "item"
        
        # Skip underscore variables
        if var_name == '_':
            self.generic_visit(node)
            return node
        
        loop_id = self.loop_counter
        self.loop_counter += 1
        
        task_var = f'_task_{loop_id}'
        iter_var = f'_iter_{loop_id}'
        items_var = f'_items_{loop_id}'
        
        # Create the setup statements
        setup = [
            # _iter = iter(original_iterable)
            ast.Assign(
                targets=[ast.Name(id=iter_var, ctx=ast.Store())],
                value=ast.Call(
                    func=ast.Name(id='iter', ctx=ast.Load()),
                    args=[node.iter],
                    keywords=[]
                )
            ),
            # _items = list(_iter)
            ast.Assign(
                targets=[ast.Name(id=items_var, ctx=ast.Store())],
                value=ast.Call(
                    func=ast.Name(id='list', ctx=ast.Load()),
                    args=[ast.Name(id=iter_var, ctx=ast.Load())],
                    keywords=[]
                )
            ),
            # Reset the task with the correct total
            ast.Expr(
                value=ast.Call(
                    func=ast.Attribute(
                        value=ast.Name(id='_progress', ctx=ast.Load()),
                        attr='reset',
                        ctx=ast.Load()
                    ),
                    args=[ast.Name(id=task_var, ctx=ast.Load())],
                    keywords=[
                        ast.keyword(
                            arg='total',
                            value=ast.Call(
                                func=ast.Name(id='len', ctx=ast.Load()),
                                args=[ast.Name(id=items_var, ctx=ast.Load())],
                                keywords=[]
                            )
                        )
                    ]
                )
            )
        ]
        
        # Add progress update to end of loop body
        update_call = ast.Expr(
            value=ast.Call(
                func=ast.Attribute(
                    value=ast.Name(id='_progress', ctx=ast.Load()),
                    attr='update',
                    ctx=ast.Load()
                ),
                args=[ast.Name(id=task_var, ctx=ast.Load())],
                keywords=[
                    ast.keyword(arg='advance', value=ast.Constant(value=1))
                ]
            )
        )
        
        # Modify the for loop to iterate over _items
        node.iter = ast.Name(id=items_var, ctx=ast.Load())
        node.body.append(update_call)
        
        # Recursively visit child nodes
        self.generic_visit(node)
        
        # Return setup + modified for loop
        return setup + [node]
    
    def visit_FunctionDef(self, node):
        # Visit all statements in the function body
        new_body = []
        for stmt in node.body:
            result = self.visit(stmt)
            if isinstance(result, list):
                new_body.extend(result)
            else:
                new_body.append(result)
        node.body = new_body
        return node


def auto_progress(func):
    """
    Decorator that automatically adds progress bars to all for loops in a function.
    Creates exactly one progress bar per loop that updates in place.
    """
    @wraps(func)
    def wrapper(*args, **kwargs):
        # Get the source code of the function
        source = inspect.getsource(func)
        
        # Remove the decorator line(s)
        source_lines = source.split('\n')
        func_start = next(i for i, line in enumerate(source_lines) 
                         if line.strip().startswith('def '))
        source = '\n'.join(source_lines[func_start:])
        source = textwrap.dedent(source)
        
        # Parse the source into an AST
        tree = ast.parse(source)
        
        # First pass: analyze to count loops
        analyzer = ProgressAnalyzer(source.split('\n'))
        analyzer.visit(tree)
        num_loops = len(analyzer.loops)
        
        # Second pass: transform the AST
        transformer = ProgressTransformer(num_loops)
        new_tree = transformer.visit(tree)
        ast.fix_missing_locations(new_tree)
        
        # Compile the modified AST
        code = compile(new_tree, filename='<ast>', mode='exec')
        
        # Create a Progress instance
        progress = Progress(
            SpinnerColumn(),
            TextColumn("[progress.description]{task.description}"),
            BarColumn(),
            TextColumn("[progress.percentage]{task.percentage:>3.0f}%"),
            TimeRemainingColumn(),
        )
        
        # Pre-create all tasks
        tasks = []
        for i, loop_info in enumerate(analyzer.loops):
            # Build description with comment if available
            desc = loop_info['var_name']
            if loop_info['comment']:
                desc = f"{desc} ({loop_info['comment']})"
            
            task = progress.add_task(
                desc,
                total=100,  # Placeholder, will be reset when loop starts
                start=False  # Don't start yet
            )
            tasks.append(task)
        
        # Create namespace with progress and tasks available
        namespace = {'_progress': progress}
        for i, task in enumerate(tasks):
            namespace[f'_task_{i}'] = task
        namespace.update(func.__globals__)
        
        # Execute the modified code
        exec(code, namespace)
        modified_func = namespace[func.__name__]
        
        # Run with progress context
        with progress:
            result = modified_func(*args, **kwargs)
            # Force a final refresh to ensure all bars show 100%
            progress.refresh()
        
        return result
    
    return wrapper


# Example usage
if __name__ == "__main__":
    import time
    import random
    
    @auto_progress
    def complex_data_processing(num_datasets, samples_per_dataset):
        """
        Simulates a complex data processing pipeline with nested and sequential loops.
        """
        all_results = []
        
        # Phase 1: Process multiple datasets (outer loop)
        for dataset_idx in range(num_datasets): # Processing datasets
            dataset_results = []
            
            # Phase 1a: Process samples in each dataset (nested loop)
            for sample_idx in range(samples_per_dataset): # Analyzing samples
                # Simulate expensive computation
                value = 0
                for _ in range(10000):
                    value += random.random() * random.random()
                
                dataset_results.append({
                    'dataset': dataset_idx,
                    'sample': sample_idx,
                    'value': value
                })
                
                # Simulate I/O delay
                time.sleep(0.01)
            
            all_results.append(dataset_results)
        
        # Phase 2: Post-processing (sequential loop)
        processed_data = []
        for dataset in all_results: # Aggregating results
            total = sum(item['value'] for item in dataset)
            avg = total / len(dataset)
            processed_data.append({
                'total': total,
                'average': avg,
                'count': len(dataset)
            })
            time.sleep(0.05)
        
        # Phase 3: Final validation loop
        validated = []
        for idx in range(len(processed_data)): # Validating data
            item = processed_data[idx]
            if item['average'] > 0:
                validated.append(item)
            time.sleep(0.02)
        
        return validated
    
    print("Starting complex data processing with fixed progress bars...\n")
    
    results = complex_data_processing(num_datasets=5, samples_per_dataset=30)
    
    print(f"\n✓ Processing complete!")
    print(f"  Final results: {len(results)} datasets validated")
    print(f"  Average value across all: {sum(r['average'] for r in results) / len(results):.2f}")

Output()

Starting complex data processing with fixed progress bars...




✓ Processing complete!
  Final results: 5 datasets validated
  Average value across all: 2498.62


In [7]:
%pip install rich

Collecting rich
  Downloading rich-14.2.0-py3-none-any.whl (243 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m243.4/243.4 kB[0m [31m2.3 MB/s[0m eta [36m0:00:00[0ma [36m0:00:01[0m
[?25hCollecting markdown-it-py>=2.2.0
  Downloading markdown_it_py-4.0.0-py3-none-any.whl (87 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m87.3/87.3 kB[0m [31m6.2 MB/s[0m eta [36m0:00:00[0m
Collecting mdurl~=0.1
  Downloading mdurl-0.1.2-py3-none-any.whl (10.0 kB)
Installing collected packages: mdurl, markdown-it-py, rich
Successfully installed markdown-it-py-4.0.0 mdurl-0.1.2 rich-14.2.0

[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m23.0.1[0m[39;49m -> [0m[32;49m25.3[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpip3.10 install --upgrade pip[0m
Note: you may need to restart the kernel to use updated packages.
