# -----------------------------------------------------------------------------
# Dependency Installation Notes
# -----------------------------------------------------------------------------
 The following dependencies are required. Please install them in your Python environment,
 for example, using pip:

 pip install mediapy
 pip install torch
 pip install safetensors
 pip install numpy

 For advanced optimizations, consider installing the following:
 pip install flash-attn --no-build-isolation
 pip install deepspeed
 pip install accelerate
 pip install xformers

 It's recommended to use a virtual environment.
 -----------------------------------------------------------------------------

This section contains the set-up components for training the ctm model with the byte-level encoder with binary patches, ctm processing with synpase system set to multi-objective, 



and binary patches from the ctm (after 20 rounds of COT thinking) refined and trained with MCMC to encourage the model to have reasoning steps closely related to the best answer, 



Each epoch is saved as a safetensor checkpoint to preserve training progress.

In [None]:
import os

# Save the original working directory Before LFS to avoid python corruption errors.
original_dir = os.getcwd()

# Add the Git LFS APT repository
!curl -s https://packagecloud.io/install/repositories/github/git-lfs/script.deb.sh | bash #You need to be the root user or else this won't work. 
!apt install git-lfs -y 
!git lfs install
!pip install pickleshare
!git clone https://github.com/viasky657/Arc-AGI-2.git
%cd Arc-AGI-2
!git lfs pull #This is to download the actual model tensors and not the Github LFS pointer files.  #This is to download the actual model tensors and not the Github LFS pointer files. 

In [None]:
!git lfs fetch --all
!git lfs checkout

In [None]:
import os

os.chdir("/workspace/Arc-AGI-2")  # Set your root working directory
print("Current directory:", os.getcwd())

In [None]:

!ls -lh *.safetensors #Check for successful LFS pull for the safetensors and the .pt files. 
!git rev-parse --is-inside-work-tree
!git lfs ls-files
!head -n 10 contineous-thought-machines/examples/checkpoints/ctm_arc_agi_2_enhanced_diffusion/arc_output_head_epoch_20.safetensors #The below output should be binary 
#and metadata if it is a real tensor and 
#not a LFS pointer.

In [None]:
#Change Back to Original directory to avoid errors. 
# Safely return to the original directory
%cd $original_di #Probably don't need to use this since the code above automatically creates the correct directory. 

In [None]:
!pip install matplotlib
!pip install mediapy
!pip install numpy
# For advanced optimizations, consider installing the following:
!pip install accelerate
!pip install diffusers
!pip install xformers
!pip install seaborn
!pip install safetensors
!pip install deepspeed
!apt-get update -y && apt-get install -y portaudio19-dev
!pip install pyaudio


Collecting matplotlib
  Downloading matplotlib-3.10.3-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (11 kB)
Collecting contourpy>=1.0.1 (from matplotlib)
  Downloading contourpy-1.3.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (5.5 kB)
Collecting cycler>=0.10 (from matplotlib)
  Downloading cycler-0.12.1-py3-none-any.whl.metadata (3.8 kB)
