Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion tilelang/cache/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ def cached(
execution_backend: Optional[Literal["dlpack", "ctypes", "cython", "nvrtc"]] = "cython",
verbose: Optional[bool] = False,
pass_configs: Optional[dict] = None,
compile_flags: Optional[List[str]] = None,
compile_flags: Optional[Union[List[str], str]] = None,
) -> JITKernel:
"""
Caches and reuses compiled kernels (using KernelCache class).
Expand Down
11 changes: 9 additions & 2 deletions tilelang/cache/kernel_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ def _generate_key(
target: Union[str, Target] = "auto",
target_host: Union[str, Target] = None,
pass_configs: dict = None,
compile_flags: Optional[Union[List[str], str]] = None,
) -> str:
"""
Generates a unique hash key for caching compiled kernels.
Expand Down Expand Up @@ -101,6 +102,7 @@ def _generate_key(
"target_host": str(target_host) if target_host else None,
"execution_backend": execution_backend,
"pass_configs": pass_configs,
"compile_flags": compile_flags,
}
# Sort keys to ensure consistency
key_string = json.dumps(key_data, sort_keys=True)
Expand All @@ -117,7 +119,7 @@ def cached(
execution_backend: Literal["dlpack", "ctypes", "cython", "nvrtc"] = "cython",
verbose: bool = False,
pass_configs: dict = None,
compile_flags: Optional[List[str]] = None,
compile_flags: Optional[Union[List[str], str]] = None,
) -> JITKernel:
"""
Caches and reuses compiled kernels to avoid redundant compilation.
Expand Down Expand Up @@ -152,6 +154,7 @@ def cached(
target=target,
target_host=target_host,
pass_configs=pass_configs,
compile_flags=compile_flags,
)
with self._lock:
# First check in-memory cache
Expand All @@ -165,7 +168,8 @@ def cached(

# Then check disk cache
kernel = self._load_kernel_from_disk(key, target, target_host, out_idx,
execution_backend, pass_configs, func, verbose)
execution_backend, pass_configs, compile_flags,
func, verbose)
if kernel is not None:
if verbose:
self.logger.debug(
Expand All @@ -185,6 +189,7 @@ def cached(
target_host=target_host,
verbose=verbose,
pass_configs=pass_configs,
compile_flags=compile_flags,
)
if execution_backend == "dlpack":
self.logger.warning("DLPack backend does not support cache saving to disk.")
Expand Down Expand Up @@ -322,6 +327,7 @@ def _load_kernel_from_disk(
out_idx: List[int] = None,
execution_backend: Literal["dlpack", "ctypes", "cython", "nvrtc"] = "cython",
pass_configs: dict = None,
compile_flags: Optional[Union[List[str], str]] = None,
func: Callable = None,
verbose: bool = False,
) -> Optional[JITKernel]:
Expand Down Expand Up @@ -382,6 +388,7 @@ def _load_kernel_from_disk(
out_idx=out_idx,
execution_backend=execution_backend,
pass_configs=pass_configs,
compile_flags=compile_flags,
)
else:
return None
Expand Down
7 changes: 5 additions & 2 deletions tilelang/jit/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ class _JitImplementation:
verbose: bool
pass_configs: Optional[Dict[str, Any]]
debug_root_path: Optional[str]
compile_flags: Optional[List[str]]
compile_flags: Optional[Union[List[str], str]]

def __init__(self,
out_idx: Any = None,
Expand All @@ -105,7 +105,7 @@ def __init__(self,
verbose: bool = False,
pass_configs: Optional[Dict[str, Any]] = None,
debug_root_path: Optional[str] = None,
compile_flags: Optional[List[str]] = None):
compile_flags: Optional[Union[List[str], str]] = None):
"""
Initializes the JIT compiler decorator.

Expand Down Expand Up @@ -137,6 +137,9 @@ def __init__(self,
If None, no debug information is saved (default: None).
If a relative path is given, it's made absolute relative to the project root
or current working directory.
compile_flags : Optional[Union[List[str], str]], optional
Additional compilation flags to pass to the compiler.
If None, no additional compilation flags are passed (default: None).
"""
self.out_idx = out_idx
self.execution_backend = execution_backend
Expand Down
Loading