# 3x3 SimpleAgents Working on ARC Puzzles POC

This code is designed to solve ARC (AI Research Challenge) puzzles using a neural network-based approach. The main components of the code include:

1. **SimpleAgent Class**: This is a neural network model defined using PyTorch's `nn.Module`. It consists of three fully connected layers with ReLU activations and outputs a log-softmax distribution over 10 classes. The input size is 18, which likely includes both grid data and an epoch value.

2. **distribute_agents Function**: This function generates a list of agent positions across the grid based on the specified `grid_size` and `agent_scope`.

3. **train_agents Function**: This function trains multiple simple agents distributed over the grid. It uses Mean Squared Error (MSE) loss for optimization, with an Adam optimizer. The training process involves:
   - Moving models to the specified device (e.g., GPU).
   - Collecting predictions from each agent based on their local grid area.
   - Calculating loss for each prediction and updating the model parameters.
   - Adjusting the learning rate dynamically based on the decrease in training loss.

4. **update_plots Function**: This function updates the visualization plots during training, including the loss plot and the grid transformation visualization. It saves these plots to specified directories.

5. **transform_grid Function**: This function transforms the input grid using the trained agents' predictions for each position on the grid.

6. **save_initial_grid Function**: This function saves an initial grid image before starting the training process.

7. **load_arc_data Function**: This function loads ARC JSON data, converting it into tensors suitable for training.

8. **Main Script**: The main part of the script initializes plots, loads ARC data, creates and trains agents, and transforms the input grid based on the trained models' predictions. It also saves initial grids and transformed grids for visualization and debugging purposes.

### Key Points:
- **Distributed Agents**: Each agent is responsible for a local area on the grid and makes predictions based on its local context.
- **Dynamic Learning Rate Adjustment**: The learning rate is adjusted dynamically during training to improve convergence.
- **Visualization and Debugging**: Extensive plotting and saving of intermediate results allow for easy visualization and debugging throughout the training process.
- **Training Process**: The agents are trained iteratively over multiple epochs, with loss being monitored and plotted to assess model performance.

Author's note: This code is designed as a starting point for grokking the necessary components required for solving ARC and similar problems.  This approach was not expected to work.  However, completing the basic functionality as inteded lent me a wealth of knowledge for more promising approaches.

# v0.4

## Training Block

NOTES:  This embodiment took about 12.5 hours to train on my RTX 4060ti / AMD 3700 Windows PC (notebook running in WSL).  The smallest grid in this example required 4 seconds per epoch while the last and largest grid required over a minute per epoch.  Basically, this is the worlds fanciest random number generator :-P 

In [78]:
import json
import math
import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt
from IPython.display import display, clear_output
from tqdm import tqdm
import os
from datetime import datetime

# Ensure the correct backend is used for interactive plots
import matplotlib
matplotlib.use('inline')
%matplotlib inline

class SimpleAgent(nn.Module):
    def __init__(self, input_size=18):  # Updated to handle concatenated input size of 18
        super(SimpleAgent, self).__init__()
        self.fc1 = nn.Linear(input_size, 36)  # First hidden layer with larger capacity for the epoch input
        self.relu = nn.ReLU()
        self.fc2 = nn.Linear(36, 36)  # Second hidden layer
        self.fc3 = nn.Linear(36, 10)  # Output layer
        
    def forward(self, x):
        # print(f"Input shape before concatenation: {x.shape}")
        x = torch.cat((x[:, :-1], x[:, -1].unsqueeze(1).repeat(1, 9)), dim=1)  # Repeat the epoch value to match grid input
        # print(f"Input shape after concatenation: {x.shape}")
        x = self.relu(self.fc1(x))
        x = self.relu(self.fc2(x))
        return torch.nn.functional.log_softmax(self.fc3(x), dim=1)

def distribute_agents(grid_size, agent_scope=3):
    agents = []
    step_size = 1  # Place an agent at every cell
    for i in range(0, grid_size[0], step_size):  
        for j in range(0, grid_size[1], step_size):
            agents.append((i, j))  # Record agent positions
    return agents
