# Grid selection with logprobs

## Goal

Can I use VLLM to select the correct grid answer?

## Configuration

In [None]:
class cfg:
    solutions_filepath: str = '/mnt/hdd0/MEGA/AI/22_Kaggle/arc24/scripts/submission_x8_logprob.json'
    dataset_filepath: str = '/mnt/hdd0/Kaggle/arc24/data/arc-agi_evaluation_challenges.json'

## Imports

In [None]:
import sys
import os
import json
import time
import textwrap
from vllm import LLM, SamplingParams
from transformers import AutoTokenizer
import matplotlib.pyplot as plt
import matplotlib as mpl

sys.path.append(os.path.realpath('../scripts/'))

from evaluation import (
    load_arc_data_with_solutions,
    evaluate,
    plot_grid,
    plot_task
)
from inference import (
    _create_empty_solutions
)
from arc24.prompting import (
    pretty_print_prompt,
    system_prompt,
    prompt_template,
    answer_template,
    remove_assistant_ending
)
from voting import select_most_voted_solutions

plt.plot()
plt.close('all')
plt.rcParams["figure.figsize"] = (25, 4)
mpl.rcParams['lines.linewidth'] = 3
mpl.rcParams['font.size'] = 16

## Baseline results

What is the pass_n metric of all the predictions, and the accuracy of voting?

In [None]:
with open(cfg.solutions_filepath, 'r') as f:
    solutions = json.load(f)
ground_truth = load_arc_data_with_solutions(cfg.dataset_filepath)
evaluate(ground_truth, solutions, verbose=False)[0]

In [None]:
voting_solutions = select_most_voted_solutions(solutions, 2)
evaluate(ground_truth, voting_solutions, verbose=False)[0]

If we use all predictions we get a pass_n=23.5%, if we vote two candidates we get 12.5%.

## Exploring the use of logprobs to select the correct answer

In [None]:
with open(cfg.solutions_filepath.replace('.json', '_task_results.json'), 'r') as f:
    rich_solutions = json.load(f)

In [None]:
grids = dict()
task_metrics = evaluate(ground_truth, solutions, verbose=False)[1]
relevant_tasks = {task_id for task_id in task_metrics if task_metrics[task_id]['pass_n'] == 1.0}
for output in rich_solutions:
    if output['grid']:
        task_id = output['task_id']
        if task_id not in relevant_tasks:
            continue
        idx = output['idx']
        if task_id not in grids:
            grids[task_id] = dict()
        if idx not in grids[task_id]:
            grids[task_id][idx] = list()
        keys = ['grid', 'n_tokens', 'cumulative_logprob']
        info = {k: output[k] for k in keys}
        info['mean_cumulative_logprob'] = info['cumulative_logprob'] / info['n_tokens']
        grids[task_id][idx].append(info)
len(relevant_tasks), len(grids)

In [None]:
def create_empty_solutions(data):
    solutions = dict()
    for task_id, task in data.items():
        solutions[task_id] = [dict(attempt_1=[], attempt_2=[]) for _ in task['test']]
    return solutions

In [None]:
n_outputs = 5

solutions = create_empty_solutions({key: ground_truth[key] for key in grids.keys()})
chosen_metric = 'cumulative_logprob'
for task_id, task in grids.items():
    plot_task(ground_truth[task_id]); plt.suptitle(task_id); plt.show()
    for idx, outputs in task.items():
        outputs = sorted(outputs, key=lambda x: x[chosen_metric], reverse=True)

        for output in outputs[:n_outputs]:
            plt.subplot(1, n_outputs, outputs.index(output) + 1)
            plot_grid(output['grid'])
            plt.title(f'{output[chosen_metric]:.2f}')
        plt.tight_layout()
        plt.show()

        print([output[chosen_metric] for output in outputs[:3]])
        best_output = outputs[0]
        solutions[task_id][idx] = dict(attempt_1=best_output['grid'])
        for output in outputs[1:]:
            if output['grid'] == best_output['grid']:
                continue
            else:
                solutions[task_id][idx]['attempt_2'] = output['grid']
                break

In [None]:
evaluate(ground_truth, solutions, verbose=False)[0]

I need to visualize the metrics