In [2]:
import triton
import torch
import triton.language as tl

  from .autonotebook import tqdm as notebook_tqdm


In [26]:
@triton.jit
def _rand2d(
    randval_ptr,
    randmask_ptr,
    stride_m, stride_n,
    M, N,
    p, seed,
    UNROLL: tl.constexpr,
    BLOCK_SIZE: tl.constexpr,
):
    pid = tl.program_id(0)  
    pid_m = pid*UNROLL

    offs_row = tl.arange(0, UNROLL)
    offs_col = tl.arange(0, BLOCK_SIZE)
    rand_offs = offs_row[:,None]*BLOCK_SIZE + offs_col[None, :]
    randval_ptrs = randval_ptr+(pid_m+offs_row[:, None])*stride_m + offs_col[None, :]*stride_n
    randmask_ptrs = randmask_ptr+(pid_m+offs_row[:, None])*stride_m + offs_col[None, :]*stride_n
    mask = ((pid_m+offs_row[:, None])<M) & (offs_col[None, :]<N)

    rand = tl.rand(seed+pid, rand_offs)
    
    tl.store(randval_ptrs, rand, mask=mask)
    tl.store(randmask_ptrs, rand > p, mask=mask)

In [71]:
def rand2d(
        M, N,
        p, seed,
):
    randval = torch.zeros((M, N), dtype=torch.float32, device='cuda')
    randmask = torch.zeros((M, N), dtype=torch.bool, device='cuda')
    assert randval.is_cuda and randmask.is_cuda
    BLOCK_SIZE = triton.next_power_of_2(N)
#     BLOCK_SIZE = N
    UNROLL = 4
    grid = lambda meta: (triton.cdiv(M, UNROLL),)
    _rand2d[grid](randval, randmask, 
                  randval.stride(0), randval.stride(1), 
                  M, N,
                  p, seed, 
                  UNROLL,
                  BLOCK_SIZE)
    print(randval)
    print(randmask)

In [72]:
rand2d(M=2000, N=8024, p=0.5, seed=0)

tensor([[0.7981, 0.0555, 0.0389,  ..., 0.0133, 0.9774, 0.8623],
        [0.7012, 0.0307, 0.9815,  ..., 0.4536, 0.1290, 0.4373],
        [0.4705, 0.4544, 0.1348,  ..., 0.3655, 0.1412, 0.5890],
        ...,
        [0.2110, 0.1895, 0.1564,  ..., 0.8951, 0.9303, 0.3100],
        [0.2296, 0.4710, 0.4035,  ..., 0.5152, 0.1921, 0.2477],
        [0.3409, 0.1286, 0.5049,  ..., 0.9492, 0.3556, 0.2647]],
       device='cuda:0')
tensor([[ True, False, False,  ..., False,  True,  True],
        [ True, False,  True,  ..., False, False, False],
        [False, False, False,  ..., False, False,  True],
        ...,
        [False, False, False,  ...,  True,  True, False],
        [False, False, False,  ...,  True, False, False],
        [False, False,  True,  ...,  True, False, False]], device='cuda:0')


