/home/xaqmusic/ARC-AGI/data/training/00d62c1b.json

Things to fix:  Agents need to cover entire grid and be able to transform their section into ints 0 to 9.

In [None]:
import json
import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt
from IPython.display import display, clear_output
from tqdm import tqdm
import os

# Ensure the correct backend is used for interactive plots
import matplotlib
matplotlib.use('inline')
%matplotlib inline

class SimpleAgent(nn.Module):
    def __init__(self, input_size=9):  # For a 3x3 grid area
        super(SimpleAgent, self).__init__()
        self.fc1 = nn.Linear(input_size, 64)  # First hidden layer
        self.relu = nn.ReLU()
        self.fc2 = nn.Linear(64, 32)  # Second hidden layer
        self.fc3 = nn.Linear(32, 1)  # Output layer
        
    def forward(self, x):
        x = self.relu(self.fc1(x))
        x = self.relu(self.fc2(x))
        return torch.sigmoid(self.fc3(x))  # Output between 0 and 1

def distribute_agents(grid_size, agent_scope=3):
    agents = []
    for i in range(0, grid_size[0], agent_scope):
        for j in range(0, grid_size[1], agent_scope):
            agents.append((i, j))  # Record agent positions
    return agents

def train_agents(fig, ax1, ax2, agents, inputs, outputs, epochs=100, learning_rate=0.01, device='cuda', plot_dir=None):
    criterion = nn.MSELoss()
    
    print("Moving models to the specified device...")
    for i, (position, model) in enumerate(zip(agents, agent_models)):
        model.to(device)
        print(f"Agent at position {position} moved to {device}")
    
    # Flatten list of lists into a single list of parameters
    params_to_optimize = [param for _, model in zip(agents, agent_models) for param in model.parameters()]
    print("Parameters to optimize:", params_to_optimize)
    
    optimizer = torch.optim.Adam(params_to_optimize, lr=learning_rate)
    
    losses = []  # Initialize an empty list to store all epoch losses
    
    for epoch in tqdm(range(epochs), desc="Training", leave=False):
        total_loss = 0
        num_samples = len(inputs)
        
        for i, (x, y) in enumerate(zip(inputs, outputs)):
            optimizer.zero_grad()
            
            agent_inputs = []
            predictions = []
            
            for position, model in zip(agents, agent_models):
                x_start = max(0, position[0] - 1)
                y_start = max(0, position[1] - 1)
                x_end = min(x.size(0), position[0] + 2)
                y_end = min(x.size(1), position[1] + 2)
              
                local_x = x[x_start:x_end, y_start:y_end].flatten()
                
                if len(local_x) != 9:
                    print(f"Padded local_x shape: {local_x.shape}")
                    local_x = torch.nn.functional.pad(local_x, (0, 9 - len(local_x)))
                
                agent_inputs.append(local_x)
                
                # Forward pass for the current agent
                prediction = model(agent_inputs[-1].unsqueeze(0).to(device))
                print(f"Agent at position {position} prediction shape: {prediction.shape}")
                predictions.append(prediction.squeeze())
            
            combined_pred = torch.mean(torch.stack(predictions))
            
            # Use .clone().detach() to avoid issues with tensor conversion
            target_value = y[position[0], position[1]].clone().detach().to(device)
            loss = criterion(combined_pred, target_value.float())
            
            total_loss += loss.item()
            loss.backward()
            optimizer.step()
        
        average_loss = total_loss / num_samples
        losses.append(average_loss)  # Append the current epoch's average loss
        
        # Update and display the grid visualization
        transformed = transform_grid(input_grid, agents, agent_models_by_position, device=device)
        update_plots(fig, ax1, ax2, epoch + 1, losses, transformed, plot_dir)
        
        print(f'Epoch {epoch+1}, Loss: {average_loss:.4f}')

def update_plots(fig, ax1, ax2, epoch, losses, transformed, plot_dir):
    # Clear previous plots
    ax1.clear()
    ax2.clear()

    # Update the loss plot
    ax1.plot(range(1, len(losses) + 1), losses, marker='o', color='blue')
    ax1.set_title(f'Loss Over Epochs (Epoch {epoch})')
    ax1.set_xlabel('Epoch')
    ax1.set_ylabel('Loss')

    # Update the grid visualization
    cax = ax2.imshow(transformed, cmap='viridis', interpolation='nearest')
    ax2.set_title(f'Grid Transformation (Epoch {epoch})')
    
    # Save the plot to a file with a unique name based on the epoch number
    loss_plot_path = os.path.join(plot_dir, f'loss_epoch_{epoch}.png')
    grid_plot_path = os.path.join(plot_dir, f'grid_epoch_{epoch}.png')
    
    fig.savefig(loss_plot_path)
    print(f"Saved loss plot for epoch {epoch} to {loss_plot_path}")
    
    # Save only the grid plot without colorbar
    fig2, ax3 = plt.subplots(figsize=(7, 6))
    cax2 = ax3.imshow(transformed, cmap='viridis', interpolation='nearest')
    ax3.set_title(f'Grid Transformation (Epoch {epoch})')
    fig2.savefig(grid_plot_path)
    print(f"Saved grid plot for epoch {epoch} to {grid_plot_path}")
    
    # Close the figure to free up memory
    plt.close(fig)
    plt.close(fig2)

def transform_grid(grid, agents, agent_models_by_position, device='cuda'):
    transformed = np.array(grid)  # Initialize with input grid
    
    # Convert grid to tensor and move it to the specified device
    grid_tensor = torch.tensor(grid, dtype=torch.float32).to(device)
    
    for position, model in zip(agents, agent_models):
        x_start = max(0, position[0] - 1)
        y_start = max(0, position[1] - 1)
        x_end = min(len(transformed), position[0] + 2)
        y_end = min(len(transformed[0]), position[1] + 2)
        
        local_area = grid[x_start:x_end, y_start:y_end].flatten()
        
        if len(local_area) != 9:
            print(f"Padded local_area shape: {local_area.shape}")
            local_area = np.pad(local_area, (0, 9 - len(local_area)), 'constant')
        
        # Convert local_area to tensor and move to device
        local_area_tensor = torch.tensor(local_area, dtype=torch.float32).to(device)
        
        prediction = model(local_area_tensor.unsqueeze(0)).cpu().item()  # Move back to CPU for numpy conversion
        
        if prediction > 0.5:
            transformed[position[0]][position[1]] = 4
            
    return transformed

# Function to load ARC JSON data
def load_arc_data(json_file):
    with open(json_file, 'r') as f:
        data = json.load(f)
    
    inputs = []
    outputs = []
    
    for task in data['train']:
        input_grid = np.array(task['input'])
        output_grid = np.array(task['output'])
        inputs.append(input_grid)
        outputs.append(output_grid)
    
    return inputs, outputs

# Define the plot directory
plot_dir = '/home/xaqmusic/plots'

# Create the directory if it doesn't exist
os.makedirs(plot_dir, exist_ok=True)

# Initialize plots at the beginning of your script
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 6))
loss_plot, = ax1.plot([], [], marker='o')
ax1.set_title('Loss Over Epochs')
ax1.set_xlabel('Epoch')
ax1.set_ylabel('Loss')

# Example usage
if __name__ == "__main__":
    json_file = '/home/xaqmusic/ARC-AGI/data/training/00d62c1b.json'
    
    # Load ARC JSON data
    inputs, outputs = load_arc_data(json_file)
    
    for i in range(len(inputs)):
        input_grid = np.array(inputs[i])
        output_grid = np.array(outputs[i])
        
        grid_size = input_grid.shape
        
        agents = distribute_agents(grid_size)
        
        print("Creating agent models...")
        num_agents = len(agents)
        agent_models = [SimpleAgent() for _ in range(num_agents)]
        
        print("Mapping each model to its corresponding position")
        agent_position_to_model_idx = {agent: idx for idx, agent in enumerate(agents)}
        agent_models_by_position = {agent: agent_models[idx] for idx, agent in enumerate(agents)}
        
        device = 'cuda' if torch.cuda.is_available() else 'cpu'
        print(f"Using device: {device}")
        
        train_inputs = [torch.tensor(input_grid, dtype=torch.float32).to(device)]
        train_outputs = [torch.tensor(output_grid, dtype=torch.float32).to(device)]
        
        print("Training Agents...")
        train_agents(fig, ax1, ax2, agents, train_inputs, train_outputs, epochs=100, device=device, plot_dir=plot_dir)
        
        print("\nTransforming Grid...")
        transformed = transform_grid(input_grid, agents, agent_models_by_position, device=device)
        print("\nInput Grid:")
        print(input_grid)
        print("\nOutput Grid after Transformation (Desired):")
        print(output_grid)
        print("\nOutput Grid after Transformation (Model's Prediction):")
        print(transformed)


Still not working.  The transformed grids are not changing for each epoch during the validation step and the network is not yet learning the solution.

In [None]:
import json
import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt
from IPython.display import display, clear_output
from tqdm import tqdm
import os

# Ensure the correct backend is used for interactive plots
import matplotlib
matplotlib.use('inline')
%matplotlib inline

class SimpleAgent(nn.Module):
    def __init__(self, input_size=9):  # For a 3x3 grid area
        super(SimpleAgent, self).__init__()
        self.fc1 = nn.Linear(input_size, 64)  # First hidden layer
        self.relu = nn.ReLU()
        self.fc2 = nn.Linear(64, 32)  # Second hidden layer
        self.fc3 = nn.Linear(32, 10)  # Output layer
        
    def forward(self, x):
        x = self.relu(self.fc1(x))
        x = self.relu(self.fc2(x))
        return torch.nn.functional.log_softmax(self.fc3(x), dim=1)  # Softmax for multiple classes

class ComplexAgent(nn.Module):
    def __init__(self, input_size):
        super(ComplexAgent, self).__init__()
        self.fc1 = nn.Linear(input_size, 128)
        self.relu = nn.ReLU()
        self.fc2 = nn.Linear(128, 64)
        self.fc3 = nn.Linear(64, 10)
        
    def forward(self, x):
        x = self.relu(self.fc1(x))
        x = self.relu(self.fc2(x))
        return torch.nn.functional.log_softmax(self.fc3(x), dim=1)  # Softmax for multiple classes

def distribute_agents(grid_size, agent_scope=3):
    agents = []
    for i in range(0, grid_size[0], agent_scope):
        for j in range(0, grid_size[1], agent_scope):
            if i + agent_scope <= grid_size[0] and j + agent_scope <= grid_size[1]:
                agents.append((i, j))  # Record agent positions
    return agents

def train_agents(fig, ax1, ax2, simple_agents, complex_agent, inputs, outputs, epochs=100, learning_rate=0.01, device='cuda', plot_dir=None):
    criterion = nn.NLLLoss()
    
    print("Moving models to the specified device...")
    for position in simple_agents:
        model = simple_agent_models[position]
        model.to(device)
        print(f"Simple Agent at position {position} moved to {device}")
    complex_agent.to(device)
    print(f"Complex Agent moved to {device}")

    # Flatten list of lists into a single list of parameters
    params_to_optimize = [param for _, model in simple_agent_models.items() for param in model.parameters()] + list(complex_agent.parameters())
    
    optimizer = torch.optim.Adam(params_to_optimize, lr=learning_rate)
    
    losses = []  # Initialize an empty list to store all epoch losses
    
    for epoch in tqdm(range(epochs), desc="Training", leave=False):
        clear_output(wait=True)  # Clear the output before logging new information
        print(f"Epoch {epoch + 1}/{epochs}")
        
        total_loss = 0
        num_samples = len(inputs)
        
        # Training phase
        for i, (x, y) in enumerate(zip(inputs, outputs)):
            optimizer.zero_grad()
            
            simple_predictions = []  # Initialize this list at the start of each iteration
            simple_targets = []  # Reset simple targets
            
            # Collect predictions from simple agents
            for position in simple_agents:
                x_start = max(0, position[0])
                y_start = max(0, position[1])
                x_end = min(x.size(0), position[0] + 3)
                y_end = min(x.size(1), position[1] + 3)

                # Extract local input and flatten
                local_x = x[x_start:x_end, y_start:y_end].flatten().to(device)
                
                # Pad the local area if necessary to ensure it's of size 9
                if len(local_x) != 9:
                    print(f"Padded local_x shape: {local_x.shape}")
                    local_x = torch.nn.functional.pad(local_x, (0, 9 - len(local_x)))
                
                # Ensure local_x_batched has the correct shape (1, 9)
                local_x_batched = local_x.view(1, -1)
                
                # Forward pass for the current simple agent
                model = simple_agent_models[position]
                prediction = model(local_x_batched)
                print(f"Simple Agent at position {position} prediction shape: {prediction.shape}")
                
                # Append the prediction to the list (ensure it's on CPU)
                simple_predictions.append(prediction.squeeze().cpu())
            
            complex_prediction = complex_agent(x.flatten().unsqueeze(0).to(device)).squeeze()
            
            # Combine predictions
            combined_pred = torch.stack(simple_predictions + [complex_prediction.cpu()])
            
            # Prepare target values for each prediction
            y_flat = y.flatten().long().to(device)
            complex_target = y_flat[-1]  # Assuming the last element is the target for the complex agent
            
            for position in simple_agents:
                x_start = max(0, position[0])
                y_start = max(0, position[1])
                x_end = min(x.size(0), position[0] + 3)
                y_end = min(x.size(1), position[1] + 3)

                # Extract local target and flatten
                local_y = y[x_start:x_end, y_start:y_end].flatten().to(device)
                
                # Pad the local area if necessary to ensure it's of size 9
                if len(local_y) != 9:
                    print(f"Padded local_y shape: {local_y.shape}")
                    local_y = torch.nn.functional.pad(local_y, (0, 9 - len(local_y)))
                
                # Ensure local_y_batched has the correct shape (1, 9)
                local_y_batched = local_y.view(1, -1)
                
                # Check if local_y_batched is empty
                if local_y_batched.shape[1] != 9:
                    print(f"Error: local_y_batched has an unexpected shape {local_y_batched.shape} for position {position}")
                    continue
                
                # Use argmax on the batch dimension to get the target class
                simple_target = local_y_batched.argmax(dim=1).cpu()
                
                # Append the target to the list (ensure it's a scalar)
                if len(simple_target) == 0:
                    print(f"Warning: Simple Agent at position {position} has an empty target tensor.")
                    continue
                
                simple_targets.append(simple_target.item())
            
            # Check if we have valid targets for all simple agents
            if len(simple_targets) != len(simple_agents):
                print("Skipping this sample due to mismatched number of simple targets")
                continue
            
            # Stack all targets and ensure they match the batch size of combined_pred
            all_targets = torch.tensor(simple_targets + [complex_target], dtype=torch.long)
            print(f"All Targets shape: {all_targets.shape}")
            
            # Calculate loss for each prediction separately
            total_loss_this_sample = 0
            for pred, target in zip(combined_pred, all_targets):
                loss = criterion(pred.unsqueeze(0), target.unsqueeze(0))
                total_loss_this_sample += loss.item()
                loss.backward(retain_graph=True)  # Retain graph to allow backpropagation through all agents
            
            total_loss += total_loss_this_sample
            optimizer.step()
        
        average_loss = total_loss / num_samples
        losses.append(average_loss)  # Append the current epoch's average loss
        
        # Validation phase
        validation_loss = 0
        transformed_grid = np.zeros_like(inputs[0].cpu().numpy())
        for i, (x, y) in enumerate(zip(inputs, outputs)):
            simple_predictions = []  # Initialize this list at the start of each iteration
            simple_targets = []  # Reset simple targets
            
            # Collect predictions from simple agents
            for position in simple_agents:
                x_start = max(0, position[0])
                y_start = max(0, position[1])
                x_end = min(x.size(0), position[0] + 3)
                y_end = min(x.size(1), position[1] + 3)

                # Extract local input and flatten
                local_x = x[x_start:x_end, y_start:y_end].flatten().to(device)
                
                # Pad the local area if necessary to ensure it's of size 9
                if len(local_x) != 9:
                    print(f"Padded local_x shape: {local_x.shape}")
                    local_x = torch.nn.functional.pad(local_x, (0, 9 - len(local_x)))
                
                # Ensure local_x_batched has the correct shape (1, 9)
                local_x_batched = local_x.view(1, -1)
                
                # Forward pass for the current simple agent
                model = simple_agent_models[position]
                prediction = model(local_x_batched)
                print(f"Simple Agent at position {position} prediction shape: {prediction.shape}")
                
                # Append the prediction to the list (ensure it's on CPU)
                simple_predictions.append(prediction.squeeze().cpu())
            
            complex_prediction = complex_agent(x.flatten().unsqueeze(0).to(device)).squeeze()
            
            # Combine predictions
            combined_pred = torch.stack(simple_predictions + [complex_prediction.cpu()])
            
            # Prepare target values for each prediction (validation phase)
            y_flat = y.flatten().long().to(device)
            complex_target = y_flat[-1]  # Assuming the last element is the target for the complex agent
            
            for position in simple_agents:
                x_start = max(0, position[0])
                y_start = max(0, position[1])
                x_end = min(x.size(0), position[0] + 3)
                y_end = min(x.size(1), position[1] + 3)

                # Extract local target and flatten
                local_y = y[x_start:x_end, y_start:y_end].flatten().to(device)
                
                # Pad the local area if necessary to ensure it's of size 9
                if len(local_y) != 9:
                    print(f"Padded local_y shape: {local_y.shape}")
                    local_y = torch.nn.functional.pad(local_y, (0, 9 - len(local_y)))
                
                # Ensure local_y_batched has the correct shape (1, 9)
                local_y_batched = local_y.view(1, -1)
                
                # Check if local_y_batched is empty
                if local_y_batched.shape[1] != 9:
                    print(f"Error: local_y_batched has an unexpected shape {local_y_batched.shape} for position {position}")
                    continue
                
                # Use argmax on the batch dimension to get the target class
                simple_target = local_y_batched.argmax(dim=1).cpu()
                
                # Append the target to the list (ensure it's a scalar)
                if len(simple_target) == 0:
                    print(f"Warning: Simple Agent at position {position} has an empty target tensor.")
                    continue
                
                simple_targets.append(simple_target.item())
            
            # Check if we have valid targets for all simple agents
            if len(simple_targets) != len(simple_agents):
                print("Skipping this sample due to mismatched number of simple targets")
                continue
            
            # Stack all targets and ensure they match the batch size of combined_pred
            all_targets = torch.tensor(simple_targets + [complex_target], dtype=torch.long)
            print(f"All Targets shape: {all_targets.shape}")
            
            # Calculate validation loss for each prediction separately
            total_validation_loss_this_sample = 0
            for pred, target in zip(combined_pred, all_targets):
                loss = criterion(pred.unsqueeze(0), target.unsqueeze(0))
                total_validation_loss_this_sample += loss.item()
            
            validation_loss += total_validation_loss_this_sample
            
            # Update transformed grid based on predictions (if needed)
            for position, pred in zip(simple_agents, simple_predictions):
                if pred.argmax(dim=0) == 4:
                    transformed_grid[position[0]][position[1]] = 4

        average_validation_loss = validation_loss / num_samples
        print(f'Epoch {epoch+1}, Training Loss: {average_loss:.4f}, Validation Loss: {average_validation_loss:.4f}')

        # Update and display the grid visualization
        transformed = transform_grid(input_grid, simple_agents, complex_agent, device=device)
        update_plots(fig, ax1, ax2, epoch + 1, losses, transformed, plot_dir)

def update_plots(fig, ax1, ax2, epoch, losses, transformed, plot_dir):
    # Clear previous plots
    ax1.clear()
    ax2.clear()

    # Update the loss plot
    ax1.plot(range(1, len(losses) + 1), losses, marker='o', color='blue')
    ax1.set_title(f'Loss Over Epochs (Epoch {epoch})')
    ax1.set_xlabel('Epoch')
    ax1.set_ylabel('Loss')

    # Update the grid visualization
    cax = ax2.imshow(transformed, cmap='viridis', interpolation='nearest')
    ax2.set_title(f'Grid Transformation (Epoch {epoch})')
    
    # Save the plot to a file with a unique name based on the epoch number
    loss_plot_path = os.path.join(plot_dir, f'loss_epoch_{epoch}.png')
    grid_plot_path = os.path.join(plot_dir, f'grid_epoch_{epoch}.png')
    
    fig.savefig(loss_plot_path)
    print(f"Saved loss plot for epoch {epoch} to {loss_plot_path}")
    
    # Save only the grid plot without colorbar
    fig2, ax3 = plt.subplots(figsize=(7, 6))
    cax2 = ax3.imshow(transformed, cmap='viridis', interpolation='nearest')
    ax3.set_title(f'Grid Transformation (Epoch {epoch})')
    fig2.savefig(grid_plot_path)
    print(f"Saved grid plot for epoch {epoch} to {grid_plot_path}")
    
    # Close the figure to free up memory
    plt.close(fig)
    plt.close(fig2)

def transform_grid(grid, simple_agents, complex_agent, device='cuda'):
    transformed = np.array(grid)  # Initialize with input grid
    
    # Convert grid to tensor and move it to the specified device
    grid_tensor = torch.tensor(grid, dtype=torch.float32).to(device)
    
    for position in simple_agents:
        x_start = max(0, position[0])
        y_start = max(0, position[1])
        x_end = min(len(transformed), position[0] + 3)
        y_end = min(len(transformed[0]), position[1] + 3)

        local_area = grid[x_start:x_end, y_start:y_end].flatten()
        
        # Pad the local area if necessary
        if len(local_area) != 9:
            print(f"Padded local_area shape: {local_area.shape}")
            local_area = np.pad(local_area, (0, 9 - len(local_area)), 'constant')
        
        # Convert local_area to tensor and move to device
        local_area_tensor = torch.tensor(local_area, dtype=torch.float32).to(device)
        
        prediction = simple_agent_models[position](local_area_tensor.view(1, -1)).cpu().argmax(dim=1).item()  # Add batch dimension
        
        if prediction == 4:
            transformed[position[0]][position[1]] = 4
            
    complex_prediction = complex_agent(grid_tensor.view(-1).unsqueeze(0)).cpu().argmax(dim=1).item()
    
    return transformed

# Function to load ARC JSON data
def load_arc_data(json_file):
    with open(json_file, 'r') as f:
        data = json.load(f)
    
    inputs = []
    outputs = []
    
    for task in data['train']:
        input_grid = np.array(task['input'])
        output_grid = np.array(task['output'])
        
        # Print the input and desired output grids
        print("\nInput Grid:")
        print(input_grid)
        print("\nDesired Output Grid:")
        print(output_grid)
        
        inputs.append(input_grid)
        outputs.append(output_grid)
    
    return inputs, outputs

# Define the plot directory
plot_dir = '/home/xaqmusic/plots'

# Create the directory if it doesn't exist
os.makedirs(plot_dir, exist_ok=True)

# Initialize plots at the beginning of your script
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 6))
loss_plot, = ax1.plot([], [], marker='o')
ax1.set_title('Loss Over Epochs')
ax1.set_xlabel('Epoch')
ax1.set_ylabel('Loss')

# Example usage
if __name__ == "__main__":
    json_file = '/home/xaqmusic/ARC-AGI/data/training/00d62c1b.json'
    
    # Load ARC JSON data
    inputs, outputs = load_arc_data(json_file)
    
    for i in range(len(inputs)):
        input_grid = np.array(inputs[i])
        output_grid = np.array(outputs[i])
        
        grid_size = input_grid.shape
        
        simple_agents = distribute_agents(grid_size)
        
        print("Creating agent models...")
        num_simple_agents = len(simple_agents)
        simple_agent_models = {agent: SimpleAgent() for agent in simple_agents}
        
        complex_agent = ComplexAgent(input_size=grid_size[0] * grid_size[1])
        
        device = 'cuda' if torch.cuda.is_available() else 'cpu'
        print(f"Using device: {device}")
        
        train_inputs = [torch.tensor(input_grid, dtype=torch.float32).to(device)]
        train_outputs = [torch.tensor(output_grid, dtype=torch.long).to(device)]
        
        print("Training Agents...")
        train_agents(fig, ax1, ax2, simple_agents, complex_agent, train_inputs, train_outputs, epochs=100, device=device, plot_dir=plot_dir)
        
        print("\nTransforming Grid...")
        transformed = transform_grid(input_grid, simple_agents, complex_agent, device=device)
        print("\nInput Grid:")
        print(input_grid)
        print("\nOutput Grid after Transformation (Desired):")
        print(output_grid)
        print("\nOutput Grid after Transformation (Model's Prediction):")
        print(transformed)


the simple agents are too far apart so they are only updating based on blank pixels.  The boundries are not being addressed.

In [None]:
import json
import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt
from IPython.display import display, clear_output
from tqdm import tqdm
import os
from datetime import datetime  # Import datetime to generate timestamp

# Ensure the correct backend is used for interactive plots
import matplotlib
matplotlib.use('inline')
%matplotlib inline

class SimpleAgent(nn.Module):
    def __init__(self, input_size=9):  # For a 3x3 grid area
        super(SimpleAgent, self).__init__()
        self.fc1 = nn.Linear(input_size, 128)  # First hidden layer
        self.relu = nn.ReLU()
        self.fc2 = nn.Linear(128, 64)  # Second hidden layer
        self.fc3 = nn.Linear(64, 10)  # Output layer
        
    def forward(self, x):
        x = self.relu(self.fc1(x))
        x = self.relu(self.fc2(x))
        return torch.nn.functional.log_softmax(self.fc3(x), dim=1)  # Softmax for multiple classes

