Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

How to perform a store operation on a part of a Tensor? #4080

Closed
YKTian-x2b opened this issue Jun 5, 2024 · 1 comment
Closed

How to perform a store operation on a part of a Tensor? #4080

YKTian-x2b opened this issue Jun 5, 2024 · 1 comment

Comments

@YKTian-x2b
Copy link

YKTian-x2b commented Jun 5, 2024

Store operation on a part of a Tensor, what I want:

accum_outs = tl.zeros([N], dtype=tl.float32)
for col_off in range(0, N, BLOCK_SIZE):
    cols = col_off + tl.arange(0, BLOCK_SIZE)
    mask = cols < N
    a_eles = tl.load(a_ptr + cols, mask=mask, other=0.0)
    b_eles = tl.load(b_ptr + cols, mask=mask, other=0.0)
    # How to implement the following: 
    accum_outs[col_off: col_off+BLOCK_SIZE] = a_eles * b_eles

Because it's going to use "accum_outs" later and I don't want to store it back like following:

for col_off in range(0, N, BLOCK_SIZE):
    a_eles = tl.load(a_ptr + cols, mask=mask, other=0.0)
    b_eles = tl.load(b_ptr + cols, mask=mask, other=0.0)
    tl.store(accum_res_ptr + cols, a_eles * b_eles, mask=mask)

# tl.ops that must be done outside the loop

for col_off in range(0, N, BLOCK_SIZE):
    eles = tl.load(accum_res_ptr + cols, mask=mask, other=0.0)
    ...
@thumbe3
Copy link

thumbe3 commented Jun 5, 2024

At the moment you can't perform stores on part of tensors as far as I know. You can load a and b using 2d indexing and unravelling the output something like this for your use-case. NUM_BLOCKS should be passed as a tl.constexpr and should be the next power of 2 for (N + BLOCK_SIZE - 1//BLOCK_SIZE)

    num_blocks = tl.cdiv(N, BLOCK_SIZE)
    block_offs = tl.arange(0, NUM_BLOCKS)
    per_block_offs = tl.arange(0, BLOCK_SIZE)
    all_offs = block_offs[None, :] * BLOCK_SIZE + per_block_offs[:, None]
    
    # 2d pointers of shape [NUM_BLOCKS, BLOCK_SIZE]
    a = tl.load(a_ptr + all_offs,
                mask=all_offs < N and block_offs[None, :] < num_blocks,
                other=0.0)
    b = tl.load(b_ptr + all_offs,
                mask=all_offs  < N and block_offs[None, :] < num_blocks,
                other=0.0)
  
    # [NUM_BLOCKS, BLOCK_SIZE] --> NUM_BLOCKS * BLOCK_SIZE in which N elements will be filled others = 0
    accum_outs = tl.ravel(a * b)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants