## Introduction

### Grap a GPU

(Google Colab only)

Before doing anything, let's try to get a GPU instance.

Go to instance options

![step_0](https://raw.githubusercontent.com/theostos/small-pytanque-tp/refs/heads/main/img/step_0_option.png)

Change runtime

![step_1](https://raw.githubusercontent.com/theostos/small-pytanque-tp/refs/heads/main/img/step_1_change_runtime.png)

Add T4 GPU

![step_2](https://raw.githubusercontent.com/theostos/small-pytanque-tp/refs/heads/main/img/step_2_t4.png)

Then connect to a GPU instance

![step_3](https://raw.githubusercontent.com/theostos/small-pytanque-tp/refs/heads/main/img/step_3_connect.png)


### Environment setup

This first cell needs to be executed to set up our environment.

In [None]:
%pip install ollama
%pip install colab-xterm

!sudo apt-get update
!sudo apt-get install pciutils lshw

!curl -fsSL https://ollama.com/install.sh | sh
url = "https://raw.githubusercontent.com/theostos/small-pytanque-tp/refs/heads/main/client.py"
!wget --no-cache --backups=1 {url}
url = "https://raw.githubusercontent.com/theostos/small-pytanque-tp/refs/heads/main/utils.py"
!wget --no-cache --backups=1 {url}
!pip install requests
!pip install transformers

The following cell instantiates a client to connect to the Rocq proof assistant.
Don't forget to change the IP_ADDRESS constant!

In [24]:
IP_ADDRESS = "http://128.93.101.129:8765"

from client import ProofAssistantClientAPI, goals_to_str

client = ProofAssistantClientAPI(IP_ADDRESS)

On the server side, there is a file containing some theorems. First, let's look at them.

In [None]:
for section in client.sections():
  print(f"Theorems in section: {section}")
  for k in range(client.num_thm(section)):
    thm = client.show_thm(section, k)
    print(f"Theorem {k}, {thm['name']} description: {thm['statement']}")
  print()


## A case study: Lelarge's Theorem

We will start by proving Lelarge's Theorem in the Introduction section.
Let's review it by selecting the first theorem (index = 0) in the "introduction" section.

Before doing so, we need to examine the current context (i.e. the available lemmas, definitions, and notations) to understand it.

In [None]:
thm =  client.show_thm("introduction", 0)
print("Context:")
print("\n".join(thm['premises']))
print("\n")
print(f"Theorem {thm['name']}: {thm['statement']}",)

Theorem involution_injective states that an arbitrary function f, satisfying Hinv (i.e. being an involution), is one-to-one.

To prove it, we will use the following tactics (read from left to right):

* "intros x y H." introduces two variables, x and y, associated with the forall quantifier, and the hypothesis H corresponding to f x = f y.
* "rewrite <- Hinv with (x := x)." rewrites the left-hand side of the current goal using the hypothesis specialized with x = x.
* "rewrite <- Hinv with (x := y)." rewrites the left-hand side of the current goal using the hypothesis specialized with x = y.
* "rewrite H." rewrites the current goal using the introduced hypothesis H (i.e. f x = f y).
* "reflexivity." discharges goals of the form a = a.

In [None]:
state, goals = client.start_thm("introduction", 0)
print("Started theorem session:")
print(goals_to_str(goals))

In [None]:
state, goals = client.run_tac(state, "intros x y H.")
print(goals_to_str(goals))

In [None]:
state, goals = client.run_tac(state, "rewrite <- Hinv with (x := x).")
print(goals_to_str(goals))

In [None]:
state, goals = client.run_tac(state, "rewrite <- Hinv with (x := y).")
print(goals_to_str(goals))

In [None]:
state, goals = client.run_tac(state, "rewrite H.")
print(goals_to_str(goals))

In [None]:
state, goals = client.run_tac(state, "reflexivity.")
print(goals_to_str(goals))

# Automation

## ChatGPT vs Rocq

Let's try to automate theorem proving. In deep learning, our sledgehammer is a big LLM, such as chatGPT.
Is it able to prove Lelarge's theorem?

First, let's ask it to one of the best "reasoning" models: GPT o3-mini.
Extract the sequence of tactics from [this](https://chatgpt.com/share/67e9842e-7d18-8007-b394-d29b03d859cb) link, and try to submit it to the Rocq server ([petanque server](https://github.com/ejgallego/coq-lsp/tree/main/petanque)).

To do it, complete the following cell. You only need to add the remaining tactics.

In [None]:
# complete the following list with the sequence of tactics provided by GPT o3-mini
tactics = ['intros x y H.', 'rewrite <- (Hinv x).', 'rewrite H.', 'rewrite Hinv.', 'reflexivity.'] # ['intros x y H.', 'tactic_1', 'tactic_2', 'and so on']

state, goals = client.start_thm("introduction", 0)

for tactic in tactics:
    state, goals = client.run_tac(state, tactic)
    print(goals_to_str(goals))

Remarkably, if we ask a more modest model, things get a bit more complex:
* Asking to GPT 4o model (see [here](https://chatgpt.com/share/67e986d3-6650-8007-b01c-b5dfe48468a5)), it is able to prove it after 5 attempts.
* Asking to GPT 4o-mini (see [here](https://chatgpt.com/share/67e9871d-9de8-8007-aaeb-04d48d61b525)) is not able to prove it in 8 attempts.

And even more remarkably, both the non reasoning and reasoning model from [DeepSeek](https://chat.deepseek.com/) are able to solve this problem on the first attempt.

## Open-weight model

Now, let's try to do it with an open-weight, locally run model. We choose Gemma 3 12b, since it has good performance while being able to run on a colab T4 instance.

First, we need to start an inference engine (to serve our model locally).

In [None]:
!while true; do nohup ollama serve >/dev/null 2>&1; sleep 1; done >/dev/null 2>&1 &

Then, let's download Gemma 3 12b.

In [None]:
!ollama pull gemma3:12b

To use the model, we will simply call the function get_response.

In [None]:
from utils import get_response
from tqdm import tqdm

prompt = "Why the sky is blue? Write a short answer."

print(get_response(prompt))

Now let's try to prove simple Rocq lemmas (section "logic").

Complete the following cell.

In [None]:
import re

def parse_output(output):
    """
    Function to parse LLM output.
    It expects outputs with the following format
    ```coq
    tactic.
    ```
    """
    # to avoid some parsing issue, we accept instruction to not end with a point as normally required.
    pattern = r'```coq\n(.*?)\.?\n'
    match_output = re.search(pattern, output)
    if match_output:
      output = match_output.group(1).strip()
      return output + '.'
    return ''

prompt_template = """You are an expert in Coq, a theorem-proving assistant. Your task is to help progress a formal proof by providing exactly one correct and effective Coq tactic to advance towards the goal.
Current proof state:
{goal}

Carefully analyze the current goal, consider available hypotheses, and propose the most logical and efficient next step in the proof.

Respond with ONLY ONE Coq tactic enclosed in a Coq code block. Ensure the tactic is syntactically correct and directly applicable. Don't write any comment, simple the code block.

Example of correct formatting:
```coq
tactic_1.
```
"""

for idx in range(client.num_thm('logic')):
  # iterate over all theorem in the section 'logic'

  print(f"Try to prove theorem {idx}")
  state, goals = client.start_thm("logic", idx)
  tactics = []
  for _ in tqdm(range(25)):

    # retrieve goals as a string
    to_prove_pp = goals_to_str(goals)
    # Gemma3 seems to prefer natural language instead of weird logician symbols
    to_prove = to_prove_pp.replace('|-', 'to prove: ')

    prompt = prompt.format(goal=to_prove)
    output = get_response(prompt)
    next_tactic = parse_output(output)
    try:
      # send tactic to Rocq proof assistant
      state, goals = client.run_tac(state, next_tactic)
      tactics.append(next_tactic)
    except Exception as e:
      # ignore tactics if failed
      pass
    if not goals:
      # proof finished
      print("Finished!")
      print("Found solution:\n" + "\n".join(tactics) + "\n\n")
      break
  if goals:
    print("Failed" + "\n\n")
  print()
  print()

You may try to improve the prompt.

In the following cell, there is a prompt with an explanation of some tactics in Rocq.
Complete it and try it to see if it makes any difference.

In [None]:
prompt_template = """You are an expert in Coq, a theorem-proving assistant. Your task is to help progress a formal proof by providing exactly one correct and effective Coq tactic to advance towards the goal.

Here is a brief explanation of tactics you may use:

- intros: Introduces hypotheses or variables into the context. (example: "intros P Q." introduces hypotheses P and Q.)
- apply: Applies a hypothesis or theorem to match the current goal. (example: "apply H0." If the goal is Q and you have a hypothesis H0: P -> Q, applying H0 changes the goal to P.)
- exact: Directly solves the current goal if you have an exact matching hypothesis. (example: "exact H." If the goal is P and you have a hypothesis H: P, then the goal is resolved.)
- contradiction: Resolves the goal if there is a contradiction in the hypotheses. (example: if you have hypotheses H1: P -> False and H2: P, using "contradiction." resolves the goal.)
- unfold not: Expands the definition of negation (~P becomes P -> False). (examples: "unfold not in H." applies it to hypothesis H, "unfold not." applies it to the goal.)
- inversion: Breaks apart hypotheses involving conjunctions (and), disjunctions (or), or existential quantifiers to reveal simpler components. (example: "inversion H." breaks hypothesis H into simpler parts.)
- split: Splits goals involving conjunctions into separate subgoals. (example: "split." transforms goal P /\ Q into two separate goals, P and Q.)
- left/right: Selects a side of a disjunction (or) goal to prove. (examples: "left." to prove the left side of a goal P \/ Q, "right." to prove the right side.)

Current proof state:
{goal}

Carefully analyze the current goal, consider available hypotheses, and propose the most logical and efficient next step in the proof.

Respond with ONLY ONE Coq tactic enclosed in a Coq code block. Ensure the tactic is syntactically correct and directly applicable. Don't write any comment, simple the code block.

Example of correct formatting:
```coq
tactic_1.
```
"""

for idx in range(client.num_thm('logic')):
  # iterate over all theorems in the section 'logic'

  print(f"Try to prove theorem {idx}")
  state, goals = client.start_thm("logic", idx)
  tactics = []
  for _ in tqdm(range(25)):

    # retrieve goals as a string
    to_prove_pp = goals_to_str(goals)
    # Gemma3 seems to prefer natural language instead of weird logician symbols
    to_prove = to_prove_pp.replace('|-', 'to prove: ')

    prompt = prompt_template.format(goal=to_prove)
    output = get_response(prompt)
    next_tactic = parse_output(output)

    try:
      # send tactic to Rocq proof assistant
      state, goals = client.run_tac(state, next_tactic)
      tactics.append(next_tactic)
    except Exception as e:
      # ignore tactics if failed
      pass
    if not goals:
      # proof finished
      print("Finished!")
      print("Found solution:\n" + "\n".join(tactics) + "\n\n")
      break
  if goals:
    print("Failed" + "\n\n")
  print()
  print()

You could try the following strategies:
- Increase the number of tries.
- Add some examples of lemma + proof in your prompt (few-shot prompting).

If you feel confident enough, you could try to improve the overall strategy:
- What would happen if you keep track of the wrong step in your prompt? (i.e. the one that throws an error)
- What would happen if you keep track of redundant steps?

In [1]:
# Find a better strategy/prompt!

# TO DO 

If you are able to prove most exercises from the "logic" section, maybe you can try to prove Lelarge's theorem? (section 'introduction', index 0)!

In the following cell, you will find a partial implementation that is a bit better (keeps track of errors, redundant tactics, and increase number of tries).

Try to complete it or improve it.

In [None]:
from collections import defaultdict

prompt_template = """You are an expert in Coq, a theorem-proving assistant. Your task is to help progress a formal proof by providing exactly one correct and effective Coq tactic to advance towards the goal.

Here is a brief explanation of tactics you may use:

- intros: Introduces hypotheses or variables into the context. (example: "intros P Q." introduces hypotheses P and Q.)
- apply: Applies a hypothesis or theorem to match the current goal. (example: "apply H0." If the goal is Q and you have a hypothesis H0: P -> Q, applying H0 changes the goal to P.)
- exact: Directly solves the current goal if you have an exact matching hypothesis. (example: "exact H." If the goal is P and you have a hypothesis H: P, then the goal is resolved.)
- contradiction: Resolves the goal if there is a contradiction in the hypotheses. (example: if you have hypotheses H1: P -> False and H2: P, using "contradiction." resolves the goal.)
- unfold not: Expands the definition of negation (~P becomes P -> False). (examples: "unfold not in H." applies it to hypothesis H, "unfold not." applies it to the goal.)
- inversion: Breaks apart hypotheses involving conjunctions (and), disjunctions (or), or existential quantifiers to reveal simpler components. (example: "inversion H." breaks hypothesis H into simpler parts.)
- split: Splits goals involving conjunctions into separate subgoals. (example: "split." transforms goal P /\ Q into two separate goals, P and Q.)
- left/right: Selects a side of a disjunction (or) goal to prove. (examples: "left." to prove the left side of a goal P \/ Q, "right." to prove the right side.)

Current proof state:
{goal}

Carefully analyze the current goal, consider available hypotheses, and propose the most logical and efficient next step in the proof.

Respond with ONLY ONE Coq tactic enclosed in a Coq code block. Ensure the tactic is syntactically correct and directly applicable. Don't write any comment, simple the code block.

Example of correct formatting:
```coq
tactic_1.
```
"""

prompt_remove = """Don't use any of the following instructions:
{remove}
"""

success = False
for idx in range(client.num_thm('logic')):
  # iterate over all theorems in the section 'logic'

  print(f"Try to prove theorem {idx}")
  state, goals = client.start_thm("logic", idx)
  tactics = []

  # useless_tactics and failed_tactics keep track of errors and redundant tactics
  useless_tactics = defaultdict(list)
  failed_tactics = defaultdict(list)
  for _ in tqdm(range(40)):
    to_prove_pp = goals_to_str(goals)
    to_prove = to_prove_pp.replace('|-', 'to prove: ')

    bad_tactics = useless_tactics[to_prove_pp] + failed_tactics[to_prove_pp]
    prompt = prompt_template.format(goal=to_prove)
    prompt += prompt_remove.format(remove="\n".join(bad_tactics))
    
    output = get_response(prompt)
    next_tactic = parse_output(output)
    try:
      state, goals = client.run_tac(state, next_tactic)
      new_to_prove_pp = goals_to_str(goals)

      # if goal doesn't change, add next_tactic to useless_tactics
      if to_prove_pp == new_to_prove_pp:
        useless_tactics[new_to_prove_pp].append(next_tactic)
      tactics.append(next_tactic)
    except Exception as e:
      # if tactic fails, add next_tactic to failed_tactics
      failed_tactics[to_prove_pp].append(next_tactic)
      pass

    if not goals:
      print("Finished!")
      print("Found solution:\n" + "\n".join(tactics))
      break
  if goals:
    print("Failed")
    

## Specialized open-weight model

Let's download [ProofWala](https://arxiv.org/abs/2502.04671), a fine-tuned version of the [Code T5](https://arxiv.org/abs/2109.00859) model on a dataset of Lean and Rocq proofs.

In [None]:
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
from transformers import pipeline
from functools import partial
from utils import generate_tactics_wala

model_name = "amitayusht/ProofWala-Multilingual"
model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name)
pipeline = pipeline("text2text-generation", model=model, tokenizer=tokenizer, device=0) # device=0 for GPU, -1 for CPU

generate_tactics = partial(generate_tactics_wala, pipeline, model, tokenizer)

### Naive strategy

Now, let's try this new model on the previous lemmas.

In [None]:
for idx in range(client.num_thm('logic')):
  # iterate over all theorems in the section 'logic'
  print(f"Try to prove theorem {idx}")
  state, goals = client.start_thm("logic", idx)

  # keep track of previous steps and incorrect steps
  steps = []
  incorrect_steps = []
  for _ in tqdm(range(100)):
    # generate 1 candidate for the next step associated to goals, and given previous steps and incorrect_steps
    tactics, _ = generate_tactics(1, goals, steps=steps, incorrect_steps=incorrect_steps)
    # tactics is a list of size 1
    next_tactic = tactics[0]
    try:
      state, goals = client.run_tac(state, next_tactic)
      steps.append(next_tactic)

      # reset incorrect steps since goal may have change.
      incorrect_steps = []
    except Exception as e:
      incorrect_steps.append(next_tactic)
      pass
    if not goals:
      print("Finished!")
      print("Found solution:\n" + "\n".join(steps) + "\n\n")
      break
  if goals:
    print("Failed" + "\n\n")

### Beam Search

In this setup, we use an LLM to generate a sequence of steps.
Unlike typical token-wise beam search, our approach operates step-wise scoring each entire step by its mean log probability.

how it works (k-beam search):

1. **Initialization:**  
   Begin with a `<START>` step.

2. **Step-wise Expansion:**  
   At each iteration, expand each candidate sequence by generating full steps (each representing a complete Rocq step) and compute the mean log probability of the step.

3. **Pruning:**  
   Retain only the top-k candidates (according to their mean log probability) for further expansion.

4. **Termination:**  
   Continue until the `<END>` step is reached or a maximum number of steps is generated.


![beam_search](https://raw.githubusercontent.com/theostos/small-pytanque-tp/refs/heads/main/img/beam_search.png)

Now, what would happen with a beam search (step-wise)? Let's try it on our whole collection of exercises.

In [None]:
from utils import beam_search

for section in ['introduction', 'logic', 'math']:
    for idx_thm in range(client.num_thm(section)):
        print(f"Trying to prove theorem {idx_thm} in section {section}.")
        found = False
        for _ in tqdm(range(30)):
            result = beam_search(generate_tactics, client, section, idx_thm, max_depth=7, beam_size=32, timeout=60)
            if result:
                print("Found solution:\n" + "\n".join(result) + "\n\n")
                found = True
                break
        if not found:
            print("Failed" + "\n\n")