In [None]:
import os
print(os.getcwd())

In [None]:
        !pip install matplotlib
        !pip install mediapy
        !pip install numpy
        # For advanced optimizations, consider installing the following:
        !pip install accelerate
        !pip install diffusers
        !pip install xformers
        !pip install seaborn

This section contains the set-up components for training the ctm model with the byte-level encoder with binary patches, ctm processing with synpase system set to multi-objective, 



and binary patches from the ctm (after 20 rounds of COT thinking) refined and trained with MCMC to encourage the model to have reasoning steps closely related to the best answer, 



Each epoch is saved as a safetensor checkpoint to preserve training progress.

# -----------------------------------------------------------------------------
# Dependency Installation Notes
# -----------------------------------------------------------------------------
 The following dependencies are required. Please install them in your Python environment,
 for example, using pip:

 pip install mediapy
 pip install torch
 pip install safetensors
 pip install numpy

 For advanced optimizations, consider installing the following:
 pip install flash-attn --no-build-isolation
 pip install deepspeed
 pip install accelerate
 pip install xformers

 It's recommended to use a virtual environment.
 -----------------------------------------------------------------------------

In [None]:
print("-----------------------------------------------------------------------------")
print("Dependency Setup & Imports")
print("-----------------------------------------------------------------------------")
print("Please ensure all required dependencies are installed.")
print("Base dependencies: torchaudio, imageio, mediapy, torch, safetensors, numpy")
print("Optional optimization dependencies: flash-attn, deepspeed, accelerate, xformers")
print("See comments at the top of the script for installation commands.")
print("-----------------------------------------------------------------------------")

import os
import sys
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
import json
import random
import time
from datetime import datetime
from PIL import Image
import matplotlib.pyplot as plt
import glob
from IPython.display import display, Markdown

# Define the base directory for saving checkpoints
CHECKPOINT_DIR = "checkpoints"
os.makedirs(CHECKPOINT_DIR, exist_ok=True)
print(f"Base Checkpoint Directory: {CHECKPOINT_DIR}")

# Attempt to import optional dependencies and set flags
try:
    from accelerate import Accelerator
    from accelerate.utils import DistributedDataParallelKwargs
    HAS_ACCELERATE = True
    print("Accelerate library found.")
except ImportError:
    HAS_ACCELERATE = False
    print("Accelerate library not found. Some features like multi-GPU training might be limited.")

try:
    import xformers.ops as xops
    HAS_XFORMERS = True
    print("xFormers library found.")
except ImportError:
    HAS_XFORMERS = False
    print("xFormers library not found.")
print("-----------------------------------------------------------------------------")

import numpy as np
import math
import time

# OPTIMIZATION: Advanced optimization imports
try:
    from accelerate import Accelerator
    ACCELERATE_AVAILABLE = True
    print("✅ Accelerate available")
except ImportError:
    ACCELERATE_AVAILABLE = False
    print("⚠️ Accelerate not available")

try:
    import xformers
    import xformers.ops
    XFORMERS_AVAILABLE = True
    print("✅ xFormers available - Expected 1.5-2x speedup")
except ImportError:
    XFORMERS_AVAILABLE = False
    print("⚠️ xFormers not available")

# Try to import mediapy, fallback if not available
try:
    import mediapy
    MEDIAPY_AVAILABLE = True
except ImportError:
    MEDIAPY_AVAILABLE = False
    print("Warning: mediapy not available. GIF preview will be limited.")

import matplotlib.pyplot as plt

from torch.nn import functional as F
from torch.nn.utils.rnn import pad_sequence
from torch.utils.checkpoint import checkpoint  # OPTIMIZATION: Gradient checkpointing

print("\n🚀 OPTIMIZATION STATUS:")
print(f"  ⚡ torch.compile: {'✅' if hasattr(torch, 'compile') else '❌'}")
print(f"  📈 Accelerate: {'✅' if ACCELERATE_AVAILABLE else '❌'}")
print(f"  ⚡ xFormers: {'✅' if XFORMERS_AVAILABLE else '❌'}")

# Add module paths
# IMPORTANT: These paths assume the script is run from a directory where '..'
# correctly points to the project root relative to 'models' and 'tasks' folders.
# Adjust if your script is located elsewhere.
print("\n-----------------------------------------------------------------------------")
print("Setting up module paths...")
print("-----------------------------------------------------------------------------")
try:
    current_script_path = os.path.dirname(os.path.abspath(__file__))
except NameError: # __file__ is not defined in interactive shells, use os.getcwd()
    current_script_path = os.getcwd()

module_paths = [
    os.path.abspath(os.path.join(current_script_path, '..', 'models')),
    os.path.abspath(os.path.join(current_script_path, '..'))
]
module_paths.append(os.path.abspath('contineous-thought-machines'))
for path in module_paths:
    if path not in sys.path:
        sys.path.append(path)
        print(f"Added to sys.path: {path}")

# Import Enhanced CTM with diffusion control and all optimizations.
# OPTIMIZED_CTM_CONFIG_ARC will be defined below if imports are successful.
print("\n-----------------------------------------------------------------------------")
print("Importing CTM and Dataloader modules...")
print("-----------------------------------------------------------------------------")
EnhancedCTMDiffusion = None # Initialize to None
ENHANCED_MCMC_AVAILABLE = False # Initialize

try:
    from models.ctm_Diffusion_NEWNEW import (
        EnhancedCTMDiffusion,
        #EnhancedCTMConfig, #This is turned off since it is included in the notebook set-up phase now to avoid undefined errors. 
        CTMControlledDiffusionProcessor,
        FrequencyDomainAwareAttention,
        IntegrationFlowHiPASampler,
        CTMIntegrationFlowTrainer,
    )
    print("✓ Successfully imported EnhancedCTMDiffusion with ALL GPU optimizations")
    print("  - Integration Flow one-step generation")
    print("  - Task-Aware HiPA frequency enhancement")
    print("  - CTM-guided diffusion control")
    print("  - GPU memory optimizations")
    print("  - Mixed precision training support")

except ImportError as e_ctm:
    print(f"❌ Error importing Enhanced CTM or related components: {e_ctm}")
    print("   Please ensure 'models/ctm_Diffusion_NEWNEW_.py' components exist and are accessible.")
    EnhancedCTMDiffusion = None

# -----------------------------------------------------------------------------
# Configuration for Integrated Diffusion CTM
# -----------------------------------------------------------------------------
print("\n-----------------------------------------------------------------------------")
print("Initializing Configuration for Integrated Diffusion CTM")
print("-----------------------------------------------------------------------------")

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

# 🚀 OPTIMIZATION 1: Enhanced Mixed Precision Training Setup (FP16/BF16)
USE_MIXED_PRECISION = torch.cuda.is_available() and hasattr(torch.cuda, 'amp') and device.type == 'cuda'
USE_BFLOAT16 = USE_MIXED_PRECISION and torch.cuda.is_bf16_supported()

autocast_dtype = torch.float32 # Default
if USE_MIXED_PRECISION:
    from torch.cuda.amp import GradScaler, autocast
    if USE_BFLOAT16:
        scaler = torch.amp.GradScaler("cuda", enabled=True)
        autocast_dtype = torch.bfloat16
        print("✅ Mixed precision training enabled (BF16) - Expected ~2x speedup")
    else:
        scaler = torch.amp.GradScaler("cuda", enabled=True)
        autocast_dtype = torch.float16
        print("✅ Mixed precision training enabled (FP16) - Expected ~2x speedup")
else:
    scaler = None
    class dummy_autocast:
        def __enter__(self): return None
        def __exit__(self, exc_type, exc_val, exc_tb): return False
    autocast = dummy_autocast
    autocast_dtype = torch.float32
    print("⚠️ Mixed precision training not available (CPU or older GPU or torch.cuda.amp not found)")

# 🚀 OPTIMIZATION 2: Gradient Accumulation Configuration
GRADIENT_ACCUMULATION_STEPS = 4
MAX_GRAD_NORM = 1.0

# 🚀 OPTIMIZATION 4: Data Loading Optimizations
OPTIMIZED_DATALOADER_CONFIG = {
    'num_workers': min(8, os.cpu_count() if os.cpu_count() else 1),
    'pin_memory': torch.cuda.is_available(),
    'persistent_workers': True if min(8, os.cpu_count() if os.cpu_count() else 1) > 0 else False,
    'prefetch_factor': 4 if min(8, os.cpu_count() if os.cpu_count() else 1) > 0 else None,
}

# General Training Parameters (can be overridden by specific phases)
LEARNING_RATE = 1e-4

-----------------------------------------------------------------------------
Dependency Setup & Imports
-----------------------------------------------------------------------------
Please ensure all required dependencies are installed.
Base dependencies: torchaudio, imageio, mediapy, torch, safetensors, numpy
Optional optimization dependencies: flash-attn, deepspeed, accelerate, xformers
See comments at the top of the script for installation commands.
-----------------------------------------------------------------------------
Base Checkpoint Directory: checkpoints
Accelerate library found.
xFormers library found.
-----------------------------------------------------------------------------
✅ Accelerate available
✅ xFormers available - Expected 1.5-2x speedup
✅ safetensors available for model checkpointing.

🚀 OPTIMIZATION STATUS:
  ⚡ torch.compile: ✅
  📈 Accelerate: ✅
  ⚡ xFormers: ✅

-----------------------------------------------------------------------------
Setting up module pa

# Injected MCMC Components and EnhancedCTMFenchelYoungIntegration Initialization

This notebook contains the Python code for MCMC components, including ARC-specific output spaces and an enhanced Fenchel-Young integration layer, structured for use in a Jupyter environment.

## 1. Imports
All necessary libraries and modules are imported here.

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import random
from typing import Optional, Callable, Tuple, Dict, Any, List, Union
from dataclasses import dataclass, field
import math
import numpy as np
import warnings # For ARCGridOutputSpace warnings

## 2. Utility Functions (from `models.utils`)
Helper functions used across various modules.

In [None]:
def add_coord_dim(x, scaled=True):
    """
    Adds a final dimension to the tensor representing 2D coordinates.
    """
    B, H, W = x.shape
    x_coords = torch.arange(W, device=x.device, dtype=x.dtype).repeat(H, 1)
    y_coords = torch.arange(H, device=x.device, dtype=x.dtype).unsqueeze(-1).repeat(1, W)
    if scaled:
        x_coords = x_coords / (W - 1) if W > 1 else torch.zeros_like(x_coords)
        y_coords = y_coords / (H - 1) if H > 1 else torch.zeros_like(y_coords)
    coords = torch.stack((x_coords, y_coords), dim=-1)
    coords = coords.unsqueeze(0) 
    coords = coords.repeat(B, 1, 1, 1) 
    return coords

def compute_normalized_entropy(logits, reduction='mean'):
    """
    Calculates the normalized entropy of a PyTorch tensor of logits along the 
    final dimension.
    """
    preds = F.softmax(logits, dim=-1)
    log_preds = torch.log_softmax(logits, dim=-1)
    entropy = -torch.sum(preds * log_preds, dim=-1)
    num_classes = preds.shape[-1]
    if num_classes <= 1: # Avoid log(1)=0 or log(0)
        return torch.zeros_like(entropy)
    max_entropy = torch.log(torch.tensor(num_classes, dtype=torch.float32, device=logits.device))
    if max_entropy == 0: # Should only happen if num_classes is 1
        return torch.zeros_like(entropy)
    normalized_entropy = entropy / max_entropy
    if len(logits.shape) > 2 and reduction == 'mean':
        normalized_entropy = normalized_entropy.flatten(1).mean(-1)
    return normalized_entropy

## 3. Core Modules (from `models.modules`)
Custom neural network layers.

In [None]:
class SuperLinear(nn.Module):
    def __init__(self,
                 in_dims,
                 out_dims,
                 N,
                 T=1.0,
                 do_norm=False,
                 dropout=0):
        super().__init__()
        self.dropout = nn.Dropout(dropout) if dropout > 0 else nn.Identity()
        self.in_dims = in_dims
        self.layernorm = nn.LayerNorm(in_dims, elementwise_affine=True) if do_norm else nn.Identity()
        self.do_norm = do_norm
        self.register_parameter('w1', nn.Parameter(
            torch.empty((in_dims, out_dims, N)).uniform_(
                -1/math.sqrt(in_dims + out_dims),
                 1/math.sqrt(in_dims + out_dims)
            ), requires_grad=True)
        )
        self.register_parameter('b1', nn.Parameter(torch.zeros((1, N, out_dims)), requires_grad=True))
        self.register_parameter('T', nn.Parameter(torch.Tensor([T]))) 

    def forward(self, x):
        out = self.dropout(x)
        out = self.layernorm(out)
        out = torch.einsum('BDM,MHD->BDH', out, self.w1) + self.b1
        out = out.squeeze(-1) / self.T
        return out

## 4. MCMC Interpretability Solver Components (from `models.mcmc_interpretability_solver`)
Dataclasses and hooks for tracking and interpreting MCMC processes and solver states.

