In [1]:
# Generic
from pathlib import Path
import typing
import itertools

# Numerical Computing
import numpy as np
import torch
import pandas as pd
from jaxtyping import Bool, Int, Float
from muutils.nbutils.configure_notebook import configure_notebook

# Our Code
# dataset stuff
from maze_dataset import MazeDataset, MazeDatasetConfig, SolvedMaze, LatticeMaze, SPECIAL_TOKENS, LatticeMazeGenerators, CoordArray
from maze_dataset.tokenization import MazeTokenizer, TokenizationMode

# model stuff
from maze_transformer.training.config import ZanjHookedTransformer

# mechinterp stuff
from maze_transformer.mechinterp.logit_attrib_task import LOGIT_ATTRIB_TASKS
from maze_transformer.evaluation.load_model_testdata import load_model_with_test_data
from maze_transformer.evaluation.eval_single_token_tasks import TaskPrompt, TaskEvalResult, get_task_prompts_targets, eval_model_across_tasks
from maze_transformer.mechinterp.logit_attrib_task import LOGIT_ATTRIB_TASKS
from maze_transformer.evaluation.eval_model import predict_maze_paths



In [2]:
# Setup (we won't be training any models)
DEVICE: torch.device = configure_notebook(seed=42, dark_mode=False)
print(f"{DEVICE = }")
torch.set_grad_enabled(False)

DEVICE = device(type='cuda')


<torch.autograd.grad_mode.set_grad_enabled at 0x1eab433ee50>

In [3]:

MODELS_PATHS: list[str] = [
	"../examples/model.hallway-jvq.final.zanj",
	"../examples/wandb.jerpkipj.zanj",
	# "../examples/model.hallway-insto.v17.zanj",
	# "../examples/model.hallway-insto.final.zanj",
	# "../examples/hallway-orig.zanj",
]

DATASET_CFG_SOURCES: list[MazeDatasetConfig] = [
	MazeDatasetConfig(
		name="rdfs",
		grid_n=7,
		n_mazes=128,
		maze_ctor=LatticeMazeGenerators.gen_dfs,
		maze_ctor_kwargs=dict(do_forks=True),
	),
	MazeDatasetConfig(
		name="prdfs",
		grid_n=7,
		n_mazes=128,
		maze_ctor=LatticeMazeGenerators.gen_dfs_percolation,
		maze_ctor_kwargs=dict(p=0.1),
	),
	MazeDatasetConfig(
		name="hallway",
		grid_n=7,
		n_mazes=128,
		maze_ctor=LatticeMazeGenerators.gen_dfs,
		maze_ctor_kwargs=dict(do_forks=False),
	),
]

DATASETS: list[MazeDataset] = [
	MazeDataset.from_config(dcs)
	for dcs in DATASET_CFG_SOURCES
]


In [4]:



def generate_eval_table(
	models_paths: list[str],
	datasets: list[MazeDataset],
	rollouts_acc: int|None = None,
) -> pd.DataFrame:
	
	output: list[dict] = list()
	
	for model_path in models_paths:
		model, dataset = load_model_with_test_data(
			model_path=model_path,
			dataset_cfg_source=None,
		)

		for dataset in datasets:
			
			print(f"evaluating {model_path} on {dataset.cfg.name}")
			tokenizer: MazeTokenizer = model.zanj_model_config.maze_tokenizer

			task_prompt_targets: dict[str, TaskPrompt] = get_task_prompts_targets(
				dataset=dataset,
				maze_tokenizer=tokenizer,
				tasks=LOGIT_ATTRIB_TASKS,
			)

			print(f"\trunning evals")
			task_results: dict[str, TaskEvalResult] = eval_model_across_tasks(
				model = model,
				task_prompts = task_prompt_targets,
				do_cache=False,
			)

			output_item: dict = dict(
				model=model.config.name,
				dataset=dataset.cfg.to_fname(),
				**{
					task_name: task_result.predicted_correct.float().mean().item()
					for task_name, task_result in task_results.items()
				}
			)

			if rollouts_acc is not None:
				print(f"\trunning rollout evals")
				# do rollouts
				predictions: list[list[tuple[int, int]]] = predict_maze_paths(
					tokens_batch=dataset.as_tokens(maze_tokenizer=tokenizer)[:rollouts_acc],
					data_cfg=dataset.cfg,
					model=model,
					max_new_tokens=15,
				)
				predictions_np: list[CoordArray] = [
					np.array(path)
					for path in predictions
				]

				mazes: list[SolvedMaze] = dataset.mazes[:rollouts_acc]
				assert len(mazes) == len(predictions_np)

				# evaluate rollouts
				exact_correct: list[bool] = list()
				valid_path: list[bool] = list()
				target_correct: list[bool] = list()

				for p, m in zip(predictions_np, mazes):
					exact_correct.append(
						np.all(p == m.solution)
						if p.shape == m.solution.shape
						else False
					)
					valid_path.append(m.is_valid_path(p))
					target_correct.append(np.all(p[-1] == m.end_pos))

				output_item["exactly correct rollouts"] = np.mean(exact_correct)
				output_item["valid rollouts"] = np.mean(valid_path)
				output_item["rollouts with target reached"] = np.mean(target_correct)

			output.append(output_item)

	return pd.DataFrame(output)

In [5]:
RESULTS: pd.DataFrame = generate_eval_table(
	models_paths=MODELS_PATHS,
	datasets=DATASETS,
	rollouts_acc=128,
)

