Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

tl.dot() input dimensions #2266

Open
var-nan opened this issue Sep 7, 2023 · 2 comments
Open

tl.dot() input dimensions #2266

var-nan opened this issue Sep 7, 2023 · 2 comments

Comments

@var-nan
Copy link

var-nan commented Sep 7, 2023

Do all the input dimensions must be greater than 16 for tl.dot() operation?

I'm trying to do tl.dot(a,b) between a 2D matrix (shape: [16,16]) and 1D vector (shape[16,1]). I'm getting a compilation error as
AssertionError('All values in both first input shape ([constexpr[16], constexpr[16]]) and second input shape ([constexpr[16], constexpr[1]]) must be >= 16!')

I'm not sure whether I'm doing it wrong or Triton doesn't allow dot on 1D vectors.

Below is the code for reference.

@triton.jit
def kernel_opt(A_in, 
                    A_out, 
                    kernel, kernel_size_sq:tl.constexpr, 
                    BLOCK_SIZE: tl.constexpr):
    
    """ A_in (shape: [M,N]), A_out (shape: [M',1]), kernel (shape: [N,1]) """

    pid = tl.program_id(0)
    
    stride_x = kernel_size_sq
    offsets_x = tl.arange(0,kernel_size_sq)
    offsets_y = tl.arange(0, BLOCK_SIZE)[:,None] * stride_x
    
    data_pointers = A_in + pid* kernel_size_sq*BLOCK_SIZE + offsets_y + offsets_x
    kernel_pointers = kernel_pointer + offsets_x[None,:]
    
    # load
    load_mat = tl.load(data_pointers)
    load_kernel = tl.load(kernel_pointers)

    # compute
    result = tl.dot(load_mat, tl.trans(load_kernel)) # result should be column vector
    
    # store result
    out_pointers = A_out  + pid*BLOCK_SIZE + offsets_x[:,None]
    tl.store(out_pointers, result)
@CHDev93
Copy link

CHDev93 commented Sep 12, 2023

Have experienced the same. Triton does put certain constraints on inputs to its functions. Large-ish blocks and power of 2 dimensions are the ones I regularly encounter. I think it's more efficient to tile up things like matrix multiplication into blocks rather than matrix-vector operations (see the L2 cache optimisations section).

@jon-chuang
Copy link
Contributor

jon-chuang commented Sep 14, 2023

It seems reasonable that in this case, Triton would not use mma instructions, but rather ordinary FMA instructions. This, however, appears to be unimplemented. The list of supported sizes is here.

To my understanding, Triton also does not support optimizing other "edge-cases" when it comes to dot perf, for instance tall-and-skinny matmuls.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants