Skip to content

Commit

Permalink
remove \ in cache_dir
Browse files Browse the repository at this point in the history
  • Loading branch information
taomiao committed Oct 11, 2023
1 parent 8bc04f4 commit 81dafde
Showing 1 changed file with 93 additions and 88 deletions.
181 changes: 93 additions & 88 deletions torch/_inductor/codecache.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@
from concurrent.futures import Future, ProcessPoolExecutor, ThreadPoolExecutor
from copy import copy
from ctypes import c_void_p, cdll, CDLL
from dataclasses import field
from functools import partial
from importlib import abc
from pathlib import Path
Expand All @@ -41,7 +40,7 @@
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, TYPE_CHECKING, Union

import torch

from dataclasses import field
from torch._dynamo.device_interface import (
get_interface_for_device,
get_registered_device_interfaces,
Expand Down Expand Up @@ -76,16 +75,18 @@
def log_global_cache_errors(*args, **kwargs):
pass


def log_global_cache_stats(*args, **kwargs):
pass


def log_global_cache_vals(*args, **kwargs):
pass


def use_global_cache() -> bool:
return False


LOCK_TIMEOUT = 600

# timing metrics for time spent in the compilation
Expand Down Expand Up @@ -115,7 +116,11 @@ def _compile_end() -> None:
def cache_dir() -> str:
cache_dir = os.environ.get("TORCHINDUCTOR_CACHE_DIR")
if cache_dir is None:
cache_dir = f"{tempfile.gettempdir()}/torchinductor_{getpass.getuser()}"
cache_dir = (
f"{tempfile.gettempdir()}/torchinductor_{getpass.getuser()}".replace(
"\\", ""
)
)
os.makedirs(cache_dir, exist_ok=True)
return cache_dir

Expand Down Expand Up @@ -245,11 +250,11 @@ def get_global_cache(self):
return global_cache["cache"]

def lookup(
self,
choices: List[ChoiceCaller],
name: str,
inputs: str,
benchmark: Callable[[Any], Dict[ChoiceCaller, float]],
self,
choices: List[ChoiceCaller],
name: str,
inputs: str,
benchmark: Callable[[Any], Dict[ChoiceCaller, float]],
) -> Dict[ChoiceCaller, float]:
"""
Check to see if we have benchmarked the given choice callers. For each
Expand Down Expand Up @@ -288,8 +293,8 @@ def check_cache(cache, callback=None) -> bool:
local_cache = self.get_local_cache()
# check local cache first since it is data specific to the current machine
if not check_cache(local_cache) and not (
use_global_cache()
and check_cache(self.get_global_cache(), callback=log_stats)
use_global_cache()
and check_cache(self.get_global_cache(), callback=log_stats)
):
try:
# re-benchmark everything to try to get consistent numbers from the same machine
Expand Down Expand Up @@ -331,15 +336,15 @@ def code_hash(code: Union[str, bytes], extra: str = ""):
if extra != "":
hashing_str = hashing_str + b"||" + extra.encode("utf-8")
return (
"c"
+ base64.b32encode(hashlib.sha256(hashing_str).digest())[:51]
.decode("utf-8")
.lower()
"c"
+ base64.b32encode(hashlib.sha256(hashing_str).digest())[:51]
.decode("utf-8")
.lower()
)


def get_path(
basename: str, extension: str, specified_dir: str = ""
basename: str, extension: str, specified_dir: str = ""
) -> Tuple[str, str, str]:
if specified_dir:
if os.path.isabs(specified_dir):
Expand All @@ -363,11 +368,11 @@ def get_hash(content: Union[str, bytes], extra: str = "", hash_type: str = "code


def write(
content: Union[str, bytes],
extension: str,
extra: str = "",
hash_type: str = "code",
specified_dir: str = "",
content: Union[str, bytes],
extension: str,
extra: str = "",
hash_type: str = "code",
specified_dir: str = "",
) -> Tuple[str, str]:
key: str = get_hash(content, extra, hash_type)
basename, subdir, path = get_path(key, extension, specified_dir)
Expand Down Expand Up @@ -583,10 +588,10 @@ def compiled_fx_graph_hash(fx_args: List[Any], fx_kwargs: Dict[str, Any]) -> str
details = FxGraphHashDetails(fx_args, fx_kwargs)
serialized_data = FxGraphCachePickler.dumps(details)
return (
"f"
+ base64.b32encode(hashlib.sha256(serialized_data).digest())[:51]
.decode("utf-8")
.lower()
"f"
+ base64.b32encode(hashlib.sha256(serialized_data).digest())[:51]
.decode("utf-8")
.lower()
)


Expand All @@ -612,10 +617,10 @@ def load_graph(cls, cg_path: str) -> CompiledFxGraph:

@classmethod
def load(
cls,
compile_fx_fn: Callable[..., Any],
fx_args: List[Any],
fx_kwargs: Dict[str, Any],
cls,
compile_fx_fn: Callable[..., Any],
fx_args: List[Any],
fx_kwargs: Dict[str, Any],
):
from filelock import FileLock

Expand Down Expand Up @@ -1044,25 +1049,25 @@ def homebrew_libomp() -> Tuple[bool, str]:


def get_include_and_linking_paths(
include_pytorch: bool = False,
vec_isa: VecISA = invalid_vec_isa,
cuda: bool = False,
aot_mode: bool = False,
include_pytorch: bool = False,
vec_isa: VecISA = invalid_vec_isa,
cuda: bool = False,
aot_mode: bool = False,
) -> Tuple[str, str, str, str]:
if (
config.is_fbcode()
and "CUDA_HOME" not in os.environ
and "CUDA_PATH" not in os.environ
config.is_fbcode()
and "CUDA_HOME" not in os.environ
and "CUDA_PATH" not in os.environ
):
os.environ["CUDA_HOME"] = os.path.dirname(build_paths.cuda())
from torch.utils import cpp_extension

macros = ""
if sys.platform == "linux" and (
include_pytorch
or vec_isa != invalid_vec_isa
or cuda
or config.cpp.enable_kernel_profile
include_pytorch
or vec_isa != invalid_vec_isa
or cuda
or config.cpp.enable_kernel_profile
):
# Note - We include pytorch only on linux right now. There is more work
# to do to enable OMP build on darwin where PyTorch is built with IOMP
Expand All @@ -1088,7 +1093,7 @@ def get_include_and_linking_paths(
# are in lib/cuda-12 and lib/cuda-12/stubs
for i, path in enumerate(lpaths):
if path.startswith(
os.environ["CUDA_HOME"]
os.environ["CUDA_HOME"]
) and not os.path.exists(f"{path}/libcudart_static.a"):
for root, dirs, files in os.walk(path):
if "libcudart_static.a" in files:
Expand Down Expand Up @@ -1155,7 +1160,7 @@ def get_include_and_linking_paths(
lpaths.append(conda_lib_path)
# Prefer Intel OpenMP on x86 machine
if os.uname().machine == "x86_64" and os.path.exists(
os.path.join(conda_lib_path, "libiomp5.dylib")
os.path.join(conda_lib_path, "libiomp5.dylib")
):
libs = ["iomp5"]

Expand Down Expand Up @@ -1199,16 +1204,16 @@ def get_include_and_linking_paths(


def cpp_compile_command(
input: Union[str, List[str]],
output: str,
warning_all: bool = True,
shared: bool = True,
include_pytorch: bool = False,
vec_isa: VecISA = invalid_vec_isa,
cuda: bool = False,
aot_mode: bool = False,
compile_only: bool = False,
use_absolute_path: bool = False,
input: Union[str, List[str]],
output: str,
warning_all: bool = True,
shared: bool = True,
include_pytorch: bool = False,
vec_isa: VecISA = invalid_vec_isa,
cuda: bool = False,
aot_mode: bool = False,
compile_only: bool = False,
use_absolute_path: bool = False,
) -> str:
ipaths, lpaths, libs, macros = get_include_and_linking_paths(
include_pytorch, vec_isa, cuda, aot_mode
Expand Down Expand Up @@ -1282,11 +1287,11 @@ class AotCodeCache:

@classmethod
def compile(
cls,
graph: GraphLowering,
source_code: str,
serialized_extern_kernel_nodes: Optional[str],
cuda: bool,
cls,
graph: GraphLowering,
source_code: str,
serialized_extern_kernel_nodes: Optional[str],
cuda: bool,
) -> Callable[..., Any]:
picked_vec_isa = pick_vec_isa()
cpp_command = repr(
Expand Down Expand Up @@ -1473,7 +1478,7 @@ def cpp_prefix() -> str:
# Given a path to an input cpp file and an output path,
# Attempts to compile the file, storing the output in "output_path"
def compile_file(
input_path: Union[str, List[str]], output_path: str, cmd: List[str]
input_path: Union[str, List[str]], output_path: str, cmd: List[str]
) -> None:
input_paths = [input_path] if isinstance(input_path, str) else input_path
input_files = [
Expand Down Expand Up @@ -1585,22 +1590,22 @@ def write(cls, source_code: str, extra: str = "") -> Tuple[str, str]:

@classmethod
def load(
cls,
source_code: str,
extra: str = "",
linemap: Optional[List[Tuple[int, str]]] = None,
attrs: Optional[Dict[str, Any]] = None,
cls,
source_code: str,
extra: str = "",
linemap: Optional[List[Tuple[int, str]]] = None,
attrs: Optional[Dict[str, Any]] = None,
) -> ModuleType:
key, path = write(source_code, "py", extra=extra)
return cls.load_by_key_path(key, path, linemap, attrs)

@classmethod
def load_by_key_path(
cls,
key: str,
path: str,
linemap: Optional[List[Tuple[int, str]]] = None,
attrs: Optional[Dict[str, Any]] = None,
cls,
key: str,
path: str,
linemap: Optional[List[Tuple[int, str]]] = None,
attrs: Optional[Dict[str, Any]] = None,
) -> ModuleType:
if linemap is None:
linemap = []
Expand Down Expand Up @@ -1631,7 +1636,7 @@ def load_by_key_path(
@classmethod
@functools.lru_cache(None)
def stack_frames_for_code(
cls, path: str, lineno: int
cls, path: str, lineno: int
) -> Optional[List[Dict[str, Any]]]:
if path not in cls.linemaps:
return None
Expand Down Expand Up @@ -1762,7 +1767,7 @@ def _cuda_lib_options() -> List[str]:
if is_linux():
extra_lib_dir = "lib64"
if not os.path.exists(
cpp_extension._join_cuda_home(extra_lib_dir)
cpp_extension._join_cuda_home(extra_lib_dir)
) and os.path.exists(cpp_extension._join_cuda_home("lib")):
# 64-bit CUDA may be installed in "lib"
# Note that it's also possible both don't exist (see _find_cuda_home) - in that case we stay with "lib64"
Expand Down Expand Up @@ -1829,22 +1834,22 @@ def _nvcc_compiler_options() -> List[str]:


def cuda_compile_command(
src_files: List[str],
dst_file: str,
dst_file_ext: str,
src_files: List[str],
dst_file: str,
dst_file_ext: str,
) -> str:
include_paths = _cutlass_include_paths()
cuda_lib_options = _cuda_lib_options()
nvcc_host_compiler_options = _nvcc_host_compiler_options()
nvcc_compiler_options = _nvcc_compiler_options()
options = (
nvcc_compiler_options
+ [
f"-Xcompiler {opt}" if "=" in opt else f"-Xcompiler={opt}"
for opt in nvcc_host_compiler_options
]
+ ["-I" + path for path in include_paths]
+ cuda_lib_options
nvcc_compiler_options
+ [
f"-Xcompiler {opt}" if "=" in opt else f"-Xcompiler={opt}"
for opt in nvcc_host_compiler_options
]
+ ["-I" + path for path in include_paths]
+ cuda_lib_options
)
src_file = " ".join(src_files)
res = ""
Expand All @@ -1863,8 +1868,8 @@ class DLLWrapper:
"""A wrapper for a dynamic library."""

def __init__(
self,
lib_path: str,
self,
lib_path: str,
):
self.lib_path = lib_path
self.DLL = cdll.LoadLibrary(lib_path)
Expand Down Expand Up @@ -1999,7 +2004,7 @@ def caching_device_properties():


def _worker_compile(
kernel_name: str, source_code: str, cc: int, device: torch.device
kernel_name: str, source_code: str, cc: int, device: torch.device
) -> None:
device_interface = get_interface_for_device(device.type)
device_interface.Worker.set_device(device.index)
Expand All @@ -2015,10 +2020,10 @@ def _load_kernel(kernel_name: str, source_code: str) -> ModuleType:

class TritonFuture:
def __init__(
self,
kernel_name: str,
source_code: str,
future: Future[Any],
self,
kernel_name: str,
source_code: str,
future: Future[Any],
) -> None:
self.kernel_name = kernel_name
self.source_code = source_code
Expand Down Expand Up @@ -2141,7 +2146,7 @@ def map(cls, fn: Callable[..., Any], seq: List[Any]) -> List[Any]:
return [t.result() for t in [cls.pool().submit(fn, x) for x in seq]]

def triton(
self, kernel_name: str, source_code: str, device: str = "cuda"
self, kernel_name: str, source_code: str, device: str = "cuda"
) -> Union[TritonFuture, ModuleType]:
_compile_start()

Expand Down

0 comments on commit 81dafde

Please sign in to comment.