In [1]:
# Generic
from pathlib import Path
import typing
import itertools

# Numerical Computing
import numpy as np
import torch
import pandas as pd
from jaxtyping import Bool, Int, Float
from muutils.nbutils.configure_notebook import configure_notebook

# Our Code
# dataset stuff
from maze_dataset import MazeDataset, MazeDatasetConfig, SolvedMaze, LatticeMaze, SPECIAL_TOKENS, LatticeMazeGenerators, CoordArray
from maze_dataset.tokenization import MazeTokenizer, TokenizationMode

# model stuff
from maze_transformer.training.config import ZanjHookedTransformer

# mechinterp stuff
from maze_transformer.mechinterp.logit_attrib_task import LOGIT_ATTRIB_TASKS
from maze_transformer.evaluation.load_model_testdata import load_model_with_test_data
from maze_transformer.evaluation.eval_single_token_tasks import TaskPrompt, TaskEvalResult, get_task_prompts_targets, eval_model_across_tasks
from maze_transformer.mechinterp.logit_attrib_task import LOGIT_ATTRIB_TASKS
from maze_transformer.evaluation.eval_model import predict_maze_paths



In [2]:
# Setup (we won't be training any models)
DEVICE: torch.device = configure_notebook(seed=42, dark_mode=False)
print(f"{DEVICE = }")
torch.set_grad_enabled(False)

DEVICE = device(type='cuda')


<torch.autograd.grad_mode.set_grad_enabled at 0x18e293f47d0>

In [3]:

MODELS_PATHS: list[str] = [
	"../examples/model.hallway-jvq.final.zanj",
	"../examples/wandb.jerpkipj.zanj",
	# "../examples/model.hallway-insto.v17.zanj",
	# "../examples/model.hallway-insto.final.zanj",
	# "../examples/hallway-orig.zanj",
]

GRID_N: int = 6
N_MAZES: int = 128

DATASET_CFG_SOURCES: list[MazeDatasetConfig] = [
	MazeDatasetConfig(
		name="rdfs",
		grid_n=GRID_N,
		n_mazes=N_MAZES,
		maze_ctor=LatticeMazeGenerators.gen_dfs,
		maze_ctor_kwargs=dict(do_forks=True),
	),
	MazeDatasetConfig(
		name="prdfs",
		grid_n=GRID_N,
		n_mazes=N_MAZES,
		maze_ctor=LatticeMazeGenerators.gen_dfs_percolation,
		maze_ctor_kwargs=dict(p=0.1),
	),
	MazeDatasetConfig(
		name="hallway",
		grid_n=GRID_N,
		n_mazes=N_MAZES,
		maze_ctor=LatticeMazeGenerators.gen_dfs,
		maze_ctor_kwargs=dict(do_forks=False),
	),
]

DATASETS: list[MazeDataset] = [
	MazeDataset.from_config(dcs)
	for dcs in DATASET_CFG_SOURCES
]


In [6]:



