---
date: 2024-05-15
title: Proof Search with an LLM Kanren
---

https://arxiv.org/abs/2402.08147 VerMCTS: Synthesizing Multi-Step Programs using a Verifier, a Large Language Model, and Tree Search


Deep learning and in particular LLMs are still hot.

Machine learning applied to theorem proving is not a new field. I was watching an interview with Stephen Schulz and he mentioned he was doing a PhD in the nineties on it. Axiom and strategy selection is an important part of top tier solvers. 
Probably it goes back in various forms all the way to the beginning.

To some degree, copilot is already an extremely useful machine learning proof assistant out of the box. I've been finding it mostly fills out reasonable lemma selection and even theorem statements while I've been working on knuckledragger.

It now feels almost passe, but the computer world was rocked by alpha go. Monte carlo tree search seems like a natural fit for theorem proving search as well. There is a proof search tree that branches when there are multiple possible proof rules to follow.



Kanren is a style of embedding logic programming as a DSL into host language. 
Logic programming has two big pieces that can be treated separately

- Search. A program is given alternative possibilities. This nondeterminism can be modelled as a function `state -> list state` and combinators around this type.
- Unification. Unification is two way pattern matching. The bidirectional character of it is what enables the seemingly magical ability for prolog of minikanren to run relations backwards or forwards (for example using a single `append` predicate to append two lists, take a list difference, or generate partitions of a list). The `state` in a basic kanren is the unification state, which is a substitution mapping from logic variables to terms.

There are two basic combinators you need for kanren style search, `conj` and `disj` aka `and` and `or`. Conj combines requirements and is the bind operation of a nondeterminism monad. `conj` presents alternative choices and is where the nondeterminism is generated.

Different search strategies can be represented by different implementations of `conj` and `disj`


The most basic depth first strategy can be written as


In [2]:
def disj(*args): # very similar to chain https://docs.python.org/3/library/itertools.html#itertools.chain
    def res(state):
        for a in args:
            for state in a(state):
                yield state
    return res

def conj(*args): # kind of state threading product https://docs.python.org/3/library/itertools.html#itertools.product
    def res(state):
        if len(args) == 0:
            yield state
            return
        else:
            for state in args[0](state): # expand first requirement
                for state in conj(*args[1:])(state): # send it into expansion of next requiresments
                    yield state
    return res

def fail(state):
    return

def filt(pred):
    def res(state):
        if pred(state):
            yield state
    return res

# convert ordinary generator into a goal... waaaaaitaminute.
def goal(gen):
    def res(state):
        for i in gen(state):
            yield i
    return res

#def run(goal,n=-1):
#    for x in goal()

In [6]:
ex1= conj(range, 
          disj(conj(filt(lambda x: x > 5), 
               disj(filt(lambda x: x < 8), 
                    filt(lambda x: x < 10)))))

for i in range(20):
    for z in ex1(i):
        print(z)

6
6
7
7
8
9


Ok, but let's say I want to return a choice of what to expand.

I could wrap this in a gym interface maybe.
type hints = string * state
type nondet (hints, a) = strem (hints * nondet (hints,a)) | done of state
type choice = [(hints, () -> choice)] | state
 -- so I can view the state or a description and choose to expand this branch, or it may be done 

We need to invert control. Whatever the hell that means
I'm ticking the iteratee part of my brain.
What about lazy game trees?

state -> choice
The hints aren't choices.


In [9]:
def disj(*args):
    def res(state):
        return ("disj", state, [arg(state) for arg in args]) # "disj"
    return res

def conj(*args):
    def res(state):
        if len(args) == 0:
            return [state]
            #return [("success", state)]
        else:
            return [state for state in args[0](state) for state in conj(*args[1:])(state)]
            #return [("choice", state, conj(*args[1:])(state)) for state in args[0](state)]
    return res

def filt(pred):
    def res(state):
        if pred(state):
            return [state]
        else:
            return []
    return res

def goal(gen):
    def res(state):
        return ("disj", state, gen(state))
    return res

    
ex1= conj(filt(lambda x: x > 5), 
               disj(filt(lambda x: x < 8), 
                    filt(lambda x: x < 10)))

for i in range(20):
    for z in ex1(i):
        print(z)

disj
6
[[6], [6]]
disj
7
[[7], [7]]
disj
8
[[], [8]]
disj
9
[[], [9]]
disj
10
[[], []]
disj
11
[[], []]
disj
12
[[], []]
disj
13
[[], []]
disj
14
[[], []]
disj
15
[[], []]
disj
16
[[], []]
disj
17
[[], []]
disj
18
[[], []]
disj
19
[[], []]


