# Grid selection with logprobs

## Goal

Can I use VLLM to select the correct grid answer?

## Configuration

```
python inference.py --model_path=/home/gbarbadillo/data/Qwen2-0.5B-arc --predictions_per_task=8 --output_filepath=submission_x8_logprob.json
python inference.py --model_path=/home/gbarbadillo/data/Qwen2-0.5B-arc --predictions_per_task=32 --output_filepath=submission_x32_logprob.json
python inference.py --model_path=/home/gbarbadillo/data/Qwen2-0.5B-arc --predictions_per_task=128 --output_filepath=submission_x128_logprob.json
```

In [None]:
class cfg:
    # solutions_filepath: str = '/mnt/hdd0/Kaggle/arc24/evaluations/first_evaluations/submission_x8_logprob.json'
    # solutions_filepath: str = '/mnt/hdd0/Kaggle/arc24/evaluations/first_evaluations/submission_x32_logprob.json'
    solutions_filepath: str = '/mnt/hdd0/Kaggle/arc24/evaluations/first_evaluations/submission_x128_logprob.json'
    dataset_filepath: str = '/mnt/hdd0/Kaggle/arc24/data/arc-agi_evaluation_challenges.json'

## Imports

In [None]:
import sys
import os
import json
import numpy as np
import matplotlib.pyplot as plt
import matplotlib as mpl

sys.path.append(os.path.realpath('../scripts/'))

from evaluation import (
    load_arc_data_with_solutions,
    evaluate,
    plot_grid,
    plot_task,
    print_metrics
)
from voting import select_most_voted_solutions

plt.plot()
plt.close('all')
plt.rcParams["figure.figsize"] = (25, 4)
mpl.rcParams['lines.linewidth'] = 3
mpl.rcParams['font.size'] = 16

## Baseline results

What is the pass_n metric of all the predictions, and the accuracy of voting?

In [None]:
with open(cfg.solutions_filepath, 'r') as f:
    solutions = json.load(f)
ground_truth = load_arc_data_with_solutions(cfg.dataset_filepath)
print_metrics(evaluate(ground_truth, solutions, verbose=False)[0])
voting_solutions = select_most_voted_solutions(solutions, 2)
print_metrics(evaluate(ground_truth, voting_solutions, verbose=False)[0])

If we use all predictions we get a pass_n=23.5%, if we vote two candidates we get 12.5%.

## Exploring the use of logprobs to select the correct answer

### Naive approach

In [None]:
with open(cfg.solutions_filepath.replace('.json', '_task_results.json'), 'r') as f:
    rich_solutions = json.load(f)

In [None]:
def create_empty_solutions(data):
    solutions = dict()
    for task_id, task in data.items():
        solutions[task_id] = [dict(attempt_1=[], attempt_2=[]) for _ in task['test']]
    return solutions

In [None]:
grids = dict()
for output in rich_solutions:
    if output['grid']:
        task_id = output['task_id']
        idx = output['idx']
        if task_id not in grids:
            grids[task_id] = dict()
        if idx not in grids[task_id]:
            grids[task_id][idx] = list()
        keys = ['grid', 'n_tokens', 'cumulative_logprob']
        info = {k: output[k] for k in keys}
        info['mean_cumulative_logprob'] = info['cumulative_logprob'] / info['n_tokens']
        grids[task_id][idx].append(info)

In [None]:
n_outputs = 5
task_metrics = evaluate(ground_truth, solutions, verbose=False)[1]
relevant_tasks = {task_id for task_id in task_metrics if task_metrics[task_id]['pass_n'] == 1.0}
# relevant_tasks = {}