loaded model with 1.2M params (num_params = 1238076) from
../examples/model.hallway-jvq.final.zanj
original model name: 'model.zanj_model_config.name = 'hallway_v3'', changing to 'hallway-jvq.final'
loaded dataset with 100 examples
dataset.cfg.summary() = {'name': 'hallway', 'fname': 'hallway-g7-n100-a_dfs-h52723', 'sdc_hash': 110302684828161283175731365142829285748232501725876255654215979074922758852723, 'seed': 42, 'seq_len_min': 1, 'seq_len_max': 256, 'applied_filters': [{'name': 'collect_generation_meta', 'args': [], 'kwargs': {}}], 'grid_n': 7, 'grid_shape': (7, 7), 'n_mazes': 100, 'maze_ctor_name': 'gen_dfs', 'maze_ctor_kwargs': {'do_forks': False}}
evaluating ../examples/model.hallway-jvq.final.zanj on rdfs
	running evals
	running rollout evals
evaluating ../examples/model.hallway-jvq.final.zanj on prdfs
	running evals
	running rollout evals
evaluating ../examples/model.hallway-jvq.final.zanj on hallway
	running evals
	running rollout evals
loaded model with 9.6M params (num_par

In [6]:
RESULTS

Unnamed: 0,model,dataset,path_start,origin_after_path_start,first_path_choice,path_end,final_before_path_end,rand_path_token,rand_path_token_non_endpoint,exactly correct rollouts,valid rollouts,rollouts with target reached
0,hallway-jvq.final,rdfs-g7-n128-a_dfs-h16527,1.0,1.0,0.679688,1.0,0.9375,0.90625,0.929688,0.234375,0.6875,0.320312
1,hallway-jvq.final,prdfs-g7-n128-a_dfs_percolation-h45426,1.0,0.96875,0.632812,1.0,0.828125,0.851562,0.8125,0.195312,0.757812,0.289062
2,hallway-jvq.final,hallway-g7-n128-a_dfs-h67447,1.0,0.929688,0.71875,0.90625,1.0,0.90625,0.976562,0.507812,0.8125,0.625
3,jerpkipj,rdfs-g7-n128-a_dfs-h16527,1.0,1.0,0.648438,1.0,0.90625,0.867188,0.84375,0.210938,0.304688,0.539062
4,jerpkipj,prdfs-g7-n128-a_dfs_percolation-h45426,1.0,1.0,0.625,1.0,0.898438,0.796875,0.789062,0.203125,0.367188,0.585938
5,jerpkipj,hallway-g7-n128-a_dfs-h67447,1.0,0.921875,0.703125,0.828125,0.898438,0.875,0.953125,0.414062,0.6875,0.554688


In [7]:
RESULTS.to_json("eval_results.json", index=True, orient="records", lines=True)

In [8]:
RESULTS['model/dataset'] = (
	RESULTS['model'].apply(lambda x: x.split('-')[0]) 
	+ '/' 
	+ RESULTS['dataset'].apply(lambda x: x.split('-')[0])
)

In [16]:
transposed_df = RESULTS.drop(['model', 'dataset'], axis=1).set_index('model/dataset').T.reset_index()
transposed_df.rename(columns={'index': 'Metric/'}, inplace=True)
# convert all to percentages
transposed_df = transposed_df.applymap(lambda x: f"{x:.1%}" if isinstance(x, float) else x)
transposed_df
# columns by what comes after the slash
transposed_df = transposed_df[sorted(transposed_df.columns, key=lambda x: x.split('/')[1])]
transposed_df

model/dataset,Metric/,hallway/hallway,jerpkipj/hallway,hallway/prdfs,jerpkipj/prdfs,hallway/rdfs,jerpkipj/rdfs
0,path_start,100.0%,100.0%,100.0%,100.0%,100.0%,100.0%
1,origin_after_path_start,93.0%,92.2%,96.9%,100.0%,100.0%,100.0%
2,first_path_choice,71.9%,70.3%,63.3%,62.5%,68.0%,64.8%
3,path_end,90.6%,82.8%,100.0%,100.0%,100.0%,100.0%
4,final_before_path_end,100.0%,89.8%,82.8%,89.8%,93.8%,90.6%
5,rand_path_token,90.6%,87.5%,85.2%,79.7%,90.6%,86.7%
6,rand_path_token_non_endpoint,97.7%,95.3%,81.2%,78.9%,93.0%,84.4%
7,exactly correct rollouts,50.8%,41.4%,19.5%,20.3%,23.4%,21.1%
8,valid rollouts,81.2%,68.8%,75.8%,36.7%,68.8%,30.5%
9,rollouts with target reached,62.5%,55.5%,28.9%,58.6%,32.0%,53.9%


In [17]:
print(transposed_df.to_latex(index=False))

\begin{tabular}{lllllll}
\toprule
                     Metric/ & hallway/hallway & jerpkipj/hallway & hallway/prdfs & jerpkipj/prdfs & hallway/rdfs & jerpkipj/rdfs \\
\midrule
                  path\_start &          100.0\% &           100.0\% &        100.0\% &         100.0\% &       100.0\% &        100.0\% \\
     origin\_after\_path\_start &           93.0\% &            92.2\% &         96.9\% &         100.0\% &       100.0\% &        100.0\% \\
           first\_path\_choice &           71.9\% &            70.3\% &         63.3\% &          62.5\% &        68.0\% &         64.8\% \\
                    path\_end &           90.6\% &            82.8\% &        100.0\% &         100.0\% &       100.0\% &        100.0\% \\
       final\_before\_path\_end &          100.0\% &            89.8\% &         82.8\% &          89.8\% &        93.8\% &         90.6\% \\
             rand\_path\_token &           90.6\% &            87.5\% &         85.2\% &          79.7\% &        90.6\


In future versions `DataFrame.to_latex` is expected to utilise the base implementation of `Styler.to_latex` for formatting and rendering. The arguments signature may therefore change. It is recommended instead to use `DataFrame.style.to_latex` which also contains additional functionality.

