Skip to content

Commit

Permalink
Aggregate
Browse files Browse the repository at this point in the history
  • Loading branch information
pantonante committed Feb 19, 2024
1 parent adf8713 commit 813ad6b
Show file tree
Hide file tree
Showing 8 changed files with 234 additions and 152 deletions.
14 changes: 3 additions & 11 deletions continuous_eval/eval/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,9 @@ class Dataset:
def __init__(self, dataset_path: typing.Union[str, Path]) -> None:
if isinstance(dataset_path, str):
dataset_path = Path(dataset_path)
assert dataset_path.exists(), f"Dataset folder {dataset_name} does not exist"
assert (dataset_path / "manifest.yaml").exists(), f"Manifest file not found in {dataset_name}"
assert (dataset_path / "dataset.jsonl").exists(), f"Dataset file not found in {dataset_name}"
assert dataset_path.exists(), f"Dataset folder {dataset_path} does not exist"
assert (dataset_path / "manifest.yaml").exists(), f"Manifest file not found in {dataset_path}"
assert (dataset_path / "dataset.jsonl").exists(), f"Dataset file not found in {dataset_path}"
# Load manifest
with open(dataset_path / "manifest.yaml", "r") as manifest_file:
self._manifest = yaml.safe_load(manifest_file)
Expand Down Expand Up @@ -61,11 +61,3 @@ def filed_types(self, name: str) -> type:
@property
def data(self):
return self._data

# def get_value(self, field: typing.Union[str, DatasetField], index: int):
# if isinstance(field, str):
# return self._data[index][field]
# elif isinstance(field, DatasetField):
# return self._data[index][field.name]
# else:
# raise ValueError(f"field {field} not recognized")
202 changes: 82 additions & 120 deletions continuous_eval/eval/manager.py
Original file line number Diff line number Diff line change
@@ -1,131 +1,108 @@
import json
from collections import ChainMap
import warnings
from pathlib import Path
from typing import Type, get_origin
from typing import Dict, List, Optional, get_origin

from loguru import logger

from continuous_eval.eval.dataset import Dataset, DatasetField
from continuous_eval.eval.pipeline import ModuleOutput, Pipeline


def _instantiate_type(type_hint: Type):
origin = get_origin(type_hint)
# If the origin is None, it means type_hint is not a generic type
# and we assume type_hint itself is directly instantiable
if origin is None:
origin = type_hint
try:
# This only works for types without required arguments in their __init__.
instance = origin()
except TypeError as e:
# If instantiation fails, return an error message or raise a custom exception
instance = None
return instance
from continuous_eval.eval.result_types import EvaluationResults, MetricsResults, TestResults


class EvaluationManager:
def __init__(self):
self._pipeline = None
self._dataset = None
self._samples = None
self._eval_results = None
self._test_results = None
self._is_running = False
self._pipeline: Optional[Pipeline] = None
self._eval_results: EvaluationResults = EvaluationResults()
self._metrics_results: MetricsResults = MetricsResults()
self._test_results: TestResults = TestResults()
self._is_running: bool = False

self._idx = 0

def _build_empty_samples(self):
assert self.pipeline is not None, "Pipeline not set"
empty_samples = dict()
for module in self.pipeline.modules:
empty_samples[module.name] = _instantiate_type(module.output)
return empty_samples

@property
def is_complete(self):
return self._idx == len(self._dataset)
def is_complete(self) -> bool:
if self._pipeline is None:
return False
return self._idx == len(self._pipeline.dataset.data)

@property
def samples(self):
return self._samples
def samples(self) -> List[dict]:
return self._eval_results.results

@property
def pipeline(self) -> Pipeline:
return self._pipeline
def evaluation(self) -> EvaluationResults:
return self._eval_results

@property
def dataset(self) -> Dataset:
return self._dataset
def metrics(self) -> MetricsResults:
return self._metrics_results

@property
def test_results(self):
def tests(self) -> TestResults:
return self._test_results

@property
def eval_results(self):
return self._eval_results
def pipeline(self) -> Pipeline:
if self._pipeline is None:
raise ValueError("Pipeline not set")
return self._pipeline

@property
def dataset(self) -> Dataset:
if self._pipeline is None:
raise ValueError("Pipeline not set")
if self._pipeline.dataset is None:
raise ValueError("Dataset not set")
return self._pipeline.dataset

