In [None]:
from llama_cpp import Llama

In [None]:
class stream_chat_bot:
    def __init__(self, model_path):
        self.model_path = model_path
        self.llm = Llama(model_path=self.model_path, n_ctx=16384, n_gpu_layers=99,verbose=False)
        self.clean()

    def push_back(self, role, content):
        self.stream += f" <|start_header_id|>{role}<|end_header_id|>"\
                        f"{content}"\
                        "<|eot_id|>"

    def clean(self):
        self.stream = ""
        self.push_back("system", "You are a helpful assistant. Always think step-by-step before answering and format your response as follows:\n"
                                 "<step 1 content>\n"
                                 "<step 2 content>\n"
                                 "...\n"
                                 "[answer]\n"
                                 "<answer content>\n"
                                 "Ensure every response follows this format, with each reasoning step on a new line and the answer preceded by [answer] on a new line, followed by its content on the next line.")

    def answer(self, query):
        self.push_back("user", query)
        output = self.llm(self.stream + "<|start_header_id|>assistant<|end_header_id|>", max_tokens=2048)
        self.push_back("assistant", f"{output['choices'][0]['text']}")
        return output['choices'][0]['text']

In [None]:
bot = stream_chat_bot("llama-3.1-8b-instruct-q4_k_m.gguf")
while True:
    print(bot.answer(input("\n>")))

# sentence as node

In [None]:
from llama_cpp import Llama
import numpy as np