In [None]:
@dataclass
class ThoughtStep:
    step_id: int
    layer_name: str
    input_state: Optional[torch.Tensor] = None
    output_state: Optional[torch.Tensor] = None
    attention_weights: Optional[torch.Tensor] = None
    mcmc_samples: Optional[torch.Tensor] = None
    confidence_score: float = 0.0
    reasoning_vector: Optional[torch.Tensor] = None
    energy_landscape: Dict[str, float] = field(default_factory=dict)
    correction_ratio: Optional[float] = None
    metadata: Dict[str, Any] = field(default_factory=dict)

@dataclass
class ReasoningChain:
    input_data: Optional[torch.Tensor] = None
    thought_steps: List[ThoughtStep] = field(default_factory=list)
    final_output: Optional[torch.Tensor] = None
    confidence_trajectory: List[float] = field(default_factory=list)
    decision_points: List[int] = field(default_factory=list)
    reasoning_summary: str = ""
    convergence_metrics: Dict[str, float] = field(default_factory=dict)
    solver_diagnostics: List[Dict[str, Any]] = field(default_factory=list)

class MCMCInterpretabilityHook:
    def __init__(self, layer_name: str):
        self.layer_name = layer_name
        self.activations: List[Dict[str, Any]] = []
        self.gradients: List[torch.Tensor] = []
        self.attention_maps: List[Optional[torch.Tensor]] = []
        self.mcmc_states: List[Optional[torch.Tensor]] = []
        self.energy_values: List[float] = []
        self.correction_ratios: List[Optional[float]] = []
        self.solver_diagnostics: List[Dict[str, Any]] = []

    def forward_hook(self, module, input_data, output_data):
        input_tensor = input_data[0] if isinstance(input_data, tuple) else input_data
        self.activations.append({
            'input': input_tensor.detach().clone() if torch.is_tensor(input_tensor) else input_tensor,
            'output': output_data.detach().clone() if torch.is_tensor(output_data) else output_data,
            'layer': self.layer_name,
            'timestamp': len(self.activations)
        })
        if hasattr(module, 'mcmc_samples') and module.mcmc_samples is not None:
            self.mcmc_states.append(module.mcmc_samples.detach().clone())
        if hasattr(module, 'attention_weights') and module.attention_weights is not None:
            self.attention_maps.append(module.attention_weights.detach().clone())
        if hasattr(module, 'correction_ratios_log') and module.correction_ratios_log: # Assuming a log attribute
            self.correction_ratios.append(module.correction_ratios_log[-1])
        if hasattr(module, 'solver_diagnostics_log') and module.solver_diagnostics_log: # Assuming a log attribute
            diagnostics = module.solver_diagnostics_log[-1]
            self.solver_diagnostics.append(diagnostics)
            if 'last_objective_value' in diagnostics:
                self.energy_values.append(diagnostics['last_objective_value'])
    
    def backward_hook(self, module, grad_input, grad_output):
        if grad_output and grad_output[0] is not None:
            self.gradients.append(grad_output[0].detach().clone())

class BlackBoxSolver:
    def __init__(self, model: nn.Module, device: str = 'cpu'): # Default to CPU if not specified
        self.model = model
        self.device = device
        self.hooks: Dict[str, MCMCInterpretabilityHook] = {}
        self.reasoning_chains: List[ReasoningChain] = []

    def _register_hooks(self):
        for name, module in self.model.named_modules():
            if any(keyword in name.lower() for keyword in ['mcmc', 'enhanced', 'correction', 'fenchel', 'oracle']):
                if name not in self.hooks:
                    hook = MCMCInterpretabilityHook(name)
                    self.hooks[name] = hook
                    module.register_forward_hook(hook.forward_hook)
    
    def clear_hooks_data(self):
        for hook in self.hooks.values():
            hook.activations.clear()
            hook.gradients.clear()
            hook.attention_maps.clear()
            hook.mcmc_states.clear()
            hook.energy_values.clear()
            hook.correction_ratios.clear()
            hook.solver_diagnostics.clear()

## 5. Base MCMC Components (from `models.fenchel_young_mcmc`)
Core classes for MCMC sampling, including configuration, temperature scheduling, and output space representation.

In [None]:
@dataclass
class MCMCConfig:
    num_chains: int = 3
    chain_length: int = 20
    burn_in: int = 5
    temperature_schedule: str = "geometric"
    initial_temp: float = 10.0
    final_temp: float = 1.0
    decay_rate: float = 0.995
    neighborhood_radius: int = 1 # This is a general parameter, interpretation depends on OutputSpace
    initialization_method: str = "persistent"

class TemperatureScheduler:
    @staticmethod
    def geometric(initial_temp: float, decay_rate: float, final_temp: float):
        def schedule(step: int) -> float:
            return max(initial_temp * (decay_rate ** step), final_temp)
        return schedule

    @staticmethod
    def linear(initial_temp: float, final_temp: float, total_steps: int):
        def schedule(step: int) -> float:
            progress = min(step / total_steps, 1.0) if total_steps > 0 else 1.0
            return initial_temp * (1 - progress) + final_temp * progress
        return schedule

    @staticmethod
    def constant(temperature: float):
        def schedule(step: int) -> float:
            return temperature
        return schedule

class DiscreteOutputSpace:
    def __init__(self, dimension: int):
        self.dimension = dimension
        self._full_output_space_generated = False
        self.output_space: List[torch.Tensor] = []
        if self.dimension <= 4: 
            try:
                self.output_space = self._generate_space()
                self._full_output_space_generated = True
            except (NotImplementedError, ValueError):
                self.output_space = []

    def _generate_space(self) -> List[torch.Tensor]:
        raise NotImplementedError("Subclasses must implement _generate_space or rely on _generate_random_member_directly")

    def get_available_neighborhood_strategies(self, state: Optional[torch.Tensor] = None) -> List[str]:
        raise NotImplementedError("Subclasses must implement get_available_neighborhood_strategies")

    def get_neighbors(self, state: torch.Tensor, strategy_name: str, **strategy_params) -> List[torch.Tensor]:
        raise NotImplementedError("Subclasses must implement get_neighbors")

    def get_proposal_prob(self, current_state: torch.Tensor, proposed_state: torch.Tensor, strategy_name: str, **strategy_params) -> float:
        neighbors = self.get_neighbors(current_state, strategy_name, **strategy_params)
        if not neighbors: return 0.0
        is_neighbor = any(torch.allclose(neighbor, proposed_state) for neighbor in neighbors)
        return (1.0 / len(neighbors)) if is_neighbor else 0.0
    
    def _generate_random_member_directly(self) -> Optional[torch.Tensor]:
        return None

    def random_state(self) -> torch.Tensor:
        direct_sample = self._generate_random_member_directly()
        if direct_sample is not None:
            return direct_sample
        if self._full_output_space_generated and self.output_space:
            return random.choice(self.output_space).clone()
        if not self.output_space and not self._full_output_space_generated:
            try:
                self.output_space = self._generate_space()
                self._full_output_space_generated = True
                if self.output_space:
                    return random.choice(self.output_space).clone()
            except (NotImplementedError, ValueError) as e:
                raise RuntimeError(f"Cannot generate random_state for {self.__class__.__name__} (dim {self.dimension}). Error: {e}")
        if not self.output_space:
             raise RuntimeError(f"Output space empty for {self.__class__.__name__} (dim {self.dimension}). Cannot sample random_state.")
        return random.choice(self.output_space).clone()

## 6. ARC Grid Output Space
A specific implementation of `DiscreteOutputSpace` for ARC-like grid environments.