def generate_eval_table(
	models_paths: list[str],
	datasets: list[MazeDataset],
	rollouts_acc: int|None = None,
) -> pd.DataFrame:
	
	output: list[dict] = list()
	
	for model_path in models_paths:
		model, dataset = load_model_with_test_data(
			model_path=model_path,
			dataset_cfg_source=None,
		)

		for dataset in datasets:
			
			print(f"evaluating {model_path} on {dataset.cfg.name}")
			tokenizer: MazeTokenizer = model.zanj_model_config.maze_tokenizer

			task_prompt_targets: dict[str, TaskPrompt] = get_task_prompts_targets(
				dataset=dataset,
				maze_tokenizer=tokenizer,
				tasks=LOGIT_ATTRIB_TASKS,
			)

			print(f"\trunning evals")
			task_results: dict[str, TaskEvalResult] = eval_model_across_tasks(
				model = model,
				task_prompts = task_prompt_targets,
				do_cache=False,
			)

			output_item: dict = dict(
				model=model.config.name,
				dataset=dataset.cfg.to_fname(),
				**{
					task_name: task_result.predicted_correct.float().mean().item()
					for task_name, task_result in task_results.items()
				}
			)

			if rollouts_acc is not None:
				print(f"\trunning rollout evals")
				# do rollouts
				predictions: list[list[tuple[int, int]]] = predict_maze_paths(
					tokens_batch=dataset.as_tokens(maze_tokenizer=tokenizer)[:rollouts_acc],
					data_cfg=dataset.cfg,
					model=model,
					max_new_tokens=15,
				)
				predictions_np: list[CoordArray] = [
					np.array(path)
					for path in predictions
				]

				mazes: list[SolvedMaze] = dataset.mazes[:rollouts_acc]
				assert len(mazes) == len(predictions_np)

				# evaluate rollouts
				exact_correct: list[bool] = list()
				valid_path: list[bool] = list()
				target_correct: list[bool] = list()

				for p, m in zip(predictions_np, mazes):
					if p.max() >= m.grid_n:
						# if it goes out of bounds, it's not valid
						exact_correct.append(False)
						valid_path.append(False)
					else:
						exact_correct.append(
							np.all(p == m.solution)
							if p.shape == m.solution.shape
							else False
						)
						valid_path.append(m.is_valid_path(p))
					
					target_correct.append(np.all(p[-1] == m.end_pos))

				output_item["exactly correct rollouts"] = np.mean(exact_correct)
				output_item["valid rollouts"] = np.mean(valid_path)
				output_item["rollouts with target reached"] = np.mean(target_correct)

			output.append(output_item)

	return pd.DataFrame(output)

In [7]:
RESULTS: pd.DataFrame = generate_eval_table(
	models_paths=MODELS_PATHS,
	datasets=DATASETS,
	rollouts_acc=128,
)

loaded model with 1.2M params (num_params = 1238076) from
../examples/model.hallway-jvq.final.zanj
original model name: 'model.zanj_model_config.name = 'hallway_v3'', changing to 'hallway-jvq.final'
loaded dataset with 100 examples
dataset.cfg.summary() = {'name': 'hallway', 'fname': 'hallway-g7-n100-a_dfs-h52723', 'sdc_hash': 110302684828161283175731365142829285748232501725876255654215979074922758852723, 'seed': 42, 'seq_len_min': 1, 'seq_len_max': 256, 'applied_filters': [{'name': 'collect_generation_meta', 'args': [], 'kwargs': {}}], 'grid_n': 7, 'grid_shape': (7, 7), 'n_mazes': 100, 'maze_ctor_name': 'gen_dfs', 'maze_ctor_kwargs': {'do_forks': False}}
evaluating ../examples/model.hallway-jvq.final.zanj on rdfs
	running evals
	running rollout evals
evaluating ../examples/model.hallway-jvq.final.zanj on prdfs
	running evals
	running rollout evals
evaluating ../examples/model.hallway-jvq.final.zanj on hallway
	running evals
	running rollout evals