class ComplexAgent(nn.Module):
    def __init__(self, input_size):
        super(ComplexAgent, self).__init__()
        self.fc1 = nn.Linear(input_size, 256)
        self.relu = nn.ReLU()
        self.fc2 = nn.Linear(256, 128)
        self.fc3 = nn.Linear(128, 10)
        
    def forward(self, x):
        x = self.relu(self.fc1(x))
        x = self.relu(self.fc2(x))
        return torch.nn.functional.log_softmax(self.fc3(x), dim=1)  # Softmax for multiple classes

def distribute_agents(grid_size, agent_scope=3):
    agents = []
    for i in range(0, grid_size[0], agent_scope):
        for j in range(0, grid_size[1], agent_scope):
            if i + agent_scope <= grid_size[0] and j + agent_scope <= grid_size[1]:
                agents.append((i, j))  # Record agent positions
    return agents

def custom_loss(predicted_grid, target_grid, threshold=0.95):
    """
    Calculate a custom loss that rewards grids which are closer in terms of patterns.
    
    Parameters:
    - predicted_grid: The grid generated by the model (numpy array).
    - target_grid: The desired output grid (numpy array).
    - threshold: A value between 0 and 1 to determine how similar two values need to be.
    
    Returns:
    - loss: The calculated loss.
    """
    # Convert grids to numpy arrays if they are not already
    predicted_grid = np.array(predicted_grid)
    target_grid = np.array(target_grid)
    
    # Ensure both grids have the same shape
    assert predicted_grid.shape == target_grid.shape, "Grids must have the same shape"
    
    # Calculate the absolute difference between the two grids
    diff = np.abs(predicted_grid - target_grid)
    
    # Count how many cells are within the threshold of being correct
    within_threshold_count = np.sum(diff < threshold)
    
    # Calculate the custom loss as the percentage of cells not within the threshold
    loss = 1.0 - (within_threshold_count / predicted_grid.size)
    
    return loss

def train_agents(fig, ax1, ax2, simple_agents, complex_agent, inputs, outputs, epochs=100, learning_rate=0.001, device='cuda', plot_dir=None):
    criterion = nn.CrossEntropyLoss()  # Use Cross-Entropy Loss for classification
    
    print("Moving models to the specified device...")
    for position in simple_agents:
        model = simple_agent_models[position]
        model.to(device)
        print(f"Simple Agent at position {position} moved to {device}")
    complex_agent.to(device)
    print(f"Complex Agent moved to {device}")

    # Flatten list of lists into a single list of parameters
    params_to_optimize = [param for _, model in simple_agent_models.items() for param in model.parameters()] + list(complex_agent.parameters())
    
    optimizer = torch.optim.Adam(params_to_optimize, lr=learning_rate)
    
    losses = []  # Initialize an empty list to store all epoch losses
    
    best_loss = float('inf')
    patience = 10  # Number of epochs to wait for improvement
    counter = 0
    
    for epoch in tqdm(range(epochs), desc="Training", leave=False):
        clear_output(wait=True)  # Clear the output before logging new information
        print(f"Epoch {epoch + 1}/{epochs}")
        
        total_loss = 0
        num_samples = len(inputs)
        
        # Training phase
        for i, (x, y) in enumerate(zip(inputs, outputs)):
            optimizer.zero_grad()
            
            simple_predictions = []  # Initialize this list at the start of each iteration
            
            simple_targets = []  # Initialize this list for targets within the current training sample
            
            # Collect predictions from simple agents
            for position in simple_agents:
                x_start = max(0, position[0])
                y_start = max(0, position[1])
                x_end = min(x.size(0), position[0] + 3)
                y_end = min(x.size(1), position[1] + 3)

                # Extract local input and flatten
                local_x = x[x_start:x_end, y_start:y_end].flatten().to(device)
                
                # Pad the local area if necessary to ensure it's of size 9
                if len(local_x) != 9:
                    print(f"Padded local_x shape: {local_x.shape}")
                    local_x = torch.nn.functional.pad(local_x, (0, 9 - len(local_x)))
                
                # Ensure local_x_batched has the correct shape (1, 9)
                local_x_batched = local_x.view(1, -1)
                
                # Forward pass for the current simple agent
                model = simple_agent_models[position]
                prediction = model(local_x_batched)
                print(f"Simple Agent at position {position} prediction shape: {prediction.shape}")
                
                # Append the prediction to the list (ensure it's on CPU)
                simple_predictions.append(prediction.squeeze().cpu())
            
            complex_prediction = complex_agent(x.flatten().unsqueeze(0).to(device)).squeeze()
            
            # Combine predictions
            combined_pred = torch.stack(simple_predictions + [complex_prediction.cpu()])
            
            # Prepare target values for each prediction
            y_flat = y.flatten().long().to(device)
            complex_target = y_flat[-1]  # Assuming the last element is the target for the complex agent
            
            for position in simple_agents:
                x_start = max(0, position[0])
                y_start = max(0, position[1])
                x_end = min(x.size(0), position[0] + 3)
                y_end = min(x.size(1), position[1] + 3)

                # Extract local target and flatten
                local_y = y[x_start:x_end, y_start:y_end].flatten().to(device)
                
                # Pad the local area if necessary to ensure it's of size 9
                if len(local_y) != 9:
                    print(f"Padded local_y shape: {local_y.shape}")
                    local_y = torch.nn.functional.pad(local_y, (0, 9 - len(local_y)))
                
                # Ensure local_y_batched has the correct shape (1, 9)
                local_y_batched = local_y.view(1, -1)
                
                # Check if local_y_batched is empty
                if local_y_batched.shape[1] != 9:
                    print(f"Error: local_y_batched has an unexpected shape {local_y_batched.shape} for position {position}")
                    continue
                
                # Use argmax on the batch dimension to get the target class
                simple_target = local_y_batched.argmax(dim=1).cpu()
                
                # Append the target to the list (ensure it's a scalar)
                if len(simple_target) == 0:
                    print(f"Warning: Simple Agent at position {position} has an empty target tensor.")
                    continue
                
                simple_targets.append(simple_target.item())
            
            # Check if we have valid targets for all simple agents
            if len(simple_targets) != len(simple_agents):
                print("Skipping this sample due to mismatched number of simple targets")
                continue
            
            # Stack all targets and ensure they match the batch size of combined_pred
            all_targets = torch.tensor(simple_targets + [complex_target], dtype=torch.long)
            print(f"All Targets shape: {all_targets.shape}")
            
            # Calculate loss for each prediction separately
            total_loss_this_sample = 0
            for pred, target in zip(combined_pred, all_targets):
                loss = criterion(pred.unsqueeze(0), target.unsqueeze(0))
                total_loss_this_sample += loss.item()
                loss.backward(retain_graph=True)  # Retain graph to allow backpropagation through all agents
            
            total_loss += total_loss_this_sample
            optimizer.step()
        
        average_loss = total_loss / num_samples
        losses.append(average_loss)  # Append the current epoch's average loss
        
        # Validation phase
        validation_loss = 0
        transformed_grid = np.zeros_like(inputs[0].cpu().numpy())
        for i, (x, y) in enumerate(zip(inputs, outputs)):
            simple_predictions = []  # Initialize this list at the start of each iteration
            
            simple_targets = []  # Initialize this list for targets within the current validation sample
            
            # Collect predictions from simple agents
            for position in simple_agents:
                x_start = max(0, position[0])
                y_start = max(0, position[1])
                x_end = min(x.size(0), position[0] + 3)
                y_end = min(x.size(1), position[1] + 3)

                # Extract local input and flatten
                local_x = x[x_start:x_end, y_start:y_end].flatten().to(device)
                
                # Pad the local area if necessary to ensure it's of size 9
                if len(local_x) != 9:
                    print(f"Padded local_x shape: {local_x.shape}")
                    local_x = torch.nn.functional.pad(local_x, (0, 9 - len(local_x)))
                
                # Ensure local_x_batched has the correct shape (1, 9)
                local_x_batched = local_x.view(1, -1)
                
                # Forward pass for the current simple agent
                model = simple_agent_models[position]
                prediction = model(local_x_batched)
                print(f"Simple Agent at position {position} prediction shape: {prediction.shape}")
                
                # Append the prediction to the list (ensure it's on CPU)
                simple_predictions.append(prediction.squeeze().cpu())
            
            complex_prediction = complex_agent(x.flatten().unsqueeze(0).to(device)).squeeze()
            
            # Combine predictions
            combined_pred = torch.stack(simple_predictions + [complex_prediction.cpu()])
            
            # Prepare target values for each prediction (validation phase)
            y_flat = y.flatten().long().to(device)
            complex_target = y_flat[-1]  # Assuming the last element is the target for the complex agent
            
            for position in simple_agents:
                x_start = max(0, position[0])
                y_start = max(0, position[1])
                x_end = min(x.size(0), position[0] + 3)
                y_end = min(x.size(1), position[1] + 3)

                # Extract local target and flatten
                local_y = y[x_start:x_end, y_start:y_end].flatten().to(device)
                
                # Pad the local area if necessary to ensure it's of size 9
                if len(local_y) != 9:
                    print(f"Padded local_y shape: {local_y.shape}")
                    local_y = torch.nn.functional.pad(local_y, (0, 9 - len(local_y)))
                
                # Ensure local_y_batched has the correct shape (1, 9)
                local_y_batched = local_y.view(1, -1)
                
                # Check if local_y_batched is empty
                if local_y_batched.shape[1] != 9:
                    print(f"Error: local_y_batched has an unexpected shape {local_y_batched.shape} for position {position}")
                    continue
                
                # Use argmax on the batch dimension to get the target class
                simple_target = local_y_batched.argmax(dim=1).cpu()
                
                # Append the target to the list (ensure it's a scalar)
                if len(simple_target) == 0:
                    print(f"Warning: Simple Agent at position {position} has an empty target tensor.")
                    continue
                
                simple_targets.append(simple_target.item())
            
            # Check if we have valid targets for all simple agents
            if len(simple_targets) != len(simple_agents):
                print("Skipping this sample due to mismatched number of simple targets")
                continue
            
            # Stack all targets and ensure they match the batch size of combined_pred
            all_targets = torch.tensor(simple_targets + [complex_target], dtype=torch.long)
            print(f"All Targets shape: {all_targets.shape}")
            
            # Calculate validation loss for each prediction separately
            total_validation_loss_this_sample = 0
            for pred, target in zip(combined_pred, all_targets):
                loss = criterion(pred.unsqueeze(0), target.unsqueeze(0))
                total_validation_loss_this_sample += loss.item()
            
            validation_loss += total_validation_loss_this_sample
            
            # Update transformed grid based on predictions (if needed)
            for position, pred in zip(simple_agents, simple_predictions):
                probabilities = torch.nn.functional.softmax(pred, dim=0)
                predicted_class = torch.argmax(probabilities).item()
                transformed_grid[position[0]][position[1]] = predicted_class

        average_validation_loss = validation_loss / num_samples
        print(f'Epoch {epoch+1}, Training Loss: {average_loss:.4f}, Validation Loss: {average_validation_loss:.4f}')

        # Update and display the grid visualization
        transformed = transform_grid(input_grid, simple_agents, complex_agent, device=device)
        update_plots(fig, ax1, ax2, epoch + 1, losses, transformed, plot_dir)

        # Early stopping logic
        if average_validation_loss < best_loss:
            best_loss = average_validation_loss
            counter = 0
        else:
            counter += 1
            if counter >= patience:
                print("Early stopping: Validation loss did not improve for {} epochs.".format(patience))
                break

def update_plots(fig, ax1, ax2, epoch, losses, transformed, plot_dir):
    # Clear previous plots
    ax1.clear()
    ax2.clear()

    # Update the loss plot
    ax1.plot(range(1, len(losses) + 1), losses, marker='o', color='blue')
    ax1.set_title(f'Loss Over Epochs (Epoch {epoch})')
    ax1.set_xlabel('Epoch')
    ax1.set_ylabel('Loss')

    # Update the grid visualization
    cax = ax2.imshow(transformed, cmap='viridis', interpolation='nearest')
    ax2.set_title(f'Grid Transformation (Epoch {epoch})')
    
    # Save the plot to a file with a unique name based on the epoch number
    loss_plot_path = os.path.join(plot_dir, f'{prefix}_loss_epoch_{epoch}.png')  # Prefix added here
    grid_plot_path = os.path.join(plot_dir, f'{prefix}_grid_epoch_{epoch}.png')  # Prefix added here
    
    fig.savefig(loss_plot_path)
    print(f"Saved loss plot for epoch {epoch} to {loss_plot_path}")
    
    # Save only the grid plot without colorbar
    fig2, ax3 = plt.subplots(figsize=(7, 6))
    cax2 = ax3.imshow(transformed, cmap='viridis', interpolation='nearest')
    ax3.set_title(f'Grid Transformation (Epoch {epoch})')
    fig2.savefig(grid_plot_path)
    print(f"Saved grid plot for epoch {epoch} to {grid_plot_path}")
    
    # Close the figure to free up memory
    plt.close(fig)
    plt.close(fig2)

def transform_grid(grid, simple_agents, complex_agent, device='cuda'):
    transformed = np.zeros_like(grid)  # Initialize with zeros
    
    # Convert grid to tensor and move it to the specified device
    grid_tensor = torch.tensor(grid, dtype=torch.float32).to(device)
    
    for position in simple_agents:
        x_start = max(0, position[0])
        y_start = max(0, position[1])
        x_end = min(len(transformed), position[0] + 3)
        y_end = min(len(transformed[0]), position[1] + 3)

        local_area = grid[x_start:x_end, y_start:y_end].flatten()
        
        # Pad the local area if necessary
        if len(local_area) != 9:
            print(f"Padded local_area shape: {local_area.shape}")
            local_area = np.pad(local_area, (0, 9 - len(local_area)), 'constant')
        
        # Convert local_area to tensor and move to device
        local_area_tensor = torch.tensor(local_area, dtype=torch.float32).to(device)
        
        prediction = simple_agent_models[position](local_area_tensor.unsqueeze(0)).cpu()
        probabilities = torch.nn.functional.softmax(prediction, dim=1)
        
        # Log the probabilities for debugging
        print(f"Simple Agent at position {position} - Predicted Probabilities: {probabilities.squeeze().tolist()}")
        
        # Get the predicted class (integer between 0 and 9)
        predicted_class = torch.argmax(probabilities).item()
        transformed[position[0]][position[1]] = predicted_class
    
    complex_prediction = complex_agent(grid_tensor.view(-1).unsqueeze(0)).cpu()
    complex_probabilities = torch.nn.functional.softmax(complex_prediction, dim=1)
    
    # Log the complex agent's probabilities for debugging
    print(f"Complex Agent - Predicted Probabilities: {complex_probabilities.squeeze().tolist()}")
    
    # Get the predicted class (integer between 0 and 9)
    complex_predicted_class = torch.argmax(complex_probabilities).item()
    
    return transformed

# Function to load ARC JSON data
def load_arc_data(json_file):
    with open(json_file, 'r') as f:
        data = json.load(f)
    
    inputs = []
    outputs = []
    
    for task in data['train']:
        input_grid = np.array(task['input'])
        output_grid = np.array(task['output'])
        inputs.append(input_grid)
        outputs.append(output_grid)
    
    return inputs, outputs

# Define the plot directory
plot_dir = '/home/xaqmusic/plots'

# Create the directory if it doesn't exist
os.makedirs(plot_dir, exist_ok=True)

# Generate a unique prefix for this run using timestamp
prefix = datetime.now().strftime("%Y%m%d_%H%M%S")  # Example: "20231005_143000"

# Initialize plots at the beginning of your script
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 6))
loss_plot, = ax1.plot([], [], marker='o')
ax1.set_title('Loss Over Epochs')
ax1.set_xlabel('Epoch')
ax1.set_ylabel('Loss')

# Example usage
if __name__ == "__main__":
    json_file = '/home/xaqmusic/ARC-AGI/data/training/00d62c1b.json'
    
    # Load ARC JSON data
    inputs, outputs = load_arc_data(json_file)
    
    for i in range(len(inputs)):
        input_grid = np.array(inputs[i])
        output_grid = np.array(outputs[i])
        
        grid_size = input_grid.shape
        
        simple_agents = distribute_agents(grid_size)
        
        print("Creating agent models...")
        num_simple_agents = len(simple_agents)
        simple_agent_models = {agent: SimpleAgent() for agent in simple_agents}
        
        complex_agent = ComplexAgent(input_size=grid_size[0] * grid_size[1])
        
        device = 'cuda' if torch.cuda.is_available() else 'cpu'
        print(f"Using device: {device}")
        
        train_inputs = [torch.tensor(input_grid, dtype=torch.float32).to(device)]
        train_outputs = [torch.tensor(output_grid, dtype=torch.long).to(device)]
        
        print("Training Agents...")
        train_agents(fig, ax1, ax2, simple_agents, complex_agent, train_inputs, train_outputs, epochs=100, device=device, plot_dir=plot_dir)
        
        print("\nTransforming Grid...")
        transformed = transform_grid(input_grid, simple_agents, complex_agent, device=device)
        print("\nInput Grid:")
        print(input_grid)
        print("\nOutput Grid after Transformation (Desired):")
        print(output_grid)
        print("\nOutput Grid after Transformation (Model's Prediction):")
        print(transformed)


Removing the ComplexAgent (wasn't doing anything useful) and increasing the number of simple agents to cover the entire grid (no gaps)

In [None]:
import json
import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt
from IPython.display import display, clear_output
from tqdm import tqdm
import os
from datetime import datetime

# Ensure the correct backend is used for interactive plots
import matplotlib
matplotlib.use('inline')
%matplotlib inline

class SimpleAgent(nn.Module):
    def __init__(self, input_size=9):  # For a 3x3 grid area
        super(SimpleAgent, self).__init__()
        self.fc1 = nn.Linear(input_size, 64)  # First hidden layer
        self.relu = nn.ReLU()
        self.fc2 = nn.Linear(64, 32)  # Second hidden layer
        self.fc3 = nn.Linear(32, 10)  # Output layer
        
    def forward(self, x):
        x = self.relu(self.fc1(x))
        x = self.relu(self.fc2(x))
        return torch.nn.functional.log_softmax(self.fc3(x), dim=1)  # Softmax for multiple classes

def distribute_agents(grid_size, agent_scope=3):
    agents = []
    for i in range(0, grid_size[0], agent_scope - 1):  # Adjust step size to ensure overlap
        for j in range(0, grid_size[1], agent_scope - 1):
            if i + agent_scope <= grid_size[0] and j + agent_scope <= grid_size[1]:
                agents.append((i, j))  # Record agent positions
    return agents

def train_agents(fig, ax1, ax2, simple_agents, inputs, outputs, epochs=50, learning_rate=0.1, device='cuda', plot_dir=None):
    criterion = nn.NLLLoss()
    
    print("Moving models to the specified device...")
    for position in simple_agents:
        model = simple_agent_models[position]
        model.to(device)
        print(f"Simple Agent at position {position} moved to {device}")

    # Flatten list of lists into a single list of parameters
    params_to_optimize = [param for _, model in simple_agent_models.items() for param in model.parameters()]
    
    optimizer = torch.optim.Adam(params_to_optimize, lr=learning_rate)
    
    losses = []  # Initialize an empty list to store all epoch losses
    
    for epoch in tqdm(range(epochs), desc="Training", leave=False):
        clear_output(wait=True)  # Clear the output before logging new information
        print(f"Epoch {epoch + 1}/{epochs}")
        total_loss = 0
        num_samples = len(inputs)
        
        # Training phase
        for i, (x, y) in enumerate(zip(inputs, outputs)):
            optimizer.zero_grad()
            
            simple_predictions = []  # Initialize this list at the start of each iteration
            
            # Collect predictions from simple agents
            for position in simple_agents:
                x_start = max(0, position[0])
                y_start = max(0, position[1])
                x_end = min(x.size(0), position[0] + 3)
                y_end = min(x.size(1), position[1] + 3)

                local_x = x[x_start:x_end, y_start:y_end].flatten().to(device)
                
                if len(local_x) != 9:
                    print(f"Padded local_x shape: {local_x.shape}")
                    local_x = torch.nn.functional.pad(local_x, (0, 9 - len(local_x)))
                
                # Forward pass for the current simple agent
                model = simple_agent_models[position]
                prediction = model(local_x.unsqueeze(0).to(device))
                print(f"Simple Agent at position {position} prediction shape: {prediction.shape}")
                
                # Append the prediction to the list (ensure it's on CPU)
                simple_predictions.append(prediction.squeeze().cpu())
            
            combined_pred = torch.stack(simple_predictions)
            
            # Prepare target values for each prediction
            y_flat = y.flatten().long().to(device)
            simple_targets = []
            
            for position in simple_agents:
                x_start = max(0, position[0])
                y_start = max(0, position[1])
                x_end = min(x.size(0), position[0] + 3)
                y_end = min(x.size(1), position[1] + 3)

                local_y = y[x_start:x_end, y_start:y_end].flatten().to(device)
                
                if len(local_y) != 9:
                    print(f"Padded local_y shape: {local_y.shape}")
                    local_y = torch.nn.functional.pad(local_y, (0, 9 - len(local_y)))
                
                # Correct the dimension for argmax
                simple_targets.append(local_y.argmax(dim=0).cpu())
            
            # Stack all targets and ensure they match the batch size of combined_pred
            all_targets = torch.stack(simple_targets)
            
            # Calculate loss for each prediction separately
            total_loss_this_sample = 0
            for pred, target in zip(combined_pred, all_targets):
                loss = criterion(pred.unsqueeze(0), target.unsqueeze(0))
                total_loss_this_sample += loss.item()
                loss.backward(retain_graph=True)  # Retain graph to allow backpropagation through all agents
            
            total_loss += total_loss_this_sample
            optimizer.step()
        
        average_loss = total_loss / num_samples
        losses.append(average_loss)  # Append the current epoch's average loss
        
        # Validation phase (same as training phase)
        validation_loss = 0
        transformed_grid = np.zeros_like(inputs[0].cpu().numpy())
        for i, (x, y) in enumerate(zip(inputs, outputs)):
            simple_predictions = []  # Initialize this list at the start of each iteration
            
            # Collect predictions from simple agents
            for position in simple_agents:
                x_start = max(0, position[0])
                y_start = max(0, position[1])
                x_end = min(x.size(0), position[0] + 3)
                y_end = min(x.size(1), position[1] + 3)

                local_x = x[x_start:x_end, y_start:y_end].flatten().to(device)
                
                if len(local_x) != 9:
                    print(f"Padded local_x shape: {local_x.shape}")
                    local_x = torch.nn.functional.pad(local_x, (0, 9 - len(local_x)))
                
                # Forward pass for the current simple agent
                model = simple_agent_models[position]
                prediction = model(local_x.unsqueeze(0).to(device))
                print(f"Simple Agent at position {position} prediction shape: {prediction.shape}")
                
                # Append the prediction to the list (ensure it's on CPU)
                simple_predictions.append(prediction.squeeze().cpu())
            
            combined_pred = torch.stack(simple_predictions)
            
            # Prepare target values for each prediction (validation phase)
            y_flat = y.flatten().long().to(device)
            simple_targets = []
            
            for position in simple_agents:
                x_start = max(0, position[0])
                y_start = max(0, position[1])
                x_end = min(x.size(0), position[0] + 3)
                y_end = min(x.size(1), position[1] + 3)

                local_y = y[x_start:x_end, y_start:y_end].flatten().to(device)
                
                if len(local_y) != 9:
                    print(f"Padded local_y shape: {local_y.shape}")
                    local_y = torch.nn.functional.pad(local_y, (0, 9 - len(local_y)))
                
                # Correct the dimension for argmax
                simple_targets.append(local_y.argmax(dim=0).cpu())
            
            # Stack all targets and ensure they match the batch size of combined_pred
            all_targets = torch.stack(simple_targets)
            
            # Calculate validation loss for each prediction separately
            total_validation_loss_this_sample = 0
            for pred, target in zip(combined_pred, all_targets):
                loss = criterion(pred.unsqueeze(0), target.unsqueeze(0))
                total_validation_loss_this_sample += loss.item()
            
            validation_loss += total_validation_loss_this_sample
            
            # Update transformed grid based on predictions (if needed)
            for j, position in enumerate(simple_agents):
                if combined_pred[j].argmax(dim=0).item() == 4:
                    transformed_grid[position[0]][position[1]] = 4

        average_validation_loss = validation_loss / num_samples
        print(f'Epoch {epoch+1}, Training Loss: {average_loss:.4f}, Validation Loss: {average_validation_loss:.4f}')

        # Update and display the grid visualization
        transformed = transform_grid(input_grid, simple_agents, device=device)
        update_plots(fig, ax1, ax2, epoch + 1, losses, transformed, plot_dir)

def update_plots(fig, ax1, ax2, epoch, losses, transformed, plot_dir):
    # Clear previous plots
    ax1.clear()
    ax2.clear()

    # Update the loss plot
    ax1.plot(range(1, len(losses) + 1), losses, marker='o', color='blue')
    ax1.set_title(f'Loss Over Epochs (Epoch {epoch})')
    ax1.set_xlabel('Epoch')
    ax1.set_ylabel('Loss')

    # Update the grid visualization
    cax = ax2.imshow(transformed, cmap='viridis', interpolation='nearest')
    ax2.set_title(f'Grid Transformation (Epoch {epoch})')
    
    # Save the plot to a file with a unique name based on the epoch number
    loss_plot_path = os.path.join(plot_dir, f'{prefix}_loss_epoch_{epoch}.png')
    grid_plot_path = os.path.join(plot_dir, f'{prefix}_grid_epoch_{epoch}.png')
    
    fig.savefig(loss_plot_path)
    print(f"Saved loss plot for epoch {epoch} to {loss_plot_path}")
    
    # Save only the grid plot without colorbar
    fig2, ax3 = plt.subplots(figsize=(7, 6))
    cax2 = ax3.imshow(transformed, cmap='viridis', interpolation='nearest')
    ax3.set_title(f'Grid Transformation (Epoch {epoch})')
    fig2.savefig(grid_plot_path)
    print(f"Saved grid plot for epoch {epoch} to {grid_plot_path}")
    
    # Close the figure to free up memory
    plt.close(fig)
    plt.close(fig2)

def transform_grid(grid, simple_agents, device='cuda'):
    transformed = np.array(grid)  # Initialize with input grid
    
    # Convert grid to tensor and move it to the specified device
    grid_tensor = torch.tensor(grid, dtype=torch.float32).to(device)
    
    for position in simple_agents:
        x_start = max(0, position[0])
        y_start = max(0, position[1])
        x_end = min(len(transformed), position[0] + 3)
        y_end = min(len(transformed[0]), position[1] + 3)

        local_area = grid[x_start:x_end, y_start:y_end].flatten()
        
        if len(local_area) != 9:
            print(f"Padded local_area shape: {local_area.shape}")
            local_area = np.pad(local_area, (0, 9 - len(local_area)), 'constant')
        
        # Convert local_area to tensor and move to device
        local_area_tensor = torch.tensor(local_area, dtype=torch.float32).to(device)
        
        prediction = simple_agent_models[position](local_area_tensor.unsqueeze(0)).cpu().argmax(dim=1).item()  # Move back to CPU for numpy conversion
        
        if prediction == 4:
            transformed[position[0]][position[1]] = 4
            
    return transformed

# Function to load ARC JSON data
def load_arc_data(json_file):
    with open(json_file, 'r') as f:
        data = json.load(f)
    
    inputs = []
    outputs = []
    
    for task in data['train']:
        input_grid = np.array(task['input'])
        output_grid = np.array(task['output'])
        inputs.append(input_grid)
        outputs.append(output_grid)
    
    return inputs, outputs

# Define the plot directory
plot_dir = '/home/xaqmusic/plots'

# Create the directory if it doesn't exist
os.makedirs(plot_dir, exist_ok=True)

# Generate a unique prefix for this run using timestamp
prefix = datetime.now().strftime("%Y%m%d_%H%M%S")  # Example: "20231005_143000"

# Initialize plots at the beginning of your script
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 6))
loss_plot, = ax1.plot([], [], marker='o')
ax1.set_title('Loss Over Epochs')
ax1.set_xlabel('Epoch')
ax1.set_ylabel('Loss')

