Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

symbolic codegen and exec #1552

Merged
merged 5 commits into from
Aug 16, 2023
Merged

symbolic codegen and exec #1552

merged 5 commits into from
Aug 16, 2023

Conversation

chenyuxyz
Copy link
Collaborator

part of #1353 , codegen and exec to implement realize for symbolic inputs.

The combined var_vals are passed into kernel function directly. I have implemented the backend for CLANG, GPU, METAL. global_size, local_size and op_estimate are computed to int during exec. In addition to CI, I also tested with DEBUG=4 python -m pytest -rA test/test_symbolic_ops.py to make sure all debugging info renders with symbols.

Most of the hand coded optimization works, and I need to disable the upcast ones because we cannot upcast a symbolic axis.

Also added some test cases for 2 variables. Want to make it as general as possible.

These two examples show that we need to support symbolic for loop max and symbolic offset

kernel for matmul (3, vi) @ (vi, 5)

#include <metal_stdlib>
using namespace metal;
kernel void r_5_3_s0(device float* data0, const device float* data1, const device float* data2, constant int& i, uint3 gid [[threadgroup_position_in_grid]], uint3 lid [[thread_position_in_threadgroup]]) {
{ int gidx0 = gid.x;  /* 5 */
  { int lidx1 = lid.x;  /* 3 */
    float acc0_0 = 0.0f;
    for (int ridx2 = 0; ridx2 <= (-1+i); ++ridx2) {
      float val1_0 = *(data1+(lidx1*i)+ridx2);
      float val2_0 = *(data2+(ridx2*5)+gidx0);
      acc0_0 = ((val1_0*val2_0)+acc0_0);
    } /* reduce */
    *(data0+(lidx1*5)+gidx0) = acc0_0;
  }} /* global+local */
}

kernel for (3, vi).cat((3, vj), dim=1)

#include <metal_stdlib>
using namespace metal;
kernel void E_s0_3n14(device float* data0, const device float* data1, const device float* data2, constant int& i, constant int& j, uint3 gid [[threadgroup_position_in_grid]], uint3 lid [[thread_position_in_threadgroup]]) {
{ int gidx0 = gid.x;  /* <(i[1-10]+j[1-10])> */
  { int lidx1 = lid.x;  /* 3 */
    float val1_0 = ((gidx0<i))?(*(data1+(lidx1*i)+gidx0)):0.0f;
    float val2_0 = (((gidx0*-1)<((i*-1)+1)))?(*(data2+(i*-1)+(lidx1*j)+gidx0)):0.0f;
    float alu0 = (val1_0+val2_0);
    *(data0+(lidx1*(i+j))+gidx0) = alu0;
  }} /* global+local */
}

@tinyb0t
Copy link

tinyb0t commented Aug 16, 2023

Changes made in tinygrad/:

------------------------------------------------------------
files                             insertions       deletions
------------------------------------------------------------
tinygrad/codegen/linearizer.py             6               3
tinygrad/codegen/optimizer.py              4               4
tinygrad/helpers.py                        6               1
tinygrad/ops.py                           21              14
tinygrad/renderer/cstyle.py                5               3
tinygrad/renderer/wgsl.py                  1               1
tinygrad/runtime/ops_clang.py              2               2
tinygrad/runtime/ops_gpu.py                2               2
tinygrad/runtime/ops_metal.py              6               3
tinygrad/shape/shapetracker.py             3               3
tinygrad/shape/symbolic.py                 9               0
------------------------------------------------------------
total                                     65              36
------------------------------------------------------------
lines added in the tinygrad folder: 29

@@ -263,10 +265,10 @@ jobs:
run: python -c "from tinygrad.lazy import Device; assert Device.DEFAULT in ['LLVM','CLANG','CUDA','GPU'], Device.DEFAULT"
- name: Run pytest (not cuda)
if: matrix.backend!='cuda'
run: python -m pytest -n=auto test/ -k '${{matrix.backend=='llvm'&&'not (test_nn.py and test_conv_transpose2d)'||'test'}}' -m 'not exclude_${{matrix.backend}}'
run: CI=1 python -m pytest -n=auto test/ -k '${{matrix.backend=='llvm'&&'not (test_nn.py and test_conv_transpose2d)'||'test'}}' -m 'not exclude_${{matrix.backend}}'
Copy link
Collaborator

@geohot geohot Aug 16, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You shouldn't need this, CI should be set automatically by CI.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

got it - removed

@@ -106,5 +106,17 @@ def test_context_exit_reverts_updated_values(self):
...
assert D.value == 2, f"Expected D to be 2, but was {D.value}. Indicates that Context.__exit__ did not restore to the correct value."