In [None]:
class ARCGridOutputSpace(DiscreteOutputSpace):
    def __init__(self, dimension: int, grid_shape: Tuple[int, int], num_symbols: int):
        super().__init__(dimension)
        self.grid_shape = grid_shape
        self.num_symbols = num_symbols
        if dimension != grid_shape[0] * grid_shape[1]:
            raise ValueError(f"Dimension ({dimension}) must match grid_shape ({grid_shape[0]}*{grid_shape[1]}={grid_shape[0]*grid_shape[1]})")

    def _generate_random_member_directly(self) -> Optional[torch.Tensor]:
        random_grid = torch.randint(0, self.num_symbols, self.grid_shape, dtype=torch.long)
        return random_grid.view(-1).float()

    def get_available_neighborhood_strategies(self, state: Optional[torch.Tensor] = None) -> List[str]:
        return ["flip_one_cell_value", "swap_two_cells"]

    def get_neighbors(self, state: torch.Tensor, strategy_name: str, **strategy_params) -> List[torch.Tensor]:
        neighbors = []
        state_grid = state.view(self.grid_shape).long()

        if strategy_name == "flip_one_cell_value":
            num_neighbors_to_generate = strategy_params.get('num_flips', min(5, self.dimension))
            for _ in range(num_neighbors_to_generate):
                neighbor_grid = state_grid.clone()
                row = random.randint(0, self.grid_shape[0] - 1)
                col = random.randint(0, self.grid_shape[1] - 1)
                original_value = neighbor_grid[row, col].item()
                
                if self.num_symbols <= 1:
                    new_value = original_value
                else:
                    new_value = random.randint(0, self.num_symbols - 1)
                    while new_value == original_value:
                        new_value = random.randint(0, self.num_symbols - 1)
                neighbor_grid[row, col] = new_value
                neighbors.append(neighbor_grid.view(-1).float())
        
        elif strategy_name == "swap_two_cells":
            num_neighbors_to_generate = strategy_params.get('num_swaps', min(5, self.dimension // 2 if self.dimension >=2 else 0))
            for _ in range(num_neighbors_to_generate):
                if self.dimension < 2: break
                neighbor_grid = state_grid.clone()
                r1, c1 = random.randint(0, self.grid_shape[0] - 1), random.randint(0, self.grid_shape[1] - 1)
                r2, c2 = random.randint(0, self.grid_shape[0] - 1), random.randint(0, self.grid_shape[1] - 1)
                while r1 == r2 and c1 == c2:
                    r2, c2 = random.randint(0, self.grid_shape[0] - 1), random.randint(0, self.grid_shape[1] - 1)
                
                val1 = neighbor_grid[r1,c1].item()
                neighbor_grid[r1,c1] = neighbor_grid[r2,c2].item()
                neighbor_grid[r2,c2] = val1
                neighbors.append(neighbor_grid.view(-1).float())
        else:
            warnings.warn(f"Unknown strategy: {strategy_name} for ARCGridOutputSpace. Returning empty neighbor list.")
        return neighbors

    def _generate_space(self) -> List[torch.Tensor]:
        if self.dimension > 6:
            warnings.warn(f"Full space generation for ARCGridOutputSpace with dimension {self.dimension} is too large. Returning empty list.")
            return []
        return super()._generate_space()

## 7. Enhanced MCMC Layers and Fenchel-Young Integration
Includes `ExactOptimizationOracle`, MCMC samplers (`CorrectionRatioMCMC`, `LargeNeighborhoodSearchMCMC`), and the main `EnhancedCTMFenchelYoungIntegration` module.

In [None]:

class ExactOptimizationOracle:
    def __init__(self, output_space: DiscreteOutputSpace, phi_network: Optional[nn.Module] = None, model: Optional[nn.Module] = None):
        self.output_space = output_space
        self.phi_network = phi_network
        self.solver_state: Dict[str, Any] = {
            'last_solution': None, 'last_objective_value': None,
            'num_evaluations': 0, 'optimization_history': []
        }

    def solve(self, theta: torch.Tensor, neighborhood: Optional[List[torch.Tensor]] = None) -> Optional[torch.Tensor]:
        search_space = neighborhood
        if search_space is None:
            if hasattr(self.output_space, '_generate_random_member_directly') and \
               (not hasattr(self.output_space, 'output_space') or not self.output_space.output_space):
                # Try to generate a small random search space if the full one is not available/too large
                # Ensure getattr has a default for 'dimension' if it might be missing
                dimension_val = getattr(self.output_space, 'dimension', 20) 
                search_space = [self.output_space._generate_random_member_directly() for _ in range(min(dimension_val, 20))]
                search_space = [s for s in search_space if s is not None]
            elif hasattr(self.output_space, 'output_space'): # Check if output_space attribute exists
                search_space = self.output_space.output_space
            else: # Fallback if no way to get/generate search space
                search_space = []

        # Initialize num_evaluations at the beginning of the method.
        # It's reset per call to solve.
        self.solver_state['num_evaluations'] = 0
        
        # Ensure optimization_history is initialized if it's not already present
        if 'optimization_history' not in self.solver_state:
            self.solver_state['optimization_history'] = []

        if not search_space:
            self.solver_state['last_solution'] = None
            self.solver_state['last_objective_value'] = float('-inf')
            self.solver_state['optimization_history'].append({
                'solution': None,
                'value': float('-inf'),
                'search_space_size': 0
            })
            return None

        is_batched = theta.ndim == 2
        batch_size = theta.shape[0] if is_batched else 1

        if is_batched:
            # Stores the best candidate tensor for each item in the batch
            best_solution_list: List[Optional[torch.Tensor]] = [None] * batch_size 
            best_value_tensor = torch.full((batch_size,), float('-inf'), device=theta.device, dtype=theta.dtype)
        else:
            best_solution_single: Optional[torch.Tensor] = None
            best_value_scalar = float('-inf')
        
        for candidate_state_maybe_none in search_space:
            if candidate_state_maybe_none is None:
                continue
            # Ensure candidate is on the same device as theta and has the same dtype
            candidate = candidate_state_maybe_none.to(device=theta.device, dtype=theta.dtype)

            current_objective_value: Union[torch.Tensor, float] 

            if is_batched:
                # theta is [B, D'], candidate is [D'] -> objective_value_batch is [B]
                current_objective_value = torch.mv(theta, candidate)
            else:
                # theta is [D'], candidate is [D'] -> objective_value_scalar is scalar
                current_objective_value = torch.dot(theta, candidate)

            if self.phi_network is not None:
                # candidate is [D'], phi_network expects [N, D']
                phi_input = candidate.unsqueeze(0) # [1, D']
                phi_val = self.phi_network(phi_input) # Output [1, 1] or [1]
                
                # Squeeze to make it a scalar or 1D tensor if it was [1,1] or [1]
                phi_val_squeezed = phi_val.squeeze()
                
                # Ensure phi_val_squeezed is a scalar tensor before adding
                if phi_val_squeezed.ndim > 0 and phi_val_squeezed.numel() == 1:
                    phi_val_squeezed = phi_val_squeezed.squeeze()

                # Add scalar phi_val to objective_value (scalar or [B] tensor)
                # This works due to broadcasting if current_objective_value is [B]
                current_objective_value = current_objective_value + phi_val_squeezed # Ensure it's an assignment
            
            if is_batched:
                # current_objective_value is a tensor of shape [B]
                # best_value_tensor is a tensor of shape [B]
                improved_mask = current_objective_value > best_value_tensor
                best_value_tensor[improved_mask] = current_objective_value[improved_mask]
                for i in range(batch_size):
                    if improved_mask[i]:
                        best_solution_list[i] = candidate.clone()
            else: # not batched, current_objective_value is a scalar float or 0-dim tensor
                obj_val_float: float
                if isinstance(current_objective_value, torch.Tensor): # Ensure it's a Python float for comparison
                    obj_val_float = current_objective_value.item()
                else:
                    # This case should ideally not happen if operations are tensor-based
                    obj_val_float = float(current_objective_value) 
                
                if obj_val_float > best_value_scalar:
                    best_value_scalar = obj_val_float
                    best_solution_single = candidate.clone()
            
            self.solver_state['num_evaluations'] += 1

        # Update solver_state and determine return value
        final_return_solution: Optional[torch.Tensor] = None

        if is_batched:
            # For solver_state, use the first item of the batch as a compromise
            actual_best_solution_for_state = best_solution_list[0] if best_solution_list and best_solution_list[0] is not None else None
            actual_best_value_for_state = float(best_value_tensor[0].item()) if best_value_tensor.numel() > 0 else float('-inf')

            self.solver_state['last_solution'] = actual_best_solution_for_state.clone() if actual_best_solution_for_state is not None else None
            self.solver_state['last_objective_value'] = actual_best_value_for_state
            self.solver_state['optimization_history'].append({
                'solution': actual_best_solution_for_state.clone().cpu().numpy() if actual_best_solution_for_state is not None else None,
                'value': actual_best_value_for_state,
                'search_space_size': len(search_space)
            })

            # For return value: if any item in batch failed to find a solution, return None. Otherwise, stack.
            if any(s is None for s in best_solution_list):
                final_return_solution = None
            else:
                # All solutions are tensors, safe to stack.
                # Need to cast best_solution_list to List[torch.Tensor] for stack
                final_return_solution = torch.stack([s for s in best_solution_list if s is not None])
        
        else: # not batched
            self.solver_state['last_solution'] = best_solution_single.clone() if best_solution_single is not None else None
            self.solver_state['last_objective_value'] = float(best_value_scalar)
            self.solver_state['optimization_history'].append({
                'solution': best_solution_single.clone().cpu().numpy() if best_solution_single is not None else None,
                'value': float(best_value_scalar),
                'search_space_size': len(search_space)
            })
            final_return_solution = best_solution_single
            
        return final_return_solution

    def get_solver_state(self) -> Dict[str, Any]:
        return self.solver_state.copy()

    def set_solver_parameters(self, params: Dict[str, Any]) -> None:
        if 'reset_history' in params and params['reset_history']:
            self.solver_state['optimization_history'] = []
            self.solver_state['num_evaluations'] = 0


class CorrectionRatioMCMC(nn.Module):
    def __init__(self,
                 output_space: DiscreteOutputSpace,
                 config: MCMCConfig,
                 phi_network: Optional[nn.Module] = None,
                 exact_oracle: Optional[ExactOptimizationOracle] = None):
        super().__init__()
        self.output_space = output_space
        self.config = config
        self.phi_network = phi_network
        self.exact_oracle = exact_oracle
        self.temp_scheduler = self._create_temperature_scheduler()
        self.persistent_states: Optional[List[Optional[torch.Tensor]]] = None
        self.correction_ratios_log: List[float] = []
        self.solver_diagnostics_log: List[Dict[str, Any]] = []
        self.step_count = 0

    def _create_temperature_scheduler(self) -> Callable[[int], float]:
        if self.config.temperature_schedule == "geometric":
            return TemperatureScheduler.geometric(self.config.initial_temp, self.config.decay_rate, self.config.final_temp)
        elif self.config.temperature_schedule == "linear":
            return TemperatureScheduler.linear(self.config.initial_temp, self.config.final_temp, self.config.chain_length)
        else:
            return TemperatureScheduler.constant(self.config.final_temp)

    def phi_function(self, state: torch.Tensor) -> torch.Tensor:
        if self.phi_network is not None:
            state_for_phi = state.unsqueeze(0) if state.dim() == self.output_space.dimension.bit_length() else state
            if state_for_phi.dim() == 1:
                state_for_phi = state_for_phi.unsqueeze(0)

            return self.phi_network(state_for_phi).squeeze()
        return torch.tensor(0.0, device=state.device)

    def compute_correction_ratio(self, current: torch.Tensor, proposal: torch.Tensor, theta: torch.Tensor,
                                 strategy_name: str, strategy_params: Dict[str, Any]) -> float:
        if strategy_name == "LNS":
            # For LNS, the proposal mechanism is different and typically involves an oracle.
            # A common simplification is to assume the correction ratio is 1.0,
            # effectively treating the LNS proposal as symmetric for the correction term.
            # This means acceptance relies primarily on the energy difference.
            # A more rigorous treatment would require defining q_LNS(y'|y) and q_LNS(y|y')
            # based on the LNS oracle's behavior.
            return 1.0

        q_proposal_given_current = self.output_space.get_proposal_prob(current, proposal, strategy_name, **strategy_params)
        q_current_given_proposal = self.output_space.get_proposal_prob(proposal, current, strategy_name, **strategy_params)

        if q_proposal_given_current == 0: # Cannot propose this move
            return 0.0 
        if q_current_given_proposal == 0: # Cannot reverse this move via proposal
            # If q(y|y') is 0, the detailed balance implies acceptance should be 0 
            # unless E(y') is drastically lower than E(y) and the exp term dominates.
            # Setting correction to 0 ensures this.
            return 0.0
        
        correction = q_current_given_proposal / q_proposal_given_current
        return correction


    def enhanced_acceptance_ratio(self, current: torch.Tensor, proposal: torch.Tensor, theta: torch.Tensor,
                                temperature: float, strategy_name: str, strategy_params: Dict[str, Any]) -> float:
        current_energy = torch.dot(theta, current.squeeze()) + self.phi_function(current.squeeze())
        proposal_energy = torch.dot(theta, proposal.squeeze()) + self.phi_function(proposal.squeeze())
        energy_diff = proposal_energy - current_energy
        
        correction_factor = self.compute_correction_ratio(current, proposal, theta, strategy_name, strategy_params)
        self.correction_ratios_log.append(correction_factor)

        if correction_factor < 0:
            correction_factor = 0.0
        if temperature <= 1e-9:
            return float('inf') if energy_diff <= 0 and correction_factor > 1e-9 else 0.0
        
        exp_term = torch.exp(energy_diff / temperature)
        acceptance_term_pk = float(correction_factor * exp_term)
        return max(0.0, acceptance_term_pk)

    def large_neighborhood_search_step(self, current_state: torch.Tensor, theta: torch.Tensor,
                                     neighborhood_size: int = 5) -> Optional[torch.Tensor]:
        if self.exact_oracle is None:
            available_strategies = self.output_space.get_available_neighborhood_strategies(current_state)
            if not available_strategies:
                return current_state
            chosen_strategy = random.choice(available_strategies)
            s_params = {'radius': 1} if 'radius' in chosen_strategy else {'num_flips':1} if 'flip' in chosen_strategy else {}

            neighbors = self.output_space.get_neighbors(current_state, chosen_strategy, **s_params)
            return random.choice(neighbors) if neighbors else current_state

        large_neighborhood: List[torch.Tensor] = []
        strat_params = {'num_flips': neighborhood_size // 2, 'num_swaps': neighborhood_size // 2}
        for strat_name in self.output_space.get_available_neighborhood_strategies(current_state):
            large_neighborhood.extend(self.output_space.get_neighbors(current_state, strat_name, **strat_params))
            if len(large_neighborhood) >= neighborhood_size:
                break
        
        while len(large_neighborhood) < neighborhood_size:
            random_s = self.output_space.random_state()
            if not any(torch.allclose(random_s, existing) for existing in large_neighborhood):
                large_neighborhood.append(random_s)
        
        large_neighborhood = large_neighborhood[:min(len(large_neighborhood), neighborhood_size * 2)]

        best_solution = self.exact_oracle.solve(theta, large_neighborhood)
        if self.exact_oracle.solver_state:
             self.solver_diagnostics_log.append(self.exact_oracle.get_solver_state())
        return best_solution if best_solution is not None else current_state

    def sample_chain_corrected(self, theta: torch.Tensor, chain_id: int = 0,
                               target_y: Optional[torch.Tensor] = None, #target_state was changed to target_y to avoid errors. 
                               use_large_neighborhood_step_flag: bool = False
                               ) -> Tuple[List[torch.Tensor], Dict[str, float]]:
        # theta is now expected to be a 1D tensor for the current chain/batch item.
        if self.config.initialization_method == "persistent" and self.persistent_states is not None and \
           chain_id < len(self.persistent_states) and self.persistent_states[chain_id] is not None:
            current_state = self.persistent_states[chain_id].clone().to(theta.device)
        elif self.config.initialization_method == "data_based" and target_y is not None:
            current_state = target_y.clone().to(theta.device) # target_state is also 1D here
        else:
            current_state = self.output_space.random_state().to(theta.device)

        samples = []
        acceptances = 0
        total_steps_for_chain = self.config.chain_length + self.config.burn_in
        
        temperature = self.config.initial_temp # Initialize temperature for the loop

        for step_idx in range(total_steps_for_chain):
            temperature = self.temp_scheduler(step_idx)
            proposal = None
            chosen_strategy_name = "unknown"
            strategy_params: Dict[str, Any] = {}

            perform_lns_this_iteration = False
            if use_large_neighborhood_step_flag and isinstance(self, LargeNeighborhoodSearchMCMC) and self.exact_oracle:
                lns_freq = getattr(self, 'lns_frequency', 10) 
                if lns_freq > 0 and (step_idx + 1) % lns_freq == 0:
                    perform_lns_this_iteration = True
            
            if perform_lns_this_iteration and isinstance(self, LargeNeighborhoodSearchMCMC):
                lns_hood_size = getattr(self, 'lns_neighborhood_size', 5)
                proposal = self.large_neighborhood_search_step(current_state, theta, lns_hood_size) # Pass 1D theta
                chosen_strategy_name = "LNS"
                strategy_params = {'lns_generated': True}
            else:
                available_strategies = self.output_space.get_available_neighborhood_strategies(current_state)
                if not available_strategies:
                    if step_idx >= self.config.burn_in:
                        samples.append(current_state.clone())
                    continue
                chosen_strategy_name = random.choice(available_strategies)
                
                if "radius" in chosen_strategy_name:
                    strategy_params['radius'] = self.config.neighborhood_radius
                elif "flip" in chosen_strategy_name:
                    strategy_params['num_flips'] = 1
                elif "swap" in chosen_strategy_name:
                    strategy_params['num_swaps'] = 1
                else:
                    strategy_params['radius'] = self.config.neighborhood_radius

                neighbors = self.output_space.get_neighbors(current_state, chosen_strategy_name, **strategy_params)
                if not neighbors:
                    if step_idx >= self.config.burn_in:
                        samples.append(current_state.clone())
                    continue
                proposal = random.choice(neighbors)
            
            if proposal is None:
                if step_idx >= self.config.burn_in:
                    samples.append(current_state.clone())
                continue
            
            proposal = proposal.to(theta.device)

            # MODIFIED LINE: Pass 1D theta directly
            acceptance_term_pk = self.enhanced_acceptance_ratio(current_state, proposal, theta, temperature, chosen_strategy_name, strategy_params)
            
            if random.random() < min(1.0, acceptance_term_pk):
                current_state = proposal
                acceptances += 1
            
            if step_idx >= self.config.burn_in:
                samples.append(current_state.clone())
        
        if self.persistent_states is None or len(self.persistent_states) != self.config.num_chains:
             self.persistent_states = [None for _ in range(self.config.num_chains)] # Should match num_chains for indexing
        if chain_id < len(self.persistent_states): # Ensure chain_id is a valid index
            self.persistent_states[chain_id] = current_state.clone()
        else:
            # This case should ideally not be reached if chain_id is always < self.config.num_chains
            warnings.warn(f"chain_id {chain_id} is out of bounds for persistent_states (len {len(self.persistent_states)}). Skipping persistence update for this chain.")

        stats = {
            'acceptance_rate': acceptances / total_steps_for_chain if total_steps_for_chain > 0 else 0.0,
            'final_temperature': temperature,
            'chain_length_collected': len(samples)
        }
        return samples, stats

    def estimate_expectation_with_corrections(self, theta_batch: torch.Tensor, target_state_batch: Optional[torch.Tensor] = None,
                                              use_large_neighborhood: bool = False
                                              ) -> Tuple[torch.Tensor, Dict[str, Any]]:
        # theta_batch is expected to be (batch_size, feature_dim)
        # target_state_batch is expected to be (batch_size, feature_dim) or None
        
        if theta_batch.ndim == 1: # If a single theta is passed, unsqueeze to make it a batch of 1
            theta_batch = theta_batch.unsqueeze(0)
            if target_state_batch is not None and target_state_batch.ndim == 1:
                target_state_batch = target_state_batch.unsqueeze(0)

        batch_size = theta_batch.shape[0]
        all_batch_expectations: List[torch.Tensor] = []
        all_batch_stats_collector: List[Dict[str, Any]] = []

        # Initialize persistent_states if needed. It's a list of length self.config.num_chains.
        # Each call to sample_chain_corrected for a given chain_id will use/update the corresponding persistent state.
        if self.persistent_states is None or len(self.persistent_states) != self.config.num_chains:
            self.persistent_states = [None for _ in range(self.config.num_chains)]

        for i in range(batch_size):
            current_theta_item = theta_batch[i]  # This is 1D: (feature_dim,)
            current_target_state_item = None
            if target_state_batch is not None:
                if target_state_batch.shape[0] == batch_size: # Ensure target_state_batch matches batch size
                    current_target_state_item = target_state_batch[i] # Also 1D
                elif batch_size == 1 and target_state_batch.ndim == 1: # Special case: single item batch, target might be 1D
                     current_target_state_item = target_state_batch
            # Else, if target_state_batch is not None but dimensions mismatch, current_target_state_item remains None
            # or you could raise an error or warning. For now, it defaults to None if not perfectly aligned.

            item_all_samples: List[torch.Tensor] = []
            item_all_chain_stats: List[Dict[str, float]] = []

            for chain_id in range(self.config.num_chains):
                is_lns_sampler = isinstance(self, LargeNeighborhoodSearchMCMC)
                # sample_chain_corrected now receives 1D theta (current_theta_item)
                # and 1D target_state (current_target_state_item)
                samples, chain_stats_for_chain = self.sample_chain_corrected(
                    current_theta_item,
                    chain_id,
                    current_target_state_item,
                    use_large_neighborhood_step_flag=(use_large_neighborhood and is_lns_sampler)
                )
                item_all_samples.extend(samples)
                item_all_chain_stats.append(chain_stats_for_chain)

            if not item_all_samples:
                # Fallback for this specific batch item
                # Ensure output_space and its dimension attribute are correctly defined
                item_expectation = torch.zeros(self.output_space.dimension, device=current_theta_item.device, dtype=current_theta_item.dtype)
                warnings.warn(f"No MCMC samples collected for batch item {i}. Returning zeros for this item.")
                item_stats_dict = {'error': f'No samples collected for batch item {i}', 'num_samples': 0, 'avg_acceptance_rate': 0.0, 'sample_entropy': 0.0, 'chain_stats': []}
            else:
                item_expectation = torch.mean(torch.stack(item_all_samples).float(), dim=0)
                avg_acceptance_item = np.mean([s['acceptance_rate'] for s in item_all_chain_stats if 'acceptance_rate' in s]) if item_all_chain_stats else 0.0
                sample_entropy_item = compute_normalized_entropy(torch.stack(item_all_samples).detach().cpu()) if item_all_samples else 0.0
                item_stats_dict = {
                    'num_samples': len(item_all_samples),
                    'avg_acceptance_rate': float(avg_acceptance_item),
                    'sample_entropy': sample_entropy_item.tolist() if isinstance(sample_entropy_item, torch.Tensor) else float(sample_entropy_item),
                    'chain_stats': item_all_chain_stats
                }
            
            all_batch_expectations.append(item_expectation)
            all_batch_stats_collector.append(item_stats_dict)

        if not all_batch_expectations: # Handles batch_size = 0
            fallback_dim = theta_batch.shape[1] if theta_batch.ndim == 2 and theta_batch.shape[0] == 0 else self.output_space.dimension
            final_expectation = torch.empty(0, fallback_dim, device=theta_batch.device, dtype=theta_batch.dtype)
            combined_summary_stats = {'error': 'No batch items processed or all failed', 'batch_item_stats': [], 'overall_avg_acceptance_rate': 0.0, 'total_samples_collected': 0}
            warnings.warn(f"No expectations computed for any batch item. Returning empty tensor.")
            return final_expectation, combined_summary_stats

        final_expectation = torch.stack(all_batch_expectations) # Stack to get (B, D)

        # Aggregate statistics
        overall_avg_acceptance = 0.0
        total_samples = 0
        if all_batch_stats_collector:
            rates = [s['avg_acceptance_rate'] for s in all_batch_stats_collector if s.get('num_samples', 0) > 0]
            if rates:
                overall_avg_acceptance = np.mean(rates)
            total_samples = sum(s.get('num_samples', 0) for s in all_batch_stats_collector)
    
        combined_summary_stats = {
            'batch_item_stats': all_batch_stats_collector, # Detailed stats per item
            'overall_avg_acceptance_rate': float(overall_avg_acceptance),
            'total_samples_collected': total_samples
        }
        return final_expectation, combined_summary_stats
    
    def forward(self, theta: torch.Tensor, target: torch.Tensor, use_large_neighborhood: bool = False
               ) -> Tuple[torch.Tensor, Dict[str, Any]]:
        expectation, stats = self.estimate_expectation_with_corrections(theta, target_y=target, use_large_neighborhood=use_large_neighborhood)
        return expectation, stats


class LargeNeighborhoodSearchMCMC(CorrectionRatioMCMC):
    def __init__(self,
                 output_space: DiscreteOutputSpace,
                 config: MCMCConfig,
                 phi_network: Optional[nn.Module] = None,
                 lns_frequency: int = 10,
                 lns_neighborhood_size: int = 20):
        exact_oracle = ExactOptimizationOracle(output_space, phi_network)
        super().__init__(output_space, config, phi_network, exact_oracle)
        self.lns_frequency = lns_frequency
        self.lns_neighborhood_size = lns_neighborhood_size


class EnhancedCTMFenchelYoungIntegration(nn.Module):
    def __init__(self,
                 input_dim: int,
                 output_space: DiscreteOutputSpace,
                 mcmc_config: MCMCConfig,
                 hidden_dim: int = 256,
                 num_thought_steps: int = 5,
                 use_large_neighborhood_search: bool = True,
                 lns_frequency: int = 10,
                 lns_neighborhood_size: int = 20):
        super().__init__()
        self.output_space_dim = output_space.dimension
        
        self.thought_network = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, self.output_space_dim)
        )
        
        self.phi_network = nn.Sequential(
            nn.Linear(self.output_space_dim, hidden_dim // 2),
            nn.ReLU(),
            nn.Linear(hidden_dim // 2, 1)
        )
        
        if use_large_neighborhood_search:
            self.mcmc_sampler: Union[LargeNeighborhoodSearchMCMC, CorrectionRatioMCMC] = LargeNeighborhoodSearchMCMC(
                output_space=output_space, config=mcmc_config, phi_network=self.phi_network,
                lns_frequency=lns_frequency, lns_neighborhood_size=lns_neighborhood_size
            )
        else:
            self.mcmc_sampler = CorrectionRatioMCMC(
                output_space=output_space, config=mcmc_config, phi_network=self.phi_network
            )
        self.num_thought_steps = num_thought_steps

    def forward(self, x: torch.Tensor, target_y: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, Dict[str, Any]]:
        theta = self.thought_network(x)
        
        expectation_y, mcmc_stats = self.mcmc_sampler.estimate_expectation_with_corrections(
            theta, target_y,
            use_large_neighborhood=isinstance(self.mcmc_sampler, LargeNeighborhoodSearchMCMC)
        )
        
        # The Fenchel-Young loss is typically <theta, E[y]> - <theta, y_target>
        # The gradient w.r.t. theta is simply E[y] - y_target
        loss = torch.sum(theta * (expectation_y.detach() - target_y))
        
        return loss, expectation_y, mcmc_stats

## 8. Instantiation and Configuration
This section defines necessary global configuration variables and then instantiates the core components.

In [None]:
# --- Configuration Variables ---
# These variables define the ARC environment and MCMC behavior.
# NOTE: You must provide the paths to your ARC dataset directories.
ARC_TRAIN_DIR = "../data/training" # <<< IMPORTANT: SET THIS PATH
ARC_EVAL_DIR = "../data/evaluation"   # <<< IMPORTANT: SET THIS PATH

MAX_GRID_SIZE = (30, 30)
NUM_ARC_SYMBOLS = 10
PADDING_VALUE = -1 # A value not in 0-9 to be ignored by the loss function
MAX_DEMO_PAIRS = 5 # Max number of demonstration pairs to consider for context

# Configuration for ARC-AGI-2 Training (shared constants)
ARC_INPUT_FLAT_DIM = MAX_GRID_SIZE[0] * MAX_GRID_SIZE[1]

# MCMC Configuration for ARC
MCMC_OUTPUT_SPACE_DIM = ARC_INPUT_FLAT_DIM
MCMC_CONFIG_ARC = MCMCConfig(
    num_chains=3, 
    chain_length=20,
    burn_in=5,
    initial_temp=5.0,
    final_temp=1.0,
    temperature_schedule="geometric",
    decay_rate=0.95,
    neighborhood_radius=1
)
ENABLE_CTM_MCMC_INTEGRATION_FOR_ARC = True

print(f"Using MAX_GRID_SIZE: {MAX_GRID_SIZE}")
print(f"Using NUM_ARC_SYMBOLS: {NUM_ARC_SYMBOLS}")
print(f"Using ARC_INPUT_FLAT_DIM: {ARC_INPUT_FLAT_DIM}")
print(f"Using MCMC_CONFIG_ARC: chains={MCMC_CONFIG_ARC.num_chains}, length={MCMC_CONFIG_ARC.chain_length}")

Using MAX_GRID_SIZE: (30, 30)
Using NUM_ARC_SYMBOLS: 10
Using ARC_INPUT_FLAT_DIM: 900
Using MCMC_CONFIG_ARC: chains=3, length=20


In [None]:
# --- Instantiation ---
arc_grid_output_space = ARCGridOutputSpace(
    dimension=ARC_INPUT_FLAT_DIM,
    grid_shape=MAX_GRID_SIZE,
    num_symbols=NUM_ARC_SYMBOLS
)

ctm_encoder_output_dim = ARC_INPUT_FLAT_DIM 

enhanced_ctm_mcmc = None
if ENABLE_CTM_MCMC_INTEGRATION_FOR_ARC:
    enhanced_ctm_mcmc = EnhancedCTMFenchelYoungIntegration(
        input_dim=ctm_encoder_output_dim, 
        output_space=arc_grid_output_space,
        mcmc_config=MCMC_CONFIG_ARC,
        use_large_neighborhood_search=True,
        lns_frequency=5,
        lns_neighborhood_size=10
    )

    print(f"\nEnhancedCTMFenchelYoungIntegration module initialized.")
    print(f"  Output space dimension: {enhanced_ctm_mcmc.output_space_dim}")
    if isinstance(enhanced_ctm_mcmc.mcmc_sampler, LargeNeighborhoodSearchMCMC):
        print(f"  MCMC sampler type: LargeNeighborhoodSearchMCMC")
        print(f"    LNS Frequency: {enhanced_ctm_mcmc.mcmc_sampler.lns_frequency}")
        print(f"    LNS Neighborhood Size: {enhanced_ctm_mcmc.mcmc_sampler.lns_neighborhood_size}")
    else:
        print("  MCMC sampler type: CorrectionRatioMCMC")
    ENHANCED_MCMC_AVAILABLE = True
    print(f"\nENHANCED_MCMC_AVAILABLE set to: {ENHANCED_MCMC_AVAILABLE}")
else:
    print("\nMCMC Integration is disabled for ARC.")
    ENHANCED_MCMC_AVAILABLE = False


EnhancedCTMFenchelYoungIntegration module initialized.
  Output space dimension: 900
  MCMC sampler type: LargeNeighborhoodSearchMCMC
    LNS Frequency: 5
    LNS Neighborhood Size: 10

ENHANCED_MCMC_AVAILABLE set to: True


# --- ARC Dataset and Dataloader Logic --- #All Module Paths should now be defined since the modules that are not on path are automatically added to path. 

# Note: The function `pad_grid` is called but not defined in the original source.
# It is required for the NewCustomARCGridDataset to function correctly.
# You must provide its definition. A placeholder is provided below.

In [None]:
# --- Context: 2D Grid Padding (from original code) ---
# This function handles padding at the 2D grid level, before serialization.
def pad_grid(grid_list, max_dims, pad_value):
    """Pads a 2D grid to specified maximum dimensions."""
    grid_np = np.array(grid_list, dtype=np.int32)
    padded_grid = np.full(max_dims, pad_value, dtype=np.int32)
    h, w = grid_np.shape
    padded_grid[:h, :w] = grid_np
    return padded_grid

# --- Fix: Byte Sequence Padding for the Model --- #
# According to the model explanation, the key step is to pad the *serialized byte sequence*
# to `config.max_sequence_length`. The function below implements this logic.

# Define the model's expected input dimension from the configuration.
MAX_SEQUENCE_LENGTH = 8192
PADDING_BYTE_VALUE = 0

def serialize_and_pad_grid(grid, max_len=MAX_SEQUENCE_LENGTH, pad_value=PADDING_BYTE_VALUE):
    """
    Serializes a grid into a byte sequence and pads it to a fixed length.

    This function implements the required padding logic for the LearnedBytePatcherEncoder.
    It takes a grid, converts it to a flat byte sequence, and then pads or truncates
    it to `max_sequence_length` (8192 bytes), ensuring a fixed-size input for the model.
    
    Args:
        grid (list or np.ndarray): The input ARC grid.
        max_len (int): The target length for the byte sequence, corresponding to
                       `config.max_sequence_length`.
        pad_value (int): The byte value to use for padding (0-255).

    Returns:
        bytes: The padded byte sequence of length `max_len`.
    """
    # Convert the grid to a NumPy array of single bytes (uint8) and flatten it.
    # ARC values (0-9) fit perfectly within a single byte.
    flat_array = np.array(grid, dtype=np.uint8).flatten()

    # Serialize the flattened array into a raw byte sequence.
    byte_sequence = flat_array.tobytes()

    # Calculate the number of padding bytes needed.
    padding_len = max_len - len(byte_sequence)

    if padding_len < 0:
        # If the original sequence is too long, truncate it.
        padded_sequence = byte_sequence[:max_len]
    else:
        # If the sequence is shorter, create padding and append it.
        padding = bytes([pad_value] * padding_len)
        padded_sequence = byte_sequence + padding
        
    return padded_sequence

class NewCustomARCGridDataset(Dataset):
    def __init__(self, data_dir, max_grid_size=MAX_GRID_SIZE, padding_value=PADDING_VALUE):
        self.data_dir = data_dir
        self.task_files = glob.glob(os.path.join(data_dir, "*.json"))
        self.max_grid_size = max_grid_size
        self.padding_value = padding_value
        self.tasks = []
        print(f"NewCustomARCGridDataset: Looking for tasks in: {data_dir}")
        if not self.task_files:
            print(f"NewCustomARCGridDataset Warning: No JSON files found in {data_dir}. Dataset will be empty.")
        for task_file in self.task_files:
            try:
                with open(task_file, 'r') as f:
                    self.tasks.append(json.load(f))
            except Exception as e:
                print(f"NewCustomARCGridDataset Warning: Could not load or parse {task_file}: {e}")
        if not self.tasks:
            print(f"NewCustomARCGridDataset Warning: No tasks successfully loaded from {data_dir}.")
        else:
            print(f"NewCustomARCGridDataset: Loaded {len(self.tasks)} ARC tasks from {data_dir}.")

    def __len__(self):
        return len(self.tasks)

    def __getitem__(self, idx):
        task_data = self.tasks[idx]
        processed_task = {'train': [], 'test': [], 'id': os.path.basename(self.task_files[idx]) if idx < len(self.task_files) else 'unknown_task'}

        for pair_type in ['train', 'test']:
            for item in task_data.get(pair_type, []):
                input_grid_list = item.get('input', [])
                output_grid_list = item.get('output', [])
                
                original_input_dims = (len(input_grid_list), len(input_grid_list[0]) if input_grid_list and input_grid_list[0] else (0,0))
                original_output_dims = (len(output_grid_list), len(output_grid_list[0]) if output_grid_list and output_grid_list[0] else (0,0))

                padded_input_np = pad_grid(input_grid_list, self.max_grid_size, self.padding_value)
                padded_output_np = pad_grid(output_grid_list, self.max_grid_size, self.padding_value)
                
                processed_task[pair_type].append({
                    'input': torch.from_numpy(padded_input_np).long(),
                    'output': torch.from_numpy(padded_output_np).long(),
                    'original_input_dims': original_input_dims,
                    'original_output_dims': original_output_dims
                })
        return processed_task

def collate_fn_new_custom_arc(batch_of_tasks):
    input_byte_sequences_list = []
    target_byte_sequences_for_diffusion_list = []
    original_target_grids_for_ce_loss_list = []

    for task in batch_of_tasks:
        if not isinstance(task, dict):
            continue

        # Process 'train' pairs from the task
        for train_pair in task.get('train', []):
            if not isinstance(train_pair, dict) or 'input' not in train_pair or 'output' not in train_pair:
                continue

            # train_pair['input'] and train_pair['output'] are already padded 2D LongTensors from NewCustomARCGridDataset
            input_grid_np = train_pair['input'].numpy() # Convert to numpy for serialize_and_pad_grid
            target_grid_np = train_pair['output'].numpy()

            # 1. Create input_byte_sequences (uint8)
            input_bytes = serialize_and_pad_grid(input_grid_np, max_len=MAX_SEQUENCE_LENGTH, pad_value=PADDING_BYTE_VALUE)
            input_byte_sequences_list.append(torch.tensor(list(input_bytes), dtype=torch.uint8))

            # 2. Create target_byte_sequences_for_diffusion (uint8)
            target_bytes_for_diffusion = serialize_and_pad_grid(target_grid_np, max_len=MAX_SEQUENCE_LENGTH, pad_value=PADDING_BYTE_VALUE)
            target_byte_sequences_for_diffusion_list.append(torch.tensor(list(target_bytes_for_diffusion), dtype=torch.uint8))

            # 3. Keep original_target_grids_for_ce_loss (long tensor, flattened)
            original_target_grids_for_ce_loss_list.append(train_pair['output'].view(-1)) # Flattened LongTensor
            
    if not input_byte_sequences_list:
        return {
            'input_byte_sequences': torch.empty(0, MAX_SEQUENCE_LENGTH, dtype=torch.uint8),
            'target_byte_sequences_for_diffusion': torch.empty(0, MAX_SEQUENCE_LENGTH, dtype=torch.uint8),
            'original_target_grids_for_ce_loss': torch.empty(0, ARC_INPUT_FLAT_DIM, dtype=torch.long),
        }

    # Stack all collected tensors
    final_input_byte_sequences = torch.stack(input_byte_sequences_list)
    final_target_byte_sequences_for_diffusion = torch.stack(target_byte_sequences_for_diffusion_list)
    final_original_target_grids_for_ce_loss = torch.stack(original_target_grids_for_ce_loss_list)
    
    return {
        'input_byte_sequences': final_input_byte_sequences,
        'target_byte_sequences_for_diffusion': final_target_byte_sequences_for_diffusion,
        'original_target_grids_for_ce_loss': final_original_target_grids_for_ce_loss,
    }

# --- ARC Training Setup ---
ARC_OUTPUT_HEAD_DIM = ARC_INPUT_FLAT_DIM * NUM_ARC_SYMBOLS
ARC_TASK_ID = 3
print(f"ARC Output Head Dim: {ARC_OUTPUT_HEAD_DIM}")

ctm_model_arc, arc_output_head, optimizer_arc, ctm_mcmc_integration_arc, accelerator_arc = None, None, None, None, None

print("\n-----------------------------------------------------------------------------")
print("Initializing Configuration and Model for ARC with EnhancedCTMDiffusion")
print("-----------------------------------------------------------------------------")

'''
You do not need to add any of the variables again from the ctm_Diffusion_NEWNEW.py file to your config_arc_diffusion in the Arc_AGI_2_Final.ipynb file. All the parameters you listed are already explicitly defined when config_arc_diffusion is created (between lines 2200 and 2378 approximately).

The EnhancedCTMConfig class in ctm_Diffusion_NEWNEW.py provides default values for its fields. When you create an instance like config_arc_diffusion, any parameters you explicitly set will override these defaults. Since all the parameters in your list are already set in your notebook, those are the values that will be used for training.

For example:

attention_type is set to "subquadratic" on line 2260.
positional_embedding_type is set to 'multi-learnable-fourier' on line 2276.
enable_pipeline_parallelism is set to True on line 2288.
And so on for all the other parameters you mentioned.
If you wish to change any of these settings, you should modify their values directly in the existing config_arc_diffusion definition within your Arc_AGI_2_Final.ipynb file.
'''

# Define EnhancedCTMConfig for ARC with EnhancedCTMDiffusion
# Assuming EnhancedCTMConfig is a defined class and MAX_SEQUENCE_LENGTH is a defined variable
# For example:
# from your_model_library import EnhancedCTMConfig
# MAX_SEQUENCE_LENGTH = 8192

# From contineous-thought-machines/models/constants.py
VALID_NEURON_SELECT_TYPES = [
    'first-last', 'random', 'random-pairing',  # Legacy
    # Biologically-inspired types
    'bio_hebbian', 'bio_plasticity', 'bio_competitive', 'bio_homeostatic',
    'bio_evolutionary', 'bio_stdp', 'bio_criticality', 'bio_multi_objective',
    # Hybrid approaches
    'adaptive_random', 'performance_guided', 'task_aware'
]

VALID_POSITIONAL_EMBEDDING_TYPES = [
    'learnable-fourier', 'multi-learnable-fourier',
    'custom-rotational', 'custom-rotational-1d'
]

# From contineous-thought-machines/models/ctm_Diffusion_NEWNEW.py
from dataclasses import dataclass, field
from typing import Dict, Optional, Tuple, Union, Any, List

@dataclass
class EnhancedCTMConfig: # Renamed from ContinualLearningConfig for consistency in the target file
    """Enhanced configuration for continual learning CTM-diffusion model,
    incorporating binary processing, multi-task learning, and advanced CTM features."""
    
    # Model architecture (General Transformer/Diffusion settings)
    d_model: int = 512  # Main model dimensionality
    n_heads: int = 8
    n_layers: int = 24
    max_sequence_length: int = 8192 # Max input sequence length in terms of bytes or patches
    dropout: float = 0.1
    
    # --- Byte Processing Options ---
    patch_embedding_dim: int = 256         # <<< NEW: Output embedding dimension per patch from patcher
    patch_encoder_cnn_channels: int = 64   # <<< NEW: Intermediate channels for CNN patch encoder

    # --- Dynamic Entropy Patching Options (Inspired by BLT paper) ---
    use_dynamic_entropy_patcher: bool = True # Flag to enable dynamic entropy-based patching
    entropy_patcher_threshold_type: str = "global"  # 'global' or 'relative_monotonic'
    entropy_patcher_global_threshold: float = 0.75 # Entropy threshold for 'global' type
    entropy_patcher_relative_threshold: float = 0.1 # Entropy diff threshold for 'relative_monotonic'
    entropy_patcher_min_patch_size: int = 4      # Minimum number of bytes in a dynamic patch
    entropy_patcher_max_patch_size: int = 128    # Maximum number of bytes in a dynamic patch (for CNN encoder)
    
    # --- Learnable Entropy Model Parameters (for _EntropyProxyModel) ---
    entropy_model_byte_vocab_size: int = 256
    entropy_model_embedding_dim: int = 64
    entropy_model_hidden_dim: int = 128
    entropy_model_num_layers: int = 1
    entropy_model_dropout: float = 0.1
    entropy_model_loss_weight: float = 0.1 # Weight for its auxiliary loss contribution
    # Note: These parameters are used if use_dynamic_entropy_patcher is True,
    # as LearnedBytePatcherEncoder now instantiates the learnable _EntropyProxyModel.
    
    # Fallback if not using learned_patch_encoder or dynamic_entropy_patcher
    byte_embedding_dim: int = 256
    multi_granularity: bool = False # Default to False if patcher is preferred
    # multi_granularity_output_dim is complex to predefine, MGP should expose its output dim.
    # For now, if multi_granularity is True AND use_learned_patch_encoder is False, this would be used.
    multi_granularity_output_dim: int = 256 # Placeholder if MGP is used.
    
    hierarchical_processing: bool = True # General flag, could apply to patcher or MGP
    
    # CTM Core Parameters (Specific to the OriginalCTMCore module)
    # These are prefixed with 'ctm_' to distinguish from general model params
    ctm_iterations: int = 5  # Original 'iterations'
    ctm_d_model: int = 512   # Original 'd_model' for CTM's internal latent space
    ctm_input_dim: int = 256 # Dimensionality of inputs to CTM (e.g., from byte embeddings or other features)
                             # This was 'd_input' in OriginalCTMCore if it took external features.
                             # If CTM processes outputs of byte_embedding, this might be byte_embedding_dim.
    ctm_heads: int = 8       # Attention heads within CTM
    ctm_n_synch_out: int = 64
    ctm_n_synch_action: int = 64
    ctm_synapse_depth: int = 3
    ctm_memory_length: int = 10
    ctm_deep_nlms: bool = True
    ctm_memory_hidden_dims: int = 2048
    ctm_do_layernorm_nlm: bool = False
    ctm_out_dims: int = 512  # Output dimension of CTM's own projector
    ctm_prediction_reshaper: list = field(default_factory=lambda: [-1])
    ctm_dropout: float = 0.1
    ctm_dropout_nlm: Optional[float] = None
    # Neuron selection strategy. Available options:
    # Legacy: 'first-last', 'random', 'random-pairing'
    # Biologically-inspired: 'bio_hebbian', 'bio_plasticity', 'bio_competitive',
    #                        'bio_homeostatic', 'bio_evolutionary', 'bio_stdp',
    #                        'bio_criticality', 'bio_multi_objective'
    # Hybrid: 'adaptive_random', 'performance_guided', 'task_aware'
    ctm_neuron_select_type: str = 'bio_multi_objective'
    ctm_n_random_pairing_self: int = 0
    
    # Diffusion Parameters
    diffusion_steps: int = 1000
    noise_schedule: str = "cosine" # e.g., "linear", "cosine"
    diffusion_beta_start: float = 0.0001
    diffusion_beta_end: float = 0.02
    diffusion_timesteps: int = 1000 # Number of timesteps for the diffusion process
    ctm_diffusion_coupling_strength: float = 0.8 # How CTM influences diffusion
    adaptive_scheduling: bool = True  # CTM-adaptive diffusion timestep scheduling
    iterative_refinement: bool = True # Iterative CTM-diffusion refinement for sampling
    

    
    # Training Efficiency
    mixed_precision: bool = True
    gradient_checkpointing: bool = True
    sparse_attention: bool = True  # Now implemented with BinarySparseAttention
    adaptive_depth: bool = False   # Defaulting to False, can be enabled if implemented
    
    # Sparse Attention Parameters
    sparse_attention_ratio: float = 0.1  # Keep only 10% of attention connections
    binary_pattern_size: int = 8  # Size of binary patterns to detect

    # Attention Mechanism Type
    attention_type: str = "subquadratic"  # Options: "standard", "binary_sparse", "subquadratic"
    
    # Subquadratic Attention Parameters (if attention_type is "subquadratic")
    subquadratic_attn_epsilon: float = 1e-6
    subquadratic_attn_poly_degree: int = 5
    attention_qkv_bias: bool = True # General QKV bias for attention mechanisms like Subquadratic or standard MHA
    # attn_drop and proj_drop for subquadratic_attn will be mapped from ctm_dropout

    # Positional Embedding Parameters
    positional_embedding_type: Optional[str] = 'multi-learnable-fourier' # e.g., 'custom-rotational-1d', 'learnable-fourier', multi-learnable-fourier' #Can set the value here. 
    positional_embedding_dim: Optional[int] = None  # Dimension of the positional embedding, defaults to ctm_input_dim if None
    reshape_patch_sequence_to_grid: bool = True # If True, reshape patch sequence to a 2D grid for 2D PEs. Must set to true if using 2D Grid for Positional Embeddings.
    patch_grid_width: Optional[int] = None       # Desired width of the patch grid if reshaping

    # Pipeline Parallelism Parameters
    enable_pipeline_parallelism: bool = True
    pipeline_stages: int = 4  # CTM, MCMC, Diffusion prep, Diffusion exec
    pipeline_overlap_ratio: float = 0.7  # Target overlap ratio
    
    # Adaptive Batch Sizing Parameters
    enable_adaptive_batching: bool = True
    initial_batch_size: int = 32
    min_batch_size: int = 8
    max_batch_size: int = 256
    batch_adaptation_frequency: int = 100
    memory_threshold_high: float = 0.85
    memory_threshold_low: float = 0.6
    
    # Smart Data Sampling Parameters
    enable_smart_sampling: bool = True
    sample_importance_weight: float = 0.6
    sample_diversity_weight: float = 0.4
    initial_sample_ratio: float = 0.3
    complexity_analysis_enabled: bool = True
    
    # Multi-input/output parameters
    num_inputs: int = 1  # Number of input streams
    num_outputs: int = 1  # Number of output heads
    output_dims: List[int] = field(default_factory=lambda: [64])  # Dimensions for each output head
    
    # Self-supervised learning
    ssl_dim: int = 128  # Dimension for self-supervised projection
    ssl_weight: float = 0.1  # Weight for self-supervised loss
    ssl_temperature: float = 0.07  # Temperature for contrastive loss
    ssl_noise_std: float = 0.1  # Noise standard deviation for contrastive augmentation
    
    # Spatiotemporal Processing
    use_spatial: bool = True  # Enable spatial processing for image/video data
    
    # WINA Attention
    use_wina_attention: bool = True  # Enable WINA sparse attention
    
    # Multi-task Learning Parameters
    max_tasks: int = 50  # Maximum number of tasks for continual learning
    # Added to resolve TypeError for unexpected keyword arguments
    vocab_size: Optional[int] = None
    output_audio_bytes: bool = False
    inferred_task_latent_dim: Optional[int] = None # Default to None, __post_init__ handles it
    use_hipa_attention: bool = False # Default to False
    hipa_num_heads: Optional[int] = None # Default to None
    audio_output_dtype_str: Optional[str] = "float32" # Default as per __post_init__ logic
    unet_input_feature_dim: Optional[int] = None # Default to None, __post_init__ calculates it

    # --- JEPA Training Parameters (Integrated with LearnedBytePatcherEncoder) ---
    use_jepa_training: bool = False
    # jepa_embed_dim will be derived from patch_embedding_dim if dynamic_entropy_patcher is used
    jepa_predictor_hidden_dim: int = 512 # Hidden dimension of JEPA predictor MLP
    jepa_mask_ratio_min: float = 0.15 # Min proportion of patch sequence to mask for target
    jepa_mask_ratio_max: float = 0.75 # Max proportion of patch sequence to mask for target
    jepa_context_scale_min: float = 0.3 # Min proportion of patches for context
    jepa_context_scale_max: float = 0.7 # Max proportion of patches for context
    jepa_momentum_beta: float = 0.996 # Momentum for target encoder update
    jepa_loss_weight: float = 0.1 # Weight for the JEPA loss component
    jepa_num_target_blocks: int = 1 # Number of target blocks to predict

    # --- Knowledge Store Parameters ---

    def __post_init__(self):
        # Validate output dimensions
        if len(self.output_dims) != self.num_outputs:
            raise ValueError(f"output_dims length ({len(self.output_dims)}) must match num_outputs ({self.num_outputs})")

        # Merged content from the second __post_init__
        if hasattr(self, 'ctm_prediction_reshaper') and self.ctm_prediction_reshaper == [-1] and self.vocab_size is not None:
            pass
        if hasattr(self, 'ctm_dropout_nlm') and self.ctm_dropout_nlm is None and hasattr(self, 'ctm_dropout'):
            self.ctm_dropout_nlm = self.ctm_dropout
        if hasattr(self, 'mcmc_output_space_dim') and self.mcmc_output_space_dim is None and hasattr(self, 'ctm_out_dims'):
            self.mcmc_output_space_dim = self.ctm_out_dims
        
        if hasattr(self, 'ctm_neuron_select_type') and \
           VALID_NEURON_SELECT_TYPES is not None and self.ctm_neuron_select_type not in VALID_NEURON_SELECT_TYPES:
            print(f"Warning: ctm_neuron_select_type '{self.ctm_neuron_select_type}' is not in VALID_NEURON_SELECT_TYPES ({VALID_NEURON_SELECT_TYPES}).")

        if hasattr(self, 'positional_embedding_type') and self.positional_embedding_type is not None:
            if VALID_POSITIONAL_EMBEDDING_TYPES is None: # Fallback if import failed
                print(f"Warning: VALID_POSITIONAL_EMBEDDING_TYPES not available for validation.")
            elif self.positional_embedding_type not in VALID_POSITIONAL_EMBEDDING_TYPES:
                print(f"Warning: positional_embedding_type '{self.positional_embedding_type}' is not in VALID_POSITIONAL_EMBEDDING_TYPES ({VALID_POSITIONAL_EMBEDDING_TYPES}).")
            if self.positional_embedding_dim is not None and self.positional_embedding_dim <= 0:
                raise ValueError("positional_embedding_dim must be positive if set.")
            
            if self.reshape_patch_sequence_to_grid:
                if self.patch_grid_width is None or self.patch_grid_width <= 0:
                    raise ValueError("patch_grid_width must be a positive integer if reshape_patch_sequence_to_grid is True.")
                if self.positional_embedding_type not in ['learnable-fourier', 'multi-learnable-fourier', 'custom-rotational']:
                    print(f"Warning: reshape_patch_sequence_to_grid is True, but positional_embedding_type ('{self.positional_embedding_type}') is not a typical 2D PE. Ensure compatibility.")

        # Validations for new patch encoder
        if self.use_dynamic_entropy_patcher:
            if self.patch_embedding_dim <= 0:
                raise ValueError("patch_embedding_dim must be positive if use_dynamic_entropy_patcher is True.")
            if self.entropy_patcher_min_patch_size <= 0:
                raise ValueError("entropy_patcher_min_patch_size must be positive.")
            if self.entropy_patcher_max_patch_size < self.entropy_patcher_min_patch_size:
                raise ValueError("entropy_patcher_max_patch_size must be >= entropy_patcher_min_patch_size.")
            if self.entropy_patcher_threshold_type not in ["global", "relative_monotonic"]:
                raise ValueError("entropy_patcher_threshold_type must be 'global' or 'relative_monotonic'.")
        elif self.multi_granularity and self.multi_granularity_output_dim <= 0:
            print("Warning: multi_granularity_output_dim might not be correctly set for validation if not using a patcher and MGP is active.")
        
        if not hasattr(self, 'inferred_task_latent_dim') or self.inferred_task_latent_dim is None:
            print("Warning: inferred_task_latent_dim not found or is None in config, defaulting to 64.")
            self.inferred_task_latent_dim = 512
        elif self.inferred_task_latent_dim <= 0: # This check is now safe
            raise ValueError("inferred_task_latent_dim must be positive.")
 
        if hasattr(self, 'use_hipa_attention') and self.use_hipa_attention and \
            (not hasattr(self, 'hipa_num_heads') or self.hipa_num_heads <= 0):
             raise ValueError("hipa_num_heads must be positive if use_hipa_attention is True.")
 
        if hasattr(self, 'audio_output_dtype_str'):
            if self.audio_output_dtype_str == "float32":
                self.audio_output_item_size = 4
            elif self.audio_output_dtype_str == "int16":
                self.audio_output_item_size = 2
            else:
                if hasattr(self, 'output_audio_bytes') and self.output_audio_bytes:
                    raise ValueError(f"Unsupported audio_output_dtype_str: {self.audio_output_dtype_str} when output_audio_bytes is True.")
                else:
                    self.audio_output_item_size = 4
        elif hasattr(self, 'output_audio_bytes') and self.output_audio_bytes:
            if not hasattr(self, 'audio_output_dtype_str') or self.audio_output_dtype_str is None:
                raise ValueError("audio_output_dtype_str must be defined in config if output_audio_bytes is True.")
        else:
            self.audio_output_item_size = 4

        # Calculate unet_input_feature_dim if not set
        if self.unet_input_feature_dim is None:
            if self.max_sequence_length <= 0 or self.audio_output_item_size <= 0:
                raise ValueError("max_sequence_length and audio_output_item_size must be positive to calculate unet_input_feature_dim.")
            self.unet_input_feature_dim = self.max_sequence_length // self.audio_output_item_size
            if self.unet_input_feature_dim <= 0:
                raise ValueError(f"Calculated unet_input_feature_dim ({self.unet_input_feature_dim}) must be positive. Check max_sequence_length and audio_output_item_size.")
        elif self.unet_input_feature_dim <= 0:
            raise ValueError("unet_input_feature_dim, if set, must be positive.")

        if self.use_jepa_training:
            if not (0 < self.jepa_mask_ratio_min < 1 and 0 < self.jepa_mask_ratio_max < 1 and self.jepa_mask_ratio_min <= self.jepa_mask_ratio_max):
                raise ValueError("JEPA mask ratios must be between 0 and 1, with min <= max.")
            if not (0 < self.jepa_context_scale_min < 1 and 0 < self.jepa_context_scale_max < 1 and self.jepa_context_scale_min <= self.jepa_context_scale_max):
                raise ValueError("JEPA context scales must be between 0 and 1, with min <= max.")
            if not (0 <= self.jepa_momentum_beta < 1):
                raise ValueError("jepa_momentum_beta must be between 0 and 1.")
            if self.jepa_num_target_blocks <= 0:
                raise ValueError("jepa_num_target_blocks must be positive.")
            if not self.use_dynamic_entropy_patcher:
                print("Warning: JEPA training is enabled but use_dynamic_entropy_patcher is False. JEPA relies on the patch embeddings from LearnedBytePatcherEncoder.")

# Define EnhancedCTMConfig for ARC with EnhancedCTMDiffusion
config_arc_diffusion = EnhancedCTMConfig(
    d_model=512,
    #inferred_task_latent_dim=64, # This line remains commented out
    n_heads=8,
    n_layers=24, 
    max_sequence_length=MAX_SEQUENCE_LENGTH,
    dropout=0.1,
    use_dynamic_entropy_patcher=True,
    patch_embedding_dim=256,
    patch_grid_width=16,
    patch_encoder_cnn_channels=64,
    entropy_patcher_threshold_type="global",
    entropy_patcher_global_threshold=0.75,
    entropy_patcher_relative_threshold=0.1,
    entropy_patcher_min_patch_size=4,
    entropy_patcher_max_patch_size=128,
    # Parameters for the learnable entropy model within LearnedBytePatcherEncoder
    entropy_model_byte_vocab_size=256,
    entropy_model_embedding_dim=64,
    entropy_model_hidden_dim=128,
    entropy_model_num_layers=1,
    entropy_model_dropout=0.1,
    entropy_model_loss_weight=0.1,
    
    ctm_input_dim=256,
    ctm_d_model=512,
    ctm_iterations=5,
    ctm_heads=8,
    ctm_out_dims=512,
    ctm_neuron_select_type='bio_multi_objective',
    
    # Attention Mechanism Type
    attention_type="subquadratic",  # Options: "standard", "binary_sparse", "subquadratic"
    
    # Subquadratic Attention Parameters
    subquadratic_attn_epsilon=1e-6,
    subquadratic_attn_poly_degree=5,
    attention_qkv_bias=True, # Corrected capitalization
    
    # Positional Embedding Parameters
    positional_embedding_type='multi-learnable-fourier',
    positional_embedding_dim=None,
    reshape_patch_sequence_to_grid=True,
    #patch_grid_width=None, #Already defined in the byte patch section of this config. 

    # Pipeline Parallelism Parameters
    enable_pipeline_parallelism=True,
    pipeline_stages=4,
    pipeline_overlap_ratio=0.7,
    
    # Adaptive Batch Sizing Parameters
    enable_adaptive_batching=True,
    initial_batch_size=32,
    min_batch_size=8,
    max_batch_size=256,
    batch_adaptation_frequency=100,
    memory_threshold_high=0.85,
    memory_threshold_low=0.6,
    
    # Smart Data Sampling Parameters
    enable_smart_sampling=True,
    sample_importance_weight=0.6,
    sample_diversity_weight=0.4,
    initial_sample_ratio=0.3,
    complexity_analysis_enabled=True,
    
    # Multi-input/output parameters
    num_inputs=1,
    num_outputs=1,
    output_dims=[64],  # Directly pass the list value
    
    # Self-supervised learning
    ssl_dim=128,
    ssl_weight=0.1,
    ssl_temperature=0.07,
    ssl_noise_std=0.1,
    
    # Spatiotemporal Processing
    use_spatial=True,
    
    # WINA Attention
    use_wina_attention=True,
    
    # Multi-task Learning Parameters
    max_tasks=50,
    diffusion_steps=1000,
    ctm_diffusion_coupling_strength=0.8,
    vocab_size=None,
    #enable_enhanced_mcmc=False, #ONLY USE THE ARC_AGI NOTEBOOK VERSION AND NOT THE ONE IMPORTED FROM THE DIFFUSION_NEWNEW file (This needs to be false). This flie cannot use this variable.
    #mcmc_config=MCMC_CONFIG_ARC, #I don't think this is needed. 
    output_audio_bytes=False
)

print("✓ EnhancedCTMConfig for ARC (config_arc_diffusion) created.")

if 'enhanced_ctm_mcmc' not in globals():
    print("Warning: 'enhanced_ctm_mcmc' not found in globals. Defaulting to None. Ensure the cell defining it (approx. lines 1820-1866) was run successfully.")
    enhanced_ctm_mcmc = None
    
if 'EnhancedCTMDiffusion' in globals() and EnhancedCTMDiffusion is not None:
    ctm_model_arc = EnhancedCTMDiffusion(config=config_arc_diffusion).to(device)
    print("✓ EnhancedCTMDiffusion model for ARC (ctm_model_arc) initialized.")

    # The external ARC output head will take features from the CTM core part of EnhancedCTMDiffusion
    arc_output_head_input_dim = config_arc_diffusion.output_dims[0]
    arc_output_head = nn.Linear(arc_output_head_input_dim, ARC_OUTPUT_HEAD_DIM).to(device)
    print(f"✓ ARC Output Head initialized (input_dim: {arc_output_head_input_dim}, output_dim: {ARC_OUTPUT_HEAD_DIM}).")

    # Handle external MCMC integration if enabled
    if ENABLE_CTM_MCMC_INTEGRATION_FOR_ARC and enhanced_ctm_mcmc:
        # Ensure the external MCMC module's input_dim matches the new CTM's output
        if enhanced_ctm_mcmc.thought_network[0].in_features != config_arc_diffusion.output_dims[0]:
            print(f"Re-initializing external enhanced_ctm_mcmc for new input_dim {config_arc_diffusion.output_dims[0]}")
            enhanced_ctm_mcmc = EnhancedCTMFenchelYoungIntegration(
                input_dim=config_arc_diffusion.output_dims[0], # Use output dim of CTM core
                output_space=arc_grid_output_space,
                mcmc_config=MCMC_CONFIG_ARC,
                use_large_neighborhood_search=True,
                lns_frequency=5,
                lns_neighborhood_size=10
            )
        ctm_mcmc_integration_arc = enhanced_ctm_mcmc.to(device) if enhanced_ctm_mcmc else None
        print(f"✓ External MCMC Integration for ARC is {'enabled' if ctm_mcmc_integration_arc else 'FAILED to enable'}.")
    
    arc_trainable_params = list(ctm_model_arc.parameters()) # EnhancedCTMDiffusion parameters
    if arc_output_head: arc_trainable_params.extend(list(arc_output_head.parameters()))
    if ctm_mcmc_integration_arc:
        arc_trainable_params.extend(list(ctm_mcmc_integration_arc.parameters()))

    optimizer_arc = optim.AdamW([p for p in arc_trainable_params if p.requires_grad], lr=LEARNING_RATE, weight_decay=1e-4)
    
    if ACCELERATE_AVAILABLE:
        accelerator_arc = Accelerator()
        models_to_prepare = [ctm_model_arc] # Start with the main model
        if arc_output_head: models_to_prepare.append(arc_output_head)
        if ctm_mcmc_integration_arc: models_to_prepare.append(ctm_mcmc_integration_arc)
        
        prepared_components = accelerator_arc.prepare(*models_to_prepare, optimizer_arc)
        
        optimizer_arc = prepared_components[-1] # Last element is the optimizer
        prepared_models_tuple = prepared_components[:-1] # All other elements are models

        ctm_model_arc = prepared_models_tuple[0]
        model_idx = 1
        if arc_output_head:
            arc_output_head = prepared_models_tuple[model_idx]
            model_idx +=1
        if ctm_mcmc_integration_arc:
            ctm_mcmc_integration_arc = prepared_models_tuple[model_idx]
        print("✓ ARC models (EnhancedCTMDiffusion) and optimizer prepared with Accelerate.")
else:
    print("⚠️ EnhancedCTMDiffusion model or its config for ARC-AGI-2 could not be initialized. Check imports.")

CHECKPOINT_DIR_ARC = os.path.join(CHECKPOINT_DIR, "ctm_arc_agi_2_enhanced_diffusion") # New checkpoint dir
os.makedirs(CHECKPOINT_DIR_ARC, exist_ok=True)
print(f"ARC Checkpoints will be saved to: {CHECKPOINT_DIR_ARC}")

NUM_EPOCHS_ARC = 20
ARC_BATCH_SIZE = 8

arc_train_dataset = NewCustomARCGridDataset(ARC_TRAIN_DIR)
arc_eval_dataset = NewCustomARCGridDataset(ARC_EVAL_DIR)

arc_train_loader, arc_eval_loader = None, None
if arc_train_dataset and len(arc_train_dataset) > 0:
    arc_train_loader = DataLoader(
        arc_train_dataset, batch_size=ARC_BATCH_SIZE, shuffle=True,
        collate_fn=collate_fn_new_custom_arc, **OPTIMIZED_DATALOADER_CONFIG
    )
    if accelerator_arc: arc_train_loader = accelerator_arc.prepare(arc_train_loader)
    print(f"✓ ARC Training DataLoader initialized with {len(arc_train_dataset)} tasks.")
else:
    print("⚠️ ARC Training DataLoader could not be initialized.")

if arc_eval_dataset and len(arc_eval_dataset) > 0:
    arc_eval_loader = DataLoader(
        arc_eval_dataset, batch_size=1, shuffle=False,
        collate_fn=collate_fn_new_custom_arc, **OPTIMIZED_DATALOADER_CONFIG
    )
    if accelerator_arc: arc_eval_loader = accelerator_arc.prepare(arc_eval_loader)
    print(f"✓ ARC Evaluation DataLoader initialized with {len(arc_eval_dataset)} tasks.")
else:
    print("⚠️ ARC Evaluation DataLoader could not be initialized.")

arc_criterion = nn.CrossEntropyLoss(ignore_index=PADDING_VALUE)
print("\n✓ ARC-AGI-2 Setup Complete.")

In [None]:
# --- ARC-AGI-2 Training Loop ---

print("\n" + "="*60)
print(f"🚀 STARTING PHASE 4: ARC-AGI-2 Training")
print(f"   Epochs: {NUM_EPOCHS_ARC}, Batch Size: {ARC_BATCH_SIZE}, Task ID: {ARC_TASK_ID}")
print(f"   Device: {device if not accelerator_arc else accelerator_arc.device}")
print("="*60 + "\n")

if not all([ctm_model_arc, arc_output_head, optimizer_arc, arc_train_loader, arc_criterion]):
    print("⚠️ Skipping ARC-AGI-2 training due to missing components.")
else:
    for epoch in range(NUM_EPOCHS_ARC):
        ctm_model_arc.train()
        arc_output_head.train()
        if ctm_mcmc_integration_arc: ctm_mcmc_integration_arc.train()

        total_arc_loss = 0
        processed_batches = 0

        for batch_idx, batch_data in enumerate(arc_train_loader):
            if not batch_data or batch_data['input_byte_sequences'].numel() == 0:
                print(f"Skipping empty batch {batch_idx}")
                continue

            # Get data from the updated collate_fn
            input_bytes = batch_data['input_byte_sequences'].to(device if not accelerator_arc else accelerator_arc.device)
            target_bytes_for_diffusion = batch_data['target_byte_sequences_for_diffusion'].to(device if not accelerator_arc else accelerator_arc.device)
            original_target_grids_for_ce = batch_data['original_target_grids_for_ce_loss'].to(device if not accelerator_arc else accelerator_arc.device)

            current_batch_size = input_bytes.size(0)

            optimizer_arc.zero_grad()

            with autocast(enabled=USE_MIXED_PRECISION, dtype=autocast_dtype) if not accelerator_arc else accelerator_arc.autocast():
                # Forward pass through EnhancedCTMDiffusion
                # The model internally handles patching, CTM core, diffusion (if target provided), and entropy aux loss.
                model_output_dict = ctm_model_arc(
                    byte_sequence=input_bytes,
                    target_diffusion_output=target_bytes_for_diffusion, # Provide target for diffusion loss component
                    mode='ctm_controlled_diffusion', # Ensure diffusion part is active for loss calculation
                    timestep=torch.randint(0, config_arc_diffusion.diffusion_steps, (current_batch_size,), device=input_bytes.device).long(), # Random timesteps for diffusion training
                    target_mcmc_output=None, # Internal MCMC is disabled in config_arc_diffusion
                    task_name="ARC_AGI_2", # Optional task name
                    current_epoch=epoch # Pass current epoch
                )

                # Loss from EnhancedCTMDiffusion (includes entropy aux loss, diffusion loss, etc.)
                enhanced_ctm_loss = model_output_dict.get('total_loss', torch.tensor(0.0, device=input_bytes.device))
                loss = enhanced_ctm_loss

                # Get CTM core output for the external ARC head
                ctm_core_output_data = model_output_dict.get('ctm_core_data')
                ctm_backbone_output = None
                if ctm_core_output_data and 'final_sync_out' in ctm_core_output_data:
                    ctm_backbone_output = ctm_core_output_data['final_sync_out']
                elif ctm_core_output_data and 'ctm_latent_representation' in ctm_core_output_data: # Fallback key
                    ctm_backbone_output = ctm_core_output_data['ctm_latent_representation']
                else:
                    print("Warning: CTM core output ('final_sync_out' or 'ctm_latent_representation') not found. Using zeros for ARC head input.")
                    ctm_backbone_output = torch.zeros(current_batch_size, config_arc_diffusion.ctm_out_dims, device=input_bytes.device)
                
                # External ARC Output Head for CrossEntropy loss on original grid prediction
                if arc_output_head and ctm_backbone_output is not None:
                    if ctm_backbone_output.ndim > 2 and ctm_backbone_output.shape[1] > 0:
                         ctm_features_for_head = ctm_backbone_output.mean(dim=1)
                    else:
                         ctm_features_for_head = ctm_backbone_output
                    
                    predicted_logits = arc_output_head(ctm_features_for_head)
                    predicted_logits_reshaped = predicted_logits.view(current_batch_size * ARC_INPUT_FLAT_DIM, NUM_ARC_SYMBOLS)
                    target_grids_reshaped = original_target_grids_for_ce.view(current_batch_size * ARC_INPUT_FLAT_DIM)
                    ce_loss = arc_criterion(predicted_logits_reshaped, target_grids_reshaped)
                    loss += ce_loss # Add CE loss to the total loss

                # External MCMC Integration (if enabled)
                if ctm_mcmc_integration_arc and ctm_backbone_output is not None:
                    target_grids_for_mcmc = (original_target_grids_for_ce > 0).float()
                    mcmc_input_features = ctm_backbone_output.detach()
                    if mcmc_input_features.ndim > 2 and mcmc_input_features.shape[1] > 0:
                        mcmc_input_features = mcmc_input_features.mean(dim=1)

                    mcmc_loss_val, _, _ = ctm_mcmc_integration_arc(
                        x=mcmc_input_features,
                        target_y=target_grids_for_mcmc 
                    )
                    loss += mcmc_loss_val

            if scaler: # Mixed precision (manual, without Accelerate)
                scaler.scale(loss).backward()
                if (batch_idx + 1) % GRADIENT_ACCUMULATION_STEPS == 0:
                    scaler.unscale_(optimizer_arc)
                    torch.nn.utils.clip_grad_norm_(ctm_model_arc.parameters(), MAX_GRAD_NORM)
                    scaler.step(optimizer_arc)
                    scaler.update()
                    optimizer_arc.zero_grad()
            elif accelerator_arc: # Using Hugging Face Accelerate
                 accelerator_arc.backward(loss)
                 if (batch_idx + 1) % GRADIENT_ACCUMULATION_STEPS == 0:
                    optimizer_arc.step()
                    optimizer_arc.zero_grad()
            else: # Standard training
                loss.backward()
                if (batch_idx + 1) % GRADIENT_ACCUMULATION_STEPS == 0:
                    torch.nn.utils.clip_grad_norm_(ctm_model_arc.parameters(), MAX_GRAD_NORM)
                    optimizer_arc.step()
                    optimizer_arc.zero_grad()
            
            total_arc_loss += loss.item()
            processed_batches += 1

            if (batch_idx + 1) % 50 == 0:
                print(f"  Epoch [{epoch+1}/{NUM_EPOCHS_ARC}], Batch [{batch_idx+1}/{len(arc_train_loader)}], Loss: {loss.item():.4f}")
        
        avg_epoch_loss = total_arc_loss / processed_batches if processed_batches > 0 else 0
        print(f"Epoch [{epoch+1}/{NUM_EPOCHS_ARC}] completed. Average Loss: {avg_epoch_loss:.4f}")

        if CHECKPOINT_DIR_ARC:
            model_to_save_ctm = accelerator_arc.unwrap_model(ctm_model_arc) if accelerator_arc else ctm_model_arc
            model_to_save_head = accelerator_arc.unwrap_model(arc_output_head) if accelerator_arc else arc_output_head

            # Save the models' state_dict and optimizer state using torch.save
            torch.save(model_to_save_ctm.state_dict(), os.path.join(CHECKPOINT_DIR_ARC, f"ctm_model_arc_epoch_{epoch+1}.pt"))
            torch.save(model_to_save_head.state_dict(), os.path.join(CHECKPOINT_DIR_ARC, f"arc_output_head_epoch_{epoch+1}.pt"))
            torch.save(optimizer_arc.state_dict(), os.path.join(CHECKPOINT_DIR_ARC, f"optimizer_arc_epoch_{epoch+1}.pt"))
            
            print(f"  ✓ Checkpoint saved for epoch {epoch+1} to {CHECKPOINT_DIR_ARC}")

    print("\n🎉 ARC-AGI-2 Training Phase Completed!")

In [None]:
# ## ARC-AGI-2 Evaluation

import traceback

print("\n" + "="*60)
print(f"🔬 STARTING ARC-AGI-2 Evaluation")
print("="*60 + "\n")

if not all([ctm_model_arc is not None, arc_output_head is not None, arc_eval_loader is not None]):
    print("⚠️ Skipping ARC-AGI-2 evaluation due to missing components.")
else:
    latest_epoch = NUM_EPOCHS_ARC
    ctm_checkpoint_path_eval = os.path.join(CHECKPOINT_DIR_ARC, f"ctm_model_arc_epoch_{latest_epoch}.safetensors")
    head_checkpoint_path_eval = os.path.join(CHECKPOINT_DIR_ARC, f"arc_output_head_epoch_{latest_epoch}.safetensors")

    try:
        if os.path.exists(ctm_checkpoint_path_eval) and callable(load_model):
            unwrapped_ctm_model = accelerator_arc.unwrap_model(ctm_model_arc) if accelerator_arc else ctm_model_arc
            load_model(unwrapped_ctm_model, ctm_checkpoint_path_eval, device=device if not accelerator_arc else accelerator_arc.device)
            
            # FIX: Flatten LSTM parameters after loading to ensure contiguous memory, matching the saving process.
            # This is good practice and can prevent errors in distributed/parallel environments.
            if hasattr(unwrapped_ctm_model, 'dynamic_entropy_patcher') and \
               hasattr(unwrapped_ctm_model.dynamic_entropy_patcher, 'entropy_model') and \
               hasattr(unwrapped_ctm_model.dynamic_entropy_patcher.entropy_model, 'lstm') and \
               isinstance(unwrapped_ctm_model.dynamic_entropy_patcher.entropy_model.lstm, torch.nn.LSTM):
                print("    > Flattening LSTM parameters for ctm_model_arc after loading...")
                unwrapped_ctm_model.dynamic_entropy_patcher.entropy_model.lstm.flatten_parameters()

            print(f"✓ Loaded CTM checkpoint from epoch {latest_epoch}.")
        else:
            print(f"⚠️ CTM Checkpoint not found. Evaluating with current model state.")

        if os.path.exists(head_checkpoint_path_eval) and callable(load_model):
            unwrapped_head_model = accelerator_arc.unwrap_model(arc_output_head) if accelerator_arc else arc_output_head
            load_model(unwrapped_head_model, head_checkpoint_path_eval, device=device if not accelerator_arc else accelerator_arc.device)
            print(f"✓ Loaded ARC Output Head checkpoint from epoch {latest_epoch}.")
        else:
            print(f"⚠️ ARC Output Head Checkpoint not found. Evaluating with current model state.")
        
        ctm_model_arc.eval()
        arc_output_head.eval()
        if ctm_mcmc_integration_arc: ctm_mcmc_integration_arc.eval()

        total_tasks = 0
        solved_tasks = 0

        with torch.inference_mode():
            for task_idx, task_batch in enumerate(arc_eval_loader):
                if not task_batch: continue
                
                current_task_data = task_batch # Dataloader batch_size=1, so task_batch is the task dict
                
                total_tasks += 1
                task_solved_overall = True

                if 'test' not in current_task_data or not current_task_data['test']:
                    print(f"Task {task_idx + 1} ({current_task_data.get('id', 'N/A')}): No test cases found. Skipping.")
                    task_solved_overall = False
                    continue

                for test_pair_idx, test_pair in enumerate(current_task_data['test']):
                    # Input for evaluation is a single grid, needs to be converted to byte sequence
                    input_grid_np_eval = test_pair['input'].numpy() # Get numpy array from tensor
                    input_bytes_eval_single = serialize_and_pad_grid(input_grid_np_eval, max_len=MAX_SEQUENCE_LENGTH, pad_value=PADDING_BYTE_VALUE)
                    input_bytes_eval = torch.tensor(list(input_bytes_eval_single), dtype=torch.uint8).unsqueeze(0).to(device if not accelerator_arc else accelerator_arc.device)

                    target_grid_np = test_pair['output'].cpu().numpy()
                    original_dims = test_pair['original_output_dims']

                    test_input_solved = False
                    for trial in range(3): # ARC rules allow 3 trials
                        # Forward pass with EnhancedCTMDiffusion using CTM-controlled diffusion for generation
                        # Assuming timestep 0 is appropriate for one-step or final-step generation
                        current_batch_size_eval = input_bytes_eval.size(0) # Should be 1 for evaluation
                        eval_timestep = torch.zeros(current_batch_size_eval, device=input_bytes_eval.device).long()

                        eval_model_output_dict = ctm_model_arc(
                            byte_sequence=input_bytes_eval,
                            mode='ctm_controlled_diffusion', # Use CTM-controlled diffusion
                            target_diffusion_output=None,   # No target during generation
                            timestep=eval_timestep,
                            task_name="ARC_AGI_2_EVAL_DIFFUSION"
                        )
                        
                        # ASSUMPTION: The generated output is a byte sequence under the key 'diffusion_output_pred'
                        # The shape is expected to be (batch_size, MAX_SEQUENCE_LENGTH)
                        predicted_byte_sequence = eval_model_output_dict.get('diffusion_output_pred') 
                        
                        if predicted_byte_sequence is None:
                            print("Warning: Key 'diffusion_output_pred' not found in model output. Trying 'generated_output'.")
                            predicted_byte_sequence = eval_model_output_dict.get('generated_output') # Common alternative
                        
                        if predicted_byte_sequence is None:
                            print("Warning: Generated output key not found. Using zeros as prediction.")
                            # Fallback: create a zero tensor of the expected grid size if generation fails to be found
                            preds_grid = np.zeros(MAX_GRID_SIZE, dtype=int)
                        else:
                            # Ensure the sequence has the correct batch dimension (should be 1)
                            if predicted_byte_sequence.ndim == 1 and current_batch_size_eval == 1:
                                predicted_byte_sequence = predicted_byte_sequence.unsqueeze(0)

                            # Extract the part of the sequence corresponding to the flattened grid
                            # ARC_INPUT_FLAT_DIM = MAX_GRID_SIZE[0] * MAX_GRID_SIZE[1]
                            if predicted_byte_sequence.shape[1] >= ARC_INPUT_FLAT_DIM:
                                preds_flat_bytes = predicted_byte_sequence[0, :ARC_INPUT_FLAT_DIM] # Get first item in batch, first ARC_INPUT_FLAT_DIM bytes
                                # Convert byte values (0-9 for ARC symbols) to long tensor and reshape
                                preds_grid = preds_flat_bytes.view(MAX_GRID_SIZE).long().cpu().numpy()
                            else:
                                print(f"Warning: Generated byte sequence too short ({predicted_byte_sequence.shape[1]} vs {ARC_INPUT_FLAT_DIM}). Using zeros.")
                                preds_grid = np.zeros(MAX_GRID_SIZE, dtype=int)
                        
                        # Unpad to original dimensions
                        h, w = original_dims
                        final_pred = preds_grid[:h, :w]
                        final_target = target_grid_np[:h, :w]

                        if np.array_equal(final_pred, final_target):
                            test_input_solved = True
                            break

                    if not test_input_solved:
                        task_solved_overall = False
                        break
                
                if task_solved_overall:
                    solved_tasks += 1
                    print(f"  Task {task_idx + 1}/{len(arc_eval_loader)} ({current_task_data.get('id', 'N/A')}): SOLVED")
                else:
                    print(f"  Task {task_idx + 1}/{len(arc_eval_loader)} ({current_task_data.get('id', 'N/A')}): FAILED")

        if total_tasks > 0:
            accuracy = (solved_tasks / total_tasks) * 100
            summary = f"ARC-AGI-2 Evaluation Summary:\n  Total tasks evaluated: {total_tasks}\n  Tasks solved: {solved_tasks}\n  Accuracy: {accuracy:.2f}%"
            print(f"\n{summary}")
            with open('arc_agi_2_evaluation_summary.txt', 'w') as f:
                f.write(summary)
        else:
            print("\nARC-AGI-2 Evaluation: No tasks were evaluated.")
            
    except Exception as e:
        print(f"❌ Error during ARC-AGI-2 evaluation: {e}")
        traceback.print_exc()

print("\n🔬 ARC-AGI-2 Evaluation Phase Completed.")