# Example usage
if __name__ == "__main__":
    json_file = '/home/xaqmusic/ARC-AGI/data/training/00d62c1b.json'
    
    # Load ARC JSON data
    inputs, outputs = load_arc_data(json_file)
    
    for i in range(len(inputs)):
        input_grid = np.array(inputs[i])
        output_grid = np.array(outputs[i])
        
        grid_size = input_grid.shape
        
        simple_agents = distribute_agents(grid_size)
        
        print("Creating agent models...")
        num_simple_agents = len(simple_agents)
        simple_agent_models = {agent: SimpleAgent() for agent in simple_agents}
        
        complex_agent = ComplexAgent(input_size=grid_size[0] * grid_size[1])
        
        device = 'cuda' if torch.cuda.is_available() else 'cpu'
        print(f"Using device: {device}")
        
        train_inputs = [torch.tensor(input_grid, dtype=torch.float32).to(device)]
        train_outputs = [torch.tensor(output_grid, dtype=torch.long).to(device)]
        
        print("Training Agents...")
        train_agents(fig, ax1, ax2, simple_agents, train_inputs, train_outputs, epochs=50, device=device, plot_dir=plot_dir)
        
        print("\nTransforming Grid...")
        transformed = transform_grid(input_grid, simple_agents, device=device)
        print("\nInput Grid:")
        print(input_grid)
        print("\nOutput Grid after Transformation (Desired):")
        print(output_grid)
        print("\nOutput Grid after Transformation (Model's Prediction):")
        print(transformed)


In [None]:
import json
import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt
from IPython.display import display, clear_output
from tqdm import tqdm
import os
from datetime import datetime

# Ensure the correct backend is used for interactive plots
import matplotlib
matplotlib.use('inline')
%matplotlib inline

class SimpleAgent(nn.Module):
    def __init__(self, input_size=9):  # For a 3x3 grid area
        super(SimpleAgent, self).__init__()
        self.fc1 = nn.Linear(input_size, 64)  # First hidden layer
        self.relu = nn.ReLU()
        self.fc2 = nn.Linear(64, 32)  # Second hidden layer
        self.fc3 = nn.Linear(32, 10)  # Output layer
        
    def forward(self, x):
        x = self.relu(self.fc1(x))
        x = self.relu(self.fc2(x))
        return torch.nn.functional.log_softmax(self.fc3(x), dim=1)  # Softmax for multiple classes

def distribute_agents(grid_size, agent_scope=3):
    agents = []
    for i in range(0, grid_size[0], agent_scope - 1):  # Adjust step size to ensure overlap
        for j in range(0, grid_size[1], agent_scope - 1):
            if i + agent_scope <= grid_size[0] and j + agent_scope <= grid_size[1]:
                agents.append((i, j))  # Record agent positions
    return agents

def train_agents(fig, ax1, ax2, simple_agent_models, inputs, outputs, epochs=10, learning_rate=0.001, device='cuda', plot_dir=None):
    criterion = nn.NLLLoss()
    
    print("Moving models to the specified device...")
    for position, model in simple_agent_models.items():
        model.to(device)
        print(f"Simple Agent at position {position} moved to {device}")

    # Flatten list of lists into a single list of parameters
    params_to_optimize = [param for _, model in simple_agent_models.items() for param in model.parameters()]
    
    optimizer = torch.optim.Adam(params_to_optimize, lr=learning_rate)
    
    losses = []  # Initialize an empty list to store all epoch losses
    
    for epoch in tqdm(range(epochs), desc="Training", leave=False):
        print(f"Epoch {epoch + 1}/{epochs}")
        total_loss = 0
        num_samples = len(inputs)
        
        # Training phase
        for i, (x, y) in enumerate(zip(inputs, outputs)):
            optimizer.zero_grad()
            
            simple_predictions = []  # Initialize this list at the start of each iteration
            
            # Collect predictions from simple agents
            for position, model in simple_agent_models.items():
                x_start = max(0, position[0])
                y_start = max(0, position[1])
                x_end = min(x.size(0), position[0] + 3)
                y_end = min(x.size(1), position[1] + 3)

                local_x = x[x_start:x_end, y_start:y_end].flatten().to(device)
                
                if len(local_x) != 9:
                    print(f"Padded local_x shape: {local_x.shape}")
                    local_x = torch.nn.functional.pad(local_x, (0, 9 - len(local_x)))
                
                # Forward pass for the current simple agent
                prediction = model(local_x.unsqueeze(0).to(device))
                print(f"Simple Agent at position {position} prediction shape: {prediction.shape}")
                
                # Append the prediction to the list (ensure it's on CPU)
                simple_predictions.append(prediction.squeeze().cpu())
            
            combined_pred = torch.stack(simple_predictions)
            
            # Prepare target values for each prediction
            y_flat = y.flatten().long().to(device)
            simple_targets = []
            
            for position in simple_agent_models:
                x_start = max(0, position[0])
                y_start = max(0, position[1])
                x_end = min(x.size(0), position[0] + 3)
                y_end = min(x.size(1), position[1] + 3)

                local_y = y[x_start:x_end, y_start:y_end].flatten().to(device)
                
                if len(local_y) != 9:
                    print(f"Padded local_y shape: {local_y.shape}")
                    local_y = torch.nn.functional.pad(local_y, (0, 9 - len(local_y)))
                
                # Correct the dimension for argmax
                simple_targets.append(local_y.argmax(dim=0).cpu())
            
            # Stack all targets and ensure they match the batch size of combined_pred
            all_targets = torch.stack(simple_targets)
            
            # Calculate loss for each prediction separately
            total_loss_this_sample = 0
            for pred, target in zip(combined_pred, all_targets):
                loss = criterion(pred.unsqueeze(0), target.unsqueeze(0))
                total_loss_this_sample += loss.item()
                loss.backward(retain_graph=True)  # Retain graph to allow backpropagation through all agents
            
            total_loss += total_loss_this_sample
            optimizer.step()
        
        average_loss = total_loss / num_samples
        losses.append(average_loss)  # Append the current epoch's average loss
        
        # Validation phase (same as training phase)
        validation_loss = 0
        transformed_grid = np.zeros_like(inputs[0].cpu().numpy())
        for i, (x, y) in enumerate(zip(inputs, outputs)):
            simple_predictions = []  # Initialize this list at the start of each iteration
            
            # Collect predictions from simple agents
            for position, model in simple_agent_models.items():
                x_start = max(0, position[0])
                y_start = max(0, position[1])
                x_end = min(x.size(0), position[0] + 3)
                y_end = min(x.size(1), position[1] + 3)

                local_x = x[x_start:x_end, y_start:y_end].flatten().to(device)
                
                if len(local_x) != 9:
                    print(f"Padded local_x shape: {local_x.shape}")
                    local_x = torch.nn.functional.pad(local_x, (0, 9 - len(local_x)))
                
                # Forward pass for the current simple agent
                prediction = model(local_x.unsqueeze(0).to(device))
                print(f"Simple Agent at position {position} prediction shape: {prediction.shape}")
                
                # Append the prediction to the list (ensure it's on CPU)
                simple_predictions.append(prediction.squeeze().cpu())
            
            combined_pred = torch.stack(simple_predictions)
            
            # Prepare target values for each prediction (validation phase)
            y_flat = y.flatten().long().to(device)
            simple_targets = []
            
            for position in simple_agent_models:
                x_start = max(0, position[0])
                y_start = max(0, position[1])
                x_end = min(x.size(0), position[0] + 3)
                y_end = min(x.size(1), position[1] + 3)

                local_y = y[x_start:x_end, y_start:y_end].flatten().to(device)
                
                if len(local_y) != 9:
                    print(f"Padded local_y shape: {local_y.shape}")
                    local_y = torch.nn.functional.pad(local_y, (0, 9 - len(local_y)))
                
                # Correct the dimension for argmax
                simple_targets.append(local_y.argmax(dim=0).cpu())
            
            # Stack all targets and ensure they match the batch size of combined_pred
            all_targets = torch.stack(simple_targets)
            
            # Calculate validation loss for each prediction separately
            total_validation_loss_this_sample = 0
            for pred, target in zip(combined_pred, all_targets):
                loss = criterion(pred.unsqueeze(0), target.unsqueeze(0))
                total_validation_loss_this_sample += loss.item()
            
            validation_loss += total_validation_loss_this_sample
            
            # Update transformed grid based on predictions (if needed)
            for j, position in enumerate(simple_agent_models):
                if combined_pred[j].argmax(dim=0).item() == 4:
                    transformed_grid[position[0]][position[1]] = 4

        average_validation_loss = validation_loss / num_samples
        print(f'Epoch {epoch+1}, Training Loss: {average_loss:.4f}, Validation Loss: {average_validation_loss:.4f}')

        # Update and display the grid visualization
        transformed = transform_grid(input_grid, simple_agent_models, device=device)
        update_plots(fig, ax1, ax2, epoch + 1, losses, transformed, plot_dir)

        # Check if both training and validation loss are below the threshold
        if average_loss < 0.00001 and average_validation_loss < 0.00001:
            print("Training complete: Both training and validation loss are below 0.00001.")
            break

        # Update the grid for the next epoch
        inputs[0] = torch.tensor(transformed_grid, dtype=torch.float32).to(device)
        clear_output(wait=True)  # Clear the output before logging new information

def update_plots(fig, ax1, ax2, epoch, losses, transformed, plot_dir):
    # Clear previous plots
    ax1.clear()
    ax2.clear()

    # Update the loss plot
    ax1.plot(range(1, len(losses) + 1), losses, marker='o', color='blue')
    ax1.set_title(f'Loss Over Epochs (Epoch {epoch})')
    ax1.set_xlabel('Epoch')
    ax1.set_ylabel('Loss')

    # Update the grid visualization
    cax = ax2.imshow(transformed, cmap='viridis', interpolation='nearest')
    ax2.set_title(f'Grid Transformation (Epoch {epoch})')
    
    # Save the plot to a file with a unique name based on the epoch number
    loss_plot_path = os.path.join(plot_dir, f'{prefix}_loss_epoch_{epoch}.png')
    grid_plot_path = os.path.join(plot_dir, f'{prefix}_grid_epoch_{epoch}.png')
    
    fig.savefig(loss_plot_path)
    print(f"Saved loss plot for epoch {epoch} to {loss_plot_path}")
    
    # Save only the grid plot without colorbar
    fig2, ax3 = plt.subplots(figsize=(7, 6))
    cax2 = ax3.imshow(transformed, cmap='viridis', interpolation='nearest')
    ax3.set_title(f'Grid Transformation (Epoch {epoch})')
    fig2.savefig(grid_plot_path)
    print(f"Saved grid plot for epoch {epoch} to {grid_plot_path}")
    
    # Close the figure to free up memory
    plt.close(fig)
    plt.close(fig2)

def transform_grid(grid, simple_agent_models, device='cuda'):
    transformed = np.array(grid)  # Initialize with input grid
    
    # Convert grid to tensor and move it to the specified device
    grid_tensor = torch.tensor(grid, dtype=torch.float32).to(device)
    
    for position, model in simple_agent_models.items():
        x_start = max(0, position[0])
        y_start = max(0, position[1])
        x_end = min(len(transformed), position[0] + 3)
        y_end = min(len(transformed[0]), position[1] + 3)

        local_area = grid[x_start:x_end, y_start:y_end].flatten()
        
        if len(local_area) != 9:
            print(f"Padded local_area shape: {local_area.shape}")
            local_area = np.pad(local_area, (0, 9 - len(local_area)), 'constant')
        
        # Convert local_area to tensor and move to device
        local_area_tensor = torch.tensor(local_area, dtype=torch.float32).to(device)
        
        prediction = model(local_area_tensor.unsqueeze(0)).cpu().argmax(dim=1).item()  # Move back to CPU for numpy conversion
        
        if prediction == 4:
            transformed[position[0]][position[1]] = 4
            
    return transformed

# Function to load ARC JSON data
def load_arc_data(json_file):
    with open(json_file, 'r') as f:
        data = json.load(f)
    
    inputs = []
    outputs = []
    
    for task in data['train']:
        input_grid = np.array(task['input'])
        output_grid = np.array(task['output'])
        inputs.append(input_grid)
        outputs.append(output_grid)
    
    return inputs, outputs

# Define the plot directory
plot_dir = '/home/xaqmusic/plots'

# Create the directory if it doesn't exist
os.makedirs(plot_dir, exist_ok=True)

# Generate a unique prefix for this run using timestamp
prefix = datetime.now().strftime("%Y%m%d_%H%M%S")  # Example: "20231005_143000"

# Initialize plots at the beginning of your script
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 6))
loss_plot, = ax1.plot([], [], marker='o')
ax1.set_title('Loss Over Epochs')
ax1.set_xlabel('Epoch')
ax1.set_ylabel('Loss')

# Example usage
if __name__ == "__main__":
    json_file = '/home/xaqmusic/ARC-AGI/data/training/00d62c1b.json'
    
    # Load ARC JSON data
    inputs, outputs = load_arc_data(json_file)
    
    for i in range(len(inputs)):
        input_grid = np.array(inputs[i])
        output_grid = np.array(outputs[i])
        
        grid_size = input_grid.shape
        
        simple_agents = distribute_agents(grid_size)
        
        print("Creating agent models...")
        num_simple_agents = len(simple_agents)
        simple_agent_models = {agent: SimpleAgent() for agent in simple_agents}
        
        device = 'cuda' if torch.cuda.is_available() else 'cpu'
        print(f"Using device: {device}")
        
        train_inputs = [torch.tensor(input_grid, dtype=torch.float32).to(device)]
        train_outputs = [torch.tensor(output_grid, dtype=torch.long).to(device)]
        
        print("Training Agents...")
        train_agents(fig, ax1, ax2, simple_agent_models, train_inputs, train_outputs, epochs=50, device=device, plot_dir=plot_dir)
        
        print("\nTransforming Grid...")
        transformed = transform_grid(input_grid, simple_agent_models, device=device)
        print("\nInput Grid:")
        print(input_grid)
        print("\nOutput Grid after Transformation (Desired):")
        print(output_grid)
        print("\nOutput Grid after Transformation (Model's Prediction):")
        print(transformed)


working but there are still gaps between the agents

In [None]:
import json
import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt
from IPython.display import display, clear_output
from tqdm import tqdm
import os
from datetime import datetime

# Ensure the correct backend is used for interactive plots
import matplotlib
matplotlib.use('inline')
%matplotlib inline

class SimpleAgent(nn.Module):
    def __init__(self, input_size=9):  # For a 3x3 grid area
        super(SimpleAgent, self).__init__()
        self.fc1 = nn.Linear(input_size, 64)  # First hidden layer
        self.relu = nn.ReLU()
        self.fc2 = nn.Linear(64, 32)  # Second hidden layer
        self.fc3 = nn.Linear(32, 10)  # Output layer
        
    def forward(self, x):
        x = self.relu(self.fc1(x))
        x = self.relu(self.fc2(x))
        return torch.nn.functional.log_softmax(self.fc3(x), dim=1)  # Softmax for multiple classes

def distribute_agents(grid_size, agent_scope=3):
    agents = []
    for i in range(0, grid_size[0], agent_scope - 1):  # Adjust step size to ensure overlap
        for j in range(0, grid_size[1], agent_scope - 1):
            if i + agent_scope <= grid_size[0] and j + agent_scope <= grid_size[1]:
                agents.append((i, j))  # Record agent positions
    return agents

def train_agents(fig, ax1, ax2, simple_agent_models, inputs, outputs, epochs=50, learning_rate=0.001, device='cuda', plot_dir=None):
    criterion = nn.MSELoss()  # Use MSE Loss
    
    print("Moving models to the specified device...")
    for position, model in simple_agent_models.items():
        model.to(device)
        print(f"Simple Agent at position {position} moved to {device}")

    # Flatten list of lists into a single list of parameters
    params_to_optimize = [param for _, model in simple_agent_models.items() for param in model.parameters()]
    
    optimizer = torch.optim.Adam(params_to_optimize, lr=learning_rate)
    
    losses = []  # Initialize an empty list to store all epoch losses
    
    for epoch in tqdm(range(epochs), desc="Training", leave=False):
        clear_output(wait=True)  # Clear the output before logging new information
        print(f"Epoch {epoch + 1}/{epochs}")
        total_loss = 0
        num_samples = len(inputs)
        
        # Training phase
        for i, (x, y) in enumerate(zip(inputs, outputs)):
            optimizer.zero_grad()
            
            simple_predictions = []  # Initialize this list at the start of each iteration
            
            # Collect predictions from simple agents
            for position, model in simple_agent_models.items():
                x_start = max(0, position[0])
                y_start = max(0, position[1])
                x_end = min(x.size(0), position[0] + 3)
                y_end = min(x.size(1), position[1] + 3)

                local_x = x[x_start:x_end, y_start:y_end].flatten().to(device)
                
                if len(local_x) != 9:
                    print(f"Padded local_x shape: {local_x.shape}")
                    local_x = torch.nn.functional.pad(local_x, (0, 9 - len(local_x)))
                
                # Forward pass for the current simple agent
                prediction = model(local_x.unsqueeze(0).to(device))
                print(f"Simple Agent at position {position} prediction shape: {prediction.shape}")
                
                # Append the prediction to the list (ensure it's on CPU)
                simple_predictions.append(prediction.squeeze().cpu())
            
            combined_pred = torch.stack(simple_predictions)
            
            # Prepare target values for each prediction
            y_flat = y.flatten().long().to(device)
            simple_targets = []
            
            for position in simple_agent_models:
                x_start = max(0, position[0])
                y_start = max(0, position[1])
                x_end = min(x.size(0), position[0] + 3)
                y_end = min(x.size(1), position[1] + 3)

                local_y = y[x_start:x_end, y_start:y_end].flatten().to(device)
                
                if len(local_y) != 9:
                    print(f"Padded local_y shape: {local_y.shape}")
                    local_y = torch.nn.functional.pad(local_y, (0, 9 - len(local_y)))
                
                # Correct the dimension for argmax
                simple_targets.append(local_y.argmax(dim=0).cpu())
            
            # Stack all targets and ensure they match the batch size of combined_pred
            all_targets = torch.stack(simple_targets)
            
            # Calculate loss for each prediction separately
            total_loss_this_sample = 0
            for pred, target in zip(combined_pred, all_targets):
                # Convert target to float before calculating loss
                target_float = target.float()
                loss = criterion(pred.unsqueeze(0), target_float.unsqueeze(0))
                total_loss_this_sample += loss.item()
                loss.backward(retain_graph=True)  # Retain graph to allow backpropagation through all agents
            
            total_loss += total_loss_this_sample
            optimizer.step()
        
        average_loss = total_loss / num_samples
        losses.append(average_loss)  # Append the current epoch's average loss
        
        # Validation phase (same as training phase)
        validation_loss = 0
        # Initialize transformed_grid with zeros for consistency
        max_value = inputs[0].max().item()  # Convert tensor to scalar
        transformed_grid_tensor = torch.zeros_like(inputs[0])  # Use torch.rand_like for similar shape and dtype
        
        for i, (x, y) in enumerate(zip(inputs, outputs)):
            simple_predictions = []  # Initialize this list at the start of each iteration
            
            # Collect predictions from simple agents
            for position, model in simple_agent_models.items():
                x_start = max(0, position[0])
                y_start = max(0, position[1])
                x_end = min(x.size(0), position[0] + 3)
                y_end = min(x.size(1), position[1] + 3)

                local_x = x[x_start:x_end, y_start:y_end].flatten().to(device)
                
                if len(local_x) != 9:
                    print(f"Padded local_x shape: {local_x.shape}")
                    local_x = torch.nn.functional.pad(local_x, (0, 9 - len(local_x)))
                
                # Forward pass for the current simple agent
                prediction = model(local_x.unsqueeze(0).to(device))
                print(f"Simple Agent at position {position} prediction shape: {prediction.shape}")
                
                # Append the prediction to the list (ensure it's on CPU)
                simple_predictions.append(prediction.squeeze().cpu())
            
            combined_pred = torch.stack(simple_predictions)
            
            # Prepare target values for each prediction (validation phase)
            y_flat = y.flatten().long().to(device)
            simple_targets = []
            
            for position in simple_agent_models:
                x_start = max(0, position[0])
                y_start = max(0, position[1])
                x_end = min(x.size(0), position[0] + 3)
                y_end = min(x.size(1), position[1] + 3)

                local_y = y[x_start:x_end, y_start:y_end].flatten().to(device)
                
                if len(local_y) != 9:
                    print(f"Padded local_y shape: {local_y.shape}")
                    local_y = torch.nn.functional.pad(local_y, (0, 9 - len(local_y)))
                
                # Correct the dimension for argmax
                simple_targets.append(local_y.argmax(dim=0).cpu())
            
            # Stack all targets and ensure they match the batch size of combined_pred
            all_targets = torch.stack(simple_targets)
            
            # Calculate validation loss for each prediction separately
            total_validation_loss_this_sample = 0
            for pred, target in zip(combined_pred, all_targets):
                # Convert target to float before calculating loss
                target_float = target.float()
                loss = criterion(pred.unsqueeze(0), target_float.unsqueeze(0))
                total_validation_loss_this_sample += loss.item()
            
            validation_loss += total_validation_loss_this_sample
            
            # Update transformed grid based on predictions (if needed)
            for j, position in enumerate(simple_agents):
                if combined_pred[j].argmax(dim=0).item() == 4:
                    transformed_grid_tensor[position[0]][position[1]] = max_value

        average_validation_loss = validation_loss / num_samples
        print(f'Epoch {epoch+1}, Training Loss: {average_loss:.4f}, Validation Loss: {average_validation_loss:.4f}')

        # Update and display the grid visualization
        transformed = transform_grid(inputs[0].cpu(), simple_agents, device=device)
        update_plots(fig, ax1, ax2, epoch + 1, losses, transformed, plot_dir)

        # Check if both training and validation loss are below the threshold
        if average_loss < 0.00001 and average_validation_loss < 0.00001:
            print("Training complete: Both training and validation loss are below 0.00001.")
            break

        # Update the grid for the next epoch
        inputs[0] = transformed_grid_tensor.to(device)  # Ensure it's on the correct device before updating

