In [1]:
import sys
sys.path.append("/home/jovyan/shares/SR006.nfs1/sivtsov/armt")

In [2]:
import torch
import torch.nn as nn

torch.set_grad_enabled(False)

<torch.autograd.grad_mode.set_grad_enabled at 0x7fdfdd544710>

In [3]:
from group_gemm import group_gemm_fn

In [4]:
hid_dim = 512
segment_size = 128
mem_size = 128

proj_dim = hid_dim

device = "cuda"

In [5]:
memory = torch.rand((mem_size, hid_dim), device=device, dtype=torch.float32)
prev_segment = torch.rand((segment_size, hid_dim), device=device, dtype=torch.float32)
segment = torch.rand((segment_size, hid_dim), device=device, dtype=torch.float32)

In [6]:
q_proj = torch.rand((hid_dim, proj_dim), device=device, dtype=torch.float32)
k_proj = torch.rand((hid_dim, proj_dim), device=device, dtype=torch.float32)
v_proj = torch.rand((hid_dim, hid_dim), device=device, dtype=torch.float32)

In [7]:
concated_ref = torch.concat([memory, segment])

In [8]:
r = concated_ref @ q_proj

In [9]:
r

tensor([[129.8166, 133.8074, 119.1817,  ..., 127.0881, 127.8358, 122.8069],
        [121.9021, 125.1182, 115.5519,  ..., 125.2113, 121.3747, 120.9387],
        [129.3303, 129.0235, 118.8091,  ..., 126.3715, 126.9246, 122.2639],
        ...,
        [126.2679, 131.9058, 116.8035,  ..., 121.5519, 122.8329, 120.3719],
        [128.5999, 129.5410, 116.9284,  ..., 125.7963, 123.7666, 121.1420],
        [122.5388, 132.4052, 121.2210,  ..., 129.2840, 129.2532, 120.6156]],
       device='cuda:0')

In [10]:
q_list = group_gemm_fn([memory, segment], [q_proj, q_proj])
k_list = group_gemm_fn([memory, segment], [k_proj, k_proj])

In [11]:
# ([0, 0, 1, 1], [0, 1, 0, 1])

q_attn_mul = [q_list[i] for i in range(len(q_list)) for j in range(len(k_list))]
k_tr_attn_mul = [k_list[j].T.contiguous() for i in range(len(q_list)) for j in range(len(k_list))]

In [12]:
[bl.shape for bl in q_attn_mul], [bl.shape for bl in k_tr_attn_mul]

([torch.Size([128, 512]),
  torch.Size([128, 512]),
  torch.Size([128, 512]),
  torch.Size([128, 512])],
 [torch.Size([512, 128]),
  torch.Size([512, 128]),
  torch.Size([512, 128]),
  torch.Size([512, 128])])

In [13]:
# rowwise attention
attn_list = group_gemm_fn(q_attn_mul, k_tr_attn_mul)

In [14]:
[bl.shape for bl in attn_list]

[torch.Size([128, 128]),
 torch.Size([128, 128]),
 torch.Size([128, 128]),
 torch.Size([128, 128])]

In [15]:
# just for aligning
attn_list = [100*a/a.mean() for a in attn_list]

In [16]:
attn_list_matr = [attn_list[i:i+2] for i in range(0, 4, 2)]

In [17]:
[[bl.shape for bl in row] for row in attn_list_matr]

[[torch.Size([128, 128]), torch.Size([128, 128])],
 [torch.Size([128, 128]), torch.Size([128, 128])]]

In [18]:
def naive_softmax(x):
    """Compute row-wise softmax of X using native pytorch

    We subtract the maximum element in order to avoid overflows. Softmax is invariant to
    this shift.
    """
    # read  MN elements ; write M  elements
    x_max = x.max(dim=1)[0]
    # read MN + M elements ; write MN elements
    z = x - x_max[:, None]
    # read  MN elements ; write MN elements
    numerator = torch.exp(z)
    # read  MN elements ; write M  elements
    denominator = numerator.sum(dim=1)
    # read MN + M elements ; write MN elements
    ret = numerator / denominator[:, None]
    # in total: read 5MN + 2M elements ; wrote 3MN + 2M elements
    return ret

