# Verification-Guided Chain-of-Thought Reasoning

This notebook implements a novel framework that integrates an incremental SMT verifier into the Chain-of-Thought reasoning process to ensure logical consistency. The system consists of three main components:

1. **Reasoner LLM**: Generates reasoning steps
2. **Verifier LLM**: Translates natural language to SMT-LIB format
3. **Incremental SMT Verifier**: Uses Z3 solver to verify logical consistency

The framework iteratively refines reasoning steps through a Reasoning-Verification-Revision (RVR) loop until formal consistency is achieved.

In [None]:
!pip install z3-solver transformers accelerate bitsandbytes sentencepiece
!pip install --upgrade transformers

## Setup and Installation

Install required dependencies including Z3 SMT solver and Hugging Face transformers library.

In [None]:
import torch
from typing import Optional, List, Tuple, Any, Dict
import time
import re
import textwrap
from z3 import Solver, Bool, parse_smt2_string, sat, unsat, And, BoolRef, AstVector 
from transformers import AutoModelForCausalLM, AutoTokenizer, TextGenerationPipeline, pipeline, BitsAndBytesConfig
try:
    HF_AVAILABLE = True
except Exception as e:
    HF_AVAILABLE = False
    
    print("transformers not available: falling back to stub. Install transformers to use local models.")

In [None]:
from huggingface_hub import snapshot_download

model_name = "mistralai/Mistral-7B-Instruct-v0.2"
local_dir = "./Mistral-7B-Instruct-v0.2"

print(f"Downloading {model_name} to {local_dir}...")
snapshot_download(repo_id=model_name, local_dir=local_dir, local_dir_use_symlinks=False)
print("Download complete.")

## Model Configuration

Configure the LLM model (Mistral-7B-Instruct-v0.2) with 4-bit quantization for efficient inference. The LLM class provides a flexible wrapper that can be used for both the Reasoner and Verifier with different temperature settings.

In [None]:

MODEL_NAME = "./Mistral-7B-Instruct-v0.2"  
DEVICE = "auto"
MAX_ITER = 5
VERBOSE = True



class LLM:
    def __init__(self, model_name=MODEL_NAME, device=DEVICE, temperature = 0.0):
        self.model_name = model_name
        self.device = device
        self.generator = None
        self.temperature = temperature
        if HF_AVAILABLE:
            try:
                tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)

                
                quantization_config = BitsAndBytesConfig(
                    load_in_4bit=True,
                    bnb_4bit_quant_type="nf4",
                    bnb_4bit_compute_dtype=torch.bfloat16,
                    bnb_4bit_use_double_quant=False,
                )

                model = AutoModelForCausalLM.from_pretrained(model_name, trust_remote_code=True,
                                                             quantization_config=quantization_config, 
                                                             device_map="auto" if device != "cpu" else None)
                self.generator = pipeline("text-generation", model=model, tokenizer=tokenizer, max_new_tokens=256)
            except Exception as e:
                print("HF model load failed; you may need to set up a proper environment or use an API.", e)
                self.generator = None

    def call(self, prompt: str, max_tokens=1024) -> str:
        """
        Generic LLM call. If HF pipeline available, use it. Otherwise, raise and expect the user to override.
        """
        if self.generator:
            out = self.generator(prompt, do_sample=(self.temperature>0), max_new_tokens=max_tokens)
            
            generated_text = out[0]["generated_text"]
            
            
            if generated_text.startswith(prompt):
                return generated_text[len(prompt):].strip()
            else:
                return generated_text.strip() 
        else:
            
            raise RuntimeError("No LLM backend available. Replace LLM.call with your API call.")

## Prompt Engineering

Define the system prompts for both the Reasoner and Verifier LLMs:

- **Reasoner**: Generates one reasoning step at a time with explicit termination signals
- **Verifier**: Translates natural language into valid SMT-LIB v2 code with strict formatting rules

In [None]:
REASONER_SYSTEM = textwrap.dedent("""
You are a chain-of-thought generator (Reasoner). Produce exactly one reasoning step (a short single sentence).
Be explicit and avoid omitting assumptions. Use x*x to represent x**2.
**Termination Conditions:**- When you derive the final conclusion or answer,
use a clear prefix at the beginning of that reasoning step, such as “CONCLUSION:” or “THEREFORE:”,
to signal to external systems that the reasoning process can conclude.
- If the current step is not the final conclusion, output the reasoning step normally.
""")