class MCTS_REASONING_LLM:

    class MCTS_NODE:
        def __init__(self,parent,actions_and_probs,k):
            self.parent = parent
            self.last_visit = -1
            self.end = False
            (self.actions, self.probs) = actions_and_probs
            self.actions = np.array(self.actions)
            self.probs = np.array(self.probs)
            self.child = np.full((k,),-1)
            self.mean_reward = np.full((k,),0)
            self.visit_count = 0
            self.child_visit_count = np.full((k,),0)

        def get_optimum_child(self, policy_weight, explore_weight):
            score = self.mean_reward + self.probs*policy_weight + \
                (np.sqrt(np.log(self.visit_count + 1e-6))/(self.child_visit_count + 1e-6))*explore_weight
            return self.child[np.argmax(score)]

    def __init__(self, model_path, k, policy_weight = 1, explore_weight = 1):
        self.model = Llama(model_path, n_ctx = 16384, n_gpu_layers = 99,logits_all = True,verbose = False)
        self.clean()
        self.nodes = []
        self.policy_weight = policy_weight
        self.explore_weight = explore_weight   
        self.k = k

    def push_back(self, role, content):
        self.stream += f" <|start_header_id|>{role}<|end_header_id|>"\
                        f"{content}"\
                        "<|eot_id|>"

    def clean(self):
        self.stream = ""
        self.push_back("system", "You are a helpful assistant. Always think step-by-step before answering and format your response as follows:\n"
                                 "<step 1 content>\n"
                                 "<step 2 content>\n"
                                 "...\n"
                                 "[answer]\n"
                                 "<answer content>\n"
                                 "Ensure every response follows this format, with each reasoning step on a new line and the answer preceded by [answer] on a new line, followed by its content on the next line.")
    
    def generate_action_list(self,prompt):
        output = self.model(prompt = prompt,max_tokens = 1,logprobs = self.k,temperature=0)
        output = output["choices"][0]["logprobs"]["top_logprobs"][0]
        first_tokens = list(output.keys())
        probs = list(output.values())
        actions = []
        for token in first_tokens:
            actions.append(token + self.model(prompt = prompt + token, max_tokens = 256,
                                      logprobs = 0, temperature = 0, stop='\n')["choices"][0]["text"] + "\n")
        return (actions, probs)
    
    def generate_answer_list(self,prompt):
        output = self.model(prompt = prompt,max_tokens = 1,logprobs = self.k,temperature=0)
        output = output["choices"][0]["logprobs"]["top_logprobs"][0]
        first_tokens = list(output.keys())
        probs = list(output.values())
        actions = []
        for token in first_tokens:
            actions.append(token + self.model(prompt = prompt + token, max_tokens = 512,
                                      logprobs = 0, temperature = 0)["choices"][0]["text"])
        return (actions, probs)
    
    # given a query and an answer, evaluate the answer
    # this function is a placeholder and should be implemented based on the specific evaluation criteria
    def answer_evaluate(self, query_answer):
        pass
    
    def MCTS_initialize(self):
        self.nodes = []
        self.deleted = []
        self.nodes.append(MCTS_REASONING_LLM.MCTS_NODE(-1, self.generate_action_list(self.stream), self.k))
        self.set_root(0)

    def new_node(self, parent, actions_and_probs):
        new_node = MCTS_REASONING_LLM.MCTS_NODE(parent, actions_and_probs, self.k)
        if(len(self.deleted)):
            self.nodes[self.deleted[0]] = new_node
            idx = self.deleted[0]
            self.deleted = self.deleted[1:]
            return idx 
        else:
            self.nodes.append(new_node)
            return len(self.nodes) - 1
    
    def delete_node(self, idx):
        self.deleted.append(idx)

    def delete_tree(self, idx):
        for child in self.nodes[idx].child:
            if(child != -1):
                self.delete_tree(child)
        self.deleted.append(idx)

    def set_root(self,idx):
        self.root = idx
        self.nodes[idx].parent = -1
    
    def select_and_expand(self):
        previous_node = -1
        current_node = self.root
        last_visit = -1
        current_prompt = self.stream
        while current_node != -1 and self.nodes[current_node].end == False:
            last_visit = self.nodes[current_node].get_optimum_child(self.policy_weight, self.explore_weight)
            self.nodes[current_node].last_visit = last_visit
            current_prompt += self.nodes[current_node].actions[last_visit]
            previous_node = current_node
            current_node = self.nodes[current_node].child[last_visit]

        if current_node != -1:
            return current_node, current_prompt
        else:
            if self.nodes[previous_node].actions[last_visit] == "[answer]\n":
                actions, probs = self.generate_answer_list(current_prompt) 
                self.nodes[previous_node].child[last_visit] = self.new_node(previous_node, (actions, probs))
                current_node = self.nodes[previous_node].child[last_visit]
                self.nodes[current_node].end = True
                for i in range(len(self.nodes[current_node].child)):
                    self.nodes[current_node].mean_reward[i] = self.answer_evaluate(self.stream + self.nodes[current_node].actions[i] + '<|eot_id|>')
            else:
                actions, probs = self.generate_action_list(current_prompt) 
                self.nodes[previous_node].child[last_visit] = self.new_node(previous_node, (actions, probs))
                current_node = self.nodes[previous_node].child[last_visit]
            return current_node, current_prompt
        
    def simulation(self, current_node, current_prompt):
        if self.nodes[current_node].end:
            optimum_child = self.nodes[current_node].get_optimum_child(self.policy_weight, self.explore_weight)
            self.nodes[current_node].last_visit = optimum_child
            reward = self.nodes[current_node].mean_reward[optimum_child]
            return reward
        else:
            optimum_child = self.nodes[current_node].get_optimum_child(self.policy_weight, self.explore_weight)
            self.nodes[current_node].last_visit = optimum_child
            reasoning = self.nodes[current_node].actions[optimum_child]
            while(reasoning != "[answer]\n"):
                current_prompt += reasoning
                reasoning = self.model(prompt = current_prompt, max_tokens = 256, temperature = 0, stop = '\n', logprobs = 0)["choices"][0]["text"] + "\n"
            current_prompt += reasoning
            respond = self.model(prompt = current_prompt, max_tokens = 512, temperature = 0,  logprobs = 0)["choices"][0]["text"] + "<|eot_id|>"
            return self.answer_evaluate(self.stream + respond + '<|eot_id|>')
            
    def backpropagation(self, current_node, reward):
        while current_node != -1:
            self.nodes[current_node].visit_count += 1
            self.nodes[current_node].child_visit_count[self.nodes[current_node].last_visit] += 1
            self.nodes[current_node].mean_reward[self.nodes[current_node].last_visit] += \
            (reward - self.nodes[current_node].mean_reward[self.nodes[current_node].last_visit]) / self.nodes[current_node].child_visit_count[self.nodes[current_node].last_visit]
            current_node = self.nodes[current_node].parent

    def query(self,query,iterations=100):
        self.push_back("user",query)
        self.stream += "<|start_header_id|>assistant<|end_header_id|>"
        self.MCTS_initialize()
        while(not self.nodes[self.root].end):
            while self.nodes[self.root].visit_count < iterations :
                current_node, current_prompt = self.select_and_expand()
                reward = self.simulation(current_node, current_prompt)
                self.backpropagation(current_node, reward)
            optimum_child = self.nodes[self.root].get_optimum_child(self.policy_weight,self.explore_weight)
            self.stream += self.nodes[self.root].actions[optimum_child]
            for i in range(self.k):
                if i != optimum_child and i != -1:
                    self.delete_tree(self.nodes[self.root].child[i])
            new_root = self.nodes[self.root].child[optimum_child]
            self.delete_node(self.root)
            self.set_root(new_root)
        optimum_child = self.nodes[self.root].get_optimum_child(self.policy_weight,self.explore_weight)
        respond = self.nodes[self.root].actions[optimum_child]
        self.stream += respond + '<|eot_id|>'
        return respond
            
    




