# Grid selection with logprobs

## Goal

Can I use VLLM to select the correct grid answer?

## Configuration

```
python inference.py --model_path=/home/gbarbadillo/data/Qwen2-0.5B-arc --predictions_per_task=8 --output_filepath=submission_x8_logprob.json
python inference.py --model_path=/home/gbarbadillo/data/Qwen2-0.5B-arc --predictions_per_task=32 --output_filepath=submission_x32_logprob.json
python inference.py --model_path=/home/gbarbadillo/data/Qwen2-0.5B-arc --predictions_per_task=128 --output_filepath=submission_x128_logprob.json
```

In [None]:
class cfg:
    solutions_filepath: str = '/mnt/hdd0/Kaggle/arc24/evaluations/first_evaluations/submission_x8_logprob.json'
    # solutions_filepath: str = '/mnt/hdd0/Kaggle/arc24/evaluations/first_evaluations/submission_x32_logprob.json'
    # solutions_filepath: str = '/mnt/hdd0/Kaggle/arc24/evaluations/first_evaluations/submission_x128_logprob.json'
    dataset_filepath: str = '/mnt/hdd0/Kaggle/arc24/data/arc-agi_evaluation_challenges.json'

## Imports

In [None]:
import sys
import os
import json
import numpy as np
import matplotlib.pyplot as plt
import matplotlib as mpl

sys.path.append(os.path.realpath('../scripts/'))

from evaluation import (
    load_arc_data_with_solutions,
    evaluate,
    plot_grid,
    plot_task,
    print_metrics
)
from voting import select_most_voted_solutions

plt.plot()
plt.close('all')
plt.rcParams["figure.figsize"] = (25, 4)
mpl.rcParams['lines.linewidth'] = 3
mpl.rcParams['font.size'] = 16

## Baseline results

What is the pass_n metric of all the predictions, and the accuracy of voting?

In [None]:
with open(cfg.solutions_filepath, 'r') as f:
    solutions = json.load(f)
ground_truth = load_arc_data_with_solutions(cfg.dataset_filepath)
print_metrics(evaluate(ground_truth, solutions, verbose=False)[0])
voting_solutions = select_most_voted_solutions(solutions, 2)
print_metrics(evaluate(ground_truth, voting_solutions, verbose=False)[0])

If we use all predictions we get a pass_n=23.5%, if we vote two candidates we get 12.5%.

## Exploring the use of logprobs to select the correct answer

### Naive approach

In [None]:
with open(cfg.solutions_filepath.replace('.json', '_task_results.json'), 'r') as f:
    rich_solutions = json.load(f)

In [None]:
def create_empty_solutions(data):
    solutions = dict()
    for task_id, task in data.items():
        solutions[task_id] = [dict(attempt_1=[], attempt_2=[]) for _ in task['test']]
    return solutions

In [None]:
grids = dict()
for output in rich_solutions:
    if output['grid']:
        task_id = output['task_id']
        idx = output['idx']
        if task_id not in grids:
            grids[task_id] = dict()
        if idx not in grids[task_id]:
            grids[task_id][idx] = list()
        keys = ['grid', 'n_tokens', 'cumulative_logprob']
        info = {k: output[k] for k in keys}
        info['mean_cumulative_logprob'] = info['cumulative_logprob'] / info['n_tokens']
        grids[task_id][idx].append(info)

In [None]:
n_outputs = 5
task_metrics = evaluate(ground_truth, solutions, verbose=False)[1]
relevant_tasks = {task_id for task_id in task_metrics if task_metrics[task_id]['pass_n'] == 1.0}
show_plots = False

