### Preamble

In [None]:
import datetime, json, math, os, platform, random, sys, time

from dataclasses import dataclass
from pathlib import Path
from typing import Dict, List

import numpy as np
import torch
from datasets import load_dataset
from rich.progress import BarColumn, MofNCompleteColumn, Progress, TextColumn, TimeElapsedColumn, TimeRemainingColumn
from tqdm import tqdm
from transformers import AutoModelForCausalLM, AutoTokenizer
import datasets as datasets_lib
import transformers as transformers_lib

HF_TOKEN: str | None = None
GLOBAL_SEED = 42
LAMBDA_VALUES = [-5.0, 0.0, 1.0, 5.0]

def set_seed(seed_value: int) -> None:
	random.seed(seed_value)
	os.environ["PYTHONHASHSEED"] = str(seed_value)
	np.random.seed(seed_value)
	torch.manual_seed(seed_value)
	if torch.cuda.is_available():
		torch.cuda.manual_seed_all(seed_value)

def maybe_configure_hf_auth() -> None:
	if not HF_TOKEN:
		return
	os.environ["HF_TOKEN"] = HF_TOKEN
	os.environ["HUGGING_FACE_HUB_TOKEN"] = HF_TOKEN
	os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"

torch.use_deterministic_algorithms(True, warn_only=True)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

set_seed(GLOBAL_SEED)
maybe_configure_hf_auth()

### Experiment 1 (BoolQ)

In [None]:
MAX_ROWS_EXPERIMENT_1 = 0

BOOLQ_MODELS = [
	"Qwen/Qwen2.5-0.5B-Instruct",
	"google/gemma-2-2b-it",
	"meta-llama/Llama-3.1-8B-Instruct"
]

def build_vibe_embeddings(model: AutoModelForCausalLM, device: torch.device) -> torch.Tensor:
	embedding_weight = model.get_input_embeddings().weight.detach().to(device=device, dtype=torch.float32)
	norms = embedding_weight.norm(dim=-1, keepdim=True).clamp_min(1e-12)
	return embedding_weight / norms

def compute_label_base_and_alignment(prefix_token_ids: List[int], label_token_ids: List[int], model: AutoModelForCausalLM, vibe_embeddings: torch.Tensor, device: torch.device) -> tuple[float, float]:
	total_log_probability = 0.0
	total_alignment = 0.0
	prefix = list(prefix_token_ids)
	for token_id in label_token_ids:
		input_tensor = torch.tensor([prefix], dtype=torch.long, device=device)
		with torch.no_grad():
			outputs = model(input_ids=input_tensor)
			logits = outputs.logits[:, -1, :].to(torch.float32)
			log_probabilities = torch.nn.functional.log_softmax(logits, dim=-1)
			probabilities = torch.exp(log_probabilities)
			direction_vector = probabilities @ vibe_embeddings
			token_vector = vibe_embeddings[token_id]
			alignment_value = torch.dot(token_vector, direction_vector.squeeze(0))
			log_probability_token = log_probabilities[0, token_id]
		total_log_probability += float(log_probability_token.detach().cpu())
		total_alignment += float(alignment_value.detach().cpu())
		prefix.append(int(token_id))
	return total_log_probability, total_alignment

def compute_classification_metrics(true_label_ids: List[int], predictions_for_lambda: Dict[str, List[int]]) -> Dict[str, Dict[str, float]]:
	true_tensor = torch.tensor(true_label_ids, dtype=torch.long)
	metrics: Dict[str, Dict[str, float]] = {}
	for lambda_key, predicted_ids in predictions_for_lambda.items():
		predicted_tensor = torch.tensor(predicted_ids, dtype=torch.long)
		correct_count = int((predicted_tensor == true_tensor).sum().item())
		accuracy_value = correct_count / len(true_label_ids) if len(true_label_ids) > 0 else 0.0
		f1_macro = 0.0
		for class_value in [0, 1]:
			class_tensor = torch.full_like(true_tensor, class_value)
			true_positive = int(((true_tensor == class_tensor) & (predicted_tensor == class_tensor)).sum().item())
			false_positive = int(((true_tensor != class_tensor) & (predicted_tensor == class_tensor)).sum().item())
			false_negative = int(((true_tensor == class_tensor) & (predicted_tensor != class_tensor)).sum().item())
			if true_positive == 0 and false_positive == 0 and false_negative == 0:
				f1_class = 0.0
			else:
				precision_value = true_positive / (true_positive + false_positive) if true_positive + false_positive > 0 else 0.0
				recall_value = true_positive / (true_positive + false_negative) if true_positive + false_negative > 0 else 0.0
				if precision_value + recall_value == 0.0:
					f1_class = 0.0
				else:
					f1_class = 2.0 * precision_value * recall_value / (precision_value + recall_value)
			f1_macro += f1_class
		f1_macro /= 2.0
		metrics[lambda_key] = {"accuracy": float(accuracy_value), "macro_f1": float(f1_macro)}
	return metrics