# ************************ set initial learning rate here
def train_agents(fig, ax1, ax2, simple_agent_models, inputs, outputs, epochs=4, learning_rate=0.01, device='cuda', plot_dir=None, model_dir=None):
    criterion = nn.MSELoss()  # Use MSE Loss
    
    print("Moving models to the specified device...")
    for position, model in simple_agent_models.items():
        model.to(device)
        print(f"Simple Agent at position {position} moved to {device}")

    # Flatten list of lists into a single list of parameters
    params_to_optimize = [param for _, model in simple_agent_models.items() for param in model.parameters()]
    
    optimizer = torch.optim.Adam(params_to_optimize, lr=learning_rate)
    
    losses = []  # Initialize an empty list to store all epoch losses
    
    for epoch in tqdm(range(epochs), desc="Training", leave=False):
        print(f"Epoch {epoch + 1}/{epochs}")
        total_loss = 0
        num_samples = len(inputs)
        
        # Training phase
        for i, (x, y) in enumerate(zip(inputs, outputs)):
            optimizer.zero_grad()
            
            simple_predictions = []  # Initialize this list at the start of each iteration
            
            # Collect predictions from simple agents
            for position, model in simple_agent_models.items():
                x_start = max(0, position[0])
                y_start = max(0, position[1])
                x_end = min(x.size(0), position[0] + 3)
                y_end = min(x.size(1), position[1] + 3)

                local_x = x[x_start:x_end, y_start:y_end].flatten().to(device)
                
                if len(local_x) != 9:
                     # print(f"Padded local_x shape: {local_x.shape}")
                    local_x = torch.nn.functional.pad(local_x, (0, 9 - len(local_x)))
                
                # Normalize the epoch
                normalized_epoch = float(epoch) / epochs
                
                # Add epoch to the input
                local_x_with_epoch = torch.cat((local_x, torch.tensor([normalized_epoch]).float().to(device)), dim=0)
                
                # Forward pass for the current simple agent
                prediction = model(local_x_with_epoch.unsqueeze(0).to(device))
                # print(f"Simple Agent at position {position} prediction shape: {prediction.shape}")
                
                # Append the prediction to the list (ensure it's on CPU)
                simple_predictions.append(prediction.squeeze().cpu())
            
            combined_pred = torch.stack(simple_predictions)
            
            # Prepare target values for each prediction
            y_flat = y.flatten().long().to(device)
            simple_targets = []
            
            for position in simple_agent_models:
                x_start = max(0, position[0])
                y_start = max(0, position[1])
                x_end = min(x.size(0), position[0] + 3)
                y_end = min(x.size(1), position[1] + 3)

                local_y = y[x_start:x_end, y_start:y_end].flatten().to(device)
                
                if len(local_y) != 9:
                    # print(f"Padded local_y shape: {local_y.shape}")
                    local_y = torch.nn.functional.pad(local_y, (0, 9 - len(local_y)))
                
                # Correct the dimension for argmax
                simple_targets.append(local_y.argmax(dim=0).cpu())
            
            # Stack all targets and ensure they match the batch size of combined_pred
            all_targets = torch.stack(simple_targets)
            
            # Calculate loss for each prediction separately
            total_loss_this_sample = 0
            for pred, target in zip(combined_pred, all_targets):
                # Convert target to float before calculating loss
                target_float = target.float()
                loss = criterion(pred.unsqueeze(0), target_float.unsqueeze(0))
                total_loss_this_sample += loss.item()
                loss.backward(retain_graph=True)  # Retain graph to allow backpropagation through all agents
            
            total_loss += total_loss_this_sample
            optimizer.step()
        
        average_loss = total_loss / num_samples
        losses.append(average_loss)  # Append the current epoch's average loss
        
        # Validation phase
        validation_loss = 0

        # Initialize transformed grid with zeros for consistency
        transformed_grid_tensor = torch.zeros_like(inputs[0], dtype=torch.float32).to(device)

        with torch.no_grad():
            for i, (x, y) in enumerate(zip(inputs, outputs)):
                simple_predictions = []

                # Collect predictions from simple agents
                for position, model in simple_agent_models.items():
                    x_start = max(0, position[0])
                    y_start = max(0, position[1])
                    x_end = min(x.size(0), position[0] + 3)
                    y_end = min(x.size(1), position[1] + 3)

                    local_x = x[x_start:x_end, y_start:y_end].flatten().to(device)

                    if len(local_x) != 9:
                        local_x = torch.nn.functional.pad(local_x, (0, 9 - len(local_x)))

                    # Normalize epoch
                    normalized_epoch = float(epoch) / epochs

                    # Add epoch to the input
                    local_x_with_epoch = torch.cat((local_x, torch.tensor([normalized_epoch]).float().to(device)), dim=0)

                    # Forward pass for the current simple agent
                    prediction = model(local_x_with_epoch.unsqueeze(0).to(device))
                    
                    # Append the prediction (logits) to the list (ensure it's on CPU)
                    simple_predictions.append(prediction.squeeze().cpu())

                combined_pred = torch.stack(simple_predictions)

                # Prepare target values for each prediction (validation phase)
                y_flat = y.flatten().long().to(device)
                simple_targets = []

                for position in simple_agent_models:
                    x_start = max(0, position[0])
                    y_start = max(0, position[1])
                    x_end = min(y.size(0), position[0] + 3)
                    y_end = min(y.size(1), position[1] + 3)

                    local_y = y[x_start:x_end, y_start:y_end].flatten().to(device)

                    if len(local_y) != 9:
                        local_y = torch.nn.functional.pad(local_y, (0, 9 - len(local_y)))

                    # Convert to long and append the target
                    simple_targets.append(local_y.argmax(dim=0).cpu())

                all_targets = torch.stack(simple_targets)

                # Calculate validation loss for each prediction separately
                total_validation_loss_this_sample = 0

                for pred, target in zip(combined_pred, all_targets):
                    if pred.numel() == 0 or target.numel() == 0:
                        raise ValueError(f"Zero-sized tensor encountered. pred shape: {pred.shape}, target shape: {target.shape}")
                    
                    # Ensure the target is a LongTensor for CrossEntropyLoss
                    loss = criterion(pred.unsqueeze(0), target.long().unsqueeze(0))
                    total_validation_loss_this_sample += loss.item()

                validation_loss += total_validation_loss_this_sample

        average_validation_loss = validation_loss / num_samples

        # Update transformed grid based on predictions (if needed)
        for j, position in enumerate(simple_agents):
            prediction = combined_pred[j].argmax(dim=0).item()
            transformed_grid_tensor[position[0]][position[1]] = prediction
        
        # Update and display the grid visualization
        transformed = transform_grid(inputs[0].cpu(), simple_agents, epoch=epoch, device=device)
        
        update_plots(fig, ax1, ax2, epoch + 1, losses, transformed, plot_dir)

        if average_loss < 0.000001 and average_validation_loss < 0.000001:
            print("Training complete: Both training and validation loss are below 0.00001.")
            break

        # ***********  Adjust the learning rate as loss decreases
        initial_lr = optimizer.param_groups[0]['lr']
        max_epoch = epochs
        k = 0.5  # higher value will make transition to lower rate toward end of training steeper (0.1 to 0.9)
        min_lr = 0.001  # Minimum learning rate
        new_lr = initial_lr
        decay_factor = 0
        if epoch > 0 and losses[-1] < losses[-2]:
            # Calculate the decay factor using a sigmoid function
            decay_factor = 1 / (1 + math.exp(-k * (epoch - max_epoch / 2)))
            new_lr = initial_lr * (1 - decay_factor)
            # Ensure the learning rate does not decrease below the minimum threshold
            if new_lr > min_lr:
                optimizer.param_groups[0]['lr'] = new_lr
                print(f"Reducing learning rate from {initial_lr} to {new_lr}")
            else:
                print(f"Learning rate would have decreased to {new_lr}, but it has been set to the minimum threshold of {min_lr}.")
                optimizer.param_groups[0]['lr'] = min_lr
                new_lr = min_lr
            for param_group in optimizer.param_groups:
                param_group['lr'] = new_lr

        # Update the grid for the next epoch
        inputs[0] = transformed_grid_tensor.to(device)
        clear_output(wait=True)
        print(f"Updated Grid after Epoch {epoch + 1}:\n{transformed}")
        print(f'Epoch {epoch+1}, Training Loss: {average_loss:.4f}, Validation Loss: {average_validation_loss:.4f}')
        print(f"\nNew Learning Rate:{new_lr}")
    
    # Save the model after training
    for position, model in simple_agent_models.items():
        model_path = os.path.join(model_dir, f'simple_agent_{position[0]}_{position[1]}.pth')
        torch.save(model.state_dict(), model_path)
        # print(f"Saved model at position {position} to {model_path}")