correct_positions = []
logprob_solutions = create_empty_solutions({key: ground_truth[key] for key in grids.keys()})
chosen_metric = 'mean_cumulative_logprob' # 'mean_cumulative_logprob', cumulative_logprob
for task_id, task in grids.items():
    if task_id in relevant_tasks and show_plots: plot_task(ground_truth[task_id]); plt.suptitle(task_id); plt.show()
    for idx, outputs in task.items():
        outputs = sorted(outputs, key=lambda x: x[chosen_metric], reverse=True)
        if task_id in relevant_tasks:
            if show_plots:
                for plot_idx, output in enumerate(outputs[:n_outputs]):
                    plt.subplot(1, n_outputs, plot_idx + 1)
                    plot_grid(output['grid'])
                    title = f'{output[chosen_metric]:.2f}'
                    if output['grid'] == ground_truth[task_id]['test'][idx]['output']:
                        title = f'Correct\n{title}'
                    plt.title(title)
                plt.show()
            for position, output in enumerate(outputs):
                if output['grid'] == ground_truth[task_id]['test'][idx]['output']:
                    correct_positions.append(position)
                    break
        best_output = outputs[0]
        logprob_solutions[task_id][idx] = dict(attempt_1=best_output['grid'])
        for output in outputs[1:]:
            if output['grid'] == best_output['grid']:
                continue
            else:
                logprob_solutions[task_id][idx]['attempt_2'] = output['grid']
                break
print(f'naive {chosen_metric} mean correct position: {np.mean(np.clip(correct_positions, 0, 10)):.1f} ({correct_positions})')
print_metrics(evaluate(ground_truth, logprob_solutions, verbose=False)[0])

```
mean_cumulative_logprob mean correct position: 4.3 ([82, 4, 1, 0, 0, 40, 3, 5, 101, 25, 0, 29, 15, 0, 14, 0, 0, 0, 3, 0, 0, 80, 1, 6, 42, 0])
accuracy: 5.4%	correct_pixels: 68.4%	max_correct_pixels: 73.0%	correct_size: 82.9%	any_correct_size: 86.2%	pass_n: 10.7%	unanswered: 0.0%

cumulative_logprob mean correct position: 4.4 ([83, 5, 1, 0, 0, 40, 3, 5, 101, 27, 0, 86, 15, 0, 14, 0, 0, 0, 3, 0, 0, 79, 1, 6, 42, 0])
accuracy: 5.4%	correct_pixels: 68.9%	max_correct_pixels: 73.8%	correct_size: 83.9%	any_correct_size: 87.2%	pass_n: 10.7%	unanswered: 0.0%
```

### Aggregate logprobs for the same prediction

In [None]:
grids = dict()
for output in rich_solutions:
    if output['grid']:
        task_id = output['task_id']
        idx = output['idx']
        if task_id not in grids:
            grids[task_id] = dict()
        if idx not in grids[task_id]:
            grids[task_id][idx] = dict()
        grid_key = str(output['grid'])
        if grid_key not in grids[task_id][idx]:
            grids[task_id][idx][grid_key] = dict(
                grid=output['grid'], cumulative_logprob=[], mean_cumulative_logprob=[])
        grids[task_id][idx][grid_key]['cumulative_logprob'].append(output['cumulative_logprob'])
        grids[task_id][idx][grid_key]['mean_cumulative_logprob'].append(output['cumulative_logprob'] / output['n_tokens'])

In [None]:
n_outputs = 5
task_metrics = evaluate(ground_truth, solutions, verbose=False)[1]
relevant_tasks = {task_id for task_id in task_metrics if task_metrics[task_id]['pass_n'] == 1.0}
show_plots = False
implementation = 'voting_mean_logprob'

def get_title(output):
    # if len(output[chosen_metric]) == 1:
    #     title = f'{np.mean(output[chosen_metric]):.1e} ± {1.96*default_std:.0e} ({len(output[chosen_metric])})'
    # else:
    #     title = f'{np.mean(output[chosen_metric]):.1e} ± {1.96*np.std(output[chosen_metric])/np.sqrt(len(output[chosen_metric])):.0e} ({len(output[chosen_metric])})'
    title = f'{np.mean(output[chosen_metric]):.1e} ± {1.96*default_std/np.sqrt(len(output[chosen_metric])):.0e} ({len(output[chosen_metric])})'
    return title