def run_boolq_for_model(model_id: str, validation_split, lambda_values: List[float]) -> Dict[str, object]:
	device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
	set_seed(GLOBAL_SEED)
	tokenizer = AutoTokenizer.from_pretrained(model_id, token=HF_TOKEN)
	model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.float16, token=HF_TOKEN)
	model.to(device)
	model.eval()
	vibe_embeddings = build_vibe_embeddings(model, device)
	label_texts_by_name = {"yes": " yes", "no": " no"}
	label_token_ids: Dict[str, List[int]] = {}
	for label_name, label_text in label_texts_by_name.items():
		encoding = tokenizer(label_text, add_special_tokens=False)
		label_token_ids[label_name] = [int(token_id) for token_id in encoding["input_ids"]]
	alpha_keys = [str(value) for value in lambda_values]
	predictions_by_alpha: Dict[str, List[int]] = {alpha_key: [] for alpha_key in alpha_keys}
	true_labels: List[int] = []
	example_records: List[Dict[str, object]] = []
	start_timestamp = time.time()
	progress_columns = [
		TextColumn("{task.description}"),
		BarColumn(),
		MofNCompleteColumn(),
		TimeElapsedColumn(),
		TimeRemainingColumn()
	]
	with Progress(*progress_columns) as progress:
		task_identifier = progress.add_task(f"Evaluating {model_id}", total=len(validation_split))
		for row_index in range(len(validation_split)):
			row = validation_split[int(row_index)]
			question_text = str(row["question"])
			passage_text = str(row["passage"])
			answer_boolean = bool(row["answer"])
			gold_label_name = "yes" if answer_boolean else "no"
			gold_label_id = 1 if answer_boolean else 0
			true_labels.append(gold_label_id)
			prompt_text = "Passage:\n" + passage_text + "\n\nQuestion:\n" + question_text + "\n\nAnswer the question with a single word: yes or no."
			messages = [
				{"role": "system", "content": "You are a question answering assistant. Answer with a single word: \"yes\" or \"no\"."},
				{"role": "user", "content": prompt_text}
			]
			try:
				base_text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
			except Exception as template_error:
				error_message = str(template_error)
				if "System role not supported" in error_message:
					messages = [
						{"role": "user", "content": "You are a question answering assistant. Answer with a single word: \"yes\" or \"no\".\n\n" + prompt_text}
					]
					base_text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
				else:
					raise
			encoded_input = tokenizer(base_text, return_tensors="pt")
			prefix_token_ids = encoded_input["input_ids"][0].tolist()
			base_scores: Dict[str, float] = {}
			alignments: Dict[str, float] = {}
			for label_name in label_texts_by_name.keys():
				label_ids_for_name = label_token_ids[label_name]
				base_log_probability, alignment_total = compute_label_base_and_alignment(prefix_token_ids, label_ids_for_name, model, vibe_embeddings, device)
				base_scores[label_name] = float(base_log_probability)
				alignments[label_name] = float(alignment_total)
			predictions_for_example: Dict[str, Dict[str, object]] = {}
			for lambda_value in lambda_values:
				lambda_key = str(lambda_value)
				label_scores: Dict[str, float] = {}
				for label_name in label_texts_by_name.keys():
					score_value = base_scores[label_name] + lambda_value * alignments[label_name]
					label_scores[label_name] = float(score_value)
				if label_scores["yes"] >= label_scores["no"]:
					predicted_label_name = "yes"
				else:
					predicted_label_name = "no"
				predicted_label_id = 1 if predicted_label_name == "yes" else 0
				predictions_by_alpha[lambda_key].append(predicted_label_id)
				predictions_for_example[lambda_key] = {
					"predicted_label": predicted_label_name,
					"scores": {"yes": float(label_scores["yes"]), "no": float(label_scores["no"])}
				}
			example_records.append(
				{
					"index": int(row_index),
					"question": question_text,
					"passage": passage_text,
					"gold_label": gold_label_name,
					"input_messages": messages,
					"input_text": base_text,
					"label_token_ids": {key: [int(value) for value in values] for key, values in label_token_ids.items()},
					"predictions": predictions_for_example
				}
			)
			progress.advance(task_identifier, 1)
	end_timestamp = time.time()
	metrics = compute_classification_metrics(true_labels, predictions_by_alpha)
	setup_data = {
		"model_name": model_id,
		"dataset_name": "google/boolq",
		"split": "validation",
		"lambda_values": [float(value) for value in lambda_values],
		"max_rows": int(MAX_ROWS_EXPERIMENT_1),
		"num_examples_evaluated": int(len(true_labels)),
		"random_seed": GLOBAL_SEED,
		"device": str(device),
		"vibe_method": "Label scoring with sequential vibe-adjusted token log-probabilities for labels \"yes\" and \"no\".",
		"label_texts_by_name": label_texts_by_name,
		"library_versions": {
			"python": platform.python_version(),
			"torch": str(torch.__version__),
			"transformers": str(transformers_lib.__version__),
			"datasets": str(datasets_lib.__version__)
		},
		"runtime_seconds": float(end_timestamp - start_timestamp),
		"remarks": f"{model_id} evaluated on BoolQ using vibe decoding over label strings."
	}
	results_data = {
		"model_name": model_id,
		"dataset_name": "google/boolq",
		"split": "validation",
		"lambda_values": [float(value) for value in lambda_values],
		"metrics": metrics,
		"examples": example_records
	}
	return {"setup": setup_data, "results": results_data}

dataset_boolq = load_dataset("google/boolq")
validation_dataset = dataset_boolq["validation"]
if MAX_ROWS_EXPERIMENT_1 and MAX_ROWS_EXPERIMENT_1 > 0:
	row_count = min(MAX_ROWS_EXPERIMENT_1, len(validation_dataset))
	validation_dataset = validation_dataset.select(range(row_count))

