Skip to content

Commit d36d712

Browse files
committed
my py
1 parent 5b7817e commit d36d712

File tree

1 file changed

+16
-21
lines changed

1 file changed

+16
-21
lines changed

packages/traceloop-sdk/traceloop/sdk/experiment/experiment.py

Lines changed: 16 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import asyncio
33
import json
44
import os
5-
from typing import Any, List, Callable, Optional, Tuple, Dict
5+
from typing import Any, List, Callable, Optional, Tuple, Dict, Awaitable, Union
66
from traceloop.sdk.client.http import HTTPClient
77
from traceloop.sdk.datasets.datasets import Datasets
88
from traceloop.sdk.evaluator.evaluator import Evaluator
@@ -36,7 +36,7 @@ def __init__(self, http_client: HTTPClient, async_http_client: httpx.AsyncClient
3636

3737
async def run(
3838
self,
39-
task: Callable[[Optional[Dict[str, Any]]], Dict[str, Any]],
39+
task: Callable[[Optional[Dict[str, Any]]], Awaitable[Dict[str, Any]]],
4040
dataset_slug: Optional[str] = None,
4141
dataset_version: Optional[str] = None,
4242
evaluators: Optional[List[EvaluatorDetails]] = None,
@@ -52,7 +52,7 @@ async def run(
5252
Otherwise, will run the experiment locally.
5353
5454
Args:
55-
task: Function to run on each dataset row
55+
task: Async function to run on each dataset row
5656
dataset_slug: Slug of the dataset to use
5757
dataset_version: Version of the dataset to use
5858
evaluators: List of evaluator slugs to run
@@ -87,11 +87,10 @@ async def run(
8787
stop_on_error=stop_on_error,
8888
wait_for_results=wait_for_results,
8989
)
90-
91-
90+
9291
async def _run_locally(
9392
self,
94-
task: Callable[[Optional[Dict[str, Any]]], Dict[str, Any]],
93+
task: Callable[[Optional[Dict[str, Any]]], Awaitable[Dict[str, Any]]],
9594
dataset_slug: Optional[str] = None,
9695
dataset_version: Optional[str] = None,
9796
evaluators: Optional[List[EvaluatorDetails]] = None,
@@ -106,7 +105,7 @@ async def _run_locally(
106105
107106
Args:
108107
dataset_slug: Slug of the dataset to use
109-
task: Function to run on each dataset row
108+
task: Async function to run on each dataset row
110109
evaluators: List of evaluator slugs to run
111110
experiment_slug: Slug for this experiment run
112111
experiment_metadata: Metadata for this experiment (an experiment holds all the experiment runs)
@@ -160,17 +159,15 @@ async def _run_locally(
160159

161160
async def run_single_row(row: Optional[Dict[str, Any]]) -> TaskResponse:
162161
try:
163-
# TODO: Fix type annotation - task should return Awaitable, not dict
164-
task_result = await task(row) # type: ignore[misc]
165-
# TODO: Fix type - task_input should accept Optional[Dict]
162+
task_result = await task(row)
166163
task_id = self._create_task(
167164
experiment_slug=experiment_slug,
168165
experiment_run_id=run_id,
169-
task_input=row, # type: ignore[arg-type]
166+
task_input=row,
170167
task_output=task_result,
171168
).id
172169

173-
eval_results = {}
170+
eval_results: Dict[str, Union[Dict[str, Any], str]] = {}
174171
if evaluator_details:
175172
for evaluator_slug, evaluator_version in evaluator_details:
176173
try:
@@ -197,13 +194,11 @@ async def run_single_row(row: Optional[Dict[str, Any]]) -> TaskResponse:
197194
input=task_result,
198195
)
199196

200-
# TODO: Fix type - eval_results should accept Union[Dict, str]
201197
msg = f"Triggered execution of {evaluator_slug}"
202-
eval_results[evaluator_slug] = msg # type: ignore[assignment]
198+
eval_results[evaluator_slug] = msg
203199

204200
except Exception as e:
205-
# TODO: Fix type - eval_results should accept Union[Dict, str]
206-
eval_results[evaluator_slug] = f"Error: {str(e)}" # type: ignore[assignment]
201+
eval_results[evaluator_slug] = f"Error: {str(e)}"
207202

208203
return TaskResponse(
209204
task_result=task_result,
@@ -245,7 +240,7 @@ async def run_with_semaphore(row: Optional[Dict[str, Any]]) -> TaskResponse:
245240

246241
async def _run_in_github(
247242
self,
248-
task: Callable[[Optional[Dict[str, Any]]], Dict[str, Any]],
243+
task: Callable[[Optional[Dict[str, Any]]], Awaitable[Dict[str, Any]]],
249244
dataset_slug: Optional[str] = None,
250245
dataset_version: Optional[str] = None,
251246
evaluators: Optional[List[EvaluatorDetails]] = None,
@@ -262,7 +257,7 @@ async def _run_in_github(
262257
4. Backend runs evaluators and posts PR comment
263258
264259
Args:
265-
task: Function to run on each dataset row
260+
task: Async function to run on each dataset row
266261
dataset_slug: Slug of the dataset to use
267262
dataset_version: Version of the dataset
268263
evaluators: List of evaluator slugs or (slug, version) tuples to run
@@ -398,7 +393,7 @@ def _create_task(
398393
self,
399394
experiment_slug: str,
400395
experiment_run_id: str,
401-
task_input: Dict[str, Any],
396+
task_input: Optional[Dict[str, Any]],
402397
task_output: Dict[str, Any],
403398
) -> CreateTaskResponse:
404399
body = CreateTaskRequest(
@@ -433,7 +428,7 @@ def _parse_jsonl_to_rows(self, jsonl_data: str) -> List[Dict[str, Any]]:
433428
async def _execute_tasks(
434429
self,
435430
rows: List[Dict[str, Any]],
436-
task: Callable[[Optional[Dict[str, Any]]], Dict[str, Any]],
431+
task: Callable[[Optional[Dict[str, Any]]], Awaitable[Dict[str, Any]]],
437432
) -> List[TaskResult]:
438433
"""Execute tasks locally with concurrency control
439434
@@ -447,7 +442,7 @@ async def _execute_tasks(
447442
"""
448443
task_results: List[TaskResult] = []
449444

450-
async def run_single_row(row) -> TaskResult:
445+
async def run_single_row(row: Optional[Dict[str, Any]]) -> TaskResult:
451446
try:
452447
task_output = await task(row)
453448
return TaskResult(

0 commit comments

Comments
 (0)