In [76]:
@triton.jit
def _dropout(
    input_ptr,
    output_ptr,
    drop_mask_ptr,
    stride_m, stride_n,
    # drop_mask_group_size: tl.constexpr,
    M: tl.constexpr, N: tl.constexpr,
    p: tl.constexpr, seed: tl.constexpr,
    UNROLL: tl.constexpr,
    BLOCK_SIZE: tl.constexpr,
    drop_mask_group_size: tl.constexpr
):
    pid = tl.program_id(0)  
    pid_m = pid*UNROLL

    offs_row = tl.arange(0, UNROLL)
    offs_col = tl.arange(0, BLOCK_SIZE)
    
    # block mask
    rowscols_mask = ((pid_m+offs_row[:, None])<M) & (offs_col[None, :]<N)
    input_ptrs = input_ptr + (pid_m+offs_row[:, None])*stride_m + offs_col[None, :]*stride_n
    inputdata = tl.load(input_ptrs, mask=rowscols_mask)

    # generate rand and decide which to drop
    # rand_offs = offs_row[:,None]*BLOCK_SIZE + offs_col[None, :]
    rand_offs = tl.arange(0,UNROLL)*N + tl.arange(0, N)
    rand_mask = tl.rand(seed+pid, rand_offs)>p
    # unsupported slice on constpr tensor
    # rand_mask = rand_mask[:UNROLL, :N]
    #
    output_ptrs = output_ptr + (pid_m+offs_row[:, None])*stride_m + offs_col[None, :]*stride_n
    output = tl.where(rand_mask, inputdata/(1-p), 0.0)

    # compress the rand_mask to shared_drop_mask
    shared_drop_mask = tl.zeros((1,drop_mask_group_size), dtype=tl.uint8)
    
    rand_mask = rand_mask.to(tl.uint8)
    rand_mask = tl.ravel(rand_mask)
    append_0_num = tl.cdiv(N*UNROLL, 8)*8-N*UNROLL
    rand_mask = tl.cat((tl.view(rand_mask, (1, -1)), tl.zeros((1,append_0_num), dtype=tl.uint8)), dim=1) 
    rand_mask = tl.view(rand_mask, (-1, 8))
    weights = (2**tl.arange(0,8))[None, :]
    rand_mask = rand_mask * weights
    rand_mask = tl.sum(rand_mask, axis=1)
    # drop_mask_group_size = tl.cdiv(N*UNROLL, 8)
    
    drop_mask_ptrs = drop_mask_ptr + pid*drop_mask_group_size + tl.arange(0, drop_mask_group_size)

    tl.store(drop_mask_ptrs, rand_mask)
    tl.store(output_ptrs, output, mask=rowscols_mask)

In [115]:
@triton.jit
def _dropout(
    input_ptr,
    output_ptr,
    drop_mask_ptr,
    stride_m, stride_n,
    # drop_mask_group_size: tl.constexpr,
    M: tl.constexpr, N: tl.constexpr,
    p: tl.constexpr, seed: tl.constexpr,
    UNROLL: tl.constexpr,
    BLOCK_SIZE: tl.constexpr,
    drop_mask_group_size: tl.constexpr
):
    pid = tl.program_id(0)  
    pid_m = pid*UNROLL

    offs_row = tl.arange(0, UNROLL)
    offs_col = tl.arange(0, BLOCK_SIZE)
    
    # block mask
    rowscols_mask = ((pid_m+offs_row[:, None])<M) & (offs_col[None, :]<N)
    input_ptrs = input_ptr + (pid_m+offs_row[:, None])*stride_m + offs_col[None, :]*stride_n
    inputdata = tl.load(input_ptrs, mask=rowscols_mask)

    # generate rand and decide which to drop
    # rand_offs = offs_row[:,None]*BLOCK_SIZE + offs_col[None, :]
    rand_offs = tl.arange(0,UNROLL)*N + tl.arange(0, N)
    rand_mask = tl.rand(seed+pid, rand_offs)>p
    # tl.device_print('mask', rand_mask.dtype)
    rand_mask = rand_mask.to(tl.uint8)
    # tl.device_print(rand_mask)
    rand_mask = tl.view(rand_mask, (1,-1))
    # unsupported slice on constpr tensor
    # rand_mask = rand_mask[:UNROLL, :N]

In [107]:
def dropout(
        inputdata,
        p, seed,
):
    output = torch.zeros_like(inputdata)
    M, N = output.shape
    UNROLL = 4
    drop_mask_group_size = triton.cdiv(N*UNROLL, 8)
    drop_mask = torch.zeros(triton.cdiv(M, UNROLL)*drop_mask_group_size, dtype=torch.uint8, device='cuda')
    assert inputdata.is_cuda and output.is_cuda and drop_mask.is_cuda

    BLOCK_SIZE = triton.next_power_of_2(N)
    UNROLL = 4
    grid = lambda meta: (triton.cdiv(M, UNROLL),)
    _dropout[grid](
        inputdata, output, drop_mask,
        inputdata.stride(0), inputdata.stride(1),
        M,N,
        p, seed,
        UNROLL,
        BLOCK_SIZE,
        drop_mask_group_size=drop_mask_group_size
    )

In [116]:
M, N = 5, 4
inputdata = torch.rand((M, N), dtype=torch.float32, device='cuda')
dropout(inputdata, p=0.5, seed=0)