def update_plots(fig, ax1, ax2, epoch, losses, transformed, plot_dir):
    # Clear previous plots
    ax1.clear()
    ax2.clear()

    # Update the loss plot
    ax1.plot(range(1, len(losses) + 1), losses, marker='o', color='blue')
    ax1.set_title(f'Loss Over Epochs (Epoch {epoch})')
    ax1.set_xlabel('Epoch')
    ax1.set_ylabel('Loss')

    # Ensure transformed is a NumPy array before passing to imshow
    if isinstance(transformed, torch.Tensor):
        transformed = transformed.cpu().numpy()
    
    # Update the grid visualization
    cax = ax2.imshow(transformed, cmap='viridis', interpolation='nearest')  # Convert tensor to numpy array
    ax2.set_title(f'Grid Transformation (Epoch {epoch})')
    
    # Save the plot to a file with a unique name based on the epoch number
    loss_plot_path = os.path.join(plot_dir, f'{prefix}_loss_epoch_{epoch}.png')
    grid_plot_path = os.path.join(plot_dir, f'{prefix}_grid_epoch_{epoch}.png')
    
    fig.savefig(loss_plot_path)
    print(f"Saved loss plot for epoch {epoch} to {loss_plot_path}")
    
    # Save only the grid plot without colorbar
    fig2, ax3 = plt.subplots(figsize=(7, 6))
    cax2 = ax3.imshow(transformed, cmap='viridis', interpolation='nearest')
    ax3.set_title(f'Grid Transformation (Epoch {epoch})')
    fig2.savefig(grid_plot_path)
    print(f"Saved grid plot for epoch {epoch} to {grid_plot_path}")
    
    # Close the figure to free up memory
    plt.close(fig)
    plt.close(fig2)

def transform_grid(grid, simple_agents, device='cuda'):
    transformed = np.array(grid)  # Initialize with input grid
    
    # Convert grid to tensor and move it to the specified device
    grid_tensor = torch.tensor(grid, dtype=torch.float32).to(device)
    
    for position in simple_agents:
        x_start = max(0, position[0])
        y_start = max(0, position[1])
        x_end = min(len(transformed), position[0] + 3)
        y_end = min(len(transformed[0]), position[1] + 3)

        local_area = grid[x_start:x_end, y_start:y_end].flatten()
        
        if len(local_area) != 9:
            print(f"Padded local_area shape: {local_area.shape}")
            local_area = np.pad(local_area, (0, 9 - len(local_area)), 'constant')
        
        # Convert local_area to tensor and move to device
        local_area_tensor = torch.tensor(local_area, dtype=torch.float32).to(device)
        
        prediction = simple_agent_models[position](local_area_tensor.unsqueeze(0)).cpu().argmax(dim=1).item()  # Move back to CPU for numpy conversion
        
        if prediction == 4:
            transformed[position[0]][position[1]] = grid.max().item()
            
    return transformed

def save_initial_grid(grid, prefix, plot_dir):
    fig, ax = plt.subplots(figsize=(7, 6))
    cax = ax.imshow(grid.cpu().numpy(), cmap='viridis', interpolation='nearest')
    ax.set_title(f'Initial Grid (Prefix: {prefix})')
    
    # Save the plot to a file
    initial_grid_plot_path = os.path.join(plot_dir, f'{prefix}_initial_grid.png')
    plt.savefig(initial_grid_plot_path)
    print(f"Saved initial grid plot to {initial_grid_plot_path}")
    
    # Close the figure to free up memory
    plt.close(fig)

# Function to load ARC JSON data and convert to tensors
def load_arc_data(json_file):
    with open(json_file, 'r') as f:
        data = json.load(f)
    
    inputs = []
    outputs = []
    
    for task in data['train']:
        input_grid = torch.tensor(task['input'], dtype=torch.float32)
        output_grid = torch.tensor(task['output'], dtype=torch.float32)
        inputs.append(input_grid)
        outputs.append(output_grid)
    
    return inputs, outputs

# Define the plot directory
plot_dir = '/home/xaqmusic/plots'

# Create the directory if it doesn't exist
os.makedirs(plot_dir, exist_ok=True)

# Generate a unique prefix for this run using Unix time (epoch seconds since 1970)
prefix = str(int(datetime.now().timestamp()))  # Example: "1693425180"

# Initialize plots at the beginning of your script
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 6))
loss_plot, = ax1.plot([], [], marker='o')
ax1.set_title('Loss Over Epochs')
ax1.set_xlabel('Epoch')
ax1.set_ylabel('Loss')

# Example usage
if __name__ == "__main__":
    json_file = '/home/xaqmusic/ARC-AGI/data/training/00d62c1b.json'
    
    # Load ARC JSON data and convert to tensors
    inputs, outputs = load_arc_data(json_file)
    
    for i in range(len(inputs)):
        input_grid = inputs[i]
        output_grid = outputs[i]
        
        grid_size = input_grid.shape
        
        simple_agents = distribute_agents(grid_size)
        
        print("Creating agent models...")
        num_simple_agents = len(simple_agents)
        simple_agent_models = {agent: SimpleAgent() for agent in simple_agents}
        
        device = 'cuda' if torch.cuda.is_available() else 'cpu'
        print(f"Using device: {device}")
        
        train_inputs = [input_grid.to(device)]
        train_outputs = [output_grid.to(device)]
        
        # Save the initial grid before training
        save_initial_grid(input_grid, prefix, plot_dir)
        
        print("Training Agents...")
        train_agents(fig, ax1, ax2, simple_agent_models, train_inputs, train_outputs, epochs=50, device=device, plot_dir=plot_dir)
        
        print("\nTransforming Grid...")
        transformed = transform_grid(input_grid.cpu(), simple_agents, device=device)
        print("\nInput Grid:")
        print(input_grid.cpu().numpy())
        print("\nOutput Grid after Transformation (Desired):")
        print(output_grid.cpu().numpy())
        print("\nOutput Grid after Transformation (Model's Prediction):")
        print(transformed)


## fix the grid so all squares can be affected by agents

all cells are changing and something is happening!

In [None]:
import json
import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt
from IPython.display import display, clear_output
from tqdm import tqdm
import os
from datetime import datetime

# Ensure the correct backend is used for interactive plots
import matplotlib
matplotlib.use('inline')
%matplotlib inline

class SimpleAgent(nn.Module):
    def __init__(self, input_size=9):  # For a 3x3 grid area
        super(SimpleAgent, self).__init__()
        self.fc1 = nn.Linear(input_size, 18)  # First hidden layer
        self.relu = nn.ReLU()
        self.fc2 = nn.Linear(18, 16)  # Second hidden layer
        self.fc3 = nn.Linear(16, 10)  # Output layer
        
    def forward(self, x):
        x = self.relu(self.fc1(x))
        x = self.relu(self.fc2(x))
        return torch.nn.functional.log_softmax(self.fc3(x), dim=1)  # Softmax for multiple classes

def distribute_agents(grid_size, agent_scope=3):
    agents = []
    step_size = 1  # Place an agent at every cell
    for i in range(0, grid_size[0], step_size):  
        for j in range(0, grid_size[1], step_size):
            agents.append((i, j))  # Record agent positions
    return agents

def train_agents(fig, ax1, ax2, simple_agent_models, inputs, outputs, epochs=30, learning_rate=0.0001, device='cuda', plot_dir=None):
    criterion = nn.MSELoss()  # Use MSE Loss
    
    print("Moving models to the specified device...")
    for position, model in simple_agent_models.items():
        model.to(device)
        print(f"Simple Agent at position {position} moved to {device}")

    # Flatten list of lists into a single list of parameters
    params_to_optimize = [param for _, model in simple_agent_models.items() for param in model.parameters()]
    
    optimizer = torch.optim.Adam(params_to_optimize, lr=learning_rate)
    
    losses = []  # Initialize an empty list to store all epoch losses
    
    for epoch in tqdm(range(epochs), desc="Training", leave=False):
        print(f"Epoch {epoch + 1}/{epochs}")
        total_loss = 0
        num_samples = len(inputs)
        
        # Training phase
        for i, (x, y) in enumerate(zip(inputs, outputs)):
            optimizer.zero_grad()
            
            simple_predictions = []  # Initialize this list at the start of each iteration
            
            # Collect predictions from simple agents
            for position, model in simple_agent_models.items():
                x_start = max(0, position[0])
                y_start = max(0, position[1])
                x_end = min(x.size(0), position[0] + 3)
                y_end = min(x.size(1), position[1] + 3)

                local_x = x[x_start:x_end, y_start:y_end].flatten().to(device)
                
                if len(local_x) != 9:
                     # print(f"Padded local_x shape: {local_x.shape}")
                    local_x = torch.nn.functional.pad(local_x, (0, 9 - len(local_x)))
                
                # Forward pass for the current simple agent
                prediction = model(local_x.unsqueeze(0).to(device))
                # print(f"Simple Agent at position {position} prediction shape: {prediction.shape}")
                
                # Append the prediction to the list (ensure it's on CPU)
                simple_predictions.append(prediction.squeeze().cpu())
            
            combined_pred = torch.stack(simple_predictions)
            
            # Prepare target values for each prediction
            y_flat = y.flatten().long().to(device)
            simple_targets = []
            
            for position in simple_agent_models:
                x_start = max(0, position[0])
                y_start = max(0, position[1])
                x_end = min(x.size(0), position[0] + 3)
                y_end = min(x.size(1), position[1] + 3)

                local_y = y[x_start:x_end, y_start:y_end].flatten().to(device)
                
                if len(local_y) != 9:
                    # print(f"Padded local_y shape: {local_y.shape}")
                    local_y = torch.nn.functional.pad(local_y, (0, 9 - len(local_y)))
                
                # Correct the dimension for argmax
                simple_targets.append(local_y.argmax(dim=0).cpu())
            
            # Stack all targets and ensure they match the batch size of combined_pred
            all_targets = torch.stack(simple_targets)
            
            # Calculate loss for each prediction separately
            total_loss_this_sample = 0
            for pred, target in zip(combined_pred, all_targets):
                # Convert target to float before calculating loss
                target_float = target.float()
                loss = criterion(pred.unsqueeze(0), target_float.unsqueeze(0))
                total_loss_this_sample += loss.item()
                loss.backward(retain_graph=True)  # Retain graph to allow backpropagation through all agents
            
            total_loss += total_loss_this_sample
            optimizer.step()
        
        average_loss = total_loss / num_samples
        losses.append(average_loss)  # Append the current epoch's average loss
        
        # Validation phase (same as training phase)
        validation_loss = 0
        # Initialize transformed_grid with zeros for consistency
        max_value = inputs[0].max().item()  # Convert tensor to scalar
        transformed_grid_tensor = torch.zeros_like(inputs[0])  # Use torch.rand_like for similar shape and dtype
        
        for i, (x, y) in enumerate(zip(inputs, outputs)):
            simple_predictions = []  # Initialize this list at the start of each iteration
            
            # Collect predictions from simple agents
            for position, model in simple_agent_models.items():
                x_start = max(0, position[0])
                y_start = max(0, position[1])
                x_end = min(x.size(0), position[0] + 3)
                y_end = min(x.size(1), position[1] + 3)

                local_x = x[x_start:x_end, y_start:y_end].flatten().to(device)
                
                if len(local_x) != 9:
                    # print(f"Padded local_x shape: {local_x.shape}")
                    local_x = torch.nn.functional.pad(local_x, (0, 9 - len(local_x)))
                
                # Forward pass for the current simple agent
                prediction = model(local_x.unsqueeze(0).to(device))
                # print(f"Simple Agent at position {position} prediction shape: {prediction.shape}")
                
                # Append the prediction to the list (ensure it's on CPU)
                simple_predictions.append(prediction.squeeze().cpu())
            
            combined_pred = torch.stack(simple_predictions)
            
            # Prepare target values for each prediction (validation phase)
            y_flat = y.flatten().long().to(device)
            simple_targets = []
            
            for position in simple_agent_models:
                x_start = max(0, position[0])
                y_start = max(0, position[1])
                x_end = min(x.size(0), position[0] + 3)
                y_end = min(x.size(1), position[1] + 3)

                local_y = y[x_start:x_end, y_start:y_end].flatten().to(device)
                
                if len(local_y) != 9:
                    # print(f"Padded local_y shape: {local_y.shape}")
                    local_y = torch.nn.functional.pad(local_y, (0, 9 - len(local_y)))
                
                # Correct the dimension for argmax
                simple_targets.append(local_y.argmax(dim=0).cpu())
            
            # Stack all targets and ensure they match the batch size of combined_pred
            all_targets = torch.stack(simple_targets)
            
            # Calculate validation loss for each prediction separately
            total_validation_loss_this_sample = 0
            for pred, target in zip(combined_pred, all_targets):
                # Convert target to float before calculating loss
                target_float = target.float()
                loss = criterion(pred.unsqueeze(0), target_float.unsqueeze(0))
                total_validation_loss_this_sample += loss.item()
            
            validation_loss += total_validation_loss_this_sample
            
            # Update transformed grid based on predictions (if needed)
            for j, position in enumerate(simple_agents):
                prediction = combined_pred[j].argmax(dim=0).item()
                # print(f"Agent at position {position} predicted: {prediction}")
                
                transformed_grid_tensor[position[0]][position[1]] = prediction

        average_validation_loss = validation_loss / num_samples
        print(f'Epoch {epoch+1}, Training Loss: {average_loss:.4f}, Validation Loss: {average_validation_loss:.4f}')

        # Update and display the grid visualization
        transformed = transform_grid(inputs[0].cpu(), simple_agents, device=device)
        
        # Debug: Print the updated grid values
        # print(f"Updated Grid after Epoch {epoch + 1}:\n{transformed}")
        
        update_plots(fig, ax1, ax2, epoch + 1, losses, transformed, plot_dir)

        # Check if both training and validation loss are below the threshold
        if average_loss < 0.00001 and average_validation_loss < 0.00001:
            print("Training complete: Both training and validation loss are below 0.00001.")
            break

        # Update the grid for the next epoch
        inputs[0] = transformed_grid_tensor.to(device)  # Ensure it's on the correct device before updating
        clear_output(wait=True)  # Clear the output before logging new information
        print(f"Updated Grid after Epoch {epoch + 1}:\n{transformed}")

def update_plots(fig, ax1, ax2, epoch, losses, transformed, plot_dir):
    # Clear previous plots
    ax1.clear()
    ax2.clear()

    # Update the loss plot
    ax1.plot(range(1, len(losses) + 1), losses, marker='o', color='blue')
    ax1.set_title(f'Loss Over Epochs (Epoch {epoch})')
    ax1.set_xlabel('Epoch')
    ax1.set_ylabel('Loss')

    # Ensure transformed is a NumPy array before passing to imshow
    if isinstance(transformed, torch.Tensor):
        transformed = transformed.cpu().numpy()
    
    # Update the grid visualization
    cax = ax2.imshow(transformed, cmap='viridis', interpolation='nearest')  # Convert tensor to numpy array
    ax2.set_title(f'Grid Transformation (Epoch {epoch})')
    
    # Save the plot to a file with a unique name based on the epoch number
    loss_plot_path = os.path.join(plot_dir, f'{prefix}_loss_epoch_{epoch}.png')
    grid_plot_path = os.path.join(plot_dir, f'{prefix}_grid_epoch_{epoch}.png')
    
    fig.savefig(loss_plot_path)
    print(f"Saved loss plot for epoch {epoch} to {loss_plot_path}")
    
    # Save only the grid plot without colorbar
    fig2, ax3 = plt.subplots(figsize=(7, 6))
    cax2 = ax3.imshow(transformed, cmap='viridis', interpolation='nearest')
    ax3.set_title(f'Grid Transformation (Epoch {epoch})')
    fig2.savefig(grid_plot_path)
    print(f"Saved grid plot for epoch {epoch} to {grid_plot_path}")
    
    # Close the figure to free up memory
    plt.close(fig)
    plt.close(fig2)

def transform_grid(grid, simple_agents, device='cuda'):
    transformed = np.array(grid)  # Initialize with input grid
    
    # Convert grid to tensor and move it to the specified device
    grid_tensor = torch.tensor(grid, dtype=torch.float32).to(device)
    
    for position in simple_agents:
        x_start = max(0, position[0])
        y_start = max(0, position[1])
        x_end = min(len(transformed), position[0] + 3)
        y_end = min(len(transformed[0]), position[1] + 3)

        local_area = grid[x_start:x_end, y_start:y_end].flatten()
        
        if len(local_area) != 9:
            # print(f"Padded local_area shape: {local_area.shape}")
            local_area = np.pad(local_area, (0, 9 - len(local_area)), 'constant')
        
        # Convert local_area to tensor and move to device
        local_area_tensor = torch.tensor(local_area, dtype=torch.float32).to(device)
        
        prediction = simple_agent_models[position](local_area_tensor.unsqueeze(0)).cpu().argmax(dim=1).item()  # Move back to CPU for numpy conversion
        
        # print(f"Agent at position {position} predicted: {prediction}")
        
        transformed[position[0]][position[1]] = prediction
            
    return transformed

def save_initial_grid(grid, prefix, plot_dir):
    # Generate a random grid with values between 0 and 9
    initial_grid = torch.randint(10, size=grid.shape).float()
    
    fig, ax = plt.subplots(figsize=(7, 6))
    cax = ax.imshow(initial_grid.numpy(), cmap='viridis', interpolation='nearest')
    ax.set_title(f'Initial Grid (Prefix: {prefix})')
    
    # Save the plot to a file
    initial_grid_plot_path = os.path.join(plot_dir, f'{prefix}_initial_grid.png')
    plt.savefig(initial_grid_plot_path)
    print(f"Saved initial grid plot to {initial_grid_plot_path}")
    
    # Close the figure to free up memory
    plt.close(fig)

# Function to load ARC JSON data and convert to tensors
def load_arc_data(json_file):
    with open(json_file, 'r') as f:
        data = json.load(f)
    
    inputs = []
    outputs = []
    
    for task in data['train']:
        input_grid = torch.tensor(task['input'], dtype=torch.float32)
        output_grid = torch.tensor(task['output'], dtype=torch.float32)
        inputs.append(input_grid)
        outputs.append(output_grid)
    
    return inputs, outputs

# Define the plot directory
plot_dir = '/home/xaqmusic/plots'

# Create the directory if it doesn't exist
os.makedirs(plot_dir, exist_ok=True)

# Generate a unique prefix for this run using Unix time (epoch seconds since 1970)
prefix = str(int(datetime.now().timestamp()))  # Example: "1693425180"

# Initialize plots at the beginning of your script
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 6))
loss_plot, = ax1.plot([], [], marker='o')
ax1.set_title('Loss Over Epochs')
ax1.set_xlabel('Epoch')
ax1.set_ylabel('Loss')

# Example usage
if __name__ == "__main__":
    json_file = '/home/xaqmusic/ARC-AGI/data/training/00d62c1b.json'
    
    # Load ARC JSON data and convert to tensors
    inputs, outputs = load_arc_data(json_file)
    
    for i in range(len(inputs)):
        input_grid = inputs[i]
        output_grid = outputs[i]
        
        grid_size = input_grid.shape
        
        simple_agents = distribute_agents(grid_size)
        
        print("Creating agent models...")
        num_simple_agents = len(simple_agents)
        simple_agent_models = {agent: SimpleAgent() for agent in simple_agents}
        
        device = 'cuda' if torch.cuda.is_available() else 'cpu'
        print(f"Using device: {device}")
        
        # **** Generate a random initial grid with values between 0 and 9
        # initial_grid = torch.randint(10, size=grid_size).float().to(device)
        # **** or just the input grid to start
        ititial_grid = input_grid
        
        # train_inputs = [initial_grid]
        train_inputs = [input_grid]
        train_outputs = [output_grid.to(device)]
        
        # Save the initial grid before training
        save_initial_grid(initial_grid.cpu(), prefix, plot_dir)
        
        print("Training Agents...")
        train_agents(fig, ax1, ax2, simple_agent_models, train_inputs, train_outputs, epochs=200, device=device, plot_dir=plot_dir)
        
        print("\nTransforming Grid...")
        transformed = transform_grid(input_grid.cpu(), simple_agents, device=device)
        print("\nInput Grid:")
        print(input_grid.cpu().numpy())
        print("\nOutput Grid after Transformation (Desired):")
        print(output_grid.cpu().numpy())
        print("\nOutput Grid after Transformation (Model's Prediction):")
        print(transformed)
        input("Press Enter to continue...")


In [None]:
import json
import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt
from IPython.display import display, clear_output
from tqdm import tqdm
import os
from datetime import datetime

# Ensure the correct backend is used for interactive plots
import matplotlib
matplotlib.use('inline')
%matplotlib inline

class SimpleAgent(nn.Module):
    def __init__(self, input_size=18):  # Updated to handle concatenated input size of 18
        super(SimpleAgent, self).__init__()
        self.fc1 = nn.Linear(input_size, 36)  # First hidden layer with larger capacity for the epoch input
        self.relu = nn.ReLU()
        self.fc2 = nn.Linear(36, 10)  # Second hidden layer
        self.fc3 = nn.Linear(10, 10)  # Output layer
        
    def forward(self, x):
        # print(f"Input shape before concatenation: {x.shape}")
        x = torch.cat((x[:, :-1], x[:, -1].unsqueeze(1).repeat(1, 9)), dim=1)  # Repeat the epoch value to match grid input
        # print(f"Input shape after concatenation: {x.shape}")
        x = self.relu(self.fc1(x))
        x = self.relu(self.fc2(x))
        return torch.nn.functional.log_softmax(self.fc3(x), dim=1)

def distribute_agents(grid_size, agent_scope=3):
    agents = []
    step_size = 1  # Place an agent at every cell
    for i in range(0, grid_size[0], step_size):  
        for j in range(0, grid_size[1], step_size):
            agents.append((i, j))  # Record agent positions
    return agents

