Skip to content

Commit f29fda2

Browse files
committed
Add QA support w/ SimpleQA,GPQA
1 parent 770e9da commit f29fda2

File tree

13 files changed

+629
-24
lines changed

13 files changed

+629
-24
lines changed

experiments/eval/run.py

Lines changed: 26 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from typing import Optional, Dict, Any, Callable
77
from magentic_ui.eval.core import run_evaluate_benchmark_func, evaluate_benchmark_func
88
from systems.magentic_ui_sim_user_system import MagenticUISimUserSystem
9+
from systems.llm_system import LLMSystem
910
from magentic_ui.eval.benchmarks import WebVoyagerBenchmark
1011
from magentic_ui.eval.benchmark import Benchmark
1112
from autogen_core.models import ChatCompletionClient
@@ -157,19 +158,27 @@ def run_system_sim_user(args: argparse.Namespace, system_name: str) -> None:
157158
"""
158159
config = load_config(args.config)
159160

160-
system = MagenticUISimUserSystem(
161-
simulated_user_type=args.simulated_user_type,
162-
endpoint_config_orch=config.get("orchestrator_client") if config else None,
163-
endpoint_config_websurfer=config.get("web_surfer_client") if config else None,
164-
endpoint_config_coder=config.get("coder_client") if config else None,
165-
endpoint_config_file_surfer=config.get("file_surfer_client")
166-
if config
167-
else None,
168-
endpoint_config_user_proxy=config.get("user_proxy_client") if config else None,
169-
web_surfer_only=args.web_surfer_only,
170-
how_helpful_user_proxy=args.how_helpful_user_proxy,
171-
dataset_name=args.dataset,
172-
)
161+
if system_name == "LLM":
162+
# Use LLMSystem for LLM-based evaluations
163+
system = LLMSystem(
164+
system_name=system_name,
165+
endpoint_config=config.get("model_config") if config else None,
166+
dataset_name=args.dataset,
167+
)
168+
else:
169+
system = MagenticUISimUserSystem(
170+
simulated_user_type=args.simulated_user_type,
171+
endpoint_config_orch=config.get("orchestrator_client") if config else None,
172+
endpoint_config_websurfer=config.get("web_surfer_client") if config else None,
173+
endpoint_config_coder=config.get("coder_client") if config else None,
174+
endpoint_config_file_surfer=config.get("file_surfer_client")
175+
if config
176+
else None,
177+
endpoint_config_user_proxy=config.get("user_proxy_client") if config else None,
178+
web_surfer_only=args.web_surfer_only,
179+
how_helpful_user_proxy=args.how_helpful_user_proxy,
180+
dataset_name=args.dataset,
181+
)
173182

174183
run_system_evaluation(args, system, system_name, config)
175184

@@ -229,8 +238,8 @@ def main() -> None:
229238
parser.add_argument(
230239
"--system-type",
231240
type=str,
232-
default="magentic-ui",
233-
choices=["magentic-ui", "magentic-ui-sim-user"],
241+
default="MagenticUI",
242+
choices=["MagenticUI", "magentic-ui-sim-user", "LLM"],
234243
help="Type of system to run",
235244
)
236245
parser.add_argument(
@@ -250,7 +259,8 @@ def main() -> None:
250259

251260
# Determine system name based on arguments
252261

253-
system_name = "MagenticUI"
262+
system_name = args.system_type
263+
254264
if args.simulated_user_type != "none":
255265
system_name += f"_{args.simulated_user_type}_{args.how_helpful_user_proxy}"
256266
if args.web_surfer_only:

experiments/eval/systems/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from .magentic_ui_sim_user_system import MagenticUISimUserSystem
22
from .magentic_ui_system import MagenticUIAutonomousSystem
33
from .magentic_one_system import MagenticOneSystem
4+
from .llm_system import LLMSystem
45

5-
__all__ = ["MagenticUISimUserSystem", "MagenticUIAutonomousSystem", "MagententicOneSystem"]
6+
__all__ = ["MagenticUISimUserSystem", "MagenticUIAutonomousSystem", "MagententicOneSystem", "LLMSystem"]
Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,86 @@
1+
import asyncio
2+
import json
3+
import os
4+
from typing import List, Tuple, Dict, Any, Optional, Union
5+
from autogen_core import ComponentModel
6+
from autogen_core.models import ChatCompletionClient, SystemMessage, UserMessage
7+
from magentic_ui.eval.basesystem import BaseSystem
8+
from magentic_ui.eval.models import BaseQATask, BaseCandidate
9+
10+
class LLMSystem(BaseSystem):
11+
12+
default_client_config = {
13+
"provider": "OpenAIChatCompletionClient",
14+
"config": {
15+
"model": "gpt-4o-2024-08-06",
16+
},
17+
"max_retries": 10,
18+
}
19+
20+
def __init__(self, system_name, endpoint_config=default_client_config, dataset_name:str="SimpleQA"):
21+
super().__init__(system_name)
22+
23+
self.endpoint_config = endpoint_config
24+
self.dataset_name = dataset_name
25+
self.candidate_class = BaseCandidate
26+
27+
def get_answer(
28+
self, task_id: str, task: BaseQATask, output_dir: str
29+
) -> BaseCandidate:
30+
"""
31+
Runs the agent team to solve a given task and saves the answer and logs to disk.
32+
33+
Args:
34+
task_id (str): Unique identifier for the task.
35+
task (BaseTask): The task object containing the question and metadata.
36+
output_dir (str): Directory to save logs, screenshots, and answer files.
37+
38+
Returns:
39+
BaseCandidate: An object containing the final answer and any screenshots taken during execution.
40+
"""
41+
async def _runner() -> Tuple[str, List[str]]:
42+
"""Asynchronous runner to answer the task and return the answer"""
43+
task_question = task.format_to_user_message() if hasattr(task, 'format_to_user_message') else task.question
44+
system_instruction = task.system_instruction if hasattr(task, 'system_instruction') else ""
45+
46+
def get_model_client(
47+
endpoint_config: Optional[Union[ComponentModel, Dict[str, Any]]],
48+
) -> ChatCompletionClient:
49+
"""
50+
Loads a ChatCompletionClient from a given endpoint configuration.
51+
52+
Args:
53+
endpoint_config (Optional[Union[ComponentModel, Dict[str, Any]]]):
54+
The configuration for the model client.
55+
56+
Returns:
57+
ChatCompletionClient: The loaded model client.
58+
"""
59+
if endpoint_config is None:
60+
return ChatCompletionClient.load_component(
61+
self.default_client_config
62+
)
63+
return ChatCompletionClient.load_component(endpoint_config)
64+
65+
messages = [
66+
SystemMessage(content=system_instruction),
67+
UserMessage(content=task_question, source="user"),
68+
]
69+
client = get_model_client(self.endpoint_config)
70+
71+
response = await client.create(
72+
messages=messages,
73+
)
74+
75+
await client.close()
76+
77+
answer = response.content
78+
usage = response.usage
79+
80+
return answer, usage
81+
82+
answer, usage = asyncio.run(_runner())
83+
return BaseCandidate(answer=answer)
84+
85+
86+

src/magentic_ui/eval/baseqa.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
from .benchmark import Benchmark
2+
from typing import Union, Optional, Dict
3+
from .models import AllTaskTypes
4+
5+
6+
class BaseQABenchmark(Benchmark):
7+
"""Base class for Question-Answering benchmarks."""
8+
9+
def __init__(
10+
self,
11+
name: str,
12+
data_dir: Union[str, None] = None,
13+
tasks: Optional[Dict[str, AllTaskTypes]] = None,
14+
num_instances: Optional[int] = None,
15+
):
16+
super().__init__(name, data_dir, tasks)
17+
18+
self.num_instances = num_instances
19+
20+
def get_formatted_question(self, task: AllTaskTypes) -> str:
21+
raise NotImplementedError(
22+
"Subclasses must implement get_formatted_question method."
23+
)

src/magentic_ui/eval/benchmarks/__init__.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,17 @@
55
from .bearcubs.bearcubs import BearcubsBenchmark
66
from .webgames.webgames import WebGamesBenchmark
77

8+
# QA
9+
from .simpleqa.simpleqa import SimpleQABenchmark
10+
from .gpqa.gpqa import GPQABenchmark
11+
812
__all__ = [
913
"AssistantBenchBenchmark",
1014
"CustomBenchmark",
1115
"GaiaBenchmark",
1216
"WebVoyagerBenchmark",
1317
"BearcubsBenchmark",
1418
"WebGamesBenchmark",
19+
"SimpleQABenchmark",
20+
"GPQABenchmark",
1521
]

src/magentic_ui/eval/benchmarks/gpqa/__init__.py

Whitespace-only changes.
Lines changed: 118 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,118 @@
1+
""" """
2+
3+
import re
4+
import os
5+
import logging
6+
import pandas as pd
7+
from ...baseqa import BaseQABenchmark
8+
from ...models import (
9+
GPQACandidate,
10+
GPQATask,
11+
GPQAEvalResult,
12+
AllTaskTypes,
13+
)
14+
from typing import Dict, List, Union, Optional
15+
16+
from huggingface_hub import snapshot_download # type: ignore
17+
18+
19+
class GPQABenchmark(BaseQABenchmark):
20+
DATASET_URL = "hf://datasets/Idavidrein/gpqa/"
21+
DATASET_REPO_ID = "Idavidrein/gpqa"
22+
SPLITS = ["diamond", "extended", "main"]
23+
SYSTEM_INSTRUCTION = """You are a helpful assistant that answers questions."""
24+
25+
def __init__(
26+
self,
27+
name: str,
28+
data_dir: Union[str, None] = None,
29+
tasks: Optional[Dict[str, AllTaskTypes]] = None,
30+
num_instances: Optional[int] = None,
31+
system_instruction: str = SYSTEM_INSTRUCTION,
32+
):
33+
super().__init__(name, data_dir, tasks, num_instances)
34+
35+
self.system_instruction = system_instruction
36+
37+
def download_dataset(self) -> None:
38+
"""
39+
Download the dataset into self.data_dir using huggingface_hub.snapshot_download().
40+
"""
41+
assert self.data_dir is not None, "data_dir must be provided for GPQABenchmark"
42+
if not os.path.exists(self.data_dir):
43+
os.makedirs(self.data_dir, exist_ok=True)
44+
45+
logging.info(f"[GPQABenchmark] Downloading dataset into '{self.data_dir}'...")
46+
snapshot_download(
47+
repo_id=self.DATASET_REPO_ID,
48+
repo_type="dataset",
49+
local_dir=self.data_dir,
50+
local_dir_use_symlinks=True,
51+
)
52+
logging.info("[GPQABenchmark] Dataset downloaded.")
53+
54+
def load_dataset(self) -> None:
55+
"""
56+
Read all the split csvs from the dataset
57+
"""
58+
59+
split_paths = { # type: ignore
60+
split: os.path.join(self.data_dir, f"gpqa_{split}.csv") # type: ignore
61+
for split in self.SPLITS
62+
}
63+
64+
for split_name, split_path in split_paths.items(): # type: ignore
65+
if not os.path.exists(split_path): # type: ignore
66+
raise FileNotFoundError(f"Dataset file {split_path} does not exist.")
67+
68+
df = pd.read_csv(split_path) # type: ignore
69+
for _, row in df.iterrows():
70+
self.tasks[row["Record ID"]] = GPQATask( # type: ignore
71+
id=row["Record ID"], # type: ignore
72+
question=row["Question"],
73+
answer=row["Correct Answer"], # type: ignore
74+
options=[ # type: ignore
75+
row["Correct Answer"],
76+
row["Incorrect Answer 1"],
77+
row["Incorrect Answer 2"],
78+
row["Incorrect Answer 3"],
79+
],
80+
set=split_name,
81+
metadata=row.to_dict(), # type: ignore
82+
system_instruction=self.system_instruction, # type: ignore
83+
)
84+
85+
logging.info(
86+
f"[GPQABenchmark] Loaded {len(self.tasks)} tasks from {self.SPLITS} splits from the dataset."
87+
)
88+
89+
def get_split_tasks(self, split: str) -> List[str]:
90+
assert (
91+
split in self.SPLITS
92+
), f"Invalid split: {split}. Must be one of {self.SPLITS}."
93+
return [task.id for task in self.tasks.values() if task.set == split]
94+
95+
def evaluator(self, task: GPQATask, candidate: GPQACandidate) -> GPQAEvalResult: # type: ignore
96+
if isinstance(task, Dict):
97+
task = GPQATask(**task) # type: ignore
98+
if isinstance(candidate, Dict):
99+
candidate = GPQACandidate(**candidate) # type: ignore
100+
101+
answer_search_by_format = re.search(
102+
r"(?i)Answer[ \t]*:[ \t]*\$?([A-D])\$?", candidate.answer
103+
)
104+
extracted_answer = (
105+
answer_search_by_format.group(1) if answer_search_by_format else None
106+
)
107+
108+
ground_truth_answer = task.answer # type: ignore
109+
score = ground_truth_answer == extracted_answer # type: ignore
110+
return GPQAEvalResult( # type: ignore
111+
score=score, # type: ignore
112+
metadata={
113+
"ground_truth_answer": ground_truth_answer,
114+
"extracted_answer": extracted_answer,
115+
"llm_response": candidate.answer,
116+
"task_id": task.id,
117+
},
118+
)

src/magentic_ui/eval/benchmarks/simpleqa/__init__.py

Whitespace-only changes.

0 commit comments

Comments
 (0)