loaded model with 9.6M params (num_par

In [8]:
RESULTS

Unnamed: 0,model,dataset,path_start,origin_after_path_start,first_path_choice,path_end,final_before_path_end,rand_path_token,rand_path_token_non_endpoint,exactly correct rollouts,valid rollouts,rollouts with target reached
0,hallway-jvq.final,rdfs-g6-n128-a_dfs-h106,1.0,1.0,0.703125,1.0,0.914062,0.882812,0.914062,0.289062,0.65625,0.414062
1,hallway-jvq.final,prdfs-g6-n128-a_dfs_percolation-h87605,1.0,1.0,0.625,1.0,0.875,0.875,0.867188,0.289062,0.765625,0.460938
2,hallway-jvq.final,hallway-g6-n128-a_dfs-h60815,1.0,0.890625,0.726562,0.890625,0.953125,0.859375,0.96875,0.484375,0.820312,0.671875
3,jerpkipj,rdfs-g6-n128-a_dfs-h106,1.0,1.0,0.867188,1.0,0.992188,1.0,1.0,0.640625,0.859375,0.734375
4,jerpkipj,prdfs-g6-n128-a_dfs_percolation-h87605,1.0,1.0,0.796875,1.0,1.0,0.945312,0.960938,0.640625,0.882812,0.90625
5,jerpkipj,hallway-g6-n128-a_dfs-h60815,1.0,0.851562,0.695312,0.773438,0.890625,0.828125,0.921875,0.351562,0.617188,0.4375


In [9]:
RESULTS.to_json("eval_results.json", index=True, orient="records", lines=True)

In [10]:
_DATASET_MAP: dict[str,str] = {
	"hallway": "a:forkless",
	"rdfs": "b:RDFS",
	"prdfs": "c:pRDFS",
}

RESULTS['model/dataset'] = (
	RESULTS['model'].apply(lambda x: x.split('-')[0]) 
	+ '/' 
	+ RESULTS['dataset'].apply(lambda x: _DATASET_MAP[x.split('-')[0]])
)

In [11]:
transposed_df = RESULTS.drop(['model', 'dataset'], axis=1).set_index('model/dataset').T.reset_index()
transposed_df.rename(columns={'index': 'Metric/'}, inplace=True)
# convert all to percentages
transposed_df = transposed_df.applymap(lambda x: f"{x:.1%}" if isinstance(x, float) else x)
transposed_df
# columns by what comes after the slash
transposed_df = transposed_df[sorted(transposed_df.columns, key=lambda x: x.split('/')[1])]
transposed_df

model/dataset,Metric/,hallway/a:forkless,jerpkipj/a:forkless,hallway/b:RDFS,jerpkipj/b:RDFS,hallway/c:pRDFS,jerpkipj/c:pRDFS
0,path_start,100.0%,100.0%,100.0%,100.0%,100.0%,100.0%
1,origin_after_path_start,89.1%,85.2%,100.0%,100.0%,100.0%,100.0%
2,first_path_choice,72.7%,69.5%,70.3%,86.7%,62.5%,79.7%
3,path_end,89.1%,77.3%,100.0%,100.0%,100.0%,100.0%
4,final_before_path_end,95.3%,89.1%,91.4%,99.2%,87.5%,100.0%
5,rand_path_token,85.9%,82.8%,88.3%,100.0%,87.5%,94.5%
6,rand_path_token_non_endpoint,96.9%,92.2%,91.4%,100.0%,86.7%,96.1%
7,exactly correct rollouts,48.4%,35.2%,28.9%,64.1%,28.9%,64.1%
8,valid rollouts,82.0%,61.7%,65.6%,85.9%,76.6%,88.3%
9,rollouts with target reached,67.2%,43.8%,41.4%,73.4%,46.1%,90.6%


In [12]:
print(transposed_df.to_latex(index=False))

\begin{tabular}{lllllll}
\toprule
                     Metric/ & hallway/a:forkless & jerpkipj/a:forkless & hallway/b:RDFS & jerpkipj/b:RDFS & hallway/c:pRDFS & jerpkipj/c:pRDFS \\
\midrule
                  path\_start &             100.0\% &              100.0\% &         100.0\% &          100.0\% &          100.0\% &           100.0\% \\
     origin\_after\_path\_start &              89.1\% &               85.2\% &         100.0\% &          100.0\% &          100.0\% &           100.0\% \\
           first\_path\_choice &              72.7\% &               69.5\% &          70.3\% &           86.7\% &           62.5\% &            79.7\% \\
                    path\_end &              89.1\% &               77.3\% &         100.0\% &          100.0\% &          100.0\% &           100.0\% \\
       final\_before\_path\_end &              95.3\% &               89.1\% &          91.4\% &           99.2\% &           87.5\% &           100.0\% \\
             rand\_path\_token &   


In future versions `DataFrame.to_latex` is expected to utilise the base implementation of `Styler.to_latex` for formatting and rendering. The arguments signature may therefore change. It is recommended instead to use `DataFrame.style.to_latex` which also contains additional functionality.