CompilationError: at 31:39:    rowscols_mask = ((pid_m+offs_row[:, None])<M) & (offs_col[None, :]<N)
    input_ptrs = input_ptr + (pid_m+offs_row[:, None])*stride_m + offs_col[None, :]*stride_n
    inputdata = tl.load(input_ptrs, mask=rowscols_mask)

    # generate rand and decide which to drop
    # rand_offs = offs_row[:,None]*BLOCK_SIZE + offs_col[None, :]
    rand_offs = tl.arange(0,UNROLL)*N + tl.arange(0, N)
    rand_mask = tl.rand(seed+pid, rand_offs)>p
    # tl.device_print('mask', rand_mask.dtype)
    rand_mask = rand_mask.to(tl.uint8)
    # tl.device_print(rand_mask)
    rand_mask = tl.view(rand_mask, (1,-1))
                                       ^
ValueError('cannot view block of different shape')

In [32]:
output = torch.zeros(10, dtype=torch.int8)
dataindex = torch.cat((torch.arange(10)[:, None], (torch.arange(10)*3)[:, None]), dim=1)
print(dataindex) 
print(dataindex[:, 0])
output[dataindex[:, 0]] += dataindex[:, 1]
print(output)
output[dataindex[:, 0]] += dataindex[:, 1]
print(output)

tensor([[ 0,  0],
        [ 1,  3],
        [ 2,  6],
        [ 3,  9],
        [ 4, 12],
        [ 5, 15],
        [ 6, 18],
        [ 7, 21],
        [ 8, 24],
        [ 9, 27]])
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])
tensor([ 0,  3,  6,  9, 12, 15, 18, 21, 24, 27], dtype=torch.int8)
tensor([ 0,  6, 12, 18, 24, 30, 36, 42, 48, 54], dtype=torch.int8)


In [43]:
randmask = torch.rand((4,9), dtype=torch.float32)>0.5
randmask = torch.cat((randmask.reshape(-1), torch.zeros((4*9+7)//8*8-4*9, dtype=torch.bool)))
print(randmask)
print(randmask.shape)
randmask = randmask.reshape(-1, 8).int()
weights = torch.pow(2, torch.arange(8, dtype=torch.uint8))
print(weights)
print(randmask.reshape(-1, 8))
randmask = randmask.reshape(-1, 8)*weights[None, :] 
print(randmask)
print(randmask.shape)
randmask = torch.sum(randmask, dim=1)
print(randmask)
drop_mask = torch.zeros((4*9+7)//8, dtype=torch.int8)
drop_mask = randmask


tensor([ True, False,  True, False, False,  True, False, False,  True,  True,
        False, False,  True,  True,  True,  True,  True, False,  True,  True,
         True, False, False,  True,  True, False, False, False,  True, False,
         True, False, False,  True,  True,  True, False, False, False, False])
torch.Size([40])
tensor([  1,   2,   4,   8,  16,  32,  64, 128], dtype=torch.uint8)
tensor([[1, 0, 1, 0, 0, 1, 0, 0],
        [1, 1, 0, 0, 1, 1, 1, 1],
        [1, 0, 1, 1, 1, 0, 0, 1],
        [1, 0, 0, 0, 1, 0, 1, 0],
        [0, 1, 1, 1, 0, 0, 0, 0]], dtype=torch.int32)
tensor([[  1,   0,   4,   0,   0,  32,   0,   0],
        [  1,   2,   0,   0,  16,  32,  64, 128],
        [  1,   0,   4,   8,  16,   0,   0, 128],
        [  1,   0,   0,   0,  16,   0,  64,   0],
        [  0,   2,   4,   8,   0,   0,   0,   0]], dtype=torch.int32)
torch.Size([5, 8])
tensor([ 37, 243, 157,  81,  14])


In [45]:
bits = torch.pow(2, torch.arange(8, dtype=torch.uint8))
print(bits) 
print(drop_mask[:, None])
print(drop_mask[:, None]&bits>0)

tensor([  1,   2,   4,   8,  16,  32,  64, 128], dtype=torch.uint8)
tensor([[ 37],
        [243],
        [157],
        [ 81],
        [ 14]])
tensor([[ True, False,  True, False, False,  True, False, False],
        [ True,  True, False, False,  True,  True,  True,  True],
        [ True, False,  True,  True,  True, False, False,  True],
        [ True, False, False, False,  True, False,  True, False],
        [False,  True,  True,  True, False, False, False, False]])
