In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import rigging as rg
import kego.arc.images as images
from kego.files.json import load_json
import arckit
import pathlib

In [None]:
PATH_DATA = pathlib.Path("/home/kristian/Projects/kego/data/")
PATH_COMPETITION = PATH_DATA / "arc/arc-prize-2024/"
PATH_TRAIN_CHALLENGES = PATH_COMPETITION / "arc-agi_training_challenges.json"
PATH_TRAIN_SOLUTIONS = PATH_COMPETITION / "arc-agi_training_solutions.json"
PATH_TEST = PATH_COMPETITION / "test.csv"
PATH_SUBMISSION_EXAMPLE = PATH_COMPETITION / "sample_submission.csv"

# MODEL = "transformers!meta-llama/Meta-Llama-3-8B-Instruct,device_map=cuda:1,max_tokens=1024,load_in_4bit=True"
MODEL = "openai/gpt-4o-mini"

In [None]:
train_solutions = load_json(PATH_TRAIN_SOLUTIONS)

train_set, test_set = arckit.load_data("kaggle2024")

task_example = train_set.tasks[0]
drawing = images.show_task(task_example, train_solutions=train_solutions)

In [None]:
task, solution = images.get_solution(task_example, train_solutions=train_solutions)

In [None]:
prompt = task_example.gpt_prompt(0, include_completion=False)
prompt

In [None]:
def array_to_str(array):
    return (
        array.__str__().strip("[]").replace("]", "\n").replace("[", "").replace(",", "")
    )


def sample_to_str(sample):
    examples_dict = sample.to_dict()
    training_examples_str = ""
    for i, example in enumerate(examples_dict["train"]):
        for io in ["input", "output"]:
            training_examples_str += f"{io} {i}: " + array_to_str(example[io]) + "\n"
    return training_examples_str

In [None]:
prompt

In [None]:
print(prompt)

In [None]:
class Solution(rg.Model):
    solution: str
    explanation: str

    @classmethod
    def xml_example(cls) -> str:
        return Solution(
            solution="""
            7 7 7 0 0 0 7 7 7
            7 7 7 0 0 0 7 7 7
            7 7 7 0 0 0 7 7 7
            7 7 7 0 0 0 7 7 7
            7 7 7 0 0 0 7 7 7
            7 7 7 0 0 0 7 7 7
            7 7 0 7 7 7 0 0 0
            7 7 0 7 7 7 0 0 0
            7 7 0 7 7 7 0 0 0
            """,
            explanation="""
            Step 1: Identify the Transformation Pattern
            Looking at the provided examples:

            Input 1 to Output 1: The input grid was a 3x3 matrix, and the output grid became a 9x9 matrix.
            Input 2 to Output 2: The input grid was a 3x3 matrix, and the output grid became a 9x9 matrix.
            Input 3 to Output 3: The input grid was a 3x3 matrix, and the output grid became a 9x9 matrix.
            Input 4 to Output 4: The input grid was a 3x3 matrix, and the output grid became a 9x9 matrix.
            Input 5 to Output 5: The input grid was a 3x3 matrix, and the output grid became a 9x9 matrix.
            From this, we can infer that the transformation involves expanding the grid from 3x3 to 9x9. This suggests a magnification or duplication of the grid elements.

            Step 2: Determine the Transformation Rule
            Upon examining the relationship between input and output, we observe:

            Each element of the 3x3 input grid seems to be expanded into a 3x3 block in the 9x9 output grid. For instance:
            In Input 1, the 7 at position (1,2) of the input grid results in a 3x3 block of 7s at positions (4, 7) in the output grid.
            Empty cells (with value 0) remain as 0 in the corresponding 3x3 blocks in the output grid.

            Step 3: Apply the Transformation to Input 6
            Let's now apply the transformation to Input 6:

            Input 6:
            7 0 7
            7 0 7
            7 7 0
            For each element in this grid:

            The 7 at position (1,1) in the input should expand into a 3x3 block of 7s in the output.
            The 0 at position (1,2) in the input should expand into a 3x3 block of 0s in the output.
            This pattern continues for each element.

            Output 6 will be:
            """,
        ).to_pretty_xml()


def get_model():
    return rg.get_generator(MODEL)


