In [None]:
from dataclasses import dataclass
from datetime import datetime
import json
import os
import pickle
import uuid

import jax
import jax.numpy as jnp
import numpy as np

#<config>
@dataclass(frozen=True)
class Config:
    seed: int = int(os.environ.get("SEED", 0))
    # --- data
    data_seed: int = 42
    train_challenges: str = '/kaggle/input/arc-prize-2024/arc-agi_training_challenges.json'
    train_solutions: str = '/kaggle/input/arc-prize-2024/arc-agi_training_solutions.json'
    valid_challenges: str = '/kaggle/input/arc-prize-2024/arc-agi_evaluation_challenges.json'
    valid_solutions: str = '/kaggle/input/arc-prize-2024/arc-agi_evaluation_solutions.json'
    submission_challenges: str = '/kaggle/input/arc-prize-2024/arc-agi_test_challenges.json'
    num_order_augs: int = 2 # number of task train grid order augmentations
    num_color_augs: int = 2 # number of task grid color augmentations (maintains matching train and test pair)
    max_grid_size: int = 30 # maximum grid size
    max_train_pairs: int = 10 # maximum number of train pairs
    max_test_pairs: int = 3 # maximum number of test pairs
    num_colors: int = 10 # number of colors in the grid
    pad_value: int = 0 # padding value for grids
    # --- logging 
    morph: str = os.environ.get("MORPH", "test")
    compute_backend: str = os.environ.get("COMPUTE_BACKEND", "oop")
    wandb_entity: str = "phuongdv"
    wandb_project: str = "evoarc"
    created_on: str = datetime.now().strftime("%Y%m%d%H%M%S")
    # --- model
    # TODO: add model hyperparameters here
    # --- training
    num_epochs: int = 32 # number of epochs to train
    batch_size: int = 2 # number of tasks per batch
    print_every: int = 1e4 # print training loss every this many steps
    early_stopping_patience: int = 10 # stop training if no improvement for this many epochs
    learning_rate: float = 1e-3 # initial learning rate
#</config>

cfg = Config()

if cfg.compute_backend == "kaggle":
    # when submitting to kaggle, save the output to the current directory
    output_dir = os.getcwd()
else:
    output_dir = f"/evoarc/output/{cfg.morph}"
    os.makedirs(output_dir, exist_ok=True)

print(f"output_dir: {output_dir}")
print(f"config:{json.dumps(cfg.__dict__, indent=4)}")
config_filepath = os.path.join(output_dir, "config.json")
with open(config_filepath, 'w') as f:
    json.dump(cfg.__dict__, f, indent=4)

if not cfg.compute_backend == "kaggle":
    import wandb
    # wandb.login()
    wandb.init(entity=cfg.wandb_entity, project=cfg.wandb_project, name=f"{cfg.compute_backend}.{cfg.morph}.{str(uuid.uuid4())[:6]}", config=cfg.__dict__)
    wandb.save(config_filepath)

def save_checkpoint(params, filename):
    with open(os.path.join(output_dir, filename), 'wb') as f:
        pickle.dump(params, f)

def load_checkpoint(filename):
    with open(os.path.join(output_dir, filename), 'rb') as f:
        return pickle.load(f)

#<data>
def load_tasks(challenges_path: str, solutions_path: str, cfg: Config):
    with open(challenges_path, 'r') as f:
        challenges_dict = json.load(f)
    print(f"loading challenges from {challenges_path}, found {len(challenges_dict)} challenges")
    if solutions_path is not None:
        with open(solutions_path, 'r') as f:
            solutions_dict = json.load(f)
        print(f"loading solutions from {solutions_path}, found {len(solutions_dict)} solutions")
    """
    tasks are stored in JSON format. Each JSON file consists of two key-value pairs.
    train: a list of two to ten input/output pairs (typically three.) These are used for your algorithm to infer a rule.
    test: a list of one to three input/output pairs (typically one.) Your model should apply the inferred rule from the train set and construct an output solution.
    """
    tasks = []
    for task_id in challenges_dict.keys():
        train_in = []
        train_out = []
        test_in = []
        test_out = []
        """
        a "grid" is a rectangular matrix (list of lists) of integers between 0 and 9 (inclusive).
        the smallest possible grid size is 1x1 and the largest is 30x30.
        0 represents the background color, 1-9 represent the pattern colors.
        """
        for pair in challenges_dict[task_id]['train']:
            train_in.append(np.array(pair['input'], dtype=np.uint8))
            train_out.append(np.array(pair['output'], dtype=np.uint8))
        for grid in challenges_dict[task_id]['test']:
            test_in.append(np.array(grid['input'], dtype=np.uint8))
        if solutions_path is not None:
            for grid in solutions_dict[task_id]:
                test_out.append(np.array(grid, dtype=np.uint8))
        tasks.append((task_id, train_in, train_out, test_in, test_out))
    return tasks