# paragraph as node

In [None]:
import math
from llama_cpp import Llama
import re
import numpy as np
import random

class MCTS_NODE:
    def __init__(self, parent, solution, critique, Q_value):
        self.parent = parent
        self.solution = solution
        self.critique = critique
        self.Q_value = Q_value
        self.visit_count = 0
        self.children = []  
        self.reward_samples = [] 
        self.fully_expanded = False 
        self.expanded_children = 0 
    
class MCTS_REASONING_LLM:
    def __init__(self, model_path, max_child=5, c=1):
        self.model = Llama(model_path, n_ctx=16384, n_gpu_layers=99, logits_all=False, verbose=False)
        self.max_child = max_child
        self.c = c
        self.nodes = []
        self.query = None
        self.dummy_answers = [
            "I Don't Know",
            "I can't understand this question.",
            "I can't help with this question.",
            "I don't know how to solve this question.",
            "I don't know the answer to this question.",
            "I don't know the answer to this question, sorry."
        ]

    def is_fully_expanded(self, idx): 
        if len(self.nodes[idx].children) >= self.max_child:
            return True
        for child_idx in self.nodes[idx].children:
            if self.nodes[child_idx].Q_value > self.nodes[idx].Q_value:
                return True
        return False

    def get_optimum_child(self, idx):
        if not self.nodes[idx].children: 
            return -1
        if self.nodes[idx].parent != -1:
            parent_visit_count = self.nodes[self.nodes[idx].parent].visit_count
        else:
            parent_visit_count = 1
        UCT = []
        for child_idx in self.nodes[idx].children:
            child_visit_count = self.nodes[child_idx].visit_count
            uct_value = (self.nodes[child_idx].Q_value + 
                        self.c * math.sqrt(math.log(parent_visit_count + 1) / (child_visit_count + 1e-6)))
            UCT.append(uct_value)
        
        if (not self.nodes[idx].fully_expanded and 
            np.max(UCT) < self.c * math.sqrt(math.log(parent_visit_count + 1) / 1e-6)):
            return -1
            
        return self.nodes[idx].children[np.argmax(UCT)]

    def generate_critique(self, query, solution):
        print(f"[generate_critique] Generating critique for solution...")
        
        prompt = f"""<|start_header_id|>user<|end_header_id|>Since we have a weak Answer, could you provide me with a reflection or feedback to correct this answer better? Analyze this Answer Strictly and Critically, point out every flaw for every possible imperfect to minus every possible score!

Question: {query}
Answer: {solution}<|eot_id|>
<|start_header_id|>assistant<|end_header_id|>
Let's think step by step."""

        try:
            response = self.model(prompt=prompt, max_tokens=1024, temperature=0.8)
            critique = response["choices"][0]["text"].strip()
            print(f"[generate_critique] Critique generated successfully: {critique[:100]}...")
            return critique
        except Exception as e:
            print(f"[generate_critique] ERROR: {str(e)}")
            return "The answer needs improvement."

    def generate_refined_solution(self, query, original_solution, critique):
        print(f"[generate_refined_solution] Refining solution...")
        
        prompt = f"""<|start_header_id|>user<|end_header_id|>Please refine your answer according to the Reflection or Feedback. The response should begin with [reasoning process]...[Verification]... and end with "[Final Answer] The answer is [answer formula]"

Question: {query}
Original Answer: {original_solution}
Feedback: {critique}<|eot_id|>
<|start_header_id|>assistant<|end_header_id|>
Let's think step by step."""

        try:
            response = self.model(prompt=prompt, max_tokens=2048, temperature=0.8)
            refined = response["choices"][0]["text"].strip()
            print(f"[generate_refined_solution] Refined solution generated successfully: {refined[:100]}...")
            return refined
        except Exception as e:
            print(f"[generate_refined_solution] ERROR: {str(e)}")
            return original_solution

    def extract_score_from_text(self, text):
        # Look for patterns like [Score] -50, [Score]: -50, Score: -50, etc.
        score_patterns = [
            r'\[Score\]\s*[-]?\d+',  # [Score] -50
            r'\[Score\]:\s*[-]?\d+',  # [Score]: -50
            r'Score:\s*[-]?\d+',     # Score: -50
            r'Score\s+[-]?\d+',      # Score -50
            r'score\s*[:=]\s*[-]?\d+',  # score: -50 or score = -50
        ]
        
        for pattern in score_patterns:
            match = re.search(pattern, text, re.IGNORECASE)
            if match:
                # Extract just the number from the match
                number_match = re.search(r'[-]?\d+', match.group())
                if number_match:
                    return int(number_match.group())
        
        # Fallback: look for the last number in the text (often the final score)
        all_numbers = re.findall(r'[-]?\d+', text)
        if all_numbers:
            # Filter numbers to reasonable score range
            valid_scores = [int(num) for num in all_numbers if -100 <= int(num) <= 100]
            if valid_scores:
                return valid_scores[-1]  # Take the last valid score
        
        return 0  # Default if no score found

    def self_evaluate(self, query, solution, num_samples=3):
        print(f"[self_evaluate] Evaluating solution with {num_samples} samples...")
        
        scores = []
        for i in range(num_samples):
            prompt = f"""<|start_header_id|>user<|end_header_id|>Question: {query}
Answer: {solution}

Analyze this Answer Strictly and Critically, and point out every flaw for every possible imperfect to minus every possible score! You need to be very harsh and mean in calculating grades, and never give full marks to ensure that the marks are authoritative.

Output a score between [-100,+100].

Format: [Analysis] your analysis here [Score] your_number_here

Example: [Analysis] The solution has calculation errors and lacks proper reasoning. [Score] -45<|eot_id|>
<|start_header_id|>assistant<|end_header_id|>"""

            try:
                response = self.model(prompt=prompt, max_tokens=512, temperature=0.8)
                text = response["choices"][0]["text"]
                print(f"[self_evaluate] Sample {i+1} response: {text[:100]}...")
                
                # Extract score using improved method
                score = self.extract_score_from_text(text)
                
                # Full Score Suppression: reduce scores above 95
                if score > 95:
                    score = max(95, score - 10)
                
                # Clamp score to valid range
                score = max(-100, min(100, score))
                scores.append(score)
                print(f"[self_evaluate] Sample {i+1} extracted score: {score}")
                    
            except Exception as e:
                print(f"[self_evaluate] ERROR in sample {i+1}: {str(e)}")
                scores.append(0)
        
        print(f"[self_evaluate] Final scores: {scores}")
        return scores
    
    def calculate_q_value(self, reward_samples):
        """Calculate Q value using formula from paper: Q(a) = 1/2 * (min(R_a) + mean(R_a))"""
        if not reward_samples:
            return 0
        
        min_reward = min(reward_samples)
        mean_reward = sum(reward_samples) / len(reward_samples)
        q_value = 0.5 * (min_reward + mean_reward)
        
        return q_value
    
    def update_q_value_with_children(self, node_index):
        """Update Q value considering children: Q'(a) = 1/2 * (Q(a) + max_child_Q)"""
        node = self.nodes[node_index]
        
        # Calculate base Q value from own rewards
        base_q = self.calculate_q_value(node.reward_samples)
        
        # Find maximum Q value among children
        max_child_q = float('-inf')
        has_children = False
        
        for child_idx in node.children:
            child_q = self.nodes[child_idx].Q_value
            max_child_q = max(max_child_q, child_q)
            has_children = True
        
        # Update Q value: Q'(a) = 1/2 * (Q(a) + max_child_Q)
        if has_children:
            node.Q_value = 0.5 * (base_q + max_child_q)
        else:
            node.Q_value = base_q

    def mcts_init(self, query):
        print(f"[mcts_init] Initializing MCTS for query: {query[:50]}...")
        
        self.query = query
        self.nodes = []
        
        # Create root node with dummy answer
        dummy_solution = random.choice(self.dummy_answers)
        critique = self.generate_critique(self.query, dummy_solution)
        
        # Create root node
        root_node = MCTS_NODE(-1, dummy_solution, critique, 0)
        
        # Evaluate root node
        root_node.reward_samples = self.self_evaluate(query, dummy_solution)
        root_node.Q_value = self.calculate_q_value(root_node.reward_samples)
        root_node.visit_count = 1
        
        self.nodes.append(root_node)
        print(f"[mcts_init] Root node created with Q-value: {root_node.Q_value}")

    def iterator(self):
        """Single MCTS iteration combining all phases: Selection -> Expansion -> Evaluation -> Backpropagation"""
        # SELECTION PHASE: Navigate to leaf node
        current_node = 0
        previous_node = -1
        while current_node != -1:
            previous_node = current_node
            current_node = self.get_optimum_child(current_node)
        
        # EXPANSION PHASE: Create refined solution
        solution = self.generate_refined_solution(
            self.query, 
            self.nodes[previous_node].solution, 
            self.nodes[previous_node].critique 
        )
        critique = self.generate_critique(self.query, solution)
        
        # Create new child node
        new_node = MCTS_NODE(previous_node, solution, critique, 0)
        self.nodes.append(new_node)
        current_node = len(self.nodes) - 1
        
        # Add child to parent's children list
        self.nodes[previous_node].children.append(current_node)
        
        # EVALUATION PHASE: Self-evaluate the new solution
        self.nodes[current_node].reward_samples = self.self_evaluate(self.query, solution)
        self.nodes[current_node].Q_value = self.calculate_q_value(self.nodes[current_node].reward_samples)
        self.nodes[current_node].visit_count = 1
        
        # BACKPROPAGATION PHASE: Update Q values up the tree
        while previous_node != -1:
            self.nodes[previous_node].visit_count += 1
            self.update_q_value_with_children(previous_node)
            self.nodes[previous_node].fully_expanded = self.is_fully_expanded(previous_node)  # Fixed: was 'fully_expended'
            previous_node = self.nodes[previous_node].parent

    def run(self, query, iterations=100):
        """Run MCTS for specified number of iterations"""
        print(f"[run] Starting MCTS with {iterations} iterations...")
        self.mcts_init(query)
        for i in range(iterations):
            print(f"[run] Iteration {i+1}/{iterations}")
            self.iterator()
            
            # Print progress
            if (i + 1) % 10 == 0:
                best_node = max(self.nodes, key=lambda n: n.Q_value)
                print(f"[run] Best Q-value after {i+1} iterations: {best_node.Q_value}")
        
        # Return best solution
        best_node = max(self.nodes, key=lambda n: n.Q_value)
        print(f"[run] Final best Q-value: {best_node.Q_value}")
        return best_node.solution

    def get_best_solution(self):
        """Get the solution with highest Q-value"""
        if not self.nodes:
            return None
        
        best_node = max(self.nodes, key=lambda n: n.Q_value)
        return best_node.solution
    