experiment_1_output: Dict[str, object] = {
	"experiment_id": 1,
	"task": "BoolQ yes or no classification with vibe decoding over label strings.",
	"lambda_values": [float(value) for value in LAMBDA_VALUES],
	"random_seed": GLOBAL_SEED,
	"models": {}
}

for model_identifier in BOOLQ_MODELS:
	experiment_1_output["models"][model_identifier] = run_boolq_for_model(model_identifier, validation_dataset, LAMBDA_VALUES)

with open("experiment-1.json", "w", encoding="utf-8") as file_handle:
	json.dump(experiment_1_output, file_handle, ensure_ascii=False, indent=2)

### Experiment 2 (NumerSense)

In [None]:
MAX_ROWS_EXPERIMENT_2 = 0

CANDIDATE_WORDS = [
	"zero",
	"one",
	"two",
	"three",
	"four",
	"five",
	"six",
	"seven",
	"eight",
	"nine",
	"ten"
]

NUMERSENSE_DATASET_ID = "INK-USC/numer_sense"
NUMERSENSE_FALLBACK_TSV_URL = "https://raw.githubusercontent.com/INK-USC/NumerSense/main/data/validation.masked.tsv"

SYSTEM_PROMPT_NUMERSENSE = "You are a helpful assistant."
USER_TEMPLATE_NUMERSENSE = (
	"You are given a sentence where a number between 0 and 10 has been replaced by the token <mask>.\n"
	"Sentence: {sentence}\n\n"
	"Which single English word from [zero, one, two, three, four, five, six, seven, eight, nine, ten] best fills in for <mask>? Answer with just the word."
)

@dataclass
class ExampleResult:
	index: int
	x: str
	y: str
	target: str
	lmb: float
	hit1: int
	rank: int | None

NUMERSENSE_MODELS = [
	{"model_id": "Qwen/Qwen2.5-0.5B-Instruct", "use_system_prompt": True},
	{"model_id": "google/gemma-2-2b-it", "use_system_prompt": False},
	{"model_id": "meta-llama/Llama-3.1-8B-Instruct", "use_system_prompt": True}
]

def prepare_tokenizer_and_model_numer_sense(model_id: str) -> tuple[AutoTokenizer, AutoModelForCausalLM, torch.device]:
	token_kwargs: Dict[str, object] = {}
	if HF_TOKEN:
		token_kwargs["token"] = HF_TOKEN
	tokenizer = AutoTokenizer.from_pretrained(model_id, **token_kwargs)
	model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.bfloat16, device_map="auto", **token_kwargs)
	device = model.device
	model.eval()
	return tokenizer, model, device

def apply_chat_template_numer_sense(tokenizer: AutoTokenizer, sentence: str, use_system_prompt: bool) -> str:
	content = USER_TEMPLATE_NUMERSENSE.format(sentence=sentence)
	if use_system_prompt:
		messages = [
			{"role": "system", "content": SYSTEM_PROMPT_NUMERSENSE},
			{"role": "user", "content": content}
		]
	else:
		messages = [
			{"role": "user", "content": content}
		]
	return tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)

def get_lm_head_and_norms(model: AutoModelForCausalLM) -> tuple[torch.Tensor, torch.Tensor]:
	lm_head = model.get_output_embeddings()
	weight_matrix = lm_head.weight.detach()
	if weight_matrix.dim() != 2:
		raise ValueError("Language model head weight must be rank 2.")
	with torch.no_grad():
		norms = torch.linalg.norm(weight_matrix.float(), dim=1)
		inverse_norms = torch.where(norms > 0, norms.reciprocal(), torch.zeros_like(norms))
		inverse_norms = inverse_norms.to(weight_matrix.dtype)
	return weight_matrix, inverse_norms

def build_candidate_token_ids(tokenizer: AutoTokenizer) -> Dict[str, int]:
	candidate_ids: Dict[str, int] = {}
	for word in CANDIDATE_WORDS:
		token_ids_with_space = tokenizer(" " + word, add_special_tokens=False)["input_ids"]
		if len(token_ids_with_space) == 1:
			candidate_ids[word] = token_ids_with_space[0]
			continue
		token_ids_plain = tokenizer(word, add_special_tokens=False)["input_ids"]
		if len(token_ids_plain) == 1:
			candidate_ids[word] = token_ids_plain[0]
			continue
		raise ValueError(f"Cannot map word {word!r} to a single token.")
	return candidate_ids

def first_step_scores(base_logits: torch.Tensor, weight_matrix: torch.Tensor, inverse_norms: torch.Tensor, lambda_value: float) -> torch.Tensor:
	if base_logits.dim() != 1:
		raise ValueError("Base logits for NumerSense must be one dimensional.")
	if base_logits.shape[0] != weight_matrix.shape[0]:
		raise ValueError("Base logits and language model head weight dimension mismatch.")
	if base_logits.dtype != weight_matrix.dtype:
		base_logits = base_logits.to(weight_matrix.dtype)
	vocabulary_size, hidden_dim = weight_matrix.shape
	if inverse_norms.shape[0] != vocabulary_size:
		raise ValueError("Inverse norms shape does not match vocabulary size.")
	probabilities = torch.softmax(base_logits.float(), dim=-1)
	weighted_probabilities = probabilities * inverse_norms.float()
	direction = weighted_probabilities @ weight_matrix.float()
	e_dot_direction = (weight_matrix.float() @ direction) * inverse_norms.float()
	if lambda_value != 0.0:
		e_dot_direction = e_dot_direction.to(base_logits.dtype)
		adjusted = base_logits + lambda_value * e_dot_direction
		return adjusted
	return base_logits