logprob_solutions = create_empty_solutions({key: ground_truth[key] for key in grids.keys()})
chosen_metric = 'cumulative_logprob' # 'mean_cumulative_logprob', cumulative_logprob
for task_id, task in grids.items():
    if task_id in relevant_tasks: plot_task(ground_truth[task_id]); plt.suptitle(task_id); plt.show()
    for idx, outputs in task.items():
        outputs = sorted(outputs, key=lambda x: x[chosen_metric], reverse=True)
        if task_id in relevant_tasks:
            for plot_idx, output in enumerate(outputs[:n_outputs]):
                plt.subplot(1, n_outputs, plot_idx + 1)
                plot_grid(output['grid'])
                title = f'{output[chosen_metric]:.2f}'
                if output['grid'] == ground_truth[task_id]['test'][idx]['output']:
                    title = f'Correct\n{title}'
                plt.title(title)
            plt.show()
        best_output = outputs[0]
        logprob_solutions[task_id][idx] = dict(attempt_1=best_output['grid'])
        for output in outputs[1:]:
            if output['grid'] == best_output['grid']:
                continue
            else:
                logprob_solutions[task_id][idx]['attempt_2'] = output['grid']
                break

In [None]:
print_metrics(evaluate(ground_truth, logprob_solutions, verbose=False)[0])

### Aggregate logprobs for the same prediction

In [None]:
grids = dict()
for output in rich_solutions:
    if output['grid']:
        task_id = output['task_id']
        idx = output['idx']
        if task_id not in grids:
            grids[task_id] = dict()
        if idx not in grids[task_id]:
            grids[task_id][idx] = dict()
        grid_key = str(output['grid'])
        if grid_key not in grids[task_id][idx]:
            grids[task_id][idx][grid_key] = dict(
                grid=output['grid'], cumulative_logprob=[], mean_cumulative_logprob=[])
        grids[task_id][idx][grid_key]['cumulative_logprob'].append(output['cumulative_logprob'])
        grids[task_id][idx][grid_key]['mean_cumulative_logprob'].append(output['cumulative_logprob'] / output['n_tokens'])

In [None]:
n_outputs = 5
task_metrics = evaluate(ground_truth, solutions, verbose=False)[1]
relevant_tasks = {task_id for task_id in task_metrics if task_metrics[task_id]['pass_n'] == 1.0}
relevant_tasks = {}


logprob_solutions = create_empty_solutions({key: ground_truth[key] for key in grids.keys()})
chosen_metric = 'cumulative_logprob' # 'mean_cumulative_logprob', cumulative_logprob
for task_id, task in grids.items():
    if task_id in relevant_tasks: plot_task(ground_truth[task_id]); plt.suptitle(task_id); plt.show()
    for idx, outputs in task.items():
        outputs = sorted(outputs.values(), key=lambda x: np.mean(x[chosen_metric]), reverse=True)
        if task_id in relevant_tasks:
            for plot_idx, output in enumerate(outputs[:n_outputs]):
                plt.subplot(1, n_outputs, plot_idx + 1)
                plot_grid(output['grid'])
                title = f'{np.mean(output[chosen_metric]):.1e} +- {1.96*np.std(output[chosen_metric])/np.sqrt(len(output[chosen_metric])):.0e} ({len(output[chosen_metric])})'
                if output['grid'] == ground_truth[task_id]['test'][idx]['output']:
                    title = f'Correct\n{title}'
                plt.title(title)
            plt.show()
        best_output = outputs[0]
        logprob_solutions[task_id][idx] = dict(attempt_1=best_output['grid'])
        for output in outputs[1:]:
            if output['grid'] == best_output['grid']:
                continue
            else:
                logprob_solutions[task_id][idx]['attempt_2'] = output['grid']
                break

In [None]:
print_metrics(evaluate(ground_truth, logprob_solutions, verbose=False)[0])

```
# cumulative_logprob
accuracy: 4.8%	correct_pixels: 67.0%	max_correct_pixels: 72.8%	correct_size: 81.4%	any_correct_size: 85.2%	pass_n: 9.7%	unanswered: 0.0%	
# mean_cumulative_logprob
accuracy: 4.8%	correct_pixels: 65.6%	max_correct_pixels: 70.8%	correct_size: 79.3%	any_correct_size: 83.2%	pass_n: 9.7%	unanswered: 0.0%	
```

## TODO:

- [ ] How the number of predictions affects to the comparison, voting scales well with the number of predictions