**README**: Before you run the notebook, ensure you have the following files in the Current Working Directory

- checkpoint.json

- corrected_proofs.zip

- proof_outcomes_by_model.json

# Initialization

In [None]:
!pip install trl peft transformers torch_ema

Collecting trl
  Downloading trl-0.22.2-py3-none-any.whl.metadata (11 kB)
Collecting torch_ema
  Downloading torch_ema-0.3-py3-none-any.whl.metadata (415 bytes)
Downloading trl-0.22.2-py3-none-any.whl (544 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m544.8/544.8 kB[0m [31m9.1 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading torch_ema-0.3-py3-none-any.whl (5.5 kB)
Installing collected packages: torch_ema, trl
Successfully installed torch_ema-0.3 trl-0.22.2


In [None]:
import numpy as np
import pandas as pd
import statsmodels.api as sm
import json

model_dict = json.load(open("checkpoint.json"))

def parse_success_dict(model_dict: dict):
    """
    model_dict: mapping solver name -> binary string like "100101..." indicating
    which problems that solver solves.
    Returns solver names list and success_matrix of shape (n_solvers, n_problems).
    """
    solvers = list(model_dict.keys())
    # assume all strings are same length and consist of '0'/'1'
    success_matrix = np.array(
        [[int(c) for c in model_dict[name]] for name in solvers],
        dtype=int
    )  # shape: (n_solvers, n_problems)
    return solvers, success_matrix

solvers, success_matrix = parse_success_dict(model_dict['model_dict'])
print(solvers)
print(success_matrix)

['AI-MO_Kimina-Prover-Preview-Distill-7B', 'ByteDance-Seed_BFS-Prover', 'Goedel-LM_Goedel-Prover-SFT', 'deepseek-ai_DeepSeek-Prover-V1', 'deepseek-ai_DeepSeek-Prover-V1.5-RL', 'deepseek-ai_DeepSeek-Prover-V2-7B', 'kfdong_STP_model_Lean', 'stoney0062_Leanabell-Prover-DS-SFT', 'wellecks_llmstep-mathlib4-pythia2.8b']
[[1 0 0 1 1 1 0 1 1 0 1 1 0 1 1 1 0 1 0 0 0 0 0 1 0 1 1 0 0 1 0 0 1 1 1 1
  0 0 0 0]
 [1 0 0 1 1 1 0 0 1 0 1 0 0 1 1 0 0 1 0 0 1 0 0 1 0 0 0 1 0 0 0 0 1 1 1 1
  0 0 1 0]
 [1 1 0 1 1 1 0 0 1 1 1 1 0 1 1 1 0 1 0 0 0 0 0 1 0 1 0 1 0 1 0 1 1 1 1 0
  0 0 1 0]
 [1 0 0 1 1 1 0 0 1 0 1 1 0 1 1 1 0 1 0 0 1 0 0 1 0 0 0 1 0 1 0 0 1 1 1 1
  0 0 1 0]
 [1 0 0 0 1 1 0 1 1 0 1 1 0 1 1 1 0 1 0 0 1 0 0 1 0 0 0 1 0 1 0 0 0 1 1 1
  0 0 1 1]
 [1 1 1 0 1 1 0 1 1 1 1 1 0 0 1 1 0 1 0 0 0 1 0 1 0 1 0 1 0 1 0 0 1 1 1 1
  0 0 1 1]
 [1 1 0 1 1 1 0 1 1 1 1 1 0 1 1 1 0 1 0 0 1 1 0 1 0 0 0 1 0 1 0 0 1 1 1 1
  0 0 1 0]
 [1 1 0 0 1 1 0 0 1 1 1 0 0 0 0 0 0 0 0 0 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
  0 0 0 0]
 [0 

# Code Exploration

In [None]:
type(solvers)

list

In [None]:
arr = np.array(solvers)
condition = arr == 'deepseek-ai_DeepSeek-Prover-V2-7B'
np.where(condition)[0]

array([5])

In [None]:
print(success_matrix[5])

[1 1 1 0 1 1 0 1 1 1 1 1 0 0 1 1 0 1 0 0 0 1 0 1 0 1 0 1 0 1 0 0 1 1 1 1 0
 0 1 1]


In [None]:
ones = np.ones(success_matrix.shape[1], dtype=int)
ones

array([1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
       1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1])

In [None]:
success_matrix[0]

array([1, 0, 0, 1, 1, 1, 0, 1, 1, 0, 1, 1, 0, 1, 1, 1, 0, 1, 0, 0, 0, 0,
       0, 1, 0, 1, 1, 0, 0, 1, 0, 0, 1, 1, 1, 1, 0, 0, 0, 0])

In [None]:
np.sum(success_matrix[0])

np.int64(20)

In [None]:
comb = success_matrix[0] | ones
comb

array([1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
       1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1])

In [None]:
a = np.where(success_matrix[0] == 1); a

(array([ 0,  3,  4,  5,  7,  8, 10, 11, 13, 14, 15, 17, 23, 25, 26, 29, 32,
        33, 34, 35]),)

In [None]:
b = np.where(success_matrix[1] == 0); b

(array([ 1,  2,  6,  7,  9, 11, 12, 15, 16, 18, 19, 21, 22, 24, 25, 26, 28,
        29, 30, 31, 36, 37, 39]),)

In [None]:
np.size(np.intersect1d(a, b))

6

# Greedy Selection

In [None]:
success_arr = np.zeros(success_matrix.shape[1], dtype=int)
ensemble = []
contribution: dict[str, int] = {}
for i, solver in enumerate(solvers):
  success_arr_ = success_arr | success_matrix[i]
  diff = np.sum(success_arr_) - np.sum(success_arr)
  if diff > 0:
    contribution[solver] = diff
    success_arr = success_arr_
    ensemble.append(solver)

In [None]:
ensemble, len(ensemble)

(['AI-MO_Kimina-Prover-Preview-Distill-7B',
  'ByteDance-Seed_BFS-Prover',
  'Goedel-LM_Goedel-Prover-SFT',
  'deepseek-ai_DeepSeek-Prover-V1.5-RL',
  'deepseek-ai_DeepSeek-Prover-V2-7B'],
 5)

In [None]:
contribution.items()

dict_items([('AI-MO_Kimina-Prover-Preview-Distill-7B', np.int64(20)), ('ByteDance-Seed_BFS-Prover', np.int64(3)), ('Goedel-LM_Goedel-Prover-SFT', np.int64(3)), ('deepseek-ai_DeepSeek-Prover-V1.5-RL', np.int64(1)), ('deepseek-ai_DeepSeek-Prover-V2-7B', np.int64(2))])

# Find Marginal Contribution

In [None]:
import itertools

list(itertools.combinations(range(len(solvers)), 2))

[(0, 1),
 (0, 2),
 (0, 3),
 (0, 4),
 (0, 5),
 (0, 6),
 (0, 7),
 (0, 8),
 (1, 2),
 (1, 3),
 (1, 4),
 (1, 5),
 (1, 6),
 (1, 7),
 (1, 8),
 (2, 3),
 (2, 4),
 (2, 5),
 (2, 6),
 (2, 7),
 (2, 8),
 (3, 4),
 (3, 5),
 (3, 6),
 (3, 7),
 (3, 8),
 (4, 5),
 (4, 6),
 (4, 7),
 (4, 8),
 (5, 6),
 (5, 7),
 (5, 8),
 (6, 7),
 (6, 8),
 (7, 8)]

In [None]:
from numpy import logical_not as not_
def not_1(arr1: np.ndarray, arr2: np.ndarray):
  return np.sum(not_(arr1) & arr2)

def not_2(arr1: np.ndarray, arr2: np.ndarray):
  return np.sum(arr1 & not_(arr2))

In [None]:
sucess_matrix_ = success_matrix[:, :15]

In [None]:
res = []
for pair in itertools.combinations(range(len(solvers)), 2):
    n_1 = not_1(sucess_matrix_[pair[0]], sucess_matrix_[pair[1]])
    n_2 = not_2(sucess_matrix_[pair[0]], sucess_matrix_[pair[1]])

    # base score (mean of n1 and n2)
    base_score = np.mean([n_1, n_2])

    # penalty for being far apart
    penalty = abs(n_1 - n_2)   # or (n_1 - n_2)**2

    # final score (subtract penalty, weight it if needed)
    score = base_score - 0.5 * penalty

    res.append(((pair[0], pair[1]),
                (solvers[pair[0]], solvers[pair[1]]),
                score,
                (n_1, n_2)))

best = sorted(res, key=lambda x: x[2], reverse=True)[0]
best

((0, 5),
 ('AI-MO_Kimina-Prover-Preview-Distill-7B',
  'deepseek-ai_DeepSeek-Prover-V2-7B'),
 np.float64(2.0),
 (np.int64(3), np.int64(2)))

In [None]:
import itertools
import numpy as np
import random

# ---- Scoring Functions ----
def pair_score_subset(i, j, cols, success_matrix, solvers):
    # restrict to the chosen subset of columns
    sub_i = success_matrix[i, cols]
    sub_j = success_matrix[j, cols]

    n1 = not_1(sub_i, sub_j)
    n2 = not_2(sub_i, sub_j)
    base = np.mean([n1, n2])
    penalty = abs(n1 - n2)
    return base - 0.45 * penalty


# ---- Greedy Search ----
def greedy_column_subset(success_matrix, solvers, k=15):
    cols = []
    best_pair = None
    best_score = -np.inf

    while len(cols) < k:
        best_add, best_score_add = None, -np.inf
        for c in range(success_matrix.shape[1]):
            if c in cols:
                continue
            candidate_cols = cols + [c]
            # find best pair for this candidate subset
            for i, j in itertools.combinations(range(len(solvers)), 2):
                score = pair_score_subset(i, j, candidate_cols, success_matrix, solvers)
                if score > best_score_add:
                    best_score_add = score
                    best_pair = (i, j)
                    best_add = c
        if best_add is None:
            break
        cols.append(best_add)
        best_score = best_score_add

    return cols, best_pair, best_score


cols_greedy, pair_greedy, score_greedy = greedy_column_subset(success_matrix, solvers, k=15)
print("Greedy best subset of cols:", cols_greedy)
print("Best solver pair:", pair_greedy)
print("Score:", score_greedy)
# rand_best, rand_score = random_search(sucess_matrix_, solvers, k=15, iters=10000)
# print("Random search best subset:", rand_best)
# print("Random search score:", rand_score)

Greedy best subset of cols: [0, 13, 1, 35, 4, 5, 8, 9, 20, 3, 7, 11, 14, 15, 17]
Best solver pair: (0, 7)
Score: 3.25


In [None]:
sucess_matrix_ = success_matrix[:, cols_greedy]

In [None]:
not_1(sucess_matrix_[pair_greedy[0]], sucess_matrix_[pair_greedy[1]])

np.int64(3)

In [None]:
a = np.where(sucess_matrix_[pair_greedy[0]]==1)
b = np.where(sucess_matrix_[pair_greedy[1]]==0)
solver_2_training_ex = np.intersect1d(a,b)
solver_2_training_ex

array([ 1,  3,  9, 10, 11, 12, 13, 14])

In [None]:
a_ = np.where(sucess_matrix_[pair_greedy[0]]==0)
b_ = np.where(sucess_matrix_[pair_greedy[1]]==1)
solver_1_training_ex = np.intersect1d(a_, b_)
solver_1_training_ex

array([2, 7, 8])

In [None]:
selected_training_problems_solver_1 = [cols_greedy[i] for i in solver_1_training_ex]
selected_training_problems_solver_2 = [cols_greedy[i] for i in solver_2_training_ex]
selected_training_problems_solver_1, selected_training_problems_solver_2

([1, 9, 20], [13, 35, 3, 7, 11, 14, 15, 17])

In [None]:
not_2(sucess_matrix_[pair_greedy[0]], sucess_matrix_[pair_greedy[1]])

np.int64(8)

In [None]:
np.sum(sucess_matrix_[pair_greedy[0]] | sucess_matrix_[pair_greedy[1]])

np.int64(15)

In [None]:
solvers[pair_greedy[0]], solvers[pair_greedy[1]]

('AI-MO_Kimina-Prover-Preview-Distill-7B',
 'stoney0062_Leanabell-Prover-DS-SFT')

In [None]:
all_cols = set(range(success_matrix.shape[1]))
cols_not_greedy = list(all_cols - set(cols_greedy))

success_matrix_other = success_matrix[:, cols_not_greedy]

In [None]:
np.sum(success_matrix_other[pair_greedy[0]] | success_matrix_other[pair_greedy[1]])

np.int64(8)

In [None]:
np.shape(success_matrix_other)

(9, 25)

# Peer Fine-Tuning (PFT)
Implementation of $L_{prove}$

## Datasets

In [None]:
from datasets import load_dataset

ds = load_dataset('script-jpg/minif2f_test-first40-in-fix')

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


README.md:   0%|          | 0.00/483 [00:00<?, ?B/s]

Repo card metadata block was not found. Setting CardData to empty.


data/train-00000-of-00001.parquet:   0%|          | 0.00/17.0k [00:00<?, ?B/s]

Generating train split:   0%|          | 0/40 [00:00<?, ? examples/s]

In [None]:
ds['train'][selected_training_problems_solver_1]['name']

['numbertheory_4x3m7y3neq2003', 'imo_1983_p6', 'mathd_numbertheory_427']

In [None]:
selected_training_formal_statements_solver_1 = ds['train'][selected_training_problems_solver_1]['formal_statement']
selected_training_formal_statements_solver_1

['import Mathlib\nimport Aesop\n\nset_option maxHeartbeats 0\n\nopen BigOperators Real Nat Topology Rat\n\n/-- Show that there are no integers $x$ and $y$ such that $4x^3 - 7y^3 = 2003$.-/\ntheorem numbertheory_4x3m7y3neq2003 (x y : ℤ) : 4 * x ^ 3 - 7 * y ^ 3 ≠ 2003 := by\n',
 'import Mathlib\nimport Aesop\n\nset_option maxHeartbeats 0\n\nopen BigOperators Real Nat Topology Rat\n\n/-- Let $a$, $b$ and $c$ be the lengths of the sides of a triangle. Prove that\n\n$a^2 b(a-b) + b^2 c(b-c) + c^2 a(c-a) \\geq 0$.\n\nDetermine when equality occurs.-/\ntheorem imo_1983_p6 (a b c : ℝ) (h₀ : 0 < a ∧ 0 < b ∧ 0 < c) (h₁ : c < a + b) (h₂ : b < a + c)\n    (h₃ : a < b + c) : 0 ≤ a ^ 2 * b * (a - b) + b ^ 2 * c * (b - c) + c ^ 2 * a * (c - a) := by\n',
 'import Mathlib\nimport Aesop\n\nset_option maxHeartbeats 0\n\nopen BigOperators Real Nat Topology Rat\n\n/-- If $A$ is the sum of the positive divisors of $500$, what is the sum of the distinct prime divisors of $A$? Show that it is 25.-/\ntheorem m

In [None]:
selected_training_formal_statements_solver_2 = ds['train'][selected_training_problems_solver_2]['formal_statement']
selected_training_formal_statements_solver_2

['import Mathlib\nimport Aesop\n\nset_option maxHeartbeats 0\n\nopen BigOperators Real Nat Topology Rat\n\n/-- What is the ones digit of $1 \\cdot 3 \\cdot 5 \\cdot 7 \\cdot 9 \\cdot 11 \\cdot 13$? Show that it is 5.-/\ntheorem mathd_numbertheory_299 : 1 * 3 * 5 * 7 * 9 * 11 * 13 % 10 = 5 := by\n',
 'import Mathlib\nimport Aesop\n\nset_option maxHeartbeats 0\n\nopen BigOperators Real Nat Topology Rat\n\n/-- Expand the product $(x+1)^2 \\cdot x$. Show that it is x^3 + 2x^2 + x.-/\ntheorem mathd_algebra_176 (x : ℝ) : (x + 1) ^ 2 * x = x ^ 3 + 2 * x ^ 2 + x := by\n',
 'import Mathlib\nimport Aesop\n\nset_option maxHeartbeats 0\n\nopen BigOperators Real Nat Topology Rat\n\n/-- What is the product of all positive odd integers less than $10000$?\n\n$\\text{(A)}\\ \\dfrac{10000!}{(5000!)^2}\\qquad \\text{(B)}\\ \\dfrac{10000!}{2^{5000}}\\qquad\n\\text{(C)}\\ \\dfrac{9999!}{2^{5000}}\\qquad \\text{(D)}\\ \\dfrac{10000!}{2^{5000} \\cdot 5000!}\\qquad\n\\text{(E)}\\ \\dfrac{5000!}{2^{5000}}$ Sho

## Process Proofs

In [None]:
!unzip corrected_proofs.zip -d proofs

Archive:  corrected_proofs.zip
   creating: proofs/corrected_proofs/
   creating: proofs/corrected_proofs/32/
   creating: proofs/corrected_proofs/35/
   creating: proofs/corrected_proofs/34/
   creating: proofs/corrected_proofs/33/
  inflating: proofs/corrected_proofs/.DS_Store  
   creating: proofs/corrected_proofs/20/
   creating: proofs/corrected_proofs/18/
   creating: proofs/corrected_proofs/27/
   creating: proofs/corrected_proofs/9/
   creating: proofs/corrected_proofs/0/
   creating: proofs/corrected_proofs/11/
   creating: proofs/corrected_proofs/7/
   creating: proofs/corrected_proofs/29/
   creating: proofs/corrected_proofs/16/
   creating: proofs/corrected_proofs/6/
   creating: proofs/corrected_proofs/28/
   creating: proofs/corrected_proofs/17/
   creating: proofs/corrected_proofs/1/
   creating: proofs/corrected_proofs/10/
   creating: proofs/corrected_proofs/19/
   creating: proofs/corrected_proofs/26/
   creating: proofs/corrected_proofs/8/
   creating: proofs/correct

In [None]:
import json
from pathlib import Path
import pandas as pd
from typing import Callable
import re

def load_proofs(base_path: str | Path = "") -> pd.DataFrame:
    base_path = Path(base_path)
    data_rows = []
    for file_path in base_path.glob('proofs/corrected_proofs/*/*/*.txt'):
        path_parts = file_path.parts
        trial_id = file_path.stem
        model = path_parts[-2]
        problem_id = path_parts[-3]
        try:
            with file_path.open('r', encoding='utf-8') as f:
                text_content = f.read()
            data_rows.append({
                'problem_id': str(problem_id),
                'model': model,
                'trial_id': str(trial_id),
                'text': text_content
            })
        except Exception as e:
            print(f"Could not read file {file_path}: {e}")
    return pd.DataFrame(data_rows)


def transform_json_with_proofs(
    data: dict,
    proofs_base_path: str | Path = "",
    no_proof_placeholder: str = "",
    postprocess: Callable[[str], str] | None = None,
    preprocess : Callable[[str], str] | None = None
) -> dict:
    """
    Replaces each 0/1 in the attempts arrays with:
      - proof text if flag == 1 and file exists
      - placeholder if flag == 0
      - marker if flag == 1 but file missing
    If `postprocess` is provided, it will be applied to each proof string.
    """
    df = load_proofs(proofs_base_path)
    proof_lookup = {
        (row["model"], str(row["problem_id"]), str(row["trial_id"])): row["text"]
        for _, row in df.iterrows()
    }

    out = {}
    for model, problems in data.items():
        out[model] = {}
        for problem_id, attempts in problems.items():
            new_attempts = []
            for idx, flag in enumerate(attempts, start=1):  # trials are 1..8
                key = (model, str(problem_id), str(idx))
                if flag == 1:
                    if key in proof_lookup:
                        text = proof_lookup[key]
                        if preprocess:
                            text = preprocess(text)
                        if postprocess:
                            text = postprocess(text)
                        new_attempts.append(text)
                    else:
                        missing_path = Path("proofs") / str(problem_id) / model / f"{idx}.txt"
                        new_attempts.append(f"[MISSING PROOF] Expected at: {missing_path}")
                # else:
                #     new_attempts.append(no_proof_placeholder)
            out[model][str(problem_id)] = new_attempts
    return out

def remove_lean_block_comments(text: str) -> str:
    """
    Removes all Lean-style block comments: /- ... -/
    (optionally handles doc comments that start with /-!)
    Non-greedy across lines; does not handle nested block comments.
    """
    return re.sub(r"/-!?.*?-/", "", text, flags=re.DOTALL)

if __name__ == "__main__":
    # Example JSON input
    with open("proof_outcomes_by_model.json", "r", encoding="utf-8") as f:
        original = json.load(f)

    # Example lambda: remove lines starting with "--"
    clean_lambda = lambda proof: "\n".join(
        line.split("--", 1)[0].rstrip()
        for line in proof.splitlines()
        if line.split("--", 1)[0].strip()  # keep only non-empty
    )

    transformed = transform_json_with_proofs(
        original,
        proofs_base_path=".",
        postprocess=clean_lambda,
        preprocess=remove_lean_block_comments
    )

    with open("output.json", "w", encoding="utf-8") as f:
        json.dump(transformed, f, ensure_ascii=False, indent=2)


In [None]:
transformed["AI-MO_Kimina-Prover-Preview-Distill-7B"]["0"][0]

'rw [h₁, h₂, h₃]\nnorm_num'

In [None]:
transformed["AI-MO_Kimina-Prover-Preview-Distill-7B"]["0"][0].splitlines()

['rw [h₁, h₂, h₃]', 'norm_num']

## Use together

In [None]:
solver_1_training_ex

array([2, 7, 8])

In [None]:
solver_1_training_ex.tolist()

[2, 7, 8]

In [None]:
s1 = solvers[pair_greedy[0]]
s2 = solvers[pair_greedy[1]]

In [None]:
selected_training_problems_solver_1, selected_training_problems_solver_2

([1, 9, 20], [13, 35, 3, 7, 11, 14, 15, 17])

In [None]:
selected_training_proofs_solver_1 = [transformed[s2][str(int(x))][0] for x in selected_training_problems_solver_1]
selected_training_proofs_solver_1

['  intro h\n  have h_mod : (4 * x ^ 3 - 7 * y ^ 3) % 7 = 2003 % 7 := by rw [h]\n  norm_num at h_mod\n  have h₁ : x ^ 3 % 7 = 0 ∨ x ^ 3 % 7 = 1 ∨ x ^ 3 % 7 = 6 := by\n    have : x % 7 = 0 ∨ x % 7 = 1 ∨ x % 7 = 2 ∨ x % 7 = 3 ∨ x % 7 = 4 ∨ x % 7 = 5 ∨ x % 7 = 6 := by omega\n    rcases this with (h | h | h | h | h | h | h) <;>\n      simp [h, pow_three, Int.mul_emod, Int.sub_emod]\n  have h₂ : y ^ 3 % 7 = 0 ∨ y ^ 3 % 7 = 1 ∨ y ^ 3 % 7 = 6 := by\n    have : y % 7 = 0 ∨ y % 7 = 1 ∨ y % 7 = 2 ∨ y % 7 = 3 ∨ y % 7 = 4 ∨ y % 7 = 5 ∨ y % 7 = 6 := by omega\n    rcases this with (h | h | h | h | h | h | h) <;>\n      simp [h, pow_three, Int.mul_emod, Int.sub_emod]\n  omega',
 '  nlinarith [sq_nonneg (a - b), sq_nonneg (b - c), sq_nonneg (c - a),\n    mul_pos h₀.1 h₀.2.1, mul_pos h₀.2.1 h₀.2.2, mul_pos h₀.2.2 h₀.1,\n    mul_pos (sub_pos.mpr h₁) (sub_pos.mpr h₂), mul_pos (sub_pos.mpr h₂) (sub_pos.mpr h₃),\n    mul_pos (sub_pos.mpr h₃) (sub_pos.mpr h₁)]',
 '  simp_all [Nat.divisors]\n  decide']

In [None]:
selected_training_proofs_solver_2 = [transformed[s1][str(int(x))][0] for x in selected_training_problems_solver_2]
selected_training_proofs_solver_2

['norm_num',
 'ring',
 'native_decide',
 'have : σ.1 10 = 2 := by\n  rw [←σ.right_inv 2]\n  rw [h₀]\nrw [this]\nrw [←σ.right_inv 1]\nrw [h₂]',
 '  have h₃ : y = 2 * x / 5 := by\n    linarith\n  have h₄ : z = 7 * x / 25 := by\n    rw [h₃] at h₂\n    linarith\n  field_simp [h₀]\n  rw [h₄]\n  ring',
 'calc\n  (100 ^ 2 - 7 ^ 2 : ℝ) / (70 ^ 2 - 11 ^ 2) * ((70 - 11) * (70 + 11) / ((100 - 7) * (100 + 7)))\n      = ((100 - 7) * (100 + 7) / (70 - 11) / (70 + 11)) * ((70 - 11) * (70 + 11) / (100 - 7) / (100 + 7)) := by\n    ring\n  _ = (((100 - 7) / (100 - 7)) * ((100 + 7) / (100 + 7)) * ((70 - 11) / (70 - 11)) * ((70 + 11) / (70 + 11))) := by\n    ring\n  _ = 1 := by\n    norm_num',
 "cases' abs_cases (a - b) with h h <;>\nnlinarith [sq_nonneg (a + b), sq_nonneg (a - b), h₀, h, sq_nonneg (a - b)]",
 '  rw [h₀, h₁]\n  norm_num']

In [None]:
training_set_solver_1 = [{"input": i, "output": o} for i, o in zip(selected_training_formal_statements_solver_1, selected_training_proofs_solver_1)]
training_set_solver_2 = [{"input": i, "output": o} for i, o in zip(selected_training_formal_statements_solver_2, selected_training_proofs_solver_2)]

In [None]:
print(training_set_solver_1[0]["input"])

import Mathlib
import Aesop

set_option maxHeartbeats 0

open BigOperators Real Nat Topology Rat

/-- Show that there are no integers $x$ and $y$ such that $4x^3 - 7y^3 = 2003$.-/
theorem numbertheory_4x3m7y3neq2003 (x y : ℤ) : 4 * x ^ 3 - 7 * y ^ 3 ≠ 2003 := by



In [None]:
print(training_set_solver_1[0]["output"])

  intro h
  have h_mod : (4 * x ^ 3 - 7 * y ^ 3) % 7 = 2003 % 7 := by rw [h]
  norm_num at h_mod
  have h₁ : x ^ 3 % 7 = 0 ∨ x ^ 3 % 7 = 1 ∨ x ^ 3 % 7 = 6 := by
    have : x % 7 = 0 ∨ x % 7 = 1 ∨ x % 7 = 2 ∨ x % 7 = 3 ∨ x % 7 = 4 ∨ x % 7 = 5 ∨ x % 7 = 6 := by omega
    rcases this with (h | h | h | h | h | h | h) <;>
      simp [h, pow_three, Int.mul_emod, Int.sub_emod]
  have h₂ : y ^ 3 % 7 = 0 ∨ y ^ 3 % 7 = 1 ∨ y ^ 3 % 7 = 6 := by
    have : y % 7 = 0 ∨ y % 7 = 1 ∨ y % 7 = 2 ∨ y % 7 = 3 ∨ y % 7 = 4 ∨ y % 7 = 5 ∨ y % 7 = 6 := by omega
    rcases this with (h | h | h | h | h | h | h) <;>
      simp [h, pow_three, Int.mul_emod, Int.sub_emod]
  omega


In [None]:
print(training_set_solver_2[0]["input"])

import Mathlib
import Aesop

set_option maxHeartbeats 0

open BigOperators Real Nat Topology Rat

/-- What is the ones digit of $1 \cdot 3 \cdot 5 \cdot 7 \cdot 9 \cdot 11 \cdot 13$? Show that it is 5.-/
theorem mathd_numbertheory_299 : 1 * 3 * 5 * 7 * 9 * 11 * 13 % 10 = 5 := by



In [None]:
print(training_set_solver_2[0]["output"])

norm_num


## Construct HF Dataset

In [None]:
from datasets import Dataset

def preprocess_function(example):
    return {
        "prompt": [{"role": "user", "content": example["input"]}],
        "completion": [{"role": "assistant", "content": example["output"]}],
    }

# Apply preprocessing and drop original columns
training_set_solver_1_hf = (Dataset.from_list(training_set_solver_1)).map(preprocess_function, remove_columns=["input", "output"])
training_set_solver_2_hf = (Dataset.from_list(training_set_solver_2)).map(preprocess_function, remove_columns=["input", "output"])

Map:   0%|          | 0/3 [00:00<?, ? examples/s]

Map:   0%|          | 0/8 [00:00<?, ? examples/s]

In [None]:
training_set_solver_1_hf[0]

{'prompt': [{'content': 'import Mathlib\nimport Aesop\n\nset_option maxHeartbeats 0\n\nopen BigOperators Real Nat Topology Rat\n\n/-- Show that there are no integers $x$ and $y$ such that $4x^3 - 7y^3 = 2003$.-/\ntheorem numbertheory_4x3m7y3neq2003 (x y : ℤ) : 4 * x ^ 3 - 7 * y ^ 3 ≠ 2003 := by\n',
   'role': 'user'}],
 'completion': [{'content': '  intro h\n  have h_mod : (4 * x ^ 3 - 7 * y ^ 3) % 7 = 2003 % 7 := by rw [h]\n  norm_num at h_mod\n  have h₁ : x ^ 3 % 7 = 0 ∨ x ^ 3 % 7 = 1 ∨ x ^ 3 % 7 = 6 := by\n    have : x % 7 = 0 ∨ x % 7 = 1 ∨ x % 7 = 2 ∨ x % 7 = 3 ∨ x % 7 = 4 ∨ x % 7 = 5 ∨ x % 7 = 6 := by omega\n    rcases this with (h | h | h | h | h | h | h) <;>\n      simp [h, pow_three, Int.mul_emod, Int.sub_emod]\n  have h₂ : y ^ 3 % 7 = 0 ∨ y ^ 3 % 7 = 1 ∨ y ^ 3 % 7 = 6 := by\n    have : y % 7 = 0 ∨ y % 7 = 1 ∨ y % 7 = 2 ∨ y % 7 = 3 ∨ y % 7 = 4 ∨ y % 7 = 5 ∨ y % 7 = 6 := by omega\n    rcases this with (h | h | h | h | h | h | h) <;>\n      simp [h, pow_three, Int.mul_emod, Int.s

In [None]:
training_set_solver_2_hf[0]

{'prompt': [{'content': 'import Mathlib\nimport Aesop\n\nset_option maxHeartbeats 0\n\nopen BigOperators Real Nat Topology Rat\n\n/-- What is the ones digit of $1 \\cdot 3 \\cdot 5 \\cdot 7 \\cdot 9 \\cdot 11 \\cdot 13$? Show that it is 5.-/\ntheorem mathd_numbertheory_299 : 1 * 3 * 5 * 7 * 9 * 11 * 13 % 10 = 5 := by\n',
   'role': 'user'}],
 'completion': [{'content': 'norm_num', 'role': 'assistant'}]}

In [None]:
# from transformers import AutoTokenizer, AutoModelForCausalLM, TrainingArguments
# from peft import LoraConfig
# from trl import SFTTrainer
# import torch
# from accelerate import Accelerator

# MODEL_ID = "AI-MO/Kimina-Prover-Preview-Distill-7B"
# OUTPUT_DIR = "./kimina_fine_tuned"

# # Initialize accelerator
# accelerator = Accelerator()

# # Tokenizer & model
# tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True)
# model = AutoModelForCausalLM.from_pretrained(
#     MODEL_ID,
#     trust_remote_code=True,
#     torch_dtype=torch.bfloat16,  # Changed to bfloat16
#     device_map="auto"
# )
# model.gradient_checkpointing_enable()  # Enable gradient checkpointing
# model, tokenizer = accelerator.prepare(model, tokenizer)

# # LoRA config
# peft_config = LoraConfig(
#     r=4,  # Reduced rank
#     lora_alpha=16,
#     target_modules=["q_proj", "v_proj"],
#     lora_dropout=0.05,
#     bias="none",
#     task_type="CAUSAL_LM"
# )

# # Training args
# training_args = TrainingArguments(
#     output_dir=OUTPUT_DIR,
#     per_device_train_batch_size=1,  # Reduced
#     gradient_accumulation_steps=4,  # Increased
#     learning_rate=2e-4,
#     num_train_epochs=3,
#     logging_steps=10,
#     save_strategy="epoch",
#     max_grad_norm=0.3,
#     bf16=True,  # or bf16=True
#     max_steps=-1,
#     lr_scheduler_type="cosine",
#     warmup_ratio=0.03,
#     gradient_checkpointing=True,
#     report_to="none"
# )

# # Clear GPU memory
# torch.cuda.empty_cache()

# # Trainer
# trainer = SFTTrainer(
#     model=model,
#     train_dataset=training_set_solver_1_hf,
#     peft_config=peft_config,
#     args=training_args,
#     # max_seq_length=512  # Limit sequence length
# )



In [None]:
# trainer.train()
# # trainer.model.save_pretrained(OUTPUT_DIR)
# # tokenizer.save_pretrained(OUTPUT_DIR)
# !zip -r kimina_fine_tuned.zip kimina_fine_tuned/
# torch.cuda.empty_cache()

# Stability Term: $L_{stability}$

- **TODO**: Get the E_i dataset ✅

- **TODO**: Implement EMA

In [None]:
# by complementarity of sucess cases between solver 1 and solver  2, define the "easy" set
easy_solver_1 = selected_training_problems_solver_2
easy_solver_2 = selected_training_problems_solver_1
easy_solver_1, easy_solver_2


([13, 35, 3, 7, 11, 14, 15, 17], [1, 9, 20])

**GPT code below. Be careful. Could be extremely buggy idk**

In [None]:
import torch

In [None]:
torch.cuda.empty_cache()

In [None]:
# from transformers import (
#     Trainer, TrainingArguments,
#     AutoModelForCausalLM, AutoTokenizer,
#     DataCollatorForLanguageModeling,
# )
# from datasets import Dataset  # if you need it for map return types
# import torch
# import torch.nn.functional as F
# from peft import LoraConfig, get_peft_model
# from torch_ema import ExponentialMovingAverage
# from copy import deepcopy
# from accelerate import Accelerator

# # Dataset
# easy_set_solver_1_hf = training_set_solver_2_hf  # your dataset with 'prompt' and 'completion' fields

# # Initialize accelerator
# accelerator = Accelerator()

# # Load base model + tokenizer
# model_name = "AI-MO/Kimina-Prover-Preview-Distill-7B"
# tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True)

# # --- FIX 1: ensure there is a pad token (typical for causal LMs) ---
# if tokenizer.pad_token is None:
#     # for LLaMA-like models, use eos as pad
#     tokenizer.pad_token = tokenizer.eos_token
# # keep bos/eos as whatever the tokenizer defines; don't invent new ones
# # explicitly set IDs on model config later

# model = AutoModelForCausalLM.from_pretrained(
#     model_name,
#     torch_dtype=torch.bfloat16,
#     device_map="auto",
# )

# # reflect pad token on the model/generation configs
# model.config.pad_token_id = tokenizer.pad_token_id
# if getattr(model.generation_config, "pad_token_id", None) is None:
#     model.generation_config.pad_token_id = tokenizer.pad_token_id

# model.gradient_checkpointing_enable()

# # (optional) you don't need to pass the tokenizer to accelerator.prepare
# model = accelerator.prepare(model)

# # Apply LoRA to the student
# peft_config = LoraConfig(r=4, lora_alpha=16, target_modules=["q_proj", "v_proj"])
# model = get_peft_model(model, peft_config)

# # Build a frozen, CPU-resident teacher
# teacher = deepcopy(model.base_model if hasattr(model, "base_model") else model)
# teacher = teacher.to(torch.bfloat16).cpu().eval()
# for p in teacher.parameters():
#     p.requires_grad = False

# # EMA on trainable params
# ema = ExponentialMovingAverage(
#     (p for p in model.parameters() if p.requires_grad),
#     decay=0.999,
# )

# # --- FIX 2: preprocess dataset into input_ids/attention_mask/labels ---
# max_len = getattr(tokenizer, "model_max_length", 4096)
# if max_len is None or max_len > 32768:  # guard against "very large" sentinel values
#     max_len = 4096

# def format_and_tokenize(example):
#     # join prompt + completion; for KL you can supervise all tokens
#     # Extract the string content from the list of dictionaries
#     prompt_text = example.get("prompt", [{}])[0].get("content", "")
#     completion_text = example.get("completion", [{}])[0].get("content", "")
#     text = prompt_text + completion_text

#     toks = tokenizer(
#         text,
#         truncation=True,
#         max_length=max_len,
#         padding=False,            # padding happens in the collator
#         return_attention_mask=True,
#     )
#     # labels = next-token prediction on the whole sequence
#     toks["labels"] = toks["input_ids"].copy()
#     return toks

# # use batched=False to avoid returning ragged lists inside lists
# tokenized_ds = easy_set_solver_1_hf.map(format_and_tokenize, remove_columns=[c for c in easy_set_solver_1_hf.column_names if c not in ("input_ids","attention_mask","labels")])

# # Data collator for CLM (pads and copies labels if needed)
# collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)

# class KLTrainer(Trainer):
#     def __init__(self, teacher, temperature=1.0, *args, **kwargs):
#         super().__init__(*args, **kwargs)
#         self.teacher = teacher
#         self.temperature = temperature

#     def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None, **kwargs):
#         input_ids = inputs["input_ids"]
#         attention_mask = inputs.get(
#             "attention_mask", (input_ids != self.tokenizer.pad_token_id).long()
#         )

#         # Student forward
#         outputs = model(input_ids=input_ids, attention_mask=attention_mask)
#         logits_s = outputs.logits  # [B, T, V]

#         # Teacher forward on CPU (no grad)
#         with torch.no_grad():
#             logits_t = self.teacher(
#                 input_ids=input_ids.to("cpu"),
#                 attention_mask=attention_mask.to("cpu"),
#             ).logits.to(logits_s.device)  # [B, T, V]

#         t = self.temperature

#         # Shift for next-token prediction
#         logits_s = logits_s[:, :-1, :] / t
#         logits_t = logits_t[:, :-1, :] / t
#         tok_mask = attention_mask[:, 1:]  # [B, T-1] (0/1)

#         # Log-probs
#         logp_s = F.log_softmax(logits_s, dim=-1)          # [B, T-1, V]
#         logp_t = F.log_softmax(logits_t, dim=-1)          # [B, T-1, V]
#         p_t    = logp_t.exp()                              # [B, T-1, V]

#         # KL over vocab (no expand, no boolean indexing)
#         kl_per_pos = (p_t * (logp_t - logp_s)).sum(dim=-1)  # [B, T-1]

#         # Mask over time, cast to float to avoid index semantics
#         tok_mask = tok_mask.to(kl_per_pos.dtype)            # float
#         kl_sum = (kl_per_pos * tok_mask).sum() * (t * t)
#         denom = tok_mask.sum().clamp(min=1.0)
#         loss = kl_sum / denom

#         return (loss, outputs) if return_outputs else loss

#     def training_step(self, model, inputs, num_items_in_batch=None, **kwargs):
#         loss = super().training_step(model, inputs, num_items_in_batch=num_items_in_batch, **kwargs)
#         ema.update()
#         return loss

# # Training args
# training_args = TrainingArguments(
#     output_dir="./kimina_stabilized",
#     per_device_train_batch_size=1,
#     gradient_accumulation_steps=4,
#     max_steps=100,
#     bf16=True,
#     gradient_checkpointing=True,
#     report_to="none",
#     remove_unused_columns=True,   # <-- let Trainer drop non-model keys if any remain
# )

# # Clear GPU memory
# torch.cuda.empty_cache()

# # Trainer
# trainer = KLTrainer(
#     model=model,
#     teacher=teacher,
#     args=training_args,
#     train_dataset=tokenized_ds,
#     tokenizer=tokenizer,
#     data_collator=collator,
#     temperature=1.0,
# )

In [None]:
# trainer.train()
# torch.cuda.empty_cache()