Collecting fonttools>=4.22.0 (from matplotlib)
  Downloading fonttools-4.58.5-cp310-cp310-manylinux2014_x86_64.manylinux_2_17_x86_64.whl.metadata (106 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m106.9/106.9 kB[0m [31m2.7 MB/s[0m eta [36m0:00:00[0ma [36m0:00:01[0m
[?25hCollecting kiwisolver>=1.3.1 (from matplotlib)
  Downloading kiwisolver-1.4.8-cp310-cp310-manylinux_2_12_x86_64.manylinux2010_x86_64.whl.metadata (6.2 kB)
Downloading matplotlib-3.10.3-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (8.6 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [

In [2]:
!apt-get update && apt-get install -y libaio-dev


Get:1 https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2204/x86_64  InRelease [1581 B]
Get:2 http://security.ubuntu.com/ubuntu jammy-security InRelease [129 kB]      
Get:3 http://archive.ubuntu.com/ubuntu jammy InRelease [270 kB]                
Get:4 https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2204/x86_64  Packages [1804 kB]
Get:5 http://security.ubuntu.com/ubuntu jammy-security/universe amd64 Packages [1266 kB]
Get:6 https://ppa.launchpadcontent.net/deadsnakes/ppa/ubuntu jammy InRelease [18.1 kB]
Get:7 http://archive.ubuntu.com/ubuntu jammy-updates InRelease [128 kB]        
Get:8 https://ppa.launchpadcontent.net/deadsnakes/ppa/ubuntu jammy/main amd64 Packages [32.9 kB]
Get:9 http://archive.ubuntu.com/ubuntu jammy-backports InRelease [127 kB]      
Get:10 http://security.ubuntu.com/ubuntu jammy-security/restricted amd64 Packages [4795 kB]
Get:11 http://archive.ubuntu.com/ubuntu jammy/multiverse amd64 Packages [266 kB]
Get:12 http://archive.ubuntu.com

In [3]:
!mkdir -p /root/.triton/autotune
#Triton (used by xFormers or FlashAttention) is trying to autotune kernels.

#df command is likely being used internally to check disk usage of Triton cache.

#Safe to ignore unless you rely on persistent Triton tuning across sessions.

#🛠️ Fix (optional): Create the directory manually using the code above.

# Setup Section for the Arc Training

In [None]:

import torch
import torch.nn as nn
import torch.nn.functional as F
import random
from typing import Optional, Callable, Tuple, Dict, Any, List, Union
from dataclasses import dataclass, field
import math
import numpy as np
import warnings # For ARCGridOutputSpace warnings
import sys
import os
import json
import glob
import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader
import torch.optim as optim
from accelerate import Accelerator

WORKSPACE_ROOT = "/workspace/Arc-AGI-2"

ARC_TRAIN_DIR = os.path.join(WORKSPACE_ROOT, "contineous_thought_machines", "data", "training")
ARC_EVAL_DIR = os.path.join(WORKSPACE_ROOT, "contineous_thought_machines", "data", "evaluation")

def find_json_file(filename, search_dir):
    """
    Search for a specific JSON file by filename in a given directory tree.
    """
    for root, _, files in os.walk(search_dir):
        if filename in files:
            return os.path.join(root, filename)
    return None

def resolve_json_files(directory):
    """
    Collect all JSON files in the given directory.
    If any is missing, try to find it elsewhere in the workspace.
    Returns a list of absolute file paths.
    """
    json_files = []
    # Get all JSON files that exist in the given directory
    for file in os.listdir(directory):
        if file.endswith(".json"):
            abs_path = os.path.join(directory, file)
            if os.path.exists(abs_path):
                json_files.append(abs_path)
            else:
                # Try to find it in the workspace
                print(f"[WARN] File not found in expected path: {abs_path}. Searching workspace...")
                found = find_json_file(file, WORKSPACE_ROOT)
                if found:
                    print(f"[INFO] Found {file} at: {found}")
                    json_files.append(found)
                else:
                    print(f"[ERROR] Could not find {file} anywhere in {WORKSPACE_ROOT}")
    return json_files

# Use the function for both training and evaluation dirs
train_json_files = resolve_json_files(ARC_TRAIN_DIR)
eval_json_files = resolve_json_files(ARC_EVAL_DIR)

print(f"✅ Found {len(train_json_files)} training JSON files.")
print(f"✅ Found {len(eval_json_files)} evaluation JSON files.")

# Example: show first few
print("Training files:", train_json_files[:3])
print("Evaluation files:", eval_json_files[:3])


MAX_GRID_SIZE = (30, 30)
NUM_ARC_SYMBOLS = 10
PADDING_VALUE = -1 # A value not in 0-9 to be ignored by the loss function

# Configuration for ARC-AGI-2 Training (shared constants)
ARC_INPUT_FLAT_DIM = MAX_GRID_SIZE[0] * MAX_GRID_SIZE[1]

print(f"Using MAX_GRID_SIZE: {MAX_GRID_SIZE}")
print(f"Using NUM_ARC_SYMBOLS: {NUM_ARC_SYMBOLS}")
print(f"Using ARC_INPUT_FLAT_DIM: {ARC_INPUT_FLAT_DIM}") 

# Ensure the workspace root is in sys.path for correct module resolution.
if WORKSPACE_ROOT not in sys.path:
    sys.path.insert(0, WORKSPACE_ROOT)
    print(f"[INFO] Added workspace root to sys.path: {WORKSPACE_ROOT}")

# --- Constants and Configs ---
MAX_GRID_SIZE = (30, 30)
PADDING_VALUE = -1
NUM_ARC_SYMBOLS = 10
ARC_INPUT_FLAT_DIM = MAX_GRID_SIZE[0] * MAX_GRID_SIZE[1]
device = "cuda" if torch.cuda.is_available() else "cpu"
LEARNING_RATE = 1e-4
CHECKPOINT_DIR = "checkpoints"
ARC_TRAIN_DIR = "/workspace/Arc-AGI-2/contineous_thought_machines/data/training" #Training Dataset Directory

ACCELERATE_AVAILABLE = True
try:
    from accelerate import Accelerator
except ImportError:
    ACCELERATE_AVAILABLE = False
    Accelerator = None

# Check for xformers
XFORMERS_AVAILABLE = False
if device == "cuda":
    try:
        import xformers
        XFORMERS_AVAILABLE = True
    except ImportError:
        pass

# Check for torch.compile
TORCH_COMPILE_AVAILABLE = hasattr(torch, 'compile')

# Check for deepspeed
DEEPSPEED_AVAILABLE = False
if device == "cuda":
    try:
        import deepspeed
        DEEPSPEED_AVAILABLE = True
    except ImportError:
        pass

# A reasonable default for dataloader config
OPTIMIZED_DATALOADER_CONFIG = {
    "num_workers": 4,
    "pin_memory": True,
    "prefetch_factor": 2
} if torch.cuda.is_available() else {}

# --- Context: 2D Grid Padding (from original code) ---
# This function handles padding at the 2D grid level, before serialization.
def pad_grid(grid_list, max_dims, pad_value):
    """Pads a 2D grid to specified maximum dimensions."""
    grid_np = np.array(grid_list, dtype=np.int32)
    padded_grid = np.full(max_dims, pad_value, dtype=np.int32)
    h, w = grid_np.shape
    padded_grid[:h, :w] = grid_np
    return padded_grid

# --- Fix: Byte Sequence Padding for the Model --- #
# According to the model explanation, the key step is to pad the *serialized byte sequence*
# to `config.max_sequence_length`. The function below implements this logic.

# Define the model's expected input dimension from the configuration.
MAX_SEQUENCE_LENGTH = 8192
PADDING_BYTE_VALUE = 0

def serialize_and_pad_grid(grid, max_len=MAX_SEQUENCE_LENGTH, pad_value=PADDING_BYTE_VALUE):
    """
    Serializes a grid into a byte sequence and pads it to a fixed length.

    This function implements the required padding logic for the LearnedBytePatcherEncoder.
    It takes a grid, converts it to a flat byte sequence, and then pads or truncates
    it to `max_sequence_length` (8192 bytes), ensuring a fixed-size input for the model.
    
    Args:
        grid (list or np.ndarray): The input ARC grid.
        max_len (int): The target length for the byte sequence, corresponding to
                       `config.max_sequence_length`.
        pad_value (int): The byte value to use for padding (0-255).

    Returns:
        bytes: The padded byte sequence of length `max_len`.
    """
    # Convert the grid to a NumPy array of single bytes (uint8) and flatten it.
    # ARC values (0-9) fit perfectly within a single byte.
    flat_array = np.array(grid, dtype=np.uint8).flatten()

    # Serialize the flattened array into a raw byte sequence.
    byte_sequence = flat_array.tobytes()

    # Calculate the number of padding bytes needed.
    padding_len = max_len - len(byte_sequence)

    if padding_len < 0:
        # If the original sequence is too long, truncate it.
        padded_sequence = byte_sequence[:max_len]
    else:
        # If the sequence is shorter, create padding and append it.
        padding = bytes([pad_value] * padding_len)
        padded_sequence = byte_sequence + padding
        
    return padded_sequence

class NewCustomARCGridDataset(Dataset):
    def __init__(self, data_dir, max_grid_size=MAX_GRID_SIZE, padding_value=PADDING_VALUE):
        self.data_dir = data_dir
        self.task_files = glob.glob(os.path.join(data_dir, "*.json"))
        self.max_grid_size = max_grid_size
        self.padding_value = padding_value
        self.tasks = []
        print(f"NewCustomARCGridDataset: Looking for tasks in: {data_dir}")
        if not self.task_files:
            print(f"NewCustomARCGridDataset Warning: No JSON files found in {data_dir}. Attempting fallback search.")
            base_dir = '/workspace/Arc-AGI-2'
            self.task_files = []
            for root, dirs, files in os.walk(base_dir):
                for file in files:
                    if file.endswith('.json'):
                        self.task_files.append(os.path.join(root, file))
            if self.task_files:
                print(f"Found {len(self.task_files)} JSON files via fallback search in {base_dir}")
            else:
                print(f"No JSON files found via fallback search in {base_dir}. Dataset will be empty.")
        for task_file in self.task_files:
            try:
                with open(task_file, 'r') as f:
                    self.tasks.append(json.load(f))
            except Exception as e:
                print(f"NewCustomARCGridDataset Warning: Could not load or parse {task_file}: {e}")
        if not self.tasks:
            print(f"NewCustomARCGridDataset Warning: No tasks successfully loaded from {data_dir}.")
        else:
            print(f"NewCustomARCGridDataset: Loaded {len(self.tasks)} ARC tasks from {data_dir}.")

    def __len__(self):
        return len(self.tasks)

    def __getitem__(self, idx):
        task_data = self.tasks[idx]
        processed_task = {'train': [], 'test': [], 'id': os.path.basename(self.task_files[idx]) if idx < len(self.task_files) else 'unknown_task'}

        for pair_type in ['train', 'test']:
            for item in task_data.get(pair_type, []):
                input_grid_list = item.get('input', [])
                output_grid_list = item.get('output', [])
                
                original_input_dims = (len(input_grid_list), len(input_grid_list[0]) if input_grid_list and input_grid_list[0] else (0,0))
                original_output_dims = (len(output_grid_list), len(output_grid_list[0]) if output_grid_list and output_grid_list[0] else (0,0))

                padded_input_np = pad_grid(input_grid_list, self.max_grid_size, self.padding_value)
                padded_output_np = pad_grid(output_grid_list, self.max_grid_size, self.padding_value)
                
                processed_task[pair_type].append({
                    'input': torch.from_numpy(padded_input_np).long(),
                    'output': torch.from_numpy(padded_output_np).long(),
                    'original_input_dims': original_input_dims,
                    'original_output_dims': original_output_dims
                })
        return processed_task

def collate_fn_new_custom_arc(batch_of_tasks):
    input_byte_sequences_list = []
    target_byte_sequences_for_diffusion_list = []
    original_target_grids_for_ce_loss_list = []

    for task in batch_of_tasks:
        if not isinstance(task, dict):
            continue

        # Process 'train' pairs from the task
        for train_pair in task.get('train', []):
            if not isinstance(train_pair, dict) or 'input' not in train_pair or 'output' not in train_pair:
                continue

            # train_pair['input'] and train_pair['output'] are already padded 2D LongTensors from NewCustomARCGridDataset
            input_grid_np = train_pair['input'].numpy() # Convert to numpy for serialize_and_pad_grid
            target_grid_np = train_pair['output'].numpy()

            # 1. Create input_byte_sequences (uint8)
            input_bytes = serialize_and_pad_grid(input_grid_np, max_len=MAX_SEQUENCE_LENGTH, pad_value=PADDING_BYTE_VALUE)
            input_byte_sequences_list.append(torch.tensor(list(input_bytes), dtype=torch.uint8))

            # 2. Create target_byte_sequences_for_diffusion (uint8)
            target_bytes_for_diffusion = serialize_and_pad_grid(target_grid_np, max_len=MAX_SEQUENCE_LENGTH, pad_value=PADDING_BYTE_VALUE)
            target_byte_sequences_for_diffusion_list.append(torch.tensor(list(target_bytes_for_diffusion), dtype=torch.uint8))

            # 3. Keep original_target_grids_for_ce_loss (long tensor, flattened)
            original_target_grids_for_ce_loss_list.append(train_pair['output'].view(-1)) # Flattened LongTensor
            
    if not input_byte_sequences_list:
        return {
            'input_byte_sequences': torch.empty(0, MAX_SEQUENCE_LENGTH, dtype=torch.uint8),
            'target_byte_sequences_for_diffusion': torch.empty(0, MAX_SEQUENCE_LENGTH, dtype=torch.uint8),
            'original_target_grids_for_ce_loss': torch.empty(0, ARC_INPUT_FLAT_DIM, dtype=torch.long),
        }

    # Stack all collected tensors
    final_input_byte_sequences = torch.stack(input_byte_sequences_list)
    final_target_byte_sequences_for_diffusion = torch.stack(target_byte_sequences_for_diffusion_list)
    final_original_target_grids_for_ce_loss = torch.stack(original_target_grids_for_ce_loss_list)
    
    return {
        'input_byte_sequences': final_input_byte_sequences,
        'target_byte_sequences_for_diffusion': final_target_byte_sequences_for_diffusion,
        'original_target_grids_for_ce_loss': final_original_target_grids_for_ce_loss,
    }

# --- ARC Training Setup ---
ARC_OUTPUT_HEAD_DIM = ARC_INPUT_FLAT_DIM * NUM_ARC_SYMBOLS
ARC_TASK_ID = 3
print(f"ARC Output Head Dim: {ARC_OUTPUT_HEAD_DIM}")

ctm_model_arc, optimizer_arc, accelerator_arc = None, None, None

print("\n-----------------------------------------------------------------------------")
print("Initializing Configuration for Integrated Diffusion CTM")
print("-----------------------------------------------------------------------------")
print(f"Using device: {device}")
if device == "cuda":
    print("✅ Mixed precision training enabled (BF16) - Expected ~2x speedup")

print("\n🚀 OPTIMIZATION STATUS:")
print(f"  ⚡ torch.compile: {'✅' if TORCH_COMPILE_AVAILABLE else '❌'}")
print(f"  📈 Accelerate: {'✅' if ACCELERATE_AVAILABLE else '❌'}")
print(f"  ⚡ xFormers: {'✅' if XFORMERS_AVAILABLE else '❌'}")
print(f"  ⚡ Deepspeed: {'✅' if DEEPSPEED_AVAILABLE else '❌'}")

# From contineous_thought_machines/models/constants.py
VALID_NEURON_SELECT_TYPES = [
    'first-last', 'random', 'random-pairing',  # Legacy
    # Biologically-inspired types
    'bio_hebbian', 'bio_plasticity', 'bio_competitive', 'bio_homeostatic',
    'bio_evolutionary', 'bio_stdp', 'bio_criticality', 'bio_multi_objective',
    # Hybrid approaches
    'adaptive_random', 'performance_guided', 'task_aware'
]

VALID_POSITIONAL_EMBEDDING_TYPES = [
    'learnable-fourier', 'multi-learnable-fourier',
    'custom-rotational'
]

# From contineous_thought_machines/models/ctm_Diffusion_NEWNEW.py
from dataclasses import dataclass, field
from typing import Dict, Optional, Tuple, Union, Any, List
import math

from contineous_thought_machines.models.ctm_Diffusion_NEWNEW import EnhancedCTMDiffusion

@dataclass
class EnhancedCTMConfig: # Renamed from ContinualLearningConfig for consistency in the target file
    """Enhanced configuration for continual learning CTM-diffusion model,
    incorporating binary processing, multi-task learning, and advanced CTM features."""
    
    # Model architecture (General Transformer/Diffusion settings)
    d_model: int = 512  # Main model dimensionality
    n_heads: int = 8
    n_layers: int = 24
    max_sequence_length: int = 8192 # Max input sequence length in terms of bytes or patches
    dropout: float = 0.1
    
    # --- Byte Processing Options ---
    patch_embedding_dim: int = 256         # <<< NEW: Output embedding dimension per patch from patcher
    patch_encoder_cnn_channels: int = 64   # <<< NEW: Intermediate channels for CNN patch encoder

    # --- Dynamic Entropy Patching Options (Inspired by BLT paper) ---
    use_dynamic_entropy_patcher: bool = True # Flag to enable dynamic entropy-based patching
    entropy_patcher_threshold_type: str = "global"  # 'global' or 'relative_monotonic'
    entropy_patcher_global_threshold: float = 0.75 # Entropy threshold for 'global' type
    entropy_patcher_relative_threshold: float = 0.1 # Entropy diff threshold for 'relative_monotonic'
    entropy_patcher_min_patch_size: int = 4      # Minimum number of bytes in a dynamic patch
    entropy_patcher_max_patch_size: int = 128    # Maximum number of bytes in a dynamic patch (for CNN encoder)
    
    # --- Learnable Entropy Model Parameters (for _EntropyProxyModel) ---
    entropy_model_byte_vocab_size: int = 256
    entropy_model_embedding_dim: int = 64
    entropy_model_hidden_dim: int = 128
    entropy_model_num_layers: int = 1
    entropy_model_dropout: float = 0.1
    entropy_model_loss_weight: float = 0.1 # Weight for its auxiliary loss contribution
    # Note: These parameters are used if use_dynamic_entropy_patcher is True,
    # as LearnedBytePatcherEncoder now instantiates the learnable _EntropyProxyModel.
    
    # Fallback if not using learned_patch_encoder or dynamic_entropy_patcher
    byte_embedding_dim: int = 256
    multi_granularity: bool = False # Default to False if patcher is preferred
    # multi_granularity_output_dim is complex to predefine, MGP should expose its output dim.
    # For now, if multi_granularity is True AND use_learned_patch_encoder is False, this would be used.
    multi_granularity_output_dim: int = 256 # Placeholder if MGP is used.
    
    hierarchical_processing: bool = True # General flag, could apply to patcher or MGP
    
    # CTM Core Parameters (Specific to the OriginalCTMCore module)
    # These are prefixed with 'ctm_' to distinguish from general model params
    ctm_iterations: int = 5  # Original 'iterations'
    ctm_d_model: int = 512   # Original 'd_model' for CTM's internal latent space
    ctm_input_dim: int = 256 # Dimensionality of inputs to CTM (e.g., from byte embeddings or other features)
                             # This was 'd_input' in OriginalCTMCore if it took external features.
                             # If CTM processes outputs of byte_embedding, this might be byte_embedding_dim.
    ctm_heads: int = 8       # Attention heads within CTM
    ctm_n_synch_out: int = 64
    ctm_n_synch_action: int = 64
    ctm_synapse_depth: int = 3
    ctm_memory_length: int = 10
    ctm_deep_nlms: bool = True
    ctm_memory_hidden_dims: int = 2048
    ctm_do_layernorm_nlm: bool = False
    ctm_out_dims: int = 512  # Output dimension of CTM's own projector
    ctm_prediction_reshaper: list = field(default_factory=lambda: [-1])
    ctm_dropout: float = 0.1
    ctm_dropout_nlm: Optional[float] = None
    # Neuron selection strategy. Available options:
    # Legacy: 'first-last', 'random', 'random-pairing'
    # Biologically-inspired: 'bio_hebbian', 'bio_plasticity', 'bio_competitive',
    #                        'bio_homeostatic', 'bio_evolutionary', 'bio_stdp',
    #                        'bio_criticality', 'bio_multi_objective'
    # Hybrid: 'adaptive_random', 'performance_guided', 'task_aware'
    ctm_neuron_select_type: str = 'bio_multi_objective'
    ctm_n_random_pairing_self: int = 0
    
    #Inferred_Latent_Dimensions Set to Avoid runtime errors but it does not functionally do anything in the model or program processing. 
    inferred_task_latent_dim=512

    # Diffusion Parameters
    diffusion_steps: int = 1000
    noise_schedule: str = "cosine" # e.g., "linear", "cosine"
    diffusion_beta_start: float = 0.0001
    diffusion_beta_end: float = 0.02
    diffusion_timesteps: int = 1000 # Number of timesteps for the diffusion process
    ctm_diffusion_coupling_strength: float = 0.8 # How CTM influences diffusion
    adaptive_scheduling: bool = True  # CTM-adaptive diffusion timestep scheduling
    iterative_refinement: bool = True # Iterative CTM-diffusion refinement for sampling
    
    # Training Efficiency
    mixed_precision: bool = True
    gradient_checkpointing: bool = True
    sparse_attention: bool = True  # Now implemented with BinarySparseAttention
    adaptive_depth: bool = False   # Defaulting to False, can be enabled if implemented
    use_activity_plasticity: bool = True # To enable/disable plasticity updates; Needs to be set to TRUE
    ctm_use_internal_feedback: bool = True # Enable self-modulating feedback within the CTM core

    # --- Bidirectional Reasoning Parameters ---
    enable_bidirectional_reasoning: bool = True # Allows CTM to move forward/backward in its thought process
    reasoning_step_gating_threshold: float = 0.7 # Confidence threshold for the reasoning controller to terminate
    max_reasoning_steps: int = 15 # Max total steps in a bidirectional reasoning loop to prevent infinite loops
    
    # Sparse Attention Parameters
    sparse_attention_ratio: float = 0.1  # Keep only 10% of attention connections
    binary_pattern_size: int = 8  # Size of binary patterns to detect

    # Attention Mechanism Type
    attention_type: str = "WINA"  # Options: "standard", "binary_sparse", "WINA" #Need to use WINA attention in place of "standard"
    control_dim: int = 64 # Dimension for WINA control mechanism

    # Positional Embedding Parameters
    positional_embedding_type: Optional[str] = 'multi-learnable-fourier' # e.g., 'custom-rotational', 'learnable-fourier', multi-learnable-fourier' #Can set the value here.
    positional_embedding_dim: Optional[int] = None  # Dimension of the positional embedding, defaults to ctm_input_dim if None
    reshape_patch_sequence_to_grid: bool = True # If True, reshape patch sequence to a 2D grid for 2D PEs. Must set to true if using 2D Grid for Positional Embeddings.
    patch_grid_width: Optional[int] = None       # Desired width of the patch grid if reshaping

    # --- Hierarchical Reasoning Model (HRM) Parameters ---
    use_hrm_core: bool = True # Set to True to use the HierarchicalCTM core
    hrm_high_level_cycles: int = 4 # N: Number of high-level cycles
    hrm_low_level_timesteps: int = 8 # T: Number of low-level timesteps per high-level cycle
    program_vocab_size: int = 1024 # Vocabulary size for the program synthesizer
    program_synth_n_heads: int = 4
    program_synth_n_layers: int = 3
    program_synth_d_ff: int = 1024
    ltm_size: int = 2048 # Size of the long-term memory
    ltm_surprise_threshold: float = 0.6 # Surprise threshold for storing in LTM
    ltm_top_k: int = 5 # Top-k for memory retrieval
    replay_batch_size: int = 4 # Batch size for memory replay
    replay_policy: str = "surprise_weighted_replay" # "simple_replay", "surprise_weighted_replay", "usefulness_replay"

    # Pipeline Parallelism Parameters
    enable_pipeline_parallelism: bool = True
    pipeline_stages: int = 3  # CTM, Diffusion prep, Diffusion exec
    pipeline_overlap_ratio: float = 0.7  # Target overlap ratio
    
    # Adaptive Batch Sizing Parameters
    enable_adaptive_batching: bool = True
    initial_batch_size: int = 32
    min_batch_size: int = 8
    max_batch_size: int = 256
    batch_adaptation_frequency: int = 100
    memory_threshold_high: float = 0.85
    memory_threshold_low: float = 0.6
    
    # Smart Data Sampling Parameters
    enable_smart_sampling: bool = True
    sample_importance_weight: float = 0.6
    sample_diversity_weight: float = 0.4
    initial_sample_ratio: float = 0.3
    complexity_analysis_enabled: bool = True
    
    # Multi-input/output parameters
    num_inputs: int = 1  # Number of input streams
    num_outputs: int = 1  # Number of output heads
    output_dims: List[int] = field(default_factory=lambda: [64])  # Dimensions for each output head
    
    # Self-supervised learning
    ssl_dim: int = 128  # Dimension for self-supervised projection
    ssl_weight: float = 0.1  # Weight for self-supervised loss
    ssl_temperature: float = 0.07  # Temperature for contrastive loss
    ssl_noise_std: float = 0.1  # Noise standard deviation for contrastive augmentation
    
    # Spatiotemporal Processing
    use_spatial: bool = True  # Enable spatial processing for image/video data
    
    # WINA Attention
    use_wina_attention: bool = True  # Enable WINA sparse attention
    
    # Multi-task Learning Parameters
    max_tasks: int = 50  # Maximum number of tasks for continual learning
    # Added to resolve TypeError for unexpected keyword arguments
    vocab_size: Optional[int] = None
    output_audio_bytes: bool = False
    inferred_task_latent_dim: Optional[int] = None # Default to None, __post_init__ handles it
    use_hipa_attention: bool = False # Default to False
    hipa_num_heads: Optional[int] = None # Default to None
    audio_output_dtype_str: Optional[str] = "float32" # Default as per __post_init__ logic
    unet_input_feature_dim: Optional[int] = None # Default to None, __post_init__ calculates it

    # --- JEPA Training Parameters (Integrated with LearnedBytePatcherEncoder) ---
    use_jepa_training: bool = False
    # jepa_embed_dim will be derived from patch_embedding_dim if dynamic_entropy_patcher is used
    jepa_predictor_hidden_dim: int = 512 # Hidden dimension of JEPA predictor MLP
    jepa_mask_ratio_min: float = 0.15 # Min proportion of patch sequence to mask for target
    jepa_mask_ratio_max: float = 0.75 # Max proportion of patch sequence to mask for target
    jepa_context_scale_min: float = 0.3 # Min proportion of patches for context
    jepa_context_scale_max: float = 0.7 # Max proportion of patches for context
    jepa_momentum_beta: float = 0.996 # Momentum for target encoder update
    jepa_loss_weight: float = 0.1 # Weight for the JEPA loss component
    jepa_num_target_blocks: int = 1 # Number of target blocks to predict

    # --- Global Plasticity Loss Parameters ---
    local_hebbian_loss_weight: float = 0.01 # New weight for backprop-based hebbian loss

    # --- Basal Ganglia Parameters --- #Controls action suppression so that the model's unwanted first unrelated thoughts are suppressed which helps with model safety. Is needed for action suppresion.
    ctm_enable_basal_ganglia: bool = True
    ctm_bg_dopamine_dim: int = 32

    # --- Synaptic Empathy Parameters ---
    enable_synaptic_empathy: bool = True # Set to True to use the new SynapticEmpathy module
    synaptic_empathy_reward_weight: float = 0.1

    # --- Mirror Neuron / High-Level Empathy Parameters ---
    enable_mirror_neurons: bool = True # Set to True to use the high-level MirrorNeuronLayer
    num_emotion_dim: int = 4 # Dimensionality of the emotion state vector
    goal_dim: int = 8 # Dimensionality of the predicted goal vector
    mirror_reward_weight: float = 0.2 # Weight for the selfless reward signal


    # --- Confidence Thresholding Parameters ---
    confidence_threshold: float = 0.0 # Confidence threshold for abstaining. If > 0, model can abstain.
 
    # --- Consciousness Controller Parameters ---
    enable_consciousness_controller: bool = True
    consciousness_max_attention_steps: int = 100

    # --- Recursion Parameters ---
    max_recursion: int = 3
    early_stop_threshold: float = 1e-3
    
    def __post_init__(self):
        # Validate output dimensions
        if len(self.output_dims) != self.num_outputs:
            raise ValueError(f"output_dims length ({len(self.output_dims)}) must match num_outputs ({self.num_outputs})")

        # Merged content from the second __post_init__
        if hasattr(self, 'ctm_prediction_reshaper') and self.ctm_prediction_reshaper == [-1] and self.vocab_size is not None:
            pass
        if hasattr(self, 'ctm_dropout_nlm') and self.ctm_dropout_nlm is None and hasattr(self, 'ctm_dropout'):
            self.ctm_dropout_nlm = self.ctm_dropout
        
        if hasattr(self, 'ctm_neuron_select_type') and \
           VALID_NEURON_SELECT_TYPES is not None and self.ctm_neuron_select_type not in VALID_NEURON_SELECT_TYPES:
            print(f"Warning: ctm_neuron_select_type '{self.ctm_neuron_select_type}' is not in VALID_NEURON_SELECT_TYPES ({VALID_NEURON_SELECT_TYPES}).")

        if hasattr(self, 'positional_embedding_type') and self.positional_embedding_type is not None:
            if VALID_POSITIONAL_EMBEDDING_TYPES is None: # Fallback if import failed
                print(f"Warning: VALID_POSITIONAL_EMBEDDING_TYPES not available for validation.")
            elif self.positional_embedding_type not in VALID_POSITIONAL_EMBEDDING_TYPES:
                print(f"Warning: positional_embedding_type '{self.positional_embedding_type}' is not in VALID_POSITIONAL_EMBEDDING_TYPES ({VALID_POSITIONAL_EMBEDDING_TYPES}).")
            if self.positional_embedding_dim is not None and self.positional_embedding_dim <= 0:
                raise ValueError("positional_embedding_dim must be positive if set.")
            
            if self.reshape_patch_sequence_to_grid:
                if self.patch_grid_width is None or self.patch_grid_width <= 0:
                    raise ValueError("patch_grid_width must be a positive integer if reshape_patch_sequence_to_grid is True.")
                if self.positional_embedding_type not in ['learnable-fourier', 'multi-learnable-fourier', 'custom-rotational']:
                    print(f"Warning: reshape_patch_sequence_to_grid is True, but positional_embedding_type ('{self.positional_embedding_type}') is not a typical 2D PE. Ensure compatibility.")

        # Validations for new patch encoder
        if self.use_dynamic_entropy_patcher:
            if self.patch_embedding_dim <= 0:
                raise ValueError("patch_embedding_dim must be positive if use_dynamic_entropy_patcher is True.")
            if self.entropy_patcher_min_patch_size <= 0:
                raise ValueError("entropy_patcher_min_patch_size must be positive.")
            if self.entropy_patcher_max_patch_size < self.entropy_patcher_min_patch_size:
                raise ValueError("entropy_patcher_max_patch_size must be >= entropy_patcher_min_patch_size.")
            if self.entropy_patcher_threshold_type not in ["global", "relative_monotonic"]:
                raise ValueError("entropy_patcher_threshold_type must be 'global' or 'relative_monotonic'.")
        elif self.multi_granularity and self.multi_granularity_output_dim <= 0:
            print("Warning: multi_granularity_output_dim might not be correctly set for validation if not using a patcher and MGP is active.")
        
        if not hasattr(self, 'inferred_task_latent_dim') or self.inferred_task_latent_dim is None:
            print("Warning: inferred_task_latent_dim not found or is None in config, defaulting to 512.")
            self.inferred_task_latent_dim = 512
        elif self.inferred_task_latent_dim <= 0: # This check is now safe
            raise ValueError("inferred_task_latent_dim must be positive.")
 
        if hasattr(self, 'use_hipa_attention') and self.use_hipa_attention and \
            (not hasattr(self, 'hipa_num_heads') or self.hipa_num_heads <= 0):
             raise ValueError("hipa_num_heads must be positive if use_hipa_attention is True.")
 
        if hasattr(self, 'audio_output_dtype_str'):
            if self.audio_output_dtype_str == "float32":
                self.audio_output_item_size = 4
            elif self.audio_output_dtype_str == "int16":
                self.audio_output_item_size = 2
            else:
                if hasattr(self, 'output_audio_bytes') and self.output_audio_bytes:
                    raise ValueError(f"Unsupported audio_output_dtype_str: {self.audio_output_dtype_str} when output_audio_bytes is True.")
                else:
                    self.audio_output_item_size = 4
        elif hasattr(self, 'output_audio_bytes') and self.output_audio_bytes:
            if not hasattr(self, 'audio_output_dtype_str') or self.audio_output_dtype_str is None:
                raise ValueError("audio_output_dtype_str must be defined in config if output_audio_bytes is True.")
        else:
            self.audio_output_item_size = 4

        # Calculate unet_input_feature_dim if not set
        if self.unet_input_feature_dim is None:
            if self.max_sequence_length <= 0 or self.audio_output_item_size <= 0:
                raise ValueError("max_sequence_length and audio_output_item_size must be positive to calculate unet_input_feature_dim.")
            self.unet_input_feature_dim = self.max_sequence_length // self.audio_output_item_size
            if self.unet_input_feature_dim <= 0:
                raise ValueError(f"Calculated unet_input_feature_dim ({self.unet_input_feature_dim}) must be positive. Check max_sequence_length and audio_output_item_size.")
        elif self.unet_input_feature_dim <= 0:
            raise ValueError("unet_input_feature_dim, if set, must be positive.")

        if self.use_jepa_training:
            if not (0 < self.jepa_mask_ratio_min < 1 and 0 < self.jepa_mask_ratio_max < 1 and self.jepa_mask_ratio_min <= self.jepa_mask_ratio_max):
                raise ValueError("JEPA mask ratios must be between 0 and 1, with min <= max.")
            if not (0 < self.jepa_context_scale_min < 1 and 0 < self.jepa_context_scale_max < 1 and self.jepa_context_scale_min <= self.jepa_context_scale_max):
                raise ValueError("JEPA context scales must be between 0 and 1, with min <= max.")
            if not (0 <= self.jepa_momentum_beta < 1):
                raise ValueError("jepa_momentum_beta must be between 0 and 1.")
            if self.jepa_num_target_blocks <= 0:
                raise ValueError("jepa_num_target_blocks must be positive.")
            if not self.use_dynamic_entropy_patcher:
                print("Warning: JEPA training is enabled but use_dynamic_entropy_patcher is False. JEPA relies on the patch embeddings from LearnedBytePatcherEncoder.")
        

        # Validations for recursion parameters
        if self.max_recursion < 1:
            raise ValueError("max_recursion must be at least 1")
        if self.early_stop_threshold <= 0:
            raise ValueError("early_stop_threshold must be positive.")

# --- Model Configuration ---
config_arc_diffusion = EnhancedCTMConfig(
    d_model=512,
    n_heads=8,
    n_layers=24,
    max_sequence_length=MAX_SEQUENCE_LENGTH,
    dropout=0.1,
    use_dynamic_entropy_patcher=True,
    patch_embedding_dim=256,
    patch_grid_width=16,
    patch_encoder_cnn_channels=64,
    entropy_patcher_threshold_type="global",
    entropy_patcher_global_threshold=0.75,
    entropy_patcher_relative_threshold=0.1,
    entropy_patcher_min_patch_size=4,
    entropy_patcher_max_patch_size=128,
    entropy_model_byte_vocab_size=256,
    entropy_model_embedding_dim=64,
    entropy_model_hidden_dim=128,
    entropy_model_num_layers=1,
    entropy_model_dropout=0.1,
    entropy_model_loss_weight=0.1,
    ctm_input_dim=256,
    ctm_d_model=512,
    ctm_iterations=5,
    ctm_heads=8,
    ctm_out_dims=512,
    ctm_neuron_select_type='bio_multi_objective',
    positional_embedding_type='multi-learnable-fourier',
    positional_embedding_dim=None,
    reshape_patch_sequence_to_grid=True,
    enable_pipeline_parallelism=True,
    pipeline_stages=4,
    pipeline_overlap_ratio=0.7,
    enable_adaptive_batching=True,
    initial_batch_size=32,
    min_batch_size=8,
    max_batch_size=256,
    batch_adaptation_frequency=100,
    memory_threshold_high=0.85,
    memory_threshold_low=0.6,
    enable_smart_sampling=True,
    sample_importance_weight=0.6,
    sample_diversity_weight=0.4,
    initial_sample_ratio=0.3,
    complexity_analysis_enabled=True,
    num_inputs=1,
    num_outputs=1,
    output_dims=[64],
    ssl_dim=128,
    ssl_weight=0.1,
    ssl_temperature=0.07,
    ssl_noise_std=0.1,
    use_spatial=False,
    use_wina_attention=True,
    max_tasks=50,
    diffusion_steps=1000,
    ctm_diffusion_coupling_strength=0.8,
    vocab_size=None,
    output_audio_bytes=True,
    unet_input_feature_dim=MAX_SEQUENCE_LENGTH // 4, # Calculated based on float32 audio
    local_hebbian_loss_weight=0.01,
    enable_consciousness_controller=True,
    consciousness_max_attention_steps=100,
    use_hrm_core=True,
    attention_type="WINA",
    inferred_task_latent_dim=512 #This does nothing in the model training but is included in a placeholder to avoid possible errors with initializing Torch for training.
)
print("✓ EnhancedCTMConfig for ARC (config_arc_diffusion) created.")
    
if 'EnhancedCTMDiffusion' in globals() and EnhancedCTMDiffusion is not None:
    ctm_model_arc = EnhancedCTMDiffusion(config=config_arc_diffusion).to(device)
    print("✓ EnhancedCTMDiffusion model for ARC (ctm_model_arc) initialized.")

    # The new EnhancedCTMDiffusion model is end-to-end and does not require an external output head.
    print("✓ ARC Output Head is disabled as it's not needed for the new model.")

    # MCMC integration is disabled as per new model requirements.
    
    arc_trainable_params = list(ctm_model_arc.parameters()) # EnhancedCTMDiffusion parameters

    optimizer_arc = optim.AdamW([p for p in arc_trainable_params if p.requires_grad], lr=LEARNING_RATE, weight_decay=1e-4)
    
    if ACCELERATE_AVAILABLE:
        accelerator_arc = Accelerator()
        # Only the main model and optimizer need to be prepared.
        ctm_model_arc, optimizer_arc = accelerator_arc.prepare(ctm_model_arc, optimizer_arc)
        print("✓ ARC models (EnhancedCTMDiffusion) and optimizer prepared with Accelerate.")
else:
    print("⚠️ EnhancedCTMDiffusion model or its config for ARC-AGI-2 could not be initialized. Check imports.")

CHECKPOINT_DIR_ARC = os.path.join(CHECKPOINT_DIR, "ctm_arc_agi_2_enhanced_diffusion") # New checkpoint dir
os.makedirs(CHECKPOINT_DIR_ARC, exist_ok=True)
print(f"ARC Checkpoints will be saved to: {CHECKPOINT_DIR_ARC}")

NUM_EPOCHS_ARC = 20
ARC_BATCH_SIZE = 16

arc_train_dataset = NewCustomARCGridDataset(ARC_TRAIN_DIR)
arc_eval_dataset = NewCustomARCGridDataset(ARC_EVAL_DIR)

arc_train_loader, arc_eval_loader = None, None
if arc_train_dataset and len(arc_train_dataset) > 0:
    arc_train_loader = DataLoader(
        arc_train_dataset, batch_size=ARC_BATCH_SIZE, shuffle=True,
        collate_fn=collate_fn_new_custom_arc, **OPTIMIZED_DATALOADER_CONFIG
    )
    if accelerator_arc: arc_train_loader = accelerator_arc.prepare(arc_train_loader)
    print(f"✓ ARC Training DataLoader initialized with {len(arc_train_dataset)} tasks.")
else:
    print("⚠️ ARC Training DataLoader could not be initialized.")

if arc_eval_dataset and len(arc_eval_dataset) > 0:
    arc_eval_loader = DataLoader(
        arc_eval_dataset, batch_size=1, shuffle=False,
        collate_fn=collate_fn_new_custom_arc, **OPTIMIZED_DATALOADER_CONFIG
    )
    if accelerator_arc: arc_eval_loader = accelerator_arc.prepare(arc_eval_loader)
    print(f"✓ ARC Evaluation DataLoader initialized with {len(arc_eval_dataset)} tasks.")
else:
    print("⚠️ ARC Evaluation DataLoader could not be initialized.")

# The CE loss criterion is no longer needed as the model calculates its own loss.
print("\n✓ ARC-AGI-2 Setup Complete.")

ARC Output Head Dim: 9000

-----------------------------------------------------------------------------
Initializing Configuration and Model for ARC with EnhancedCTMDiffusion
-----------------------------------------------------------------------------


TypeError: EnhancedCTMConfig.__init__() got an unexpected keyword argument 'enable_consciousness_controller'

# Training Arc_AGI_2 and Principles Phase

In [None]:
# --- ARC-AGI-2 Meta-Learning Training Loop ---
import os
import torch
import torch.distributed as dist
from safetensors.torch import save_file
from torch.utils.data import DataLoader, Dataset
from tqdm.auto import tqdm
from torch.cuda.amp import autocast, GradScaler
import glob
import json
from dataclasses import dataclass, field
from typing import List, Optional, Any
import math

CUDA_LAUNCH_BLOCKING=1 #Diagnose cuda errors. 
# --- FIX: Define NUM_ARC_SYMBOLS globally for DataLoader workers ---
# The standard ARC task has 10 symbols (0-9).
NUM_ARC_SYMBOLS = 10
import numpy as np

print("\n" + "="*60)
print(f"🚀 STARTING PHASE 4: ARC-AGI-2 Meta-Learning Training")
print(f"   Epochs: {NUM_EPOCHS_ARC}, Batch Size: {ARC_BATCH_SIZE}, Task ID: {ARC_TASK_ID}")
print(f"   Device: {device if not accelerator_arc else accelerator_arc.device}")
print("="*60 + "\n")

# --- Principles Training Configuration ---
NUM_EPOCHS_PRINCIPLES = 3 #Can be lowered due to new DPPM++ Solver converging 10 epoch sooner.

# --- Training Configuration ---
USE_MIXED_PRECISION = torch.cuda.is_available()
autocast_dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16
GRADIENT_ACCUMULATION_STEPS = 2
MAX_GRAD_NORM = 1.0
scaler = torch.amp.GradScaler('cuda',enabled=USE_MIXED_PRECISION)

# --- Context: 2D Grid Padding (from original code) ---
def pad_grid(grid_list, max_dims, pad_value):
    """Pads a 2D grid to specified maximum dimensions."""
    grid_np = np.array(grid_list, dtype=np.int32)
    padded_grid = np.full(max_dims, pad_value, dtype=np.int32)
    h, w = grid_np.shape
    padded_grid[:h, :w] = grid_np
    return padded_grid

# --- Fix: Byte Sequence Padding for the Model --- #
MAX_SEQUENCE_LENGTH = 8192
PADDING_BYTE_VALUE = 0

def serialize_and_pad_grid(grid, max_len=MAX_SEQUENCE_LENGTH, pad_value=PADDING_BYTE_VALUE):
    """
    Serializes a grid into a byte sequence and pads it to a fixed length.
    """
    flat_array = np.array(grid, dtype=np.uint8).flatten()
    byte_sequence = flat_array.tobytes()
    padding_len = max_len - len(byte_sequence)

    if padding_len < 0:
        padded_sequence = byte_sequence[:max_len]
    else:
        padding = bytes([pad_value] * padding_len)
        padded_sequence = byte_sequence + padding
        
    return padded_sequence

from contineous_thought_machines.models.ctm_Diffusion_NEWNEW import batched_numeric_tensor_to_bytes

class PrinciplesDataset(Dataset):
    def __init__(self, file_path, max_len=MAX_SEQUENCE_LENGTH, pad_value=PADDING_BYTE_VALUE, audio_duration_seconds=2.0, sample_rate=16000):
        self.max_len = max_len
        self.pad_value = pad_value
        self.audio_duration_seconds = audio_duration_seconds
        self.sample_rate = sample_rate
        self.principles = []
        try:
            with open(file_path, 'r', encoding='utf-8') as f:
                for line in f:
                    line = line.strip()
                    if line:
                        self.principles.append(line)
            print(f"PrinciplesDataset: Loaded {len(self.principles)} principles from {file_path}.")
        except Exception as e:
            print(f"PrinciplesDataset Warning: Could not load or parse {file_path}: {e}")

    def __len__(self):
        return len(self.principles)

    def __getitem__(self, idx):
        principle_text = self.principles[idx]
        
        # 1. Prepare text input
        text_bytes = torch.tensor(list(principle_text.encode('utf-8')), dtype=torch.uint8)
        
        # 2. Prepare silent audio template
        num_audio_samples = int(self.audio_duration_seconds * self.sample_rate)
        # Create on CPU, as conversion to bytes happens on CPU.
        audio_template_numeric = torch.zeros(1, num_audio_samples) # Batch of 1
        
        # Convert audio template to bytes
        audio_template_bytes = batched_numeric_tensor_to_bytes(audio_template_numeric, source_dtype=np.float32).squeeze(0)

        # 3. Create combined byte sequence
        separator = torch.tensor([255, 0, 255, 0, 255, 0, 255, 0], dtype=torch.uint8)
        
        combined_input_bytes = torch.cat([text_bytes, separator, audio_template_bytes])

        # 4. Pad or truncate the combined sequence
        padding_len = self.max_len - len(combined_input_bytes)
        if padding_len < 0:
            padded_sequence = combined_input_bytes[:self.max_len]
        else:
            padding = torch.full((padding_len,), self.pad_value, dtype=torch.uint8)
            padded_sequence = torch.cat([combined_input_bytes, padding])
            
        return padded_sequence

def collate_fn_principles(batch):
    # The batch is already a list of tensors from __getitem__
    # We just need to stack them.
    return {'input_byte_sequences': torch.stack(batch)}

class NewCustomARCGridDataset(Dataset):
    def __init__(self, data_dir, max_grid_size=MAX_GRID_SIZE, padding_value=PADDING_VALUE):
        self.data_dir = data_dir
        self.task_files = glob.glob(os.path.join(data_dir, "*.json"))
        self.max_grid_size = max_grid_size
        self.padding_value = padding_value
        self.tasks = []
        print(f"NewCustomARCGridDataset: Looking for tasks in: {data_dir}")
        if not self.task_files:
            print(f"NewCustomARCGridDataset Warning: No JSON files found in {data_dir}. Dataset will be empty.")
        for task_file in self.task_files:
            try:
                with open(task_file, 'r') as f:
                    self.tasks.append(json.load(f))
            except Exception as e:
                print(f"NewCustomARCGridDataset Warning: Could not load or parse {task_file}: {e}")
        if not self.tasks:
            print(f"NewCustomARCGridDataset Warning: No tasks successfully loaded from {data_dir}.")
        else:
            print(f"NewCustomARCGridDataset: Loaded {len(self.tasks)} ARC tasks from {data_dir}.")

    def __len__(self):
        return len(self.tasks)

    def __getitem__(self, idx):
        task_data = self.tasks[idx]
        processed_task = {'train': [], 'test': [], 'id': os.path.basename(self.task_files[idx]) if idx < len(self.task_files) else 'unknown_task'}

        for pair_type in ['train', 'test']:
            for item in task_data.get(pair_type, []):
                input_grid_list = item.get('input', [])
                output_grid_list = item.get('output', [])
                
                original_input_dims = (len(input_grid_list), len(input_grid_list[0]) if input_grid_list and input_grid_list[0] else (0,0))
                original_output_dims = (len(output_grid_list), len(output_grid_list[0]) if output_grid_list and output_grid_list[0] else (0,0))

                padded_input_np = pad_grid(input_grid_list, self.max_grid_size, self.padding_value)
                padded_output_np = pad_grid(output_grid_list, self.max_grid_size, self.padding_value)
                
                processed_task[pair_type].append({
                    'input': torch.from_numpy(padded_input_np).long(),
                    'output': torch.from_numpy(padded_output_np).long(),
                    'original_input_dims': original_input_dims,
                    'original_output_dims': original_output_dims
                })
        return processed_task

def collate_fn_new_custom_arc(batch_of_tasks):
    input_byte_sequences_list = []
    target_byte_sequences_for_diffusion_list = []
    original_target_grids_for_ce_loss_list = []

    for task in batch_of_tasks:
        if not isinstance(task, dict):
            continue

        for train_pair in task.get('train', []):
            if not isinstance(train_pair, dict) or 'input' not in train_pair or 'output' not in train_pair:
                continue

            input_grid_np = train_pair['input'].numpy()
            target_grid_np = train_pair['output'].numpy()

            input_bytes = serialize_and_pad_grid(input_grid_np, max_len=MAX_SEQUENCE_LENGTH, pad_value=PADDING_BYTE_VALUE)
            input_byte_sequences_list.append(torch.tensor(list(input_bytes), dtype=torch.uint8))

            target_bytes_for_diffusion = serialize_and_pad_grid(target_grid_np, max_len=MAX_SEQUENCE_LENGTH, pad_value=PADDING_BYTE_VALUE)
            target_byte_sequences_for_diffusion_list.append(torch.tensor(list(target_bytes_for_diffusion), dtype=torch.uint8))

            original_target_grids_for_ce_loss_list.append(train_pair['output'].view(-1))
            
    if not input_byte_sequences_list:
        return {
            'input_byte_sequences': torch.empty(0, MAX_SEQUENCE_LENGTH, dtype=torch.uint8),
            'target_byte_sequences_for_diffusion': torch.empty(0, MAX_SEQUENCE_LENGTH, dtype=torch.uint8),
            'original_target_grids_for_ce_loss': torch.empty(0, ARC_INPUT_FLAT_DIM, dtype=torch.long),
        }

    final_input_byte_sequences = torch.stack(input_byte_sequences_list)
    final_target_byte_sequences_for_diffusion = torch.stack(target_byte_sequences_for_diffusion_list)
    final_original_target_grids_for_ce_loss = torch.stack(original_target_grids_for_ce_loss_list)
    
    # --- Fix for potential out-of-bounds padding values ---
    # The CrossEntropyLoss criterion expects class indices to be in [0, C-1].
    # If the padding value is negative or >= C, it can cause a CUDA 'device-side assert' error.
    # We defensively clamp the target tensor to the valid range [0, NUM_ARC_SYMBOLS - 1].
    final_original_target_grids_for_ce_loss.clamp_(min=0, max=NUM_ARC_SYMBOLS - 1)
    
    return {
        'input_byte_sequences': final_input_byte_sequences,
        'target_byte_sequences_for_diffusion': final_target_byte_sequences_for_diffusion,
        'original_target_grids_for_ce_loss': final_original_target_grids_for_ce_loss,
    }

arc_train_dataset = NewCustomARCGridDataset(ARC_TRAIN_DIR)
arc_eval_dataset = NewCustomARCGridDataset(ARC_EVAL_DIR)

arc_train_loader, arc_eval_loader = None, None
if arc_train_dataset and len(arc_train_dataset) > 0:
    arc_train_loader = DataLoader(
        arc_train_dataset, batch_size=ARC_BATCH_SIZE, shuffle=True,
        collate_fn=collate_fn_new_custom_arc, **OPTIMIZED_DATALOADER_CONFIG
    )
    if accelerator_arc: arc_train_loader = accelerator_arc.prepare(arc_train_loader)
    print(f"✓ ARC Training DataLoader initialized with {len(arc_train_dataset)} tasks.")
else:
    print("⚠️ ARC Training DataLoader could not be initialized.")

if arc_eval_dataset and len(arc_eval_dataset) > 0:
    arc_eval_loader = DataLoader(
        arc_eval_dataset, batch_size=1, shuffle=False,
        collate_fn=collate_fn_new_custom_arc, **OPTIMIZED_DATALOADER_CONFIG
    )
    if accelerator_arc: arc_eval_loader = accelerator_arc.prepare(arc_eval_loader)
    print(f"✓ ARC Evaluation DataLoader initialized with {len(arc_eval_dataset)} tasks.")
else:
    print("⚠️ ARC Evaluation DataLoader could not be initialized.")


# --- Principles Dataset and DataLoader ---
PRINCIPLES_FILE_PATH = "contineous-thought-machines/models/Principles/principles.txt"
principles_dataset = PrinciplesDataset(PRINCIPLES_FILE_PATH)
principles_loader = None
if principles_dataset and len(principles_dataset) > 0:
    principles_loader = DataLoader(
        principles_dataset, batch_size=ARC_BATCH_SIZE, shuffle=True,
        collate_fn=collate_fn_principles, **OPTIMIZED_DATALOADER_CONFIG
    )
    if accelerator_arc: principles_loader = accelerator_arc.prepare(principles_loader)
    print(f"✓ Principles DataLoader initialized with {len(principles_dataset)} principles.")
else:
    print("⚠️ Principles DataLoader could not be initialized.")


# === DEBUG + RANK CHECK ===
def get_rank_debug():
    if dist.is_available() and dist.is_initialized():
        rank = dist.get_rank()
        world_size = dist.get_world_size()
    else:
        rank = 0
        world_size = 1

    print(f"[DEBUG] Rank {rank} out of {world_size} total ranks")
    return rank, world_size


if not all([ctm_model_arc, optimizer_arc, arc_train_loader]):
    print("⚠️ Skipping ARC-AGI-2 training due to missing components.")
else:
    print("✓ All components ready for ARC training.")
    
    for epoch in range(NUM_EPOCHS_ARC):
        ctm_model_arc.train()
        if hasattr(ctm_model_arc, 'wake_up'):
            ctm_model_arc.wake_up()

        total_epoch_loss = 0
        
        progress_bar = tqdm(enumerate(arc_train_loader), total=len(arc_train_loader), desc=f"ARC Epoch {epoch + 1}")

        for batch_idx, batch_data in progress_bar:
            if not batch_data or batch_data['input_byte_sequences'].numel() == 0:
                print(f"Skipping empty batch {batch_idx}")
                continue

            input_bytes = batch_data['input_byte_sequences'].to(accelerator_arc.device if accelerator_arc else device)
            target_bytes_for_diffusion = batch_data['target_byte_sequences_for_diffusion'].to(accelerator_arc.device if accelerator_arc else device)
            
            current_batch_size = input_bytes.size(0)

            optimizer_arc.zero_grad()
            
            autocast_context = accelerator_arc.autocast() if accelerator_arc else autocast(enabled=USE_MIXED_PRECISION, dtype=autocast_dtype)

            with autocast_context:
                model_output_dict = ctm_model_arc(
                    byte_sequence=input_bytes,
                    target_diffusion_output=target_bytes_for_diffusion,
                    mode='ctm_controlled_diffusion',
                    timestep=torch.randint(0, config_arc_diffusion.diffusion_steps, (current_batch_size,), device=input_bytes.device).long()
                )

                total_loss = model_output_dict.get('total_loss', torch.tensor(0.0, device=input_bytes.device))

            if torch.isnan(total_loss) or torch.isinf(total_loss):
                print(f"[NaN or Inf Loss Detected] at Epoch {epoch+1}, Batch {batch_idx+1}. Skipping backward pass.")
                continue

            if accelerator_arc:
                accelerator_arc.backward(total_loss)
                if (batch_idx + 1) % GRADIENT_ACCUMULATION_STEPS == 0:
                   if accelerator_arc.sync_gradients:
                       accelerator_arc.clip_grad_norm_(ctm_model_arc.parameters(), MAX_GRAD_NORM)
                   optimizer_arc.step()
                   optimizer_arc.zero_grad()
            else:
                scaler.scale(total_loss).backward()
                if (batch_idx + 1) % GRADIENT_ACCUMULATION_STEPS == 0:
                    scaler.unscale_(optimizer_arc)
                    torch.nn.utils.clip_grad_norm_(ctm_model_arc.parameters(), MAX_GRAD_NORM)
                    scaler.step(optimizer_arc)
                    scaler.update()
                    optimizer_arc.zero_grad()

            total_epoch_loss += total_loss.item()
            progress_bar.set_postfix({
                'loss': total_loss.item(),
                'avg_loss': total_epoch_loss / (batch_idx + 1)
            })

        avg_epoch_loss = total_epoch_loss / len(arc_train_loader) if len(arc_train_loader) > 0 else 0
        print(f"Epoch [{epoch+1}/{NUM_EPOCHS_ARC}] completed. Average Loss: {avg_epoch_loss:.4f}")

        # --- Evaluation Step ---
        ctm_model_arc.eval()
        total_eval_loss = 0
        with torch.no_grad():
            for eval_batch_data in arc_eval_loader:
                input_bytes = eval_batch_data['input_byte_sequences'].to(accelerator_arc.device if accelerator_arc else device)
                target_bytes = eval_batch_data['target_byte_sequences_for_diffusion'].to(accelerator_arc.device if accelerator_arc else device)
                
                with autocast_context:
                    eval_output = ctm_model_arc(
                        byte_sequence=input_bytes,
                        target_diffusion_output=target_bytes,
                        mode='ctm_controlled_diffusion',
                        timestep=torch.randint(0, config_arc_diffusion.diffusion_steps, (input_bytes.size(0),), device=input_bytes.device).long()
                    )
                    eval_loss = eval_output.get('total_loss', torch.tensor(0.0, device=input_bytes.device))
                total_eval_loss += eval_loss.item()
        
        avg_eval_loss = total_eval_loss / len(arc_eval_loader) if len(arc_eval_loader) > 0 else 0
        print(f"Epoch [{epoch+1}/{NUM_EPOCHS_ARC}] Evaluation Loss: {avg_eval_loss:.4f}")

        # --- Checkpointing ---
        if accelerator_arc and accelerator_arc.is_main_process:
            if (epoch + 1) % 5 == 0: # Save every 5 epochs
                accelerator_arc.wait_for_everyone()
                unwrapped_model = accelerator_arc.unwrap_model(ctm_model_arc)
                
                # --- DeepSpeed Check ---
                if hasattr(accelerator_arc.state, 'deepspeed_plugin') and accelerator_arc.state.deepspeed_plugin is not None:
                    # DeepSpeed handles checkpointing via accelerator.save_state
                    accelerator_arc.save_state(os.path.join(CHECKPOINT_DIR_ARC, f"epoch_{epoch+1}"))
                else:
                    # For other setups, save with safetensors on rank 0
                    save_file(unwrapped_model.state_dict(), os.path.join(CHECKPOINT_DIR_ARC, f"ctm_model_arc_epoch_{epoch+1}.safetensors"))
                
                print(f"✓ Checkpoint saved for epoch {epoch+1} to {CHECKPOINT_DIR_ARC}")

        if hasattr(ctm_model_arc, 'sleep_down'):
            ctm_model_arc.sleep_down()

    print("\n🎉 ARC-AGI-2 Meta-Learning Training Phase Completed!")


# --- Principles Alignment Training Loop ---
if principles_loader and NUM_EPOCHS_PRINCIPLES > 0 and 'ctm_model_arc' in globals() and ctm_model_arc is not None:
    print("\n" + "="*60)
    print(f"🚀 STARTING PHASE: Principles Alignment Training")
    print(f"   Epochs: {NUM_EPOCHS_PRINCIPLES}")
    print("="*60 + "\n")

    for epoch in range(NUM_EPOCHS_PRINCIPLES):
        ctm_model_arc.train()
        if hasattr(ctm_model_arc, 'wake_up'):
            ctm_model_arc.wake_up()

        total_epoch_loss_principles = 0
        
        progress_bar_principles = tqdm(enumerate(principles_loader), total=len(principles_loader), desc=f"Principles Epoch {epoch + 1}")

        for batch_idx, batch_data in progress_bar_principles:
            if not batch_data or batch_data['input_byte_sequences'].numel() == 0:
                print(f"Skipping empty principles batch {batch_idx}")
                continue

            input_bytes = batch_data['input_byte_sequences'].to(accelerator_arc.device if accelerator_arc else device)
            
            current_batch_size = input_bytes.size(0)

            optimizer_arc.zero_grad()
            
            autocast_context = accelerator_arc.autocast() if accelerator_arc else autocast(enabled=USE_MIXED_PRECISION, dtype=autocast_dtype)

            with autocast_context:
                # For principles, the input is the target. The model learns to reconstruct the principles.
                model_output_dict = ctm_model_arc(
                    byte_sequence=input_bytes,
                    target_diffusion_output=input_bytes, # Self-supervision
                    mode='ctm_controlled_diffusion',
                    timestep=torch.randint(0, config_arc_diffusion.diffusion_steps, (current_batch_size,), device=input_bytes.device).long()
                )

                total_loss = model_output_dict.get('total_loss', torch.tensor(0.0, device=input_bytes.device))

            if torch.isnan(total_loss) or torch.isinf(total_loss):
                print(f"[NaN or Inf Loss Detected] in Principles training at Epoch {epoch+1}, Batch {batch_idx+1}. Skipping backward pass.")
                continue

            if accelerator_arc:
                accelerator_arc.backward(total_loss)
                if (batch_idx + 1) % GRADIENT_ACCUMULATION_STEPS == 0:
                   if accelerator_arc.sync_gradients:
                       accelerator_arc.clip_grad_norm_(ctm_model_arc.parameters(), MAX_GRAD_NORM)
                   optimizer_arc.step()
                   optimizer_arc.zero_grad()
            else:
                scaler.scale(total_loss).backward()
                if (batch_idx + 1) % GRADIENT_ACCUMULATION_STEPS == 0:
                    scaler.unscale_(optimizer_arc)
                    torch.nn.utils.clip_grad_norm_(ctm_model_arc.parameters(), MAX_GRAD_NORM)
                    scaler.step(optimizer_arc)
                    scaler.update()
                    optimizer_arc.zero_grad()

            total_epoch_loss_principles += total_loss.item()
            progress_bar_principles.set_postfix({
                'loss': total_loss.item(),
                'avg_loss': total_epoch_loss_principles / (batch_idx + 1)
            })

        avg_epoch_loss_principles = total_epoch_loss_principles / len(principles_loader) if len(principles_loader) > 0 else 0
        print(f"Principles Epoch [{epoch+1}/{NUM_EPOCHS_PRINCIPLES}] completed. Average Loss: {avg_epoch_loss_principles:.4f}")

        # --- Checkpointing for Principles Training ---
        if accelerator_arc and accelerator_arc.is_main_process:
            if (epoch + 1) % 5 == 0: # Save every 5 epochs
                accelerator_arc.wait_for_everyone()
                unwrapped_model = accelerator_arc.unwrap_model(ctm_model_arc)
                
                checkpoint_dir = os.path.join(CHECKPOINT_DIR_ARC, "principles_checkpoints")
                os.makedirs(checkpoint_dir, exist_ok=True)

                if hasattr(accelerator_arc.state, 'deepspeed_plugin') and accelerator_arc.state.deepspeed_plugin is not None:
                    accelerator_arc.save_state(os.path.join(checkpoint_dir, f"epoch_{epoch+1}"))
                else:
                    save_file(unwrapped_model.state_dict(), os.path.join(checkpoint_dir, f"ctm_model_arc_epoch_{epoch+1}.safetensors"))
                
                print(f"✓ Principles checkpoint saved for epoch {epoch+1} to {checkpoint_dir}")

        if hasattr(ctm_model_arc, 'sleep_down'):
            ctm_model_arc.sleep_down()

    print("\n🎉 Principles Alignment Training Phase Completed!")

#The Mixed Context training is not needed since the Program Synthesizer is not being used and the CTM Nueron Network is being used instead. 
# --- Mixed Context Training ---
'''
import random

class MixedContextDataset(Dataset):
    def __init__(self, num_samples=1000, short_len=256, long_len=4096, vocab_size=256):
        self.num_samples = num_samples
        self.short_len = short_len
        self.long_len = long_len
        self.vocab_size = vocab_size

    def __len__(self):
        return self.num_samples

    def __getitem__(self, idx):
        if random.random() < 0.5:
            # Short sequence, pack multiple
            num_packed = self.long_len // self.short_len
            packed = []
            mask = torch.zeros(self.long_len, self.long_len)
            pos = 0
            for i in range(num_packed):
                seq = torch.randint(0, self.vocab_size, (self.short_len,))
                packed.append(seq)
                # Causal mask for this segment
                segment_mask = torch.tril(torch.ones(self.short_len, self.short_len))
                mask[pos:pos+self.short_len, pos:pos+self.short_len] = segment_mask
                pos += self.short_len
            sequence = torch.cat(packed)[:self.long_len]
            is_long = False
        else:
            # Long sequence
            sequence = torch.randint(0, self.vocab_size, (self.long_len,))
            mask = torch.tril(torch.ones(self.long_len, self.long_len))
            is_long = True

        return {'sequence': sequence, 'mask': mask, 'is_long': is_long}

#mixed_dataset = MixedContextDataset()

#mixed_loader = DataLoader(mixed_dataset, batch_size=4, shuffle=True)

# Mixed training loop
 for epoch in range(5): 
     ctm_model_arc.train()
     total_loss = 0
     for batch in mixed_loader:
        sequence = batch['sequence'].to(device)
        attn_mask = batch['mask'].to(device)
        is_long = batch['is_long']

         Assuming model has train_forward that computes loss
        loss = ctm_model_arc.train_forward(sequence, attn_mask, use_rescaled_rope=is_long)

        optimizer_arc.zero_grad()
        loss.backward()
        optimizer_arc.step()
        total_loss += loss.item()

    print(f"Mixed Context Epoch {epoch+1} Avg Loss: {total_loss / len(mixed_loader)}")

print("\n🎉 Mixed Context Training Completed!")
'''


🚀 STARTING PHASE 4: ARC-AGI-2 Meta-Learning Training


NameError: name 'NUM_EPOCHS_ARC' is not defined

In [20]:
print(f"ARC_EVAL_DIR: {ARC_EVAL_DIR}")
print("Exists?", os.path.exists(ARC_EVAL_DIR))

ARC_EVAL_DIR: /workspace/Arc-AGI-2/contineous-thought-machines/examples/contineous-thought-machines/data/evaluation
Exists? False


In [21]:
ARC_EVAL_DIR = "/workspace/Arc-AGI-2/contineous-thought-machines/data/evaluation" 

In [22]:
print(f"ARC_EVAL_DIR: {ARC_EVAL_DIR}")
print("Exists?", os.path.exists(ARC_EVAL_DIR))

ARC_EVAL_DIR: /workspace/Arc-AGI-2/contineous-thought-machines/data/evaluation
Exists? True


In [16]:
ls /workspace/Arc-AGI-2/contineous-thought-machines/data/evaluation

0934a4d8.json  332f06d7.json  65b59efc.json  8e5c0c38.json  c7f57c3e.json
135a2760.json  35ab12c3.json  67e490f4.json  8f215267.json  cb2d8a2c.json
136b0064.json  36a08778.json  6e453dd6.json  8f3a5a89.json  cbebaa4b.json
13e47133.json  38007db0.json  6e4f6532.json  9385bd28.json  d35bdbdc.json
142ca369.json  3a25b0d8.json  6ffbe589.json  97d7923e.json  d59b0160.json
16b78196.json  3dc255db.json  71e489b6.json  981571dc.json  d8e07eb2.json
16de56c4.json  3e6067c3.json  7491f3cf.json  9aaea919.json  da515329.json
1818057f.json  409aa875.json  7666fa5d.json  9bbf930d.json  db0c5428.json
195c6913.json  446ef5d2.json  78332cb0.json  a251c730.json  db695cfb.json
1ae2feb7.json  45a5af55.json  7b0280bc.json  a25697e4.json  dbff022c.json
20270e3b.json  4a21e3da.json  7b3084d4.json  a32d8b75.json  dd6b8c4b.json
20a9e565.json  4c3d4a41.json  7b5033c1.json  a395ee82.json  de809cff.json
21897d95.json  4c416de3.json  7b80bb43.json  a47bf94d.json  dfadab01.json
221dfab4.json  4c7dc4dd.json  7c66cb00

# Evaluation Arc_AGI_2 Phase

In [None]:
import sys
import os
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import numpy as np
import glob
import json
import traceback
import random
from dataclasses import dataclass, field
from typing import List, Optional, Any
import math

# --- De-noising and Meta-Learning Components ---

def perform_online_update(model, optimizer, scheduler, input_bytes, corrected_grid_np: np.ndarray, device):
    """
    Performs a single, targeted fine-tuning step on the end-to-end model.
    """
    model.train()
    optimizer.zero_grad()

    target_bytes_single = serialize_and_pad_grid(corrected_grid_np, max_len=MAX_SEQUENCE_LENGTH, pad_value=PADDING_BYTE_VALUE)
    target_bytes_np = np.frombuffer(target_bytes_single, dtype=np.uint8).copy()
    target_bytes_tensor = torch.from_numpy(target_bytes_np).to(torch.uint8).unsqueeze(0).to(device)

    train_timestep = torch.zeros(1, device=device).long()
    
    # The model's forward pass calculates all necessary losses internally.
    output_dict = model(
        byte_sequence=input_bytes,
        mode='ctm_controlled_diffusion',
        target_diffusion_output=target_bytes_tensor,
        timestep=train_timestep,
        task_name="ARC_AGI_2_ONLINE_LEARN"
    )

    loss = output_dict.get('total_loss')

    if loss is not None and torch.isfinite(loss):
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step()
        scheduler.step()
        print(f"  > Model updated with loss: {loss.item():.4f}. LR: {scheduler.get_last_lr()[0]:.6f}")
    else:
        print("  > Skipping online update due to invalid loss.")

    model.eval()

# Setup module paths based on user-provided successful import logic
print("--- Setting up module paths ---")
project_root = '/workspaces/Arc-AGI-2'
module_path = os.path.join(project_root, 'contineous-thought-machines')

if module_path not in sys.path:
    sys.path.append(module_path)
    print(f"Added to sys.path: {module_path}")

try:
    from safetensors.torch import load_file
except ImportError:
    print("Warning: safetensors not found. Loading .safetensors will fail.")
    def load_file(path, device="cpu"):
        raise ImportError(f"safetensors is not installed, cannot load {path}")

print("\n--- Statically importing EnhancedCTMDiffusion model ---")
EnhancedCTMDiffusion = None
try:
    from contineous_thought_machines.models.ctm_Diffusion_NEWNEW import EnhancedCTMDiffusion
    print(" -> Successfully imported EnhancedCTMDiffusion from models package.")
except ImportError as e_direct:
    print(f"FATAL: Import from models package failed. Last error: {e_direct}")
    EnhancedCTMDiffusion = None

try:
    from accelerate import Accelerator
    ACCELERATE_AVAILABLE = True
except ImportError:
    print("Warning: Hugging Face Accelerate not found. Will run on a single device.")
    ACCELERATE_AVAILABLE = False
    Accelerator = None

# --- Constants and Configuration ---
device = "cuda" if torch.cuda.is_available() else "cpu"
MAX_GRID_SIZE = (30, 30)
PADDING_VALUE = -1
ARC_INPUT_FLAT_DIM = MAX_GRID_SIZE[0] * MAX_GRID_SIZE[1]
MAX_SEQUENCE_LENGTH = 8192
PADDING_BYTE_VALUE = 0
NUM_ARC_SYMBOLS = 10
LEARNING_RATE = 1e-4

# --- Data Handling ---
def pad_grid(grid_list, max_dims, pad_value):
    grid_np = np.array(grid_list, dtype=np.int32)
    padded_grid = np.full(max_dims, pad_value, dtype=np.int32)
    h, w = grid_np.shape
    padded_grid[:h, :w] = grid_np
    return padded_grid

def serialize_and_pad_grid(grid, max_len=MAX_SEQUENCE_LENGTH, pad_value=PADDING_BYTE_VALUE):
    flat_array = np.array(grid, dtype=np.uint8).flatten()
    byte_sequence = flat_array.tobytes()
    padding_len = max_len - len(byte_sequence)
    if padding_len < 0:
        return byte_sequence[:max_len]
    return byte_sequence + bytes([pad_value] * padding_len)

@dataclass
class EnhancedCTMConfig: # Renamed from ContinualLearningConfig for consistency in the target file
    """Enhanced configuration for continual learning CTM-diffusion model,
    incorporating binary processing, multi-task learning, and advanced CTM features."""
    
    # Model architecture (General Transformer/Diffusion settings)
    d_model: int = 512  # Main model dimensionality
    n_heads: int = 8
    n_layers: int = 24
    max_sequence_length: int = 8192 # Max input sequence length in terms of bytes or patches
    dropout: float = 0.1
    
    # --- Byte Processing Options ---
    patch_embedding_dim: int = 256         # <<< NEW: Output embedding dimension per patch from patcher
    patch_encoder_cnn_channels: int = 64   # <<< NEW: Intermediate channels for CNN patch encoder

    # --- Dynamic Entropy Patching Options (Inspired by BLT paper) ---
    use_dynamic_entropy_patcher: bool = True # Flag to enable dynamic entropy-based patching
    entropy_patcher_threshold_type: str = "global"  # 'global' or 'relative_monotonic'
    entropy_patcher_global_threshold: float = 0.75 # Entropy threshold for 'global' type
    entropy_patcher_relative_threshold: float = 0.1 # Entropy diff threshold for 'relative_monotonic'
    entropy_patcher_min_patch_size: int = 4      # Minimum number of bytes in a dynamic patch
    entropy_patcher_max_patch_size: int = 128    # Maximum number of bytes in a dynamic patch (for CNN encoder)
    
    # --- Learnable Entropy Model Parameters (for _EntropyProxyModel) ---
    entropy_model_byte_vocab_size: int = 256
    entropy_model_embedding_dim: int = 64
    entropy_model_hidden_dim: int = 128
    entropy_model_num_layers: int = 1
    entropy_model_dropout: float = 0.1
    entropy_model_loss_weight: float = 0.1 # Weight for its auxiliary loss contribution
    # Note: These parameters are used if use_dynamic_entropy_patcher is True,
    # as LearnedBytePatcherEncoder now instantiates the learnable _EntropyProxyModel.
    
    # Fallback if not using learned_patch_encoder or dynamic_entropy_patcher
    byte_embedding_dim: int = 256
    multi_granularity: bool = False # Default to False if patcher is preferred
    # multi_granularity_output_dim is complex to predefine, MGP should expose its output dim.
    # For now, if multi_granularity is True AND use_learned_patch_encoder is False, this would be used.
    multi_granularity_output_dim: int = 256 # Placeholder if MGP is used.
    
    hierarchical_processing: bool = True # General flag, could apply to patcher or MGP
    
    # CTM Core Parameters (Specific to the OriginalCTMCore module)
    # These are prefixed with 'ctm_' to distinguish from general model params
    ctm_iterations: int = 5  # Original 'iterations'
    ctm_d_model: int = 512   # Original 'd_model' for CTM's internal latent space
    ctm_input_dim: int = 256 # Dimensionality of inputs to CTM (e.g., from byte embeddings or other features)
                             # This was 'd_input' in OriginalCTMCore if it took external features.
                             # If CTM processes outputs of byte_embedding, this might be byte_embedding_dim.
    ctm_heads: int = 8       # Attention heads within CTM
    ctm_n_synch_out: int = 64
    ctm_n_synch_action: int = 64
    ctm_synapse_depth: int = 3
    ctm_memory_length: int = 10
    ctm_deep_nlms: bool = True
    ctm_memory_hidden_dims: int = 2048
    ctm_do_layernorm_nlm: bool = False
    ctm_out_dims: int = 512  # Output dimension of CTM's own projector
    ctm_prediction_reshaper: list = field(default_factory=lambda: [-1])
    ctm_dropout: float = 0.1
    ctm_dropout_nlm: Optional[float] = None
    # Neuron selection strategy. Available options:
    # Legacy: 'first-last', 'random', 'random-pairing'
    # Biologically-inspired: 'bio_hebbian', 'bio_plasticity', 'bio_competitive',
    #                        'bio_homeostatic', 'bio_evolutionary', 'bio_stdp',
    #                        'bio_criticality', 'bio_multi_objective'
    # Hybrid: 'adaptive_random', 'performance_guided', 'task_aware'
    ctm_neuron_select_type: str = 'bio_multi_objective'
    ctm_n_random_pairing_self: int = 0
    
    # Diffusion Parameters
    diffusion_steps: int = 1000
    noise_schedule: str = "cosine" # e.g., "linear", "cosine"
    diffusion_beta_start: float = 0.0001
    diffusion_beta_end: float = 0.02
    diffusion_timesteps: int = 1000 # Number of timesteps for the diffusion process
    ctm_diffusion_coupling_strength: float = 0.8 # How CTM influences diffusion
    adaptive_scheduling: bool = True  # CTM-adaptive diffusion timestep scheduling
    iterative_refinement: bool = True # Iterative CTM-diffusion refinement for sampling
    
    # Training Efficiency
    mixed_precision: bool = True
    gradient_checkpointing: bool = True
    sparse_attention: bool = True  # Now implemented with BinarySparseAttention
    adaptive_depth: bool = False   # Defaulting to False, can be enabled if implemented
    use_activity_plasticity: bool = True # To enable/disable plasticity updates; Needs to be set to TRUE
    ctm_use_internal_feedback: bool = True # Enable self-modulating feedback within the CTM core

    # --- Bidirectional Reasoning Parameters ---
    enable_bidirectional_reasoning: bool = True # Allows CTM to move forward/backward in its thought process
    reasoning_step_gating_threshold: float = 0.7 # Confidence threshold for the reasoning controller to terminate
    max_reasoning_steps: int = 15 # Max total steps in a bidirectional reasoning loop to prevent infinite loops
    
    # Sparse Attention Parameters
    sparse_attention_ratio: float = 0.1  # Keep only 10% of attention connections
    binary_pattern_size: int = 8  # Size of binary patterns to detect

    # Attention Mechanism Type
    attention_type: str = "WINA"  # Options: "standard", "binary_sparse", "WINA" #Need to use WINA attention in place of "standard"

    positional_embedding_type: Optional[str] = 'multi-learnable-fourier' # e.g., 'custom-rotational-1d', 'learnable-fourier', multi-learnable-fourier' #Can set the value here.
    positional_embedding_dim: Optional[int] = None  # Dimension of the positional embedding, defaults to ctm_input_dim if None
    reshape_patch_sequence_to_grid: bool = True # If True, reshape patch sequence to a 2D grid for 2D PEs. Must set to true if using 2D Grid for Positional Embeddings.
    patch_grid_width: Optional[int] = None       # Desired width of the patch grid if reshaping

    # --- Hierarchical Reasoning Model (HRM) Parameters ---
    use_hrm_core: bool = True # Set to True to use the HierarchicalCTM core
    hrm_high_level_cycles: int = 4 # N: Number of high-level cycles
    hrm_low_level_timesteps: int = 8 # T: Number of low-level timesteps per high-level cycle
    program_vocab_size: int = 1024 # Vocabulary size for the program synthesizer
    program_synth_n_heads: int = 4
    program_synth_n_layers: int = 3
    program_synth_d_ff: int = 1024
    ltm_size: int = 2048 # Size of the long-term memory
    ltm_surprise_threshold: float = 0.6 # Surprise threshold for storing in LTM
    replay_batch_size: int = 4 # Batch size for memory replay
    replay_policy: str = "surprise_weighted_replay" # "simple_replay", "surprise_weighted_replay", "usefulness_replay"

    # Pipeline Parallelism Parameters
    enable_pipeline_parallelism: bool = True
    pipeline_stages: int = 3  # CTM, Diffusion prep, Diffusion exec
    pipeline_overlap_ratio: float = 0.7  # Target overlap ratio
    
    # Adaptive Batch Sizing Parameters
    enable_adaptive_batching: bool = True
    initial_batch_size: int = 32
    min_batch_size: int = 8
    max_batch_size: int = 256
    batch_adaptation_frequency: int = 100
    memory_threshold_high: float = 0.85
    memory_threshold_low: float = 0.6
    
    # Smart Data Sampling Parameters
    enable_smart_sampling: bool = True
    sample_importance_weight: float = 0.6
    sample_diversity_weight: float = 0.4
    initial_sample_ratio: float = 0.3
    complexity_analysis_enabled: bool = True
    
    # Multi-input/output parameters
    num_inputs: int = 1  # Number of input streams
    num_outputs: int = 1  # Number of output heads
    output_dims: List[int] = field(default_factory=lambda: [64])  # Dimensions for each output head
    
    # Self-supervised learning
    ssl_dim: int = 128  # Dimension for self-supervised projection
    ssl_weight: float = 0.1  # Weight for self-supervised loss
    ssl_temperature: float = 0.07  # Temperature for contrastive loss
    ssl_noise_std: float = 0.1  # Noise standard deviation for contrastive augmentation
    
    # Spatiotemporal Processing
    use_spatial: bool = True  # Enable spatial processing for image/video data
    
    # WINA Attention
    use_wina_attention: bool = True  # Enable WINA sparse attention
    
    # Multi-task Learning Parameters
    max_tasks: int = 50  # Maximum number of tasks for continual learning
    # Added to resolve TypeError for unexpected keyword arguments
    vocab_size: Optional[int] = None
    output_audio_bytes: bool = False
    inferred_task_latent_dim: Optional[int] = None # Default to None, __post_init__ handles it
    use_hipa_attention: bool = False # Default to False
    hipa_num_heads: Optional[int] = None # Default to None
    audio_output_dtype_str: Optional[str] = "float32" # Default as per __post_init__ logic
    unet_input_feature_dim: Optional[int] = None # Default to None, __post_init__ calculates it

    # --- JEPA Training Parameters (Integrated with LearnedBytePatcherEncoder) ---
    use_jepa_training: bool = False
    # jepa_embed_dim will be derived from patch_embedding_dim if dynamic_entropy_patcher is used
    jepa_predictor_hidden_dim: int = 512 # Hidden dimension of JEPA predictor MLP
    jepa_mask_ratio_min: float = 0.15 # Min proportion of patch sequence to mask for target
    jepa_mask_ratio_max: float = 0.75 # Max proportion of patch sequence to mask for target
    jepa_context_scale_min: float = 0.3 # Min proportion of patches for context
    jepa_context_scale_max: float = 0.7 # Max proportion of patches for context
    jepa_momentum_beta: float = 0.996 # Momentum for target encoder update
    jepa_loss_weight: float = 0.1 # Weight for the JEPA loss component
    jepa_num_target_blocks: int = 1 # Number of target blocks to predict

    # --- Global Plasticity Loss Parameters ---
    local_hebbian_loss_weight: float = 0.01 # New weight for backprop-based hebbian loss

    # --- Basal Ganglia Parameters --- #Controls action suppression so that the model's unwanted first unrelated thoughts are suppressed which helps with model safety. Is needed for action suppresion.
    ctm_enable_basal_ganglia: bool = True
    ctm_bg_dopamine_dim: int = 32

    # --- Synaptic Empathy Parameters ---
    enable_synaptic_empathy: bool = True # Set to True to use the new SynapticEmpathy module
    synaptic_empathy_reward_weight: float = 0.1

    # --- Mirror Neuron / High-Level Empathy Parameters ---
    enable_mirror_neurons: bool = True # Set to True to use the high-level MirrorNeuronLayer
    num_emotion_dim: int = 4 # Dimensionality of the emotion state vector
    goal_dim: int = 8 # Dimensionality of the predicted goal vector
    mirror_reward_weight: float = 0.2 # Weight for the selfless reward signal


    # --- Confidence Thresholding Parameters ---
    confidence_threshold: float = 0.0 # Confidence threshold for abstaining. If > 0, model can abstain.
 
    # --- Consciousness Controller Parameters ---
    enable_consciousness_controller: bool = True
    consciousness_max_attention_steps: int = 100

    def __post_init__(self):
        # Validate output dimensions
        if len(self.output_dims) != self.num_outputs:
            raise ValueError(f"output_dims length ({len(self.output_dims)}) must match num_outputs ({self.num_outputs})")

        # Merged content from the second __post_init__
        if hasattr(self, 'ctm_prediction_reshaper') and self.ctm_prediction_reshaper == [-1] and self.vocab_size is not None:
            pass
        if hasattr(self, 'ctm_dropout_nlm') and self.ctm_dropout_nlm is None and hasattr(self, 'ctm_dropout'):
            self.ctm_dropout_nlm = self.ctm_dropout
        
        if hasattr(self, 'ctm_neuron_select_type') and \
           VALID_NEURON_SELECT_TYPES is not None and self.ctm_neuron_select_type not in VALID_NEURON_SELECT_TYPES:
            print(f"Warning: ctm_neuron_select_type '{self.ctm_neuron_select_type}' is not in VALID_NEURON_SELECT_TYPES ({VALID_NEURON_SELECT_TYPES}).")

        if hasattr(self, 'positional_embedding_type') and self.positional_embedding_type is not None:
            if VALID_POSITIONAL_EMBEDDING_TYPES is None: # Fallback if import failed
                print(f"Warning: VALID_POSITIONAL_EMBEDDING_TYPES not available for validation.")
            elif self.positional_embedding_type not in VALID_POSITIONAL_EMBEDDING_TYPES:
                print(f"Warning: positional_embedding_type '{self.positional_embedding_type}' is not in VALID_POSITIONAL_EMBEDDING_TYPES ({VALID_POSITIONAL_EMBEDDING_TYPES}).")
            if self.positional_embedding_dim is not None and self.positional_embedding_dim <= 0:
                raise ValueError("positional_embedding_dim must be positive if set.")
            
            if self.reshape_patch_sequence_to_grid:
                if self.patch_grid_width is None or self.patch_grid_width <= 0:
                    raise ValueError("patch_grid_width must be a positive integer if reshape_patch_sequence_to_grid is True.")
                if self.positional_embedding_type not in ['learnable-fourier', 'multi-learnable-fourier', 'custom-rotational']:
                    print(f"Warning: reshape_patch_sequence_to_grid is True, but positional_embedding_type ('{self.positional_embedding_type}') is not a typical 2D PE. Ensure compatibility.")

        # Validations for new patch encoder
        if self.use_dynamic_entropy_patcher:
            if self.patch_embedding_dim <= 0:
                raise ValueError("patch_embedding_dim must be positive if use_dynamic_entropy_patcher is True.")
            if self.entropy_patcher_min_patch_size <= 0:
                raise ValueError("entropy_patcher_min_patch_size must be positive.")
            if self.entropy_patcher_max_patch_size < self.entropy_patcher_min_patch_size:
                raise ValueError("entropy_patcher_max_patch_size must be >= entropy_patcher_min_patch_size.")
            if self.entropy_patcher_threshold_type not in ["global", "relative_monotonic"]:
                raise ValueError("entropy_patcher_threshold_type must be 'global' or 'relative_monotonic'.")
        elif self.multi_granularity and self.multi_granularity_output_dim <= 0:
            print("Warning: multi_granularity_output_dim might not be correctly set for validation if not using a patcher and MGP is active.")
        
        if not hasattr(self, 'inferred_task_latent_dim') or self.inferred_task_latent_dim is None:
            print("Warning: inferred_task_latent_dim not found or is None in config, defaulting to 64.")
            self.inferred_task_latent_dim = 512
        elif self.inferred_task_latent_dim <= 0: # This check is now safe
            raise ValueError("inferred_task_latent_dim must be positive.")
 
        if hasattr(self, 'use_hipa_attention') and self.use_hipa_attention and \
            (not hasattr(self, 'hipa_num_heads') or self.hipa_num_heads <= 0):
             raise ValueError("hipa_num_heads must be positive if use_hipa_attention is True.")
 
        if hasattr(self, 'audio_output_dtype_str'):
            if self.audio_output_dtype_str == "float32":
                self.audio_output_item_size = 4
            elif self.audio_output_dtype_str == "int16":
                self.audio_output_item_size = 2
            else:
                if hasattr(self, 'output_audio_bytes') and self.output_audio_bytes:
                    raise ValueError(f"Unsupported audio_output_dtype_str: {self.audio_output_dtype_str} when output_audio_bytes is True.")
                else:
                    self.audio_output_item_size = 4
        elif hasattr(self, 'output_audio_bytes') and self.output_audio_bytes:
            if not hasattr(self, 'audio_output_dtype_str') or self.audio_output_dtype_str is None:
                raise ValueError("audio_output_dtype_str must be defined in config if output_audio_bytes is True.")
        else:
            self.audio_output_item_size = 4

        # Calculate unet_input_feature_dim if not set
        if self.unet_input_feature_dim is None:
            if self.max_sequence_length <= 0 or self.audio_output_item_size <= 0:
                raise ValueError("max_sequence_length and audio_output_item_size must be positive to calculate unet_input_feature_dim.")
            self.unet_input_feature_dim = self.max_sequence_length // self.audio_output_item_size
            if self.unet_input_feature_dim <= 0:
                raise ValueError(f"Calculated unet_input_feature_dim ({self.unet_input_feature_dim}) must be positive. Check max_sequence_length and audio_output_item_size.")
        elif self.unet_input_feature_dim <= 0:
            raise ValueError("unet_input_feature_dim, if set, must be positive.")

        if self.use_jepa_training:
            if not (0 < self.jepa_mask_ratio_min < 1 and 0 < self.jepa_mask_ratio_max < 1 and self.jepa_mask_ratio_min <= self.jepa_mask_ratio_max):
                raise ValueError("JEPA mask ratios must be between 0 and 1, with min <= max.")
            if not (0 < self.jepa_context_scale_min < 1 and 0 < self.jepa_context_scale_max < 1 and self.jepa_context_scale_min <= self.jepa_context_scale_max):
                raise ValueError("JEPA context scales must be between 0 and 1, with min <= max.")
            if not (0 <= self.jepa_momentum_beta < 1):
                raise ValueError("jepa_momentum_beta must be between 0 and 1.")
            if self.jepa_num_target_blocks <= 0:
                raise ValueError("jepa_num_target_blocks must be positive.")
            if not self.use_dynamic_entropy_patcher:
                print("Warning: JEPA training is enabled but use_dynamic_entropy_patcher is False. JEPA relies on the patch embeddings from LearnedBytePatcherEncoder.")

config_arc_diffusion = EnhancedCTMConfig(
    enable_consciousness_controller=True,
    consciousness_max_attention_steps=100
)

print("✓ EnhancedCTMConfig for ARC (config_arc_diffusion) created.")

if EnhancedCTMDiffusion is not None:
    ctm_model_arc = EnhancedCTMDiffusion(config=config_arc_diffusion).to(device)
    print("✓ EnhancedCTMDiffusion model for ARC (ctm_model_arc) initialized.")
    optimizer_arc = optim.AdamW(ctm_model_arc.parameters(), lr=LEARNING_RATE, weight_decay=1e-4)

    if ACCELERATE_AVAILABLE:
        print(" -> Preparing components with Hugging Face Accelerate...")
        accelerator_arc = Accelerator()
        ctm_model_arc, optimizer_arc = accelerator_arc.prepare(ctm_model_arc, optimizer_arc)
        print("✓ ARC model and optimizer prepared with Accelerate.")
else:
    print("⚠️ EnhancedCTMDiffusion model could not be initialized.")
    ctm_model_arc, optimizer_arc, accelerator_arc = None, None, None


class ARCEvalDataset(Dataset):
    def __init__(self, data_path, max_grid_size=MAX_GRID_SIZE, padding_value=PADDING_VALUE):
        self.task_files = glob.glob(os.path.join(data_path, "*.json"))
        if not self.task_files:
            print(f"Warning: No .json files found at path: {data_path}")
        self.max_grid_size = max_grid_size
        self.padding_value = padding_value
        self.tasks = [json.load(open(f)) for f in self.task_files]
        print(f"Loaded {len(self.tasks)} tasks from {data_path}.")

    def __len__(self):
        return len(self.tasks)

    def __getitem__(self, idx):
        task_data = self.tasks[idx]
        processed_task = {'train': [], 'test': [], 'id': os.path.basename(self.task_files[idx])}
        for pair_type in ['train', 'test']:
            for item in task_data.get(pair_type, []):
                input_grid = item['input']
                output_grid = item['output']
                original_input_dims = (len(input_grid), len(input_grid[0]) if input_grid else 0)
                original_output_dims = (len(output_grid), len(output_grid[0]) if output_grid else 0)
                padded_input = pad_grid(input_grid, self.max_grid_size, self.padding_value)
                padded_output = pad_grid(output_grid, self.max_grid_size, self.padding_value)
                processed_task[pair_type].append({
                    'input': torch.from_numpy(padded_input).long(),
                    'output': torch.from_numpy(padded_output).long(),
                    'original_input_dims': original_input_dims,
                    'original_output_dims': original_output_dims
                })
        return processed_task

ARC_EVAL_DIR = "/workspace/Arc-AGI-2/contineous-thought-machines/data/evaluation"
CHECKPOINT_DIR_ARC = "/workspaces/Arc-AGI-2/contineous-thought-machines/examples/checkpoints/ctm_arc_agi_2_enhanced_diffusion"
CHECKPOINT_DIR_PRINCIPLES = os.path.join(CHECKPOINT_DIR_ARC, "principles_checkpoints")
NUM_EPOCHS_ARC = 20
NUM_EPOCHS_PRINCIPLES = 10 # Should match the value in training.py

print("\n--- Initializing Evaluation Dataloader ---")
arc_eval_loader = None
if os.path.exists(ARC_EVAL_DIR):
    arc_eval_dataset = ARCEvalDataset(data_path=ARC_EVAL_DIR)
    if len(arc_eval_dataset) > 0:
        arc_eval_loader = DataLoader(arc_eval_dataset, batch_size=1, shuffle=False)
        print(f"✓ Evaluation DataLoader initialized with {len(arc_eval_dataset)} tasks.")
else:
    print(f"⚠️ Evaluation directory not found: '{ARC_EVAL_DIR}'")

# --- Main Evaluation Logic ---
print("\n" + "="*60)
print(f"🔬 STARTING ARC-AGI-2 Evaluation on device '{device}'")
print("="*60 + "\n")

if not all([ctm_model_arc, optimizer_arc, arc_eval_loader]):
     print("⚠️ Skipping evaluation due to missing components.")
else:
    latest_epoch = NUM_EPOCHS_PRINCIPLES
    ctm_checkpoint_path_eval = os.path.join(CHECKPOINT_DIR_PRINCIPLES, f"ctm_model_arc_epoch_{latest_epoch}.safetensors")

    try:
        if os.path.exists(ctm_checkpoint_path_eval):
            print(f"  > Loading CTM checkpoint from {ctm_checkpoint_path_eval}...")
            state_dict_ctm = load_file(ctm_checkpoint_path_eval, device="cpu")
            model_to_load_ctm = accelerator_arc.unwrap_model(ctm_model_arc) if ACCELERATE_AVAILABLE else ctm_model_arc
            model_to_load_ctm.load_state_dict(state_dict_ctm, strict=False)
            print(f"✓ Loaded CTM checkpoint from epoch {latest_epoch}.")
        else:
            print(f"⚠️ CTM Checkpoint not found at {ctm_checkpoint_path_eval}.")

        ctm_model_arc.eval()
        if hasattr(ctm_model_arc, 'wake_up'):
            ctm_model_arc.wake_up()
        total_tasks = 0
        solved_tasks = 0

        # Create a scheduler for the online updates.
        scheduler_arc = optim.lr_scheduler.StepLR(optimizer_arc, step_size=5, gamma=0.9)

        for task_idx, task_batch in enumerate(arc_eval_loader):
            if not task_batch: continue

            current_task_data = task_batch
            total_tasks += 1
            task_solved_overall = True
            
            # Since batch_size is 1, unpack the lists
            task_id = current_task_data['id'][0]
            test_pairs = [{k: v.squeeze(0) for k, v in pair.items()} for pair in current_task_data['test'][0]]

            if not test_pairs:
                print(f"Task {task_idx + 1} ({task_id}): No test cases found. Skipping.")
                continue

            for test_pair_idx, test_pair in enumerate(test_pairs):
                input_grid_np_eval = test_pair['input'].cpu().numpy()
                input_bytes_eval = torch.from_numpy(np.frombuffer(serialize_and_pad_grid(input_grid_np_eval), dtype=np.uint8)).to(torch.uint8).unsqueeze(0).to(device)

                target_grid_np = test_pair['output'].cpu().numpy()
                h, w = test_pair['original_output_dims']
                original_dims = (h.item(), w.item())
                final_target = target_grid_np[:original_dims[0], :original_dims[1]]
                
                test_input_solved = False

                # --- First Attempt: Standard Prediction ---
                print(f"  > Attempt 1 for test pair {test_pair_idx + 1}...")
                with torch.no_grad():
                    eval_model_output = ctm_model_arc.iterative_ctm_diffusion_sample(shape=input_bytes_eval.shape, initial_byte_sequence_for_inference=input_bytes_eval, num_steps=50)
                    output_bytes = eval_model_output[0]
                    
                    if output_bytes is not None and output_bytes.numel() > 0:
                        grid_flat = np.frombuffer(output_bytes.squeeze(0).cpu().numpy().tobytes(), dtype=np.uint8)
                        preds_grid = np.full(MAX_GRID_SIZE, PADDING_VALUE, dtype=int)
                        reshaped_len = min(len(grid_flat), ARC_INPUT_FLAT_DIM)
                        preds_grid.flat[:reshaped_len] = grid_flat[:reshaped_len]
                    else:
                        preds_grid = np.full(MAX_GRID_SIZE, PADDING_VALUE, dtype=int)

                final_pred = preds_grid[:original_dims[0], :original_dims[1]]

                if np.array_equal(final_pred, final_target):
                    print(f"    - Solved on first attempt.")
                    test_input_solved = True
                else:
                    print(f"    - Failed on first attempt. Trying online update.")
                    # --- Second Attempt: Fine-tune and Predict Again ---
                    perform_online_update(
                        model=ctm_model_arc,
                        optimizer=optimizer_arc,
                        scheduler=scheduler_arc,
                        input_bytes=input_bytes_eval,
                        corrected_grid_np=target_grid_np, # Use the full target grid for the update
                        device=device
                    )
                    
                    print(f"  > Attempt 2 for test pair {test_pair_idx + 1} (post-update)...")
                    with torch.no_grad():
                        eval_model_output_2 = ctm_model_arc.iterative_ctm_diffusion_sample(shape=input_bytes_eval.shape, initial_byte_sequence_for_inference=input_bytes_eval, num_steps=50)
                        output_bytes_2 = eval_model_output_2[0]

                        if output_bytes_2 is not None and output_bytes_2.numel() > 0:
                            grid_flat_2 = np.frombuffer(output_bytes_2.squeeze(0).cpu().numpy().tobytes(), dtype=np.uint8)
                            preds_grid_2 = np.full(MAX_GRID_SIZE, PADDING_VALUE, dtype=int)
                            reshaped_len_2 = min(len(grid_flat_2), ARC_INPUT_FLAT_DIM)
                            preds_grid_2.flat[:reshaped_len_2] = grid_flat_2[:reshaped_len_2]
                        else:
                            preds_grid_2 = np.full(MAX_GRID_SIZE, PADDING_VALUE, dtype=int)
                    
                    final_pred_2 = preds_grid_2[:original_dims[0], :original_dims[1]]
                    
                    if np.array_equal(final_pred_2, final_target):
                         print(f"    - Solved on second attempt after fine-tuning.")
                         test_input_solved = True
                    else:
                         print(f"    - Failed on second attempt.")

                if not test_input_solved:
                    task_solved_overall = False
                    break
            
            if task_solved_overall:
                solved_tasks += 1
                print(f"  Task {task_idx + 1}/{len(arc_eval_loader)} ({task_id}): SOLVED")
            else:
                print(f"  Task {task_idx + 1}/{len(arc_eval_loader)} ({task_id}): FAILED")

        if total_tasks > 0:
            accuracy = (solved_tasks / total_tasks) * 100
            summary = f"ARC-AGI-2 Evaluation Summary:\n  Total tasks evaluated: {total_tasks}\n  Tasks solved: {solved_tasks}\n  Accuracy: {accuracy:.2f}%"
            print(f"\n{summary}")
            with open('arc_agi_2_evaluation_summary.txt', 'w') as f:
                f.write(summary)
        else:
            print("\nARC-AGI-2 Evaluation: No tasks were evaluated.")
        
        if hasattr(ctm_model_arc, 'sleep_down'):
            ctm_model_arc.sleep_down()
            
    except FileNotFoundError as e:
        print(f"❌ Checkpoint file not found: {e}. Please ensure paths are correct.")   
    except Exception as e:
        print(f"❌ Error during ARC-AGI-2 evaluation: {e}")
        traceback.print_exc()
        
    print("\n🔬 ARC-AGI-2 Evaluation Phase Completed.")
