# Setup

In [None]:
# import os
# from google import genai
# from google.genai import types
# # The client gets the API key from the environment variable `GEMINI_API_KEY`.
# client = genai.Client()
# response = client.models.generate_content(
#     model="gemini-2.5-flash-lite-preview-06-17", 
#     config=types.GenerateContentConfig(
#         thinking_config=types.ThinkingConfig(thinking_budget=0) # Disables thinking
#     ),
#     contents="Explain how AI works in a few words"
# )
# print(response.text)

In [None]:
import os
from google import genai
from google.genai import types
# The client gets the API key from the environment variable `GEMINI_API_KEY`.
client = genai.Client()

def llm(prompt, stop=["\n"], num_traces=1):
  temperature_setting = 0.0 if num_traces == 1 else 0.7
  response = client.models.generate_content(
    model="gemini-2.5-flash-lite-preview-06-17",
    contents=prompt,
    config=types.GenerateContentConfig(
        thinking_config=types.ThinkingConfig(thinking_budget=0), # Disables thinking
        stop_sequences=stop,
        temperature=temperature_setting,
        max_output_tokens=100,
        top_p=1.0
    )
  )
  return response.text

In [None]:
import re

def extract_final_answer_from_trace_string(trace_trajectory_string):
    """
    Extracts the final answer from a ReAct trace trajectory string.
    Looks for the last occurrence of 'Action X: Finish[answer]'.
    """
    pattern = re.compile(r"^Action \d+: Finish\[(.*?)\]\s*$", re.MULTILINE)
    matches = pattern.findall(trace_trajectory_string)
    
    if matches:
        # The last match in the string is the one we want
        return matches[-1].strip()
        
    return None

def extract_answers_from_traces(all_traces_info):
    """
    Extracts the final answer from each trace in the all_traces_info list.
    """
    extracted_answers = []
    if not isinstance(all_traces_info, list):
        print(f"Warning: extract_answers_from_traces expected a list, got {type(all_traces_info)}")
        return extracted_answers

    for i, trace_info in enumerate(all_traces_info):
        trajectory = trace_info.get('traj', '')
        answer_from_traj = extract_final_answer_from_trace_string(trajectory)
        
        if answer_from_traj is not None:
            extracted_answers.append(answer_from_traj)
        else:
            env_answer = trace_info.get('answer')
            if env_answer:
                extracted_answers.append(env_answer)
            else:
                extracted_answers.append(None)
    return [ans for ans in extracted_answers if ans is not None]

In [None]:
def synthesize_answer_with_llm(list_of_answers, question_for_context=""):
    """
    Synthesizes a single best answer from a list of answers using an LLM.
    Includes the original question for better context if provided.
    """
    if not list_of_answers:
        return "Error: No answers provided to synthesize."

    unique_answers = sorted(list(set(str(a).strip() for a in list_of_answers if str(a).strip())))
    if len(unique_answers) == 0:
        return "Error: No valid answers found after filtering to synthesize."
    if len(unique_answers) == 1:
        return unique_answers[0]

    prompt_template = """As an expert analyst, your task is to determine the single best answer from the following list, which was generated in response to the same question.\n{question_context}\nReview all answers, identify the most consistent and factually correct choice, and return that single answer. For fixed-choice questions (like yes/no or numbers), this will be a majority vote. For text-based answers, synthesize the information into the most accurate and complete response. Ignore any clear outliers or factually incorrect statements.\n\nGenerated Answers:\n{formatted_answers}\n\nFinal Answer:"""

    question_context_str = ""
    if question_for_context:
        question_context_str = f"The question asked was: \"{question_for_context}\"\n\n"

    formatted_answers = ""
    for i, ans in enumerate(list_of_answers):
        formatted_answers += f"{i+1}. {ans}\n"
    formatted_answers = formatted_answers.strip()
    
    synthesizer_prompt = prompt_template.format(
        question_context=question_context_str,
        formatted_answers=formatted_answers
    )
    
    final_answer = llm(synthesizer_prompt, num_traces=1)
    return final_answer.strip()

In [None]:
import wikienv, wrappers
import requests

env = wikienv.WikiEnv()
env = wrappers.HotPotQAWrapper(env, split="dev")
env = wrappers.LoggingWrapper(env)

def step(env, action):
    attempts = 0
    while attempts < 10:
        try:
            return env.step(action)
        except requests.exceptions.Timeout:
            attempts += 1

# ReAct

In [None]:
import json
import sys

folder = './prompts/'
prompt_file = 'prompts_naive.json'
with open(folder + prompt_file, 'r') as f:
    prompt_dict = json.load(f)