def concat_block_matr(matr):
    return torch.concat(
        [torch.concat(el, dim=1) for el in matr], 
        dim=0
    )

In [19]:
from functools import reduce
def block_softmax(x_matr):
    """Compute row-wise softmax of X using native pytorch

    We subtract the maximum element in order to avoid overflows. Softmax is invariant to
    this shift.
    """
    # read  MN elements ; write M  elements
    block_max_matr = [[bl.max(dim=1)[0] for bl in row] for row in x_matr]
    x_max_matr = [reduce(lambda x,y: torch.max(x,y), row) for row in block_max_matr]
    # read MN + M elements ; write MN elements
    z_matr = [[block-max_row[:, None] for block in row] for row, max_row in zip(x_matr, x_max_matr)]
    # read  MN elements ; write MN elements
    numerator_matr = [[torch.exp(bl) for bl in row] for row in z_matr]
    # read  MN elements ; write M  elements
    denominator_vec = [sum([bl.sum(dim=1) for bl in row]) for row in numerator_matr]
    # read MN + M elements ; write MN elements
    ret_matr = [[bl / denom_bl[:, None] for bl in row] for row, denom_bl in zip(numerator_matr, denominator_vec)]
    # in total: read 5MN + 2M elements ; wrote 3MN + 2M elements
    return ret_matr


In [20]:
[[bl.shape for bl in row] for row in attn_list_matr]

[[torch.Size([128, 128]), torch.Size([128, 128])],
 [torch.Size([128, 128]), torch.Size([128, 128])]]

In [21]:
block_impl = block_softmax(attn_list_matr)
concat_impl = naive_softmax(concat_block_matr(attn_list_matr))

In [22]:
torch.allclose(concat_block_matr(block_impl), concat_impl, atol=1e-2, rtol=0)

True

In [23]:
concat_block_matr(block_impl), concat_impl

