In [None]:
import os
import sys
from plot_funcs import plot_task

# Add the project root directory to sys.path
project_root = os.path.abspath(os.path.join(os.getcwd(), '..'))
sys.path.append(project_root)

from src.load_data import load_from_json

data_source = 'arc-agi_training'
training_challenges, training_solutions = load_from_json(data_source, '../input_data/')

In [None]:
import itertools
import random
import copy

def count(matrix, frequency):
    for row in matrix:
        for num in row:
            frequency[num] += 1

def transform(matrix, mapping):
    for row in matrix:
        for i, num in enumerate(row):
            row[i] = mapping[num] if num > 0 else 0
    return matrix

def data_augment_preprocess(task, task_solution):
    # Initialize frequency dictionary
    frequency = {i: 0 for i in range(10)}  # 0 to 9
    
    for input_matrix in task['train']:
        count(input_matrix['input'], frequency)
        count(input_matrix['output'], frequency)
    
    count(task['test'][0]['input'], frequency)
    count(task_solution, frequency)
    
    # Filter out keys with count 0 and return only the keys
    non_zero_keys = [key for key, count in frequency.items() if count > 0]
    
    return set(non_zero_keys)

def generate_deterministic_permutations(input_set, N, seed=43):
    # Set the seed for reproducibility
    random.seed(seed)
    

    if 0 in input_set:
        input_set.remove(0)
    
    # Generate all possible permutations
    all_permutations = list(itertools.permutations(range(1, 10), len(input_set)))
    
    # Shuffle the permutations deterministically
    random.shuffle(all_permutations)
    
    # Select the first N unique permutations (or all if N is greater)
    unique_permutations = all_permutations[:min(N, len(all_permutations))]
    
    return unique_permutations  


def data_augment_process(task, task_solution, keys, perm):
    mapping = dict(zip(keys, perm))
    
    task = copy.deepcopy(task)
    task_solution = copy.deepcopy(task_solution)
    
#     print('deep copied')

#     print(task['train'], type(task['train']))
    for i, input_matrix  in enumerate(task['train']):
        # print(i, input_matrix)
        task['train'][i]['input'] = transform(input_matrix['input'], mapping)
        task['train'][i]['output'] = transform(input_matrix['output'], mapping)
    
    task['test'][0]['input'] = transform(task['test'][0]['input'], mapping)
    task_solution = transform(task_solution, mapping)
    
    return task, task_solution

In [None]:
from pprint import pprint
from pprint import pformat

for i in range(1): # range(len(training_challenges)):
    t=list(training_challenges)[i]
    task=training_challenges[t]
    task_solution = training_solutions[t][0]
    
    # print('#train', pformat(task['train']).replace('\n', ''))
    # print('#test', pformat(task['test']).replace('\n', ''))
    # print('#test_solution', pformat(task_solution).replace('\n', ''))
    
    keys = data_augment_preprocess(task, task_solution)
    # print('keys', keys)
    
    perms = generate_deterministic_permutations(keys, 40, i)
    
    for perm in perms:
        new_task, new_task_solution = data_augment_process(task, task_solution, keys, perm)
        # print('new_task, new_task_solution', new_task, new_task_solution)
        
    print('total perms', len(perms))