def load_numer_sense_data(max_rows: int) -> List[Dict[str, object]]:
	try:
		dataset_split = load_dataset(NUMERSENSE_DATASET_ID, split="train", trust_remote_code=True)
	except Exception:
		dataset_split = load_dataset("csv", data_files={"train": NUMERSENSE_FALLBACK_TSV_URL}, sep="\t")["train"]
	if max_rows and max_rows > 0:
		row_count = min(max_rows, len(dataset_split))
		dataset_split = dataset_split.select(range(row_count))
	rows: List[Dict[str, object]] = []
	candidate_set = set(CANDIDATE_WORDS)
	column_names = list(dataset_split.column_names)
	def try_named_columns(item: Dict[str, object]) -> tuple[str | None, str | None]:
		sentence_value = None
		for key in ["sentence", "probe", "text", "input"]:
			if key in item and isinstance(item[key], str):
				sentence_value = item[key]
				break
		target_value = None
		for key in ["target", "answer", "label", "ground_truth", "gold"]:
			if key in item and isinstance(item[key], str):
				target_value = item[key]
				break
		return sentence_value, target_value
	def autodetect_columns() -> tuple[str, str]:
		sentence_column_name = None
		target_column_name = None
		mask_counts: Dict[str, int] = {}
		candidate_counts: Dict[str, int] = {}
		for column_name in column_names:
			column_values = dataset_split[column_name]
			if all(isinstance(value, str) for value in column_values):
				mask_counts[column_name] = sum(1 for value in column_values if "<mask>" in value)
				candidate_counts[column_name] = sum(1 for value in column_values if value.strip().lower() in candidate_set)
		if mask_counts:
			max_mask_column_name = max(mask_counts, key=lambda name: mask_counts[name])
			if mask_counts[max_mask_column_name] > 0:
				sentence_column_name = max_mask_column_name
		if candidate_counts:
			max_candidate_column_name = max(candidate_counts, key=lambda name: candidate_counts[name])
			if candidate_counts[max_candidate_column_name] > 0:
				target_column_name = max_candidate_column_name
		if not sentence_column_name or not target_column_name:
			raise ValueError("Could not locate sentence or target column for NumerSense.")
		return sentence_column_name, target_column_name
	sentence_column_name: str | None = None
	target_column_name: str | None = None
	first_item = dataset_split[0] if len(dataset_split) > 0 else {}
	probe_sentence, probe_target = try_named_columns(first_item) if first_item else (None, None)
	if probe_sentence and probe_target:
		sentence_column_name = next(key for key in column_names if key in first_item and first_item[key] == probe_sentence)
		target_column_name = next(key for key in column_names if key in first_item and first_item[key] == probe_target)
	else:
		sentence_column_name, target_column_name = autodetect_columns()
	for index in range(len(dataset_split)):
		item = dataset_split[index]
		sentence_text = str(item[sentence_column_name])
		target_text = str(item[target_column_name]).strip().lower()
		rows.append({"index": index, "sentence": sentence_text, "target": target_text})
	return rows

def evaluate_lambda_for_numer_sense(model: AutoModelForCausalLM, tokenizer: AutoTokenizer, device: torch.device, data_rows: List[Dict[str, object]], lambda_value: float, weight_matrix: torch.Tensor, inverse_norms: torch.Tensor, candidate_ids: Dict[str, int], use_system_prompt: bool) -> tuple[List[ExampleResult], float, float]:
	results: List[ExampleResult] = []
	hits = 0
	reciprocal_rank_sum = 0.0
	reciprocal_rank_count = 0
	total = len(data_rows)
	for row in tqdm(data_rows, desc=f"λ={lambda_value:+g}"):
		index = int(row["index"])
		sentence = str(row["sentence"])
		target = str(row["target"])
		prompt_text = apply_chat_template_numer_sense(tokenizer, sentence, use_system_prompt)
		tokenised = tokenizer(prompt_text, add_special_tokens=False, return_tensors="pt")
		input_ids = tokenised["input_ids"].to(device)
		attention_mask = tokenised.get("attention_mask")
		if attention_mask is not None:
			attention_mask = attention_mask.to(device)
		with torch.no_grad():
			outputs = model(input_ids=input_ids, attention_mask=attention_mask)
		logits = outputs.logits
		if logits.dim() != 3:
			raise ValueError("NumerSense logits must be three dimensional.")
		base_logits = logits[0, -1, :]
		if base_logits.shape[0] != weight_matrix.shape[0]:
			raise ValueError("NumerSense base logits vocabulary mismatch.")
		adjusted_logits = first_step_scores(base_logits, weight_matrix, inverse_norms, lambda_value)
		score_items: List[tuple[str, float]] = []
		for word, token_id in candidate_ids.items():
			score_items.append((word, float(adjusted_logits[token_id].item())))
		score_items.sort(key=lambda pair: pair[1], reverse=True)
		predicted_word = score_items[0][0]
		hit_value = 1 if predicted_word == target else 0
		hits += hit_value
		rank_value: int | None = None
		if target in candidate_ids:
			ordered_words = [name for name, _ in score_items]
			if target in ordered_words:
				rank_value = 1 + ordered_words.index(target)
				reciprocal_rank_sum += 1.0 / rank_value
				reciprocal_rank_count += 1
		results.append(
			ExampleResult(
				index=index,
				x=prompt_text,
				y=predicted_word,
				target=target,
				lmb=lambda_value,
				hit1=hit_value,
				rank=rank_value
			)
		)
	hit_at_1 = hits / max(1, total)
	if reciprocal_rank_count == 0:
		mean_reciprocal_rank = 0.0
		return results, hit_at_1, mean_reciprocal_rank
	mean_reciprocal_rank = reciprocal_rank_sum / reciprocal_rank_count
	return results, hit_at_1, mean_reciprocal_rank

