# Fine-tuned LLMs for ARC24 Challenge

## Goal

How far can we get with LLMs on ARC24 challenge?

Can I reach the same accuracy with single prompt as with the dialog in the previous notebook?

## Motivation

On my previous [Few-shot prompting for ARC24](https://www.kaggle.com/code/ironbar/few-shot-prompting-for-arc24) I found that current "small" open-source LLMs do not benefit from few-shot prompting for ARC challenge. 

On this notebook I'm going to focus on the 0-shot approach: the model will only receive the input-output examples from the task of interest (no information from other tasks).

This notebook will allow to evaluate and submit different LLMs on ARC tasks. It will open the door to use fine-tuned models and also to use the misterious "test time fine-tuning". 

## Results

## References

### Notebooks

- https://www.kaggle.com/code/ironbar/2xvllm-with-code-interpreter
- https://www.kaggle.com/code/ironbar/autobots-roll-out/notebook
- https://www.kaggle.com/code/ironbar/few-shot-prompting-for-arc24

### Models

- https://huggingface.co/microsoft/Phi-3-mini-128k-instruct
- https://huggingface.co/blog/kv-cache-quantization
- https://www.reddit.com/r/LocalLLaMA/comments/1e0kkgk/hardware_requirements_for_phi3_mini_and_phi3/
- https://www.reddit.com/media?url=https%3A%2F%2Fpreview.redd.it%2Fmicrosoft-phi-3-3-8b-with-128k-context-released-on-hf-v0-h2xzg8vaigwc1.jpeg%3Fwidth%3D1734%26format%3Dpjpg%26auto%3Dwebp%26s%3Dec27de7bd97a90a4c44ff95c561ce8008ce7aed3
- https://huggingface.co/docs/transformers/main/en/chat_templating

> The Phi-3-Mini-128K-Instruct is a 3.8 billion-parameter, lightweight, state-of-the-art open model trained using the Phi-3 datasets. 

## Configuration

In [None]:
from typing import Optional

class cfg:
    # Model
    model_path = "/home/gbarbadillo/data/Qwen2-0.5B-arc"
    lora_path = None
    #lora_path : Optional[str] = '/kaggle/input/loras/transformers/phi-3_128k/1'
    merged_model_path : Optional[str] = None
    max_model_len = 8192 #61000 for phi-3
    # Dataset
    dataset_path = '/mnt/hdd0/Kaggle/arc24/data/arc-agi_training_challenges.json'
    #dataset_path = '/mnt/hdd0/Kaggle/arc24/data/arc-agi_evaluation_challenges.json'
    n_tasks = 5 # Optional parameter to limit the number of task in the inference, set it to None to use all the tasks
    # Inference params
    max_predictions_per_task = 2 # 
    sampling_params = dict(temperature=0.0, max_tokens=1000) # https://docs.vllm.ai/en/latest/dev/sampling_params.html

In [None]:
from jinja2 import Template

system_prompt = """You are a helpful AI assistant. Your job is to solve tasks from the Abstraction and Reasoning Challenge (ARC). 
The user will present you with sample input and output grids for each task. 
Your job will be to understand the transformation between the input and the output and apply it to the last input grid given by the user. 
The puzzle-like inputs and outputs present a grid where each square can be one of ten colors. A grid can be any height or width between 1x1 and 30x30.
The background of the grid is typically colored with 0.
The tasks from ARC are based on the following priors:

- Objectness: Objects persist and cannot appear or disappear without reason. Objects can interact or not depending on the circumstances.
- Goal-directed: Objects can be animate or inanimate. Some objects are "agents" - they have intentions and they pursue goals.
- Numbers & counting: Objects can be counted or sorted by their shape, appearance, or movement using basic mathematics like addition, subtraction, and comparison.
- Basic geometry & topology: Objects can be shapes like rectangles, triangles, and circles which can be mirrored, rotated, translated, deformed, combined, repeated, etc. Differences in distances can be detected.

The transformations between input and output should be based on these priors.
"""

prompt_template = Template("""Let's see if you can solve this simple ARC task. These are some input-output grid examples that define the task.
{% for sample in train_samples %}
## Example {{ loop.index }}

### Input

{{ sample.input }}

### Output

{{ sample.output }}
{% endfor %}
## Test case

### Input

{{ test_input }}
""")

answer_template = Template("""### Output

{{ test_output }}""")

train_samples = [dict(input=[0], output=[1]), dict(input=[2], output=[3])]
prompt = prompt_template.render(train_samples=train_samples, test_input=[4])
print(prompt)
print(answer_template.render(test_output=[5]))

In [None]:
import os
is_dry_run = cfg.dataset_path == '/kaggle/input/arc-prize-2024/arc-agi_test_challenges.json' and not os.getenv('KAGGLE_IS_COMPETITION_RERUN')
if is_dry_run:
    print('This is a dry run, no inference nor installation of packages will be done')

## Install

In [None]:
%%time
if not is_dry_run:
    # model imports
    from vllm import LLM, SamplingParams
    from vllm.lora.request import LoRARequest
    from transformers import AutoTokenizer

## Imports

In [None]:
from abc import ABC, abstractmethod
import json
import os
from tqdm.auto import tqdm
import numpy as np
from itertools import islice, product
import matplotlib.pyplot as plt
from matplotlib import colors
from termcolor import colored
import shutil

In [None]:
import logging

for handler in logging.root.handlers[:]:
    logging.root.removeHandler(handler)
logging.basicConfig(level=logging.INFO,
                    format='%(asctime)s - %(levelname)s - %(message)s')
logging.info('Started logging')

## Code

### Grid encoding

There are many ways to encode/format the grid as input to the LLM. In this section we are going to define several encoders so we can sistematically try them all.

In [None]:
class GridEncoder(ABC):
    @abstractmethod
    def to_text(self, grid):
        pass
    
    @abstractmethod
    def to_grid(self, text):
        pass

In [None]:
sample_grid = np.eye(3, dtype=int).tolist()

def test_translator(translator):
    assert sample_grid == translator.to_grid(translator.to_text(sample_grid))
    print(translator.to_text(sample_grid))

In [None]:
class MinimalGridEncoder(GridEncoder):
    @staticmethod
    def to_text(grid):
        text = '\n'.join([''.join([str(x) for x in line]) for line in grid])
        return text
    
    @staticmethod
    def to_grid(text):
        lines = text.strip().splitlines()
        grid = [[int(x) for x in line] for line in lines]
        return grid
        
test_translator(MinimalGridEncoder())

In [None]:
class GridWithSeparationEncoder(GridEncoder):
    def __init__(self, split_symbol):
        self.split_symbol = split_symbol

    def to_text(self, grid):
        text = '\n'.join([self.split_symbol.join([str(x) for x in line]) for line in grid])
        return text
    
    def to_grid(self, text):
        lines = text.strip().splitlines()
        grid = [[int(x) for x in line.split(self.split_symbol)] for line in lines]
        return grid
        
test_translator(GridWithSeparationEncoder('|'))

In [None]:
class GridCodeBlockEncoder(GridEncoder):
    def __init__(self, base_encoder):
        self.encoder = base_encoder
    
    def to_text(self, grid):
        text = f'```grid\n{self.encoder.to_text(grid)}\n```'
        return text
    
    def to_grid(self, text):
        grid_text = text.split('```grid\n')[1].split('\n```')[0]
        grid = self.encoder.to_grid(grid_text)
        return grid
        
test_translator(GridCodeBlockEncoder(MinimalGridEncoder()))

test_translator(GridCodeBlockEncoder(GridWithSeparationEncoder('|')))

### Prompting

There are also many ways to build a prompt for the ARC challenge. The class that builds the prompt will receive a grid encoder as input, this way we can try different prompts with different grid encoders. 
The class that builds the prompts needs to be also capable of parsing the response from the model.

In [None]:
class PromptCreator(ABC):
    def __init__(self, grid_encoder: GridEncoder):
        self.grid_encoder = grid_encoder
    
    @abstractmethod
    def create_task_prompts(self, task):
        pass
    
    @abstractmethod
    def parse_response(self, text):
        pass

In [None]:
class SimplePromptCreator(PromptCreator):
    def __init__(self, grid_encoder):
        super().__init__(grid_encoder)
    
    def create_task_prompts(self, task):        
        train_samples = [{key: self.grid_encoder.to_text(grid) for key, grid in sample.items()} for sample in task['train']]     
        prompts = []
        for test_sample in task['test']:
            user_message = prompt_template.render(train_samples=train_samples, 
                                                  test_input=self.grid_encoder.to_text(test_sample['input']))
            messages = [{"role": "system", "content": system_prompt},
                        {"role": "user", "content": user_message},
                        {"role": "assistant", "content": """### Output\n```grid\n"""}]
            # TODO: add start of assistant reply
            prompt = tokenizer.apply_chat_template(messages,
                                                   tokenize=False,
                                                   add_generation_prompt=False)
            prompts.append(remove_assistant_ending(prompt))
        return prompts
    
    def parse_response(self, text):
        return self.grid_encoder.to_grid('```grid\n' + text)
    
    
def remove_assistant_ending(text):
    """
phi-3

<|assistant|>
### Output
```grid
<|end|>
<|endoftext|>

llama 3.1

<|eot_id|><|start_header_id|>assistant<|end_header_id|>

### Output
```grid<|eot_id|><|start_header_id|>assistant<|end_header_id|>
    """
    if 'llama' in cfg.model_path:
        split_text = '<|eot_id|>'
    elif 'Qwen' in cfg.model_path:
        split_text = '<|im_end|>'
    else:
        split_text = '<|end|>'
    return split_text.join(text.split(split_text)[:-1])

In [None]:
def print_sample_prompt(data, prompt_creator):
    prompts = [prompt_creator.create_task_prompts(task)[0] for task in data.values()]
    prompts = sorted(prompts, key=lambda x: len(x))
    pretty_print_prompt(prompts[0])
    
def pretty_print_prompt(text, default_color='black'):
    color = default_color
    attrs = None
    for line in text.splitlines():
        if line.startswith('<|assistant|>'):
            color = 'blue'
        elif line.startswith('<|user|>'):
            color = default_color
        elif line.startswith('<|system|>'):
            color = 'green'
        if line.startswith('<'):
            attrs = ['bold']
        else:
            attrs = None
        print(colored(line, color, attrs=attrs))

In [None]:
def plot_input_token_length_distribution(data, prompt_creator):
    prompts = []
    for task in data.values():
        prompts.extend(prompt_creator.create_task_prompts(task))
    token_length_distribution = [len(tokenizer.tokenize(prompt)) for prompt in tqdm(prompts)]
    plt.title('Prompt token length distribution')
    plt.hist(token_length_distribution)
    plt.xlabel('n tokens')

### Model

In [None]:
script_text = f"""
import os
import glob
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import PeftModel
import shutil

base_model_path = '{cfg.model_path}'
lora_path = '{cfg.lora_path}'
output_path = '{cfg.merged_model_path}'

base_model = AutoModelForCausalLM.from_pretrained(base_model_path, torch_dtype=torch.float16)
tokenizer = AutoTokenizer.from_pretrained(base_model_path, trust_remote_code=True)
if 'llama' in '{cfg.model_path}':
    tokenizer.add_special_tokens(dict(pad_token='<|pad|>'))
    base_model.resize_token_embeddings(len(tokenizer))

model = PeftModel.from_pretrained(base_model, lora_path)
merged_model = model.merge_and_unload()
print('Saving the merged model to the output path')
merged_model.save_pretrained(output_path)

for filepath in glob.glob(os.path.join(base_model_path, '*')):
    dst = os.path.join(output_path, os.path.basename(filepath))
    if not os.path.exists(dst):
        print('Copying', filepath)
        shutil.copy(filepath, dst)

print('Done!')
"""

if not is_dry_run and cfg.merged_model_path is not None and not os.path.exists(cfg.merged_model_path):
    with open('merge_lora.py', 'w') as f:
        f.write(script_text)
    !python merge_lora.py
    os.remove('merge_lora.py')
elif cfg.merged_model_path is not None and os.path.exists(cfg.merged_model_path):
    print(f'Merged model already exists: {cfg.merged_model_path}')

In [None]:
if not is_dry_run:
    model_path = cfg.merged_model_path or cfg.model_path
    print(f'Loading {model_path}')
    llm = LLM(model=model_path,
              trust_remote_code=True, 
              dtype='half', 
              tensor_parallel_size=2, # to use 2 gpus
              max_model_len=cfg.max_model_len,
              #kv_cache_dtype='fp8_e5m2', I have disabled kv cache quantization because it is hurtful
              enforce_eager=True, # without this 13.9GB of memory is used on each GPU, with this is 13.3GB,
              disable_log_stats=True,
             )
    tokenizer = AutoTokenizer.from_pretrained(cfg.model_path)
    for number in '0123456789':
        print(f'{number}: {[key for key in tokenizer.get_vocab().keys() if number in key and not key.startswith("<")]}')

The tokenizer from phi-3 encodes each digit indepently, it does not group numbers such as 10 or 100.

### Data augmentation

We need data augmentation to make multiple predictions for each task.

In [None]:
class DataAugmentation():
    def __init__(self, flip, n_rot90):
        self.flip = flip
        self.n_rot90 = n_rot90
        
    def augment_task(self, task):
        augmented_task = dict()
        for partition, samples in task.items():
            augmented_task[partition] = [{name:self.augment_grid(grid) for name,grid in sample.items()} for sample in samples]
        return augmented_task
    
    def augment_grid(self, grid):
        grid = np.array(grid)
        if self.flip:
            grid = np.flip(grid, axis=1)
        grid = np.rot90(grid, k=self.n_rot90)
        return grid.tolist()
    
    def revert_augmentation(self, grid):
        grid = np.array(grid)
        grid = np.rot90(grid, k=-self.n_rot90)
        if self.flip:
            grid = np.flip(grid, axis=1)
        return grid.tolist()


for flip in [True, False]:
    for n_rot90 in range(4):
        data_augmentation = DataAugmentation(flip, n_rot90)
        assert sample_grid == data_augmentation.revert_augmentation(data_augmentation.augment_grid(sample_grid))

### Plots

In [None]:
def plot_task(task):
    samples = task['train'] + task['test']
    for plot_idx, sample in enumerate(samples):
        plt.subplot(2, len(samples), plot_idx + 1)
        plot_grid(sample['input'])
        if 'output' in sample:
            plt.subplot(2, len(samples), plot_idx + 1 + len(samples))
            plot_grid(sample['output'])
            
def plot_grids(grids):
    for plot_idx, grid in enumerate(grids):
        plt.subplot(1, len(grids), plot_idx + 1)
        plot_grid(grid)
            
def plot_grid(grid):
    grid = np.array(grid)
    cmap = colors.ListedColormap(
        ['#000000', '#0074D9','#FF4136','#2ECC40','#FFDC00',
         '#AAAAAA', '#F012BE', '#FF851B', '#7FDBFF', '#870C25'])
    norm = colors.Normalize(vmin=0, vmax=9)
    plt.imshow(grid, cmap=cmap, norm=norm)
    plt.grid(True,which='both',color='lightgrey', linewidth=0.5) 
    plt.xticks(np.arange(-0.5, grid.shape[1]), [])
    plt.yticks(np.arange(-0.5, grid.shape[0]), [])
    plt.xlim(-0.5, grid.shape[1]-0.5)

### Evaluation

In [None]:
def analyze_number_of_predictions_per_task(data, texts):
    number_of_predictions = dict()
    for task_id, task in data.items():
        number_of_predictions[task_id] = len(texts[task_id]['responses'])/len(task['test'])
    plt.title('Distribution of the number of predictions per task')
    plt.hist(number_of_predictions.values(), bins=np.arange(1.5, 9))
    plt.xlabel('number of predictions')
    plt.ylabel('count')
    return number_of_predictions

In [None]:
def evaluate(ground_truth, solutions):
    """
    Computes the following metrics:
    
    - Accuracy
    - Correct pixels
    - Correct size
    """
    metrics = []
    for task_id, task_ground_truth in ground_truth.items():
        task_metrics = []
        plot_task(data[task_id]); plt.suptitle(f'{task_id}'); plt.show()
        for idx, correct_grid in enumerate(task_ground_truth):
            predicted_grids = list(solutions[task_id][idx].values())
            predicted_grids = [grid for grid in predicted_grids if grid]
            
            task_metrics.append(evaluate_grid(correct_grid, predicted_grids))
            print_metrics(task_metrics[-1], f'{task_id}_{idx}')
            plot_grids([correct_grid] + predicted_grids)
            plt.suptitle(f'{task_id}_{idx}')
            plt.show()
        metrics.append(average_metrics(task_metrics))
    print('\n'*3 + '# Aggregated metrics:')
    print_metrics(average_metrics(metrics))
    save_metrics(metrics, solutions)
    plot_metrics_distribution(metrics)
    print_metrics(average_metrics(metrics))
    
def plot_metrics_distribution(metrics):
    for key in metrics[0]:
        values = [x[key] for x in metrics]
        plt.title(f'Distribution of {key}')
        plt.hist(values, bins=np.linspace(0, 1, 10))
        plt.xlabel(key)
        plt.ylabel('count')
        plt.show()
    
def average_metrics(metrics):
    averaged_metrics = dict()
    for key in metrics[0]:
        averaged_metrics[key] = np.mean([x[key] for x in metrics])
    return averaged_metrics
        
def save_metrics(metrics, solutions):
    formatted_metrics = dict(global_metrics=average_metrics(metrics))
    for task_id, task_metrics in zip(solutions, metrics):
        formatted_metrics[task_id] = task_metrics
    with open('metrics.json', 'w') as f:
        json.dump(formatted_metrics, f)

def print_metrics(metrics, prefix=''):
    text = f'{prefix}'
    for key, value in metrics.items():
        text += f'{key}: {value*100:.1f}%\t'
    print(text)

    
def evaluate_grid(correct_grid, predicted_grids):
    correct_grid = np.array(correct_grid)
    metrics = dict(accuracy=0, correct_pixels=0, correct_size=0, unanswered=(2 - len(predicted_grids))/2)
    for predicted_grid in predicted_grids:
        predicted_grid = np.array(predicted_grid)
        if correct_grid.shape == predicted_grid.shape:
            metrics['accuracy'] = max(metrics['accuracy'], np.all(predicted_grid == correct_grid))
            metrics['correct_pixels'] = max(metrics['correct_pixels'], np.mean(predicted_grid == correct_grid))
            metrics['correct_size'] = max(metrics['correct_size'], correct_grid.shape == predicted_grid.shape)
    return metrics

## Inference

We need to generate 2 different predictions for each task. The model could fail to generate a prediction, or the parsing can fail... Thus we need a method that is robust to fails.

One way to solve this would be to use data augmentation. By applying rotations and flips we could generate up to 8 variations of each task. So we could try with different data augmentations until we have 2 predictions for each task. Another alternative would be to make inference with the 8 variations and use majority voting.

In [None]:
def solve_task(task_id, task, prompt_creator, sampling_params):
    data_augmentation_params = product([False, True], [0, 1, 2, 3])
    solution = {task_id:[{"attempt_1": [], "attempt_2": []} for _ in task['test']]}
    texts = dict(prompts=[], responses=[], exceptions=[])
    for flip, n_rot90 in islice(data_augmentation_params, cfg.max_predictions_per_task):
        data_augmentation = DataAugmentation(flip, n_rot90)
        augmented_task = data_augmentation.augment_task(task)
        prompts = prompt_creator.create_task_prompts(augmented_task)
        outputs = llm.generate(prompts, sampling_params, use_tqdm=False)
        responses = [output.outputs[0].text for output in outputs]
        for idx, response in enumerate(responses):
            try:
                augmented_grid = prompt_creator.parse_response(response)
                grid = data_augmentation.revert_augmentation(augmented_grid)
                if not solution[task_id][idx]["attempt_1"]:
                    solution[task_id][idx]["attempt_1"] = grid
                elif solution[task_id][idx]["attempt_1"] != grid and not solution[task_id][idx]["attempt_2"]:
                    solution[task_id][idx]["attempt_2"] = grid
            except Exception as e:
                print(f'Exception when parsing response from {task_id}: {e}')
                texts['exceptions'].append(str(e))
        texts['prompts'].append(prompts)
        texts['responses'].append(responses)
        if is_solution_done(solution):
            break
    return solution, {task_id:texts}

def is_solution_done(solution):
    for task_id, predictions in solution.items():
        for prediction in predictions:
            for grid in prediction.values():
                if not grid:
                    return False
    return True

In [None]:
def inference(data, prompt_creator, sampling_params):
    solutions, texts = dict(), dict()
    for idx, (task_id, task) in tqdm(enumerate(data.items()), total=len(data), desc='Solving tasks'):
        logging.info(f'Solving {task_id}, {idx+1}/{len(data)}')
        task_solution, task_texts = solve_task(task_id, task, prompt_creator, sampling_params)
        solutions.update(task_solution)
        texts.update(task_texts)
    return solutions, texts

In [None]:
with open(cfg.dataset_path) as f:
    data = json.load(f)
if cfg.n_tasks is not None:
    data = dict(islice(data.items(), cfg.n_tasks))
print(f'There are {len(data)} tasks to solve.')

In [None]:
if not is_dry_run:
    prompt_creator = SimplePromptCreator(GridCodeBlockEncoder(MinimalGridEncoder()))
    print_sample_prompt(data, prompt_creator)
    plot_input_token_length_distribution(data, prompt_creator)

In [None]:
if is_dry_run:
    with open('submission.json', 'w') as f:
        json.dump(dict(dry_run=True), f)
else:
    sampling_params = SamplingParams(n=1, **cfg.sampling_params)
    solutions, texts = inference(data, prompt_creator, sampling_params)
    with open('submission.json', 'w') as f:
        json.dump(solutions, f)    

In [None]:
if not is_dry_run:
    number_of_predictions_per_task = analyze_number_of_predictions_per_task(data, texts)
    number_of_predictions_per_task

## Evaluation

In [None]:
ground_truth_path = cfg.dataset_path.replace('challenges.json', 'solutions.json')
if os.path.exists(ground_truth_path):
    with open(ground_truth_path, 'r') as f:
        ground_truth = json.load(f)
    ground_truth = {key: ground_truth[key] for key in solutions}
    evaluate(ground_truth, solutions)
    
    with open('texts.json', 'w') as f:
        json.dump(texts, f)
    with open('number_of_predictions_per_task.json', 'w') as f:
        json.dump(number_of_predictions_per_task, f)

## Clean

In [None]:
def clear_vllm_gpu_memory():
    global llm
    # https://github.com/vllm-project/vllm/issues/1908
    from vllm.distributed.parallel_state import destroy_model_parallel, destroy_distributed_environment
    import torch
    import gc
    destroy_model_parallel()
    destroy_distributed_environment()
    del llm.llm_engine.model_executor
    del llm
    gc.collect()
    torch.cuda.empty_cache()

if not is_dry_run:
    clear_vllm_gpu_memory()
    if cfg.merged_model_path is not None:
        shutil.rmtree(cfg.merged_model_path)

In [None]:
#!rm -rf *
!ls -lh

## TODO

- [x] I need to read the code and understand everything before refactoring. 
- [x] Add logging, some evaluations on the previous notebook took more than 12 hours and I didn't have observability
- [x] Remove all few-shot prompt inheritance from the previous notebook
- [x] ~~Allow to use LoRAs~~ It does not work with VLLM, maybe is due to the GPU. Instead I'm merging the model
- [x] Allow to merge a model with its LoRA (it should be faster)
- [x] Check and copy improvements done on local evaluation
- [x] Verify that I get the same results as in local evaluation
- [ ] Wait for free GPU
- [ ] Better logging of parsing errors, f.e. print the shape of the list.
- [ ] Add an option to do test-time fine-tuning
- [ ] Can I speedup inference? By making multiple requests in parallel, or running the server in another thread.
- [ ] Can I create a more compact visualization of the tasks and predictions?
- [ ] More flexible configuration and prompt specification. This should be compatible with the training script. Maybe the code should be shared.
- [ ] How to handle the case where I predict the grid shape before the grid?