In [None]:
# %pip install transformers
# !pip install -U datasets
# !pip install fastdtw
# !pip install tf-keras
# !pip install --upgrade huggingface_hub
# !pip install --upgrade datasets

In [None]:
!git clone https://github.com/alycialee/beyond-scale-language-data-diversity.git
%cd beyond-scale-language-data-diversity

# # 2️⃣  Make sure the build tools are modern enough
# # %pip install --quiet --upgrade pip setuptools wheel

%pip install pip==24.0

# # 3️⃣  Editable-install *into the live kernel*  ← note the %pip magic
%pip install -e .

In [3]:
import sys, pathlib
sys.path.insert(0, str(pathlib.Path.cwd() / "src"))
sys.path.insert(0, str(pathlib.Path.cwd() / "src" / "diversity"))

In [4]:
def make_loss_fn(ignore_id):
    """Factory function to create a cross-entropy loss function."""
    def _loss_fn(logits, tgt, *_, **__):
        logits = logits[:, :-1, :].contiguous()
        tgt = tgt[:, 1:].contiguous()
        return F.cross_entropy(logits.view(-1, logits.size(-1)), tgt.view(-1), ignore_index=ignore_id)
    return _loss_fn

In [5]:
from typing import List, Tuple
import torch.nn as nn

# Names that identify vocab-size–dependent parts across common HF LMs
_EXCLUDE_SUBSTRS = [
    "wte",               # GPT-2 input token embeddings
    "embed_tokens",      # LLaMA/T5-like input embeddings
    "word_embeddings",   # BERT-like input embeddings
    "lm_head",           # GPT-style output head
    "cls.predictions",   # BERT MLM head
    "output_projection", # some decoder heads
    "decoder.embed_tokens" # seq2seq decoders
]

# Names we consider "fixed" and informative for Task2Vec
_INCLUDE_SUBSTRS = [
    "transformer.h.", "transformer.ln_f",  # GPT-2/DistilGPT-2 blocks + final LN
    "encoder.layer.", "decoder.layer.",    # BERT/seq2seq blocks
    "wpe", "position_embeddings",          # positional embeddings (fixed-size if max_position_embeddings is fixed)
    "ln_", "layer_norm", "norm",           # layer norms
    "attn", "mlp", "intermediate", "output" # block internals
]

def _should_exclude(name: str) -> bool:
    return any(s in name for s in _EXCLUDE_SUBSTRS)

def _should_include(name: str) -> bool:
    # include anything that matches INCLUDE and is not explicitly excluded
    return (any(s in name for s in _INCLUDE_SUBSTRS)) and (not _should_exclude(name))

def freeze_token_and_head_params(model: nn.Module) -> Tuple[List[str], List[str]]:
    """
    Freeze token embeddings and LM head so Fisher excludes them.
    Returns (included_names, excluded_names) for logging.
    """
    included, excluded = [], []
    # Freeze via high-level API if available
    emb = getattr(model, "get_input_embeddings", lambda: None)()
    if emb is not None and hasattr(emb, "weight"):
        emb.weight.requires_grad_(False)
    out_emb = getattr(model, "get_output_embeddings", lambda: None)()
    if out_emb is not None and hasattr(out_emb, "weight"):
        out_emb.weight.requires_grad_(False)

    # Pattern-based freeze as a backstop and to catch tied heads
    for n, p in model.named_parameters():
        if _should_exclude(n):
            p.requires_grad_(False)
            excluded.append(n)
        else:
            included.append(n)
    return included, excluded

class ParamFilteredWrapper(nn.Module):
    """
    Wraps a model but only *exposes* parameters we want Task2Vec to consider.
    Forward is unchanged; only .parameters() / .named_parameters() are filtered.
    """
    def __init__(self, inner: nn.Module):
        super().__init__()
        self.inner = inner

    def forward(self, *args, **kwargs):
        return self.inner(*args, **kwargs)

    def named_parameters(self, prefix: str = "", recurse: bool = True):
        for n, p in self.inner.named_parameters(prefix="", recurse=recurse):
            # Keep only fixed-size, informative params
            if _should_include(n) and p.requires_grad:
                yield n if prefix == "" else f"{prefix}.{n}", p

    def parameters(self, recurse: bool = True):
        for _, p in self.named_parameters(recurse=recurse):
            yield p

    # (Optional) expose original attributes if your code accesses them
    def __getattr__(self, name):
        # Delegate everything else to the inner model
        try:
            return super().__getattr__(name)
        except AttributeError:
            return getattr(self.inner, name)


