# ARC Specific-AI Model Performance!

In [None]:
import os
import sys
sys.path.append(os.path.abspath(os.path.join(os.getcwd(), os.pardir)))

import warnings
import json
import torch
import matplotlib.pyplot as plt
from matplotlib import colors
try:
  from rich import print
except:
  pass


file_path = '../../outputs/2024-08-17/18-17-24/test_results.json' # Change filepath to your test_results.json

COLORS = [
  '#000000', '#0074D9', '#FF4136', '#2ECC40', '#FFDC00',
  '#AAAAAA', '#F012BE', '#FF851B', '#7FDBFF', '#870C25'
]

In [None]:
def plot_single_image(matrix, ax, title, cmap, norm):
    ax.imshow(matrix, cmap=cmap, norm=norm)
    ax.grid(True, which='both', color='lightgrey', linewidth = 0.5)
    
    plt.setp(plt.gcf().get_axes(), xticklabels=[], yticklabels=[])
    ax.set_xticks([x-0.5 for x in range(1 + len(matrix[0]))])     
    ax.set_yticks([x-0.5 for x in range(1 + len(matrix))])
    
    ax.set_title(title, fontweight='bold')


def plot_xytc(*images, task_id, titles=['Input', 'Predicted', 'Answer', 'Correct']):
    """Plots the input, predicted, and answer pairs of a specified task, using the ARC color scheme."""
    num_img = len(images)
    
    fig, axs = plt.subplots(1, num_img, figsize=(len(images)*3, 3))
    plt.suptitle(f'Task {task_id}', fontsize=20, fontweight='bold', y=0.96)

    # If there's only one image, axs may not be an array
    if num_img == 1:
        axs = [axs]  # Make it array

    cmap = colors.ListedColormap(COLORS)
    norm = colors.Normalize(vmin=0, vmax=9)
    
    images = [image.detach().cpu().squeeze(0).long() if isinstance(image, torch.Tensor) else torch.tensor(image) for image in images if image is not None]
    images = [torch.argmax(image, dim=0) if len(image.shape) > 2 else image for image in images]

    for i in range(num_img):
        if num_img > 2 and num_img < 5:
            i = 2 if i == 1 else 1 if i == 2 else i
        plot_single_image(images[i], axs[i], titles[i], cmap, norm)
    
    fig.patch.set_linewidth(5)
    fig.patch.set_edgecolor('black')
    fig.patch.set_facecolor('#dddddd')

    fig.tight_layout()
    plt.show()


def plot_xytc_from_json(file_path='./output/test_results.json', titles=['Input', 'Output', 'Answer', 'Correct'], keys_json=['input', 'output', 'target', 'correct_pixels'], plot_only_correct=False, top_k=2, total=400, change_order_1_2=True, verbose=False):
    results = json.load(open(file_path, 'r'))
    print('Hyperparameters:', results['hparams'])
    results = results['results']

    exist_label = list(results.values())[0][0][0].get('target') is not None
    assert not plot_only_correct or exist_label, 'The results do not contain answer.'
    if len(list(results.values())[0][0]) < top_k:
        warnings.warn(f'Top-k is set to {top_k} but the number of trials is {len(list(results.values())[0][0])}, less than {top_k}.')

    if exist_label:
        if change_order_1_2:
            titles[1], titles[2] = titles[2], titles[1]
            keys_json[1], keys_json[2] = keys_json[2], keys_json[1]

        task_ids_correct = [key for key, task_result in results.items() if \
            all(any(
                all(all(pixel == 3 for pixel in row) for row in trial['correct_pixels'])
                    for trial in trials[:top_k if len(trials) >= top_k else len(trials)]) for trials in task_result)
        ]
        print('N Submittable: {} | N Total: {} | N Correct: {} | Accuracy: {:.2f}%'. format(len(results), total, len(task_ids_correct), len(task_ids_correct)/total*100))
    else:
        print('N Submittable: {} | N Total: {}'.format(len(results), total))

    for task_id, task_result in results.items():
        if plot_only_correct and task_id not in task_ids_correct:
            continue

        for i, trials in enumerate(task_result):
            for j, trial in enumerate(trials):
                if j == top_k:
                    break

                if verbose:
                    hparmas_ids = trial['hparams_ids']
                    print(f'Task {task_id} | Test {i+1} | hparams_ids: {hparmas_ids}')

                images = [trial[key] for key in keys_json if key in trial]
                plot_xytc(*images, task_id=task_id, titles=titles)


## Only Tasks got Correct

In [None]:
plot_xytc_from_json(file_path, plot_only_correct=True, top_k=2, total=400, verbose=True)

## All Tasks

In [None]:
plot_xytc_from_json(file_path, plot_only_correct=False, top_k=1, total=400)