In [1]:
!pip install -qqq bitsandbytes==0.45.3 # NEED this version, found via https://www.reddit.com/r/LocalLLaMA/comments/1j1mq6y/4bit_quantization_requires_the_latest_version_of/
!pip install -qqq accelerate
!pip install -U -qqq transformers
import bitsandbytes as bnb
import math
import random
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
import torch
from typing import List, Tuple

# Get Lean4ServerScheduler from DeepSeekProver

The following cells clone DeepSeekProver and load in the `Lean4ServerScheduler` object for lean4 verification.

In [None]:
# clone deepseek
!git clone https://github.com/deepseek-ai/DeepSeek-Prover-V1.5.git

In [None]:
# rename the repo to a valid string and flatten by 1 depth
!mv DeepSeek-Prover-V1.5 DeepSeekProver
!mv DeepSeekProver/* . && rmdir DeepSeekProver

## MANDATORY MANUAL STEP:
Remove the line `from .generator import GeneratorProcess` from `prover/workers/__init__.py`

This bypasses the `vLLM` import, mandatory for non-compatible devices.

In [2]:
# import the verifier
from prover.lean.verifier import Lean4ServerScheduler

# RMaxTS Setup

In [8]:
class Node:
    def __init__(self, state: str, tactics_applied: List[str], parent=None):
        self.state = state  # Theorem statement plus tactics applied so far
        self.tactics_applied = tactics_applied  # List of tactics applied
        self.parent = parent
        self.children = []  # List of (tactic_sequence, prior_prob, child_node)
        self.visit_count = 0
        self.total_reward = 0.0

class LeanVerifier:
    """Wrapper for DeepSeekProver's Lean4ServerScheduler to verify proof states."""
    def __init__(self):
        self.scheduler = Lean4ServerScheduler(
            max_concurrent_requests=1,
            timeout=300,
            memory_limit=10,
            name='verifier'
        )

    def construct_full_code(self, theorem: str, tactics: List[str]) -> str:
        """Construct Lean4 code from theorem and tactics."""
        # Assume theorem is in format "theorem name : statement := by"
        proof_lines = "\n  ".join(tactic.strip() for tactic in tactics if tactic.strip())
        full_code = f"{theorem}\n  {proof_lines}" if proof_lines else theorem
        return full_code

    def verify(self, theorem: str, tactics: List[str]) -> Tuple[bool, bool]:
        """Verify a sequence of tactics applied to a theorem."""
        code = self.construct_full_code(theorem, tactics)
        request_id = self.scheduler.submit_all_request([code])[0]
        result = self.scheduler.get_all_request_outputs([request_id])[0]
        is_valid = result.get('pass', False)
        is_complete = result.get('complete', False)
        return is_valid, is_complete

    def verify_with_truncation(self, theorem: str, tactics: List[str]) -> Tuple[bool, bool, int]:
        """Verify and find the truncation point if invalid."""
        code = self.construct_full_code(theorem, tactics)
        request_id = self.scheduler.submit_all_request([code])[0]
        result = self.scheduler.get_all_request_outputs([request_id])[0]
        is_valid = result.get('pass', False)
        is_complete = result.get('complete', False)
        if not is_valid:
            # Assume 'error_step' indicates the index of the first invalid tactic
            valid_steps = result.get('error_step', 0)
        else:
            valid_steps = len(tactics)
        return is_valid, is_complete, valid_steps

    def close(self):
        """Close the scheduler and its processes."""
        self.scheduler.close()

class RMaxTS:
    def __init__(self, model, tokenizer, theorem: str, c=1.0, b=1.0, max_depth=10, num_sequences=5):
        self.model = model
        self.tokenizer = tokenizer
        self.theorem = theorem
        self.verifier = LeanVerifier()
        self.c = c
        self.b = b
        self.max_depth = max_depth
        self.num_sequences = num_sequences
        self.root = None
        self.verification_cache = {}

    def verify_state(self, tactics: List[str]) -> Tuple[bool, bool]:
        """Verify a state defined by tactics applied to the theorem."""
        tactics_tuple = tuple(tactics)
        cache_key = (self.theorem, tactics_tuple)
        if cache_key not in self.verification_cache:
            is_valid, is_complete = self.verifier.verify(self.theorem, tactics)
            self.verification_cache[cache_key] = (is_valid, is_complete)
        return self.verification_cache[cache_key]

    def is_terminal(self, node: Node) -> bool:
        """A node is terminal if invalid or proof is complete."""
        is_valid, is_complete = self.verify_state(node.tactics_applied)
        return not is_valid or is_complete

    def is_proof_complete(self, tactics: List[str]) -> bool:
        """Check if the proof is complete."""
        _, is_complete = self.verify_state(tactics)
        return is_complete

    def generate_tactics(self, state: str) -> List[Tuple[List[str], float]]:
        """Generate multiple proof continuations using the LLM."""
        prompt = state + "\nComplete the proof:\n"
        inputs = self.tokenizer(prompt, return_tensors="pt").to(self.model.device)
        outputs = self.model.generate(
            **inputs,
            max_new_tokens=200,
            num_beams=self.num_sequences,
            num_return_sequences=self.num_sequences,
            pad_token_id=self.tokenizer.eos_token_id
        )
        sequences = []
        for i, output in enumerate(outputs):
            generated_text = self.tokenizer.decode(output[inputs["input_ids"].shape[1]:], skip_special_tokens=True).strip()
            tactic_sequence = [t.strip() for t in generated_text.split("\n") if t.strip()]
            prior_prob = 1.0 / (i + 1)
            sequences.append((tactic_sequence, prior_prob))
        return sequences

    def select(self, node: Node) -> Node:
        """Select a node to expand using UCT with RMaxTS bonus."""
        while node.children and not self.is_terminal(node):
            node = max(
                node.children,
                key=lambda x: self.uct_value(x[2], node.visit_count),
            )[2]
        return node

    def uct_value(self, child_node: Node, parent_visits: int) -> float:
        """Compute UCT value with intrinsic exploration bonus."""
        if child_node.visit_count == 0:
            return float("inf")
        exploitation = child_node.total_reward / child_node.visit_count
        exploration = self.c * math.sqrt(math.log(parent_visits) / child_node.visit_count)
        bonus = self.b / math.sqrt(child_node.visit_count)
        return exploitation + exploration + bonus

    def expand(self, node: Node):
        """Expand node by generating and truncating proof continuations."""
        tactic_sequences = self.generate_tactics(node.state)
        for tactic_seq, prior_prob in tactic_sequences:
            new_tactics = node.tactics_applied + tactic_seq
            is_valid, is_complete, valid_steps = self.verifier.verify_with_truncation(self.theorem, new_tactics)
            if valid_steps > 0:
                truncated_tactics = new_tactics[:valid_steps]
                new_state = self.theorem + "\n" + "\n".join(truncated_tactics)
                child_node = Node(new_state, truncated_tactics, parent=node)
                node.children.append((tactic_seq[:valid_steps], prior_prob, child_node))

    def simulate(self, tactics: List[str]) -> float:
        """Simulate by generating a continuation and verifying it."""
        current_state = self.theorem + "\n" + "\n".join(tactics)
        sequences = self.generate_tactics(current_state)
        if not sequences:
            return 0.0
        tactic_seq, _ = sequences[0]
        new_tactics = tactics + tactic_seq
        is_valid, is_complete = self.verify_state(new_tactics)
        return 1.0 if is_complete else 0.0

    def backpropagate(self, node: Node, reward: float):
        """Update node statistics up the tree."""
        while node is not None:
            node.visit_count += 1
            node.total_reward += reward
            node = node.parent

    def search_best_proof(self, initial_state: str, num_iterations: int = 100) -> List[str]:
        """Perform RMaxTS search to generate a complete proof."""
        self.root = Node(initial_state, [], parent=None)
        is_valid, _ = self.verify_state(self.root.tactics_applied)
        if not is_valid:
            return None

        for _ in range(num_iterations):
            node = self.select(self.root)
            if self.is_terminal(node):
                reward = 1.0 if self.is_proof_complete(node.tactics_applied) else 0.0
            else:
                self.expand(node)
                if node.children:
                    child = random.choice(node.children)[2]
                    reward = self.simulate(child.tactics_applied)
                else:
                    reward = 0.0
            self.backpropagate(node, reward)

        current_node = self.root
        while current_node.children and not self.is_proof_complete(current_node.tactics_applied):
            current_node = max(current_node.children, key=lambda x: x[2].visit_count)[2]
        if self.is_proof_complete(current_node.tactics_applied):
            return current_node.tactics_applied
        return None

    def generate_whole_proof(self, theorem: str, iterations_per_sim: int = 100) -> List[str]:
        """Generate a complete proof for the given theorem."""
        self.theorem = theorem
        proof = self.search_best_proof(theorem, num_iterations=iterations_per_sim)
        if proof:
            final_state = theorem + "\n" + "\n".join(proof)
            print("Final proof state:", final_state)
            print("Proof steps:", proof)
            return proof
        else:
            print("Failed to generate a complete proof.")
            return None

    def close(self):
        """Close the verifier to release resources."""
        self.verifier.close()

# Load Quantized Model

In [4]:
# load 8 bit quantized model
"""
quantization_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_use_double_quant=True
)
"""
quantization_config = BitsAndBytesConfig(
    load_in_8bit=True
)


model_name = "deepseek-ai/DeepSeek-Prover-V1.5-RL"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name, device_map="cuda", quantization_config=quantization_config)

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


tokenizer_config.json:   0%|          | 0.00/1.59k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/4.61M [00:00<?, ?B/s]

config.json:   0%|          | 0.00/699 [00:00<?, ?B/s]

model.safetensors.index.json:   0%|          | 0.00/22.7k [00:00<?, ?B/s]

Fetching 2 files:   0%|          | 0/2 [00:00<?, ?it/s]

model-00002-of-000002.safetensors:   0%|          | 0.00/5.23G [00:00<?, ?B/s]

model-00001-of-000002.safetensors:   0%|          | 0.00/8.59G [00:00<?, ?B/s]

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

# Demonstrate Quantized Model on a Simple Proof

In [5]:
# run the quant model on a short lean thm proof
lean_string = r'''import Mathlib
import Aesop

set_option maxHeartbeats 0

open BigOperators Real Nat Topology Rat

/-- The second and fourth terms of a geometric sequence are $2$ and $6$. Which of the following is a possible first term?
Show that it is $\frac{2\sqrt{3}}{3}$.-/
theorem amc12b_2003_p6 (a r : ℝ) (u : ℕ → ℝ) (h₀ : ∀ k, u k = a * r ^ k) (h₁ : u 1 = 2)
  (h₂ : u 3 = 6) : u 0 = 2 / Real.sqrt 3 ∨ u 0 = -(2 / Real.sqrt 3) := by
'''

output = model.generate(
    input_ids = tokenizer(lean_string, return_tensors="pt").input_ids.to(model.device),
    max_new_tokens=1000,
    do_sample=True,
    temperature=0.7,
    top_p=0.95,
    pad_token_id=tokenizer.eos_token_id
)
print(tokenizer.decode(output[0], skip_special_tokens=True))

The attention mask is not set and cannot be inferred from input because pad token is same as eos token. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.


import Mathlib
import Aesop

set_option maxHeartbeats 0

open BigOperators Real Nat Topology Rat

/-- The second and fourth terms of a geometric sequence are $2$ and $6$. Which of the following is a possible first term?
Show that it is $\frac{2\sqrt{3}}{3}$.-/
theorem amc12b_2003_p6 (a r : ℝ) (u : ℕ → ℝ) (h₀ : ∀ k, u k = a * r ^ k) (h₁ : u 1 = 2)
  (h₂ : u 3 = 6) : u 0 = 2 / Real.sqrt 3 ∨ u 0 = -(2 / Real.sqrt 3) := by
  simp_all only [pow_one, pow_three, mul_one, mul_add, mul_assoc, mul_comm]
  have h₃ : r ^ 2 = 3 := by
    nlinarith [h₁, h₂]
  have h₄ : a = 2 / Real.sqrt 3 ∨ a = -(2 / Real.sqrt 3) := by
    apply eq_or_eq_neg_of_sq_eq_sq <;> field_simp <;>
    nlinarith
  simp_all only [pow_zero, mul_one]
  <;>
  tauto
```


# Demonstrate RMaxTS Agent on a Harder Proof

In [9]:
theorem = "theorem symmetry_eq : ∀ a b : Nat, a = b → b = a := by"
rmax_ts = RMaxTS(model, tokenizer, theorem, num_sequences=5)
proof = rmax_ts.generate_whole_proof(theorem, iterations_per_sim=100)
rmax_ts.close()

Complete launching 1 LeanServerProcesses
Failed to generate a complete proof.
All 1 LeanServerProcesses stopped