In [6]:
import os
# -------------------------------------------------------------
# 1️⃣  Global config – adjust these numbers once, at the top
# -------------------------------------------------------------
# Choose a batch size that comfortably fits in your GPU memory.
# For a 40 GB A100 you can typically go up to 64–128.
BATCH_SIZE = 128          # <- was 10

# Task2Vec mini‑batch size – 4 – 8 works well on most cards.
TASK2VEC_BATCH = 4       # <- was 1

# Number of CPU processes for the HF‑datasets .map()
TOKENIZE_WORKERS = min(32, os.cpu_count())   # feel free to raise

# DataLoader workers (prefetching of the raw MiniF2F rows)
DL_WORKERS = 8           # <- set >0

In [7]:
import os
from functools import lru_cache

@lru_cache(maxsize=None)
def embed_with_solver_tokenizer(
    solver_model_id: str,
    raw_dataset,
    probe_model,
    max_len: int,
    epochs: int = 1,
    cpu_workers: int = TOKENIZE_WORKERS,
):
    # ---------- Tokenizer ----------
    tok = AutoTokenizer.from_pretrained(solver_model_id, trust_remote_code=True)
    if tok.pad_token is None:
        tok.add_special_tokens({'pad_token': '[PAD]'})

    # Resize probe vocab (so it can read the solver ids)
    probe_model.resize_token_embeddings(len(tok))
    if hasattr(probe_model, "tie_weights"):
        probe_model.tie_weights()

    # ---------- Freeze token/LM‑head ----------
    freeze_token_and_head_params(probe_model)          # removes wte & lm_head from grad

    # ---------- Wrap the model so Task2Vec sees only the fixed params ----------
    probe_for_task2vec = ParamFilteredWrapper(probe_model)

    # ---------- Tokenise the dataset – GLOBAL padding ----------
    tok_ds = raw_dataset.map(
        lambda b: tok(
            b["text"],
            padding="max_length",          # <-- crucial change
            truncation=True,
            max_length=max_len,
        ),
        batched=True,
        batch_size=20,
        remove_columns=raw_dataset.column_names,
        num_proc=cpu_workers,
    )
    tok_ds.set_format(type="torch", columns=["input_ids", "attention_mask"])

    # ---------- Task2Vec ----------
    task2vec = Task2Vec(
        probe_for_task2vec,
        max_samples=1024,
        loader_opts={"batch_size": TASK2VEC_BATCH, "shuffle": True, "num_workers": 8},
    )
    task2vec.loss_fn = make_loss_fn(tok.pad_token_id)

    emb, lm_loss = task2vec.embed(tok_ds, epochs=epochs)

    # Convert to torch tensor (Hessian matrix → dense tensor)
    embedding_tensor = torch.from_numpy(emb.hessian).to(dtype=torch.float32)

    # ---- Clean‑up (no empty_cache needed) ----
    del tok, tok_ds, task2vec, emb, probe_for_task2vec
    gc.collect()
    # torch.cuda.empty_cache()   # <-- safe to drop

    return embedding_tensor, lm_loss



In [10]:
import os
import gc
import torch
import warnings
import torch.nn.functional as F
from datasets import load_dataset, Dataset
from transformers import (
    GPTNeoForCausalLM,
    AutoTokenizer,
    AutoModelForCausalLM,
    AutoConfig,
)
from task2vec import Task2Vec
from itertools import combinations
from torch.utils.data import DataLoader
from functools import lru_cache
from typing import Callable, Tuple
from torch import Tensor
# from fastdtw import fastdtw
from scipy.spatial.distance import euclidean




