In [None]:
# ARC-AGI 2025 Submission - VAE Decoder Approach

# This notebook implements a sophisticated VAE decoder approach for solving ARC-AGI tasks.

# ## Architecture Overview
# - Multi-tensor system for handling 5D data (examples, colors, directions, x, y)
# - VAE decoder with KL divergence loss
# - Directional processing layers (cummax, shift)
# - Solution selection based on uncertainty scoring

# Import required libraries
import os
import sys
import time
import json
import gc
import traceback
import numpy as np
import torch
import matplotlib.pyplot as plt
from tqdm.auto import tqdm

# Add the compressarc dataset path
sys.path.append('/kaggle/input/compressarc')

print("Importing ARC solver modules...")
try:
    import multitensor_systems
    import preprocessing
    import arc_compressor
    import initializers
    import layers
    import solution_selection
    import train
    import solve_task
    print("✓ All modules imported successfully")
except ImportError as e:
    print(f"✗ Import error: {e}")
    print("Make sure the 'compressarc' dataset is added as input")

# Setup environment and check GPU availability
print("Setting up environment...")

# Set random seeds for reproducibility
np.random.seed(0)
torch.manual_seed(0)

# Configure device
if torch.cuda.is_available():
    device = torch.cuda.get_device_name(0)
    memory = torch.cuda.get_device_properties(0).total_memory / 1e9
    torch.set_default_device('cuda')
    torch.set_default_dtype(torch.float32)
    print(f"✓ Using CUDA: {device}")
    print(f"✓ GPU Memory: {memory:.1f} GB")
else:
    torch.set_default_device('cpu')
    print("⚠ Using CPU (CUDA not available)")

print(f"✓ PyTorch version: {torch.__version__}")
print(f"✓ NumPy version: {np.__version__}")

# Load and explore test data
print("Loading test challenges...")

try:
    with open('/kaggle/input/arc-prize-2025/arc-agi_test_challenges.json', 'r') as f:
        test_challenges = json.load(f)
    
    print(f"✓ Found {len(test_challenges)} test tasks")
    
    # Analyze task characteristics
    task_stats = {
        'total_tasks': len(test_challenges),
        'train_examples': [],
        'test_examples': [],
        'grid_sizes': []
    }
    
    for task_name, task_data in test_challenges.items():
        task_stats['train_examples'].append(len(task_data['train']))
        task_stats['test_examples'].append(len(task_data['test']))
        
        # Collect grid sizes
        for example in task_data['train'] + task_data['test']:
            input_shape = np.array(example['input']).shape
            task_stats['grid_sizes'].append(input_shape)
            if 'output' in example:
                output_shape = np.array(example['output']).shape
                task_stats['grid_sizes'].append(output_shape)
    
    print(f"Train examples per task: {np.mean(task_stats['train_examples']):.1f} ± {np.std(task_stats['train_examples']):.1f}")
    print(f"Test examples per task: {np.mean(task_stats['test_examples']):.1f} ± {np.std(task_stats['test_examples']):.1f}")
    
    grid_sizes = np.array(task_stats['grid_sizes'])
    print(f"Grid sizes - X: {grid_sizes[:, 0].min()}-{grid_sizes[:, 0].max()}, Y: {grid_sizes[:, 1].min()}-{grid_sizes[:, 1].max()}")
    
except Exception as e:
    print(f"✗ Error loading test data: {e}")
    test_challenges = {}

# Configuration parameters
CONFIG = {
    'max_iterations': 800,        # Training iterations per task
    'time_limit_per_task': 45,    # Seconds per task
    'early_stopping_threshold': 100,  # Steps before checking convergence
    'convergence_check_interval': 50,  # Steps between convergence checks
    'memory_cleanup_interval': 20,     # Tasks between memory cleanup
    'progress_report_interval': 10,    # Tasks between progress reports
    'learning_rate': 0.01,
    'adam_betas': (0.5, 0.9)
}

print("Configuration:")
for key, value in CONFIG.items():
    print(f"  {key}: {value}")