(tensor([[3.1463e-04, 2.9836e-06, 5.5897e-05,  ..., 6.8722e-06, 5.2913e-05,
          1.3089e-04],
         [3.8445e-04, 4.5194e-06, 7.3921e-05,  ..., 1.0048e-05, 7.0312e-05,
          1.6648e-04],
         [3.3720e-04, 3.4580e-06, 6.1579e-05,  ..., 7.8434e-06, 5.8520e-05,
          1.4254e-04],
         ...,
         [3.6734e-04, 4.1386e-06, 6.9611e-05,  ..., 9.2297e-06, 6.6116e-05,
          1.5825e-04],
         [3.3702e-04, 3.4340e-06, 6.1504e-05,  ..., 7.8204e-06, 5.8175e-05,
          1.4222e-04],
         [3.2280e-04, 3.1444e-06, 5.7867e-05,  ..., 7.2134e-06, 5.4843e-05,
          1.3492e-04]], device='cuda:0'),
 tensor([[3.1463e-04, 2.9836e-06, 5.5897e-05,  ..., 6.8722e-06, 5.2913e-05,
          1.3089e-04],
         [3.8445e-04, 4.5194e-06, 7.3921e-05,  ..., 1.0048e-05, 7.0312e-05,
          1.6648e-04],
         [3.3720e-04, 3.4580e-06, 6.1579e-05,  ..., 7.8434e-06, 5.8520e-05,
          1.4254e-04],
         ...,
         [3.6734e-04, 4.1386e-06, 6.9611e-05,  ..., 9.2297e-06

In [24]:
def multiply_matr_on_vec(matr, vec):
    assert len(matr[0]) == len(vec), "be carefull, not implemented"
    result_acc = None # [torch.zeros_like(t) for t in vec]
    
    for i in range(len(vec)):
        l = [row[i] for row in matr]
        r = [vec[i] for _ in range(len(matr))]
        
        partial = group_gemm_fn(l, r)

        if result_acc is None:
            result_acc = partial
        else:
            result_acc = [acc + partial for acc, partial in zip(result_acc, partial)]

    return result_acc

In [25]:
attn_matr = block_softmax(attn_list_matr)
attn_concat = concat_block_matr(attn_matr)

# v_projdd
v_mart = [memory, segment]
v = torch.concat(v_mart)

result_ref = attn_concat @ v

In [26]:
result_acc = multiply_matr_on_vec(attn_matr, v_mart)

In [27]:
torch.allclose(torch.concat(result_acc), result_ref, atol=1e-2, rtol=0)

True

In [28]:
result_acc

[tensor([[0.3752, 0.5069, 0.5463,  ..., 0.5801, 0.5423, 0.4103],
         [0.3845, 0.5047, 0.5441,  ..., 0.5775, 0.5420, 0.4160],
         [0.3785, 0.5061, 0.5454,  ..., 0.5792, 0.5422, 0.4122],
         ...,
         [0.3792, 0.5060, 0.5453,  ..., 0.5790, 0.5423, 0.4127],
         [0.3712, 0.5079, 0.5472,  ..., 0.5811, 0.5424, 0.4078],
         [0.3756, 0.5069, 0.5461,  ..., 0.5801, 0.5424, 0.4105]],
        device='cuda:0'),
 tensor([[0.3770, 0.5064, 0.5457,  ..., 0.5796, 0.5423, 0.4114],
         [0.3729, 0.5074, 0.5468,  ..., 0.5807, 0.5424, 0.4089],
         [0.3709, 0.5079, 0.5473,  ..., 0.5812, 0.5424, 0.4076],
         ...,
         [0.3825, 0.5052, 0.5445,  ..., 0.5781, 0.5422, 0.4148],
         [0.3784, 0.5061, 0.5454,  ..., 0.5793, 0.5421, 0.4121],
         [0.3763, 0.5066, 0.5460,  ..., 0.5796, 0.5422, 0.4109]],
        device='cuda:0')]

### Generate input for layer

inp - memory, prev_segment, segment   
weights - q_proj, k_proj, v_proj

In [30]:
hid_dim = 512
segment_size = 128
mem_size = 128

proj_dim = hid_dim

device = "cuda"

In [31]:
memory = torch.rand((mem_size, hid_dim), device=device, dtype=torch.float32)
prev_segment = torch.rand((segment_size, hid_dim), device=device, dtype=torch.float32)
segment = torch.rand((segment_size, hid_dim), device=device, dtype=torch.float32)

q_proj = torch.rand((hid_dim, proj_dim), device=device, dtype=torch.float32)
k_proj = torch.rand((hid_dim, proj_dim), device=device, dtype=torch.float32)
v_proj = torch.rand((hid_dim, hid_dim), device=device, dtype=torch.float32)

In [32]:
q_list = group_gemm_fn([memory, segment], [q_proj, q_proj])
k_list = group_gemm_fn([memory, segment], [k_proj, k_proj])
v_list = group_gemm_fn([memory, segment], [v_proj, v_proj])

### Generate blocked attention matrix

In [33]:
# [memory, segment] x [prev, memory, segment]
ones_bl = torch.ones(())
attn_mask_matr = [
    [torch.zeros((mem_size, segment_size)), torch.ones((mem_size, mem_size)), torch.ones((mem_size, segment_size))],
    [torch.ones((segment_size, segment_size)).triu(), torch.ones((segment_size, mem_size)), torch.ones((segment_size, segment_size)).tril()],
]

attn_mask_matr = [[bl.type(torch.bool).to("cuda") for bl in row] for row in attn_mask_matr]

# effectivelly, window size "segment+1", tune accordingly

# 000 111 111
# 000 111 111
# 000 111 111

# 111 111 100
# 011 111 110
# 001 111 111

## Make simple armt attention with merging input

In [34]:
class SimpleMergeArmtAttention(nn.Module):
    def __init__(self, k_proj, q_proj, v_proj):
        self.k_proj = k_proj
        self.q_proj = q_proj
        self.v_proj = v_proj

        super().__init__()

    def forward(self, segment, memory, prev_kv_proj = None, attn_mask = None):
        memory_seq_len = memory.shape[-2]
        segment_seq_len = segment.shape[-2]
        
        mem_segm_merged = torch.concat((memory, segment), dim=-2)

        q = mem_segm_merged @ q_proj
        k = mem_segm_merged @ k_proj
        v = mem_segm_merged @ v_proj

        if prev_kv_proj is not None:
            prev_segment_proj_k, prev_segment_proj_v = prev_kv_proj
            k = torch.concat((prev_segment_proj_k, k), dim=-2)
            v = torch.concat((prev_segment_proj_v, v), dim=-2)

        attn = q@k.T
        dk = k.shape[-1]
        attn = attn / dk**0.5

        if attn_mask is not None:
            attn.masked_fill_(~attn_mask, -1e9)

        attn = torch.softmax(attn, dim=-1)
        out = attn @ v

        out_segment = out[..., -segment_seq_len:, :]
        out_memory = out[..., -segment_seq_len-memory_seq_len:-segment_seq_len, :]

        cache_k_activation = k[..., -segment_seq_len:, :]
        cache_v_activation = v[..., -segment_seq_len:, :]
        return out_segment, out_memory, [cache_k_activation, cache_v_activation]

In [35]:
attn_mask = concat_block_matr(attn_mask_matr)
attn_mask_matr_first = [el[1:] for el in attn_mask_matr]
attn_mask_first = concat_block_matr(attn_mask_matr_first)

In [36]:
simple_armt = SimpleMergeArmtAttention(k_proj, q_proj, v_proj)

In [37]:
out_segment, out_memory, segm_kv = simple_armt(segment, memory, prev_kv_proj = None, attn_mask = attn_mask_first)

In [38]:
out_segment, out_memory, segm_kv  = simple_armt(segment, out_memory, prev_kv_proj = segm_kv, attn_mask = attn_mask)

## Make group gemm armt attention

In [39]:
import torch.nn as nn

In [40]:
def group_tensor_mask(mask_matr, tensor_mask):
    return [
        [torch.where(mask_bl, tensor_bl, -torch.inf) for mask_bl, tensor_bl in zip(mask_row, tensor_row)]
        for mask_row, tensor_row in zip(mask_matr, tensor_mask)       
    ]

In [41]:
class ArmtAttention(nn.Module):
    def __init__(self, k_proj, q_proj, v_proj):
        self.k_proj = k_proj
        self.q_proj = q_proj
        self.v_proj = v_proj

        super().__init__()

    def forward(self, segment, memory, prev_kv_proj = None, attn_mask_matr = None):
        """
        There are no concatenation or excess copy (probably some because of implementation, like contiguous, but...)
        It handles all matrix multiplication as grouped one without explicit torch.concat
        """
        
        # Apply projections
        q_list = group_gemm_fn([memory, segment], [self.q_proj, self.q_proj])
        k_list = group_gemm_fn([memory, segment], [self.k_proj, self.k_proj])
        v_list = group_gemm_fn([memory, segment], [self.v_proj, self.v_proj])

        # if not first, then concat projected KV from previous iteration
        if prev_kv_proj is not None:
            prev_segment_proj_k, prev_segment_proj_v = prev_kv_proj
            k_list = [prev_segment_proj_k, ] + k_list
            v_list = [prev_segment_proj_v, ] + v_list

        # Q@K.T but in grouped format
        q_attn_mul = [q_list[i] for i in range(len(q_list)) for j in range(len(k_list))]
        k_tr_attn_mul = [
            k_list[j].T.contiguous() # this implementation of group_gemm faulty if memory is not contiguous
            for i in range(len(q_list)) for j in range(len(k_list))
        ]
        
        attn_list = group_gemm_fn(q_attn_mul, k_tr_attn_mul)

        # Reshape result of Q@K.T to matrix 
        blocks_in_row = len(k_list)
        rows_total = len(q_list)
        attn_list_matr = [attn_list[i:i+blocks_in_row] for i in range(0, blocks_in_row*rows_total, blocks_in_row)]

        # Scale
        # scale = (attn_list_matr[0][0].shape[-1]) ** 0.5
        scale = (k_list[0][0].shape[-1]) ** 0.5
        attn_list_matr = [[bl/scale for bl in row] for row in attn_list_matr]

        # Mask. Note, mask is also tiled
        if attn_mask_matr is not None:
            attn_list_matr = group_tensor_mask(attn_mask_matr, attn_list_matr)
                
        # Softmax
        attn_matr_softmax = block_softmax(attn_list_matr)

        # get block output
        result = multiply_matr_on_vec(attn_matr_softmax, v_list)

        # return result and projection of current segment for next iteration
        return result[1], result[0], [k_list[-1], v_list[-1]]

    def backward(self, grad):
        raise Exception("For now not implemented")

In [42]:
layer = ArmtAttention(k_proj, q_proj, v_proj)

In [43]:
attn_mask_matr_first = [
    row[1:] for row in attn_mask_matr
]

In [44]:
segm_out_matr, mem_out_matr, segm_kv_matr = layer(segment, memory, attn_mask_matr=attn_mask_matr_first)

In [45]:
attn_out_matr2, kv_segment2, segm_kv_matr_1 = layer(segment, mem_out_matr, prev_kv_proj=segm_kv_matr, attn_mask_matr=attn_mask_matr)

## Comparision

In [89]:
hid_dim = 4096
segment_size = 128
mem_size = 64

proj_dim = hid_dim

device = "cuda"

memory = torch.rand((mem_size, hid_dim), device=device, dtype=torch.float32)
prev_segment = torch.rand((segment_size, hid_dim), device=device, dtype=torch.float32)
segment = torch.rand((segment_size, hid_dim), device=device, dtype=torch.float32)

q_proj = torch.rand((hid_dim, proj_dim), device=device, dtype=torch.float32)
k_proj = torch.rand((hid_dim, proj_dim), device=device, dtype=torch.float32)
v_proj = torch.rand((hid_dim, hid_dim), device=device, dtype=torch.float32)

In [90]:
hid_dim, segment_size, mem_size

(4096, 128, 64)

In [91]:
ones_bl = torch.ones(())
attn_mask_matr = [
    [torch.zeros((mem_size, segment_size)), torch.ones((mem_size, mem_size)), torch.ones((mem_size, segment_size))],
    [torch.ones((segment_size, segment_size)).triu(), torch.ones((segment_size, mem_size)), torch.ones((segment_size, segment_size)).tril()],
]

attn_mask_matr = [[bl.type(torch.bool).to("cuda") for bl in row] for row in attn_mask_matr]
attn_mask_matr_first = [row[1:] for row in attn_mask_matr]

attn_mask = concat_block_matr(attn_mask_matr)
attn_mask_first = concat_block_matr(attn_mask_matr_first)

In [92]:
simple_armt = SimpleMergeArmtAttention(k_proj, q_proj, v_proj)

In [93]:
%%time

out_segment, out_memory, segm_kv = simple_armt(segment, memory, prev_kv_proj = None, attn_mask = attn_mask_first)
torch.cuda.synchronize()

CPU times: user 2.64 ms, sys: 172 μs, total: 2.81 ms
Wall time: 2.17 ms


In [94]:
%%time

out_segment, out_memory, segm_kv  = simple_armt(segment, out_memory, prev_kv_proj = segm_kv, attn_mask = attn_mask)
torch.cuda.synchronize()

CPU times: user 2.51 ms, sys: 4 μs, total: 2.52 ms
Wall time: 1.98 ms


In [95]:
n_iters_measure = 3000

In [96]:
import time

In [97]:
start = time.time()

for i in range(n_iters_measure):
    out_segment, out_memory, segm_kv  = simple_armt(segment, out_memory, prev_kv_proj = segm_kv, attn_mask = attn_mask)
    torch.cuda.synchronize()

end = time.time()

In [98]:
end-start, (end-start)/n_iters_measure

(4.5942347049713135, 0.0015314115683237712)

In [99]:
layer = ArmtAttention(k_proj, q_proj, v_proj)

In [100]:
%%time

segm_out_matr, mem_out_matr, segm_kv_matr = layer(segment, memory, attn_mask_matr=attn_mask_matr_first)
torch.cuda.synchronize()

CPU times: user 4.58 ms, sys: 90 μs, total: 4.67 ms
Wall time: 4.53 ms


In [101]:
%%time

segm_out_matr, mem_out_matr, segm_kv_matr = layer(segment, mem_out_matr, prev_kv_proj=segm_kv_matr, attn_mask_matr=attn_mask_matr)

CPU times: user 4.53 ms, sys: 82 μs, total: 4.62 ms
Wall time: 4.46 ms


In [102]:
start = time.time()

for i in range(n_iters_measure):
    segm_out_matr, mem_out_matr, segm_kv_matr = layer(segment, mem_out_matr, prev_kv_proj=segm_kv_matr, attn_mask_matr=attn_mask_matr)

end = time.time()

In [103]:
end-start, (end-start)/n_iters_measure

(12.51606798171997, 0.0041720226605733235)

In [72]:
from torch.profiler import profile, record_function, ProfilerActivity

In [108]:
n_iters_bench = 40

In [109]:
with profile(activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA], record_shapes=True, with_stack=True) as prof:
    with record_function("model_inference"):
        for i in range(n_iters_bench):
            out_segment, out_memory, segm_kv  = simple_armt(segment, out_memory, prev_kv_proj = segm_kv, attn_mask = attn_mask)

In [110]:
print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=20))

-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                                   Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg     Self CUDA   Self CUDA %    CUDA total  CUDA time avg    # of Calls  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                        model_inference         0.00%       0.000us         0.00%       0.000us       0.000us      60.074ms       102.10%      60.074ms      60.074ms             1  
                                        model_inference         8.46%       5.329ms        33.44%      21.075ms      21.075ms       0.000us         0.00%      58.836ms      58.836ms             1  
         

In [111]:
sort_by_keyword = "self_" + device + "_time_total"

In [112]:
with_stack=True,
# Print aggregated stats
print(prof.key_averages(group_by_stack_n=5).table(sort_by=sort_by_keyword, row_limit=10))

prof.export_chrome_trace("/home/jovyan/sivtsov/armt/simple_armt_trace_onelayer_4096_128_64.json")

-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                                   Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg     Self CUDA   Self CUDA %    CUDA total  CUDA time avg    # of Calls  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                        model_inference         0.00%       0.000us         0.00%       0.000us       0.000us      60.074ms       102.10%      60.074ms      60.074ms             1  
                                               aten::mm         5.08%       3.204ms         7.87%       4.957ms      24.784us      57.407ms        97.57%      57.407ms     287.035us           200  
         

In [113]:
with profile(activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA], with_stack=True, record_shapes=True) as prof:
    with record_function("model_inference"):
        for i in range(n_iters_bench):
            segm_out_matr, mem_out_matr, segm_kv_matr = layer(
                segment, mem_out_matr, prev_kv_proj=segm_kv_matr, attn_mask_matr=attn_mask_matr
            )


