# Setup

In [None]:
# import os
# from google import genai
# from google.genai import types
# # The client gets the API key from the environment variable `GEMINI_API_KEY`.
# client = genai.Client()
# response = client.models.generate_content(
#     model="gemini-2.5-flash-lite-preview-06-17", 
#     config=types.GenerateContentConfig(
#         thinking_config=types.ThinkingConfig(thinking_budget=0) # Disables thinking
#     ),
#     contents="Explain how AI works in a few words"
# )
# print(response.text)

In [None]:
import os
from google import genai
from google.genai import types
# The client gets the API key from the environment variable `GEMINI_API_KEY`.
client = genai.Client()

def llm(prompt, stop=["\n"], num_traces=1):
  temperature_setting = 0.0 if num_traces == 1 else 0.7
  response = client.models.generate_content(
    model="gemini-2.5-flash-lite-preview-06-17",
    contents=prompt,
    config=types.GenerateContentConfig(
        thinking_config=types.ThinkingConfig(thinking_budget=0), # Disables thinking
        stop_sequences=stop,
        temperature=temperature_setting,
        max_output_tokens=100,
        top_p=1.0
    )
  )
  return response.text

In [None]:
import wikienv, wrappers
import requests

env = wikienv.WikiEnv()
env = wrappers.HotPotQAWrapper(env, split="dev")
env = wrappers.LoggingWrapper(env)

def step(env, action):
    attempts = 0
    while attempts < 10:
        try:
            return env.step(action)
        except requests.exceptions.Timeout:
            attempts += 1

# ReAct

In [None]:
import json
import sys

folder = './prompts/'
prompt_file = 'prompts_naive.json'
with open(folder + prompt_file, 'r') as f:
    prompt_dict = json.load(f)

webthink_examples = prompt_dict['webthink_simple6']
instruction = """Solve a question answering task with interleaving Thought, Action, Observation steps. Thought can reason about the current situation, and Action can be three types: 
(1) Search[entity], which searches the exact entity on Wikipedia and returns the first paragraph if it exists. If not, it will return some similar entities to search.
(2) Lookup[keyword], which returns the next sentence containing keyword in the current passage.
(3) Finish[answer], which returns the answer and finishes the task.
Here are some examples.
"""
webthink_prompt_template = instruction + webthink_examples # Renamed to avoid conflict

def webthink(idx=None, initial_prompt_template=webthink_prompt_template, to_print=True, num_traces=1):
    all_traces_info = []

    for trace_num in range(num_traces):
        question = env.reset(idx=idx) # Reset environment for each trace
        
        # Use a fresh prompt for each trace, incorporating the question
        current_prompt = initial_prompt_template + question + "\n"

        if to_print:
            print(f"--- Trace {trace_num + 1}/{num_traces} ---")
            print(idx, question)
        
        n_calls, n_badcalls = 0, 0

        for i in range(1, 8): # Max 7 steps per trace
            n_calls += 1
            thought_action = llm(current_prompt + f"Thought {i}:", stop=[f"\nObservation {i}:"], num_traces=num_traces)
            try:
                thought, action = thought_action.strip().split(f"\nAction {i}: " )
            except:
                n_badcalls += 1
                n_calls += 1
                thought = thought_action.strip().split('\n')[0]
                action = llm(current_prompt + f"Thought {i}: {thought}\nAction {i}:", stop=[f"\n"], num_traces=num_traces).strip()
            
            obs, r, done, info = step(env, action[0].lower() + action[1:])
            obs = obs.replace('\\n', '')
            
            step_str = f"Thought {i}: {thought}\nAction {i}: {action}\nObservation {i}: {obs}\n"
            current_prompt += step_str
            
            if to_print:
                print(step_str)
            
            if done:
                break
        
        if not done:
            obs, r, done, info = step(env, "finish[]")
            if 'traj' not in info:
                info = {}
            info.update({'finish_action_obs': obs})

        trace_info = info.copy()
        trace_info.update({'n_calls': n_calls, 
                           'n_badcalls': n_badcalls, 
                           'traj': current_prompt, 
                           'question_idx': idx,
                           'trace_num': trace_num + 1})
        all_traces_info.append(trace_info)

        if to_print:
            print(f"(Trace {trace_num + 1}) Info: {trace_info}\n")
            if num_traces > 1 and trace_num < num_traces - 1:
                print(f"--- End of Trace {trace_num + 1} ---\n")

    if num_traces == 1 and all_traces_info:
        final_r = all_traces_info[0].get('reward', 0.0)
        return final_r, all_traces_info[0]
    
    return all_traces_info