# Initialize the model
mctsr = MCTS_REASONING_LLM("llama-3.1-8b-instruct-q4_k_m.gguf")

# Solve a problem
query = "What is the sum of the first 10 prime numbers?"
solution = mctsr.run(query, 20)

# Get tree statistics
stats = mctsr.get_tree_stats()
print(f"Generated {stats['total_nodes']} nodes with max Q-value: {stats['max_q_value']}")

llama_context: n_ctx_per_seq (16384) < n_ctx_train (131072) -- the full capacity of the model will not be utilized
llama_kv_cache_unified: LLAMA_SET_ROWS=0, using old ggml_cpy() method for backwards compatibility
llama_kv_cache_unified: LLAMA_SET_ROWS=0, using old ggml_cpy() method for backwards compatibility


[run] Starting MCTS with 20 iterations...
[mcts_init] Initializing MCTS for query: What is the sum of the first 10 prime numbers?...
[generate_critique] Generating critique for solution...
[generate_critique] Critique generated successfully: Step 1:  First, we need to understand what prime numbers are. Prime numbers are numbers greater than...
[self_evaluate] Evaluating solution with 3 samples...
[generate_critique] Critique generated successfully: Step 1:  First, we need to understand what prime numbers are. Prime numbers are numbers greater than...
[self_evaluate] Evaluating solution with 3 samples...
[self_evaluate] Sample 1 response: 