def train_agents(fig, ax1, ax2, simple_agent_models, inputs, outputs, epochs=30, learning_rate=0.00001, device='cuda', plot_dir=None):
    criterion = nn.MSELoss()  # Use MSE Loss
    
    print("Moving models to the specified device...")
    for position, model in simple_agent_models.items():
        model.to(device)
        print(f"Simple Agent at position {position} moved to {device}")

    # Flatten list of lists into a single list of parameters
    params_to_optimize = [param for _, model in simple_agent_models.items() for param in model.parameters()]
    
    optimizer = torch.optim.Adam(params_to_optimize, lr=learning_rate)
    
    losses = []  # Initialize an empty list to store all epoch losses
    
    for epoch in tqdm(range(epochs), desc="Training", leave=False):
        print(f"Epoch {epoch + 1}/{epochs}")
        total_loss = 0
        num_samples = len(inputs)
        
        # Training phase
        for i, (x, y) in enumerate(zip(inputs, outputs)):
            optimizer.zero_grad()
            
            simple_predictions = []  # Initialize this list at the start of each iteration
            
            # Collect predictions from simple agents
            for position, model in simple_agent_models.items():
                x_start = max(0, position[0])
                y_start = max(0, position[1])
                x_end = min(x.size(0), position[0] + 3)
                y_end = min(x.size(1), position[1] + 3)

                local_x = x[x_start:x_end, y_start:y_end].flatten().to(device)
                
                if len(local_x) != 9:
                     # print(f"Padded local_x shape: {local_x.shape}")
                    local_x = torch.nn.functional.pad(local_x, (0, 9 - len(local_x)))
                
                # Add epoch to the input
                local_x_with_epoch = torch.cat((local_x, torch.tensor([epoch]).float().to(device)), dim=0)
                
                # Forward pass for the current simple agent
                prediction = model(local_x_with_epoch.unsqueeze(0).to(device))
                # print(f"Simple Agent at position {position} prediction shape: {prediction.shape}")
                
                # Append the prediction to the list (ensure it's on CPU)
                simple_predictions.append(prediction.squeeze().cpu())
            
            combined_pred = torch.stack(simple_predictions)
            
            # Prepare target values for each prediction
            y_flat = y.flatten().long().to(device)
            simple_targets = []
            
            for position in simple_agent_models:
                x_start = max(0, position[0])
                y_start = max(0, position[1])
                x_end = min(x.size(0), position[0] + 3)
                y_end = min(x.size(1), position[1] + 3)

                local_y = y[x_start:x_end, y_start:y_end].flatten().to(device)
                
                if len(local_y) != 9:
                    # print(f"Padded local_y shape: {local_y.shape}")
                    local_y = torch.nn.functional.pad(local_y, (0, 9 - len(local_y)))
                
                # Correct the dimension for argmax
                simple_targets.append(local_y.argmax(dim=0).cpu())
            
            # Stack all targets and ensure they match the batch size of combined_pred
            all_targets = torch.stack(simple_targets)
            
            # Calculate loss for each prediction separately
            total_loss_this_sample = 0
            for pred, target in zip(combined_pred, all_targets):
                # Convert target to float before calculating loss
                target_float = target.float()
                loss = criterion(pred.unsqueeze(0), target_float.unsqueeze(0))
                total_loss_this_sample += loss.item()
                loss.backward(retain_graph=True)  # Retain graph to allow backpropagation through all agents
            
            total_loss += total_loss_this_sample
            optimizer.step()
        
        average_loss = total_loss / num_samples
        losses.append(average_loss)  # Append the current epoch's average loss
        
        # Validation phase (same as training phase)
        validation_loss = 0
        # Initialize transformed_grid with zeros for consistency
        max_value = inputs[0].max().item()  # Convert tensor to scalar
        transformed_grid_tensor = torch.zeros_like(inputs[0])  # Use torch.rand_like for similar shape and dtype
        
        for i, (x, y) in enumerate(zip(inputs, outputs)):
            simple_predictions = []  # Initialize this list at the start of each iteration
            
            # Collect predictions from simple agents
            for position, model in simple_agent_models.items():
                x_start = max(0, position[0])
                y_start = max(0, position[1])
                x_end = min(x.size(0), position[0] + 3)
                y_end = min(x.size(1), position[1] + 3)

                local_x = x[x_start:x_end, y_start:y_end].flatten().to(device)
                
                if len(local_x) != 9:
                    # print(f"Padded local_x shape: {local_x.shape}")
                    local_x = torch.nn.functional.pad(local_x, (0, 9 - len(local_x)))
                
                # Add epoch to the input
                local_x_with_epoch = torch.cat((local_x, torch.tensor([epoch]).float().to(device)), dim=0)
                
                # Forward pass for the current simple agent
                prediction = model(local_x_with_epoch.unsqueeze(0).to(device))
                # print(f"Simple Agent at position {position} prediction shape: {prediction.shape}")
                
                # Append the prediction to the list (ensure it's on CPU)
                simple_predictions.append(prediction.squeeze().cpu())
            
            combined_pred = torch.stack(simple_predictions)
            
            # Prepare target values for each prediction (validation phase)
            y_flat = y.flatten().long().to(device)
            simple_targets = []
            
            for position in simple_agent_models:
                x_start = max(0, position[0])
                y_start = max(0, position[1])
                x_end = min(x.size(0), position[0] + 3)
                y_end = min(x.size(1), position[1] + 3)

                local_y = y[x_start:x_end, y_start:y_end].flatten().to(device)
                
                if len(local_y) != 9:
                    # print(f"Padded local_y shape: {local_y.shape}")
                    local_y = torch.nn.functional.pad(local_y, (0, 9 - len(local_y)))
                
                # Correct the dimension for argmax
                simple_targets.append(local_y.argmax(dim=0).cpu())
            
            # Stack all targets and ensure they match the batch size of combined_pred
            all_targets = torch.stack(simple_targets)
            
            # Calculate validation loss for each prediction separately
            total_validation_loss_this_sample = 0
            for pred, target in zip(combined_pred, all_targets):
                # Convert target to float before calculating loss
                target_float = target.float()
                loss = criterion(pred.unsqueeze(0), target_float.unsqueeze(0))
                total_validation_loss_this_sample += loss.item()
            
            validation_loss += total_validation_loss_this_sample
            
            # Update transformed grid based on predictions (if needed)
            for j, position in enumerate(simple_agents):
                prediction = combined_pred[j].argmax(dim=0).item()
                # print(f"Agent at position {position} predicted: {prediction}")
                
                transformed_grid_tensor[position[0]][position[1]] = prediction

        average_validation_loss = validation_loss / num_samples
        print(f'Epoch {epoch+1}, Training Loss: {average_loss:.4f}, Validation Loss: {average_validation_loss:.4f}')

        # Update and display the grid visualization
        transformed = transform_grid(inputs[0].cpu(), simple_agents, device=device)
        
        # Debug: Print the updated grid values
        # print(f"Updated Grid after Epoch {epoch + 1}:\n{transformed}")
        
        update_plots(fig, ax1, ax2, epoch + 1, losses, transformed, plot_dir)

        # Check if both training and validation loss are below the threshold
        if average_loss < 0.01 and average_validation_loss < 0.01:
            print("Training complete: Both training and validation loss are below 0.00001.")
            break

        # Update the grid for the next epoch
        inputs[0] = transformed_grid_tensor.to(device)  # Ensure it's on the correct device before updating
        clear_output(wait=True)  # Clear the output before logging new information
        print(f"Updated Grid after Epoch {epoch + 1}:\n{transformed}")

def update_plots(fig, ax1, ax2, epoch, losses, transformed, plot_dir):
    # Clear previous plots
    ax1.clear()
    ax2.clear()

    # Update the loss plot
    ax1.plot(range(1, len(losses) + 1), losses, marker='o', color='blue')
    ax1.set_title(f'Loss Over Epochs (Epoch {epoch})')
    ax1.set_xlabel('Epoch')
    ax1.set_ylabel('Loss')

    # Ensure transformed is a NumPy array before passing to imshow
    if isinstance(transformed, torch.Tensor):
        transformed = transformed.cpu().numpy()
    
    # Update the grid visualization
    cax = ax2.imshow(transformed, cmap='viridis', interpolation='nearest')  # Convert tensor to numpy array
    ax2.set_title(f'Grid Transformation (Epoch {epoch})')
    
    # Save the plot to a file with a unique name based on the epoch number
    loss_plot_path = os.path.join(plot_dir, f'{prefix}_loss_epoch_{epoch}.png')
    grid_plot_path = os.path.join(plot_dir, f'{prefix}_grid_epoch_{epoch}.png')
    
    fig.savefig(loss_plot_path)
    print(f"Saved loss plot for epoch {epoch} to {loss_plot_path}")
    
    # Save only the grid plot without colorbar
    fig2, ax3 = plt.subplots(figsize=(7, 6))
    cax2 = ax3.imshow(transformed, cmap='viridis', interpolation='nearest')
    ax3.set_title(f'Grid Transformation (Epoch {epoch})')
    fig2.savefig(grid_plot_path)
    print(f"Saved grid plot for epoch {epoch} to {grid_plot_path}")
    
    # Close the figure to free up memory
    plt.close(fig)
    plt.close(fig2)

def transform_grid(grid, simple_agents, device='cuda'):
    transformed = np.array(grid)  # Initialize with input grid
    
    # Convert grid to tensor and move it to the specified device
    grid_tensor = torch.tensor(grid, dtype=torch.float32).to(device)
    
    for position in simple_agents:
        x_start = max(0, position[0])
        y_start = max(0, position[1])
        x_end = min(len(transformed), position[0] + 3)
        y_end = min(len(transformed[0]), position[1] + 3)

        local_area = grid[x_start:x_end, y_start:y_end].flatten()
        
        if len(local_area) != 9:
            # print(f"Padded local_area shape: {local_area.shape}")
            local_area = np.pad(local_area, (0, 9 - len(local_area)), 'constant')
        
        # Convert local_area to tensor and move to device
        local_area_tensor = torch.tensor(local_area, dtype=torch.float32).to(device)
        
        prediction = simple_agent_models[position](torch.cat((local_area_tensor.unsqueeze(0), torch.tensor([epoch]).float().unsqueeze(0).to(device)), dim=1)).cpu().argmax(dim=1).item()  # Move back to CPU for numpy conversion
        
        # print(f"Agent at position {position} predicted: {prediction}")
        
        transformed[position[0]][position[1]] = prediction
            
    return transformed

def save_initial_grid(grid, prefix, plot_dir):
    # Generate a random grid with values between 0 and 9
    initial_grid = torch.randint(10, size=grid.shape).float()
    
    fig, ax = plt.subplots(figsize=(7, 6))
    cax = ax.imshow(initial_grid.numpy(), cmap='viridis', interpolation='nearest')
    ax.set_title(f'Initial Grid (Prefix: {prefix})')
    
    # Save the plot to a file
    initial_grid_plot_path = os.path.join(plot_dir, f'{prefix}_initial_grid.png')
    plt.savefig(initial_grid_plot_path)
    print(f"Saved initial grid plot to {initial_grid_plot_path}")
    
    # Close the figure to free up memory
    plt.close(fig)

# Function to load ARC JSON data and convert to tensors
def load_arc_data(json_file):
    with open(json_file, 'r') as f:
        data = json.load(f)
    
    inputs = []
    outputs = []
    
    for task in data['train']:
        input_grid = torch.tensor(task['input'], dtype=torch.float32)
        output_grid = torch.tensor(task['output'], dtype=torch.float32)
        inputs.append(input_grid)
        outputs.append(output_grid)
    
    return inputs, outputs

# Define the plot directory
plot_dir = '/home/xaqmusic/plots'

# Create the directory if it doesn't exist
os.makedirs(plot_dir, exist_ok=True)

# Generate a unique prefix for this run using Unix time (epoch seconds since 1970)
prefix = str(int(datetime.now().timestamp()))  # Example: "1693425180"

# Initialize plots at the beginning of your script
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 6))
loss_plot, = ax1.plot([], [], marker='o')
ax1.set_title('Loss Over Epochs')
ax1.set_xlabel('Epoch')
ax1.set_ylabel('Loss')

# Example usage
if __name__ == "__main__":
    json_file = '/home/xaqmusic/ARC-AGI/data/training/00d62c1b.json'
    
    # Load ARC JSON data and convert to tensors
    inputs, outputs = load_arc_data(json_file)
    
    for i in range(len(inputs)):
        input_grid = inputs[i]
        output_grid = outputs[i]
        
        grid_size = input_grid.shape
        
        simple_agents = distribute_agents(grid_size)
        
        print("Creating agent models...")
        num_simple_agents = len(simple_agents)
        simple_agent_models = {agent: SimpleAgent() for agent in simple_agents}
        
        device = 'cuda' if torch.cuda.is_available() else 'cpu'
        print(f"Using device: {device}")
        
        # **** Generate a random initial grid with values between 0 and 9
        # initial_grid = torch.randint(10, size=grid_size).float().to(device)
        # **** or just the input grid to start
        ititial_grid = input_grid
        
        # train_inputs = [initial_grid]
        train_inputs = [input_grid]
        train_outputs = [output_grid.to(device)]
        
        # Save the initial grid before training
        save_initial_grid(initial_grid.cpu(), prefix, plot_dir)
        
        print("Training Agents...")
        train_agents(fig, ax1, ax2, simple_agent_models, train_inputs, train_outputs, epochs=200, device=device, plot_dir=plot_dir)
        
        print("\nTransforming Grid...")
        transformed = transform_grid(input_grid.cpu(), simple_agents, device=device)
        print("\nInput Grid:")
        print(input_grid.cpu().numpy())
        print("\nOutput Grid after Transformation (Desired):")
        print(output_grid.cpu().numpy())
        print("\nOutput Grid after Transformation (Model's Prediction):")
        print(transformed)
        input("Press Enter to continue...")


Working!   ...as the fanciest random grid generator ever.  That validation step is suspicious.

In [None]:
import json
import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt
from IPython.display import display, clear_output
from tqdm import tqdm
import os
from datetime import datetime

# Ensure the correct backend is used for interactive plots
import matplotlib
matplotlib.use('inline')
%matplotlib inline

class SimpleAgent(nn.Module):
    def __init__(self, input_size=18):  # Updated to handle concatenated input size of 18
        super(SimpleAgent, self).__init__()
        self.fc1 = nn.Linear(input_size, 36)  # First hidden layer with larger capacity for the epoch input
        self.relu = nn.ReLU()
        self.fc2 = nn.Linear(36, 36)  # Second hidden layer
        self.fc3 = nn.Linear(36, 10)  # Output layer
        
    def forward(self, x):
        # print(f"Input shape before concatenation: {x.shape}")
        x = torch.cat((x[:, :-1], x[:, -1].unsqueeze(1).repeat(1, 9)), dim=1)  # Repeat the epoch value to match grid input
        # print(f"Input shape after concatenation: {x.shape}")
        x = self.relu(self.fc1(x))
        x = self.relu(self.fc2(x))
        return torch.nn.functional.log_softmax(self.fc3(x), dim=1)

def distribute_agents(grid_size, agent_scope=3):
    agents = []
    step_size = 1  # Place an agent at every cell
    for i in range(0, grid_size[0], step_size):  
        for j in range(0, grid_size[1], step_size):
            agents.append((i, j))  # Record agent positions
    return agents

def train_agents(fig, ax1, ax2, simple_agent_models, inputs, outputs, epochs=4, learning_rate=0.01, device='cuda', plot_dir=None, model_dir=None):
    criterion = nn.MSELoss()  # Use MSE Loss
    
    print("Moving models to the specified device...")
    for position, model in simple_agent_models.items():
        model.to(device)
        print(f"Simple Agent at position {position} moved to {device}")

    # Flatten list of lists into a single list of parameters
    params_to_optimize = [param for _, model in simple_agent_models.items() for param in model.parameters()]
    
    optimizer = torch.optim.Adam(params_to_optimize, lr=learning_rate)
    
    losses = []  # Initialize an empty list to store all epoch losses
    
    for epoch in tqdm(range(epochs), desc="Training", leave=False):
        print(f"Epoch {epoch + 1}/{epochs}")
        total_loss = 0
        num_samples = len(inputs)
        
        # Training phase
        for i, (x, y) in enumerate(zip(inputs, outputs)):
            optimizer.zero_grad()
            
            simple_predictions = []  # Initialize this list at the start of each iteration
            
            # Collect predictions from simple agents
            for position, model in simple_agent_models.items():
                x_start = max(0, position[0])
                y_start = max(0, position[1])
                x_end = min(x.size(0), position[0] + 3)
                y_end = min(x.size(1), position[1] + 3)

                local_x = x[x_start:x_end, y_start:y_end].flatten().to(device)
                
                if len(local_x) != 9:
                     # print(f"Padded local_x shape: {local_x.shape}")
                    local_x = torch.nn.functional.pad(local_x, (0, 9 - len(local_x)))
                
                # Add epoch to the input
                local_x_with_epoch = torch.cat((local_x, torch.tensor([epoch]).float().to(device)), dim=0)
                
                # Forward pass for the current simple agent
                prediction = model(local_x_with_epoch.unsqueeze(0).to(device))
                # print(f"Simple Agent at position {position} prediction shape: {prediction.shape}")
                
                # Append the prediction to the list (ensure it's on CPU)
                simple_predictions.append(prediction.squeeze().cpu())
            
            combined_pred = torch.stack(simple_predictions)
            
            # Prepare target values for each prediction
            y_flat = y.flatten().long().to(device)
            simple_targets = []
            
            for position in simple_agent_models:
                x_start = max(0, position[0])
                y_start = max(0, position[1])
                x_end = min(x.size(0), position[0] + 3)
                y_end = min(x.size(1), position[1] + 3)

                local_y = y[x_start:x_end, y_start:y_end].flatten().to(device)
                
                if len(local_y) != 9:
                    # print(f"Padded local_y shape: {local_y.shape}")
                    local_y = torch.nn.functional.pad(local_y, (0, 9 - len(local_y)))
                
                # Correct the dimension for argmax
                simple_targets.append(local_y.argmax(dim=0).cpu())
            
            # Stack all targets and ensure they match the batch size of combined_pred
            all_targets = torch.stack(simple_targets)
            
            # Calculate loss for each prediction separately
            total_loss_this_sample = 0
            for pred, target in zip(combined_pred, all_targets):
                # Convert target to float before calculating loss
                target_float = target.float()
                loss = criterion(pred.unsqueeze(0), target_float.unsqueeze(0))
                total_loss_this_sample += loss.item()
                loss.backward(retain_graph=True)  # Retain graph to allow backpropagation through all agents
            
            total_loss += total_loss_this_sample
            optimizer.step()
        
        average_loss = total_loss / num_samples
        losses.append(average_loss)  # Append the current epoch's average loss
        
        # Validation phase (same as training phase)
        validation_loss = 0
        # Initialize transformed_grid with zeros for consistency
        max_value = inputs[0].max().item()  # Convert tensor to scalar
        transformed_grid_tensor = torch.zeros_like(inputs[0])  # Use torch.rand_like for similar shape and dtype
        
        for i, (x, y) in enumerate(zip(inputs, outputs)):
            simple_predictions = []  # Initialize this list at the start of each iteration
            
            # Collect predictions from simple agents
            for position, model in simple_agent_models.items():
                x_start = max(0, position[0])
                y_start = max(0, position[1])
                x_end = min(x.size(0), position[0] + 3)
                y_end = min(x.size(1), position[1] + 3)

                local_x = x[x_start:x_end, y_start:y_end].flatten().to(device)
                
                if len(local_x) != 9:
                    # print(f"Padded local_x shape: {local_x.shape}")
                    local_x = torch.nn.functional.pad(local_x, (0, 9 - len(local_x)))
                
                # Add epoch to the input
                local_x_with_epoch = torch.cat((local_x, torch.tensor([epoch]).float().to(device)), dim=0)
                
                # Forward pass for the current simple agent
                prediction = model(local_x_with_epoch.unsqueeze(0).to(device))
                # print(f"Simple Agent at position {position} prediction shape: {prediction.shape}")
                
                # Append the prediction to the list (ensure it's on CPU)
                simple_predictions.append(prediction.squeeze().cpu())
            
            combined_pred = torch.stack(simple_predictions)
            
            # Prepare target values for each prediction (validation phase)
            y_flat = y.flatten().long().to(device)
            simple_targets = []
            
            for position in simple_agent_models:
                x_start = max(0, position[0])
                y_start = max(0, position[1])
                x_end = min(x.size(0), position[0] + 3)
                y_end = min(x.size(1), position[1] + 3)

                local_y = y[x_start:x_end, y_start:y_end].flatten().to(device)
                
                if len(local_y) != 9:
                    # print(f"Padded local_y shape: {local_y.shape}")
                    local_y = torch.nn.functional.pad(local_y, (0, 9 - len(local_y)))
                
                # Correct the dimension for argmax
                simple_targets.append(local_y.argmax(dim=0).cpu())
            
            # Stack all targets and ensure they match the batch size of combined_pred
            all_targets = torch.stack(simple_targets)
            
            # Calculate validation loss for each prediction separately
            total_validation_loss_this_sample = 0
            for pred, target in zip(combined_pred, all_targets):
                # Convert target to float before calculating loss
                target_float = target.float()
                loss = criterion(pred.unsqueeze(0), target_float.unsqueeze(0))
                total_validation_loss_this_sample += loss.item()
            
            validation_loss += total_validation_loss_this_sample
            
            # Update transformed grid based on predictions (if needed)
            for j, position in enumerate(simple_agents):
                prediction = combined_pred[j].argmax(dim=0).item()
                # print(f"Agent at position {position} predicted: {prediction}")
                
                transformed_grid_tensor[position[0]][position[1]] = prediction

        average_validation_loss = validation_loss / num_samples
        print(f'Epoch {epoch+1}, Training Loss: {average_loss:.4f}, Validation Loss: {average_validation_loss:.4f}')

        # Update and display the grid visualization
        transformed = transform_grid(inputs[0].cpu(), simple_agents, device=device)
        
        # Debug: Print the updated grid values
        # print(f"Updated Grid after Epoch {epoch + 1}:\n{transformed}")
        
        update_plots(fig, ax1, ax2, epoch + 1, losses, transformed, plot_dir)

        # Check if both training and validation loss are below the threshold
        if average_loss < 0.001 and average_validation_loss < 0.001:
            print("Training complete: Both training and validation loss are below 0.00001.")
            break

        # Update the grid for the next epoch
        inputs[0] = transformed_grid_tensor.to(device)  # Ensure it's on the correct device before updating
        clear_output(wait=True)  # Clear the output before logging new information
        print(f"Updated Grid after Epoch {epoch + 1}:\n{transformed}")
    
    # Save the model after training
    for position, model in simple_agent_models.items():
        model_path = os.path.join(model_dir, f'simple_agent_{position[0]}_{position[1]}.pth')
        torch.save(model.state_dict(), model_path)
        print(f"Saved model at position {position} to {model_path}")

def update_plots(fig, ax1, ax2, epoch, losses, transformed, plot_dir):
    # Clear previous plots
    ax1.clear()
    ax2.clear()

    # Update the loss plot
    ax1.plot(range(1, len(losses) + 1), losses, marker='o', color='blue')
    ax1.set_title(f'Loss Over Epochs (Epoch {epoch})')
    ax1.set_xlabel('Epoch')
    ax1.set_ylabel('Loss')

    # Ensure transformed is a NumPy array before passing to imshow
    if isinstance(transformed, torch.Tensor):
        transformed = transformed.cpu().numpy()
    
    # Update the grid visualization
    cax = ax2.imshow(transformed, cmap='viridis', interpolation='nearest')  # Convert tensor to numpy array
    ax2.set_title(f'Grid Transformation (Epoch {epoch})')
    
    # Save the plot to a file with a unique name based on the epoch number with linux timestamp
    prefix = str(int(datetime.now().timestamp()))
    loss_plot_path = os.path.join(plot_dir, f'{prefix}_loss_epoch_{epoch}.png')
    grid_plot_path = os.path.join(plot_dir, f'{prefix}_grid_epoch_{epoch}.png')
    
    fig.savefig(loss_plot_path)
    print(f"Saved loss plot for epoch {epoch} to {loss_plot_path}")
    
    # ******** Save only the grid plot without colorbar
    # fig2, ax3 = plt.subplots(figsize=(7, 6))
    # cax2 = ax3.imshow(transformed, cmap='viridis', interpolation='nearest')
    # ax3.set_title(f'Grid Transformation (Epoch {epoch})')
    # fig2.savefig(grid_plot_path)
    # print(f"Saved grid plot for epoch {epoch} to {grid_plot_path}")
    
    # Close the figure to free up memory
    plt.close(fig)
    # plt.close(fig2)

def transform_grid(grid, simple_agents, device='cuda'):
    transformed = np.array(grid)  # Initialize with input grid
    
    # Convert grid to tensor and move it to the specified device
    grid_tensor = torch.tensor(grid, dtype=torch.float32).to(device)
    
    for position in simple_agents:
        x_start = max(0, position[0])
        y_start = max(0, position[1])
        x_end = min(len(transformed), position[0] + 3)
        y_end = min(len(transformed[0]), position[1] + 3)

        local_area = grid[x_start:x_end, y_start:y_end].flatten()
        
        if len(local_area) != 9:
            # print(f"Padded local_area shape: {local_area.shape}")
            local_area = np.pad(local_area, (0, 9 - len(local_area)), 'constant')
        
        # Convert local_area to tensor and move to device
        local_area_tensor = torch.tensor(local_area, dtype=torch.float32).to(device)
        
        prediction = simple_agent_models[position](torch.cat((local_area_tensor.unsqueeze(0), torch.tensor([epoch]).float().unsqueeze(0).to(device)), dim=1)).cpu().argmax(dim=1).item()  # Move back to CPU for numpy conversion
        
        # print(f"Agent at position {position} predicted: {prediction}")
        
        transformed[position[0]][position[1]] = prediction
            
    return transformed

def save_initial_grid(grid, prefix, plot_dir):
    # Generate a random grid with values between 0 and 9
    # initial_grid = torch.randint(10, size=grid.shape).float()
    
    fig, ax = plt.subplots(figsize=(7, 6))
    cax = ax.imshow(initial_grid.numpy(), cmap='viridis', interpolation='nearest')
    ax.set_title(f'Initial Grid (Prefix: {prefix})')
    
    # Save the plot to a file
    initial_grid_plot_path = os.path.join(plot_dir, f'{prefix}_initial_grid.png')
    plt.savefig(initial_grid_plot_path)
    print(f"Saved initial grid plot to {initial_grid_plot_path}")
    
    # Close the figure to free up memory
    plt.close(fig)

# Function to load ARC JSON data and convert to tensors
def load_arc_data(json_file):
    with open(json_file, 'r') as f:
        data = json.load(f)
    
    inputs = []
    outputs = []
    
    for task in data['train']:
        input_grid = torch.tensor(task['input'], dtype=torch.float32)
        output_grid = torch.tensor(task['output'], dtype=torch.float32)
        inputs.append(input_grid)
        outputs.append(output_grid)
    
    return inputs, outputs

# Define the plot directory and model directory
plot_dir = '/home/xaqmusic/plots'
model_dir = '/home/xaqmusic/models'

# Create the directories if they don't exist
os.makedirs(plot_dir, exist_ok=True)
os.makedirs(model_dir, exist_ok=True)

# Generate a unique prefix for this run using Unix time (epoch seconds since 1970)
prefix = str(int(datetime.now().timestamp()))  # Example: "1693425180"

# Initialize plots at the beginning of your script
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 6))
loss_plot, = ax1.plot([], [], marker='o')
ax1.set_title('Loss Over Epochs')
ax1.set_xlabel('Epoch')
ax1.set_ylabel('Loss')

