In [None]:
import os
import sys
from plot_funcs import plot_task
import json
from typing import Dict, Any

# Add the project root directory to sys.path
project_root = os.path.abspath(os.path.join(os.getcwd(), '..'))
sys.path.append(project_root)

from src.load_data import load_from_json

data_source = 'arc-agi_evaluation' # 'arc-agi_test' # 
training_challenges, training_solutions = load_from_json(data_source, '../input_data/')

def load_submission_jsons(base_path: str) -> Dict[str, Any]:
    """
    Load all JSON files under the base path and merge them into a single dictionary.
    
    Args:
        base_path (str): Base directory path containing JSON files
        
    Returns:
        Dict[str, Any]: Merged dictionary containing all JSON contents
    """
    merged_dict = {}
    
    # Walk through all files in the directory and subdirectories
    for root, _, files in os.walk(base_path):
        for file in files:
            if file.endswith('.json'):
                try:
                    file_path = os.path.join(root, file)
                    with open(file_path, 'r') as f:
                        json_content = json.load(f)
                        # Merge the content with our existing dictionary
                        merged_dict.update(json_content)
                except json.JSONDecodeError as e:
                    print(f"Error decoding JSON from {file}: {e}")
                except Exception as e:
                    print(f"Error processing file {file}: {e}")
    
    return merged_dict

def load_submission_json(case: str, base_path: str) -> tuple:
    with open(os.path.join(base_path , case + '.json')) as f:
        submission = json.load(f)

    return submission

# ice_submission_2_rte = load_submission_json('ice_submission_2_rte', '../../3-arc24/backup/ice')
# v5_ice_submission = load_submission_json('v5_ice_submission', '../../3-arc24/backup/ice/kaggle')
# soma_submission = load_submission_json('submission', '../../3-arc24/backup/evaluation/soma')
# kaggle_soma_submission = load_submission_json('kaggle_submission', '../../3-arc24/backup/soma/')
# local_soma_submission = load_submission_json('submission', '../../3-arc24/working/soma/')
# print('ice_submission_2_rte', len(ice_submission_2_rte))
# print('v5_ice_submission', len(v5_ice_submission))
# print('soma_submission', len(soma_submission))
# print('kaggle_soma_submission', len(kaggle_soma_submission))
# print('local_soma_submission', len(local_soma_submission))
# nn_submission = load_submission_json('submission', '/home/nikola/Code/GenII/3-arc24/working/transformer/')
# nn_submission = load_submission_jsons('/home/nikola/Code/GenII/3-arc24/working/transformer/submission')
nn_submission = load_submission_json('4852f2fa', '/home/nikola/Code/GenII/3-arc24/working/transformer/submission')



'''
submission solution format:

'id': array
- [0]:
--  attempt1: list[list[]]
--  attempt2: list[list[]]
- [1]:
--  attempt1: list[list[]]
--  attempt2: list[list[]]

submission solution will be assembled from SUB-submission solution:

'id': array
- [0]:
--  attempts: array
---    list[list[]]
- [1]:
--  attempts: array
---    list[list[]]


plot_task format:

task_solutions: array:
[0]: list[list[int]]
[1]: list[list[int]]

one of them will be a placeholder here
''';


In [None]:
# from pprint import pprint
from pprint import pformat

def assert_list_of_lists_of_ints(var: list, expected_length: int) -> None:
    """
    Asserts that var is a list of list[list[int]] with exactly expected_length elements
    Raises AssertionError if conditions are not met
    """
    assert isinstance(var, list), "Outer structure must be a list"
    assert len(var) == expected_length, f"List must have exactly {expected_length} elements {var}"
    assert all(isinstance(outer_list, list) for outer_list in var), "All elements must be lists"
    assert all(
        isinstance(inner_list, list) and 
        all(isinstance(x, int) for x in inner_list)
        for outer_list in var
        for inner_list in outer_list
    ), "All nested elements must be lists of integers"

def plot_attempt(attempt, has_correct_answers, title, task):
    assert_list_of_lists_of_ints(attempt, len(task['test']))    
    has_correct_answer = task_solution == attempt
    has_correct_answers.append(has_correct_answer)
    plot_task(task, attempt, i, f"{title} {'CORRECT' if has_correct_answer else 'WRONG'}") # , save_prefix = data_source
    
def show_submision(solution, has_correct_answers, title, task):
    assert isinstance(solution, list), type(solution)
    if not 'attempts' in solution[0]:    
        if len([task['attempt_1'] for task in solution if 'attempt_1' in task]):
            attempt_1 = [task['attempt_1'] if 'attempt_1' in task else [[9]] for task in solution]
            plot_attempt(attempt_1, has_correct_answers, f'{title} attempt_1', task)
    
        attempt_2 = [task['attempt_2'] if 'attempt_2' in task else [[9]] for task in solution]
        
        if len([task['attempt_2'] for task in solution if 'attempt_2' in task]):
            attempt_2 = [task['attempt_2'] if 'attempt_2' in task else [[9]] for task in solution]
            plot_attempt(attempt_2, has_correct_answers, f'{title} attempt_2', task)
    else:
        for test_index, test_answer in enumerate(solution):
            for attempt_id, attempt_answer in enumerate(test_answer['attempts']):

                attempt_placeholder = [ [[None]] ] * len(task['test'])
                attempt_placeholder[test_index] = attempt_answer
                                
                plot_attempt(attempt_placeholder, has_correct_answers, f'{title} attempt* {attempt_id}', task)        
                
for i in range(len(training_challenges)): # 
    t=list(training_challenges)[i]

    # if t != '1c0d0a4b':
    #     continue
    task=training_challenges[t]
    task_solution = training_solutions[t] if training_solutions else None

    has_correct_answers = []

    # print('#train', pformat(task['train']).replace('\n', ''))
    # print('#test', pformat(task['test']).replace('\n', ''))
    
    # if t in ice_submission_2_rte:
    #     show_submision(ice_submission_2_rte[t], has_correct_answers, title=f"{t}, ice_submission_2_rte", task=task)

    # if t in v5_ice_submission:
    #     show_submision(v5_ice_submission[t], has_correct_answers, title=f"{t}, v5_ice_submission", task=task)
        
    # attempt_1 = [task['attempt_1'] for task in sub_icecube[t]]    
    # plot_task(task, attempt_1, i, f'{t}, sub_icecube[0]') # , save_prefix = data_source

    # attempt_2 = [task['attempt_2'] for task in sub_icecube[t]]    
    # plot_task(task, attempt_2, i, f'{t}, sub_icecube[1]') # , save_prefix = data_source

    # if t in soma_submission:
    #     show_submision(soma_submission[t], has_correct_answers, title=f"{t}, soma_submission", task=task)

    # if t in kaggle_soma_submission:
    #     show_submision(soma_submission[t], has_correct_answers, title=f"{t}, kaggle_soma_submission")

    # if t in local_soma_submission:
    #     show_submision(local_soma_submission[t], has_correct_answers, title=f"{t}, local_submission", task=task)
                
    if t in nn_submission:
        show_submision(nn_submission[t], has_correct_answers, title=f"{t}, nn_submission", task=task)

    # print('nn_submission[t]', nn_submission[t])
    # attempt_2 = [task['attempt_2'] for task in nn_submission[t]]    
    # plot_task(task, attempt_2, i, f'{t}, nn_submission[1]') # , save_prefix = data_source

    if len(has_correct_answers) > 0 and not any(has_correct_answers) and task_solution:
        plot_task(task, task_solution, i, f'{t}, solution') # , save_prefix = data_source

    # print('#test_solution', pformat(task_solution).replace('\n', ''))