In [None]:
idxs = list(range(7405))
# r, info = webthink(idxs[0], to_print=True) # Old way for single trace
# print(info)
# Example of running with multiple traces
traces_output = webthink(idx=idxs[0], to_print=True, num_traces=3)
if isinstance(traces_output, list): # Check if it's a list (for num_traces > 1)
    for i, trace_info in enumerate(traces_output):
        print(f"Information for Trace {i+1}:")
        print(f"  Answer: {trace_info.get('answer')}")
        print(f"  Reward: {trace_info.get('reward')}")
        print(f"  Trajectory Snippet: {trace_info.get('traj', '')[:200]}...") # Print a snippet
        print("---")
else: # Handle single trace output if num_traces was 1 and webthink returned (r, info)
    r_val, info_val = traces_output
    print(f"Information for Single Trace (fallback):")
    print(f"  Answer: {info_val.get('answer')}")
    print(f"  Reward: {info_val.get('reward')}")
    print(f"  Trajectory Snippet: {info_val.get('traj', '')[:200]}...")
# Example of running with a single trace (should behave as before)
# _, single_trace_info = webthink(idx=idxs[1], to_print=True, num_traces=1)
# print("Information for Single Trace Run:")
# print(f"  Answer: {single_trace_info.get('answer')}")
# print(f"  Reward: {single_trace_info.get('reward')}")
# print(f"  Trajectory Snippet: {single_trace_info.get('traj', '')[:200]}...")

In [None]:
# Use this for benchmarking typical ReAct with Gemini Flash Lite
import random
import time
idxs = list(range(7405))
random.Random(233).shuffle(idxs)
rs = []
infos = []
old_time = time.time()
for i in idxs[:500]:
    # r, info = webthink(i, to_print=True) # Original call
    # For benchmarking with multiple traces, you might want to aggregate results differently
    # This example runs with num_traces=1 to maintain similar benchmark structure
    r_val, info_val = webthink(idx=i, to_print=False, num_traces=1) # Set to_print=False for cleaner benchmark output
    rs.append(info_val['em'])
    infos.append(info_val) # info_val is a single dict here due to num_traces=1
    print(sum(rs), len(rs), sum(rs) / len(rs), (time.time() - old_time) / len(rs))
    print('-----------')
    print()

In [None]:
# Cell A: Demonstrate single trace for idxs[0]
print("--- Running webthink with num_traces=1 for idxs[0] ---")
if 'idxs' not in globals():
    print("WARNING: 'idxs' not found globally. Assuming it will be defined by a preceding cell.")
    print("For standalone execution of this cell, ensure 'idxs', 'env', 'llm', and 'webthink' are defined.")
    idxs = list(range(7405))
if 'webthink' in globals() and 'env' in globals() and 'llm' in globals() and 'idxs' in globals():
    print(f"Using idxs[0] which is: {idxs[0]}")
    r_single, info_single = webthink(idx=idxs[0], to_print=True, num_traces=1)
    print("\n--- Single Trace (idxs[0]) Summary ---")
    print(f"Reward: {r_single}")
    print(f"Answer: {info_single.get('answer')}")
    print(f"EM: {info_single.get('em')}")
    print(f"F1: {info_single.get('f1')}")
    print(f"Num Calls: {info_single.get('n_calls')}")
else:
    print("ERROR: One or more required components (idxs, webthink, env, llm) are not defined.")
    print("Please ensure all preceding setup cells are executed.")

In [None]:
# Cell B: Demonstrate multiple traces for idxs[0]
print("\n--- Running webthink with num_traces=3 for idxs[0] ---")
if 'idxs' not in globals():
    print("WARNING: 'idxs' not found globally. Assuming it will be defined by a preceding cell.")
    print("For standalone execution of this cell, ensure 'idxs', 'env', 'llm', and 'webthink' are defined.")
    idxs = list(range(7405))
if 'webthink' in globals() and 'env' in globals() and 'llm' in globals() and 'idxs' in globals():
    print(f"Using idxs[0] which is: {idxs[0]}")
    multi_traces_info = webthink(idx=idxs[0], to_print=True, num_traces=3)
    print("\n--- Multi-Trace (idxs[0]) Summary ---")
    if isinstance(multi_traces_info, list):
        for i, trace_info in enumerate(multi_traces_info):
            print(f"Trace {i+1} Summary:")
            print(f"  Answer: {trace_info.get('answer')}")
            print(f"  Reward: {trace_info.get('reward')}")
            print(f"  Num Calls: {trace_info.get('n_calls')}")
            print("  ---")
    else:
        print("Expected a list of traces for num_traces > 1, but received a single info object.")
        print(f"Received: {multi_traces_info}")
else:
    print("ERROR: One or more required components (idxs, webthink, env, llm) are not defined.")
    print("Please ensure all preceding setup cells are executed.")