Skip to content

Commit

Permalink
Pass the value of the definitions argument down to use custom definit…
Browse files Browse the repository at this point in the history
…ions for the algorithms
  • Loading branch information
tapas committed May 23, 2024
1 parent 4c08e04 commit c7aac44
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 8 deletions.
18 changes: 11 additions & 7 deletions ann_benchmarks/definitions.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,7 @@ def _get_definitions(base_dir: str = "ann_benchmarks/algorithms") -> Dict[str, D
print(f"Error loading YAML from {config_file}: {e}")
return configs

def _get_algorithm_definitions(point_type: str, distance_metric: str) -> Dict[str, Dict[str, Any]]:
def _get_algorithm_definitions(point_type: str, distance_metric: str, base_dir: str = "ann_benchmarks/algorithms") -> Dict[str, Dict[str, Any]]:
"""Get algorithm definitions for a specific point type and distance metric.
A specific algorithm folder can have multiple algorithm definitions for a given point type and
Expand Down Expand Up @@ -188,7 +188,7 @@ def _get_algorithm_definitions(point_type: str, distance_metric: str) -> Dict[st
}
```
"""
configs = load_configs(point_type)
configs = load_configs(point_type, base_dir)
definitions = {}

# param `_` is filename, not specific name
Expand Down Expand Up @@ -341,12 +341,16 @@ def create_definitions_from_algorithm(name: str, algo: Dict[str, Any], dimension
return definitions

def get_definitions(
dimension: int,
point_type: str = "float",
distance_metric: str = "euclidean",
count: int = 10
dimension: int,
point_type: str = "float",
distance_metric: str = "euclidean",
count: int = 10,
base_dir: str = "ann_benchmarks/algorithms"
) -> List[Definition]:
algorithm_definitions = _get_algorithm_definitions(point_type=point_type, distance_metric=distance_metric)
algorithm_definitions = _get_algorithm_definitions(point_type=point_type,
distance_metric=distance_metric,
base_dir=base_dir
)

definitions: List[Definition] = []

Expand Down
3 changes: 2 additions & 1 deletion ann_benchmarks/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -312,7 +312,8 @@ def main():
dimension=dimension,
point_type=dataset.attrs.get("point_type", "float"),
distance_metric=dataset.attrs["distance"],
count=args.count
count=args.count,
base_dir=args.definitions,
)
random.shuffle(definitions)

Expand Down

0 comments on commit c7aac44

Please sign in to comment.