warnings.filterwarnings("ignore", module="torch")

# ----- Global Configuration & Models -----------------------------------------
print(">>> Loading and preparing dataset...")
# raw_ds = load_dataset("Tonic/MiniF2F", split="train")
raw_ds = load_dataset("AI-MO/minif2f_test", split="train")
raw_ds = raw_ds.select(range(40)) # Using a smaller subset for demonstration
raw_ds = raw_ds.map(
    # removed ex["informal_prefix"] bc it's often repeated in formal statement in a comment
    lambda ex: {"text": "\n".join(p for p in (ex["formal_statement"]) if p)},
    num_proc=os.cpu_count(),
)
print(">>> Dataset ready.")


# Define the device to use (GPU if available, otherwise CPU)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"--- Using device: {device} ---")

print(">>> Loading probe model...")
# probe_model_id = "distilbert/distilgpt2"
probe_model_id = "Saisam/gpt-neo-math-small"
probe_cfg = AutoConfig.from_pretrained(probe_model_id, trust_remote_code=True)
probe_cfg.attn_implementation = "sdpa"
probe_model = AutoModelForCausalLM.from_pretrained(
    probe_model_id,
    config=probe_cfg,
    trust_remote_code=True,
).to(device)
# If your GPU supports FP16 (all modern ones do)
probe_model = probe_model.half()          # <<< mixed‑precision

# probe_model.gradient_checkpointing_enable()
probe_model.gradient_checkpointing_disable()

max_probe_length = probe_model.config.max_position_embeddings
print(f">>> Probe model '{probe_model_id}' loaded with max length: {max_probe_length}.")

# ----- Helper & Core Functions -----------------------------------------------
def _compute_cmdiv_core(
    model_id1: str,
    model_id2: str,
    dataset_fingerprint: str,
    batch_size: int = BATCH_SIZE
) -> float:
    """
    Computes cmdiv by averaging embedding distances over batches.
    The `align_fn` is called on each pair of flattened embeddings
    and should return the two tensors to compare.
    """
    # print(f"\n>>> Computing cmdiv for pair: ({model_id1}, {model_id2})")
    loader = DataLoader(
        raw_ds,
        batch_size=batch_size,
        pin_memory=True,
        num_workers=DL_WORKERS,            # <<< parallel pre‑fetch
    )
    distances = []

    for i, batch in enumerate(loader, start=1):
        # print(f"    ... processing batch {i}/{len(loader)}")
        batch_ds = Dataset.from_dict(batch)

        emb1, _ = embed_with_solver_tokenizer(
            model_id1, batch_ds, probe_model, max_probe_length
        )
        emb2, _ = embed_with_solver_tokenizer(
            model_id2, batch_ds, probe_model, max_probe_length
        )

        # if emb1.shape != emb2.shape:
        #   print(f"{model_id1} and {model_id2} differ in embedding shape")
        #   print(emb1.shape)
        #   print(emb2.shape)

        # def flatten_and_pad(a: torch.Tensor, b: torch.Tensor):
        #     fa = a.flatten()
        #     fb = b.flatten()
        #     max_len = max(fa.numel(), fb.numel())
        #     if fa.numel() < max_len:
        #         pad = torch.zeros(max_len - fa.numel(), device=fa.device, dtype=fa.dtype)
        #         fa = torch.cat([fa, pad], dim=0)
        #     if fb.numel() < max_len:
        #         pad = torch.zeros(max_len - fb.numel(), device=fb.device, dtype=fb.dtype)
        #         fb = torch.cat([fb, pad], dim=0)
        #     return fa, fb


        # replace DTW block with:
        # fa, fb = flatten_and_pad(emb1, emb2)
        # cos_sim = F.cosine_similarity(fa.unsqueeze(0), fb.unsqueeze(0), dim=1).item()

        assert(emb1.shape == emb2.shape, f"embedding shapes differ between {model_id1} and {model_id2}")
        fa, fb = emb1.flatten(), emb2.flatten()   # no padding needed
        cos_sim = F.cosine_similarity(fa.unsqueeze(0), fb.unsqueeze(0)).item()

        # Normalize to [0,1]:
        sim = (cos_sim + 1.0) / 2.0
        dist = 1.0 - sim
        distances.append(dist)


    avg_distance = sum(distances) / len(distances) if distances else 0.0
    # print(f">>> Average distance (cmdiv) for pair = {avg_distance:.4f}")
    return avg_distance