In [114]:

# Print aggregated stats
print(prof.key_averages(group_by_stack_n=5).table(sort_by=sort_by_keyword, row_limit=10))

prof.export_chrome_trace("/home/jovyan/sivtsov/armt/nocopy_armt_trace_onelayer_4096_128_64.json")

-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                                   Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg     Self CUDA   Self CUDA %    CUDA total  CUDA time avg    # of Calls  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                        model_inference         0.00%       0.000us         0.00%       0.000us       0.000us     245.721ms       283.83%     245.721ms     245.721ms             1  
                                  grouped_matmul_kernel         0.00%       0.000us         0.00%       0.000us       0.000us      75.015ms        86.65%      75.015ms     267.910us           280  
         

In [263]:
hid_dim = 512
segment_size = 128
mem_size = 32

proj_dim = hid_dim

device = "cuda"

segment = torch.rand((segment_size, hid_dim), device=device, dtype=torch.float32)
q_proj = torch.rand((hid_dim, proj_dim), device=device, dtype=torch.float32)

In [112]:
n_mults_at_once = 20

In [129]:
inps_dups = [torch.rand((segment_size, hid_dim), device=device, dtype=torch.float32) for i in range(n_mults_at_once)]
outs = [None for i in range(n_mults_at_once)]

In [136]:
%%time

for i in range(len(inps_dups)):
    o = inps_dups[i] @ q_proj
    outs[i] = o

torch.cuda.synchronize()

CPU times: user 2.56 ms, sys: 0 ns, total: 2.56 ms
Wall time: 2.13 ms


In [137]:
ws = [q_proj for i in range(len(inps_dups))]

In [138]:
%%time

outs = group_gemm_fn(inps_dups, ws)
torch.cuda.synchronize()

CPU times: user 2.3 ms, sys: 0 ns, total: 2.3 ms
Wall time: 2.01 ms
