In [None]:
import os
import sys
sys.path.append("..")

import hydra
import omegaconf
import jax
import optax
from flax.training.train_state import TrainState
from flax.serialization import from_bytes
import json

from src.models.lpn import LPN
from src.evaluator import Evaluator
from src.models.transformer import EncoderTransformer, DecoderTransformer

> WARNING: first download the desired checkpoint folder using e.g. the `load_checkpoint` notebook. Then load it like below:

In [None]:
local_checkpoint_path = "/kaggle/input/absurd-river-1068-checkpointv2/jax/default/1/absurd-river-1068--checkpoint:v2"
cfg = omegaconf.OmegaConf.load(os.path.join(local_checkpoint_path, "config.yaml"))

In [None]:
encoder = EncoderTransformer(hydra.utils.instantiate(cfg.encoder_transformer))
decoder = DecoderTransformer(hydra.utils.instantiate(cfg.decoder_transformer))
lpn = LPN(encoder=encoder, decoder=decoder)

key = jax.random.PRNGKey(0)
grids = jax.random.randint(
    key, (1, 3, decoder.config.max_rows, decoder.config.max_cols, 2), minval=0, maxval=decoder.config.vocab_size,
)
shapes = jax.random.randint(
    key, (1, 3, 2, 2), minval=1, maxval=min(decoder.config.max_rows, decoder.config.max_cols) + 1,
)
variables = lpn.init(key, grids, shapes, dropout_eval=False, prior_kl_coeff=0.0, pairwise_kl_coeff=0.0, mode="mean")
learning_rate, linear_warmup_steps = 0, 0
linear_warmup_scheduler = optax.warmup_exponential_decay_schedule(
    init_value=learning_rate / (linear_warmup_steps + 1),
    peak_value=learning_rate,
    warmup_steps=linear_warmup_steps,
    transition_steps=1,
    end_value=learning_rate,
    decay_rate=1.0,
)
optimizer = optax.chain(optax.clip_by_global_norm(1.0), optax.adamw(linear_warmup_scheduler))
optimizer = optax.MultiSteps(optimizer, every_k_schedule=1)
train_state = TrainState.create(apply_fn=lpn.apply, tx=optimizer, params=variables["params"])

In [None]:
with open(os.path.join(local_checkpoint_path, "state.msgpack"), "rb") as data_file:
    byte_data = data_file.read()
loaded_state = from_bytes(train_state, byte_data)
loaded_state = jax.device_put_replicated(loaded_state, jax.devices())

# Submision

In [None]:
evaluator = Evaluator(
    lpn,
    inference_mode="gradient_ascent",
    inference_mode_kwargs={
        "num_steps": 200,
        "lr": 1.0,
        "lr_schedule": True,
        "optimizer": "adam",
        "optimizer_kwargs": {"b2": 0.9},
        "accumulate_gradients_decoder_pairs": True,
    },
)

with open("/kaggle/input/arc-prize-2024/arc-agi_test_challenges.json", "r") as f:
    challenges = json.load(f)
generations = evaluator.json_submission(challenges, loaded_state.params, progress_bar=True)

In [None]:
with open("submission.json", "w") as f:
    json.dump(generations, f, indent=4)