In [None]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [None]:
#| default_exp arc

In [None]:
#| export
import json, os
import numpy as np
import gym
from gym import spaces
from time import sleep
import pygame
from matplotlib import colors
# import copy

In [None]:
#| export
class ARCEnv(gym.Env):
    def __init__(self):
        super(ARCEnv, self).__init__()
        self.index = 0
        self.env = None
        self.dimensions = []
        self.fitness = 0.0
        self.state = []
        self.done = False
        
        # Render settings
        self.screen_width = 800
        self.screen_height = 600
        self.grid_size = 20

        self.cmap = colors.ListedColormap(['#000000', '#0074D9', '#FF4136', '#2ECC40', '#FFDC00',
                                           '#AAAAAA', '#F012BE', '#FF851B', '#7FDBFF', '#870C25'])
        self.norm = colors.Normalize(vmin=0, vmax=9)

    def initialise(self, file_name, properties):
        with open(file_name, 'r') as f:
            data = json.load(f)
        
        self.index = properties.get('index', 0)
        self.train_data = data['train']
        self.test_data = data['test']
        self.inputs = [x['input'] for x in self.train_data]
        self.outputs = [x.get('output', []) for x in self.train_data]
        self.reset()

    def get_train_array(self, idx):
        return self.train_data[idx]

    def get_input_array(self, idx):
        return self.inputs[idx]

    def get_output_array(self, idx):
        return self.outputs[idx]

    def get_element(self, array, row, col):
        return array[row][col]

    def get_dimensions(self):
        return self.dimensions

    def get_index(self):
        return self.index

    def set_index(self, index):
        self.index = index
        self.reset()

    def add_rows(self, num_rows):
        self.env = np.pad(self.env, ((0, num_rows), (0, 0)), mode='constant', constant_values=0)

    def remove_rows(self, num_rows):
        self.env = self.env[:-num_rows, :]

    def add_columns(self, num_columns):
        self.env = np.pad(self.env, ((0, 0), (0, num_columns)), mode='constant', constant_values=0)

    def remove_columns(self, num_columns):
        self.env = self.env[:, :-num_columns]

    def fitness_function(self):
        output = np.array(self.outputs[self.index])
        self.set_dimensions()

        # First metric: Squared difference in dimensions
        # dim_metric = (self.env.shape[0] - output.shape[0]) ** 2 + (self.env.shape[1] - output.shape[1]) ** 2
        dim_metric = (self.dimensions[0] - self.dimensions[2]) ** 2 + (self.dimensions[1] - self.dimensions[3]) ** 2

        # Second metric: Squared difference in elements
        element_metric = 0
        for i in range(max(self.env.shape[0], output.shape[0])):
            for j in range(max(self.env.shape[1], output.shape[1])):
                env_val = self.env[i, j] if i < self.env.shape[0] and j < self.env.shape[1] else None
                output_val = output[i, j] if i < output.shape[0] and j < output.shape[1] else None
                if env_val is None or output_val is None:
                    element_metric += 25
                else:
                    element_metric += (env_val - output_val) ** 2

        # temp
        element_metric = 0

        # Final metric: Sum of the two metrics
        final_metric = dim_metric + element_metric
        return final_metric

    def step(self, action):
        num_rows, num_cols, *values = action

        if num_rows > 0:
            self.add_rows(num_rows)
        elif num_rows < 0:
            self.remove_rows(abs(num_rows))
        
        if num_cols > 0:
            self.add_columns(num_cols)
        elif num_cols < 0:
            self.remove_columns(abs(num_cols))
        
        for i, value in enumerate(values):
            row, col = divmod(i, self.env.shape[1])
            if row < self.env.shape[0] and col < self.env.shape[1]:
                self.env[row, col] = value

        self.fitness = self.fitness_function()
        if self.fitness < 1e-6:
            self.done = True
        
        self.state = self.env.flatten().tolist()
        return self.state, self.fitness, self.done

    def set_dimensions(self):
        self.dimensions = [len(self.env[0]), len(self.env), len(self.outputs[self.index][0]), len(self.outputs[self.index])]


    def reset(self):
        self.env = np.array(self.inputs[self.index])
        self.set_dimensions()
        self.fitness = self.fitness_function()
        self.done = False
        self.state = self.env.flatten().tolist()
        return self.state

    def render(self, mode='human'):
        def draw_grid(screen, grid, top_left_x, top_left_y, cell_size):
            for i, row in enumerate(grid):
                for j, value in enumerate(row):
                    # color = self.cmap(self.norm(value))

                    normed = self.norm(value)
                    color = tuple(255 * elem for elem in self.cmap(normed))
                    print(value, normed, color, end =" ")

                    pygame.draw.rect(screen, color, (top_left_x + j * cell_size, top_left_y + i * cell_size, cell_size, cell_size))

        if not hasattr(self, 'screen'):
            pygame.init()
            self.screen = pygame.display.set_mode((self.screen_width, self.screen_height))
            pygame.display.set_caption('ARC Environment')

        self.screen.fill((255, 255, 255))

        input_grid = np.array(self.inputs[self.index])
        output_grid = np.array(self.outputs[self.index])
        env_grid = self.env

        draw_grid(self.screen, input_grid, 50, 50, self.grid_size)
        draw_grid(self.screen, output_grid, 300, 50, self.grid_size)
        draw_grid(self.screen, env_grid, 550, 50, self.grid_size)

        font = pygame.font.Font(None, 74)
        arrow = font.render(u'\u2192', True, (0, 0, 0))
        equal = font.render(u'=', True, (0, 0, 0))
        fitness_text = font.render(f"Fitness: {self.fitness:.2f}", True, (0, 0, 0))
        indicator_text = font.render(u'\u2713' if self.fitness < 1e-6 else u'\u2717', True, (0, 255, 0) if self.fitness < 1e-6 else (255, 0, 0))

        self.screen.blit(arrow, (250, 150))
        self.screen.blit(equal, (500, 150))
        self.screen.blit(fitness_text, (600, 400))
        self.screen.blit(indicator_text, (600, 500))

        pygame.display.flip()

        for event in pygame.event.get():
            if event.type == pygame.QUIT:
                pygame.quit()

    def close(self):
        if self.screen is not None:
            import pygame

            pygame.display.quit()
            pygame.quit()
            self.isopen = False                

