In [3]:

from jax import Array
from sklearn.manifold import TSNE
from sklearn.decomposition import PCA
import matplotlib.colors as mcolors
import numpy as np
import matplotlib.colors as mcolors
import matplotlib.pyplot as plt
import seaborn as sns
import torch


from arc_prize.vis import COLORS

def visualize_embeddings(embeddings: Array, num_grids: int):
    """
    Visualize embeddings using t-SNE or PCA.
    
    :param embeddings: torch.Tensor of shape (n_samples, embedding_dim)
    :param labels: list of labels for each embedding (optional)
    :param method: 'tsne' or 'pca'
    """
    embeddings_np = embeddings
    
    
    tsne_reducer = TSNE(n_components=2, random_state=42)
    pca_reducer = PCA(n_components=2)
    
    reduced_tsne_embeddings = tsne_reducer.fit_transform(embeddings_np)
    reduced_pca_embeddings = pca_reducer.fit_transform(embeddings_np)
    
    cmap = mcolors.ListedColormap(COLORS)
    per_grid = len(reduced_tsne_embeddings) // num_grids
    grid_indices = torch.arange(num_grids).repeat_interleave(per_grid)
    
    fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(2, 2, figsize=(20, 16))

    scatter1 = ax1.scatter(reduced_tsne_embeddings[:, 0], reduced_tsne_embeddings[:, 1], c=grid_indices)
    plt.colorbar(scatter1, ax=ax1)
    ax1.set_title("Embeddings with TSNE by grid")

    scatter2 = ax2.scatter(reduced_pca_embeddings[:, 0], reduced_pca_embeddings[:, 1], c=grid_indices)
    plt.colorbar(scatter2, ax=ax2)
    ax2.set_title("Embeddings with PCA by grid")


    plt.show()

In [None]:
from arc_prize.flax.models import create_arc_fixed_positional_encoding


enc = create_arc_fixed_positional_encoding(32, 10, 4)

print(enc.shape)

visualize_embeddings(enc.reshape(-1, 32), num_grids=9)



In [1]:
from arc_prize.flax.models import ARCTransformerEncoderDecoderParams
from arc_prize.flax.train import TrainParams, train_and_evaluate_local

model_params = ARCTransformerEncoderDecoderParams(
  grid_dim=10,
  num_train_pairs=4,
  num_colors=10,
  num_encoder_layers=2,
  num_decoder_layers=2,
  num_heads=2,
  d_model=16,
  d_ff=16*4,
  dropout=0.1
)

train_params = TrainParams(
  batch_size=20,
  learning_rate=1e-4,
  weight_decay=1e-4,
  warmup_steps=5,
  train_steps_per_epoch=10,
  eval_steps_per_epoch=5,
  dataset_dirs=["/Users/pfh/work/arc-data/flip"],
  loss_class_weights={0: 0.2}
)

model_dir = "/Users/pfh/work/arc-models/flax-4"
num_epochs = 6

train_and_evaluate_local(model_dir, num_epochs, model_params, train_params)

Latest step 14
GPU devices [CpuDevice(id=0)]
Starting training run with dataset of 1000 training items and 200 evaluation items: /Users/pfh/work/arc-data/flip
Using batch size of 20
Starting epoch 15/20
Train loss (completed in 10.95s): 1.4297, accuracy: 0.6848
Eval loss (competed in 2.18s): 1.4241, accuracy: 0.6801
Starting epoch 16/20
Train loss (completed in 6.52s): 1.4182, accuracy: 0.6839
Eval loss (competed in 1.57s): 1.4129, accuracy: 0.6852
Starting epoch 17/20
Train loss (completed in 6.31s): 1.4027, accuracy: 0.6879
Eval loss (competed in 1.56s): 1.3867, accuracy: 0.6881
Starting epoch 18/20
Train loss (completed in 6.45s): 1.4113, accuracy: 0.6838
Eval loss (competed in 1.60s): 1.3753, accuracy: 0.6908
Starting epoch 19/20
Train loss (completed in 6.52s): 1.4005, accuracy: 0.6852
Eval loss (competed in 1.87s): 1.3624, accuracy: 0.6950
Starting epoch 20/20


KeyboardInterrupt: 

In [1]:
from arc_prize.flax.train import get_config_params, predict

model_dir = "/Users/pfh/work/arc-models/flax-2"
dataset_dir = "/Users/pfh/work/arc-data/flip"
num_steps = 10

model_params, _ = get_config_params(model_dir)

predict(model_dir, model_params, dataset_dir, num_steps)



GPU devices [CpuDevice(id=0)]
latest step 6




graphdef None
keys dict_keys(['graphdef', 'opt_state', 'other_state', 'params', 'step'])


AttributeError: 'dict' object has no attribute 'flat_state'

In [2]:
from arc_prize.flax.models import ARCTransformerEncoderDecoderParams
from arc_prize.flax.train import TrainParams
import modal
import petname

model_params = ARCTransformerEncoderDecoderParams(
  grid_dim=20,
  num_train_pairs=4,
  num_colors=10,
  num_encoder_layers=6,
  num_decoder_layers=6,
  num_heads=16,
  d_model=256,
  d_ff=256*4,
  dropout=0.1
)

train_params = TrainParams(
  batch_size=32,
  learning_rate=1e-4,
  weight_decay=1e-4,
  warmup_steps=3,
  train_steps_per_epoch=300,
  eval_steps_per_epoch=50,
  dataset_dirs=["/vol/data/html_dim_20_20240925"],
  loss_class_weights={0: 0.2}
)

model_name = petname.generate(words=3, separator='_')
model_dir = f"/vol/models/{model_name}"
num_epochs = 20


fn = modal.Function.lookup("arc-jax", "train")
fn_call = fn.spawn(model_dir, model_params, train_params, num_epochs)
print("Model name", model_name, fn_call.object_id)


Model name mainly_neat_mite fc-01JAC2TTXC4K6AAYCNA2VFJ1KV


In [None]:
from flax import nnx

rng1 = nnx.Rngs(0, params=1)
rng2 = nnx.Rngs(0)
print(rng1, rng2)