Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

bring ptx back #3623

Merged
merged 8 commits into from Mar 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/test.yml
Expand Up @@ -346,7 +346,7 @@ jobs:
strategy:
fail-fast: false
matrix:
backend: [llvm, clang, gpu, cuda, hip] #, triton] #, ptx]
backend: [llvm, clang, gpu, cuda, hip, ptx] #, triton]

name: Tests on (${{ matrix.backend }})
runs-on: ubuntu-latest
Expand Down
10 changes: 6 additions & 4 deletions test/test_linearizer.py
Expand Up @@ -784,8 +784,9 @@ def test_grouped_store_values(self):
assert store_val.dtype == dtypes.float.vec(4) and store_val.uop != UOps.CAST

def test_grouped_store_locals_and_globals(self):
if not Device[Device.DEFAULT].compiler.linearizer_opts.has_local or not Device[Device.DEFAULT].compiler.linearizer_opts.has_shared:
self.skipTest("Only Compiled uses linearizer with locals and shared")
if not Device[Device.DEFAULT].compiler.linearizer_opts.has_local or not Device[Device.DEFAULT].compiler.linearizer_opts.has_shared or \
not Device[Device.DEFAULT].compiler.linearizer_opts.supports_float4:
self.skipTest("Only Compiled uses linearizer with locals, shared, and float4")

x, y = Tensor.rand(128, 128), Tensor.rand(128, 128)
out = x@y
Expand All @@ -808,8 +809,9 @@ def test_grouped_store_locals_and_globals(self):
assert len([u for u in k.uops if u.uop is UOps.IF and u.vin[-1] == barrier]) == 1

def test_grouped_store_local_only(self):
if not Device[Device.DEFAULT].compiler.linearizer_opts.has_local or not Device[Device.DEFAULT].compiler.linearizer_opts.has_shared:
self.skipTest("Only Compiled uses linearizer with locals and shared")
if not Device[Device.DEFAULT].compiler.linearizer_opts.has_local or not Device[Device.DEFAULT].compiler.linearizer_opts.has_shared or \
not Device[Device.DEFAULT].compiler.linearizer_opts.supports_float4:
self.skipTest("Only Compiled uses linearizer with locals, shared, and float4")

x, y = Tensor.rand(1,128), Tensor.rand(128, 128)
r = (x@y).relu()
Expand Down
6 changes: 4 additions & 2 deletions test/test_linearizer_overflows.py
@@ -1,6 +1,7 @@
# ruff: noqa: E501
import unittest
from tinygrad import dtypes, Device
from tinygrad import dtypes
from tinygrad.helpers import CI
from tinygrad.codegen.linearizer import Linearizer
from tinygrad.features.search import Opt, OptOps
from tinygrad.features.search import time_linearizer, bufs_from_lin
Expand Down Expand Up @@ -63,7 +64,8 @@ def test_overflow_7(self):
opts = [Opt(op=OptOps.UPCAST, axis=3, amt=4), Opt(op=OptOps.LOCAL, axis=3, amt=16), Opt(op=OptOps.UPCAST, axis=1, amt=4), Opt(op=OptOps.LOCAL, axis=2, amt=8), Opt(op=OptOps.UPCAST, axis=1, amt=2), Opt(op=OptOps.UPCAST, axis=2, amt=4)]
_test_overflow(ast, opts)

@unittest.skipIf(Device.DEFAULT not in {"GPU", "HIP", "HSA", "CUDA", "METAL"}, "only backends with locals")
#@unittest.skipIf(Device.DEFAULT not in {"GPU", "HIP", "HSA", "CUDA", "METAL"}, "only backends with locals")
@unittest.skipIf(CI, "slow")
class TestLinearizerOverflowAlt(unittest.TestCase):
def test_overflow_1(self):
BS = 2
Expand Down
2 changes: 0 additions & 2 deletions test/test_symbolic_jit.py
Expand Up @@ -2,12 +2,10 @@