def update_plots(fig, ax1, ax2, epoch, losses, transformed, plot_dir):
    # Clear previous plots
    ax1.clear()
    ax2.clear()

    # Update the loss plot
    ax1.plot(range(1, len(losses) + 1), losses, marker='o', color='blue')
    ax1.set_title(f'Loss Over Epochs (Epoch {epoch})')
    ax1.set_xlabel('Epoch')
    ax1.set_ylabel('Loss')

    # Ensure transformed is a NumPy array before passing to imshow
    if isinstance(transformed, torch.Tensor):
        transformed = transformed.cpu().numpy()
    
    # Update the grid visualization
    cax = ax2.imshow(transformed, cmap='inferno', interpolation='nearest')  # Convert tensor to numpy array
    ax2.set_title(f'Grid Transformation (Epoch {epoch})')
    
    # Save the plot to a file with a unique name based on the epoch number with linux timestamp
    prefix = str(int(datetime.now().timestamp()))
    loss_plot_path = os.path.join(plot_dir, f'{prefix}_loss_epoch_{epoch}.png')
    grid_plot_path = os.path.join(plot_dir, f'{prefix}_grid_epoch_{epoch}.png')
    
    fig.savefig(loss_plot_path)
    print(f"Saved loss plot for epoch {epoch} to {loss_plot_path}")
    
    # ******** Save only the grid plot without colorbar
    # fig2, ax3 = plt.subplots(figsize=(7, 6))
    # cax2 = ax3.imshow(transformed, cmap='viridis', interpolation='nearest')
    # ax3.set_title(f'Grid Transformation (Epoch {epoch})')
    # fig2.savefig(grid_plot_path)
    # print(f"Saved grid plot for epoch {epoch} to {grid_plot_path}")
    
    # Close the figure to free up memory
    plt.close(fig)
    # plt.close(fig2)