@lru_cache(maxsize=None)
def compute_cmdiv(
    model_id1: str,
    model_id2: str,
    dataset_fingerprint: str,
    batch_size: int = BATCH_SIZE
) -> float:
    # Canonicalize model ID order
    m1, m2 = sorted([model_id1, model_id2])

    # Now use m1 and m2 consistently for computation
    return _compute_cmdiv_core(m1, m2, dataset_fingerprint, batch_size)

def compute_edc(model_ids: set, dataset: Dataset) -> float:
    """
    Computes the Ensemble Diversity Coefficient (EDC) for a set of models.
    """
    if len(model_ids) < 2:
        return 0.0

    dataset_fingerprint = dataset.info.builder_name
    pairwise_diversities = []

    for m_id1, m_id2 in combinations(list(model_ids), 2):
        diversity = compute_cmdiv(m_id1, m_id2, dataset_fingerprint, BATCH_SIZE)
        print(f"Diversity between {m_id1} and {m_id2} on {dataset_fingerprint} with batch size {BATCH_SIZE}: {diversity}")
        pairwise_diversities.append(diversity)

    return sum(pairwise_diversities) / len(pairwise_diversities) if pairwise_diversities else 0.0



>>> Loading and preparing dataset...


Map (num_proc=12):   0%|          | 0/40 [00:00<?, ? examples/s]

>>> Dataset ready.
--- Using device: cuda ---
>>> Loading probe model...
>>> Probe model 'Saisam/gpt-neo-math-small' loaded with max length: 2048.


  assert(emb1.shape == emb2.shape, f"embedding shapes differ between {model_id1} and {model_id2}")


In [None]:
from itertools import combinations
from typing import List, Tuple
from tqdm import tqdm

test_model_ids = [
    'AI-MO/Kimina-Prover-Preview-Distill-7B', 'ByteDance-Seed/BFS-Prover', 'Goedel-LM/Goedel-Prover-SFT', 'deepseek-ai/DeepSeek-Prover-V1', 'deepseek-ai/DeepSeek-Prover-V1.5-RL', 'deepseek-ai/DeepSeek-Prover-V2-7B', 'kfdong/STP_model_Lean', 'stoney0062/Leanabell-Prover-DS-SFT', 'wellecks/llmstep-mathlib4-pythia2.8b'
]

def get_all_combinations_of_length(s, length: int) -> List[Tuple]:
    return list(combinations(s, length))

solver_model_ensembles_3 = get_all_combinations_of_length(test_model_ids, 3)

edc_i = []
original_len = len(edc_i)

edc_i += [0] * (len(solver_model_ensembles_3) - len(edc_i))

for i, ensemble in tqdm(enumerate(solver_model_ensembles_3), total=len(solver_model_ensembles_3)):
    if edc_i[i]:
        continue
    edc_i[i] = compute_edc(ensemble, raw_ds)
    with open(f"edc_i_{i}.txt", "w") as file:
        file.write(",".join([str(j) for j in edc_i]))

  0%|          | 0/84 [00:00<?, ?it/s]

Map (num_proc=12):   0%|          | 0/40 [00:00<?, ? examples/s]

self.classifier_opts={}
MODEL DEVICE:  cuda:0


Epoch:   0%|          | 0/1 [00:00<?, ?it/s]

Iteration:   0%|          | 0/10 [00:00<?, ?it/s]


Initial loss 4.5546875 (step=0 epoch=0)

final loss step=9 epoch=0 of final layer loss 5.24609375 (note we are not recomputing loss after a step so this loss printed is larger than it should be/one off)
loss=5.24609375 (after fine tune, if not done it will be None)


Iteration:   0%|          | 0/10 [00:00<?, ?it/s]