class TestMergeDcits(unittest.TestCase):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

typo

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed

@@ -22,6 +22,10 @@ def make_pair(x:Union[int, Tuple[int, ...]], cnt=2) -> Tuple[int, ...]: return (
def flatten(l:Iterator): return [item for sublist in l for item in sublist]
def mnum(i) -> str: return str(i) if i >= 0 else f"m{-i}"
def fromimport(mod, frm): return getattr(__import__(mod, fromlist=[frm]), frm)
def merge_dicts(ds):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a reason you can't use dict1.update(dict2) to merge dicts?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

update does not check if value conflicts, added a test case that (3, vi) @ (vi, 5) should fail if they hold different values for vi

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add a type signature to this function?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

@geohot
Copy link
Collaborator

geohot commented Aug 16, 2023

Why is it constant int& and not just int?

@chenyuxyz chenyuxyz marked this pull request as draft August 16, 2023 19:53
@chenyuxyz
Copy link
Collaborator Author

int does not work

E     AssertionError: Error Domain=MTLLibraryErrorDomain Code=3 "program_source:3:98: error: invalid type 'int' for input declaration in a kernel function
E     kernel void r_s0_s3_3(device float* data0, const device float* data1, const device float* data2, int i, int j, uint3 gid [[threadgroup_position_in_grid]], uint3 lid [[thread_position_in_threadgroup]]) {
E                                                                                                      ^~~~~
E     program_source:3:105: error: invalid type 'int' for input declaration in a kernel function

I took this from tracing how PyFR implemented it for metal.

My understanding from the metal spec section 4 is that kernel argument needs an address space attribute, and it only accepts pointer or reference

@chenyuxyz chenyuxyz marked this pull request as ready for review August 16, 2023 20:27
@geohot
Copy link
Collaborator

geohot commented Aug 16, 2023

So this isn't true for OpenCL...but maybe it's fine. PyOpenCL has set_scalar_arg_dtypes to not do the np.int32 thing, which is really slow.

@@ -162,7 +162,7 @@ def real_strides(self, ignore_valid=False) -> Tuple[Optional[Union[Node, int]],
idx, valid = self.expr_idxs(idxs)
ret: List[Optional[Union[Node, int]]] = [None] * len(self.views[-1].shape)
for this_dim in (idx.nodes if isinstance(idx, SumNode) else [idx]):
if isinstance(this_dim, MulNode) and isinstance(this_dim.a, Variable):
if isinstance(this_dim, MulNode) and isinstance(this_dim.a, Variable) and str(this_dim.a.expr).startswith("idx"):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks sketchy using a string compare. Is there a better way to write this?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

good call - updated

return f"for (int {expr} = {_min}; {expr} <= {_max}; ++{expr}) {{"

def render_conditional(self, cond: str, x:str, y:str) -> str:
return f"({cond})?({x}):{y}"

def render_kernel(self, function_name:str, kernel:List[str], bufs:List[Tuple[str,DType]], global_size:List[int], local_size:List[int], prekernel:List[str]) -> Tuple[str,List[int],List[int]]:
tmp = "const sampler_t smp = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST;\n" if any(isinstance(dtype, ImageDType) for _,dtype in bufs) else ""
buftypes = [(name,f"{'read_only' if i > 0 else 'write_only'} image2d_t" if dtype.name.startswith('image') else
buftypes = [(name[8:], self.arg_int_prefix) if name.startswith("ARG_INT_") else
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

String stuff is always sketchy, is there a better way to do this?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

added an internal dtype for int argument and used that instead

@geohot
Copy link
Collaborator

geohot commented Aug 16, 2023

It's close, but I don't like the string compare stuff, particularly the one in shapetracker.

@chenyuxyz chenyuxyz marked this pull request as draft August 16, 2023 20:51
@chenyuxyz
Copy link
Collaborator Author

re: PyOpenCL and set_scalar_arg_dtypes. I suspect under the hood it's doing similar thing even with np.int32. I profiled the llama code in the main PR and the profiler indicated that a lot of time were on scalar arg related internal function. The generation wall time is lower so I am not sure if it's PyOpenCL issue or something else

@chenyuxyz chenyuxyz marked this pull request as ready for review August 16, 2023 21:40
@geohot geohot merged commit 11dd9b1 into tinygrad:master Aug 16, 2023
12 checks passed
@geohot
Copy link
Collaborator

geohot commented Aug 16, 2023

Cool, merged!

@chenyuxyz chenyuxyz deleted the jit-llama-3 branch September 1, 2023 17:26
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants