# Setup

In [1]:
import os
os.chdir('..')
print("Changed working directory to:", os.getcwd())

Changed working directory to: c:\Users\rishi\Documents\summer-research\react-research


In [3]:
import os
from google import genai
from google.genai import types
# The client gets the API key from the environment variable `GEMINI_API_KEY`.
client = genai.Client()

import time

def llm(prompt, stop=["\n"], num_traces=1):
  # This delay handles the 15 RPM limit by waiting ~4 seconds per call.
  time.sleep(4.1) # Adjusted sleep time if necessary based on API limits

  temperature_setting = 0.0 if num_traces == 1 else 0.7
  response = client.models.generate_content(
    model="gemini-2.5-flash-lite-preview-06-17", # Or other appropriate Gemini model
    contents=prompt,
    config=types.GenerateContentConfig(
        thinking_config=types.ThinkingConfig(thinking_budget=0), # Disables thinking
        stop_sequences=stop,
        temperature=temperature_setting,
        max_output_tokens=100, # As in original FEVER
        top_p=1.0
    )
  )
  return response.text

In [9]:
import re

def extract_final_answer_from_trace_string(trace_trajectory_string):
    """
    Extracts the final answer from a ReAct trace trajectory string.
    Looks for the last occurrence of 'Action X: Finish[answer]'.
    For FEVER, the answer is one of 'SUPPORTS', 'REFUTES', 'NOT ENOUGH INFO'.
    """
    # Pattern for FEVER: Finish[SUPPORTS], Finish[REFUTES], or Finish[NOT ENOUGH INFO]
    pattern = re.compile(r"^Action \d+: Finish\[(SUPPORTS|REFUTES|NOT ENOUGH INFO)\]\s*$", re.MULTILINE)
    matches = pattern.findall(trace_trajectory_string)
    
    if matches:
        # The last match in the string is the one we want
        return matches[-1].strip()
        
    return None

def extract_answers_from_traces(all_traces_info):
    """
    Extracts the final answer from each trace in the all_traces_info list.
    """
    extracted_answers = []
    if not isinstance(all_traces_info, list):
        print(f"Warning: extract_answers_from_traces expected a list, got {type(all_traces_info)}")
        return extracted_answers

    for i, trace_info in enumerate(all_traces_info):
        trajectory = trace_info.get('traj', '')
        answer_from_traj = extract_final_answer_from_trace_string(trajectory)
        
        if answer_from_traj is not None:
            extracted_answers.append(answer_from_traj)
        else:
            # Fallback to 'answer' field in info if Finish action not found in trajectory
            # This might happen if the agent stops early or format is unexpected
            env_answer = trace_info.get('answer')
            if env_answer in ['SUPPORTS', 'REFUTES', 'NOT ENOUGH INFO']:
                extracted_answers.append(env_answer)
            else:
                # If no valid answer found in this trace, append None or a placeholder
                extracted_answers.append(None) # Or 'NOT ENOUGH INFO' as a default
    
    # Filter out any Nones if a trace truly failed to produce an answer
    return [ans for ans in extracted_answers if ans is not None]

In [10]:
def synthesize_answer_with_llm(list_of_answers, question_for_context=""):
    """
    Synthesizes a single best answer from a list of answers using an LLM.
    For FEVER, this is a majority vote among 'SUPPORTS', 'REFUTES', 'NOT ENOUGH INFO'.
    If no clear majority, or if answers are mixed/absent, defaults to 'NOT ENOUGH INFO'.
    """
    if not list_of_answers:
        return "NOT ENOUGH INFO" # Default if no answers provided

    # Count occurrences of each valid FEVER answer
    counts = {'SUPPORTS': 0, 'REFUTES': 0, 'NOT ENOUGH INFO': 0}
    valid_answers_found = False
    for ans in list_of_answers:
        if ans in counts:
            counts[ans] += 1
            valid_answers_found = True

    if not valid_answers_found:
        return "NOT ENOUGH INFO"

    # Determine the majority answer
    # If 'NOT ENOUGH INFO' is the most common, or if there's a tie involving it,
    # or if SUPPORTS and REFUTES are tied, it's safest to return 'NOT ENOUGH INFO'.
    max_count = 0
    majority_answer = "NOT ENOUGH INFO" # Default to this if no clear majority
    
    # Check for a clear majority for SUPPORTS or REFUTES
    if counts['SUPPORTS'] > counts['REFUTES'] and counts['SUPPORTS'] > counts['NOT ENOUGH INFO']:
        majority_answer = 'SUPPORTS'
    elif counts['REFUTES'] > counts['SUPPORTS'] and counts['REFUTES'] > counts['NOT ENOUGH INFO']:
        majority_answer = 'REFUTES'
    # If 'NOT ENOUGH INFO' is strictly the most frequent, it becomes the answer
    elif counts['NOT ENOUGH INFO'] > counts['SUPPORTS'] and counts['NOT ENOUGH INFO'] > counts['REFUTES']:
        majority_answer = 'NOT ENOUGH INFO'
    # Handle ties: if SUPPORTS and REFUTES are tied and greater than NEI, it's ambiguous.
    # If any other tie involves NEI, or if NEI is the highest (even if tied), default to NEI.
    else: # Ties or situations where NEI is involved in the top count
        majority_answer = "NOT ENOUGH INFO"
        
    return majority_answer

In [11]:
import wikienv, wrappers
import requests # Make sure requests is imported
env = wikienv.WikiEnv()
env = wrappers.FeverWrapper(env, split="dev") # Ensure FeverWrapper is used
env = wrappers.LoggingWrapper(env)

def step(env, action):
    attempts = 0
    while attempts < 10:
        try:
            return env.step(action)
        except requests.exceptions.Timeout:
            attempts += 1
    # Fallback if all attempts fail (though env.step should handle this internally too)
    return "Timeout after 10 attempts", 0, False, {}

# ReAct

In [12]:
import json
import sys

folder = './prompts/'
prompt_file = 'fever.json' # Ensure correct prompt file for FEVER
with open(folder + prompt_file, 'r') as f:
    prompt_dict = json.load(f)

# Use the FEVER-specific prompt from the loaded dict
# Assuming 'webthink_simple3' is the key for FEVER in fever.json as in the original FEVER notebook
webthink_prompt_template = prompt_dict['webthink_simple3'] 

def webthink(idx=None, initial_prompt_template=webthink_prompt_template, to_print=True, num_traces=1):
    all_traces_info = []
    question_for_synthesis = "" # Define outside loop to store it

    if num_traces <= 0:
        if to_print:
            print(f"Warning: webthink called with num_traces = {num_traces}. Must be > 0.")
        return "[INVALID_NUM_TRACES]", {'error': 'num_traces must be > 0', 'traces': []}

    for trace_num in range(num_traces):
        question = env.reset(idx=idx) # Reset environment for each trace
        if trace_num == 0: # Capture question on first trace for synthesizer
            question_for_synthesis = question
        
        current_prompt = initial_prompt_template + question + "\n"

        if to_print:
            print(f"--- Trace {trace_num + 1}/{num_traces} ---")
            print(idx, question)
        
        n_calls, n_badcalls = 0, 0
        current_trace_steps = []

        for i in range(1, 8): # Max 7 steps per trace
            n_calls += 1
            # Pass num_traces to llm for temperature adjustment
            thought_action = llm(current_prompt + f"Thought {i}:", stop=[f"\nObservation {i}:"], num_traces=num_traces)
            try:
                thought, action = thought_action.strip().split(f"\nAction {i}: ")
            except:
                if to_print:
                    print('Error parsing thought/action:', thought_action)
                n_badcalls += 1
                n_calls += 1
                thought = thought_action.strip().split('\n')[0]
                action = llm(current_prompt + f"Thought {i}: {thought}\nAction {i}:", stop=[f"\n"], num_traces=num_traces).strip()
            
            obs, r, done, info = step(env, action[0].lower() + action[1:])
            obs = obs.replace('\\n', '')
            
            step_str = f"Thought {i}: {thought}\nAction {i}: {action}\nObservation {i}: {obs}\n"
            current_prompt += step_str
            current_trace_steps.append(step_str)
            
            if to_print:
                print(step_str)
            
            if done:
                break
        
        # Ensure 'info' is a dictionary, initialize if it's not (e.g. from timeout in step)
        if not isinstance(info, dict):
            info = {}
            
        if not done:
            # If agent didn't finish, force a finish action (important for FEVER evaluation)
            # The default or forced finish answer for FEVER should be 'NOT ENOUGH INFO'
            # if the agent itself doesn't provide one.
            # However, the env step for finish[] might return the correct state.
            obs, r, done, info_finish = step(env, "finish[NOT ENOUGH INFO]") 
            info.update(info_finish) # Update with info from finish step
            if 'answer' not in info or not info['answer']:
                 info['answer'] = 'NOT ENOUGH INFO' # Default if not set by env
            if to_print:
                print(f"Agent did not finish. Forced Finish: {info.get('answer')}")

        trace_info_package = info.copy()
        trace_info_package.update({'n_calls': n_calls, 
                           'n_badcalls': n_badcalls, 
                           'traj': initial_prompt_template + question + "\n" + "".join(current_trace_steps), 
                           'question_idx': idx,
                           'question_text': question, 
                           'trace_num': trace_num + 1})
        all_traces_info.append(trace_info_package)

        if to_print:
            print(f"(Trace {trace_num + 1}) Info: {trace_info_package}\n")
            if num_traces > 1 and trace_num < num_traces - 1:
                print(f"--- End of Trace {trace_num + 1} ---\n")
    
    if not all_traces_info: 
        if to_print:
            print("Warning: No traces were generated despite num_traces > 0.")
        # Return structure consistent with FEVER's original: (reward, info_dict)
        # For FEVER, reward is 0 or 1 (em score), info_dict contains details
        return 0, {'question_idx': idx, 'answer': '[NO_TRACE_GENERATED]', 'em':0, 'f1':0, 'reward':0, 'traces': []}

    if num_traces == 1:
        final_info_package = all_traces_info[0]
        # FEVER's webthink expects to return (reward, info_dict)
        # 'em' is often used as the primary reward signal in FEVER
        final_reward = final_info_package.get('em', 0.0) 
        return final_reward, final_info_package
    
    else: # num_traces > 1
        if to_print:
            print("\n--- Starting Answer Synthesis for FEVER ---")
        
        extracted_answers = extract_answers_from_traces(all_traces_info)
        
        if to_print:
            print(f"Extracted Answers for Synthesis: {extracted_answers}")

        if not extracted_answers:
            if to_print:
                print("Warning: No answers extracted from traces. Defaulting to NOT ENOUGH INFO.")
            synthesized_answer = "NOT ENOUGH INFO"
        else:
            synthesized_answer = synthesize_answer_with_llm(extracted_answers, question_for_synthesis)
        
        if to_print:
            print(f"Synthesized Answer: {synthesized_answer}")
            print("--- End of Answer Synthesis ---\n")

        # For multi-trace FEVER, we need to create a summary info package.
        # The 'reward' or 'em' should be based on the synthesized answer vs. ground truth.
        # We need the ground truth for this, which is in info['gt_answer'] from env.reset()
        # Get gt_answer from the first trace's info (it's the same for all traces of the same question)
        gt_answer = all_traces_info[0].get('gt_answer', 'UNKNOWN_GT_ANSWER')
        em_score = 1.0 if synthesized_answer == gt_answer else 0.0
        
        # Aggregate calls from all traces
        total_calls = sum(t.get('n_calls', 0) for t in all_traces_info)
        total_badcalls = sum(t.get('n_badcalls', 0) for t in all_traces_info)

        # Create a final info package for the synthesized result
        synthesized_info = {
            'question_idx': idx,
            'question_text': question_for_synthesis,
            'answer': synthesized_answer, # The synthesized answer
            'gt_answer': gt_answer,
            'em': em_score,
            'f1': em_score, # For FEVER, EM and F1 are the same for SUPPORTS/REFUTES/NEI
            'reward': em_score, 
            'n_calls': total_calls,
            'n_badcalls': total_badcalls,
            'num_traces_run': num_traces,
            'individual_traces': all_traces_info # List of all trace details
        }
        return em_score, synthesized_info