[Analysis] The answer is an outright refusal to help, which is unacceptable in an academic or prob...
[self_evaluate] Sample 1 extracted score: -100
[self_evaluate] Sample 1 response: 

[Analysis] The answer is an outright refusal to help, which is unacceptable in an academic or prob...
[self_evaluate] Sample 1 extracted score: -100
[self_evaluate] S

In [None]:
import pandas as pd
from llama_cpp import Llama
import random
import math
import numpy as np

class MCTS_NODE:
    def __init__(self, parent, solution, critique, Q_value):
        self.parent = parent
        self.solution = solution
        self.critique = critique
        self.Q_value = Q_value
        self.visit_count = 0
        self.children = []
        self.reward_samples = []
        self.fully_expanded = False

class MCTS_REASONING_LLM:
    def __init__(self, model_path, max_child=5, c=1):
        self.model = Llama(model_path, n_ctx=16384, n_gpu_layers=99, logits_all=False, verbose=False)
        self.max_child = max_child
        self.c = c
        self.nodes = []
        self.query = None
        self.dummy_answers = [
            "I Don't Know",
            "I can't understand this question.",
            "I can't help with this question.",
            "I don't know how to solve this question.",
            "I don't know the answer to this question.",
            "I don't know the answer to this question, sorry."
        ]

    def is_fully_expanded(self, idx):
        if len(self.nodes[idx].children) >= self.max_child:
            return True
        for child_idx in self.nodes[idx].children:
            if self.nodes[child_idx].Q_value > self.nodes[idx].Q_value:
                return True
        return False

    def get_optimum_child(self, idx):
        if not self.nodes[idx].children:
            return -1
        parent_visit_count = self.nodes[self.nodes[idx].parent].visit_count if self.nodes[idx].parent != -1 else 1
        UCT = []
        for child_idx in self.nodes[idx].children:
            child_visit_count = self.nodes[child_idx].visit_count
            uct_value = (self.nodes[child_idx].Q_value +
                         self.c * math.sqrt(math.log(parent_visit_count + 1) / (child_visit_count + 1e-6)))
            UCT.append(uct_value)
        if not self.nodes[idx].fully_expanded and np.max(UCT) < self.c * math.sqrt(math.log(parent_visit_count + 1) / 1e-6):
            return -1
        return self.nodes[idx].children[np.argmax(UCT)]

    def generate_critique(self, query, solution):
        prompt = f"""<|start_header_id|>user<|end_header_id|>Since we have a weak Answer, could you provide me with a reflection or feedback to correct this answer better? Analyze this Answer Strictly and Critically, point out every flaw for every possible imperfect to minus every possible score!

Question: {query}
Answer: {solution}<|eot_id|>
<|start_header_id|>assistant<|end_header_id|>
Let's think step by step."""
        try:
            response = self.model(prompt=prompt, max_tokens=1024, temperature=0.8)
            return response["choices"][0]["text"].strip()
        except:
            return "The answer needs improvement."

    def generate_refined_solution(self, query, original_solution, critique):
        prompt = f"""<|start_header_id|>user<|end_header_id|>Please refine your answer according to the Reflection or Feedback. The response should begin with [reasoning process]...[Verification]... and end with "[Final Answer] The answer is [answer formula]"

Question: {query}
Original Answer: {original_solution}
Feedback: {critique}<|eot_id|>
<|start_header_id|>assistant<|end_header_id|>
Let's think step by step."""
        try:
            response = self.model(prompt=prompt, max_tokens=2048, temperature=0.8)
            return response["choices"][0]["text"].strip()
        except:
            return original_solution

    def extract_score_from_text(self, text):
        score_patterns = [
            r'\[Score\]\s*[-]?\d+',
            r'\[Score\]:\s*[-]?\d+',
            r'Score:\s*[-]?\d+',
            r'Score\s+[-]?\d+',
            r'score\s*[:=]\s*[-]?\d+',
        ]
        for pattern in score_patterns:
            match = re.search(pattern, text, re.IGNORECASE)
            if match:
                number_match = re.search(r'[-]?\d+', match.group())
                if number_match:
                    return int(number_match.group())
        all_numbers = re.findall(r'[-]?\d+', text)
        if all_numbers:
            valid_scores = [int(num) for num in all_numbers if -100 <= int(num) <= 100]
            if valid_scores:
                return valid_scores[-1]
        return 0

    def self_evaluate(self, query, solution, num_samples=3):
        scores = []
        for _ in range(num_samples):
            prompt = f"""<|start_header_id|>user<|end_header_id|>Question: {query}
Answer: {solution}

Analyze this Answer Strictly and Critically, and point out every flaw for every possible imperfect to minus every possible score! You need to be very harsh and mean in calculating grades, and never give full marks to ensure that the marks are authoritative.

Output a score between [-100,+100].

Format: [Analysis] your analysis here [Score] your_number_here

Example: [Analysis] The solution has calculation errors and lacks proper reasoning. [Score] -45<|eot_id|>
<|start_header_id|>assistant<|end_header_id|>"""
            response = self.model(prompt=prompt, max_tokens=512, temperature=0.8)
            text = response["choices"][0]["text"]
            score = self.extract_score_from_text(text)
            if score > 95:
                score = max(95, score - 10)
            score = max(-100, min(100, score))
            scores.append(score)
        return scores

    def calculate_q_value(self, reward_samples):
        if not reward_samples:
            return 0
        min_reward = min(reward_samples)
        mean_reward = sum(reward_samples) / len(reward_samples)
        return 0.5 * (min_reward + mean_reward)

    def update_q_value_with_children(self, node_index):
        node = self.nodes[node_index]
        base_q = self.calculate_q_value(node.reward_samples)
        max_child_q = float('-inf')
        has_children = False
        for child_idx in node.children:
            child_q = self.nodes[child_idx].Q_value
            max_child_q = max(max_child_q, child_q)
            has_children = True
        if has_children:
            node.Q_value = 0.5 * (base_q + max_child_q)
        else:
            node.Q_value = base_q

    def mcts_init(self, query):
        self.query = query
        self.nodes = []
        dummy_solution = random.choice(self.dummy_answers)
        critique = self.generate_critique(self.query, dummy_solution)
        root_node = MCTS_NODE(-1, dummy_solution, critique, 0)
        root_node.reward_samples = self.self_evaluate(query, dummy_solution)
        root_node.Q_value = self.calculate_q_value(root_node.reward_samples)
        root_node.visit_count = 1
        self.nodes.append(root_node)

    def iterator(self):
        current_node = 0
        previous_node = -1
        while current_node != -1:
            previous_node = current_node
            current_node = self.get_optimum_child(current_node)
        solution = self.generate_refined_solution(
            self.query,
            self.nodes[previous_node].solution,
            self.nodes[previous_node].critique
        )
        critique = self.generate_critique(self.query, solution)
        new_node = MCTS_NODE(previous_node, solution, critique, 0)
        self.nodes.append(new_node)
        current_node = len(self.nodes) - 1
        self.nodes[previous_node].children.append(current_node)
        self.nodes[current_node].reward_samples = self.self_evaluate(self.query, solution)
        self.nodes[current_node].Q_value = self.calculate_q_value(self.nodes[current_node].reward_samples)
        self.nodes[current_node].visit_count = 1
        while previous_node != -1:
            self.nodes[previous_node].visit_count += 1
            self.update_q_value_with_children(previous_node)
            self.nodes[previous_node].fully_expanded = self.is_fully_expanded(previous_node)
            previous_node = self.nodes[previous_node].parent

    def run(self, query, iterations=5):
        self.mcts_init(query)
        for _ in range(iterations):
            self.iterator()
        best_node = max(self.nodes, key=lambda n: n.Q_value)
        return best_node.solution

    def get_best_solution(self):
        if not self.nodes:
            return None
        best_node = max(self.nodes, key=lambda n: n.Q_value)
        return best_node.solution
    
    def print_status(self):
        print("printing MCTS status")
        idx = 0
        for nodes in self.nodes:
            print(f"node_{idx}:\n"
                  f"parent: {nodes.parent}, children: {nodes.children}\n"
                  f"solution:\n"
                  f"{nodes.solution}\n"
                  "critique:\n"
                  f"{nodes.critique}\n"
                  f"Q_value: {nodes.Q_value}\n"
                  f"evaluate_samples:\n"
                  f"{nodes.reward_samples}\n")
            idx += 1

