-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy path__init__.py
More file actions
88 lines (77 loc) · 2.56 KB
/
__init__.py
File metadata and controls
88 lines (77 loc) · 2.56 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
import importlib.util
import subprocess
from typing import Optional, Literal, List, Hashable, Dict, Tuple
from .runtime import (
to_ssa_path,
jensum,
jensum_meta,
to_annotated_ssa_path,
linear_path_runtime_meta,
get_ops_and_max_size,
)
from .info import compute_meta_info_of_einsum_instance
__all__ = [
"find_path",
"jensum",
"jensum_meta",
"to_ssa_path",
"to_annotated_ssa_path",
"linear_path_runtime_meta",
"get_ops_and_max_size",
"compute_meta_info_of_einsum_instance",
]
Inputs = List[List[Hashable]]
Output = List[Hashable]
SizeDict = Dict[Hashable, int]
Path = List[Tuple[int, ...]]
def find_path(
format_string: str,
*tensors,
minimize: Literal["flops", "size"],
n_trials: int = 128,
n_jobs: int = 10,
show_progress_bar: bool = True,
timeout: Optional[int] = None,
):
"""Optimize a path for evaluating an einsum expression.
Args:
format_string (str): The Einstein summation notation expression.
*tensors: The input tensors.
minimize (Literal["flops", "size"]): The objective to minimize, either "flops" or "size".
n_trials (int, optional): The number of trials for the optimization process. Defaults to 128.
n_jobs (int, optional): The number of parallel jobs to run. Defaults to 10.
show_progress_bar (bool, optional): Whether to show a progress bar during optimization. Defaults to True.
timeout (int, optional): The maximum time in seconds for the optimization process. Defaults to None.
Returns:
str: An ssa path for evaluating the einsum expression.
"""
if (
importlib.util.find_spec("kahypar") is None
or importlib.util.find_spec("cgreedy") is None
or importlib.util.find_spec("optuna") is None
):
raise ImportError(
"""You need to install the optional dependencies for path to use this function
You can do this with pip install "einsum_benchmark[path]"
"""
)
from . import path_finder
inputs, output = format_string.split("->")
inputs = inputs.split(",")
shapes = [tensor.shape for tensor in tensors]
size_dict = {}
for input, shape in zip(inputs, shapes):
for char, size in zip(input, shape):
if char in size_dict:
assert size_dict[char] == size
size_dict[char] = size
return path_finder.hyper_optimized_hhg(
inputs,
output,
size_dict,
minimize,
n_trials,
n_jobs,
show_progress_bar,
timeout,
)