from test.helpers import assert_jit_cache_len
from tinygrad.features.jit import TinyJit
from tinygrad.helpers import getenv
from tinygrad.shape.symbolic import Variable
from tinygrad.tensor import Tensor
import numpy as np

@unittest.skipIf(getenv("ARM64") or getenv("PTX"), "ARM64 and PTX are not supported")
class TestSymbolicJit(unittest.TestCase):
def test_plus1(self):
def f(a): return (a+1).realize()
Expand Down
1 change: 0 additions & 1 deletion test/test_symbolic_ops.py
Expand Up @@ -5,7 +5,6 @@
from examples.gpt2 import Attention
import numpy as np

@unittest.skipIf(getenv("ARM64") or getenv("PTX"), "ARM64 and PTX are not supported")
class TestSymbolicOps(unittest.TestCase):
def test_plus1(self):
def f(a): return (a+1).realize()
Expand Down
56 changes: 40 additions & 16 deletions test/test_uops.py
Expand Up @@ -2,12 +2,11 @@
import unittest, math
import numpy as np
from tinygrad.dtype import dtypes, DType, PtrDType
from tinygrad.helpers import getenv
from tinygrad.device import Buffer, Device
from tinygrad.ops import UnaryOps, BinaryOps, TernaryOps
from tinygrad.device import CompiledASTRunner, Compiled
from tinygrad.codegen.linearizer import UOps, UOp
from tinygrad.codegen.uops import exec_alu
from tinygrad.codegen.uops import exec_alu, UOpGraph
from test.test_dtype import is_dtype_supported

def _uops_to_prg(uops):
Expand All @@ -29,7 +28,7 @@ def _test_single_value(vals, op, dts):
uop(uops, UOps.STORE, None, (buf_store, uop(uops, UOps.CONST, dtypes.int32, (), 0), alu))
buf = Buffer(Device.DEFAULT, 1, output_dtype)
buf2 = [Buffer(Device.DEFAULT, 1, dtype).copyin(np.array([a], dtype=dtype.np).data) for a,dtype in zip(vals, dts)]
prg = _uops_to_prg(uops)
prg = _uops_to_prg(UOpGraph(uops))
prg.exec([buf]+buf2)
ret = np.empty(1, output_dtype.np)
buf.copyout(ret.data)
Expand All @@ -43,7 +42,7 @@ def _test_single_value_const(vals, op, dts):
alu = uop(uops, UOps.ALU, output_dtype, loads, op)
uop(uops, UOps.STORE, None, (buf_store, uop(uops, UOps.CONST, dtypes.int32, (), 0), alu))
buf = Buffer(Device.DEFAULT, 1, output_dtype)
prg = _uops_to_prg(uops)
prg = _uops_to_prg(UOpGraph(uops))
prg.exec([buf])
ret = np.empty(1, output_dtype.np)
buf.copyout(ret.data)
Expand Down Expand Up @@ -88,26 +87,51 @@ def test_cmplt(self): self._test_bop_fxn(BinaryOps.CMPLT, lambda a,b: a<b)
# MOD isn't tested on floats

def test_where(self):
self._test_top_fxn(TernaryOps.WHERE, lambda a,b,c: b if a!=0 else c, (PtrDType(dtypes.bool), PtrDType(dtypes.float), PtrDType(dtypes.float)))
self._test_top_fxn(TernaryOps.WHERE, lambda a,b,c: b if a!=0 else c, (dtypes.bool, dtypes.float, dtypes.float))