Map (num_proc=12):   0%|          | 0/40 [00:00<?, ? examples/s]

self.classifier_opts={}
MODEL DEVICE:  cuda:0


Epoch:   0%|          | 0/1 [00:00<?, ?it/s]

Iteration:   0%|          | 0/10 [00:00<?, ?it/s]


Initial loss 4.5546875 (step=0 epoch=0)

final loss step=9 epoch=0 of final layer loss 5.24609375 (note we are not recomputing loss after a step so this loss printed is larger than it should be/one off)
loss=5.24609375 (after fine tune, if not done it will be None)


Iteration:   0%|          | 0/10 [00:00<?, ?it/s]

Diversity between AI-MO/Kimina-Prover-Preview-Distill-7B and ByteDance-Seed/BFS-Prover on parquet with batch size 128: -5.960464477539063e-08


Map (num_proc=12):   0%|          | 0/40 [00:00<?, ? examples/s]

self.classifier_opts={}
MODEL DEVICE:  cuda:0


Epoch:   0%|          | 0/1 [00:00<?, ?it/s]

Iteration:   0%|          | 0/10 [00:00<?, ?it/s]


Initial loss 4.5546875 (step=0 epoch=0)

final loss step=9 epoch=0 of final layer loss 5.24609375 (note we are not recomputing loss after a step so this loss printed is larger than it should be/one off)
loss=5.24609375 (after fine tune, if not done it will be None)


Iteration:   0%|          | 0/10 [00:00<?, ?it/s]

tokenizer_config.json: 0.00B [00:00, ?B/s]

tokenizer.json: 0.00B [00:00, ?B/s]

special_tokens_map.json:   0%|          | 0.00/369 [00:00<?, ?B/s]

Map (num_proc=12):   0%|          | 0/40 [00:00<?, ? examples/s]

self.classifier_opts={}
MODEL DEVICE:  cuda:0


Epoch:   0%|          | 0/1 [00:00<?, ?it/s]

Iteration:   0%|          | 0/10 [00:00<?, ?it/s]


Initial loss 9.8125 (step=0 epoch=0)

final loss step=9 epoch=0 of final layer loss 10.65625 (note we are not recomputing loss after a step so this loss printed is larger than it should be/one off)
loss=10.65625 (after fine tune, if not done it will be None)


Iteration:   0%|          | 0/10 [00:00<?, ?it/s]

Diversity between AI-MO/Kimina-Prover-Preview-Distill-7B and Goedel-LM/Goedel-Prover-SFT on parquet with batch size 128: 0.2229878306388855


Map (num_proc=12):   0%|          | 0/40 [00:00<?, ? examples/s]

self.classifier_opts={}
MODEL DEVICE:  cuda:0


Epoch:   0%|          | 0/1 [00:00<?, ?it/s]

Iteration:   0%|          | 0/10 [00:00<?, ?it/s]


Initial loss 4.5546875 (step=0 epoch=0)

final loss step=9 epoch=0 of final layer loss 5.24609375 (note we are not recomputing loss after a step so this loss printed is larger than it should be/one off)
loss=5.24609375 (after fine tune, if not done it will be None)


Iteration:   0%|          | 0/10 [00:00<?, ?it/s]

In [None]:
with open(f"edc_i.txt", "w") as file:
        file.write(",".join([str(x) for x in edc_i]))

In [11]:
import gc
gc.collect()
torch.cuda.empty_cache()

In [12]:
import torch
# Example: Delete any GPU tensors
for var in list(globals().keys()):  # Check global variables
    if isinstance(globals()[var], torch.Tensor) and var != 'torch':
        del globals()[var]
torch.cuda.empty_cache()

In [13]:
print(f"Allocated: {torch.cuda.memory_allocated() / 1024**3:.2f} GB")
print(f"Reserved: {torch.cuda.memory_reserved() / 1024**3:.2f} GB")

Allocated: 1.08 GB
Reserved: 4.70 GB


In [None]:
# Go to the container’s home folder
%cd ~          # now in /root
from google.colab import files
files.download('edc_i.txt')   # shows a browser download dialog