Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

necessary change to make torch2.3 work with triton2.2 #122139

Merged
merged 1 commit into from Mar 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
19 changes: 12 additions & 7 deletions torch/_higher_order_ops/triton_kernel_wrap.py
Expand Up @@ -95,7 +95,6 @@ def generate_ttir(kernel, kwargs):
"""
Uses Triton's internal code generation to create TTIR
"""
import triton
from triton.compiler.compiler import ASTSource
from triton.runtime.autotuner import Autotuner
from triton.runtime.jit import JITFunction
Expand Down Expand Up @@ -145,15 +144,21 @@ def generate_ttir(kernel, kwargs):
if i not in kernel.constexprs
}

context = triton._C.libtriton.ir.context()
target = triton.runtime.driver.active.get_current_target()
backend = triton.compiler.compiler.make_backend(target)
def get_backend():
from triton.compiler.backends.cuda import CUDABackend
from triton.runtime.driver import driver

target = driver.get_current_target()
return CUDABackend(target)

backend = get_backend()

options = backend.parse_options(dict())
triton._C.libtriton.ir.load_dialects(context)
backend.load_dialects(context)
# triton._C.libtriton.triton.ir.load_dialects(context)
# backend.load_dialects(context)

src = ASTSource(kernel, signature, constants, specialization)
ttir_module = src.make_ir(options, context)
ttir_module = src.make_ir(options)
if not ttir_module.verify():
raise Exception("Verification for TTIR module has failed")

Expand Down
60 changes: 51 additions & 9 deletions torch/utils/_triton.py
@@ -1,5 +1,6 @@
import functools
import hashlib
import os

from torch._dynamo.device_interface import get_interface_for_device

Expand Down Expand Up @@ -32,18 +33,61 @@ def is_device_compatible_with_triton():


@functools.lru_cache(None)
def triton_backend():
def triton_backend_hash():
from triton.common.backend import get_backend, get_cuda_version_key

import torch

if torch.version.hip:
# Does not work with ROCm
return None

from triton.compiler.compiler import make_backend
from triton.runtime.driver import driver
if not torch.cuda.is_available():
return None

target = driver.active.get_current_target()
return make_backend(target)
backend = get_backend("cuda")
if backend is None:
return get_cuda_version_key()
else:
return backend.get_version_key()


@functools.lru_cache
def triton_key():
import pkgutil

import triton

TRITON_PATH = os.path.dirname(os.path.abspath(triton.__file__))
contents = []
# This is redundant. Doing it to be consistent with upstream.
# frontend
with open(os.path.join(TRITON_PATH, "compiler", "compiler.py"), "rb") as f:
contents += [hashlib.sha256(f.read()).hexdigest()]

# compiler
compiler_path = os.path.join(TRITON_PATH, "compiler")
backends_path = os.path.join(TRITON_PATH, "compiler", "backends")
for lib in pkgutil.iter_modules([compiler_path, backends_path]):
with open(lib.module_finder.find_spec(lib.name).origin, "rb") as f: # type: ignore[call-arg, union-attr, arg-type]
contents += [hashlib.sha256(f.read()).hexdigest()]
# backend
libtriton_hash = hashlib.sha256()
with open(os.path.join(TRITON_PATH, "_C/libtriton.so"), "rb") as f:
while True:
chunk = f.read(1024**2)
if not chunk:
break
libtriton_hash.update(chunk)
contents.append(libtriton_hash.hexdigest())
# language
language_path = os.path.join(TRITON_PATH, "language")
for lib in pkgutil.iter_modules([language_path]):
with open(lib.module_finder.find_spec(lib.name).origin, "rb") as f: # type: ignore[call-arg, union-attr, arg-type]
contents += [hashlib.sha256(f.read()).hexdigest()]
from triton import __version__

return f"{__version__}" + "-".join(contents)


@functools.lru_cache(None)
Expand All @@ -54,8 +98,6 @@ def triton_hash_with_backend():
# Does not work with ROCm
return None

from triton.compiler.compiler import triton_key

backend = triton_backend()
key = f"{triton_key()}-{backend.hash()}"
backend_hash = triton_backend_hash()
key = f"{triton_key()}-{backend_hash}"
return hashlib.sha256(key.encode("utf-8")).hexdigest()