def set_pipeline(self, pipeline: Pipeline):
self._metrics_results.pipeline = pipeline
self._pipeline = pipeline
self._dataset = pipeline.dataset

def is_running(self) -> bool:
return self.is_running
return self._is_running

def start_run(self):
self._idx = 0
self._is_running = True
self._samples = [self._build_empty_samples() for _ in range(len(self._dataset.data))]
self._eval_results = EvaluationResults(self._pipeline)

@property
def curr_sample(self):
if self._idx >= len(self._dataset.data):
if self._pipeline is None:
raise ValueError("Pipeline not set")
if self._idx >= len(self.dataset.data):
return None
return self._dataset.data[self._idx]
return self.dataset.data[self._idx]

def next_sample(self):
if self._idx >= len(self._dataset.data):
if self._pipeline is None:
raise ValueError("Pipeline not set")
if self._idx >= len(self.dataset.data):
self._is_running = False
else:
self._idx += 1
return self.curr_sample

# Logging results

def log(self, key, value):
# Make sure everything looks good
if self._pipeline is None:
raise ValueError("Pipeline not set")
assert type(value) == get_origin(self._pipeline.module_by_name(key).output) or isinstance(
value, self._pipeline.module_by_name(key).output
), f"Value {value} does not match expected type in the pipeline"
if not self._is_running:
raise ValueError("Cannot log when not running")
if key not in self._samples[self._idx]:
if key not in self._eval_results.results[self._idx]:
raise ValueError(f"Key {key} not found, review your pipeline")
if isinstance(self._samples[self._idx][key], list) and not isinstance(value, list):
self._samples[self._idx][key] = value
elif isinstance(self._samples[self._idx][key], dict):
self._samples[self._idx][key].update(value)
elif isinstance(self._samples[self._idx][key], set) and not isinstance(value, list):
self._samples[self._idx][key] = value

if isinstance(self._eval_results.results[self._idx][key], dict):
self._eval_results.results[self._idx][key].update(value)
else:
self._samples[self._idx][key] = value

def save_results(self, filepath: Path):
assert self._samples is not None, "No samples to save"
assert self._dataset is not None, "Dataset not set"
assert len(self._samples) == len(self._dataset.data), "Samples not complete"
assert filepath.suffix == ".jsonl", "File must be a JSONL file"
# Save samples to file (JSONL)
with open(filepath, "w") as f:
for line in self._samples:
json_record = json.dumps(line, ensure_ascii=False)
f.write(json_record + "\n")

def load_results(self, filepath: Path):
assert filepath.suffix == ".jsonl", "File must be a JSONL file"
assert self._dataset is not None, "Dataset not set"
# Load samples from file (JSONL)
with open(filepath, "r") as f:
self._samples = [json.loads(line) for line in f]
assert len(self._samples) == len(self._dataset.data), "Samples not complete"
self._eval_results.results[self._idx][key] = value

# Evaluate

Expand All @@ -134,83 +111,68 @@ def _prepare(self, module, metric):
if metric.overloaded_params is not None:
for key, val in metric.overloaded_params.items():
if isinstance(val, DatasetField):
kwargs[key] = [x[val.name] for x in self._dataset.data]
kwargs[key] = [x[val.name] for x in self.dataset.data] # type: ignore
elif isinstance(val, ModuleOutput):
module_name = module.name if val.module is None else val.module.name
kwargs[key] = [val(x[module_name]) for x in self._samples]
kwargs[key] = [val(x[module_name]) for x in self._eval_results.results]
else:
raise ValueError(f"Invalid promised parameter {key}={val}")
return kwargs