def transform_grid(grid, simple_agents, epoch, device='cuda'):
    transformed = np.array(grid)  # Initialize with input grid
    
    # Convert grid to tensor and move it to the specified device
    grid_tensor = torch.tensor(grid, dtype=torch.float32).to(device)
    
    for position in simple_agents:
        x_start = max(0, position[0])
        y_start = max(0, position[1])
        x_end = min(len(transformed), position[0] + 3)
        y_end = min(len(transformed[0]), position[1] + 3)

        local_area = grid[x_start:x_end, y_start:y_end].flatten()
        
        if len(local_area) != 9:
            # print(f"Padded local_area shape: {local_area.shape}")
            local_area = np.pad(local_area, (0, 9 - len(local_area)), 'constant')
        
        # Convert local_area to tensor and move to device
        local_area_tensor = torch.tensor(local_area, dtype=torch.float32).to(device)
        
        prediction = simple_agent_models[position](torch.cat((local_area_tensor.unsqueeze(0), torch.tensor([epoch]).float().unsqueeze(0).to(device)), dim=1)).cpu().argmax(dim=1).item()  # Move back to CPU for numpy conversion
        
        # print(f"Agent at position {position} predicted: {prediction}")
        
        transformed[position[0]][position[1]] = prediction
            
    return transformed

def save_initial_grid(grid, prefix, plot_dir):
    # Generate a random grid with values between 0 and 9
    # initial_grid = torch.randint(10, size=grid.shape).float()
    initial_grid = grid
    
    fig, ax = plt.subplots(figsize=(7, 6))
    cax = ax.imshow(initial_grid.numpy(), cmap='inferno', interpolation='nearest')
    ax.set_title(f'Initial Grid (Prefix: {prefix})')
    
    # Save the plot to a file
    initial_grid_plot_path = os.path.join(plot_dir, f'{prefix}_initial_grid.png')
    plt.savefig(initial_grid_plot_path)
    print(f"Saved initial grid plot to {initial_grid_plot_path}")
    
    # Close the figure to free up memory
    plt.close(fig)