# TODO: fix this on all the backends
@unittest.skipIf(not isinstance(Device[Device.DEFAULT], Compiled) or getenv('ARM64', False), "only test for compiled backends, broken on some")
class TestNonFloatUOps(TestUOps):
def test_neg_int32(self): self._test_uop_fxn(UnaryOps.NEG, lambda a: -a, (PtrDType(dtypes.int32), ))
def test_add_int32(self): self._test_bop_fxn(BinaryOps.ADD, lambda a,b: int(a)+int(b), (PtrDType(dtypes.int32), PtrDType(dtypes.int32)))
def test_sub_int32(self): self._test_bop_fxn(BinaryOps.SUB, lambda a,b: int(a)-int(b), (PtrDType(dtypes.int32), PtrDType(dtypes.int32)))
def test_mul_int32(self): self._test_bop_fxn(BinaryOps.MUL, lambda a,b: int(a)*int(b), (PtrDType(dtypes.int32), PtrDType(dtypes.int32)))
def test_neg_int32(self): self._test_uop_fxn(UnaryOps.NEG, lambda a: -a, (dtypes.int32, ))
def test_add_int32(self): self._test_bop_fxn(BinaryOps.ADD, lambda a,b: int(a)+int(b), (dtypes.int32, dtypes.int32))
def test_sub_int32(self): self._test_bop_fxn(BinaryOps.SUB, lambda a,b: int(a)-int(b), (dtypes.int32, dtypes.int32))
def test_mul_int32(self): self._test_bop_fxn(BinaryOps.MUL, lambda a,b: int(a)*int(b), (dtypes.int32, dtypes.int32))
def test_div_int32(self):
self._test_bop_fxn(BinaryOps.DIV, lambda a,b: int(a/b), (PtrDType(dtypes.int32), PtrDType(dtypes.int32)), no_b_zero=True)
self._test_bop_fxn(BinaryOps.DIV, lambda a,b: int(a/b), (dtypes.int32, dtypes.int32), no_b_zero=True)
def test_mod_int32(self):
self._test_bop_fxn(BinaryOps.MOD,
lambda a,b: abs(int(a))%abs(int(b))*(1,-1)[a<0], (PtrDType(dtypes.int32), PtrDType(dtypes.int32)), no_b_zero=True)
def test_cmplt_int32(self): self._test_bop_fxn(BinaryOps.CMPLT, lambda a,b: float(a<b), (PtrDType(dtypes.int32), PtrDType(dtypes.int32)))
lambda a,b: abs(int(a))%abs(int(b))*(1,-1)[a<0], (dtypes.int32, dtypes.int32), no_b_zero=True)
def test_cmplt_int32(self): self._test_bop_fxn(BinaryOps.CMPLT, lambda a,b: float(a<b), (dtypes.int32, dtypes.int32))
@unittest.skipUnless(is_dtype_supported(dtypes.bool), "dtype not supported")
def test_mul_bool(self): self._test_bop_fxn(BinaryOps.MUL, lambda a,b: bool(a) and bool(b), (PtrDType(dtypes.bool), PtrDType(dtypes.bool)))
def test_mul_bool(self): self._test_bop_fxn(BinaryOps.MUL, lambda a,b: bool(a) and bool(b), (dtypes.bool, dtypes.bool))
@unittest.skipUnless(is_dtype_supported(dtypes.float16), "dtype not supported")
def test_where_float16(self):
self._test_top_fxn(TernaryOps.WHERE, lambda a,b,c: b if a!=0 else c, (PtrDType(dtypes.bool), PtrDType(dtypes.float16), PtrDType(dtypes.float16)))
self._test_top_fxn(TernaryOps.WHERE, lambda a,b,c: b if a!=0 else c, (dtypes.bool, dtypes.float16, dtypes.float16))

class TestBoolUOps(TestUOps):
def _test_uop_bool_fxn(self, op, fxn):
for f in [_test_single_value, _test_single_value_const]:
for a in [False, True]:
self._equal(f([a], op, (dtypes.bool, )*1), fxn(a))

def _test_bop_bool_fxn(self, op, fxn):
for f in [_test_single_value, _test_single_value_const]:
for a in [False, True]:
for b in [False, True]:
self._equal(f([a,b], op, (dtypes.bool, )*2), fxn(a,b))