# Function to generate direct answer using Llama
def generate_direct_answer(model, query):
    prompt = f"""<|start_header_id|>user<|end_header_id|>Question: {query}
Please provide the answer directly without any reasoning or explanation.<|eot_id|>
<|start_header_id|>assistant<|end_header_id|>"""
    try:
        response = model(prompt=prompt, max_tokens=512, temperature=0.8)
        return response["choices"][0]["text"].strip()
    except:
        return "Unable to generate answer."

# Load riddle data
df = pd.read_excel("Riddle.xlsx")

# Initialize Llama model (replace with actual model path)
mcts = MCTS_REASONING_LLM(model_path="llama-3.1-8b-instruct-q4_k_m.gguf")
llm = mcts.model
# Test on a subset of riddles (e.g., first 5 for demonstration)
results = []
for idx, row in df.head(5).iterrows():
    query = row["Question"]
    actual_answer = row["Answer"]
    
    # Generate MCTS answer
    mcts_answer = mcts.run(query, iterations=5)
    
    # Generate direct answer
    direct_answer = generate_direct_answer(llm, query)

    mcts.print_status()

    print("\nComparison of MCTS, Direct, and Actual Answers:\n")
    print(f"Question: {query}\n"
          f"Actual_answer: {actual_answer}\n"
          f"MCTS_answer: {mcts_answer}\n"
          f"direct_answer: {direct_answer}\n")
    
    # Store results
    results.append({
        "ID": row["ID"],
        "Question": query,
        "Actual Answer": actual_answer,
        "MCTS Answer": mcts_answer,
        "Direct Answer": direct_answer
    })

# Convert results to DataFrame and save to CSV
results_df = pd.DataFrame(results)
results_df.to_csv("riddle_comparison.csv", index=False)

# Display results
print("\nComparison of MCTS, Direct, and Actual Answers:")
print(results_df[["ID", "Question", "Actual Answer", "MCTS Answer", "Direct Answer"]].to_string(index=False))

llama_context: n_ctx_per_seq (16384) < n_ctx_train (131072) -- the full capacity of the model will not be utilized
llama_kv_cache_unified: LLAMA_SET_ROWS=0, using old ggml_cpy() method for backwards compatibility
llama_kv_cache_unified: LLAMA_SET_ROWS=0, using old ggml_cpy() method for backwards compatibility


[0, 0, 0]
[0, 0, 0]
[0, 0, 0]