# Define enhanced task solving function with monitoring
def solve_task_with_monitoring(task_name, problem_data, config):
    """
    Solve a single ARC task with detailed monitoring and logging.
    """
    start_time = time.time()
    
    try:
        # Create task object
        task = preprocessing.Task(task_name, problem_data, None)
        
        # Initialize model and optimizer
        model = arc_compressor.ARCCompressor(task)
        optimizer = torch.optim.Adam(
            model.weights_list, 
            lr=config['learning_rate'], 
            betas=config['adam_betas']
        )
        logger = solution_selection.Logger(task)
        
        # Initialize default solutions
        default_solution = tuple(((0, 0), (0, 0)) for _ in range(task.n_test))
        logger.solution_most_frequent = default_solution
        logger.solution_second_most_frequent = default_solution
        
        # Training metrics
        metrics = {
            'steps_completed': 0,
            'final_loss': None,
            'convergence_step': None,
            'memory_used': 0
        }
        
        # Training loop
        for step in range(config['max_iterations']):
            if time.time() - start_time > config['time_limit_per_task']:
                print(f"  Time limit reached at step {step}")
                break
            
            # Training step
            train.take_step(task, model, optimizer, step, logger)
            metrics['steps_completed'] = step + 1
            
            # Store final loss
            if len(logger.loss_curve) > 0:
                metrics['final_loss'] = logger.loss_curve[-1]
            
            # Check for convergence
            if (step > config['early_stopping_threshold'] and 
                step % config['convergence_check_interval'] == 0):
                
                if (logger.solution_most_frequent is not None and 
                    logger.solution_second_most_frequent is not None and
                    logger.solution_most_frequent != default_solution):
                    
                    metrics['convergence_step'] = step
                    print(f"  Converged at step {step}")
                    break
        
        # Extract solutions
        solutions = []
        for example_num in range(task.n_test):
            if (logger.solution_most_frequent is not None and 
                len(logger.solution_most_frequent) > example_num):
                attempt_1 = [list(row) for row in logger.solution_most_frequent[example_num]]
            else:
                attempt_1 = [[0]]
                
            if (logger.solution_second_most_frequent is not None and 
                len(logger.solution_second_most_frequent) > example_num):
                attempt_2 = [list(row) for row in logger.solution_second_most_frequent[example_num]]
            else:
                attempt_2 = [[0]]
                
            solutions.append({'attempt_1': attempt_1, 'attempt_2': attempt_2})
        
        # Memory tracking
        if torch.cuda.is_available():
            metrics['memory_used'] = torch.cuda.max_memory_allocated() / 1e9  # GB
        
        # Cleanup
        del task, model, optimizer, logger
        if torch.cuda.is_available():
            torch.cuda.empty_cache()
        gc.collect()
        
        elapsed = time.time() - start_time
        metrics['elapsed_time'] = elapsed
        
        return solutions, metrics
        
    except Exception as e:
        print(f"  Error: {e}")
        elapsed = time.time() - start_time
        metrics = {
            'steps_completed': 0,
            'final_loss': None,
            'convergence_step': None,
            'memory_used': 0,
            'elapsed_time': elapsed,
            'error': str(e)
        }
        return [{'attempt_1': [[0]], 'attempt_2': [[0]]}], metrics

# Main solving loop with comprehensive monitoring
print("\n" + "="*60)
print("STARTING ARC-AGI TASK SOLVING")
print("="*60)

if not test_challenges:
    print("No test challenges loaded. Exiting.")
else:
    all_solutions = {}
    all_metrics = {}
    total_tasks = len(test_challenges)
    start_time = time.time()
    
    # Progress tracking
    successful_tasks = 0
    failed_tasks = 0
    total_steps = 0
    total_memory = 0
    
    print(f"Processing {total_tasks} tasks...\n")
    
    # Create progress bar
    pbar = tqdm(test_challenges.items(), desc="Solving tasks", unit="task")
    
    for i, (task_name, problem_data) in enumerate(pbar):
        pbar.set_description(f"Solving {task_name}")
        
        # Solve task
        solutions, metrics = solve_task_with_monitoring(task_name, problem_data, CONFIG)
        
        # Store results
        all_solutions[task_name] = solutions
        all_metrics[task_name] = metrics
        
        # Update statistics
        if 'error' in metrics:
            failed_tasks += 1
        else:
            successful_tasks += 1
        
        total_steps += metrics['steps_completed']
        total_memory += metrics['memory_used']
        
        # Update progress bar
        pbar.set_postfix({
            'Success': f"{successful_tasks}/{i+1}",
            'Avg_time': f"{metrics['elapsed_time']:.1f}s",
            'Steps': metrics['steps_completed']
        })
        
        # Periodic progress report
        if (i + 1) % CONFIG['progress_report_interval'] == 0:
            elapsed_total = time.time() - start_time
            avg_time = elapsed_total / (i + 1)
            estimated_total = avg_time * total_tasks
            remaining = estimated_total - elapsed_total
            
            print(f"\nProgress Update [{i+1}/{total_tasks}]:")
            print(f"  Successful: {successful_tasks}, Failed: {failed_tasks}")
            print(f"  Avg time per task: {avg_time:.1f}s")
            print(f"  Estimated remaining: {remaining/60:.1f} minutes")
            print(f"  Avg steps per task: {total_steps/(i+1):.0f}")
            print(f"  Avg memory per task: {total_memory/(i+1):.2f} GB")
        
        # Memory cleanup
        if (i + 1) % CONFIG['memory_cleanup_interval'] == 0:
            gc.collect()
            if torch.cuda.is_available():
                torch.cuda.empty_cache()
    
    pbar.close()
    print("\nTask solving completed!")

# Verify solutions and generate final statistics
print("\n" + "="*60)
print("SOLUTION VERIFICATION & STATISTICS")
print("="*60)