def reasoner_prompt(problem_text: str, accepted_facts_text: str, feedback: Optional[str]=None) -> str:
    prompt = REASONER_SYSTEM + "\n"
    prompt += f"Problem:\n{problem_text}\n\n"
    prompt += f"Accepted facts:\n{accepted_facts_text}\n\n"
    if feedback:
        prompt += f"Verifier feedback:\n{feedback}\n\n"
    prompt += "Now generate exactly one next reasoning step (one sentence):\n"
    return prompt

VERIFIER_SYSTEM = textwrap.dedent("""
You are a *formal verifier* that translates ONE natural-language reasoning step into *valid* SMT-LIB v2 code.

Your output MUST follow the strict structure:

<SMT-LIB>
...
</SMT-LIB>

<INTENT>
One-line summary of what the SMT constraints express.
</INTENT>

=========================
SMT-LIB v2 RULES YOU MUST FOLLOW
=========================

1. **Types**
   Allowed types:
     - Int
     - Bool
     - Real

   Use only:
     (declare-const x Int)
     (declare-const p Bool)
     (declare-const r Real)

2. **Allowed Expressions**
   For Int:
     +, -, *, div, mod
     =, >, >=, <, <=

   For Bool:
     and, or, not, implies, =
     (assert p)

  For Real:
     +, -, *, div,  ^
     =, >, >=, <, <=

3. **Unallowed Expressions**
    sqrt, write sqrt as (^ x 0.5)
    bnot, write bnot as (not p)
    pow, write pow as (^ x y)

4. **Structure Rules**
   - ONLY USE ALLOWED ESPRESSIONS!
   - Without specification, declare variable as Real
   - Each symbol MUST be declared exactly once.
   - DO NOT use set-logic. The system will manage logic selection.
   - DO NOT invent datatypes, arrays, reals, bitvectors, quantifiers.
   - DO NOT use "check-sat", "exit", or "push/pop".
   - DO NOT output comments except inside SMT-LIB wrapper.
   - Ensure parentheses always match.

5. **Translation Strategy**
   - Identify variables.
   - Declare each needed variable.
   - Translate the step into one or more (assert ...) constraints.

=========================
You must produce VALID SMT-LIB EVERY TIME.
No deviation from the XML-like wrapper.
=========================
""")


def verifier_prompt(step: str, accepted_facts_text: str) -> str:
    prompt = VERIFIER_SYSTEM + "\n"
    prompt += f"Context facts:\n{accepted_facts_text}\n\n"
    prompt += f"Step to translate:\n\"{step}\"\n\n"
    prompt += 'Return:\n<SMT-LIB>\n...smt-lib...\n</SMT-LIB>\n\n<INTENT>\nOne-line summary\n</INTENT>\n'
    return prompt

## Incremental SMT Verifier

The `Z3Verifier` class manages the verification state and provides:

- **State tracking**: Maintains history of all assertions and their verification status
- **UNSAT core extraction**: Identifies precisely which assertions cause contradictions
- **Efficient rollback**: Reconstructs solver state without reprocessing all previous steps
- **Grammar correction**: Fixes common SMT-LIB syntax errors from the Verifier LLM