def augmentation(tasks, cfg: Config):
    augmented_tasks = []
    # grid structure means we can use spatial symmetry to augment the tasks
    for task in tasks:
        task_id, train_in, train_out, eval_in, eval_out = task
        for aug in [
            np.fliplr,
            np.flipud,
            lambda x: np.rot90(x, k=1),
            lambda x: np.rot90(x, k=3)
        ]:
            augmented_tasks.append((
                f"{task_id}.s{str(uuid.uuid4())[:6]}",
                [aug(grid) for grid in train_in],
                [aug(grid) for grid in train_out],
                [aug(grid) for grid in eval_in],
                [aug(grid) for grid in eval_out]
            ))
    print(f"after spatial augmentation, num tasks: {len(augmented_tasks)}")
    # assume order of train grids is also a valid augmentation
    for _ in range(cfg.num_order_augs):
        _augmented_tasks = []
        for task in augmented_tasks:
            task_id, train_in, train_out, eval_in, eval_out = task
            train_order = np.random.permutation(len(train_in))
            _augmented_tasks.append((
                f"{task_id}.o{str(uuid.uuid4())[:6]}",
                [train_in[i] for i in train_order],
                [train_out[i] for i in train_order],
                eval_in,
                eval_out
            ))
        augmented_tasks.extend(_augmented_tasks)
    print(f"after order augmentation x{cfg.num_order_augs}, num tasks: {len(augmented_tasks)}")
    # all colors (except for background) are interchangeable (but must match entire set)
    for _ in range(cfg.num_color_augs):
        _augmented_tasks = []
        for task in augmented_tasks:
            task_id, train_in, train_out, eval_in, eval_out = task
            color_map = np.arange(10)
            color_map[1:] = np.random.permutation(color_map[1:])
            _augmented_tasks.append((
                f"{task_id}.c{str(uuid.uuid4())[:6]}",
                [np.take(color_map, grid) for grid in train_in],
                [np.take(color_map, grid) for grid in train_out],
                [np.take(color_map, grid) for grid in eval_in],
                [np.take(color_map, grid) for grid in eval_out]
            ))
        augmented_tasks.extend(_augmented_tasks)
    print(f"after color augmentation x{cfg.num_color_augs}, num tasks: {len(augmented_tasks)}")
    return augmented_tasks

def pad_grids(grids, cfg: Config):
    padded_grids = []
    for grid in grids:
        h, w = grid.shape
        padded_grid = np.full((cfg.max_grid_size, cfg.max_grid_size), cfg.pad_value, dtype=np.uint8)
        padded_grid[:h, :w] = grid
        padded_grids.append(padded_grid)
    return np.stack(padded_grids)

def pad_pairs(pairs, pad_len: int, cfg: Config):
    while len(pairs) < pad_len:
        pairs.append(np.full((cfg.max_grid_size, cfg.max_grid_size), cfg.pad_value, dtype=np.uint8))
    return np.stack(pairs)