In [None]:
#| gui
# Example usage:
props = {'dir': 'C:\\packages\\arc-prize-2024\\training', 'code':'1_007bbfb7.dat'}
file_path = os.path.join(props['dir'], props['code'])
arc_env = ARCEnv()
arc_env.initialise(file_path, {'index': 0})
arc_env.render()
#    print(state, fitness, done)
print(arc_env.dimensions)
for i in range(6):
    state, fitness, done = arc_env.step([1, 1, 1, 2, 3, 4, 5, 6, 7])
    # print(state, fitness, done)
    print(arc_env.dimensions, fitness, done)
    arc_env.render()

sleep(5)
arc_env.close()

0 0.0 (0.0, 0.0, 0.0, 255.0) 7 0.7777777777777778 (255.0, 133.0, 27.0, 255.0) 7 0.7777777777777778 (255.0, 133.0, 27.0, 255.0) 7 0.7777777777777778 (255.0, 133.0, 27.0, 255.0) 7 0.7777777777777778 (255.0, 133.0, 27.0, 255.0) 7 0.7777777777777778 (255.0, 133.0, 27.0, 255.0) 0 0.0 (0.0, 0.0, 0.0, 255.0) 7 0.7777777777777778 (255.0, 133.0, 27.0, 255.0) 7 0.7777777777777778 (255.0, 133.0, 27.0, 255.0) 0 0.0 (0.0, 0.0, 0.0, 255.0) 0 0.0 (0.0, 0.0, 0.0, 255.0) 0 0.0 (0.0, 0.0, 0.0, 255.0) 0 0.0 (0.0, 0.0, 0.0, 255.0) 7 0.7777777777777778 (255.0, 133.0, 27.0, 255.0) 7 0.7777777777777778 (255.0, 133.0, 27.0, 255.0) 0 0.0 (0.0, 0.0, 0.0, 255.0) 7 0.7777777777777778 (255.0, 133.0, 27.0, 255.0) 7 0.7777777777777778 (255.0, 133.0, 27.0, 255.0) 0 0.0 (0.0, 0.0, 0.0, 255.0) 0 0.0 (0.0, 0.0, 0.0, 255.0) 0 0.0 (0.0, 0.0, 0.0, 255.0) 7 0.7777777777777778 (255.0, 133.0, 27.0, 255.0) 7 0.7777777777777778 (255.0, 133.0, 27.0, 255.0) 7 0.7777777777777778 (255.0, 133.0, 27.0, 255.0) 7 0.7777777777777778 (25

In [None]:
#| hide
import nbdev; nbdev.nbdev_export()