In [13]:
# This cell is for demonstrating the standard ReAct agent (num_traces=1)
import random
import time

print("--- Running standard ReAct (num_traces=1) for one FEVER example ---")
idxs = list(range(7405))  # 7405 is the number of FEVER dev examples
random.Random(233).shuffle(idxs)

if idxs: # Ensure there's at least one example
    example_idx = idxs[0]
    print(f"Using FEVER example with index: {example_idx}\n")
    # Call webthink with num_traces=1 for standard ReAct
    reward_single, info_single = webthink(idx=example_idx, to_print=False, num_traces=1)
    
    # Extract the trajectory and show only the part for the current question
    full_traj = (info_single.get('traj') or '').strip()
    question_text = info_single.get('question_text') or ''
    # Find the start of the current question in the trajectory
    if question_text in full_traj:
        traj_start = full_traj.index(question_text)
        traj_for_example = full_traj[traj_start:]
    else:
        traj_for_example = full_traj

    # Print the LLM calls (Thought/Action/Observation) for this trace
    print("--- LLM Trace (Thought/Action/Observation) ---")
    for line in traj_for_example.splitlines():
        if line.strip().startswith("Thought") or line.strip().startswith("Action") or line.strip().startswith("Observation"):
            print(line)
    print("\n--- Standard ReAct (num_traces=1) Summary ---")
    print(f"Question Index: {info_single.get('question_idx')}")
    print(f"Question/Claim (asked to LLM):\n{info_single.get('question_text')}")
    print(f"Agent's Answer: {info_single.get('answer')}")
    print(f"Ground Truth: {info_single.get('gt_answer')}")
    print(f"EM Score (Reward): {info_single.get('em')}")
    print(f"F1 Score: {info_single.get('f1')}")
    print(f"Number of LLM Calls: {info_single.get('n_calls')}")
    print("\nTrajectory for this example:\n" + traj_for_example)
else:
    print("No examples found for FEVER dev split.")

--- Running standard ReAct (num_traces=1) for one FEVER example ---
Using FEVER example with index: 3687

--- LLM Trace (Thought/Action/Observation) ---
Thought 1: I need to search for Paramore and find out where they are from.
Action 1: Search[Paramore]
Observation 1: Paramore is an American rock band formed in Franklin, Tennessee, in 2004. Since 2017, the band's lineup includes lead vocalist Hayley Williams, lead guitarist Taylor York, and drummer Zac Farro. Williams and Farro are founding members of the group, while York, a high school friend of the original lineup, joined in 2007. The band has had multiple lineup changes, with Williams being the only constant member.. The band was signed to Fueled by Ramen, a subsidiary of Atlantic Records (which is owned by Warner Music Group.) Williams was signed to Atlantic separately, as she was scouted when she was a teenager.
Thought 2: The observation states that Paramore was "formed in Franklin, Tennessee". This directly contradicts the cla

In [55]:
# This cell is for demonstrating the updated ReAct agent with num_traces = 3
import random
import time