def datagen(tasks, cfg: Config, mode="train"):
    if mode == "train":
        tasks = augmentation(tasks, cfg)
        np.random.shuffle(tasks)
    num_tasks = len(tasks)
    if mode == "train":
        batch_size = cfg.batch_size
        num_batches = num_tasks // batch_size
    elif mode == "valid":
        batch_size = cfg.batch_size
        num_batches = (num_tasks + batch_size - 1) // batch_size
    elif mode == "submission":
        batch_size = 1
        num_batches = num_tasks
    else:
        raise ValueError(f"invalid mode: {mode}")
    for batch_idx in range(num_batches):
        start_idx = batch_idx * batch_size
        end_idx = min(start_idx + batch_size, num_tasks)
        batch_tasks = tasks[start_idx:end_idx]
        # skip last incomplete batch during training
        if mode == "train" and len(batch_tasks) < batch_size:
            continue
        batch_task_id, batch_train_in, batch_train_out, batch_test_in, batch_test_out = [], [], [], [], []
        for task in batch_tasks:
            task_id, train_in, train_out, test_in, test_out = task
            batch_task_id.append(task_id)
            # Pad the grids and pairs as before
            train_in_padded = pad_grids(train_in, cfg)
            train_out_padded = pad_grids(train_out, cfg)
            test_in_padded = pad_grids(test_in, cfg)
            train_in_padded = pad_pairs(list(train_in_padded), cfg.max_train_pairs, cfg)
            train_out_padded = pad_pairs(list(train_out_padded), cfg.max_train_pairs, cfg)
            test_in_padded = pad_pairs(list(test_in_padded), cfg.max_test_pairs, cfg)
            batch_train_in.append(train_in_padded)
            batch_train_out.append(train_out_padded)
            batch_test_in.append(test_in_padded)
            if test_out:
                test_out_padded = pad_grids(test_out, cfg)
                test_out_padded = pad_pairs(list(test_out_padded), cfg.max_test_pairs, cfg)
                batch_test_out.append(test_out_padded)
            else:
                batch_test_out.append(np.zeros_like(test_in_padded))
        yield (
            batch_task_id,
            jnp.array(batch_train_in, dtype=jnp.float32),
            jnp.array(batch_train_out, dtype=jnp.float32),
            jnp.array(batch_test_in, dtype=jnp.float32),
            jnp.array(batch_test_out, dtype=jnp.float32),
        )
#</data>

#<model>
import jax
import jax.numpy as jnp
import flax.linen as nn
import optax

class SimpleCNN(nn.Module):
    cfg: Config
    @nn.compact
    def __call__(self, train_in, train_out, test_in):
        x = jnp.concatenate([train_in, train_out, test_in], axis=1)
        x = x.reshape(x.shape[0], x.shape[1], x.shape[2], x.shape[3], 1)
        x = nn.Conv(features=16, kernel_size=(1, 3, 3), strides=(1, 1, 1), padding='SAME')(x)
        x = nn.relu(x)
        x = nn.Conv(features=32, kernel_size=(1, 3, 3), strides=(1, 1, 1), padding='SAME')(x)
        x = nn.relu(x)
        x = nn.Conv(features=self.cfg.num_colors, kernel_size=(1, 1, 1), strides=(1, 1, 1), padding='SAME')(x)
        return x

def init_params(key, cfg: Config):
    model = SimpleCNN(cfg)
    init_variables = model.init(
        key,
        jnp.ones((cfg.batch_size, cfg.max_train_pairs, cfg.max_grid_size, cfg.max_grid_size)),
        jnp.ones((cfg.batch_size, cfg.max_train_pairs, cfg.max_grid_size, cfg.max_grid_size)),
        jnp.ones((cfg.batch_size, cfg.max_test_pairs, cfg.max_grid_size, cfg.max_grid_size)),
    )
    return init_variables['params']

def model(params, train_in, train_out, test_in):
    cnn_model = SimpleCNN(cfg)
    return cnn_model.apply({'params': params}, train_in, train_out, test_in)

def loss_fn(logits, labels):
    logits_flat = logits[:, -labels.shape[1]:].reshape(-1, cfg.num_colors)
    labels_flat = labels.reshape(-1)
    one_hot_labels = jax.nn.one_hot(labels_flat, cfg.num_colors)
    loss = optax.softmax_cross_entropy(logits_flat, one_hot_labels)
    return jnp.mean(loss)

def accuracy_fn(logits, labels):
    logits_flat = logits[:, -labels.shape[1]:].reshape(-1, cfg.num_colors)
    labels_flat = labels.reshape(-1)
    predicted_classes = jnp.argmax(logits_flat, axis=-1)
    correct = predicted_classes == labels_flat
    return jnp.mean(correct)

def predict_fn(logits):
    return jnp.argmax(logits, axis=-1)
#</model>


#<training>
import optax

np.random.seed(cfg.data_seed)
key = jax.random.PRNGKey(cfg.seed)
params = init_params(key, cfg)
opt = optax.adam(learning_rate=cfg.learning_rate)
opt_state = opt.init(params)