def _test_top_bool_fxn(self, op, fxn):
for f in [_test_single_value, _test_single_value_const]:
for a in [False, True]:
for b in [False, True]:
for c in [False, True]:
self._equal(f([a,b,c], op, (dtypes.bool, )*3), fxn(a,b,c))

def test_not_bool(self): self._test_uop_bool_fxn(UnaryOps.NEG, lambda a: not a)
def test_add_bool(self): self._test_bop_bool_fxn(BinaryOps.ADD, lambda a,b: a or b)
def test_mul_bool(self): self._test_bop_bool_fxn(BinaryOps.MUL, lambda a,b: a and b)
def test_xor_bool(self): self._test_bop_bool_fxn(BinaryOps.XOR, lambda a,b: a != b)
def test_cmpeq_bool(self): self._test_bop_bool_fxn(BinaryOps.CMPEQ, lambda a,b: a == b)
def test_cmplt_bool(self): self._test_bop_bool_fxn(BinaryOps.CMPLT, lambda a,b: a < b)
def test_where_bool(self): self._test_top_bool_fxn(TernaryOps.WHERE, lambda a,b,c: b if a else c)

class TestExecALU(TestUOps):
def test_sqrt(self):
Expand Down
7 changes: 4 additions & 3 deletions tinygrad/codegen/uops.py
Expand Up @@ -56,9 +56,9 @@ def uop_alu_resolve(u:UOp) -> sint:
def phi_resolve_acc(u:UOp) -> UOp: return u if u.uop is UOps.DEFINE_ACC else phi_resolve_acc(u.vin[0])

class UOpGraph:
def __init__(self):
def __init__(self, start_uops:Optional[List[UOp]]=None):
# list of uops
self.uops: List[UOp] = []
self.uops: List[UOp] = [] if start_uops is None else start_uops

# global uop cache
self.saved_exprs: Dict[Tuple, UOp] = dict()
Expand Down Expand Up @@ -88,7 +88,8 @@ def add(self, uop:UOps, dtype:Optional[DType]=None, vin:Tuple[UOp, ...]=tuple(),
if arg is BinaryOps.ADD and vin[1].uop is UOps.ALU and vin[1].arg is UnaryOps.NEG:
return self.add(UOps.ALU, dtype, (vin[0], vin[1].vin[0]), BinaryOps.SUB, cachable, insert_before)
# constant folding
if arg is UnaryOps.NEG and vin[0].uop is UOps.CONST: return self.add(UOps.CONST, dtype, arg=-vin[0].arg, insert_before=insert_before)
if arg is UnaryOps.NEG and vin[0].uop is UOps.CONST:
return self.add(UOps.CONST, dtype, arg=-vin[0].arg if dtype != dtypes.bool else not vin[0].arg, insert_before=insert_before)
if arg is TernaryOps.WHERE and vin[1] == vin[2]: return vin[1] # a conditional with the same results either way is a noop
if arg is TernaryOps.WHERE and vin[0].uop is UOps.CONST: return vin[1] if vin[0].arg else vin[2]
if arg is BinaryOps.MUL and vin[0].uop is UOps.CONST and vin[1].uop is UOps.CONST and dtype is not None and dtypes.is_float(dtype):
Expand Down
2 changes: 1 addition & 1 deletion tinygrad/device.py
Expand Up @@ -238,7 +238,7 @@ def to_program(self, k:Linearizer) -> CompiledASTRunner:
ops, mem = k.uops.flops_mem()
run_count = prod((k.global_size if k.global_size else []) + (k.local_size if k.local_size else []))
# NOTE: we use min here to ignore the indexing FLOPS
ret = CompiledASTRunner(k.name, self.compiler.render(to_function_name(k.name), k.uops.uops), self, k.global_size, k.local_size,
ret = CompiledASTRunner(k.name, self.compiler.render(to_function_name(k.name), k.uops), self, k.global_size, k.local_size,
k.uops.vars(), min(info.flops, ops * run_count), min(info.mem_estimate, mem * run_count))
return ret

Expand Down