In [None]:
class Z3Verifier:

    def __init__(self, llm_instance: Any, problem_text: str, initial_facts: str = ""):
        self.llm = llm_instance
        self.problem_text = problem_text
        self.solver = Solver()
        self.solver.set(unsat_core=True)  

        
        self.smt_history: List[Dict[str, Any]] = []  
        self.accepted_facts = initial_facts
        self.iteration = 0

    def extract_smt_from_verifier_output(self, verifier_out: str) -> Tuple[str, str]:
        """
        Extracts SMT-LIB and INTENT sections from verifier output
        """
        smt_match = re.search(r"<SMT-LIB>(.*?)</SMT-LIB>", verifier_out, re.S)
        intent_match = re.search(r"<INTENT>(.*?)</INTENT>", verifier_out, re.S)
        smt = smt_match.group(1).strip() if smt_match else ""
        intent = intent_match.group(1).strip() if intent_match else ""
        print(smt)
        return smt, intent

    def fix_smt_fragment(self, s: str) -> str:
        s = s.replace('pow', '^')
        s = s.replace('bnot', 'not')

        def process_parentheses(s: str) -> str:
            result = []
            i = 0
            n = len(s)

            while i < n:
                if s[i] != '(':
                    result.append(s[i])
                    i += 1
                else:
                    stack = 1
                    j = i + 1
                    while j < n and stack > 0:
                        if s[j] == '(':
                            stack += 1
                        elif s[j] == ')':
                            stack -= 1
                        j += 1
                    if stack != 0:
                        close_index = n - 1
                    else:
                        close_index = j - 1

                    inner_content = s[i+1:close_index]

                    processed_inner = process_parentheses(inner_content)

                    trimmed_inner = processed_inner.lstrip()
                    if trimmed_inner.startswith("sqrt") and (len(trimmed_inner) == 4 or not trimmed_inner[4].isalnum()):
                        start_index = 0
                        while start_index < len(processed_inner) and processed_inner[start_index].isspace():
                            start_index += 1
                        arg_part = processed_inner[start_index + 4:]
                        new_inner = processed_inner[:start_index] + "^" + arg_part + " 0.5"
                        result.append('(')
                        result.append(new_inner)
                        result.append(')')
                    else:
                        result.append('(')
                        result.append(processed_inner)
                        result.append(')')

                    i = close_index + 1

            return ''.join(result)
        return process_parentheses(s)

    def add_smt_step_and_check(self, smt_fragment: str, track_id: str) -> Tuple[bool, Optional[List[str]]]:

        ctx = self.solver.ctx

        try:
            
            parsed = parse_smt2_string(smt_fragment, ctx=ctx)
        except Exception as e:
            raise RuntimeError(f"Z3 Parser Error: {e}\nSMT Content:\n{smt_fragment}")

        
        parsed_list = list(parsed) if hasattr(parsed, "__iter__") else [parsed]

        bool_exprs = []
        for expr in parsed_list:
            if isinstance(expr, BoolRef):
                bool_exprs.append(expr)
            if isinstance(expr, AstVector):
                for sub_expr in expr:
                    if isinstance(sub_expr, BoolRef):
                        bool_exprs.append(sub_expr)


        
        for i, expr in enumerate(bool_exprs):
            track_name = f"{track_id}_{i}"
            self.solver.assert_and_track(expr, Bool(track_name, ctx=ctx))

        
        res = self.solver.check()
        if res == sat:
            return True, None
        elif res == unsat:
            core = [c.decl().name() for c in self.solver.unsat_core()]
            return False, core
        else:
            return False, None

    def push_smt_fragment(self, smt_fragment: str, natural_language_step: str, track_id: str) -> Tuple[bool, Optional[List[str]]]:
        """
        Pushes an SMT fragment into the data structure and verifies it

        Returns:
            Tuple[bool, Optional[List[str]]]: (is_satisfiable, UNSAT core)
        """
        
        is_sat, unsat_core = self.add_smt_step_and_check(smt_fragment, track_id)

        
        self.smt_history.append({
            'track_id': track_id,
            'smt_fragment': smt_fragment,
            'natural_language': natural_language_step,
            'is_sat': is_sat,
            'unsat_core': unsat_core,
            'iteration': self.iteration
        })

        return is_sat, unsat_core

    def pop_smt_fragment(self) -> Optional[Dict[str, Any]]:
        """
        Pops the most recently added SMT fragment (rollback mechanism)

        Returns:
            Optional[Dict]: Information of the popped fragment, None if nothing to pop
        """
        if not self.smt_history:
            return None

        
        popped = self.smt_history.pop()

        
        self._rebuild_solver_from_history()

        return popped

    def _rebuild_solver_from_history(self):
        """Rebuilds the solver state from history"""
        self.solver = Solver()
        self.solver.set(unsat_core=True)

        for item in self.smt_history:
            self.add_smt_step_and_check(item['smt_fragment'], item['track_id'])

    def get_feedback_for_reasoner(self, unsat_core: List[str]) -> str:
        """
        Generates feedback for the reasoner based on the UNSAT core
        """
        core_mappings = []
        for core_id in unsat_core:
            
            for item in self.smt_history:
                if core_id.startswith(item['track_id']):
                    core_mappings.append({
                        'core_id': core_id,
                        'step': item['natural_language'],
                        'track_id': item['track_id']
                    })
                    break

        if core_mappings:
            feedback = "Verifier found inconsistency. Conflicting steps:\n"
            for mapping in core_mappings:
                feedback += f"- {mapping['step']} (label: {mapping['track_id']})\n"
            feedback += "Please correct the reasoning step or provide missing assumptions."
        else:
            feedback = f"Verifier found inconsistency (Z3 UNSAT core labels: {unsat_core}). Please check the reasoning logic."

        return feedback

    def process_reasoner_step(self, reasoner_step: str) -> Tuple[bool, str, Optional[List[str]]]:
        """
        Processes one step from the reasoner: verify -> add -> check -> feedback
        """
        self.iteration += 1

        
        v_prompt = verifier_prompt(reasoner_step, self.accepted_facts)
        v_out = self.llm.call(v_prompt)

        smt, intent = self.extract_smt_from_verifier_output(v_out)
        print(f'Unfixed smt: {smt}')
        smt = self.fix_smt_fragment(smt)
        print(f'Fixed smt: {smt}')

        
        track_id = f"step{self.iteration}"
        is_sat, unsat_core = self.push_smt_fragment(smt, reasoner_step, track_id)

        if is_sat:
            
            self.accepted_facts += "\n" + reasoner_step
            return True, "Step accepted", None
        else:
            
            popped = self.pop_smt_fragment()
            feedback = self.get_feedback_for_reasoner(unsat_core) if unsat_core else "Verifier found inconsistency"
            return False, feedback, unsat_core


    def get_current_state(self) -> Dict[str, Any]:
        """Gets the current state"""
        return {
            'accepted_facts': self.accepted_facts,
            'smt_history_size': len(self.smt_history),
            'current_iteration': self.iteration,
            'solver_state': str(self.solver)
        }