def run_numer_sense_for_model(model_id: str, use_system_prompt: bool, lambda_values: List[float], max_rows: int) -> Dict[str, object]:
	set_seed(GLOBAL_SEED)
	tokenizer, model, device = prepare_tokenizer_and_model_numer_sense(model_id)
	weight_matrix, inverse_norms = get_lm_head_and_norms(model)
	weight_matrix = weight_matrix.to(device)
	inverse_norms = inverse_norms.to(device)
	candidate_ids = build_candidate_token_ids(tokenizer)
	data_rows = load_numer_sense_data(max_rows)
	all_records: List[Dict[str, object]] = []
	metrics: Dict[str, Dict[str, str]] = {}
	for lambda_value in lambda_values:
		run_results, hit_at_1, mean_reciprocal_rank = evaluate_lambda_for_numer_sense(model, tokenizer, device, data_rows, lambda_value, weight_matrix, inverse_norms, candidate_ids, use_system_prompt)
		print(f"Model={model_id} λ={lambda_value:+g} hit@1={hit_at_1:.8f} MRR={mean_reciprocal_rank:.8f}", flush=True)
		metrics[str(lambda_value)] = {"hit@1": f"{hit_at_1:.8f}", "MRR": f"{mean_reciprocal_rank:.8f}"}
		for result in run_results:
			all_records.append(
				{
					"index": result.index,
					"x": result.x,
					"y": result.y,
					"target": result.target,
					"lambda": result.lmb,
					"hit@1": result.hit1,
					"rank": result.rank
				}
			)
	return {"metrics": metrics, "records": all_records}

experiment_2_output: Dict[str, object] = {
	"experiment_id": 2,
	"task": "NumerSense masked number prediction with vibe decoding.",
	"lambda_values": [float(value) for value in LAMBDA_VALUES],
	"random_seed": GLOBAL_SEED,
	"dataset": {
		"id": NUMERSENSE_DATASET_ID,
		"fallback_tsv_url": NUMERSENSE_FALLBACK_TSV_URL,
		"max_rows": int(MAX_ROWS_EXPERIMENT_2)
	},
	"models": {}
}

for model_config in NUMERSENSE_MODELS:
	model_identifier = model_config["model_id"]
	use_system = bool(model_config["use_system_prompt"])
	experiment_2_output["models"][model_identifier] = run_numer_sense_for_model(model_identifier, use_system, LAMBDA_VALUES, MAX_ROWS_EXPERIMENT_2)

with open("experiment-2.json", "w", encoding="utf-8") as file_handle:
	json.dump(experiment_2_output, file_handle, ensure_ascii=False, indent=2)

### Experiment 3 (PG19)

In [None]:
MAX_ROWS_EXPERIMENT_3 = 10
PG19_BLOCK_LENGTH = 1024
PG19_BATCH_SIZE = int(os.environ.get("BATCH_SIZE", "1"))
PG19_DATASET_ID = "emozilla/pg19"
PG19_SPLIT = "test"

PG19_MODELS = [
	"Qwen/Qwen2.5-0.5B-Instruct",
	"google/gemma-2-2b-it",
	"meta-llama/Llama-3.1-8B-Instruct"
]

def select_pg19_dtype() -> torch.dtype:
	dtype_flag = os.environ.get("DTYPE", "").lower()
	if dtype_flag == "bf16":
		return torch.bfloat16
	if dtype_flag == "fp16":
		return torch.float16
	if torch.cuda.is_available() and torch.cuda.is_bf16_supported():
		return torch.bfloat16
	return torch.float16

def get_pg19_text_column(dataset_split) -> str:
	if isinstance(dataset_split.column_names, list):
		column_names = dataset_split.column_names
	else:
		column_names = list(dataset_split.column_names)
	if "text" in column_names:
		return "text"
	if "book_text" in column_names:
		return "book_text"
	raise KeyError(f"No suitable text column found. Columns: {column_names}")

def get_pg19_model_max_length(tokenizer: AutoTokenizer, model: AutoModelForCausalLM, fallback_max_length: int = 131072) -> int:
	max_length_value = getattr(model.config, "max_position_embeddings", None)
	if max_length_value is None or max_length_value <= 0:
		max_length_value = getattr(tokenizer, "model_max_length", None)
	if max_length_value is None or max_length_value <= 0 or max_length_value > 10_000_000:
		max_length_value = fallback_max_length
	return int(max_length_value)