# Example usage
if __name__ == "__main__":
    json_file = '/home/xaqmusic/ARC-AGI/data/training/00d62c1b.json'
    
    # Load ARC JSON data and convert to tensors
    inputs, outputs = load_arc_data(json_file)
    
    for i in range(len(inputs)):
        input_grid = inputs[i]
        output_grid = outputs[i]
        
        grid_size = input_grid.shape
        
        simple_agents = distribute_agents(grid_size)
        
        print("Creating agent models...")
        num_simple_agents = len(simple_agents)
        simple_agent_models = {}
        for position in simple_agents:
            model_path = os.path.join(model_dir, f'simple_agent_{position[0]}_{position[1]}.pth')
            if os.path.exists(model_path):
                print(f"Loading model from {model_path}")
                model = SimpleAgent()
                model.load_state_dict(torch.load(model_path))
            else:
                model = SimpleAgent()
            simple_agent_models[position] = model
        
        device = 'cuda' if torch.cuda.is_available() else 'cpu'
        print(f"Using device: {device}")
        
        # **** Generate a random initial grid with values between 0 and 9
        # initial_grid = torch.randint(10, size=grid_size).float().to(device)
        # **** or just the input grid to start
        ititial_grid = input_grid
        print(f"\nInitial Grid:\n{ititial_grid}")
        
        # train_inputs = [initial_grid]
        train_inputs = [input_grid.to(device)]
        print(f"\nInput Grid:\n{train_inputs}")
        train_outputs = [output_grid.to(device)]
        print(f"\nOutput Grid:\n{train_outputs}")
        
        # Save the initial grid before training
        # save_initial_grid(initial_grid.cpu(), prefix, plot_dir)
        # pause for debugging
        input("Press Enter to continue...")
        
        print("Training Agents...")
        train_agents(fig, ax1, ax2, simple_agent_models, train_inputs, train_outputs, epochs=500, device=device, plot_dir=plot_dir, model_dir=model_dir)
        
        print("\nTransforming Grid...")
        transformed = transform_grid(input_grid.cpu(), simple_agents, device=device)
        print("\nInput Grid:")
        print(input_grid.cpu().numpy())
        print("\nOutput Grid after Transformation (Desired):")
        print(output_grid.cpu().numpy())
        print("\nOutput Grid after Transformation (Model's Prediction):")
        print(transformed)
        # pause for debugging
        input("Press Enter to continue...")


# 3x3 SimpleAgents Working on ARC Puzzles POC

## 0.1 Basically Working

In [None]:
import json
import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt
from IPython.display import display, clear_output
from tqdm import tqdm
import os
from datetime import datetime

# Ensure the correct backend is used for interactive plots
import matplotlib
matplotlib.use('inline')
%matplotlib inline

class SimpleAgent(nn.Module):
    def __init__(self, input_size=18):  # Updated to handle concatenated input size of 18
        super(SimpleAgent, self).__init__()
        self.fc1 = nn.Linear(input_size, 36)  # First hidden layer with larger capacity for the epoch input
        self.relu = nn.ReLU()
        self.fc2 = nn.Linear(36, 36)  # Second hidden layer
        self.fc3 = nn.Linear(36, 10)  # Output layer
        
    def forward(self, x):
        # print(f"Input shape before concatenation: {x.shape}")
        x = torch.cat((x[:, :-1], x[:, -1].unsqueeze(1).repeat(1, 9)), dim=1)  # Repeat the epoch value to match grid input
        # print(f"Input shape after concatenation: {x.shape}")
        x = self.relu(self.fc1(x))
        x = self.relu(self.fc2(x))
        return torch.nn.functional.log_softmax(self.fc3(x), dim=1)

def distribute_agents(grid_size, agent_scope=3):
    agents = []
    step_size = 1  # Place an agent at every cell
    for i in range(0, grid_size[0], step_size):  
        for j in range(0, grid_size[1], step_size):
            agents.append((i, j))  # Record agent positions
    return agents

def train_agents(fig, ax1, ax2, simple_agent_models, inputs, outputs, epochs=4, learning_rate=0.1, device='cuda', plot_dir=None, model_dir=None):
    criterion = nn.MSELoss()  # Use MSE Loss
    
    print("Moving models to the specified device...")
    for position, model in simple_agent_models.items():
        model.to(device)
        print(f"Simple Agent at position {position} moved to {device}")

    # Flatten list of lists into a single list of parameters
    params_to_optimize = [param for _, model in simple_agent_models.items() for param in model.parameters()]
    
    optimizer = torch.optim.Adam(params_to_optimize, lr=learning_rate)
    
    losses = []  # Initialize an empty list to store all epoch losses
    
    for epoch in tqdm(range(epochs), desc="Training", leave=False):
        print(f"Epoch {epoch + 1}/{epochs}")
        total_loss = 0
        num_samples = len(inputs)
        
        # Training phase
        for i, (x, y) in enumerate(zip(inputs, outputs)):
            optimizer.zero_grad()
            
            simple_predictions = []  # Initialize this list at the start of each iteration
            
            # Collect predictions from simple agents
            for position, model in simple_agent_models.items():
                x_start = max(0, position[0])
                y_start = max(0, position[1])
                x_end = min(x.size(0), position[0] + 3)
                y_end = min(x.size(1), position[1] + 3)

                local_x = x[x_start:x_end, y_start:y_end].flatten().to(device)
                
                if len(local_x) != 9:
                     # print(f"Padded local_x shape: {local_x.shape}")
                    local_x = torch.nn.functional.pad(local_x, (0, 9 - len(local_x)))
                
                # Add epoch to the input
                local_x_with_epoch = torch.cat((local_x, torch.tensor([epoch]).float().to(device)), dim=0)
                
                # Forward pass for the current simple agent
                prediction = model(local_x_with_epoch.unsqueeze(0).to(device))
                # print(f"Simple Agent at position {position} prediction shape: {prediction.shape}")
                
                # Append the prediction to the list (ensure it's on CPU)
                simple_predictions.append(prediction.squeeze().cpu())
            
            combined_pred = torch.stack(simple_predictions)
            
            # Prepare target values for each prediction
            y_flat = y.flatten().long().to(device)
            simple_targets = []
            
            for position in simple_agent_models:
                x_start = max(0, position[0])
                y_start = max(0, position[1])
                x_end = min(x.size(0), position[0] + 3)
                y_end = min(x.size(1), position[1] + 3)

                local_y = y[x_start:x_end, y_start:y_end].flatten().to(device)
                
                if len(local_y) != 9:
                    # print(f"Padded local_y shape: {local_y.shape}")
                    local_y = torch.nn.functional.pad(local_y, (0, 9 - len(local_y)))
                
                # Correct the dimension for argmax
                simple_targets.append(local_y.argmax(dim=0).cpu())
            
            # Stack all targets and ensure they match the batch size of combined_pred
            all_targets = torch.stack(simple_targets)
            
            # Calculate loss for each prediction separately
            total_loss_this_sample = 0
            for pred, target in zip(combined_pred, all_targets):
                # Convert target to float before calculating loss
                target_float = target.float()
                loss = criterion(pred.unsqueeze(0), target_float.unsqueeze(0))
                total_loss_this_sample += loss.item()
                loss.backward(retain_graph=True)  # Retain graph to allow backpropagation through all agents
            
            total_loss += total_loss_this_sample
            optimizer.step()
        
        average_loss = total_loss / num_samples
        losses.append(average_loss)  # Append the current epoch's average loss
        
        # Validation phase (same as training phase)
        validation_loss = 0
        # Initialize transformed_grid with zeros for consistency
        max_value = inputs[0].max().item()  # Convert tensor to scalar
        transformed_grid_tensor = torch.zeros_like(inputs[0])  # Use torch.rand_like for similar shape and dtype
        
        with torch.no_grad():
            for i, (x, y) in enumerate(zip(inputs, outputs)):
                simple_predictions = []  # Initialize this list at the start of each iteration
                
                # Collect predictions from simple agents
                for position, model in simple_agent_models.items():
                    x_start = max(0, position[0])
                    y_start = max(0, position[1])
                    x_end = min(x.size(0), position[0] + 3)
                    y_end = min(x.size(1), position[1] + 3)

                    local_x = x[x_start:x_end, y_start:y_end].flatten().to(device)
                    
                    if len(local_x) != 9:
                        local_x = torch.nn.functional.pad(local_x, (0, 9 - len(local_x)))
                    
                    # Add epoch to the input
                    local_x_with_epoch = torch.cat((local_x, torch.tensor([epoch]).float().to(device)), dim=0)
                    
                    # Forward pass for the current simple agent
                    prediction = model(local_x_with_epoch.unsqueeze(0).to(device))
                    
                    # Append the prediction to the list (ensure it's on CPU)
                    simple_predictions.append(prediction.squeeze().cpu())
                
                combined_pred = torch.stack(simple_predictions)
                
                # Prepare target values for each prediction (validation phase)
                y_flat = y.flatten().long().to(device)
                simple_targets = []
                
                for position in simple_agent_models:
                    x_start = max(0, position[0])
                    y_start = max(0, position[1])
                    x_end = min(x.size(0), position[0] + 3)
                    y_end = min(x.size(1), position[1] + 3)

                    local_y = y[x_start:x_end, y_start:y_end].flatten().to(device)
                    
                    if len(local_y) != 9:
                        local_y = torch.nn.functional.pad(local_y, (0, 9 - len(local_y)))
                    
                    # Correct the dimension for argmax
                    simple_targets.append(local_y.argmax(dim=0).cpu())
                
                # Stack all targets and ensure they match the batch size of combined_pred
                all_targets = torch.stack(simple_targets)
                
                # Calculate validation loss for each prediction separately
                total_validation_loss_this_sample = 0
                for pred, target in zip(combined_pred, all_targets):
                    # Convert target to float before calculating loss
                    target_float = target.float()
                    loss = criterion(pred.unsqueeze(0), target_float.unsqueeze(0))
                    total_validation_loss_this_sample += loss.item()
                
                validation_loss += total_validation_loss_this_sample
                
                # Update transformed grid based on predictions (if needed)
                for j, position in enumerate(simple_agents):
                    prediction = combined_pred[j].argmax(dim=0).item()
                    
                    transformed_grid_tensor[position[0]][position[1]] = prediction

        average_validation_loss = validation_loss / num_samples
        # print(f'Epoch {epoch+1}, Training Loss: {average_loss:.4f}, Validation Loss: {average_validation_loss:.4f}')

        # Update and display the grid visualization
        transformed = transform_grid(inputs[0].cpu(), simple_agents, epoch=epoch, device=device)
        
        # Debug: Print the updated grid values
        # print(f"Updated Grid after Epoch {epoch + 1}:\n{transformed}")
        
        update_plots(fig, ax1, ax2, epoch + 1, losses, transformed, plot_dir)

        # Check if both training and validation loss are below the threshold
        if average_loss < 0.001 and average_validation_loss < 0.001:
            print("Training complete: Both training and validation loss are below 0.00001.")
            break

        # Update the grid for the next epoch
        inputs[0] = transformed_grid_tensor.to(device)  # Ensure it's on the correct device before updating
        clear_output(wait=True)  # Clear the output before logging new information then display last epoch info
        print(f"Updated Grid after Epoch {epoch + 1}:\n{transformed}")
        print(f'Epoch {epoch+1}, Training Loss: {average_loss:.4f}, Validation Loss: {average_validation_loss:.4f}')
    
    # Save the model after training
    for position, model in simple_agent_models.items():
        model_path = os.path.join(model_dir, f'simple_agent_{position[0]}_{position[1]}.pth')
        torch.save(model.state_dict(), model_path)
        # print(f"Saved model at position {position} to {model_path}")

def update_plots(fig, ax1, ax2, epoch, losses, transformed, plot_dir):
    # Clear previous plots
    ax1.clear()
    ax2.clear()

    # Update the loss plot
    ax1.plot(range(1, len(losses) + 1), losses, marker='o', color='blue')
    ax1.set_title(f'Loss Over Epochs (Epoch {epoch})')
    ax1.set_xlabel('Epoch')
    ax1.set_ylabel('Loss')

    # Ensure transformed is a NumPy array before passing to imshow
    if isinstance(transformed, torch.Tensor):
        transformed = transformed.cpu().numpy()
    
    # Update the grid visualization
    cax = ax2.imshow(transformed, cmap='viridis', interpolation='nearest')  # Convert tensor to numpy array
    ax2.set_title(f'Grid Transformation (Epoch {epoch})')
    
    # Save the plot to a file with a unique name based on the epoch number with linux timestamp
    prefix = str(int(datetime.now().timestamp()))
    loss_plot_path = os.path.join(plot_dir, f'{prefix}_loss_epoch_{epoch}.png')
    grid_plot_path = os.path.join(plot_dir, f'{prefix}_grid_epoch_{epoch}.png')
    
    fig.savefig(loss_plot_path)
    print(f"Saved loss plot for epoch {epoch} to {loss_plot_path}")
    
    # ******** Save only the grid plot without colorbar
    # fig2, ax3 = plt.subplots(figsize=(7, 6))
    # cax2 = ax3.imshow(transformed, cmap='viridis', interpolation='nearest')
    # ax3.set_title(f'Grid Transformation (Epoch {epoch})')
    # fig2.savefig(grid_plot_path)
    # print(f"Saved grid plot for epoch {epoch} to {grid_plot_path}")
    
    # Close the figure to free up memory
    plt.close(fig)
    # plt.close(fig2)

def transform_grid(grid, simple_agents, epoch, device='cuda'):
    transformed = np.array(grid)  # Initialize with input grid
    
    # Convert grid to tensor and move it to the specified device
    grid_tensor = torch.tensor(grid, dtype=torch.float32).to(device)
    
    for position in simple_agents:
        x_start = max(0, position[0])
        y_start = max(0, position[1])
        x_end = min(len(transformed), position[0] + 3)
        y_end = min(len(transformed[0]), position[1] + 3)

        local_area = grid[x_start:x_end, y_start:y_end].flatten()
        
        if len(local_area) != 9:
            # print(f"Padded local_area shape: {local_area.shape}")
            local_area = np.pad(local_area, (0, 9 - len(local_area)), 'constant')
        
        # Convert local_area to tensor and move to device
        local_area_tensor = torch.tensor(local_area, dtype=torch.float32).to(device)
        
        prediction = simple_agent_models[position](torch.cat((local_area_tensor.unsqueeze(0), torch.tensor([epoch]).float().unsqueeze(0).to(device)), dim=1)).cpu().argmax(dim=1).item()  # Move back to CPU for numpy conversion
        
        # print(f"Agent at position {position} predicted: {prediction}")
        
        transformed[position[0]][position[1]] = prediction
            
    return transformed

def save_initial_grid(grid, prefix, plot_dir):
    # Generate a random grid with values between 0 and 9
    # initial_grid = torch.randint(10, size=grid.shape).float()
    initial_grid = grid
    
    fig, ax = plt.subplots(figsize=(7, 6))
    cax = ax.imshow(initial_grid.numpy(), cmap='viridis', interpolation='nearest')
    ax.set_title(f'Initial Grid (Prefix: {prefix})')
    
    # Save the plot to a file
    initial_grid_plot_path = os.path.join(plot_dir, f'{prefix}_initial_grid.png')
    plt.savefig(initial_grid_plot_path)
    print(f"Saved initial grid plot to {initial_grid_plot_path}")
    
    # Close the figure to free up memory
    plt.close(fig)

# Function to load ARC JSON data and convert to tensors
def load_arc_data(json_file):
    with open(json_file, 'r') as f:
        data = json.load(f)
    
    inputs = []
    outputs = []
    
    for task in data['train']:
        input_grid = torch.tensor(task['input'], dtype=torch.float32)
        output_grid = torch.tensor(task['output'], dtype=torch.float32)
        inputs.append(input_grid)
        outputs.append(output_grid)
    
    return inputs, outputs

# Define the plot directory and model directory
plot_dir = '/home/xaqmusic/plots'
model_dir = '/home/xaqmusic/models'

# Create the directories if they don't exist
os.makedirs(plot_dir, exist_ok=True)
os.makedirs(model_dir, exist_ok=True)

# Generate a unique prefix for this run using Unix time (epoch seconds since 1970)
prefix = str(int(datetime.now().timestamp()))  # Example: "1693425180"

# Initialize plots at the beginning of your script
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 6))
loss_plot, = ax1.plot([], [], marker='o')
ax1.set_title('Loss Over Epochs')
ax1.set_xlabel('Epoch')
ax1.set_ylabel('Loss')

# Example usage
if __name__ == "__main__":
    json_file = '/home/xaqmusic/ARC-AGI/data/training/00d62c1b.json'
    
    # Load ARC JSON data and convert to tensors
    inputs, outputs = load_arc_data(json_file)
    
    for i in range(len(inputs)):
        input_grid = inputs[i]
        output_grid = outputs[i]
        
        grid_size = input_grid.shape
        
        simple_agents = distribute_agents(grid_size)
        
        print("Creating agent models...")
        num_simple_agents = len(simple_agents)
        simple_agent_models = {}
        for position in simple_agents:
            model_path = os.path.join(model_dir, f'simple_agent_{position[0]}_{position[1]}.pth')
            if os.path.exists(model_path):
                print(f"Loading model from {model_path}")
                model = SimpleAgent()
                model.load_state_dict(torch.load(model_path))
            else:
                model = SimpleAgent()
            simple_agent_models[position] = model
        
        device = 'cuda' if torch.cuda.is_available() else 'cpu'
        print(f"Using device: {device}")
        
        # **** Generate a random initial grid with values between 0 and 9
        # initial_grid = torch.randint(10, size=grid_size).float().to(device)
        # **** or just the input grid to start
        initial_grid = input_grid
        print(f"\nInitial Grid:\n{initial_grid}")
        
        train_inputs = [initial_grid]
        # train_inputs = [input_grid.to(device)]
        print(f"\nInput Grid:\n{train_inputs}")
        train_outputs = [output_grid.to(device)]
        print(f"\nOutput Grid:\n{train_outputs}")
        
        # Save the initial grid before training
        save_initial_grid(initial_grid.cpu(), prefix, plot_dir)
        # pause for debugging
        # input("Press Enter to continue...")
        
        print("Training Agents...")
        train_agents(fig, ax1, ax2, simple_agent_models, train_inputs, train_outputs, epochs=1000, device=device, plot_dir=plot_dir, model_dir=model_dir)
        
        print("\nTransforming Grid...")
        transformed = transform_grid(input_grid.cpu(), simple_agents, epoch=0, device=device)
        print("\nInput Grid:")
        print(input_grid.cpu().numpy())
        print("\nOutput Grid after Transformation (Desired):")
        print(output_grid.cpu().numpy())
        print("\nOutput Grid after Transformation (Model's Prediction):")
        print(transformed)
        # pause for debugging
        # input("Press Enter to continue...")


## 0.2 Normalized epoch input and step through network for vaidation

In [None]:
import json
import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt
from IPython.display import display, clear_output
from tqdm import tqdm
import os
from datetime import datetime

# Ensure the correct backend is used for interactive plots
import matplotlib
matplotlib.use('inline')
%matplotlib inline

class SimpleAgent(nn.Module):
    def __init__(self, input_size=18):  # Updated to handle concatenated input size of 18
        super(SimpleAgent, self).__init__()
        self.fc1 = nn.Linear(input_size, 36)  # First hidden layer with larger capacity for the epoch input
        self.relu = nn.ReLU()
        self.fc2 = nn.Linear(36, 36)  # Second hidden layer
        self.fc3 = nn.Linear(36, 10)  # Output layer
        
    def forward(self, x):
        # print(f"Input shape before concatenation: {x.shape}")
        x = torch.cat((x[:, :-1], x[:, -1].unsqueeze(1).repeat(1, 9)), dim=1)  # Repeat the epoch value to match grid input
        # print(f"Input shape after concatenation: {x.shape}")
        x = self.relu(self.fc1(x))
        x = self.relu(self.fc2(x))
        return torch.nn.functional.log_softmax(self.fc3(x), dim=1)

def distribute_agents(grid_size, agent_scope=3):
    agents = []
    step_size = 1  # Place an agent at every cell
    for i in range(0, grid_size[0], step_size):  
        for j in range(0, grid_size[1], step_size):
            agents.append((i, j))  # Record agent positions
    return agents

def train_agents(fig, ax1, ax2, simple_agent_models, inputs, outputs, epochs=4, learning_rate=0.1, device='cuda', plot_dir=None, model_dir=None):
    criterion = nn.MSELoss()  # Use MSE Loss
    
    print("Moving models to the specified device...")
    for position, model in simple_agent_models.items():
        model.to(device)
        print(f"Simple Agent at position {position} moved to {device}")

    # Flatten list of lists into a single list of parameters
    params_to_optimize = [param for _, model in simple_agent_models.items() for param in model.parameters()]
    
    optimizer = torch.optim.Adam(params_to_optimize, lr=learning_rate)
    
    losses = []  # Initialize an empty list to store all epoch losses
    
    for epoch in tqdm(range(epochs), desc="Training", leave=False):
        print(f"Epoch {epoch + 1}/{epochs}")
        total_loss = 0
        num_samples = len(inputs)
        
        # Training phase
        for i, (x, y) in enumerate(zip(inputs, outputs)):
            optimizer.zero_grad()
            
            simple_predictions = []  # Initialize this list at the start of each iteration
            
            # Collect predictions from simple agents
            for position, model in simple_agent_models.items():
                x_start = max(0, position[0])
                y_start = max(0, position[1])
                x_end = min(x.size(0), position[0] + 3)
                y_end = min(x.size(1), position[1] + 3)

                local_x = x[x_start:x_end, y_start:y_end].flatten().to(device)
                
                if len(local_x) != 9:
                     # print(f"Padded local_x shape: {local_x.shape}")
                    local_x = torch.nn.functional.pad(local_x, (0, 9 - len(local_x)))
                
                # Normalize the epoch
                normalized_epoch = float(epoch) / epochs
                
                # Add epoch to the input
                local_x_with_epoch = torch.cat((local_x, torch.tensor([normalized_epoch]).float().to(device)), dim=0)
                
                # Forward pass for the current simple agent
                prediction = model(local_x_with_epoch.unsqueeze(0).to(device))
                # print(f"Simple Agent at position {position} prediction shape: {prediction.shape}")
                
                # Append the prediction to the list (ensure it's on CPU)
                simple_predictions.append(prediction.squeeze().cpu())
            
            combined_pred = torch.stack(simple_predictions)
            
            # Prepare target values for each prediction
            y_flat = y.flatten().long().to(device)
            simple_targets = []
            
            for position in simple_agent_models:
                x_start = max(0, position[0])
                y_start = max(0, position[1])
                x_end = min(x.size(0), position[0] + 3)
                y_end = min(x.size(1), position[1] + 3)

                local_y = y[x_start:x_end, y_start:y_end].flatten().to(device)
                
                if len(local_y) != 9:
                    # print(f"Padded local_y shape: {local_y.shape}")
                    local_y = torch.nn.functional.pad(local_y, (0, 9 - len(local_y)))
                
                # Correct the dimension for argmax
                simple_targets.append(local_y.argmax(dim=0).cpu())
            
            # Stack all targets and ensure they match the batch size of combined_pred
            all_targets = torch.stack(simple_targets)
            
            # Calculate loss for each prediction separately
            total_loss_this_sample = 0
            for pred, target in zip(combined_pred, all_targets):
                # Convert target to float before calculating loss
                target_float = target.float()
                loss = criterion(pred.unsqueeze(0), target_float.unsqueeze(0))
                total_loss_this_sample += loss.item()
                loss.backward(retain_graph=True)  # Retain graph to allow backpropagation through all agents
            
            total_loss += total_loss_this_sample
            optimizer.step()
        
        average_loss = total_loss / num_samples
        losses.append(average_loss)  # Append the current epoch's average loss
        
        # Validation phase (same as training phase)
        validation_loss = 0
        transformed_grid_tensor = torch.clone(inputs[0]).to(device)  # Start with the input grid

        with torch.no_grad():
            for _ in range(10):  # Iterate the network for 10 steps
                simple_predictions = []  # Initialize this list at the start of each iteration
                
                # Collect predictions from simple agents
                for position, model in simple_agent_models.items():
                    x_start = max(0, position[0])
                    y_start = max(0, position[1])
                    x_end = min(transformed_grid_tensor.size(0), position[0] + 3)
                    y_end = min(transformed_grid_tensor.size(1), position[1] + 3)

                    local_x = transformed_grid_tensor[x_start:x_end, y_start:y_end].flatten().to(device)
                    
                    if len(local_x) != 9:
                        local_x = torch.nn.functional.pad(local_x, (0, 9 - len(local_x)))
                    
                    # Normalize the epoch
                    normalized_epoch = float(epoch) / epochs
                    
                    # Add epoch to the input
                    local_x_with_epoch = torch.cat((local_x, torch.tensor([normalized_epoch]).float().to(device)), dim=0)
                    
                    # Forward pass for the current simple agent
                    prediction = model(local_x_with_epoch.unsqueeze(0).to(device))
                    
                    # Append the prediction to the list (ensure it's on CPU)
                    simple_predictions.append(prediction.squeeze().cpu())
                
                combined_pred = torch.stack(simple_predictions)
                
                # Update transformed grid based on predictions (if needed)
                for j, position in enumerate(simple_agents):
                    prediction = combined_pred[j].argmax(dim=0).item()
                    
                    transformed_grid_tensor[position[0]][position[1]] = prediction

        average_validation_loss = validation_loss / num_samples
        # print(f'Epoch {epoch+1}, Training Loss: {average_loss:.4f}, Validation Loss: {average_validation_loss:.4f}')

        # Update and display the grid visualization
        transformed = transform_grid(inputs[0].cpu(), simple_agents, epoch=epoch, device=device)
        
        # Debug: Print the updated grid values
        # print(f"Updated Grid after Epoch {epoch + 1}:\n{transformed}")
        
        update_plots(fig, ax1, ax2, epoch + 1, losses, transformed, plot_dir)

        # Check if both training and validation loss are below the threshold
        if average_loss < 0.001 and average_validation_loss < 0.001:
            print("Training complete: Both training and validation loss are below 0.00001.")
            break

        # Update the grid for the next epoch
        inputs[0] = transformed_grid_tensor.to(device)  # Ensure it's on the correct device before updating
        clear_output(wait=True)  # Clear the output before logging new information then display last epoch info
        print(f"Updated Grid after Epoch {epoch + 1}:\n{transformed}")
        print(f'Epoch {epoch+1}, Training Loss: {average_loss:.4f}, Validation Loss: {average_validation_loss:.4f}')
    
    # Save the model after training
    for position, model in simple_agent_models.items():
        model_path = os.path.join(model_dir, f'simple_agent_{position[0]}_{position[1]}.pth')
        torch.save(model.state_dict(), model_path)
        # print(f"Saved model at position {position} to {model_path}")