webthink_examples = prompt_dict['webthink_simple6']
instruction = """Solve a question answering task with interleaving Thought, Action, Observation steps. Thought can reason about the current situation, and Action can be three types: \n(1) Search[entity], which searches the exact entity on Wikipedia and returns the first paragraph if it exists. If not, it will return some similar entities to search.\n(2) Lookup[keyword], which returns the next sentence containing keyword in the current passage.\n(3) Finish[answer], which returns the answer and finishes the task.\nHere are some examples.\n"""
webthink_prompt_template = instruction + webthink_examples # Renamed to avoid conflict

def webthink(idx=None, initial_prompt_template=webthink_prompt_template, to_print=True, num_traces=1):
    all_traces_info = []
    question_for_synthesis = "" # Define outside loop to store it

    if num_traces <= 0:
        if to_print:
            print(f"Warning: webthink called with num_traces = {num_traces}. Must be > 0.")
        return "[INVALID_NUM_TRACES]", [] 

    for trace_num in range(num_traces):
        question = env.reset(idx=idx) # Reset environment for each trace
        if trace_num == 0: # Capture question on first trace for synthesizer
            question_for_synthesis = question
        
        current_prompt = initial_prompt_template + question + "\n"

        if to_print:
            print(f"--- Trace {trace_num + 1}/{num_traces} ---")
            print(idx, question)
        
        n_calls, n_badcalls = 0, 0

        for i in range(1, 8): # Max 7 steps per trace
            n_calls += 1
            thought_action = llm(current_prompt + f"Thought {i}:", stop=[f"\nObservation {i}:"], num_traces=num_traces)
            try:
                thought, action = thought_action.strip().split(f"\nAction {i}: " )
            except:
                n_badcalls += 1
                n_calls += 1
                thought = thought_action.strip().split('\n')[0]
                action = llm(current_prompt + f"Thought {i}: {thought}\nAction {i}:", stop=[f"\n"], num_traces=num_traces).strip()
            
            obs, r, done, info = step(env, action[0].lower() + action[1:])
            obs = obs.replace('\\n', '')
            
            step_str = f"Thought {i}: {thought}\nAction {i}: {action}\nObservation {i}: {obs}\n"
            current_prompt += step_str
            
            if to_print:
                print(step_str)
            
            if done:
                break
        
        if not done:
            obs, r, done, info = step(env, "finish[]")
            if 'traj' not in info: 
                info = {}
            info.update({'finish_action_obs': obs})

        trace_info = info.copy() 
        trace_info.update({'n_calls': n_calls, 
                           'n_badcalls': n_badcalls, 
                           'traj': current_prompt, 
                           'question_idx': idx,
                           'question_text': question, 
                           'trace_num': trace_num + 1})
        all_traces_info.append(trace_info)

        if to_print:
            print(f"(Trace {trace_num + 1}) Info: {trace_info}\n")
            if num_traces > 1 and trace_num < num_traces - 1:
                print(f"--- End of Trace {trace_num + 1} ---\n")
    
    if not all_traces_info: 
        if to_print:
            print("Warning: No traces were generated despite num_traces > 0.")
        return "[NO_TRACE_GENERATED]", []

    if num_traces == 1:
        final_r = all_traces_info[0].get('reward', 0.0) 
        return final_r, all_traces_info[0] 
    
    else: # num_traces > 1
        if to_print:
            print("\n--- Starting Answer Synthesis ---")
        
        extracted_answers = extract_answers_from_traces(all_traces_info)
        
        if to_print:
            print(f"Extracted Answers for Synthesis: {extracted_answers}")

        if not extracted_answers:
            if to_print:
                print("Warning: No answers extracted from traces. Cannot synthesize.")
            return "[SYNTHESIS_FAILED_NO_EXTRACTED_ANSWERS]", all_traces_info

        synthesized_answer = synthesize_answer_with_llm(extracted_answers, question_for_synthesis)
        
        if to_print:
            print(f"Synthesized Answer: {synthesized_answer}")
            print("--- End of Answer Synthesis ---\n")

        return synthesized_answer, all_traces_info

In [None]:
idxs = list(range(7405)) # Ensure idxs is defined

# Example of running with multiple traces (num_traces > 1)
print("--- Testing webthink with num_traces=3 ---")
synthesized_answer, traces_list = webthink(idx=idxs[0], to_print=True, num_traces=3)

print(f"\n--- Synthesized Answer for idxs[0] (num_traces=3) ---")
print(synthesized_answer)
print("---")

if isinstance(traces_list, list):
    for i, trace_info in enumerate(traces_list):
        print(f"Details for Trace {i+1}:")
        print(f"  Question Index: {trace_info.get('question_idx')}")
        print(f"  Question Text: {trace_info.get('question_text')}")
        print(f"  Trace Answer: {trace_info.get('answer')}") 
        print(f"  Trace Reward: {trace_info.get('reward')}")
        print(f"  Trace EM: {trace_info.get('em')}")
        print(f"  Trace F1: {trace_info.get('f1')}")
        print("  ---")