def prepare_pg19_dataset(tokenizer: AutoTokenizer, texts: List[str], block_length: int, model_max_length: int) -> List[torch.Tensor]:
	block_size = max(2, min(block_length, model_max_length - 1))
	chunks: List[torch.Tensor] = []
	for text_value in texts:
		if not text_value:
			continue
		encoding = tokenizer(text_value, add_special_tokens=False)
		input_ids = encoding["input_ids"]
		sequence_length = len(input_ids)
		if sequence_length < 2:
			continue
		for start_index in range(0, sequence_length, block_size):
			chunk_ids = input_ids[start_index:start_index + block_size]
			if len(chunk_ids) >= 2:
				chunks.append(torch.tensor(chunk_ids, dtype=torch.long))
	return chunks

@torch.no_grad()
def evaluate_pg19_with_lambda(model: AutoModelForCausalLM, tokenizer: AutoTokenizer, chunks: List[torch.Tensor], lambda_value: float, device_identifier: str, dtype: torch.dtype, model_max_length: int) -> float:
	embedding_weight = model.get_input_embeddings().weight.to(device=device_identifier, dtype=dtype)
	vocabulary_size, hidden_dimension = embedding_weight.shape
	embedding_float = embedding_weight.detach().to(torch.float32)
	norms = torch.linalg.norm(embedding_float, dim=1, keepdim=True).clamp_min(1e-12)
	vibe_matrix = embedding_float / norms
	vibe_matrix_transposed = vibe_matrix.t().contiguous()
	total_negative_log_likelihood = 0.0
	total_tokens = 0
	model.eval()
	progress_bar = tqdm(total=len(chunks), desc=f"λ={lambda_value:.1f}")
	for start_index in range(0, len(chunks), PG19_BATCH_SIZE):
		batch_chunks = chunks[start_index:start_index + PG19_BATCH_SIZE]
		batch_chunks = [sequence[:model_max_length] if sequence.size(0) > model_max_length else sequence for sequence in batch_chunks]
		max_length_in_batch = max(sequence.size(0) for sequence in batch_chunks)
		pad_token_id = tokenizer.pad_token_id if tokenizer.pad_token_id is not None else tokenizer.eos_token_id
		input_ids = torch.full((len(batch_chunks), max_length_in_batch), pad_token_id, dtype=torch.long)
		for batch_index, sequence in enumerate(batch_chunks):
			input_ids[batch_index, :sequence.size(0)] = sequence
		attention_mask = (input_ids != pad_token_id).long()
		input_ids = input_ids.to(device_identifier)
		attention_mask = attention_mask.to(device_identifier)
		output = model(input_ids=input_ids, attention_mask=attention_mask)
		logits = output.logits
		batch_size_value, time_value, vocab_dimension = logits.shape
		if vocab_dimension != vocabulary_size:
			raise RuntimeError("Vocabulary size mismatch between embedding matrix and logits.")
		logits_pred = logits[:, :-1, :].contiguous()
		labels = input_ids[:, 1:].contiguous()
		attention_sub_mask = attention_mask[:, 1:].contiguous()
		batch_time = logits_pred.shape[0] * logits_pred.shape[1]
		logits_flat = logits_pred.reshape(batch_time, vocabulary_size).to(torch.float32)
		token_probabilities = torch.softmax(logits_flat, dim=-1)
		vibe_direction = token_probabilities @ vibe_matrix
		vibe_scores = vibe_direction @ vibe_matrix_transposed
		adjusted_logits = logits_flat + lambda_value * vibe_scores
		log_probabilities = torch.log_softmax(adjusted_logits, dim=-1)
		target_flat = labels.reshape(batch_time)
		mask_flat = attention_sub_mask.reshape(batch_time).to(torch.float32)
		gold_log_probabilities = log_probabilities.gather(1, target_flat.unsqueeze(1)).squeeze(1)
		negative_log_likelihood = -(gold_log_probabilities * mask_flat).sum().item()
		token_count = int(mask_flat.sum().item())
		total_negative_log_likelihood += negative_log_likelihood
		total_tokens += token_count
		progress_bar.update(len(batch_chunks))
	progress_bar.close()
	if total_tokens == 0:
		return float("inf")
	return float(math.exp(total_negative_log_likelihood / total_tokens))

def run_pg19_for_model(model_id: str) -> Dict[str, float]:
	set_seed(GLOBAL_SEED)
	device_identifier = "cuda" if torch.cuda.is_available() else "cpu"
	dtype = select_pg19_dtype()
	tokenizer = AutoTokenizer.from_pretrained(model_id, token=HF_TOKEN)
	if tokenizer.pad_token is None:
		tokenizer.pad_token = tokenizer.eos_token
	model = AutoModelForCausalLM.from_pretrained(model_id, token=HF_TOKEN, device_map="auto", torch_dtype=dtype)
	model_max_length = get_pg19_model_max_length(tokenizer, model)
	tokenizer.model_max_length = 10_000_000_000
	safe_block_length = min(PG19_BLOCK_LENGTH, model_max_length - 1)
	dataset_split = load_dataset(PG19_DATASET_ID, split=PG19_SPLIT)
	if MAX_ROWS_EXPERIMENT_3:
		row_count = min(MAX_ROWS_EXPERIMENT_3, len(dataset_split))
		dataset_split = dataset_split.select(range(row_count))
	text_column_name = get_pg19_text_column(dataset_split)
	texts = dataset_split[text_column_name]
	chunks = prepare_pg19_dataset(tokenizer, texts, safe_block_length, model_max_length)
	if not chunks:
		print(f"No usable chunks from dataset for model {model_id}.")
		return {}
	results: Dict[str, float] = {}
	for lambda_value in LAMBDA_VALUES:
		perplexity_value = evaluate_pg19_with_lambda(model, tokenizer, chunks, lambda_value, device_identifier, dtype, model_max_length)
		perplexity_string = f"{perplexity_value:.8f}"
		results[str(lambda_value)] = float(perplexity_string)
		print(f"[{model_id}] λ={lambda_value:+}: perplexity={perplexity_string}")
	return results