correct_positions = []
logprob_solutions = create_empty_solutions({key: ground_truth[key] for key in grids.keys()})
chosen_metric = 'cumulative_logprob' # 'mean_cumulative_logprob', cumulative_logprob
for task_id, task in grids.items():
    if task_id in relevant_tasks and show_plots: plot_task(ground_truth[task_id]); plt.suptitle(task_id); plt.show()
    for idx, outputs in task.items():
        default_std = np.mean([np.std(output[chosen_metric]) for output in outputs.values() if len(output[chosen_metric]) > 1])
        for output in outputs.values():
            if implementation == 'voting':
                output['ranking'] = len(output[chosen_metric])
            elif implementation == 'mean_logprob':
                output['ranking'] = np.mean(output[chosen_metric])
            elif implementation == 'lower_bound':
                if len(output[chosen_metric]) == 1:
                    output['ranking'] = output[chosen_metric][0] - 1.96*default_std
                else:
                    output['ranking'] = np.mean(output[chosen_metric]) - 1.96*np.std(output[chosen_metric])/np.sqrt(len(output[chosen_metric]))
            elif implementation == 'lower_bound_constant_std':
                output['ranking'] = np.mean(output[chosen_metric]) - 1.96*default_std/np.sqrt(len(output[chosen_metric]))

            elif implementation == 'voting_mean_logprob':
                output['ranking'] = (len(output[chosen_metric]), np.mean(output[chosen_metric]))
            # This one does not have sense, the ties will only happen when the number of votes is the same, so the mean value is the only relevant metric
            # if implementation == 'voting_lower_bound_constant_std':
            #     output['ranking'] = (len(output[chosen_metric]), np.mean(output[chosen_metric]) - 1.96*default_std/np.sqrt(len(output[chosen_metric])))
            else:
                raise ValueError(f'Unknown implementation: {implementation}')
        outputs = sorted(outputs.values(), key=lambda x: x['ranking'], reverse=True)
        if task_id in relevant_tasks:
            if show_plots:
                plotted_correct_grid = False
                for plot_idx, output in enumerate(outputs[:n_outputs]):
                    plt.subplot(1, n_outputs, plot_idx + 1)
                    plot_grid(output['grid'])
                    title = get_title(output)
                    if output['grid'] == ground_truth[task_id]['test'][idx]['output']:
                        title = f'Correct\n{title}'
                        plotted_correct_grid = True
                    plt.title(title)
                if not plotted_correct_grid:
                    for output in outputs:
                        if output['grid'] == ground_truth[task_id]['test'][idx]['output']:
                            plt.subplot(1, n_outputs, n_outputs)
                            plot_grid(output['grid'])
                            title = get_title(output)
                            title = f'Correct\n{title}'
                            plt.title(title)
                            break
                plt.show()
            for position, output in enumerate(outputs):
                if output['grid'] == ground_truth[task_id]['test'][idx]['output']:
                    correct_positions.append(position)
                    break
        best_output = outputs[0]
        logprob_solutions[task_id][idx] = dict(attempt_1=best_output['grid'])
        for output in outputs[1:]:
            if output['grid'] == best_output['grid']:
                continue
            else:
                logprob_solutions[task_id][idx]['attempt_2'] = output['grid']
                break
print(f'{implementation} {chosen_metric} mean correct position: {np.mean(np.clip(correct_positions, 0, 10)):.1f} ({correct_positions})')
print_metrics(evaluate(ground_truth, logprob_solutions, verbose=False)[0])