else:
    print("Error: Expected a list of traces as the second part of the tuple.")
    print(f"Received for traces_list: {traces_list}")

# Example of running with a single trace (num_traces = 1)
print("\n--- Testing webthink with num_traces=1 ---")
reward_single, info_single = webthink(idx=idxs[1], to_print=True, num_traces=1)

print(f"\n--- Single Trace Result for idxs[1] (num_traces=1) ---")
print(f"Reward: {reward_single}")
if info_single and isinstance(info_single, dict):
    print(f"  Question Index: {info_single.get('question_idx')}")
    print(f"  Question Text: {info_single.get('question_text')}")
    print(f"  Answer: {info_single.get('answer')}")
    print(f"  EM: {info_single.get('em')}")
    print(f"  F1: {info_single.get('f1')}")
else:
    print("Error: Received no valid info_single for single trace run.")

In [None]:
# Use this for benchmarking typical ReAct with Gemini Flash Lite
import random
import time
idxs = list(range(7405))
random.Random(233).shuffle(idxs)
rs = []
infos = []
old_time = time.time()
for i in idxs[:500]:
    r_val, info_val = webthink(idx=i, to_print=False, num_traces=1) 
    rs.append(info_val['em'])
    infos.append(info_val) 
    print(sum(rs), len(rs), sum(rs) / len(rs), (time.time() - old_time) / len(rs))
    print('-----------')
    print()

In [None]:
# Cell A: Demonstrate single trace for idxs[0]
print("--- Running webthink with num_traces=1 for idxs[0] ---")
if 'idxs' not in globals():
    print("WARNING: 'idxs' not found globally. Assuming it will be defined by a preceding cell.")
    print("For standalone execution of this cell, ensure 'idxs', 'env', 'llm', and 'webthink' are defined.")
    idxs = list(range(7405))
if 'webthink' in globals() and 'env' in globals() and 'llm' in globals() and 'idxs' in globals():
    print(f"Using idxs[0] which is: {idxs[0]}")
    r_single, info_single = webthink(idx=idxs[0], to_print=True, num_traces=1)
    print("\n--- Single Trace (idxs[0]) Summary ---")
    print(f"Reward: {r_single}")
    if info_single and isinstance(info_single, dict):
        print(f"Answer: {info_single.get('answer')}")
        print(f"EM: {info_single.get('em')}")
        print(f"F1: {info_single.get('f1')}")
        print(f"Num Calls: {info_single.get('n_calls')}")
    else:
        print(f"Error: info_single was not a valid dictionary. Value: {info_single}")
else:
    print("ERROR: One or more required components (idxs, webthink, env, llm) are not defined.")
    print("Please ensure all preceding setup cells are executed.")

In [None]:
# Cell B: Demonstrate multiple traces for idxs[0]
print("\n--- Running webthink with num_traces=3 for idxs[0] (Updated for Synthesizer) ---")
if 'idxs' not in globals():
    print("WARNING: 'idxs' not found globally. Assuming it will be defined by a preceding cell.")
    print("For standalone execution of this cell, ensure 'idxs', 'env', 'llm', and 'webthink' are defined.")
    idxs = list(range(7405))
if 'webthink' in globals() and 'env' in globals() and 'llm' in globals() and 'idxs' in globals():
    print(f"Using idxs[0] which is: {idxs[0]}")
    
    synthesized_answer_b, multi_traces_list_b = webthink(idx=idxs[0], to_print=True, num_traces=3)
    
    print("\n--- Multi-Trace (idxs[0]) Summary (Updated for Synthesizer) ---")
    print(f"Synthesized Answer: {synthesized_answer_b}")
    print("---")
    
    if isinstance(multi_traces_list_b, list):
        for i, trace_info in enumerate(multi_traces_list_b):
            print(f"Trace {i+1} Summary:")
            print(f"  Question Index: {trace_info.get('question_idx')}")
            print(f"  Question Text: {trace_info.get('question_text')}") 
            print(f"  Answer: {trace_info.get('answer')}") 
            print(f"  Reward: {trace_info.get('reward')}") 
            print(f"  EM: {trace_info.get('em')}") 
            print(f"  F1: {trace_info.get('f1')}") 
            print(f"  Num Calls: {trace_info.get('n_calls')}")
            print("  ---")
    else:
        print("Expected a list of traces as the second element of the tuple, but received something else.")
        print(f"Received for traces list: {multi_traces_list_b}")
else:
    print("ERROR: One or more required components (idxs, webthink, env, llm) are not defined.")
    print("Please ensure all preceding setup cells are executed.")