experiment_3_output: Dict[str, object] = {
	"experiment_id": 3,
	"task": "PG19 perplexity with vibe-adjusted decoding.",
	"lambda_values": [float(value) for value in LAMBDA_VALUES],
	"random_seed": GLOBAL_SEED,
	"dataset": {
		"id": PG19_DATASET_ID,
		"split": PG19_SPLIT,
		"block_length": PG19_BLOCK_LENGTH,
		"batch_size": PG19_BATCH_SIZE,
		"max_rows": int(MAX_ROWS_EXPERIMENT_3)
	},
	"models": {}
}

for model_identifier in PG19_MODELS:
	experiment_3_output["models"][model_identifier] = run_pg19_for_model(model_identifier)

with open("experiment-3.json", "w", encoding="utf-8") as file_handle:
	json.dump(experiment_3_output, file_handle, ensure_ascii=False, indent=2)

### Experiment 4 (WikiText-2)

In [None]:
MAX_ROWS_EXPERIMENT_4 = 1000
WIKITEXT_MODELS = [
	"Qwen/Qwen2.5-0.5B-Instruct",
	"google/gemma-2-2b-it",
	"meta-llama/Llama-3.1-8B-Instruct"
]
WIKITEXT_DATASET_ID = "Salesforce/wikitext"
WIKITEXT_CONFIG = "wikitext-2-v1"
WIKITEXT_SPLIT = "validation"
WIKITEXT_TIME_SLICE = 32
WIKITEXT_DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
WIKITEXT_PREFERRED_DTYPE = torch.bfloat16 if (WIKITEXT_DEVICE == "cuda" and torch.cuda.is_bf16_supported()) else torch.float16

if torch.cuda.is_available():
	torch.backends.cuda.matmul.allow_tf32 = True
	torch.backends.cudnn.allow_tf32 = True

def require(condition: bool, message: str) -> None:
	if condition:
		return
	raise ValueError(message)

def get_tokenizer_and_model_wikitext(model_id: str, torch_dtype: torch.dtype) -> tuple[AutoTokenizer, AutoModelForCausalLM]:
	tokenizer = AutoTokenizer.from_pretrained(model_id, use_fast=True, token=HF_TOKEN)
	model = AutoModelForCausalLM.from_pretrained(
		model_id,
		torch_dtype=torch_dtype,
		device_map="auto" if WIKITEXT_DEVICE == "cuda" else None,
		token=HF_TOKEN
	)
	if WIKITEXT_DEVICE != "cuda":
		model.to(WIKITEXT_DEVICE)
	model.eval()
	return tokenizer, model

def check_shapes_and_dtypes(model: AutoModelForCausalLM) -> tuple[int, int, torch.dtype]:
	embedding_matrix = model.get_input_embeddings().weight
	if embedding_matrix.ndim != 2:
		raise ValueError(f"Embedding matrix must be rank 2, got {embedding_matrix.shape}")
	vocabulary_size, hidden_dimension = embedding_matrix.shape
	if embedding_matrix.dtype not in (torch.float16, torch.bfloat16, torch.float32):
		raise ValueError(f"Unexpected embedding dtype {embedding_matrix.dtype}")
	return vocabulary_size, hidden_dimension, embedding_matrix.dtype

def verify_dataset_structure(dataset_wikitext, split_name: str) -> None:
	if split_name not in dataset_wikitext:
		raise ValueError(f"Missing split {split_name} in Wikitext dataset.")
	features = dataset_wikitext[split_name].features
	if "text" not in features or str(features["text"].dtype) != "string":
		raise ValueError(f"Expected 'text' feature of type string, got {features}")

@torch.no_grad()
def make_vibe_matrix(model: AutoModelForCausalLM, device_identifier: str) -> torch.Tensor:
	weight_matrix = model.get_input_embeddings().weight.detach()
	embeddings = weight_matrix.to(torch.float32)
	norms = embeddings.norm(dim=1, keepdim=True).clamp_min(1e-12)
	embeddings = embeddings / norms
	return embeddings.to(device_identifier, non_blocking=True)

@torch.no_grad()
def vibe_adjusted_logprobs(logits: torch.Tensor, vibe_matrix: torch.Tensor, lambda_value: float) -> torch.Tensor:
	if logits.ndim == 1:
		logits = logits.unsqueeze(0)
	batch_size, vocabulary_size = logits.shape
	if vocabulary_size != vibe_matrix.shape[0]:
		raise ValueError("Vocab mismatch between logits and vibe matrix.")
	logits = logits.to(torch.float32)
	log_probs = torch.log_softmax(logits, dim=-1)
	probabilities = torch.exp(log_probs)
	directions = torch.matmul(vibe_matrix.t(), probabilities.t())
	bias = torch.matmul(vibe_matrix, directions).t()
	scores = log_probs + lambda_value * bias
	return scores - torch.logsumexp(scores, dim=-1, keepdim=True)