def update_plots(fig, ax1, ax2, epoch, losses, transformed, plot_dir):
    # Clear previous plots
    ax1.clear()
    ax2.clear()

    # Update the loss plot
    ax1.plot(range(1, len(losses) + 1), losses, marker='o', color='blue')
    ax1.set_title(f'Loss Over Epochs (Epoch {epoch})')
    ax1.set_xlabel('Epoch')
    ax1.set_ylabel('Loss')

    # Ensure transformed is a NumPy array before passing to imshow
    if isinstance(transformed, torch.Tensor):
        transformed = transformed.cpu().numpy()
    
    # Update the grid visualization
    cax = ax2.imshow(transformed, cmap='viridis', interpolation='nearest')  # Convert tensor to numpy array
    ax2.set_title(f'Grid Transformation (Epoch {epoch})')
    
    # Save the plot to a file with a unique name based on the epoch number with linux timestamp
    prefix = str(int(datetime.now().timestamp()))
    loss_plot_path = os.path.join(plot_dir, f'{prefix}_loss_epoch_{epoch}.png')
    grid_plot_path = os.path.join(plot_dir, f'{prefix}_grid_epoch_{epoch}.png')
    
    fig.savefig(loss_plot_path)
    print(f"Saved loss plot for epoch {epoch} to {loss_plot_path}")
    
    # ******** Save only the grid plot without colorbar
    # fig2, ax3 = plt.subplots(figsize=(7, 6))
    # cax2 = ax3.imshow(transformed, cmap='viridis', interpolation='nearest')
    # ax3.set_title(f'Grid Transformation (Epoch {epoch})')
    # fig2.savefig(grid_plot_path)
    # print(f"Saved grid plot for epoch {epoch} to {grid_plot_path}")
    
    # Close the figure to free up memory
    plt.close(fig)
    # plt.close(fig2)

def transform_grid(grid, simple_agents, epoch, device='cuda'):
    transformed = np.array(grid)  # Initialize with input grid
    
    # Convert grid to tensor and move it to the specified device
    grid_tensor = torch.tensor(grid, dtype=torch.float32).to(device)
    
    for position in simple_agents:
        x_start = max(0, position[0])
        y_start = max(0, position[1])
        x_end = min(len(transformed), position[0] + 3)
        y_end = min(len(transformed[0]), position[1] + 3)

        local_area = grid[x_start:x_end, y_start:y_end].flatten()
        
        if len(local_area) != 9:
            # print(f"Padded local_area shape: {local_area.shape}")
            local_area = np.pad(local_area, (0, 9 - len(local_area)), 'constant')
        
        # Convert local_area to tensor and move to device
        local_area_tensor = torch.tensor(local_area, dtype=torch.float32).to(device)
        
        prediction = simple_agent_models[position](torch.cat((local_area_tensor.unsqueeze(0), torch.tensor([epoch]).float().unsqueeze(0).to(device)), dim=1)).cpu().argmax(dim=1).item()  # Move back to CPU for numpy conversion
        
        # print(f"Agent at position {position} predicted: {prediction}")
        
        transformed[position[0]][position[1]] = prediction
            
    return transformed

def save_initial_grid(grid, prefix, plot_dir):
    # Generate a random grid with values between 0 and 9
    # initial_grid = torch.randint(10, size=grid.shape).float()
    initial_grid = grid
    
    fig, ax = plt.subplots(figsize=(7, 6))
    cax = ax.imshow(initial_grid.numpy(), cmap='viridis', interpolation='nearest')
    ax.set_title(f'Initial Grid (Prefix: {prefix})')
    
    # Save the plot to a file
    initial_grid_plot_path = os.path.join(plot_dir, f'{prefix}_initial_grid.png')
    plt.savefig(initial_grid_plot_path)
    print(f"Saved initial grid plot to {initial_grid_plot_path}")
    
    # Close the figure to free up memory
    plt.close(fig)

# Function to load ARC JSON data and convert to tensors
def load_arc_data(json_file):
    with open(json_file, 'r') as f:
        data = json.load(f)
    
    inputs = []
    outputs = []
    
    for task in data['train']:
        input_grid = torch.tensor(task['input'], dtype=torch.float32)
        output_grid = torch.tensor(task['output'], dtype=torch.float32)
        inputs.append(input_grid)
        outputs.append(output_grid)
    
    return inputs, outputs

# Define the plot directory and model directory
plot_dir = '/home/xaqmusic/plots'
model_dir = '/home/xaqmusic/models'

# Create the directories if they don't exist
os.makedirs(plot_dir, exist_ok=True)
os.makedirs(model_dir, exist_ok=True)

# Generate a unique prefix for this run using Unix time (epoch seconds since 1970)
prefix = str(int(datetime.now().timestamp()))  # Example: "1693425180"

# Initialize plots at the beginning of your script
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 6))
loss_plot, = ax1.plot([], [], marker='o')
ax1.set_title('Loss Over Epochs')
ax1.set_xlabel('Epoch')
ax1.set_ylabel('Loss')

# Example usage
if __name__ == "__main__":
    json_file = '/home/xaqmusic/ARC-AGI/data/training/00d62c1b.json'
    
    # Load ARC JSON data and convert to tensors
    inputs, outputs = load_arc_data(json_file)
    
    for i in range(len(inputs)):
        input_grid = inputs[i]
        output_grid = outputs[i]
        
        grid_size = input_grid.shape
        
        simple_agents = distribute_agents(grid_size)
        
        print("Creating agent models...")
        num_simple_agents = len(simple_agents)
        simple_agent_models = {}
        for position in simple_agents:
            model_path = os.path.join(model_dir, f'simple_agent_{position[0]}_{position[1]}.pth')
            if os.path.exists(model_path):
                print(f"Loading model from {model_path}")
                model = SimpleAgent()
                model.load_state_dict(torch.load(model_path))
            else:
                model = SimpleAgent()
            simple_agent_models[position] = model
        
        device = 'cuda' if torch.cuda.is_available() else 'cpu'
        print(f"Using device: {device}")
        
        # **** Generate a random initial grid with values between 0 and 9
        # initial_grid = torch.randint(10, size=grid_size).float().to(device)
        # **** or just the input grid to start
        initial_grid = input_grid
        print(f"\nInitial Grid:\n{initial_grid}")
        
        train_inputs = [initial_grid]
        # train_inputs = [input_grid.to(device)]
        print(f"\nInput Grid:\n{train_inputs}")
        train_outputs = [output_grid.to(device)]
        print(f"\nOutput Grid:\n{train_outputs}")
        
        # Save the initial grid before training
        save_initial_grid(initial_grid.cpu(), prefix, plot_dir)
        # pause for debugging
        # input("Press Enter to continue...")
        
        print("Training Agents...")
        train_agents(fig, ax1, ax2, simple_agent_models, train_inputs, train_outputs, epochs=1000, device=device, plot_dir=plot_dir, model_dir=model_dir)
        
        print("\nTransforming Grid...")
        transformed = transform_grid(input_grid.cpu(), simple_agents, epoch=0, device=device)
        print("\nInput Grid:")
        print(input_grid.cpu().numpy())
        print("\nOutput Grid after Transformation (Desired):")
        print(output_grid.cpu().numpy())
        print("\nOutput Grid after Transformation (Model's Prediction):")
        print(transformed)
        # pause for debugging
        # input("Press Enter to continue...")


## 0.3 Added cross entropy loss and accumulation during validation steps

In [None]:
import json
import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt
from IPython.display import display, clear_output
from tqdm import tqdm
import os
from datetime import datetime

# Ensure the correct backend is used for interactive plots
import matplotlib
matplotlib.use('inline')
%matplotlib inline

class SimpleAgent(nn.Module):
    def __init__(self, input_size=18):  # Updated to handle concatenated input size of 18
        super(SimpleAgent, self).__init__()
        self.fc1 = nn.Linear(input_size, 36)  # First hidden layer with larger capacity for the epoch input
        self.relu = nn.ReLU()
        self.fc2 = nn.Linear(36, 36)  # Second hidden layer
        self.fc3 = nn.Linear(36, 10)  # Output layer
        
    def forward(self, x):
        # print(f"Input shape before concatenation: {x.shape}")
        x = torch.cat((x[:, :-1], x[:, -1].unsqueeze(1).repeat(1, 9)), dim=1)  # Repeat the epoch value to match grid input
        # print(f"Input shape after concatenation: {x.shape}")
        x = self.relu(self.fc1(x))
        x = self.relu(self.fc2(x))
        return torch.nn.functional.log_softmax(self.fc3(x), dim=1)

def distribute_agents(grid_size, agent_scope=3):
    agents = []
    step_size = 1  # Place an agent at every cell
    for i in range(0, grid_size[0], step_size):  
        for j in range(0, grid_size[1], step_size):
            agents.append((i, j))  # Record agent positions
    return agents

def train_agents(fig, ax1, ax2, simple_agent_models, inputs, outputs, epochs=4, learning_rate=0.0005, device='cuda', plot_dir=None, model_dir=None):
    criterion = nn.MSELoss()  # Use MSE Loss
    
    print("Moving models to the specified device...")
    for position, model in simple_agent_models.items():
        model.to(device)
        print(f"Simple Agent at position {position} moved to {device}")

    # Flatten list of lists into a single list of parameters
    params_to_optimize = [param for _, model in simple_agent_models.items() for param in model.parameters()]
    
    optimizer = torch.optim.Adam(params_to_optimize, lr=learning_rate)
    
    losses = []  # Initialize an empty list to store all epoch losses
    
    for epoch in tqdm(range(epochs), desc="Training", leave=False):
        print(f"Epoch {epoch + 1}/{epochs}")
        total_loss = 0
        num_samples = len(inputs)
        
        # Training phase
        for i, (x, y) in enumerate(zip(inputs, outputs)):
            optimizer.zero_grad()
            
            simple_predictions = []  # Initialize this list at the start of each iteration
            
            # Collect predictions from simple agents
            for position, model in simple_agent_models.items():
                x_start = max(0, position[0])
                y_start = max(0, position[1])
                x_end = min(x.size(0), position[0] + 3)
                y_end = min(x.size(1), position[1] + 3)

                local_x = x[x_start:x_end, y_start:y_end].flatten().to(device)
                
                if len(local_x) != 9:
                     # print(f"Padded local_x shape: {local_x.shape}")
                    local_x = torch.nn.functional.pad(local_x, (0, 9 - len(local_x)))
                
                # Normalize the epoch
                normalized_epoch = float(epoch) / epochs
                
                # Add epoch to the input
                local_x_with_epoch = torch.cat((local_x, torch.tensor([normalized_epoch]).float().to(device)), dim=0)
                
                # Forward pass for the current simple agent
                prediction = model(local_x_with_epoch.unsqueeze(0).to(device))
                # print(f"Simple Agent at position {position} prediction shape: {prediction.shape}")
                
                # Append the prediction to the list (ensure it's on CPU)
                simple_predictions.append(prediction.squeeze().cpu())
            
            combined_pred = torch.stack(simple_predictions)
            
            # Prepare target values for each prediction
            y_flat = y.flatten().long().to(device)
            simple_targets = []
            
            for position in simple_agent_models:
                x_start = max(0, position[0])
                y_start = max(0, position[1])
                x_end = min(x.size(0), position[0] + 3)
                y_end = min(x.size(1), position[1] + 3)

                local_y = y[x_start:x_end, y_start:y_end].flatten().to(device)
                
                if len(local_y) != 9:
                    # print(f"Padded local_y shape: {local_y.shape}")
                    local_y = torch.nn.functional.pad(local_y, (0, 9 - len(local_y)))
                
                # Correct the dimension for argmax
                simple_targets.append(local_y.argmax(dim=0).cpu())
            
            # Stack all targets and ensure they match the batch size of combined_pred
            all_targets = torch.stack(simple_targets)
            
            # Calculate loss for each prediction separately
            total_loss_this_sample = 0
            for pred, target in zip(combined_pred, all_targets):
                # Convert target to float before calculating loss
                target_float = target.float()
                loss = criterion(pred.unsqueeze(0), target_float.unsqueeze(0))
                total_loss_this_sample += loss.item()
                loss.backward(retain_graph=True)  # Retain graph to allow backpropagation through all agents
            
            total_loss += total_loss_this_sample
            optimizer.step()
        
        average_loss = total_loss / num_samples
        losses.append(average_loss)  # Append the current epoch's average loss
        
        # Validation phase
        validation_loss = 0

        # Initialize transformed grid with zeros for consistency
        transformed_grid_tensor = torch.zeros_like(inputs[0], dtype=torch.float32).to(device)

        with torch.no_grad():
            for i, (x, y) in enumerate(zip(inputs, outputs)):
                simple_predictions = []

                # Collect predictions from simple agents
                for position, model in simple_agent_models.items():
                    x_start = max(0, position[0])
                    y_start = max(0, position[1])
                    x_end = min(x.size(0), position[0] + 3)
                    y_end = min(x.size(1), position[1] + 3)

                    local_x = x[x_start:x_end, y_start:y_end].flatten().to(device)

                    if len(local_x) != 9:
                        local_x = torch.nn.functional.pad(local_x, (0, 9 - len(local_x)))

                    # Normalize epoch
                    normalized_epoch = float(epoch) / epochs

                    # Add epoch to the input
                    local_x_with_epoch = torch.cat((local_x, torch.tensor([normalized_epoch]).float().to(device)), dim=0)

                    # Forward pass for the current simple agent
                    prediction = model(local_x_with_epoch.unsqueeze(0).to(device))
                    
                    # Append the prediction (logits) to the list (ensure it's on CPU)
                    simple_predictions.append(prediction.squeeze().cpu())

                combined_pred = torch.stack(simple_predictions)

                # Prepare target values for each prediction (validation phase)
                y_flat = y.flatten().long().to(device)
                simple_targets = []

                for position in simple_agent_models:
                    x_start = max(0, position[0])
                    y_start = max(0, position[1])
                    x_end = min(y.size(0), position[0] + 3)
                    y_end = min(y.size(1), position[1] + 3)

                    local_y = y[x_start:x_end, y_start:y_end].flatten().to(device)

                    if len(local_y) != 9:
                        local_y = torch.nn.functional.pad(local_y, (0, 9 - len(local_y)))

                    # Convert to long and append the target
                    simple_targets.append(local_y.argmax(dim=0).cpu())

                all_targets = torch.stack(simple_targets)

                # Calculate validation loss for each prediction separately
                total_validation_loss_this_sample = 0

                for pred, target in zip(combined_pred, all_targets):
                    if pred.numel() == 0 or target.numel() == 0:
                        raise ValueError(f"Zero-sized tensor encountered. pred shape: {pred.shape}, target shape: {target.shape}")
                    
                    # Ensure the target is a LongTensor for CrossEntropyLoss
                    loss = criterion(pred.unsqueeze(0), target.long().unsqueeze(0))
                    total_validation_loss_this_sample += loss.item()

                validation_loss += total_validation_loss_this_sample

        average_validation_loss = validation_loss / num_samples

        # Update transformed grid based on predictions (if needed)
        for j, position in enumerate(simple_agents):
            prediction = combined_pred[j].argmax(dim=0).item()
            transformed_grid_tensor[position[0]][position[1]] = prediction
        
        # Update and display the grid visualization
        transformed = transform_grid(inputs[0].cpu(), simple_agents, epoch=epoch, device=device)
        
        update_plots(fig, ax1, ax2, epoch + 1, losses, transformed, plot_dir)

        if average_loss < 0.000001 and average_validation_loss < 0.000001:
            print("Training complete: Both training and validation loss are below 0.00001.")
            break

        # Update the grid for the next epoch
        inputs[0] = transformed_grid_tensor.to(device)
        clear_output(wait=True)
        print(f"Updated Grid after Epoch {epoch + 1}:\n{transformed}")
        print(f'Epoch {epoch+1}, Training Loss: {average_loss:.4f}, Validation Loss: {average_validation_loss:.4f}')
    
    # Save the model after training
    for position, model in simple_agent_models.items():
        model_path = os.path.join(model_dir, f'simple_agent_{position[0]}_{position[1]}.pth')
        torch.save(model.state_dict(), model_path)
        # print(f"Saved model at position {position} to {model_path}")

def update_plots(fig, ax1, ax2, epoch, losses, transformed, plot_dir):
    # Clear previous plots
    ax1.clear()
    ax2.clear()

    # Update the loss plot
    ax1.plot(range(1, len(losses) + 1), losses, marker='o', color='blue')
    ax1.set_title(f'Loss Over Epochs (Epoch {epoch})')
    ax1.set_xlabel('Epoch')
    ax1.set_ylabel('Loss')

    # Ensure transformed is a NumPy array before passing to imshow
    if isinstance(transformed, torch.Tensor):
        transformed = transformed.cpu().numpy()
    
    # Update the grid visualization
    cax = ax2.imshow(transformed, cmap='viridis', interpolation='nearest')  # Convert tensor to numpy array
    ax2.set_title(f'Grid Transformation (Epoch {epoch})')
    
    # Save the plot to a file with a unique name based on the epoch number with linux timestamp
    prefix = str(int(datetime.now().timestamp()))
    loss_plot_path = os.path.join(plot_dir, f'{prefix}_loss_epoch_{epoch}.png')
    grid_plot_path = os.path.join(plot_dir, f'{prefix}_grid_epoch_{epoch}.png')
    
    fig.savefig(loss_plot_path)
    print(f"Saved loss plot for epoch {epoch} to {loss_plot_path}")
    
    # ******** Save only the grid plot without colorbar
    # fig2, ax3 = plt.subplots(figsize=(7, 6))
    # cax2 = ax3.imshow(transformed, cmap='viridis', interpolation='nearest')
    # ax3.set_title(f'Grid Transformation (Epoch {epoch})')
    # fig2.savefig(grid_plot_path)
    # print(f"Saved grid plot for epoch {epoch} to {grid_plot_path}")
    
    # Close the figure to free up memory
    plt.close(fig)
    # plt.close(fig2)

def transform_grid(grid, simple_agents, epoch, device='cuda'):
    transformed = np.array(grid)  # Initialize with input grid
    
    # Convert grid to tensor and move it to the specified device
    grid_tensor = torch.tensor(grid, dtype=torch.float32).to(device)
    
    for position in simple_agents:
        x_start = max(0, position[0])
        y_start = max(0, position[1])
        x_end = min(len(transformed), position[0] + 3)
        y_end = min(len(transformed[0]), position[1] + 3)

        local_area = grid[x_start:x_end, y_start:y_end].flatten()
        
        if len(local_area) != 9:
            # print(f"Padded local_area shape: {local_area.shape}")
            local_area = np.pad(local_area, (0, 9 - len(local_area)), 'constant')
        
        # Convert local_area to tensor and move to device
        local_area_tensor = torch.tensor(local_area, dtype=torch.float32).to(device)
        
        prediction = simple_agent_models[position](torch.cat((local_area_tensor.unsqueeze(0), torch.tensor([epoch]).float().unsqueeze(0).to(device)), dim=1)).cpu().argmax(dim=1).item()  # Move back to CPU for numpy conversion
        
        # print(f"Agent at position {position} predicted: {prediction}")
        
        transformed[position[0]][position[1]] = prediction
            
    return transformed

def save_initial_grid(grid, prefix, plot_dir):
    # Generate a random grid with values between 0 and 9
    # initial_grid = torch.randint(10, size=grid.shape).float()
    initial_grid = grid
    
    fig, ax = plt.subplots(figsize=(7, 6))
    cax = ax.imshow(initial_grid.numpy(), cmap='viridis', interpolation='nearest')
    ax.set_title(f'Initial Grid (Prefix: {prefix})')
    
    # Save the plot to a file
    initial_grid_plot_path = os.path.join(plot_dir, f'{prefix}_initial_grid.png')
    plt.savefig(initial_grid_plot_path)
    print(f"Saved initial grid plot to {initial_grid_plot_path}")
    
    # Close the figure to free up memory
    plt.close(fig)

# Function to load ARC JSON data and convert to tensors
def load_arc_data(json_file):
    with open(json_file, 'r') as f:
        data = json.load(f)
    
    inputs = []
    outputs = []
    
    for task in data['train']:
        input_grid = torch.tensor(task['input'], dtype=torch.float32)
        output_grid = torch.tensor(task['output'], dtype=torch.float32)
        inputs.append(input_grid)
        outputs.append(output_grid)
    
    return inputs, outputs

# Define the plot directory and model directory
plot_dir = '/home/xaqmusic/plots'
model_dir = '/home/xaqmusic/models'

# Create the directories if they don't exist
os.makedirs(plot_dir, exist_ok=True)
os.makedirs(model_dir, exist_ok=True)

# Generate a unique prefix for this run using Unix time (epoch seconds since 1970)
prefix = str(int(datetime.now().timestamp()))  # Example: "1693425180"

# Initialize plots at the beginning of your script
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 6))
loss_plot, = ax1.plot([], [], marker='o')
ax1.set_title('Loss Over Epochs')
ax1.set_xlabel('Epoch')
ax1.set_ylabel('Loss')

# Example usage
if __name__ == "__main__":
    json_file = '/home/xaqmusic/ARC-AGI/data/training/00d62c1b.json'
    
    # Load ARC JSON data and convert to tensors
    inputs, outputs = load_arc_data(json_file)
    
    for i in range(len(inputs)):
        input_grid = inputs[i]
        output_grid = outputs[i]
        
        grid_size = input_grid.shape
        
        simple_agents = distribute_agents(grid_size)
        
        print("Creating agent models...")
        num_simple_agents = len(simple_agents)
        simple_agent_models = {}
        for position in simple_agents:
            model_path = os.path.join(model_dir, f'simple_agent_{position[0]}_{position[1]}.pth')
            if os.path.exists(model_path):
                print(f"Loading model from {model_path}")
                model = SimpleAgent()
                model.load_state_dict(torch.load(model_path))
            else:
                model = SimpleAgent()
            simple_agent_models[position] = model
        
        device = 'cuda' if torch.cuda.is_available() else 'cpu'
        print(f"Using device: {device}")
        
        # **** Generate a random initial grid with values between 0 and 9
        # initial_grid = torch.randint(10, size=grid_size).float().to(device)
        # **** or just the input grid to start
        initial_grid = input_grid
        print(f"\nInitial Grid:\n{initial_grid}")
        
        train_inputs = [initial_grid]
        # train_inputs = [input_grid.to(device)]
        print(f"\nInput Grid:\n{train_inputs}")
        train_outputs = [output_grid.to(device)]
        print(f"\nOutput Grid:\n{train_outputs}")
        
        # Save the initial grid before training
        save_initial_grid(initial_grid.cpu(), prefix, plot_dir)
        # pause for debugging
        # input("Press Enter to continue...")
        
        print("Training Agents...")
        train_agents(fig, ax1, ax2, simple_agent_models, train_inputs, train_outputs, epochs=1000, device=device, plot_dir=plot_dir, model_dir=model_dir)
        
        print("\nTransforming Grid...")
        transformed = transform_grid(input_grid.cpu(), simple_agents, epoch=0, device=device)
        print("\nInput Grid:")
        print(input_grid.cpu().numpy())
        print("\nOutput Grid after Transformation (Desired):")
        print(output_grid.cpu().numpy())
        print("\nOutput Grid after Transformation (Model's Prediction):")
        print(transformed)
        # pause for debugging
        input("Press Enter to continue...")


## 0.4 Loss decreses with improvements

In [None]:
import json
import math
import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt
from IPython.display import display, clear_output
from tqdm import tqdm
import os
from datetime import datetime

# Ensure the correct backend is used for interactive plots
import matplotlib
matplotlib.use('inline')
%matplotlib inline

class SimpleAgent(nn.Module):
    def __init__(self, input_size=18):  # Updated to handle concatenated input size of 18
        super(SimpleAgent, self).__init__()
        self.fc1 = nn.Linear(input_size, 36)  # First hidden layer with larger capacity for the epoch input
        self.relu = nn.ReLU()
        self.fc2 = nn.Linear(36, 36)  # Second hidden layer
        self.fc3 = nn.Linear(36, 10)  # Output layer
        
    def forward(self, x):
        # print(f"Input shape before concatenation: {x.shape}")
        x = torch.cat((x[:, :-1], x[:, -1].unsqueeze(1).repeat(1, 9)), dim=1)  # Repeat the epoch value to match grid input
        # print(f"Input shape after concatenation: {x.shape}")
        x = self.relu(self.fc1(x))
        x = self.relu(self.fc2(x))
        return torch.nn.functional.log_softmax(self.fc3(x), dim=1)