# Function to load ARC JSON data and convert to tensors
def load_arc_data(json_file):
    with open(json_file, 'r') as f:
        data = json.load(f)
    
    inputs = []
    outputs = []
    
    for task in data['train']:
        input_grid = torch.tensor(task['input'], dtype=torch.float32)
        output_grid = torch.tensor(task['output'], dtype=torch.float32)
        inputs.append(input_grid)
        outputs.append(output_grid)
    
    return inputs, outputs

# Define the plot directory and model directory
plot_dir = '/home/xaqmusic/plots'
model_dir = '/home/xaqmusic/models'

# Create the directories if they don't exist
os.makedirs(plot_dir, exist_ok=True)
os.makedirs(model_dir, exist_ok=True)

# Generate a unique prefix for this run using Unix time (epoch seconds since 1970)
prefix = str(int(datetime.now().timestamp()))  # Example: "1693425180"

# Initialize plots at the beginning of your script
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 6))
loss_plot, = ax1.plot([], [], marker='o')
ax1.set_title('Loss Over Epochs')
ax1.set_xlabel('Epoch')
ax1.set_ylabel('Loss')

# Example usage
if __name__ == "__main__":
    json_file = '/home/xaqmusic/ARC-AGI/data/training/00d62c1b.json'
    
    # Load ARC JSON data and convert to tensors
    inputs, outputs = load_arc_data(json_file)
    
    for i in range(len(inputs)):
        input_grid = inputs[i]
        output_grid = outputs[i]
        
        grid_size = input_grid.shape
        
        simple_agents = distribute_agents(grid_size)
        
        print("Creating agent models...")
        num_simple_agents = len(simple_agents)
        simple_agent_models = {}
        for position in simple_agents:
            model_path = os.path.join(model_dir, f'simple_agent_{position[0]}_{position[1]}.pth')
            if os.path.exists(model_path):
                print(f"Loading model from {model_path}")
                model = SimpleAgent()
                model.load_state_dict(torch.load(model_path))
            else:
                model = SimpleAgent()
            simple_agent_models[position] = model
        
        device = 'cuda' if torch.cuda.is_available() else 'cpu'
        print(f"Using device: {device}")
        
        # **** Generate a random initial grid with values between 0 and 9
        # initial_grid = torch.randint(10, size=grid_size).float().to(device)
        # **** or just the input grid to start
        initial_grid = input_grid
        print(f"\nInitial Grid:\n{initial_grid}")
        
        train_inputs = [initial_grid]
        # train_inputs = [input_grid.to(device)]
        print(f"\nInput Grid:\n{train_inputs}")
        train_outputs = [output_grid.to(device)]
        print(f"\nOutput Grid:\n{train_outputs}")
        
        # Save the initial grid before training
        save_initial_grid(initial_grid.cpu(), prefix, plot_dir)
        # pause for debugging
        # input("Press Enter to continue...")
        
        print("Training Agents...")
        train_agents(fig, ax1, ax2, simple_agent_models, train_inputs, train_outputs, epochs=700, device=device, plot_dir=plot_dir, model_dir=model_dir)
        
        print("\nTransforming Grid...")
        transformed = transform_grid(input_grid.cpu(), simple_agents, epoch=0, device=device)
        print("\nInput Grid:")
        print(input_grid.cpu().numpy())
        print("\nOutput Grid after Transformation (Desired):")
        print(output_grid.cpu().numpy())
        print("\nOutput Grid after Transformation (Model's Prediction):")
        print(transformed)
        # pause for debugging
        # input("Press Enter to continue...")


                                                                                                                        