@torch.no_grad()
def stride_ppl_with_vibe(model: AutoModelForCausalLM, tokenizer: AutoTokenizer, vibe_matrix: torch.Tensor, input_ids: torch.Tensor, lambda_value: float, max_length: int, stride: int = 512) -> float:
	negative_log_likelihood_sum = 0.0
	token_count_total = 0
	steps = list(range(0, input_ids.size(1), stride))
	for start_index in tqdm(steps, desc=f"PPL λ={lambda_value:+g}", leave=False):
		end_index = min(start_index + max_length, input_ids.size(1))
		if end_index - start_index <= 1:
			break
		chunk_ids = input_ids[:, start_index:end_index].to(WIKITEXT_DEVICE, non_blocking=True)
		outputs = model(input_ids=chunk_ids)
		logits = outputs.logits[:, :-1, :]
		target_ids = chunk_ids[:, 1:]
		batch_size, time_steps, vocabulary_size = logits.shape
		for time_start in range(0, time_steps, WIKITEXT_TIME_SLICE):
			time_end = min(time_start + WIKITEXT_TIME_SLICE, time_steps)
			logits_slice = logits[:, time_start:time_end, :].reshape(-1, vocabulary_size)
			target_slice = target_ids[:, time_start:time_end].reshape(-1, 1)
			adjusted_log_probs = vibe_adjusted_logprobs(logits_slice, vibe_matrix, lambda_value)
			negative_log_probs = -adjusted_log_probs.gather(dim=-1, index=target_slice).squeeze(-1)
			negative_log_likelihood_sum += float(negative_log_probs.sum().cpu().item())
			token_count_total += int(negative_log_probs.numel())
	if token_count_total == 0:
		return float("inf")
	return float(math.exp(negative_log_likelihood_sum / token_count_total))

set_seed(GLOBAL_SEED)
maybe_configure_hf_auth()
dataset_wikitext = load_dataset(WIKITEXT_DATASET_ID, WIKITEXT_CONFIG, token=HF_TOKEN)
verify_dataset_structure(dataset_wikitext, WIKITEXT_SPLIT)
validation_split = dataset_wikitext[WIKITEXT_SPLIT]
if MAX_ROWS_EXPERIMENT_4 and MAX_ROWS_EXPERIMENT_4 > 0:
	row_count = min(MAX_ROWS_EXPERIMENT_4, len(validation_split))
	validation_split = validation_split.select(range(row_count))
texts = [example["text"] for example in validation_split if isinstance(example.get("text"), str)]
raw_corpus = "\n".join(texts).strip()
require(len(raw_corpus) > 0, "Empty dataset slice.")
experiment_4_output: Dict[str, object] = {
	"experiment_id": 4,
	"task": "WikiText-2 perplexity with vibe-adjusted decoding.",
	"lambda_values": [float(value) for value in LAMBDA_VALUES],
	"random_seed": GLOBAL_SEED,
	"dataset": {
		"id": WIKITEXT_DATASET_ID,
		"config": WIKITEXT_CONFIG,
		"split": WIKITEXT_SPLIT,
		"max_rows": int(MAX_ROWS_EXPERIMENT_4)
	},
	"models": {}
}
for model_identifier in WIKITEXT_MODELS:
	print(f"\n=== Loading {model_identifier} dtype={WIKITEXT_PREFERRED_DTYPE} ===")
	tokenizer, model = get_tokenizer_and_model_wikitext(model_identifier, WIKITEXT_PREFERRED_DTYPE)
	vocabulary_size, hidden_dimension, embedding_dtype = check_shapes_and_dtypes(model)
	print(f"Embedding matrix V={vocabulary_size} d={hidden_dimension} dtype={embedding_dtype}")
	vibe_matrix = make_vibe_matrix(model, WIKITEXT_DEVICE)
	with torch.no_grad():
		input_ids = tokenizer(raw_corpus, return_tensors="pt").input_ids
	require(input_ids.size(1) >= 2, "Tokenised length too short for perplexity.")
	max_length_value = getattr(model.config, "max_position_embeddings", 2048) or 2048
	model_metrics: Dict[str, float] = {}
	for lambda_value in LAMBDA_VALUES:
		perplexity_value = stride_ppl_with_vibe(
			model=model,
			tokenizer=tokenizer,
			vibe_matrix=vibe_matrix,
			input_ids=input_ids,
			lambda_value=lambda_value,
			max_length=max_length_value,
			stride=min(512, max_length_value)
		)
		perplexity_string = f"{perplexity_value:.8f}"
		model_metrics[f"{lambda_value:+g}"] = float(perplexity_string)
		print(f"[{model_identifier}] λ={lambda_value:+g} PPL={perplexity_string}")
	experiment_4_output["models"][model_identifier] = {
		"results": model_metrics,
		"details": {
			"timestamp_utc": datetime.datetime.utcnow().isoformat() + "Z",
			"device": WIKITEXT_DEVICE,
			"seed": GLOBAL_SEED,
			"python": platform.python_version(),
			"torch": torch.__version__,
			"transformers": transformers_lib.__version__,
			"datasets": datasets_lib.__version__
		}
	}
with open("experiment-4.json", "w", encoding="utf-8") as file_handle:
	json.dump(experiment_4_output, file_handle, ensure_ascii=False, indent=2)