Skip to content

Commit

Permalink
necessary change to make torch2.3 work with triton2.2 (#122139)
Browse files Browse the repository at this point in the history
[ghstack-poisoned]
  • Loading branch information
guangy10 committed Mar 26, 2024
1 parent 1e82fd8 commit b6c15a8
Show file tree
Hide file tree
Showing 2 changed files with 63 additions and 16 deletions.
19 changes: 12 additions & 7 deletions torch/_higher_order_ops/triton_kernel_wrap.py
Expand Up @@ -95,7 +95,6 @@ def generate_ttir(kernel, kwargs):
"""
Uses Triton's internal code generation to create TTIR
"""
import triton
from triton.compiler.compiler import ASTSource
from triton.runtime.autotuner import Autotuner
from triton.runtime.jit import JITFunction
Expand Down Expand Up @@ -145,15 +144,21 @@ def generate_ttir(kernel, kwargs):
if i not in kernel.constexprs
}

context = triton._C.libtriton.ir.context()
target = triton.runtime.driver.active.get_current_target()
backend = triton.compiler.compiler.make_backend(target)
def get_backend():
from triton.compiler.backends.cuda import CUDABackend
from triton.runtime.driver import driver

target = driver.get_current_target()
return CUDABackend(target)

backend = get_backend()

options = backend.parse_options(dict())
triton._C.libtriton.ir.load_dialects(context)
backend.load_dialects(context)
# triton._C.libtriton.triton.ir.load_dialects(context)
# backend.load_dialects(context)

src = ASTSource(kernel, signature, constants, specialization)
ttir_module = src.make_ir(options, context)
ttir_module = src.make_ir(options)
if not ttir_module.verify():
raise Exception("Verification for TTIR module has failed")

Expand Down
60 changes: 51 additions & 9 deletions torch/utils/_triton.py
@@ -1,5 +1,6 @@
import functools
import hashlib
import os

from torch._dynamo.device_interface import get_interface_for_device

Expand Down Expand Up @@ -32,18 +33,61 @@ def is_device_compatible_with_triton():


@functools.lru_cache(None)
def triton_backend():
def triton_backend_hash():
from triton.common.backend import get_backend, get_cuda_version_key

import torch

if torch.version.hip:
# Does not work with ROCm
return None

from triton.compiler.compiler import make_backend
from triton.runtime.driver import driver
if not torch.cuda.is_available():
return None

target = driver.active.get_current_target()
return make_backend(target)
backend = get_backend("cuda")
if backend is None:
return get_cuda_version_key()
else:
return backend.get_version_key()


@functools.lru_cache
def triton_key():
import pkgutil

import triton

TRITON_PATH = os.path.dirname(os.path.abspath(triton.__file__))
contents = []
# This is redundant. Doing it to be consistent with upstream.
# frontend
with open(os.path.join(TRITON_PATH, "compiler", "compiler.py"), "rb") as f:
contents += [hashlib.sha256(f.read()).hexdigest()]

# compiler
compiler_path = os.path.join(TRITON_PATH, "compiler")
backends_path = os.path.join(TRITON_PATH, "compiler", "backends")
for lib in pkgutil.iter_modules([compiler_path, backends_path]):
with open(lib.module_finder.find_spec(lib.name).origin, "rb") as f: # type: ignore[call-arg, union-attr, arg-type]
contents += [hashlib.sha256(f.read()).hexdigest()]
# backend
libtriton_hash = hashlib.sha256()
with open(os.path.join(TRITON_PATH, "_C/libtriton.so"), "rb") as f:
while True:
chunk = f.read(1024**2)
if not chunk:
break
libtriton_hash.update(chunk)
contents.append(libtriton_hash.hexdigest())
# language
language_path = os.path.join(TRITON_PATH, "language")
for lib in pkgutil.iter_modules([language_path]):
with open(lib.module_finder.find_spec(lib.name).origin, "rb") as f: # type: ignore[call-arg, union-attr, arg-type]
contents += [hashlib.sha256(f.read()).hexdigest()]
from triton import __version__

return f"{__version__}" + "-".join(contents)


@functools.lru_cache(None)
Expand All @@ -54,8 +98,6 @@ def triton_hash_with_backend():
# Does not work with ROCm
return None

from triton.compiler.compiler import triton_key

backend = triton_backend()
key = f"{triton_key()}-{backend.hash()}"
backend_hash = triton_backend_hash()
key = f"{triton_key()}-{backend_hash}"
return hashlib.sha256(key.encode("utf-8")).hexdigest()

0 comments on commit b6c15a8

Please sign in to comment.