In [None]:
def run(goal, state0):
    goals = [goal(state0)]
    while goals:
        goal = goals.pop() # pick which goal seems most interesting here.
        #print("Which goal? ", list(enumerate(goals)))
        #goal = goals.remove(input())
        match goal:
            case "done", state:
                yield state
            case "more", state, k:
                goals.extend(next(k)) # or instead of extend, we could use a context to keep track or goals could be a tree.

def goal(gen):
    def res(state):
        return [("done", i) for i in gen(state)]
    return res

In [None]:

Ctx = namedtuple("Ctx", "up, siblings")
def run(goal, state0):
    ctx = Ctx("root", [])
    goals = [goal(state0)]
    while True:
        cmd = input()
        match cmd:
            case -1:
                up, siblings = ctx
                siblings.append(goals)
                goals = siblings
                ctx = up
            case n:
                goal = goals.pop(n)
                match goal:





I could perhaps run a version of minikanren to generate the best possible examples and bad examples as prompts.
Run it automatically.



In [None]:
def disj(*args):
    def res(state):
        for a in args:
            for state in a(state):
                yield state
    return res

def conj(*args):
    def res(state):
        if len(args) == 0:
            yield state
            return
        else:
            for state in args[0](state): # expand first requirement
                for state in conj(*args[1:])(state): # send it into expansion of next requiresments
                    yield state
    return res


a = input()

In [None]:
def pstep(seq):
    (ctx, goal) = seq
    if goal in ctx: # refl
            yield seq
    else:
        match goal:
            case ("false"):
                yield from fail
            case ("and", a, b):
                yield from conj(pstep((ctx, a)), pstep((ctx, b)))
            case ("or", a, b):
                yield from disj(pstep((ctx, a)), pstep((ctx, b)))
            case ("=>", a, b):
                yield from pstep((ctx + {a}, b))
            case ("not", a):
                yield from pstep((ctx + a, ("false",)))
         

https://arxiv.org/abs/2102.13564 Improving ENIGMA-Style Clause Selection While Learning From History

Machine learning in proving tutorial
Talia's doc

https://www.tcs.ifi.lmu.de/mitarbeiter/jasmin-blanchette/axiom_sel.pdf A Survey of Axiom Selection as a
Machine Learning Problem


relational learning https://link.springer.com/article/10.1007/s10994-013-5392-1

https://simons.berkeley.edu/workshops/agenda/theoretical-foundations-satsmt-solving/ml-solvers

https://arxiv.org/pdf/2403.04017 Learning Guided Automated Reasoning:
A Brief Survey


In [4]:
import inspect
def self_describe(f):
    src = inspect.getsource(f)
    def res(*args, **kwargs):
        print(f"arguments={args}")
        print("Source:\n", src)
        return f(*args, **kwargs)
    return res # returning tuple here didn't work so good. Could return class wrapper wit __call__ and desc.

class Goal():
    def __call__(self, *args):
        return self.f(*args)




@self_describe
def fact(n):
    if n <= 0:
        return 1
    else:
        return n * fact(n-1)


arguments=(2,)
Source:
 @self_describe
def fact(n):
    if n <= 0:
        return 1
    else:
        return n * fact(n-1)



TypeError: 'tuple' object is not callable

In [1]:
import llm


# llm datalog?

# use llm to auto build descriptor?
def auto_desc(state):
    


def conj( *args , desc=None):
    def res(state):
        for a in args:
            for state in a(state):
                
        return state
    return res
def disj(*args ): 




In [None]:

def self_describe(x):
    x.__code__
    prompt = 
    """
    This is source code for a search function. There are a couple of choices avaiable.
    Pleas
    """"

@self_describe
def foo(x,y,z):

What about using the llm to self annoate the code using __code__ attribute. Then any hints could just go in comments. That's nice.

Then I can go and get stock minikanren programs and see if it does better.


blockworld


Supposedly the point of natural deduction is that it translate well ("naturally") into the way mathematicians write down proofs in prose. This is also appealing for guiding a search that is based on an LLM trained on probably a massive amount of math literature.


The idea of minikanren is that search is modellable as
`state -> list state`


state is a list of descriptions
we could have an automatic description generator as a default


We could also probably throw some RL-iness on there. We're doing tree search, so we could have the llm 


"What clues would you use to do it faster next time?"

Choices:
1. abort
2. choice 1
3. choice 2

https://news.ycombinator.com/item?id=39479478 A* boosting. "searchformer"
Feed the llm the trace, not just the state.

Kind of need a compelling problem. That's tough? Why should that be tough?

neural kanren https://github.com/xuexue/neuralkanren


In [None]:
def conj(cost_heuristic=None):
def disj(cost_heuristic=None):