## End-to-End Reasoning Loop

The `end_to_end_demo` function implements the complete Reasoning-Verification-Revision process:

1. Reasoner generates a step
2. Verifier translates to SMT-LIB
3. Z3 checks satisfiability
4. If UNSAT, extract core and provide feedback
5. If SAT, accept step and continue
6. Repeat until conclusion is reached

In [None]:
def end_to_end_demo(problem_text: str, initial_facts_text: str, reasoner_llm: LLM, verifier_llm:LLM):
    llm = reasoner_llm
    
    verifier_helper = Z3Verifier(verifier_llm, problem_text, initial_facts_text)
    iteration = 0
    history = []

    
    
    if initial_facts_text.strip():
        
        current_state = verifier_helper.get_current_state()
        if VERBOSE:
            print("Initial facts processed by Z3VerifierHelper")
            print(f"Current state: {current_state}")

    feedback = None
    while iteration < MAX_ITER:
        iteration += 1

        
        r_prompt = reasoner_prompt(problem_text, verifier_helper.accepted_facts, feedback)
        r_out = llm.call(r_prompt)

        step = r_out.strip()
        if VERBOSE:
            print(f"\n=== Iter {iteration} Reasoner step ===\n{step}\n")
        history.append(("reasoner", step))
        
        is_sat, feedback_msg, unsat_core = verifier_helper.process_reasoner_step(step)

        if VERBOSE:
            print(f"Verifier processing result: SAT={is_sat}")
            if unsat_core:
                print(f"UNSAT core: {unsat_core}")

        if is_sat:
            
            if VERBOSE:
                print("Z3: SAT -> Accepting step.")
            history.append(("accepted", step))
            feedback = None

            
            if ("therefore" in step.lower() or
                "conclude" in step.lower()):
                if VERBOSE:
                    print("Reasoner signaled final conclusion; stopping.")
                break
            continue
        else:
            
            if VERBOSE:
                print(f"Z3: UNSAT -> Feedback: {feedback_msg}")
            history.append(("rejected", step, unsat_core))
            feedback = feedback_msg
            if VERBOSE:
                print("Feedback to Reasoner:\n", feedback)
            continue

    return {
        "history": history,
        "accepted_facts": verifier_helper.accepted_facts,
        "assertions": verifier_helper.solver.assertions(),
        "verifier_helper": verifier_helper  
    }

## Initialize LLM Models

Create two LLM instances with different temperature settings:
- Verifier: temperature=0.0 for deterministic translation
- Reasoner: temperature=0.5 for creative but coherent reasoning

In [None]:

print("Initializing LLM...")
verifier_llm = LLM(temperature=0.0)
reasoner_llm = LLM(temperature=0.5)
print("LLM initialized.\n")

## Run Experiment

Test the framework on a geometric proof problem: proving that the distance from point (x,y) to the origin is at least 5.0, given x ≥ 3.0 and y ≥ 4.0.

The system will demonstrate:
- Error detection when reasoning is inconsistent
- Precise feedback through UNSAT core extraction
- Automatic correction and convergence to valid proof

In [None]:


problem = "Given x >= 3.0 and y >= 4.0, show the distance between point (x,y) and origin is at least 5.0."

initial_facts = "x >= 3.0\ny >= 4.0"



print("Starting end-to-end demo...\n")
result = end_to_end_demo(problem, initial_facts, reasoner_llm,verifier_llm)
print("\n=== Demo finished ===")
print(f'History: {result["history"]}')
print(f'Accepted facts: {result["accepted_facts"]}')
print(f'Assertions: {result["assertions"]}')