def distribute_agents(grid_size, agent_scope=3):
    agents = []
    step_size = 1  # Place an agent at every cell
    for i in range(0, grid_size[0], step_size):  
        for j in range(0, grid_size[1], step_size):
            agents.append((i, j))  # Record agent positions
    return agents
# ************************ set initial learning rate here
def train_agents(fig, ax1, ax2, simple_agent_models, inputs, outputs, epochs=4, learning_rate=0.01, device='cuda', plot_dir=None, model_dir=None):
    criterion = nn.MSELoss()  # Use MSE Loss
    
    print("Moving models to the specified device...")
    for position, model in simple_agent_models.items():
        model.to(device)
        print(f"Simple Agent at position {position} moved to {device}")

    # Flatten list of lists into a single list of parameters
    params_to_optimize = [param for _, model in simple_agent_models.items() for param in model.parameters()]
    
    optimizer = torch.optim.Adam(params_to_optimize, lr=learning_rate)
    
    losses = []  # Initialize an empty list to store all epoch losses
    
    for epoch in tqdm(range(epochs), desc="Training", leave=False):
        print(f"Epoch {epoch + 1}/{epochs}")
        total_loss = 0
        num_samples = len(inputs)
        
        # Training phase
        for i, (x, y) in enumerate(zip(inputs, outputs)):
            optimizer.zero_grad()
            
            simple_predictions = []  # Initialize this list at the start of each iteration
            
            # Collect predictions from simple agents
            for position, model in simple_agent_models.items():
                x_start = max(0, position[0])
                y_start = max(0, position[1])
                x_end = min(x.size(0), position[0] + 3)
                y_end = min(x.size(1), position[1] + 3)

                local_x = x[x_start:x_end, y_start:y_end].flatten().to(device)
                
                if len(local_x) != 9:
                     # print(f"Padded local_x shape: {local_x.shape}")
                    local_x = torch.nn.functional.pad(local_x, (0, 9 - len(local_x)))
                
                # Normalize the epoch
                normalized_epoch = float(epoch) / epochs
                
                # Add epoch to the input
                local_x_with_epoch = torch.cat((local_x, torch.tensor([normalized_epoch]).float().to(device)), dim=0)
                
                # Forward pass for the current simple agent
                prediction = model(local_x_with_epoch.unsqueeze(0).to(device))
                # print(f"Simple Agent at position {position} prediction shape: {prediction.shape}")
                
                # Append the prediction to the list (ensure it's on CPU)
                simple_predictions.append(prediction.squeeze().cpu())
            
            combined_pred = torch.stack(simple_predictions)
            
            # Prepare target values for each prediction
            y_flat = y.flatten().long().to(device)
            simple_targets = []
            
            for position in simple_agent_models:
                x_start = max(0, position[0])
                y_start = max(0, position[1])
                x_end = min(x.size(0), position[0] + 3)
                y_end = min(x.size(1), position[1] + 3)

                local_y = y[x_start:x_end, y_start:y_end].flatten().to(device)
                
                if len(local_y) != 9:
                    # print(f"Padded local_y shape: {local_y.shape}")
                    local_y = torch.nn.functional.pad(local_y, (0, 9 - len(local_y)))
                
                # Correct the dimension for argmax
                simple_targets.append(local_y.argmax(dim=0).cpu())
            
            # Stack all targets and ensure they match the batch size of combined_pred
            all_targets = torch.stack(simple_targets)
            
            # Calculate loss for each prediction separately
            total_loss_this_sample = 0
            for pred, target in zip(combined_pred, all_targets):
                # Convert target to float before calculating loss
                target_float = target.float()
                loss = criterion(pred.unsqueeze(0), target_float.unsqueeze(0))
                total_loss_this_sample += loss.item()
                loss.backward(retain_graph=True)  # Retain graph to allow backpropagation through all agents
            
            total_loss += total_loss_this_sample
            optimizer.step()
        
        average_loss = total_loss / num_samples
        losses.append(average_loss)  # Append the current epoch's average loss
        
        # Validation phase
        validation_loss = 0

        # Initialize transformed grid with zeros for consistency
        transformed_grid_tensor = torch.zeros_like(inputs[0], dtype=torch.float32).to(device)

        with torch.no_grad():
            for i, (x, y) in enumerate(zip(inputs, outputs)):
                simple_predictions = []

                # Collect predictions from simple agents
                for position, model in simple_agent_models.items():
                    x_start = max(0, position[0])
                    y_start = max(0, position[1])
                    x_end = min(x.size(0), position[0] + 3)
                    y_end = min(x.size(1), position[1] + 3)

                    local_x = x[x_start:x_end, y_start:y_end].flatten().to(device)

                    if len(local_x) != 9:
                        local_x = torch.nn.functional.pad(local_x, (0, 9 - len(local_x)))

                    # Normalize epoch
                    normalized_epoch = float(epoch) / epochs

                    # Add epoch to the input
                    local_x_with_epoch = torch.cat((local_x, torch.tensor([normalized_epoch]).float().to(device)), dim=0)

                    # Forward pass for the current simple agent
                    prediction = model(local_x_with_epoch.unsqueeze(0).to(device))
                    
                    # Append the prediction (logits) to the list (ensure it's on CPU)
                    simple_predictions.append(prediction.squeeze().cpu())

                combined_pred = torch.stack(simple_predictions)

                # Prepare target values for each prediction (validation phase)
                y_flat = y.flatten().long().to(device)
                simple_targets = []

                for position in simple_agent_models:
                    x_start = max(0, position[0])
                    y_start = max(0, position[1])
                    x_end = min(y.size(0), position[0] + 3)
                    y_end = min(y.size(1), position[1] + 3)

                    local_y = y[x_start:x_end, y_start:y_end].flatten().to(device)

                    if len(local_y) != 9:
                        local_y = torch.nn.functional.pad(local_y, (0, 9 - len(local_y)))

                    # Convert to long and append the target
                    simple_targets.append(local_y.argmax(dim=0).cpu())

                all_targets = torch.stack(simple_targets)

                # Calculate validation loss for each prediction separately
                total_validation_loss_this_sample = 0

                for pred, target in zip(combined_pred, all_targets):
                    if pred.numel() == 0 or target.numel() == 0:
                        raise ValueError(f"Zero-sized tensor encountered. pred shape: {pred.shape}, target shape: {target.shape}")
                    
                    # Ensure the target is a LongTensor for CrossEntropyLoss
                    loss = criterion(pred.unsqueeze(0), target.long().unsqueeze(0))
                    total_validation_loss_this_sample += loss.item()

                validation_loss += total_validation_loss_this_sample

        average_validation_loss = validation_loss / num_samples

        # Update transformed grid based on predictions (if needed)
        for j, position in enumerate(simple_agents):
            prediction = combined_pred[j].argmax(dim=0).item()
            transformed_grid_tensor[position[0]][position[1]] = prediction
        
        # Update and display the grid visualization
        transformed = transform_grid(inputs[0].cpu(), simple_agents, epoch=epoch, device=device)
        
        update_plots(fig, ax1, ax2, epoch + 1, losses, transformed, plot_dir)

        if average_loss < 0.000001 and average_validation_loss < 0.000001:
            print("Training complete: Both training and validation loss are below 0.00001.")
            break

        # ***********  Adjust the learning rate as loss decreases
        initial_lr = optimizer.param_groups[0]['lr']
        max_epoch = epochs
        k = 0.5  # higher value will make transition to lower rate toward end of training steeper (0.1 to 0.9)
        min_lr = 0.001  # Minimum learning rate
        new_lr = initial_lr
        decay_factor = 0
        if epoch > 0 and losses[-1] < losses[-2]:
            # Calculate the decay factor using a sigmoid function
            decay_factor = 1 / (1 + math.exp(-k * (epoch - max_epoch / 2)))
            new_lr = initial_lr * (1 - decay_factor)
            # Ensure the learning rate does not decrease below the minimum threshold
            if new_lr > min_lr:
                optimizer.param_groups[0]['lr'] = new_lr
                print(f"Reducing learning rate from {initial_lr} to {new_lr}")
            else:
                print(f"Learning rate would have decreased to {new_lr}, but it has been set to the minimum threshold of {min_lr}.")
                optimizer.param_groups[0]['lr'] = min_lr
                new_lr = min_lr
            for param_group in optimizer.param_groups:
                param_group['lr'] = new_lr

        # Update the grid for the next epoch
        inputs[0] = transformed_grid_tensor.to(device)
        clear_output(wait=True)
        print(f"Updated Grid after Epoch {epoch + 1}:\n{transformed}")
        print(f'Epoch {epoch+1}, Training Loss: {average_loss:.4f}, Validation Loss: {average_validation_loss:.4f}')
        print(f"\nNew Learning Rate:{new_lr}")
    
    # Save the model after training
    for position, model in simple_agent_models.items():
        model_path = os.path.join(model_dir, f'simple_agent_{position[0]}_{position[1]}.pth')
        torch.save(model.state_dict(), model_path)
        # print(f"Saved model at position {position} to {model_path}")

def update_plots(fig, ax1, ax2, epoch, losses, transformed, plot_dir):
    # Clear previous plots
    ax1.clear()
    ax2.clear()

    # Update the loss plot
    ax1.plot(range(1, len(losses) + 1), losses, marker='o', color='blue')
    ax1.set_title(f'Loss Over Epochs (Epoch {epoch})')
    ax1.set_xlabel('Epoch')
    ax1.set_ylabel('Loss')

    # Ensure transformed is a NumPy array before passing to imshow
    if isinstance(transformed, torch.Tensor):
        transformed = transformed.cpu().numpy()
    
    # Update the grid visualization
    cax = ax2.imshow(transformed, cmap='inferno', interpolation='nearest')  # Convert tensor to numpy array
    ax2.set_title(f'Grid Transformation (Epoch {epoch})')
    
    # Save the plot to a file with a unique name based on the epoch number with linux timestamp
    prefix = str(int(datetime.now().timestamp()))
    loss_plot_path = os.path.join(plot_dir, f'{prefix}_loss_epoch_{epoch}.png')
    grid_plot_path = os.path.join(plot_dir, f'{prefix}_grid_epoch_{epoch}.png')
    
    fig.savefig(loss_plot_path)
    print(f"Saved loss plot for epoch {epoch} to {loss_plot_path}")
    
    # ******** Save only the grid plot without colorbar
    # fig2, ax3 = plt.subplots(figsize=(7, 6))
    # cax2 = ax3.imshow(transformed, cmap='viridis', interpolation='nearest')
    # ax3.set_title(f'Grid Transformation (Epoch {epoch})')
    # fig2.savefig(grid_plot_path)
    # print(f"Saved grid plot for epoch {epoch} to {grid_plot_path}")
    
    # Close the figure to free up memory
    plt.close(fig)
    # plt.close(fig2)

def transform_grid(grid, simple_agents, epoch, device='cuda'):
    transformed = np.array(grid)  # Initialize with input grid
    
    # Convert grid to tensor and move it to the specified device
    grid_tensor = torch.tensor(grid, dtype=torch.float32).to(device)
    
    for position in simple_agents:
        x_start = max(0, position[0])
        y_start = max(0, position[1])
        x_end = min(len(transformed), position[0] + 3)
        y_end = min(len(transformed[0]), position[1] + 3)

        local_area = grid[x_start:x_end, y_start:y_end].flatten()
        
        if len(local_area) != 9:
            # print(f"Padded local_area shape: {local_area.shape}")
            local_area = np.pad(local_area, (0, 9 - len(local_area)), 'constant')
        
        # Convert local_area to tensor and move to device
        local_area_tensor = torch.tensor(local_area, dtype=torch.float32).to(device)
        
        prediction = simple_agent_models[position](torch.cat((local_area_tensor.unsqueeze(0), torch.tensor([epoch]).float().unsqueeze(0).to(device)), dim=1)).cpu().argmax(dim=1).item()  # Move back to CPU for numpy conversion
        
        # print(f"Agent at position {position} predicted: {prediction}")
        
        transformed[position[0]][position[1]] = prediction
            
    return transformed

def save_initial_grid(grid, prefix, plot_dir):
    # Generate a random grid with values between 0 and 9
    # initial_grid = torch.randint(10, size=grid.shape).float()
    initial_grid = grid
    
    fig, ax = plt.subplots(figsize=(7, 6))
    cax = ax.imshow(initial_grid.numpy(), cmap='inferno', interpolation='nearest')
    ax.set_title(f'Initial Grid (Prefix: {prefix})')
    
    # Save the plot to a file
    initial_grid_plot_path = os.path.join(plot_dir, f'{prefix}_initial_grid.png')
    plt.savefig(initial_grid_plot_path)
    print(f"Saved initial grid plot to {initial_grid_plot_path}")
    
    # Close the figure to free up memory
    plt.close(fig)

# Function to load ARC JSON data and convert to tensors
def load_arc_data(json_file):
    with open(json_file, 'r') as f:
        data = json.load(f)
    
    inputs = []
    outputs = []
    
    for task in data['train']:
        input_grid = torch.tensor(task['input'], dtype=torch.float32)
        output_grid = torch.tensor(task['output'], dtype=torch.float32)
        inputs.append(input_grid)
        outputs.append(output_grid)
    
    return inputs, outputs

# Define the plot directory and model directory
plot_dir = '/home/xaqmusic/plots'
model_dir = '/home/xaqmusic/models'

# Create the directories if they don't exist
os.makedirs(plot_dir, exist_ok=True)
os.makedirs(model_dir, exist_ok=True)

# Generate a unique prefix for this run using Unix time (epoch seconds since 1970)
prefix = str(int(datetime.now().timestamp()))  # Example: "1693425180"

# Initialize plots at the beginning of your script
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 6))
loss_plot, = ax1.plot([], [], marker='o')
ax1.set_title('Loss Over Epochs')
ax1.set_xlabel('Epoch')
ax1.set_ylabel('Loss')

# Example usage
if __name__ == "__main__":
    json_file = '/home/xaqmusic/ARC-AGI/data/training/00d62c1b.json'
    
    # Load ARC JSON data and convert to tensors
    inputs, outputs = load_arc_data(json_file)
    
    for i in range(len(inputs)):
        input_grid = inputs[i]
        output_grid = outputs[i]
        
        grid_size = input_grid.shape
        
        simple_agents = distribute_agents(grid_size)
        
        print("Creating agent models...")
        num_simple_agents = len(simple_agents)
        simple_agent_models = {}
        for position in simple_agents:
            model_path = os.path.join(model_dir, f'simple_agent_{position[0]}_{position[1]}.pth')
            if os.path.exists(model_path):
                print(f"Loading model from {model_path}")
                model = SimpleAgent()
                model.load_state_dict(torch.load(model_path))
            else:
                model = SimpleAgent()
            simple_agent_models[position] = model
        
        device = 'cuda' if torch.cuda.is_available() else 'cpu'
        print(f"Using device: {device}")
        
        # **** Generate a random initial grid with values between 0 and 9
        # initial_grid = torch.randint(10, size=grid_size).float().to(device)
        # **** or just the input grid to start
        initial_grid = input_grid
        print(f"\nInitial Grid:\n{initial_grid}")
        
        train_inputs = [initial_grid]
        # train_inputs = [input_grid.to(device)]
        print(f"\nInput Grid:\n{train_inputs}")
        train_outputs = [output_grid.to(device)]
        print(f"\nOutput Grid:\n{train_outputs}")
        
        # Save the initial grid before training
        save_initial_grid(initial_grid.cpu(), prefix, plot_dir)
        # pause for debugging
        # input("Press Enter to continue...")
        
        print("Training Agents...")
        train_agents(fig, ax1, ax2, simple_agent_models, train_inputs, train_outputs, epochs=700, device=device, plot_dir=plot_dir, model_dir=model_dir)
        
        print("\nTransforming Grid...")
        transformed = transform_grid(input_grid.cpu(), simple_agents, epoch=0, device=device)
        print("\nInput Grid:")
        print(input_grid.cpu().numpy())
        print("\nOutput Grid after Transformation (Desired):")
        print(output_grid.cpu().numpy())
        print("\nOutput Grid after Transformation (Model's Prediction):")
        print(transformed)
        # pause for debugging
        # input("Press Enter to continue...")


Training:  96%|█████████████████████████████████████████████████████████████████▎  | 672/700 [12:08:32<30:49, 66.06s/it]

Updated Grid after Epoch 672:
[[0. 0. 0. 3. 9. 0. 3. 0. 0. 0. 3. 2. 6. 5. 5. 4. 1. 4. 8. 0.]
 [0. 0. 0. 0. 0. 0. 4. 9. 0. 0. 2. 0. 0. 5. 5. 7. 1. 0. 2. 2.]
 [3. 0. 0. 7. 0. 4. 0. 0. 0. 7. 1. 0. 9. 3. 0. 7. 5. 9. 7. 2.]
 [0. 0. 0. 8. 0. 1. 3. 5. 0. 0. 3. 3. 0. 7. 7. 6. 1. 2. 3. 2.]
 [0. 0. 0. 0. 0. 0. 0. 0. 4. 0. 0. 7. 7. 0. 1. 0. 9. 0. 8. 4.]
 [2. 5. 0. 0. 0. 3. 0. 0. 0. 0. 6. 2. 2. 3. 4. 0. 6. 2. 2. 2.]
 [3. 0. 0. 0. 0. 0. 0. 3. 0. 0. 3. 8. 4. 1. 7. 7. 1. 2. 0. 7.]
 [0. 0. 1. 4. 0. 0. 0. 0. 0. 5. 0. 2. 9. 8. 7. 9. 2. 7. 0. 9.]
 [0. 0. 8. 0. 0. 0. 0. 1. 0. 4. 8. 6. 9. 0. 3. 0. 7. 4. 9. 9.]
 [0. 0. 1. 0. 0. 0. 0. 9. 9. 1. 5. 5. 6. 0. 6. 9. 5. 2. 3. 7.]
 [3. 8. 4. 8. 0. 2. 2. 3. 4. 0. 5. 7. 1. 3. 2. 2. 8. 8. 0. 6.]
 [0. 8. 7. 7. 1. 3. 7. 7. 8. 2. 1. 9. 5. 5. 7. 2. 5. 5. 4. 3.]
 [2. 4. 6. 4. 0. 4. 2. 2. 7. 1. 8. 6. 9. 6. 3. 0. 4. 9. 2. 8.]
 [3. 0. 5. 8. 4. 9. 9. 8. 4. 3. 0. 2. 9. 4. 0. 9. 5. 4. 7. 6.]
 [5. 8. 1. 2. 0. 0. 0. 0. 0. 9. 8. 7. 0. 6. 6. 0. 3. 1. 6. 6.]
 [6. 7. 0. 5. 0. 8. 8. 8.

# Inference Runner for Current Model

If you change the model definition (layers etc) during training, be sure it is also changes here

In [None]:
import json
import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt
from IPython.display import display, clear_output
from tqdm import tqdm
import os
from datetime import datetime

# Ensure the correct backend is used for interactive plots
import matplotlib
matplotlib.use('inline')
%matplotlib inline

class SimpleAgent(nn.Module):
    def __init__(self, input_size=18):  # Updated to handle concatenated input size of 18
        super(SimpleAgent, self).__init__()
        self.fc1 = nn.Linear(input_size, 36)  # First hidden layer with larger capacity for the epoch input
        self.relu = nn.ReLU()
        self.fc2 = nn.Linear(36, 36)  # Second hidden layer
        self.fc3 = nn.Linear(36, 10)  # Output layer
        
    def forward(self, x):
        x = torch.cat((x[:, :-1], x[:, -1].unsqueeze(1).repeat(1, 9)), dim=1)  # Repeat the epoch value to match grid input
        x = self.relu(self.fc1(x))
        x = self.relu(self.fc2(x))
        return torch.nn.functional.log_softmax(self.fc3(x), dim=1)

def distribute_agents(grid_size, agent_scope=3):
    agents = []
    step_size = 1  # Place an agent at every cell
    for i in range(0, grid_size[0], step_size):  
        for j in range(0, grid_size[1], step_size):
            agents.append((i, j))  # Record agent positions
    return agents

def transform_grid(grid, simple_agents, epoch, device='cuda'):
    transformed = np.zeros_like(grid)  # Initialize with zeros
    
    # Convert grid to tensor and move it to the specified device
    grid_tensor = torch.tensor(grid, dtype=torch.float32).to(device)
    
    for position in simple_agents:
        x_start = max(0, position[0])
        y_start = max(0, position[1])
        x_end = min(len(transformed), position[0] + 3)
        y_end = min(len(transformed[0]), position[1] + 3)

        local_area = grid[x_start:x_end, y_start:y_end].flatten()
        
        if len(local_area) != 9:
            local_area = np.pad(local_area, (0, 9 - len(local_area)), 'constant')
        
        # Convert local_area to tensor and move to device
        local_area_tensor = torch.tensor(local_area, dtype=torch.float32).to(device)
        
        # Create the input for the model
        epoch_tensor = torch.tensor([epoch], dtype=torch.float32).unsqueeze(0).to(device)
        input_to_model = torch.cat((local_area_tensor.unsqueeze(0), epoch_tensor), dim=1)
        
        # Forward pass for the current simple agent if it exists
        if position in simple_agent_models:
            with torch.no_grad():  # Disable gradient computation during inference
                prediction = simple_agent_models[position](input_to_model.to(device)).argmax(dim=1).item()  # Ensure input is on device
            
            transformed[position[0]][position[1]] = prediction
        else:
            print(f"Warning: No model found for agent at position {position}. Skipping.")
            
    return torch.tensor(transformed, dtype=torch.float32).to(device)

def update_plots_inference(fig, ax1, epoch, current_grid, plot_dir):
    # Clear previous plots
    ax1.clear()

    # Move the tensor to CPU and then convert it to a numpy array
    current_grid_cpu = current_grid.cpu().numpy()
    
    # Update the grid visualization
    cax = ax1.imshow(current_grid_cpu, cmap='viridis', interpolation='nearest')  # Convert tensor to numpy array
    ax1.set_title(f'Grid Transformation (Epoch {epoch})')

    # Save the plot to a file with a unique name based on the epoch number with linux timestamp
    prefix = str(int(datetime.now().timestamp()))
    grid_plot_path = os.path.join(plot_dir, f'{prefix}_grid_epoch_{epoch}.png')
    
    fig.savefig(grid_plot_path)
    # print(f"Saved grid plot for epoch {epoch} to {grid_plot_path}")
    
    # Close the figure to free up memory
    plt.close(fig)

# Function to load ARC JSON data and convert to tensors
def load_arc_data(json_file):
    with open(json_file, 'r') as f:
        data = json.load(f)
    
    inputs = []
    outputs = []
    
    for task in data['train']:
        input_grid = torch.tensor(task['input'], dtype=torch.float32)
        output_grid = torch.tensor(task['output'], dtype=torch.float32)
        inputs.append(input_grid)
        outputs.append(output_grid)
    
    return inputs, outputs

# Define the inference plot directory and model directory
inf_plot_dir = '/home/xaqmusic/inf_plots'
model_dir = '/home/xaqmusic/models'

# Create the directories if they don't exist
os.makedirs(inf_plot_dir, exist_ok=True)

# Generate a unique prefix for this run using Unix time (epoch seconds since 1970)
prefix = str(int(datetime.now().timestamp()))  # Example: "1693425180"

# Initialize plots at the beginning of your script
fig_inference, ax_inference = plt.subplots(figsize=(7, 6))

# Example usage
if __name__ == "__main__":
    json_file = '/home/xaqmusic/ARC-AGI/data/training/00d62c1b.json'
    
    # Load ARC JSON data and convert to tensors
    inputs, outputs = load_arc_data(json_file)
    
    for i in range(len(inputs)):
        input_grid = inputs[i]
        output_grid = outputs[i]
        
        grid_size = input_grid.shape
        
        simple_agents = distribute_agents(grid_size)
        
        print("Loading agent models...")
        num_simple_agents = len(simple_agents)
        global simple_agent_models
        simple_agent_models = {}
        for position in simple_agents:
            model_path = os.path.join(model_dir, f'simple_agent_{position[0]}_{position[1]}.pth')
            if os.path.exists(model_path):
                print(f"Loading model from {model_path}")
                model = SimpleAgent()
                model.load_state_dict(torch.load(model_path))
                model.eval()  # Set the model to evaluation mode
                simple_agent_models[position] = model.to(device)  # Ensure the model is on the correct device
        
        device = 'cuda' if torch.cuda.is_available() else 'cpu'
        print(f"Using device: {device}")
        
        # Initialize the grid for inference with the input grid
        current_grid = input_grid.clone().to(device)
        
        print("Running Inference...")
        for epoch in tqdm(range(100), desc="Inference", leave=False):
            current_grid = transform_grid(current_grid.cpu(), simple_agents, epoch=epoch, device=device)
            update_plots_inference(fig_inference, ax_inference, epoch + 1, current_grid, inf_plot_dir)
            clear_output(wait=True)
            # print(current_grid)
        print("\nFinal Grid after Inference:")
        print(current_grid.cpu().numpy())
        input('Press Enter to Continue to Next Grid')

# Close the inference figure to free up memory
plt.close(fig_inference)
