22import asyncio
33import json
44import os
5- from typing import Any , List , Callable , Optional , Tuple , Dict
5+ from typing import Any , List , Callable , Optional , Tuple , Dict , Awaitable , Union
66from traceloop .sdk .client .http import HTTPClient
77from traceloop .sdk .datasets .datasets import Datasets
88from traceloop .sdk .evaluator .evaluator import Evaluator
@@ -36,7 +36,7 @@ def __init__(self, http_client: HTTPClient, async_http_client: httpx.AsyncClient
3636
3737 async def run (
3838 self ,
39- task : Callable [[Optional [Dict [str , Any ]]], Dict [str , Any ]],
39+ task : Callable [[Optional [Dict [str , Any ]]], Awaitable [ Dict [str , Any ] ]],
4040 dataset_slug : Optional [str ] = None ,
4141 dataset_version : Optional [str ] = None ,
4242 evaluators : Optional [List [EvaluatorDetails ]] = None ,
@@ -52,7 +52,7 @@ async def run(
5252 Otherwise, will run the experiment locally.
5353
5454 Args:
55- task: Function to run on each dataset row
55+ task: Async function to run on each dataset row
5656 dataset_slug: Slug of the dataset to use
5757 dataset_version: Version of the dataset to use
5858 evaluators: List of evaluator slugs to run
@@ -87,11 +87,10 @@ async def run(
8787 stop_on_error = stop_on_error ,
8888 wait_for_results = wait_for_results ,
8989 )
90-
91-
90+
9291 async def _run_locally (
9392 self ,
94- task : Callable [[Optional [Dict [str , Any ]]], Dict [str , Any ]],
93+ task : Callable [[Optional [Dict [str , Any ]]], Awaitable [ Dict [str , Any ] ]],
9594 dataset_slug : Optional [str ] = None ,
9695 dataset_version : Optional [str ] = None ,
9796 evaluators : Optional [List [EvaluatorDetails ]] = None ,
@@ -106,7 +105,7 @@ async def _run_locally(
106105
107106 Args:
108107 dataset_slug: Slug of the dataset to use
109- task: Function to run on each dataset row
108+ task: Async function to run on each dataset row
110109 evaluators: List of evaluator slugs to run
111110 experiment_slug: Slug for this experiment run
112111 experiment_metadata: Metadata for this experiment (an experiment holds all the experiment runs)
@@ -160,17 +159,15 @@ async def _run_locally(
160159
161160 async def run_single_row (row : Optional [Dict [str , Any ]]) -> TaskResponse :
162161 try :
163- # TODO: Fix type annotation - task should return Awaitable, not dict
164- task_result = await task (row ) # type: ignore[misc]
165- # TODO: Fix type - task_input should accept Optional[Dict]
162+ task_result = await task (row )
166163 task_id = self ._create_task (
167164 experiment_slug = experiment_slug ,
168165 experiment_run_id = run_id ,
169- task_input = row , # type: ignore[arg-type]
166+ task_input = row ,
170167 task_output = task_result ,
171168 ).id
172169
173- eval_results = {}
170+ eval_results : Dict [ str , Union [ Dict [ str , Any ], str ]] = {}
174171 if evaluator_details :
175172 for evaluator_slug , evaluator_version in evaluator_details :
176173 try :
@@ -197,13 +194,11 @@ async def run_single_row(row: Optional[Dict[str, Any]]) -> TaskResponse:
197194 input = task_result ,
198195 )
199196
200- # TODO: Fix type - eval_results should accept Union[Dict, str]
201197 msg = f"Triggered execution of { evaluator_slug } "
202- eval_results [evaluator_slug ] = msg # type: ignore[assignment]
198+ eval_results [evaluator_slug ] = msg
203199
204200 except Exception as e :
205- # TODO: Fix type - eval_results should accept Union[Dict, str]
206- eval_results [evaluator_slug ] = f"Error: { str (e )} " # type: ignore[assignment]
201+ eval_results [evaluator_slug ] = f"Error: { str (e )} "
207202
208203 return TaskResponse (
209204 task_result = task_result ,
@@ -245,7 +240,7 @@ async def run_with_semaphore(row: Optional[Dict[str, Any]]) -> TaskResponse:
245240
246241 async def _run_in_github (
247242 self ,
248- task : Callable [[Optional [Dict [str , Any ]]], Dict [str , Any ]],
243+ task : Callable [[Optional [Dict [str , Any ]]], Awaitable [ Dict [str , Any ] ]],
249244 dataset_slug : Optional [str ] = None ,
250245 dataset_version : Optional [str ] = None ,
251246 evaluators : Optional [List [EvaluatorDetails ]] = None ,
@@ -262,7 +257,7 @@ async def _run_in_github(
262257 4. Backend runs evaluators and posts PR comment
263258
264259 Args:
265- task: Function to run on each dataset row
260+ task: Async function to run on each dataset row
266261 dataset_slug: Slug of the dataset to use
267262 dataset_version: Version of the dataset
268263 evaluators: List of evaluator slugs or (slug, version) tuples to run
@@ -398,7 +393,7 @@ def _create_task(
398393 self ,
399394 experiment_slug : str ,
400395 experiment_run_id : str ,
401- task_input : Dict [str , Any ],
396+ task_input : Optional [ Dict [str , Any ] ],
402397 task_output : Dict [str , Any ],
403398 ) -> CreateTaskResponse :
404399 body = CreateTaskRequest (
@@ -433,7 +428,7 @@ def _parse_jsonl_to_rows(self, jsonl_data: str) -> List[Dict[str, Any]]:
433428 async def _execute_tasks (
434429 self ,
435430 rows : List [Dict [str , Any ]],
436- task : Callable [[Optional [Dict [str , Any ]]], Dict [str , Any ]],
431+ task : Callable [[Optional [Dict [str , Any ]]], Awaitable [ Dict [str , Any ] ]],
437432 ) -> List [TaskResult ]:
438433 """Execute tasks locally with concurrency control
439434
@@ -447,7 +442,7 @@ async def _execute_tasks(
447442 """
448443 task_results : List [TaskResult ] = []
449444
450- async def run_single_row (row ) -> TaskResult :
445+ async def run_single_row (row : Optional [ Dict [ str , Any ]] ) -> TaskResult :
451446 try :
452447 task_output = await task (row )
453448 return TaskResult (
0 commit comments