print("\n--- Running updated ReAct (num_traces=3) for one FEVER example ---")
# Use the same idxs and seed as previous cell; do not reshuffle or redefine idxs
if idxs: # Ensure there's at least one example
    example_idx_multi = idxs[0]  # Use the same example as previous cell
    print(f"Using FEVER example with index: {example_idx_multi} for multi-trace run\n")
    
    # Call webthink with num_traces=3 for updated ReAct
    synthesized_reward, synthesized_info = webthink(idx=example_idx_multi, to_print=False, num_traces=3)
    traces = synthesized_info.get('individual_traces', [])
    
    # Show LLM calls and summary for each trace
    for i, trace_detail in enumerate(traces):
        print(f"--- LLM Trace {i+1} (Thought/Action/Observation) ---")
        traj = (trace_detail.get('traj') or '').strip()
        question_text = trace_detail.get('question_text') or ''
        if question_text in traj:
            traj_start = traj.index(question_text)
            traj_for_example = traj[traj_start:]
        else:
            traj_for_example = traj
        for line in traj_for_example.splitlines():
            if line.strip().startswith("Thought") or line.strip().startswith("Action") or line.strip().startswith("Observation"):
                print(line)
        print(f"\nTrace {i+1} Summary:")
        print(f"  Question Index: {trace_detail.get('question_idx')}")
        print(f"  Question/Claim (asked to LLM):\n{trace_detail.get('question_text')}")
        print(f"  Trace Answer: {trace_detail.get('answer')}")
        print(f"  Ground Truth: {trace_detail.get('gt_answer')}")
        print(f"  Trace EM: {trace_detail.get('em')}")
        print(f"  Trace F1: {trace_detail.get('f1')}")
        print(f"  LLM Calls: {trace_detail.get('n_calls')}")
        print()
    print("--- Final Synthesized Summary (across all traces) ---")
    print(f"Question Index: {synthesized_info.get('question_idx')}")
    print(f"Question/Claim (asked to LLM):\n{synthesized_info.get('question_text')}")
    print(f"Synthesized Answer: {synthesized_info.get('answer')}")
    print(f"Ground Truth: {synthesized_info.get('gt_answer')}")
    print(f"Synthesized EM Score (Reward): {synthesized_info.get('em')}")
    print(f"Synthesized F1 Score: {synthesized_info.get('f1')}")
    print(f"Total Number of LLM Calls (all traces): {synthesized_info.get('n_calls')}")
    print(f"Number of Traces Run: {synthesized_info.get('num_traces_run')}")
    print("\n--- Trajectories for Each Trace ---")
    for i, trace_detail in enumerate(traces):
        traj = (trace_detail.get('traj') or '').strip()
        question_text = trace_detail.get('question_text') or ''
        if question_text in traj:
            traj_start = traj.index(question_text)
            traj_for_example = traj[traj_start:]
        else:
            traj_for_example = traj
        print(f"\nTrajectory for Trace {i+1}:\n" + traj_for_example)
else:
    print("No examples found for FEVER dev split for multi-trace run.")


--- Running updated ReAct (num_traces=3) for one FEVER example ---
Using FEVER example with index: 3687 for multi-trace run

--- LLM Trace 1 (Thought/Action/Observation) ---
Thought 1: I need to search for Paramore and find out where they are from.
Action 1: Search[Paramore]
Observation 1: Paramore is an American rock band formed in Franklin, Tennessee, in 2004. Since 2017, the band's lineup includes lead vocalist Hayley Williams, lead guitarist Taylor York, and drummer Zac Farro. Williams and Farro are founding members of the group, while York, a high school friend of the original lineup, joined in 2007. The band has had multiple lineup changes, with Williams being the only constant member.. The band was signed to Fueled by Ramen, a subsidiary of Atlantic Records (which is owned by Warner Music Group.) Williams was signed to Atlantic separately, as she was scouted when she was a teenager.
Thought 2: The observation states that Paramore was "formed in Franklin, Tennessee". This direct