Updated Grid after Epoch 700:
[[0. 0. 0. 1. 9. 0. 0. 0. 0. 0. 3. 2. 6. 5. 5. 4. 1. 4. 8. 0.]
 [2. 0. 0. 0. 0. 0. 4. 3. 0. 0. 2. 0. 0. 9. 5. 7. 1. 0. 2. 2.]
 [1. 0. 2. 7. 0. 4. 1. 0. 0. 7. 1. 0. 9. 2. 3. 7. 5. 9. 7. 2.]
 [0. 0. 0. 8. 0. 1. 0. 1. 0. 0. 3. 3. 0. 7. 7. 6. 1. 2. 3. 2.]
 [0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 7. 7. 0. 1. 0. 9. 0. 8. 4.]
 [0. 5. 0. 1. 0. 0. 0. 0. 0. 0. 6. 2. 2. 7. 4. 0. 6. 2. 2. 2.]
 [3. 0. 0. 0. 0. 0. 1. 3. 0. 0. 0. 8. 4. 1. 7. 7. 1. 2. 2. 7.]
 [0. 0. 1. 4. 0. 0. 0. 0. 0. 5. 0. 2. 9. 8. 5. 9. 2. 8. 0. 5.]
 [0. 2. 8. 0. 0. 1. 0. 1. 0. 0. 8. 6. 9. 0. 3. 0. 7. 4. 9. 9.]
 [0. 0. 0. 0. 0. 0. 1. 7. 9. 1. 5. 5. 6. 0. 6. 9. 5. 2. 3. 7.]
 [3. 8. 4. 8. 6. 2. 2. 1. 4. 0. 1. 7. 1. 3. 2. 2. 8. 8. 0. 0.]
 [0. 8. 7. 7. 0. 3. 7. 7. 8. 2. 1. 9. 5. 5. 7. 2. 3. 5. 4. 3.]
 [2. 4. 6. 4. 0. 4. 3. 2. 7. 0. 8. 6. 9. 6. 3. 0. 4. 5. 2. 8.]
 [3. 0. 5. 8. 4. 9. 1. 8. 4. 3. 0. 2. 9. 4. 0. 9. 5. 4. 3. 6.]
 [5. 8. 1. 2. 0. 4. 0. 0. 0. 9. 8. 7. 9. 3. 6. 0. 3. 1. 6. 6.]
 [6. 7. 4. 5. 0. 8. 8. 8.




Transforming Grid...

Input Grid:
[[0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0. 0. 0. 0. 3. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 3. 3. 3. 3. 0. 3. 3. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0. 0. 0. 0. 3. 0. 3. 0. 0. 0. 0. 0. 0. 0. 3. 0.]
 [0. 0. 0. 0. 0. 0. 0. 0. 3. 3. 3. 3. 3. 3. 3. 3. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0. 0. 0. 0. 3. 0. 0. 0. 0. 0. 0. 3. 0. 0. 0. 0.]
 [0. 0. 0. 0. 3. 0. 0. 0. 3. 0. 0. 0. 0. 0. 0. 3. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0. 0. 0. 0. 3. 0. 0. 0. 0. 0. 0. 3. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0. 0. 0. 0. 3. 0. 0. 0. 0. 0. 0. 3. 0. 0. 0. 0.]
 [0. 0. 3. 0. 0. 0. 0. 0. 3. 3. 3. 3. 3. 3. 3. 3. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0. 0. 0. 0. 3. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0. 0. 0. 0. 3. 3. 3. 0. 0. 0. 0. 3. 0. 3. 0. 0.]
 [0. 0. 0. 0. 0. 0. 3. 3. 0. 0. 3. 0. 0. 3. 0. 0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0. 0. 0. 3. 0. 0. 3. 3. 0. 0. 3. 0. 0. 3. 0. 0.]
 [0. 0. 0. 0. 0. 0. 0. 3. 3. 3. 3. 0. 3. 0. 0. 3. 3. 3. 0. 0.]
 [0. 0. 0. 0. 0. 0. 

# Inference Runner for Current Model

If you change the model definition (layers etc) during training, be sure it is also changes here.  Otherwise the model will not load.

In [83]:
import json
import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt
from IPython.display import display, clear_output
from tqdm import tqdm
import os
from datetime import datetime

# Ensure the correct backend is used for interactive plots
import matplotlib
matplotlib.use('inline')
%matplotlib inline

class SimpleAgent(nn.Module):
    def __init__(self, input_size=18):  # Updated to handle concatenated input size of 18
        super(SimpleAgent, self).__init__()
        self.fc1 = nn.Linear(input_size, 36)  # First hidden layer with larger capacity for the epoch input
        self.relu = nn.ReLU()
        self.fc2 = nn.Linear(36, 36)  # Second hidden layer
        self.fc3 = nn.Linear(36, 10)  # Output layer
        
    def forward(self, x):
        x = torch.cat((x[:, :-1], x[:, -1].unsqueeze(1).repeat(1, 9)), dim=1)  # Repeat the epoch value to match grid input
        x = self.relu(self.fc1(x))
        x = self.relu(self.fc2(x))
        return torch.nn.functional.log_softmax(self.fc3(x), dim=1)

def distribute_agents(grid_size, agent_scope=3):
    agents = []
    step_size = 1  # Place an agent at every cell
    for i in range(0, grid_size[0], step_size):  
        for j in range(0, grid_size[1], step_size):
            agents.append((i, j))  # Record agent positions
    return agents

def transform_grid(grid, simple_agents, epoch, device='cuda'):
    transformed = np.zeros_like(grid)  # Initialize with zeros
    
    # Convert grid to tensor and move it to the specified device
    grid_tensor = torch.tensor(grid, dtype=torch.float32).to(device)
    
    for position in simple_agents:
        x_start = max(0, position[0])
        y_start = max(0, position[1])
        x_end = min(len(transformed), position[0] + 3)
        y_end = min(len(transformed[0]), position[1] + 3)

        local_area = grid[x_start:x_end, y_start:y_end].flatten()
        
        if len(local_area) != 9:
            local_area = np.pad(local_area, (0, 9 - len(local_area)), 'constant')
        
        # Convert local_area to tensor and move to device
        local_area_tensor = torch.tensor(local_area, dtype=torch.float32).to(device)
        
        # Create the input for the model
        epoch_tensor = torch.tensor([epoch], dtype=torch.float32).unsqueeze(0).to(device)
        input_to_model = torch.cat((local_area_tensor.unsqueeze(0), epoch_tensor), dim=1)
        
        # Forward pass for the current simple agent if it exists
        if position in simple_agent_models:
            with torch.no_grad():  # Disable gradient computation during inference
                prediction = simple_agent_models[position](input_to_model.to(device)).argmax(dim=1).item()  # Ensure input is on device
            
            transformed[position[0]][position[1]] = prediction
        else:
            print(f"Warning: No model found for agent at position {position}. Skipping.")
            
    return torch.tensor(transformed, dtype=torch.float32).to(device)

def update_plots_inference(fig, ax1, epoch, current_grid, plot_dir):
    # Clear previous plots
    ax1.clear()

    # Move the tensor to CPU and then convert it to a numpy array
    current_grid_cpu = current_grid.cpu().numpy()
    
    # Update the grid visualization
    cax = ax1.imshow(current_grid_cpu, cmap='viridis', interpolation='nearest')  # Convert tensor to numpy array
    ax1.set_title(f'Grid Transformation (Epoch {epoch})')

    # Save the plot to a file with a unique name based on the epoch number with linux timestamp
    prefix = str(int(datetime.now().timestamp()))
    grid_plot_path = os.path.join(plot_dir, f'{prefix}_grid_epoch_{epoch}.png')
    
    fig.savefig(grid_plot_path)
    # print(f"Saved grid plot for epoch {epoch} to {grid_plot_path}")
    
    # Close the figure to free up memory
    plt.close(fig)

# Function to load ARC JSON data and convert to tensors
def load_arc_data(json_file):
    with open(json_file, 'r') as f:
        data = json.load(f)
    
    inputs = []
    outputs = []
    
    for task in data['train']:
        input_grid = torch.tensor(task['input'], dtype=torch.float32)
        output_grid = torch.tensor(task['output'], dtype=torch.float32)
        inputs.append(input_grid)
        outputs.append(output_grid)
    
    return inputs, outputs

# Define the inference plot directory and model directory
inf_plot_dir = '/home/xaqmusic/inf_plots'
model_dir = '/home/xaqmusic/models'

# Create the directories if they don't exist
os.makedirs(inf_plot_dir, exist_ok=True)

# Generate a unique prefix for this run using Unix time (epoch seconds since 1970)
prefix = str(int(datetime.now().timestamp()))  # Example: "1693425180"

# Initialize plots at the beginning of your script
fig_inference, ax_inference = plt.subplots(figsize=(7, 6))

# Example usage
if __name__ == "__main__":
    json_file = '/home/xaqmusic/ARC-AGI/data/training/00d62c1b.json'
    
    # Load ARC JSON data and convert to tensors
    inputs, outputs = load_arc_data(json_file)
    
    for i in range(len(inputs)):
        input_grid = inputs[i]
        output_grid = outputs[i]
        
        grid_size = input_grid.shape
        
        simple_agents = distribute_agents(grid_size)
        
        print("Loading agent models...")
        num_simple_agents = len(simple_agents)
        global simple_agent_models
        simple_agent_models = {}
        for position in simple_agents:
            model_path = os.path.join(model_dir, f'simple_agent_{position[0]}_{position[1]}.pth')
            if os.path.exists(model_path):
                print(f"Loading model from {model_path}")
                model = SimpleAgent()
                model.load_state_dict(torch.load(model_path))
                model.eval()  # Set the model to evaluation mode
                simple_agent_models[position] = model.to(device)  # Ensure the model is on the correct device
        
        device = 'cuda' if torch.cuda.is_available() else 'cpu'
        print(f"Using device: {device}")
        
        # Initialize the grid for inference with the input grid or random
        # current_grid = input_grid.clone().to(device)
        current_grid = torch.randint(10, size=grid_size).float().to(device)
        
        print("Running Inference...")
        for epoch in tqdm(range(100), desc="Inference", leave=False):
            current_grid = transform_grid(current_grid.cpu(), simple_agents, epoch=epoch, device=device)
            update_plots_inference(fig_inference, ax_inference, epoch + 1, current_grid, inf_plot_dir)
            clear_output(wait=True)
            # print(current_grid)
        print("\nFinal Grid after Inference:")
        print(current_grid.cpu().numpy())
        # input('Press Enter to Continue to Next Grid')

# Close the inference figure to free up memory
plt.close(fig_inference)


                                                                                                                        


Final Grid after Inference:
[[0. 0. 0. 1. 9. 0. 0. 0. 0. 0. 3. 2. 6. 5. 5. 4. 1. 6. 8. 0.]
 [2. 0. 0. 0. 0. 0. 4. 3. 0. 0. 2. 0. 0. 9. 5. 7. 1. 0. 2. 2.]
 [1. 0. 2. 7. 0. 4. 1. 0. 0. 7. 1. 0. 9. 2. 3. 7. 5. 9. 7. 2.]
 [0. 0. 0. 8. 0. 1. 0. 1. 0. 0. 3. 3. 0. 7. 7. 6. 1. 2. 3. 2.]
 [0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 7. 6. 0. 1. 0. 9. 0. 8. 4.]
 [0. 5. 0. 1. 0. 0. 0. 0. 0. 0. 6. 2. 2. 7. 4. 0. 6. 2. 2. 2.]
 [3. 0. 0. 0. 0. 0. 1. 3. 0. 0. 0. 8. 4. 1. 7. 7. 0. 2. 2. 7.]
 [0. 0. 1. 4. 0. 0. 0. 0. 0. 5. 0. 2. 9. 6. 5. 9. 2. 7. 0. 5.]
 [0. 2. 8. 0. 0. 1. 0. 1. 0. 0. 8. 6. 9. 0. 3. 0. 7. 4. 9. 9.]
 [0. 0. 0. 0. 0. 0. 1. 7. 9. 1. 5. 5. 6. 0. 6. 9. 5. 2. 8. 7.]
 [3. 8. 4. 8. 6. 2. 2. 1. 4. 0. 1. 7. 1. 3. 2. 2. 8. 8. 0. 0.]
 [0. 8. 7. 7. 0. 3. 7. 7. 8. 2. 1. 9. 5. 5. 7. 2. 3. 5. 4. 3.]
 [2. 4. 6. 4. 0. 4. 3. 2. 7. 0. 8. 6. 9. 6. 3. 0. 4. 9. 2. 8.]
 [3. 0. 5. 8. 4. 9. 1. 8. 4. 3. 0. 2. 9. 4. 0. 9. 5. 4. 3. 6.]
 [5. 8. 1. 0. 0. 4. 0. 0. 0. 9. 8. 7. 9. 3. 6. 0. 3. 1. 6. 6.]
 [6. 7. 0. 5. 0. 8. 8. 8. 