@jax.jit
def train_step(params, opt_state, train_in, train_out, test_in, test_out):
    def loss_and_grad(params):
        model_output = model(params, train_in, train_out, test_in)
        loss = loss_fn(model_output, test_out)
        return loss
    loss, grads = jax.value_and_grad(loss_and_grad)(params)
    updates, opt_state = opt.update(grads, opt_state, params)
    params = optax.apply_updates(params, updates)
    return params, opt_state, loss

@jax.jit
def validation_step(params, train_in, train_out, test_in, test_out):
    model_output = model(params, train_in, train_out, test_in)
    loss = loss_fn(model_output, test_out)
    accuracy = accuracy_fn(model_output, test_out)
    return loss, accuracy

train_tasks = load_tasks(cfg.train_challenges, cfg.train_solutions, cfg)
valid_tasks = load_tasks(cfg.valid_challenges, cfg.valid_solutions, cfg)
global_step = 0
best_valid_loss = float('inf')
best_valid_acc = 0.0
epochs_without_improvement = 0
for epoch in range(cfg.num_epochs):
    print(f"epoch {epoch + 1}/{cfg.num_epochs}")
    for _, train_in, train_out, test_in, test_out in datagen(train_tasks, cfg, mode="train"):
        global_step += cfg.batch_size
        params, opt_state, train_loss = train_step(params, opt_state, train_in, train_out, test_in, test_out)
        if global_step % cfg.print_every == 0:
            print(f"global step {global_step} loss = {train_loss.item():.4f}")
            if cfg.compute_backend != "kaggle":
                wandb.log({"train_loss": train_loss.item()}, step=global_step)
    # validation
    total_valid_loss = 0.0
    total_valid_acc = 0.0
    num_batches = 0
    for _, train_in, train_out, test_in, test_out in datagen(valid_tasks, cfg, mode="valid"):
        batch_valid_loss, batch_valid_acc = validation_step(params, train_in, train_out, test_in, test_out)
        total_valid_loss += batch_valid_loss.item()
        total_valid_acc += batch_valid_acc.item()
        num_batches += 1
    valid_loss = total_valid_loss / num_batches
    valid_acc = total_valid_acc / num_batches
    print(f'valid_loss: {valid_loss:.4f}, valid_acc: {valid_acc:.4f}')
    if not cfg.compute_backend == "kaggle":
        wandb.log({"valid_loss": valid_loss, "valid_acc": valid_acc}, step=global_step)
    if valid_loss < best_valid_loss:
        best_valid_loss = valid_loss
        best_valid_acc = valid_acc
        epochs_without_improvement = 0
        save_checkpoint(params, "best.pkl")
    else:
        epochs_without_improvement += 1
        if epochs_without_improvement >= cfg.early_stopping_patience:
            print(f"early stopping at epoch {epoch + 1}")
            break
save_checkpoint(params, "final.pkl")
# submission will be made with two model checkpoints
attempt1_ckpt = "best.pkl"
attempt2_ckpt = "final.pkl"
#</training>

"""
for each task output in the evaluation set, you should make exactly 2 predictions (attempt_1, attempt_2).
most tasks only have a single output (a single dictionary enclosed in a list), although some tasks have multiple outputs that must be predicted.
when a task has multiple test outputs that need to be predicted, they must be in the same order as the corresponding test inputs.
"""
submission_tasks = load_tasks(cfg.submission_challenges, None, cfg)
predictions = {}
for i, ckpt in enumerate([attempt1_ckpt, attempt2_ckpt]):
    params = load_checkpoint(ckpt)
    for task_id, train_in, train_out, test_in, _ in datagen(submission_tasks, cfg, mode="submission"):
        model_output = model(params, train_in, train_out, test_in)
        test_out = predict_fn(model_output)
        for b, id in enumerate(task_id):
            if id not in predictions:
                predictions[id] = []
            predictions[id].append({f"attempt_{i+1}" : test_out[b].tolist()})
submission_filepath = os.path.join(output_dir, "submission.json")

with open(submission_filepath, 'w') as f:
    json.dump(predictions, f)

results = {"accuracy": best_valid_acc, "loss": best_valid_loss}
results_filepath = os.path.join(output_dir, "results.json")

with open(results_filepath, 'w') as f:
    json.dump(results, f, indent=4)

if not cfg.compute_backend == "kaggle":
    wandb.save(submission_filepath)
    wandb.save(results_filepath)
    wandb.finish()