Skip to content

Commit

Permalink
[frontend] Use a structured dataclass for target description (#3710)
Browse files Browse the repository at this point in the history
This commit changes the target description to be a dataclass so that we
can use attributes to access various fields instead of magic array
indices. This makes it more readable.
  • Loading branch information
antiagainst committed Apr 22, 2024
1 parent c658b6d commit 5162346
Show file tree
Hide file tree
Showing 18 changed files with 56 additions and 35 deletions.
2 changes: 1 addition & 1 deletion python/test/unit/hopper/test_gemm.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@


def is_hip():
return triton.runtime.driver.active.get_current_target()[0] == "hip"
return triton.runtime.driver.active.get_current_target().backend == "hip"


@triton.jit
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@

def is_hip_mi200():
target = triton.runtime.driver.active.get_current_target()
return target[0] == 'hip' and target[1] == 'gfx90a'
return target.backend == 'hip' and target.arch == 'gfx90a'


@triton.autotune(
Expand Down
2 changes: 1 addition & 1 deletion python/test/unit/language/assert_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@


def get_current_target_warp_size():
return triton.runtime.driver.active.get_current_target()[2]
return triton.runtime.driver.active.get_current_target().warp_size


@triton.jit
Expand Down
2 changes: 1 addition & 1 deletion python/test/unit/language/print_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@


def get_current_target_warp_size():
return triton.runtime.driver.active.get_current_target()[2]
return triton.runtime.driver.active.get_current_target().warp_size


@triton.jit
Expand Down
6 changes: 3 additions & 3 deletions python/test/unit/language/test_conversions.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,13 @@ def is_interpreter():
return os.environ.get('TRITON_INTERPRET', '0') == '1'

def is_cuda():
return not is_interpreter() and triton.runtime.driver.active.get_current_target()[0] == "cuda"
return not is_interpreter() and triton.runtime.driver.active.get_current_target().backend == "cuda"

def is_hip():
return not is_interpreter() and triton.runtime.driver.active.get_current_target()[0] == "hip"
return not is_interpreter() and triton.runtime.driver.active.get_current_target().backend == "hip"

def is_on_mi300():
return is_hip() and triton.runtime.driver.active.get_current_target()[1] in ('gfx940', 'gfx941', 'gfx942')
return is_hip() and triton.runtime.driver.active.get_current_target().arch in ('gfx940', 'gfx941', 'gfx942')

def matching_int(dtype):
if dtype.primitive_bitwidth == 8:
Expand Down
8 changes: 5 additions & 3 deletions python/test/unit/language/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,13 @@ def is_interpreter():


def is_cuda():
return not is_interpreter() and triton.runtime.driver.active.get_current_target()[0] == "cuda"
return not is_interpreter() and \
triton.runtime.driver.active.get_current_target().backend == "cuda"


def is_hip():
return not is_interpreter() and triton.runtime.driver.active.get_current_target()[0] == "hip"
return not is_interpreter() and \
triton.runtime.driver.active.get_current_target().backend == "hip"


int_dtypes = ['int8', 'int16', 'int32', 'int64']
Expand All @@ -46,7 +48,7 @@ def is_hip():
if is_interpreter():
THREADS_PER_WARP = 1
elif is_hip():
THREADS_PER_WARP = triton.runtime.driver.active.get_current_target()[2]
THREADS_PER_WARP = triton.runtime.driver.active.get_current_target().warp_size
else:
THREADS_PER_WARP = 32

Expand Down
2 changes: 1 addition & 1 deletion python/test/unit/language/test_line_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ def get_disassembler_command_and_debug_line_format():
Returns a tuple: (object file kind, disassembler tool command,
debug line anchor, debug line file and line number separator).
"""
backend = triton.runtime.driver.active.get_current_target()[0]
backend = triton.runtime.driver.active.get_current_target().backend

if backend == "cuda":
from triton.backends.nvidia.compiler import _path_to_binary
Expand Down
2 changes: 1 addition & 1 deletion python/test/unit/operators/test_blocksparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

def is_hip_mi200():
target = triton.runtime.driver.active.get_current_target()
return target[0] == 'hip' and target[1] == 'gfx90a'
return target.backend == 'hip' and target.arch == 'gfx90a'


def sparsify_tensor(x, mask, block):
Expand Down
2 changes: 1 addition & 1 deletion python/test/unit/operators/test_matmul.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@


def is_hip():
return triton.runtime.driver.active.get_current_target()[0] == "hip"
return triton.runtime.driver.active.get_current_target().backend == "hip"


@pytest.mark.parametrize(
Expand Down
3 changes: 2 additions & 1 deletion python/test/unit/tools/test_aot.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import numpy as np

import triton
from triton.backends.compiler import GPUTarget
from triton.backends.nvidia.driver import include_dir, library_dirs

kernel_utils_src = """
Expand Down Expand Up @@ -435,7 +436,7 @@ def test_ttgir_to_ptx():
kernel_path = os.path.join(tmp_dir, "empty_kernel.ttgir")
with open(kernel_path, "w") as fp:
fp.write(src)
k = triton.compile(kernel_path, target=("cuda", 80))
k = triton.compile(kernel_path, target=GPUTarget("cuda", 80, 32))
ptx = k.asm["ptx"]
assert ".target sm_80" in ptx
assert ".address_size 64" in ptx
20 changes: 16 additions & 4 deletions python/triton/backends/compiler.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,24 @@
from abc import ABCMeta, abstractmethod, abstractclassmethod
import os
import subprocess
import re
import subprocess

from abc import ABCMeta, abstractmethod, abstractclassmethod
from dataclasses import dataclass
from typing import Union


@dataclass(frozen=True)
class GPUTarget(object):
# Target backend, e.g., cuda, hip
backend: str
# Target architecture, e.g., 90 (for cuda compute capability), gfx940 (for hip)
arch: Union[int, str]
warp_size: int


class BaseBackend(metaclass=ABCMeta):

def __init__(self, target: tuple) -> None:
def __init__(self, target: GPUTarget) -> None:
self.target = target
assert self.supports_target(target)

Expand All @@ -28,7 +40,7 @@ def _path_to_binary(binary: str):
raise RuntimeError(f"Cannot find {binary}")

@abstractclassmethod
def supports_target(target: tuple):
def supports_target(target: GPUTarget):
raise NotImplementedError

@abstractmethod
Expand Down
7 changes: 6 additions & 1 deletion python/triton/compiler/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import json
from .._C.libtriton import get_env_vars, ir
from ..backends import backends
from ..backends.compiler import GPUTarget
from .. import __version__
from ..runtime.autotuner import OutOfResources
from ..runtime.cache import get_cache_manager, get_dump_manager, get_override_manager
Expand Down Expand Up @@ -225,6 +226,7 @@ def filter_traceback(e: BaseException):
def compile(src, target=None, options=None):
if target is None:
target = driver.active.get_current_target()
assert isinstance(target, GPUTarget), "target must be of GPUTarget type"
backend = make_backend(target)
ir_source = not isinstance(src, ASTSource)
# create backend
Expand Down Expand Up @@ -302,7 +304,7 @@ def make_backend(target):
actives = [x.compiler for x in backends.values() if x.compiler.supports_target(target)]
if len(actives) != 1:
raise RuntimeError(
f"{len(actives)} compatible backends for target ({target[0]}) ({actives}). There should only be one.")
f"{len(actives)} compatible backends for target ({target.backend}) ({actives}). There should only be one.")
return actives[0](target)


Expand Down Expand Up @@ -334,6 +336,9 @@ def __init__(self, src, metadata_group, hash):
metadata_path = next((Path(p) for c, p in metadata_group.items() if c.endswith(".json")))
metadata = json.loads(metadata_path.read_text())
metadata['cluster_dims'] = tuple(metadata['cluster_dims'])
# JSON serialization dumps the target as a dict. Restore it to a GPUTarget.
target = metadata['target']
metadata['target'] = GPUTarget(target['backend'], target['arch'], target['warp_size'])
KernelMetadata = namedtuple('KernelMetadata', sorted(list(metadata.keys())))
self.metadata = KernelMetadata(**metadata)
backend = make_backend(self.metadata.target)
Expand Down
2 changes: 1 addition & 1 deletion python/triton/ops/flash_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@


def is_hip():
return triton.runtime.driver.active.get_current_target()[0] == "hip"
return triton.runtime.driver.active.get_current_target().backend == "hip"


@jit
Expand Down
2 changes: 1 addition & 1 deletion python/tutorials/06-fused-attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -572,7 +572,7 @@ def test_op(Z, H, N_CTX, D_HEAD, causal, dtype=torch.float16):
rtol = 0.0
# Relative tolerance workaround for known hardware limitation of MI200 GPU.
# For detailss see https://pytorch.org/docs/stable/notes/numerical_accuracy.html#reduced-precision-fp16-and-bf16-gemms-and-convolutions-on-amd-instinct-mi200-devices
if torch.version.hip is not None and triton.runtime.driver.active.get_current_target()[1] == "gfx90a":
if torch.version.hip is not None and triton.runtime.driver.active.get_current_target().arch == "gfx90a":
rtol = 1e-2
assert torch.allclose(ref_dv, tri_dv, atol=1e-2, rtol=rtol)
assert torch.allclose(ref_dk, tri_dk, atol=1e-2, rtol=rtol)
Expand Down
13 changes: 6 additions & 7 deletions third_party/amd/backend/compiler.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from triton.backends.compiler import BaseBackend
from triton.backends.compiler import BaseBackend, GPUTarget
from triton._C.libtriton import ir, passes, llvm, amd
from dataclasses import dataclass
from typing import Any, Tuple
Expand Down Expand Up @@ -82,17 +82,16 @@ def hash(self):
class HIPBackend(BaseBackend):

@staticmethod
def supports_target(target: list):
return target[0] == 'hip'
def supports_target(target: GPUTarget):
return target.backend == 'hip'

def __init__(self, target: list) -> None:
def __init__(self, target: GPUTarget) -> None:
super().__init__(target)
assert isinstance(target, list) and len(target) == 3
assert isinstance(target[1], str)
assert isinstance(target.arch, str)
self.binary_ext = "hsaco"

def parse_options(self, opts) -> Any:
args = {'arch': self.target[1]}
args = {'arch': self.target.arch}
args.update({k: opts[k] for k in HIPOptions.__dataclass_fields__.keys() if k in opts})
args['capability'] = HIPOptions.get_compute_capability(args['arch'])
return HIPOptions(**args)
Expand Down
3 changes: 2 additions & 1 deletion third_party/amd/backend/driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from pathlib import Path
from triton.runtime.build import _build
from triton.runtime.cache import get_cache_manager
from triton.backends.compiler import GPUTarget
from triton.backends.driver import GPUDriver

dirname = os.path.dirname(os.path.realpath(__file__))
Expand Down Expand Up @@ -434,4 +435,4 @@ def get_current_target(self):
device_properties = self.utils.get_device_properties(device)
arch = device_properties['arch']
warp_size = device_properties['warpSize']
return ["hip", arch.split(':')[0], warp_size]
return GPUTarget("hip", arch.split(':')[0], warp_size)
10 changes: 5 additions & 5 deletions third_party/nvidia/backend/compiler.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from triton.backends.compiler import BaseBackend
from triton.backends.compiler import BaseBackend, GPUTarget
from triton._C.libtriton import ir, passes, llvm, nvidia
from triton.backends.nvidia.driver import CudaUtils

Expand Down Expand Up @@ -94,12 +94,12 @@ def hash(self):
class CUDABackend(BaseBackend):

@staticmethod
def supports_target(target: tuple):
return target[0] == 'cuda'
def supports_target(target: GPUTarget):
return target.backend == 'cuda'

def __init__(self, target: tuple) -> None:
def __init__(self, target: GPUTarget) -> None:
super().__init__(target)
self.capability = target[1]
self.capability = target.arch
assert isinstance(self.capability, int)
self.binary_ext = "cubin"

Expand Down
3 changes: 2 additions & 1 deletion third_party/nvidia/backend/driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from pathlib import Path
from triton.runtime.build import _build
from triton.runtime.cache import get_cache_manager
from triton.backends.compiler import GPUTarget
from triton.backends.driver import GPUDriver

dirname = os.path.dirname(os.path.realpath(__file__))
Expand Down Expand Up @@ -374,7 +375,7 @@ def get_current_target(self):
capability = self.get_device_capability(device)
capability = capability[0] * 10 + capability[1]
warp_size = 32
return ("cuda", capability, warp_size)
return GPUTarget("cuda", capability, warp_size)

@staticmethod
def is_active():
Expand Down

0 comments on commit 5162346

Please sign in to comment.