```
voting_mean_logprob mean_cumulative_logprob mean correct position: 2.2 ([1, 5, 1, 0, 0, 3, 1, 0, 3, 27, 1, 3, 11, 0, 1, 0, 0, 0, 1, 0, 2, 3, 6, 3, 4, 0])
voting_lower_bound_constant_std mean_cumulative_logprob mean correct position: 2.2 ([1, 5, 1, 0, 0, 3, 1, 0, 3, 27, 1, 3, 11, 0, 1, 0, 0, 0, 1, 0, 2, 3, 6, 3, 4, 0])

voting_mean_logprob cumulative_logprob mean correct position: 2.3 ([1, 5, 1, 0, 0, 3, 1, 0, 3, 29, 1, 4, 11, 0, 1, 0, 0, 0, 1, 0, 2, 3, 6, 3, 4, 0])
voting_lower_bound_constant_std cumulative_logprob mean correct position: 2.3 ([1, 5, 1, 0, 0, 3, 1, 0, 3, 29, 1, 4, 11, 0, 1, 0, 0, 0, 1, 0, 2, 3, 6, 3, 4, 0])

there is no difference here.


voting mean_cumulative_logprob mean correct position: 2.4 ([1, 7, 1, 0, 0, 3, 1, 0, 4, 35, 1, 4, 14, 0, 1, 0, 0, 0, 1, 0, 2, 3, 6, 3, 4, 0])
voting cumulative_logprob mean correct position: 2.4 ([1, 7, 1, 0, 0, 3, 1, 0, 4, 35, 1, 4, 14, 0, 1, 0, 0, 0, 1, 0, 2, 3, 6, 3, 4, 0])

lower_bound_constant_std cumulative_logprob mean correct position: 2.8      ([17, 5, 0, 0, 0, 2, 2, 1, 3, 24, 0, 16, 5, 0, 9, 0, 0, 0, 0, 0, 0, 7, 2, 3, 5, 0])
lower_bound_constant_std mean_cumulative_logprob mean correct position: 2.8 ([14, 12, 0, 0, 0, 2, 2, 1, 3, 22, 0, 5, 5, 0, 9, 0, 0, 0, 0, 0, 0, 7, 2, 2, 5, 0])


lower_bound mean_cumulative_logprob mean correct position: 3.0 ([15, 18, 0, 0, 0, 2, 2, 0, 3, 21, 0, 9, 3, 0, 9, 0, 0, 0, 0, 0, 1, 7, 5, 3, 5, 0])
lower_bound cumulative_logprob mean correct position: 3.0 ([17, 11, 0, 0, 0, 2, 1, 0, 3, 24, 0, 17, 3, 0, 9, 0, 0, 0, 0, 0, 1, 7, 5, 3, 5, 0])


mean_logprob cumulative_logprob mean correct position: 3.3 ([19, 6, 1, 0, 0, 1, 2, 5, 3, 22, 0, 17, 3, 0, 15, 2, 0, 0, 0, 0, 2, 11, 3, 3, 5, 0])
mean_logprob mean_cumulative_logprob mean correct position: 3.3 ([20, 15, 1, 0, 0, 1, 2, 5, 3, 20, 0, 7, 3, 0, 15, 2, 0, 0, 0, 0, 2, 11, 3, 3, 5, 0])

naive mean_cumulative_logprob mean correct position: 4.3 ([82, 4, 1, 0, 0, 40, 3, 5, 101, 25, 0, 29, 15, 0, 14, 0, 0, 0, 3, 0, 0, 80, 1, 6, 42, 0])
naive cumulative_logprob mean correct position: 4.4 ([83, 5, 1, 0, 0, 40, 3, 5, 101, 27, 0, 86, 15, 0, 14, 0, 0, 0, 3, 0, 0, 79, 1, 6, 42, 0])
```

```
# Compute lower bound with shared standard deviation
#cumulative_logprob
accuracy: 5.4%	correct_pixels: 67.6%	max_correct_pixels: 73.4%	correct_size: 81.9%	any_correct_size: 86.2%	pass_n: 10.7%	unanswered: 0.0%	
# mean_cumulative_logprob
accuracy: 5.4%	correct_pixels: 66.9%	max_correct_pixels: 73.1%	correct_size: 80.9%	any_correct_size: 86.2%	pass_n: 10.7%	unanswered: 0.0%	

# Compute lower bound
#cumulative_logprob
accuracy: 6.1%	correct_pixels: 68.5%	max_correct_pixels: 73.8%	correct_size: 82.4%	any_correct_size: 86.2%	pass_n: 12.2%	unanswered: 0.0%	
# mean_cumulative_logprob
accuracy: 5.6%	correct_pixels: 68.4%	max_correct_pixels: 73.0%	correct_size: 82.4%	any_correct_size: 86.2%	pass_n: 11.2%	unanswered: 0.0%	
# First implementation
# cumulative_logprob
accuracy: 4.8%	correct_pixels: 67.0%	max_correct_pixels: 72.8%	correct_size: 81.4%	any_correct_size: 85.2%	pass_n: 9.7%	unanswered: 0.0%	
# mean_cumulative_logprob
accuracy: 4.8%	correct_pixels: 65.6%	max_correct_pixels: 70.8%	correct_size: 79.3%	any_correct_size: 83.2%	pass_n: 9.7%	unanswered: 0.0%	
```

## TODO:

- [ ] How the number of predictions affects to the comparison, voting scales well with the number of predictions
- [ ] Should I use other metrics such as mean position?