In [None]:
# basic minikanren
# ignore the unification stuff. Who cares.
def conj(a,b):
    def res(state0):
        for state1 in a(state0):
            for state2 in b(state1):
                yield state2
    return res

def disj(a,b):
    def res(state0): # unfair interleaving
        for state in a(state0):
            yield state
        for state in b(state0):
            yield state
    return res


In [None]:
# depth limitted search.
def disj_depth(a,b,N=10):
    def res(state0):
        if state0.depth > N:
            return
        state0.depth += 1 
        for state in a(state0):
            yield state
        for state in b(state0):
            yield state
    return res

In [None]:
{score: (state, expand)} # min heap 
conj : state -> [(score, state)]
disj : 
# take a look at the first orderized minikanren



In [None]:
import os
from openai import OpenAI

client = OpenAI(
    # This is the default and can be omitted
    api_key=os.environ.get("OPENAI_API_KEY"),
)

chat_completion = client.chat.completions.create(
    messages=[
        {
            "role": "user",
            "content": "Say this is a test",
        }
    ],
    model="gpt-3.5-turbo",
)
chat_completion

In [None]:
def system():
    prompt f"""
        This is a minikarnren like search program looking for solutions to X.
        I want you to guide it towards good choices.
        For example, if you had the state 
          
        Return ONLY the number number of the choice of which state to expand.
        {examples}
    """
    {"role": "system", "content" : "T"}

pytroch install.
https://github.com/ROCm/ROCm/discussions/2932
Actually Radeon 680M and 780M are supported by the latest ROCm 6.0, what you need to do is to set HSA_OVERRIDE_GFX_VERSION=10.3.0 for 680M, and HSA_OVERRIDE_GFX_VERSION=11.0.0 for 780M.

huggingface-cli for model management

`rocminfo

In [2]:
%env HSA_OVERRIDE_GFX_VERSION=10.3.0 

env: HSA_OVERRIDE_GFX_VERSION=10.3.0


In [1]:
%env HSA_OVERRIDE_GFX_VERSION=11.0.0

env: HSA_OVERRIDE_GFX_VERSION=11.0.0


In [2]:
%env PYTORCH_ROCM_ARCH="gfx1103"

env: PYTORCH_ROCM_ARCH="gfx1103"


In [2]:
%env AMD_LOG_LEVEL=3

env: AMD_LOG_LEVEL=3


In [3]:
%env AMD_SERIALIZE_KERNEL=3

env: AMD_SERIALIZE_KERNEL=3


In [3]:
import torch

torch.cuda.is_available()

True

In [6]:
from transformers import AutoTokenizer, AutoModelForCausalLM

tokenizer = AutoTokenizer.from_pretrained("google/gemma-2b-it")
model = AutoModelForCausalLM.from_pretrained("google/gemma-2b-it", device_map="auto")
#model.to("cuda")

input_text = "Write me a poem about Machine Learning."
input_ids = tokenizer(input_text, return_tensors="pt") #.to("cuda")

outputs = model.generate(**input_ids)
print(tokenizer.decode(outputs[0]))

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]



<bos>Write me a poem about Machine Learning.

Machines, they weave and they learn,
From


In [9]:
outputs = model.generate(**input_ids)
print(tokenizer.decode(outputs[0]))

<bos>Write me a poem about Machine Learning.

Machines, they weave and they learn,
From


In [11]:
from llama_cpp import Llama
# https://github.com/abetlen/llama-cpp-python
llm = Llama.from_pretrained(
    repo_id="google/gemma-2b-it",
    filename="*gemma-2b-it.gguf",
    verbose=False
)
llm.create_chat_completion(
      messages = [
          {"role": "system", "content": "You are an assistant who perfectly describes images."},
          {
              "role": "user",
              "content": "Describe this image in detail please."
          }
      ]
)


./gemma-2b-it.gguf:   0%|          | 0.00/10.0G [00:00<?, ?B/s]

{'id': 'chatcmpl-89948379-b0a2-46df-b224-c1c75e2c359c',
 'object': 'chat.completion',
 'created': 1711641796,
 'model': '/home/philip/.cache/huggingface/hub/models--google--gemma-2b-it/snapshots/718cb189da9c5b2e55abe86f2eeffee9b4ae0dad/./gemma-2b-it.gguf',
 'choices': [{'index': 0,
   'message': {'role': 'assistant',
    'content': '\n\nI am unable to access external sources or display images, so I am unable to provide a detailed description of the image you have specified.'},
   'finish_reason': 'stop'}],
 'usage': {'prompt_tokens': 33, 'completion_tokens': 28, 'total_tokens': 61}}