def run_eval(self):
def run_metrics(self):
logger.info("Running evaluation")
assert self._pipeline is not None, "Pipeline not set"
assert self._dataset is not None, "Dataset not set"
assert self._samples is not None, "Samples not set"
assert len(self._samples) == len(self._dataset.data), "Samples not complete"
evaluation_results = {
assert len(self._eval_results.results) > 0, "No evaluation samples to run the metrics on"
if len(self._eval_results.results) != len(self.dataset.data):
warnings.warn("The number of samples does not match the dataset size")
self._metrics_results.samples = {
module.name: {metric.name: metric.batch(**self._prepare(module, metric)) for metric in module.eval}
for module in self._pipeline.modules
if module.eval is not None
}
results = {
module_name: [dict(ChainMap(*x)) for x in zip(*eval_res.values())]
for module_name, eval_res in evaluation_results.items()
}
self._eval_results = results
return self._eval_results
return self._metrics_results

def save_eval_results(self, filepath: Path):
assert filepath.suffix == ".json", "File must be a JSON file"
assert self._eval_results is not None, "No samples to save"
assert self._dataset is not None, "Dataset not set"
assert all(
[len(module_res) == len(self._dataset.data) for module_res in self._eval_results.values()]
), "Evaluation is not complete"
with open(filepath, "w") as json_file:
json.dump(self._eval_results, json_file, indent=None)

def load_eval_results(self, filepath: Path):
assert filepath.suffix == ".json", "File must be a JSON file"
assert self._dataset is not None, "Dataset not set"
with open(filepath, "r") as json_file:
self._eval_results = json.load(json_file)
def aggregate_eval_results(self):
assert self._pipeline is not None, "Pipeline not set"
assert self._eval_results is not None, "Evaluation results not set"
assert all(
[len(module_res) == len(self._dataset.data) for module_res in self._eval_results.values()]
[len(module_res) == len(self.dataset.data) for module_res in self._metrics_results.results.values()]
), "Evaluation is not complete"
# aggregated_results = dict()
# for module in self._pipeline.modules:
# if module.eval is not None:
# for metric in module.eval:
# aggregated_results
# if metric.aggregate is not None:
# self._eval_results[module.name][metric.name] = metric.aggregate(
# self._eval_results[module.name][metric.name]
# )
return self._metrics_results.results

# Tests

def run_tests(self):
logger.info("Running tests")
assert self._pipeline is not None, "Pipeline not set"
assert self._dataset is not None, "Dataset not set"
assert self._eval_results is not None, "Evaluation results not set"
assert not self._eval_results.is_empty(), "Evaluation results not set"
assert all(
[len(module_res) == len(self._dataset.data) for module_res in self._eval_results.values()]
[len(module_res) == len(self.dataset.data) for module_res in self._metrics_results.results.values()]
), "Evaluation is not complete"
self._test_results = {
module.name: {test.name: test.run(self._eval_results[module.name]) for test in module.tests}
self._test_results.results = {
module.name: {test.name: test.run(self._metrics_results.results[module.name]) for test in module.tests}
for module in self._pipeline.modules
if module.tests is not None
}
return self._test_results

def save_test_results(self, filepath: Path):
assert filepath.suffix == ".json", "File must be a JSON file"
assert self._test_results is not None, "No samples to save"
with open(filepath, "w") as json_file:
json.dump(self._test_results, json_file, indent=None)

def load_test_results(self, filepath: Path):
assert filepath.suffix == ".json", "File must be a JSON file"
with open(filepath, "r") as json_file:
self._test_results = json.load(json_file)

def test_graph(self):
if self._pipeline is None:
raise ValueError("Pipeline not set")
if self._test_results is None:
raise ValueError("Tests not run")
pipeline_graph = self._pipeline.graph_repr()
tests = "\n %% Tests\n"
for module, results in self._test_results.items():
for module, results in self._test_results.results.items():
metrics_lines = []
for metric, passed in results.items():
status = "Pass" if passed else "Fail"
Expand Down
17 changes: 0 additions & 17 deletions continuous_eval/eval/metrics.py

This file was deleted.

10 changes: 9 additions & 1 deletion continuous_eval/eval/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ class Graph:


class Pipeline:
def __init__(self, modules: List[Module], dataset: Optional[Dataset] = None) -> None:
def __init__(self, modules: List[Module], dataset: Dataset) -> None:
self._modules = modules
self._dataset = dataset
self._graph = self._build_graph()
Expand All @@ -41,6 +41,14 @@ def module_by_name(self, name: str) -> Module:
return module
raise ValueError(f"Module {name} not found")

def get_metric(self, module_name: str, metric_name: str):
module = self.module_by_name(module_name)
try:
metric = [m for m in module.eval if m.name == metric_name][0]
except IndexError:
raise ValueError(f"Metric {metric_name} not found in module {module_name}")
return metric

def _validate_modules(self):
names = set()
for module in self._modules:
Expand Down
Loading

0 comments on commit 813ad6b

Please sign in to comment.