async def generate_solution(sample, model, verbose=False):
    n_input_problem = len(task.train) + 1
    system_prompt = f"""
        'We are playing a game which involves transforming an input grid of digits into an output grid of digits. 
        In general, digits form objects in 2D and the task is to perform some spatial transformation 
        of these objects to go from the input grid to the output grid. 
        All the information about the transformation is contained within the input pairs themselves, 
        and your answer will only be correct if the output grid is exactly correct, 
        so this is what I expect from you. I will begin by giving you several examples of input-output pairs. 
        You will then be given a new input grid, and you must provide the corresponding output grid.\n
        
        Please provide a step-by-step explanation. 
        Specifically, answer the following in your explanation. 
        1. Justify the output shape of your answer. \n
        2. Did you consider shapes in the outputs and why?\n
        Provide the output for Input {n_input_problem} again at the end of your answer.'
    """
    sample_str = sample_to_str(sample)

    prompt = "We are playing a game which involves transforming an input grid of digits into an output grid of digits. In general, digits form objects in 2D and the task is to perform some spatial transformation of these objects to go from the input grid to the output grid. All the information about the transformation is contained within the input pairs themselves, and your answer will only be correct if the output grid is exactly correct, so this is what I expect from you. I will begin by giving you several examples of input-output pairs. You will then be given a new input grid, and you must provide the corresponding output grid.\nInput 1: \n0 7 7\n7 7 7\n0 7 7\nOutput 1: \n0 0 0 0 7 7 0 7 7\n0 0 0 7 7 7 7 7 7\n0 0 0 0 7 7 0 7 7\n0 7 7 0 7 7 0 7 7\n7 7 7 7 7 7 7 7 7\n0 7 7 0 7 7 0 7 7\n0 0 0 0 7 7 0 7 7\n0 0 0 7 7 7 7 7 7\n0 0 0 0 7 7 0 7 7\n\nInput 2: \n4 0 4\n0 0 0\n0 4 0\nOutput 2: \n4 0 4 0 0 0 4 0 4\n0 0 0 0 0 0 0 0 0\n0 4 0 0 0 0 0 4 0\n0 0 0 0 0 0 0 0 0\n0 0 0 0 0 0 0 0 0\n0 0 0 0 0 0 0 0 0\n0 0 0 4 0 4 0 0 0\n0 0 0 0 0 0 0 0 0\n0 0 0 0 4 0 0 0 0\n\nInput 3: \n0 0 0\n0 0 2\n2 0 2\nOutput 3: \n0 0 0 0 0 0 0 0 0\n0 0 0 0 0 0 0 0 0\n0 0 0 0 0 0 0 0 0\n0 0 0 0 0 0 0 0 0\n0 0 0 0 0 0 0 0 2\n0 0 0 0 0 0 2 0 2\n0 0 0 0 0 0 0 0 0\n0 0 2 0 0 0 0 0 2\n2 0 2 0 0 0 2 0 2\n\nInput 4: \n6 6 0\n6 0 0\n0 6 6\nOutput 4: \n6 6 0 6 6 0 0 0 0\n6 0 0 6 0 0 0 0 0\n0 6 6 0 6 6 0 0 0\n6 6 0 0 0 0 0 0 0\n6 0 0 0 0 0 0 0 0\n0 6 6 0 0 0 0 0 0\n0 0 0 6 6 0 6 6 0\n0 0 0 6 0 0 6 0 0\n0 0 0 0 6 6 0 6 6\n\nInput 5: \n2 2 2\n0 0 0\n0 2 2\nOutput 5: \n2 2 2 2 2 2 2 2 2\n0 0 0 0 0 0 0 0 0\n0 2 2 0 2 2 0 2 2\n0 0 0 0 0 0 0 0 0\n0 0 0 0 0 0 0 0 0\n0 0 0 0 0 0 0 0 0\n0 0 0 2 2 2 2 2 2\n0 0 0 0 0 0 0 0 0\n0 0 0 0 2 2 0 2 2\n\nInput 6:\n7 0 7\n7 0 7\n7 7 0Please provide a step-by-step explanation. Specifically, answer the following in your explanation. 1. Justify the output shape of your answer. \n2. Did you consider shapes in the outputs and why?\nProvide the output for Input 6 again at the end of your answer."
    user_prompt = f"""
        The training sample is {sample_str}
    
        Provide your solution and explanation within the tag 
        
        Example:
        {prompt}
        {Solution.xml_example()}
    """
    asker = await model.chat(
        [
            {"role": "system", "content": system_prompt},
            {"role": "user", "content": user_prompt},
        ]
    ).run()
    solution = asker.last.parse(Solution).solution
    explanation = asker.last.parse(Solution).explanation
    if verbose:
        print(f"=== Solution ====")
        print(solution)
        print(f"=== Explanation ====")
        print(explanation)

    return solution, explanation

In [None]:
model = get_model()

In [None]:
task_example = train_set.tasks[1]

In [None]:
await generate_solution(sample=task_example, model=model, verbose=True)