# Verify all tasks have solutions
missing_tasks = []
for task_name in test_challenges.keys():
    if task_name not in all_solutions:
        missing_tasks.append(task_name)
        print(f"⚠ Missing solution for {task_name}, adding default")
        all_solutions[task_name] = [{'attempt_1': [[0]], 'attempt_2': [[0]]}]

if missing_tasks:
    print(f"\nAdded default solutions for {len(missing_tasks)} missing tasks")
else:
    print("✓ All tasks have solutions!")

# Calculate final statistics
total_elapsed = time.time() - start_time
metrics_values = list(all_metrics.values())

print(f"\nFinal Statistics:")
print(f"  Total tasks: {len(all_solutions)}")
print(f"  Successful tasks: {successful_tasks}")
print(f"  Failed tasks: {failed_tasks}")
print(f"  Success rate: {successful_tasks/len(all_solutions)*100:.1f}%")
print(f"  Total time: {total_elapsed/60:.1f} minutes")
print(f"  Average time per task: {total_elapsed/len(all_solutions):.1f} seconds")

if metrics_values:
    avg_steps = np.mean([m['steps_completed'] for m in metrics_values])
    avg_memory = np.mean([m['memory_used'] for m in metrics_values])
    convergence_rate = sum(1 for m in metrics_values if m.get('convergence_step') is not None) / len(metrics_values)
    
    print(f"  Average steps per task: {avg_steps:.0f}")
    print(f"  Average memory per task: {avg_memory:.2f} GB")
    print(f"  Convergence rate: {convergence_rate*100:.1f}%")

# Analyze solution characteristics
solution_sizes = []
for task_solutions in all_solutions.values():
    for example_solution in task_solutions:
        for attempt in ['attempt_1', 'attempt_2']:
            grid = example_solution[attempt]
            solution_sizes.append((len(grid), len(grid[0]) if grid else 0))

solution_sizes = np.array(solution_sizes)
print(f"\nSolution Grid Sizes:")
print(f"  X dimension: {solution_sizes[:, 0].min()}-{solution_sizes[:, 0].max()} (avg: {solution_sizes[:, 0].mean():.1f})")
print(f"  Y dimension: {solution_sizes[:, 1].min()}-{solution_sizes[:, 1].max()} (avg: {solution_sizes[:, 1].mean():.1f})")

# Generate and save submission file
print("\n" + "="*60)
print("GENERATING SUBMISSION")
print("="*60)

try:
    # Save submission file
    submission_file = 'submission.json'
    print(f"Saving submission to {submission_file}...")
    
    with open(submission_file, 'w') as f:
        json.dump(all_solutions, f, separators=(',', ':'))  # Compact format
    
    # Verify submission file
    with open(submission_file, 'r') as f:
        verification = json.load(f)
    
    # Validation checks
    print("\nSubmission Validation:")
    
    if len(verification) == len(test_challenges):
        print(f"✓ Task count: {len(verification)} (matches expected)")
    else:
        print(f"✗ Task count: {len(verification)} (expected {len(test_challenges)})")
    
    # Check format
    format_valid = True
    for task_name, solutions in verification.items():
        if not isinstance(solutions, list):
            print(f"✗ Invalid format for {task_name}: not a list")
            format_valid = False
            break
        
        for i, solution in enumerate(solutions):
            if not isinstance(solution, dict):
                print(f"✗ Invalid format for {task_name}[{i}]: not a dict")
                format_valid = False
                break
            
            if 'attempt_1' not in solution or 'attempt_2' not in solution:
                print(f"✗ Missing attempts for {task_name}[{i}]")
                format_valid = False
                break
    
    if format_valid:
        print("✓ Format validation passed")
    
    # File size check
    file_size = os.path.getsize(submission_file) / (1024 * 1024)  # MB
    print(f"✓ File size: {file_size:.2f} MB")
    
    print(f"\n🎉 Submission file '{submission_file}' generated successfully!")
    
except Exception as e:
    print(f"✗ Error generating submission: {e}")
    traceback.print_exc()

# Optional: Save detailed metrics for analysis
print("\nSaving detailed metrics...")

try:
    metrics_file = 'detailed_metrics.json'
    with open(metrics_file, 'w') as f:
        json.dump(all_metrics, f, indent=2)
    print(f"✓ Metrics saved to {metrics_file}")
except Exception as e:
    print(f"⚠ Could not save metrics: {e}")

# Final summary
print("\n" + "="*60)
print("SUBMISSION COMPLETE")
print("="*60)
print(f"📁 Submission file: submission.json")
print(f"📊 Tasks solved: {len(all_solutions)}")
print(f"⏱️ Total time: {total_elapsed/60:.1f} minutes")
print(f"🎯 Success rate: {successful_tasks/len(all_solutions)*100:.1f}%")
print("\n🚀 Ready for submission to ARC Prize 2025!")
print("="*60)