In [2]:
import torch
import torch.nn as nn

torch.set_grad_enabled(False)

<torch.autograd.grad_mode.set_grad_enabled at 0x7fde21f29a50>

## straight armt implementation

In [3]:
class SimpleMergeArmtAttention(nn.Module):
    def __init__(self, k_proj, q_proj, v_proj):
        self.k_proj = k_proj
        self.q_proj = q_proj
        self.v_proj = v_proj

        super().__init__()

    def forward(self, segment, memory, prev_kv_proj = None, attn_mask = None):
        memory_seq_len = memory.shape[-2]
        segment_seq_len = segment.shape[-2]
        
        mem_segm_merged = torch.concat((segment, memory), dim=-2)

        q = mem_segm_merged @ self.q_proj
        k = mem_segm_merged @ self.k_proj
        v = mem_segm_merged @ self.v_proj

        if prev_kv_proj is not None:
            prev_segment_proj_k, prev_segment_proj_v = prev_kv_proj
            k = torch.concat((k, prev_segment_proj_k), dim=-2)
            v = torch.concat((v, prev_segment_proj_v), dim=-2)

        attn = q@torch.transpose(k, -1, -2)
        dk = k.shape[-1]
        attn = attn / dk**0.5

        if attn_mask is not None:
            attn.masked_fill_(~attn_mask, -1e9)

        # print("BEFORE ATTN: ", attn)
        attn = torch.softmax(attn, dim=-1)
        # print("ATTN: ", attn)
        out = attn @ v

        out_segment = out[..., :segment_seq_len, :]
        out_memory = out[..., segment_seq_len:segment_seq_len+memory_seq_len, :]

        cache_k_activation = k[..., :segment_seq_len, :]
        cache_v_activation = v[..., :segment_seq_len, :]
        return out_segment, out_memory, [cache_k_activation, cache_v_activation]

## Naive implementation usage

In [527]:
def clear_memory(memory_states):
    for s in memory_states:
        s.fill_(0)

In [595]:
n_layers = 10
norm_coeff = 1
data_norm_coeff = 1

seq_size = 12800 # 128*157

# hid_dim = 1024
hid_dim = 4096*2
# segment_size = 108
# mem_size = 20
segment_size = 128
mem_size = 16

proj_dim = hid_dim

device = "cuda"

In [596]:
memory = torch.rand((mem_size, hid_dim), device=device, dtype=torch.float32)
prev_segment = torch.rand((segment_size, hid_dim), device=device, dtype=torch.float32)

full_input = torch.rand((seq_size, hid_dim), device=device, dtype=torch.float32)
segment = full_input[:segment_size, :]


In [597]:
model_naive_states = []
for i in range(n_layers):
    q_proj = torch.rand((hid_dim, proj_dim), device=device, dtype=torch.float32)/norm_coeff
    k_proj = torch.rand((hid_dim, proj_dim), device=device, dtype=torch.float32)/norm_coeff
    v_proj = torch.rand((hid_dim, hid_dim), device=device, dtype=torch.float32)/norm_coeff

    q_proj = q_proj/q_proj.norm()
    k_proj = k_proj/k_proj.norm()
    v_proj = v_proj/v_proj.norm()

    layer = SimpleMergeArmtAttention(k_proj, q_proj, v_proj)
    model_naive_states.append(layer)

memory_states = []
for i in range(n_layers):
    memory = torch.rand((mem_size, hid_dim), device=device, dtype=torch.float32)
    # memory = memory/memory.norm()
    memory_states.append(memory)


In [598]:
# model_naive = ArmtTransformer(model_naive_states)

model_naive_layers = nn.ModuleList(model_naive_states)

In [599]:
def forward_for_segment(model_naive_layers, cur_segment, memory_states, prev_segm_kvs = None):
    cur_segm_kvs = []
    for i in range(n_layers):
        cur_memory = memory_states[i]
        prev_kv_proj = prev_segm_kvs[i] if prev_segm_kvs is not None else None
        out_segment, out_memory, segm_kv = model_naive_layers[i](
            cur_segment, cur_memory,  prev_kv_proj = prev_kv_proj, attn_mask = None
        )

        # # crutch for random initialization
        # out_memory = out_memory/out_memory.norm()
        # out_segment = out_segment/out_segment.norm()
        
        memory_states[i] = out_memory
        cur_segm_kvs.append(segm_kv)
        
        # cur_segment = cur_segment+out_segment
        cur_segment = out_segment
        
    return cur_segment, cur_segm_kvs

In [600]:
cur_segment = segment

In [601]:
cur_segment, cur_segm_kvs = forward_for_segment(model_naive_layers, cur_segment, memory_states, prev_segm_kvs = None)

In [602]:

clear_memory(memory_states)

In [603]:
from torch.profiler import profile, record_function, ProfilerActivity

In [604]:
sort_by_keyword = "self_" + device + "_time_total"

activities = [ProfilerActivity.CPU, ProfilerActivity.CUDA]


In [605]:
%%time



n_segments = full_input.shape[-2]//segment_size
prev_segm_kvs = None

out_segments = []

# with profile(
#     activities=activities,
#     with_stack=True,
# ) as prof:
for i in range(n_segments):
    cur_segment = full_input[i*segment_size: (i+1)*segment_size]
    cur_segment_out, cur_segm_kvs = forward_for_segment(model_naive_layers, cur_segment, memory_states, prev_segm_kvs = prev_segm_kvs)

    out_segments.append(cur_segment_out)
    prev_segm_kvs = cur_segm_kvs

CPU times: user 979 ms, sys: 2.64 s, total: 3.62 s
Wall time: 3.62 s


In [606]:

# Print aggregated stats
# print(prof.key_averages(group_by_stack_n=5).table(sort_by=sort_by_keyword, row_limit=2))

In [607]:
# prof.export_chrome_trace("/home/jovyan/sivtsov/armt/simple_armt_trace_1024_108_20.json")

In [608]:
model_out = torch.concat(out_segments, axis=-2)

In [609]:
model_out

tensor([[0.1101, 0.1093, 0.1094,  ..., 0.1100, 0.1082, 0.1082],
        [0.1101, 0.1093, 0.1094,  ..., 0.1100, 0.1082, 0.1082],
        [0.1101, 0.1093, 0.1094,  ..., 0.1100, 0.1082, 0.1082],
        ...,
        [0.1148, 0.1140, 0.1140,  ..., 0.1147, 0.1128, 0.1128],
        [0.1148, 0.1140, 0.1140,  ..., 0.1147, 0.1128, 0.1128],
        [0.1148, 0.1140, 0.1140,  ..., 0.1147, 0.1128, 0.1128]],
       device='cuda:0')

In [610]:
class AsyncMatmulTask:
    def __init__(self, a, b, result_future):
        self.a = a
        self.b = b

        self.result_future = result_future
    
    def set_result(self, result):
        self.result_future.set_result(result)
    
    async def get_result(self):
        return await self.result_future

In [611]:
import asyncio
loop = asyncio.get_event_loop()

In [612]:
a = torch.rand((128, 512), device="cuda")
b = torch.rand((512, 192), device="cuda")

# 128, 512]-cuda:0-True @[512, 192

In [613]:
fut = loop.create_future()

In [614]:
task = AsyncMatmulTask(a, b, fut)

In [615]:
task

<__main__.AsyncMatmulTask at 0x7f7d2c310910>

In [616]:
r = a @ b

In [617]:
task.set_result(r)

In [618]:
task.get_result()

<coroutine object AsyncMatmulTask.get_result at 0x7f7d2c159e40>

In [619]:
await task.get_result()

tensor([[128.4938, 122.1323, 124.9071,  ..., 128.4474, 123.2391, 126.5575],
        [127.2382, 120.4156, 122.3017,  ..., 125.8692, 124.5613, 128.3667],
        [130.5266, 121.9220, 125.6758,  ..., 131.8502, 125.3415, 129.5557],
        ...,
        [134.1995, 126.6796, 132.4346,  ..., 136.8704, 128.2258, 135.5008],
        [133.3957, 123.7400, 124.4430,  ..., 133.0228, 124.8353, 130.6754],
        [129.1273, 122.8959, 125.5124,  ..., 125.0941, 122.5810, 127.4804]],
       device='cuda:0')

In [620]:
import sys
sys.path.append("/home/jovyan/sivtsov/armt")
from group_gemm import group_gemm_fn

In [621]:
class AsyncBatchedTaskExecutor:
    def __init__(self, loop, use_triton = False):
        self.loop = loop
        self.sorted_tasks = {}

        self.use_triton = use_triton

    def submit_task(self, task):
        bucket_name = task.__class__.__name__
        if bucket_name not in self.sorted_tasks:
            self.sorted_tasks[bucket_name] = []

        self.sorted_tasks[bucket_name].append(task)

    def execute_one_bucket(self, bucket_name, verbose=False):
        any_job_done = False
        
        if bucket_name not in self.sorted_tasks or len(self.sorted_tasks[bucket_name]) == 0:
            return any_job_done

        if bucket_name == AsyncMatmulTask.__name__:
            if self.use_triton:
                a_tensors = []
                b_tensors = []
                for task in self.sorted_tasks[bucket_name]:
                    a_tensors.append(task.a)
                    b_tensors.append(task.b)
                # print(
                #     ', '.join([f"{list(a.shape)}-{a.device}-{a.is_contiguous()} @{list(b.shape)}-{b.device}-{b.is_contiguous()}" for a,b in zip(a_tensors, b_tensors)])
                # )
                res_tensors = group_gemm_fn(a_tensors, b_tensors)
                for res, task in zip(res_tensors, self.sorted_tasks[bucket_name]):
                    task.set_result(res)
                    any_job_done = True
            else:  
                for task in self.sorted_tasks[bucket_name]:
                    res = task.a @ task.b
                    task.set_result(res)
                    any_job_done = True

            if verbose:
                print(f"Compute {len(self.sorted_tasks[bucket_name])}")

            self.sorted_tasks[bucket_name] = []

        return any_job_done

    async def compute_loop(self, max_iters=10, verbose=False):
        any_job_done = False
        for i in range(max_iters):
            # print("compute iter")
            for tn in self.sorted_tasks.keys():
                any_job_done = any_job_done or self.execute_one_bucket(tn, verbose=verbose)
            # if not any_job_done:
            #     return
                
            await asyncio.sleep(0)

    def create_future(self):
        return self.loop.create_future()

In [622]:
executor = AsyncBatchedTaskExecutor(loop, use_triton=False)

In [623]:
tasks = [AsyncMatmulTask(a, b, executor.create_future()) for i in range(3)]

In [624]:
for t in tasks:
    executor.submit_task(t)

In [625]:
await executor.compute_loop()

In [626]:
await tasks[0].get_result()

tensor([[128.4938, 122.1323, 124.9071,  ..., 128.4474, 123.2391, 126.5575],
        [127.2382, 120.4156, 122.3017,  ..., 125.8692, 124.5613, 128.3667],
        [130.5266, 121.9220, 125.6758,  ..., 131.8502, 125.3415, 129.5557],
        ...,
        [134.1995, 126.6796, 132.4346,  ..., 136.8704, 128.2258, 135.5008],
        [133.3957, 123.7400, 124.4430,  ..., 133.0228, 124.8353, 130.6754],
        [129.1273, 122.8959, 125.5124,  ..., 125.0941, 122.5810, 127.4804]],
       device='cuda:0')

In [627]:
class SimpleMergeArmtAttentionWithAsyncCompute(nn.Module):
    def __init__(self, k_proj, q_proj, v_proj):
        self.k_proj = k_proj
        self.q_proj = q_proj
        self.v_proj = v_proj

        super().__init__()

    async def forward(self, executor, segment, memory, prev_kv_proj = None, attn_mask = None):
        memory_seq_len = memory.shape[-2]
        segment_seq_len = segment.shape[-2]
        
        mem_segm_merged = torch.concat((memory, segment), dim=-2)

        # q = mem_segm_merged @ self.q_proj
        # k = mem_segm_merged @ self.k_proj
        # v = mem_segm_merged @ self.v_proj
        q_f = AsyncMatmulTask(mem_segm_merged, self.q_proj, executor.create_future())
        k_f = AsyncMatmulTask(mem_segm_merged, self.k_proj, executor.create_future())
        v_f = AsyncMatmulTask(mem_segm_merged, self.v_proj, executor.create_future())

        executor.submit_task(q_f)
        executor.submit_task(k_f)
        executor.submit_task(v_f)
        
        q = await q_f.get_result()
        k = await k_f.get_result()
        v = await v_f.get_result()
        

        if prev_kv_proj is not None:
            prev_segment_proj_k, prev_segment_proj_v = prev_kv_proj
            k = torch.concat((prev_segment_proj_k, k), dim=-2)
            v = torch.concat((prev_segment_proj_v, v), dim=-2)

        # attn = q@k.T
        attn_f = AsyncMatmulTask(q, k.T.contiguous(), executor.create_future())
        executor.submit_task(attn_f)
        attn = await attn_f.get_result()
        
        dk = k.shape[-1]
        attn = attn / dk**0.5

        if attn_mask is not None:
            attn.masked_fill_(~attn_mask, -1e9)

        # print("BEFORE ATTN: ", attn)
        attn = torch.softmax(attn, dim=-1)
        # print("ATTN: ", attn)
        # out = attn @ v
        out_f = AsyncMatmulTask(attn, v, executor.create_future())
        executor.submit_task(out_f)
        out = await out_f.get_result()

        out_segment = out[..., -segment_seq_len:, :]
        out_memory = out[..., -segment_seq_len-memory_seq_len:-segment_seq_len, :]

        cache_k_activation = k[..., -segment_seq_len:, :]
        cache_v_activation = v[..., -segment_seq_len:, :]
        return out_segment, out_memory, [cache_k_activation, cache_v_activation]

In [628]:
n_segments = full_input.shape[-2]//segment_size

In [629]:
residuals = [ 
    [None for i in range(n_segments)]
    for j in range(n_layers+1)
]

memory_states = [
    torch.zeros((mem_size, hid_dim), device=device, dtype=torch.float32)
    for j in range(n_layers)
]

prev_kvs_storage = [
    None 
    for j in range(n_layers)
]

In [630]:
model_naive_states_async = [
    SimpleMergeArmtAttentionWithAsyncCompute(attn.k_proj, attn.q_proj, attn.v_proj)
    for attn in model_naive_states
]

In [631]:
residuals[0] = [full_input[i*segment_size: (i+1)*segment_size] for i in range(n_segments)]

In [632]:


async def schedule_block(
    executor, 
    layer_id, segm_pos, n_segments, 
    layers, memory_states, prev_kvs_storage, residuals
):
    layer = layers[layer_id]

    if layer_id >= len(residuals):
        print(f"layer_id out: {layer_id}/{len(residuals)} {segm_pos}/{n_segments}")

    if segm_pos >= len(residuals[layer_id]):
        print(f"segm out: {layer_id}/{len(residuals)} {segm_pos}/{n_segments}")
    
    prev_kv_proj = prev_kvs_storage[layer_id]
    segment = residuals[layer_id][segm_pos]
    memory = memory_states[layer_id]

    # print(f"GOT: {layer_id}/{len(residuals)} {segm_pos}/{n_segments}: {segment}")
    out_segment, out_memory, out_kv = await layer.forward(executor, segment, memory, prev_kv_proj = prev_kv_proj, attn_mask = None)

    # crutch for random initialization
    # out_memory = out_memory/out_memory.norm()
    # out_segment = out_segment/out_segment.norm()
    
    # print(f"PUT: {layer_id+1}/{len(residuals)} {segm_pos}/{n_segments}: {out_segment}")
    residuals[layer_id+1][segm_pos] = out_segment
    memory_states[layer_id] = out_memory
    prev_kvs_storage[layer_id] = out_kv


In [633]:
executor = AsyncBatchedTaskExecutor(loop, use_triton=False)

In [634]:
# segm out: 0/21 158/156
# segm out: 1/21 157/156
# segm out: 2/21 156/156
# Compute 51
# Compute 17
# Compute 17
# Compute 0
# segm out: 0/21 159/156
# segm out: 1/21 158/156
# segm out: 2/21 157/156
# segm out: 3/21 156/156
# Compute 48
# Compute 16
# Compute 16
# Compute 0
# segm out: 0/21 160/156
# segm out: 1/21 159/156
# segm out: 2/21 158/156
# segm out: 3/21 157/156
# segm out: 4/21 156/156

In [635]:
n_segments+n_layers-2+1

109

In [636]:
import time

In [637]:
clear_memory(memory_states)

In [638]:
with profile(
    activities=activities,
    with_stack=True,
) as prof:
    
    start = time.time()
    
    for diag in range(0, n_segments+n_layers-2+1):
        diag_tasks = []
        for l_id in range(0, min(n_layers, diag+1)):
            seg_id = diag - l_id
            if seg_id >= n_segments:
                continue
    
            # print(f"seg_id: {seg_id}/{n_segments} {l_id}/{n_layers}")
    
            dtask = asyncio.create_task(schedule_block(
                executor, 
                l_id, seg_id, n_segments, 
                model_naive_states_async, memory_states, prev_kvs_storage, residuals, 
            ))
            diag_tasks.append(dtask)
    
        # print("diag_tasks: ", len(diag_tasks))
        
        is_undone = True
        while is_undone:
            await asyncio.sleep(0)
            
            await executor.compute_loop(1, False)
    
            is_undone = False
            for dt in diag_tasks:
                if not dt.done():
                    is_undone = True
    
    end = time.time()

In [639]:
end - start

3.684100866317749

In [640]:
# Print aggregated stats
# print(prof.key_averages(group_by_stack_n=5).table(sort_by=sort_by_keyword, row_limit=10))

# prof.export_chrome_trace("/home/jovyan/sivtsov/armt/simple_armt_batched_trace_1024_108_20.json")


In [641]:
torch.concat(residuals[-1], axis=-2)

tensor([[0.1101, 0.1093, 0.1094,  ..., 0.1100, 0.1082, 0.1082],
        [0.1101, 0.1093, 0.1094,  ..., 0.1100, 0.1082, 0.1082],
        [0.1101, 0.1093, 0.1094,  ..., 0.1100, 0.1082, 0.1082],
        ...,
        [0.1148, 0.1140, 0.1140,  ..., 0.1147, 0.1128, 0.1128],
        [0.1148, 0.1140, 0.1140,  ..., 0.1147, 0.1128, 0.1128],
        [0.1148, 0.1140, 0.1140,  ..., 0.1147, 0.1128, 0.1128]],
       device='cuda:0')

In [642]:
# residuals[0]

In [643]:
model_out

tensor([[0.1101, 0.1093, 0.1094,  ..., 0.1100, 0.1082, 0.1082],
        [0.1101, 0.1093, 0.1094,  ..., 0.1100, 0.1082, 0.1082],
        [0.1101, 0.1093, 0.1094,  ..., 0.1100, 0.1082, 0.1082],
        ...,
        [0.1148, 0.1140, 0.1140,  ..., 0.1147, 0.1128, 0.1128],
        [0.1148, 0.1140, 0.1140,  ..., 0.1147, 0.1128, 0.1128],
        [0.1148, 0.1140, 0.1140,  ..., 0.1147, 0.1128, 0.1128]],
       device='cuda:0')

### batched shift

In [644]:
# import cutlass
# import os
# import logging
# from cutlass.emit.pytorch import _ArchListSetter
# from cutlass import CUTLASS_PATH, logger
# from torch.utils.cpp_extension import load as jit_load

# extra_cuda_cflags = ["-std=c++17"]
# cc = 80

# # cuda_file = "/home/jovyan/sivtsov/armt/batched_gemm_fused/grouped_gemm_kernel.cu"
# # cpp_file = "/home/jovyan/sivtsov/armt/batched_gemm_fused/grouped_gemm.cpp"

# cuda_file = "/home/jovyan/sivtsov/armt/batched_gemm_fused_fp32/grouped_gemm_kernel.cu"
# cpp_file = "/home/jovyan/sivtsov/armt/batched_gemm_fused_fp32/grouped_gemm.cpp"

# with _ArchListSetter(cc):
#     grouped_gemm_fused = jit_load(
#         "grouped_gemm_fused",
#         [cpp_file, cuda_file],
#         extra_cuda_cflags=extra_cuda_cflags,
#         extra_include_paths=[
#             os.path.join(CUTLASS_PATH, "include"),
#             os.path.join(CUTLASS_PATH, "tools/util/include"),
#         ],
#         # extra_ldflags=["-lcuda"],
#         verbose=(logger.level == logging.DEBUG)
#     )

In [645]:
memory_states_values = []
for i in range(n_layers):
    memory = torch.rand((mem_size, hid_dim), device=device, dtype=torch.float32)
    # memory = memory/memory.norm()
    memory_states_values.append(memory)
prev_segment = torch.rand((segment_size, hid_dim), device=device, dtype=torch.float32)

full_input_values = torch.rand((seq_size, hid_dim), device=device, dtype=torch.float32)
model_naive_layers = nn.ModuleList(model_naive_states)

In [646]:
memory_states = [ms.clone() for ms in memory_states_values]
full_input = full_input_values.clone()

In [647]:
%%time
# with profile(
#     activities=activities,
#     with_stack=True,
# ) as prof:
    
n_segments = full_input.shape[-2]//segment_size
prev_segm_kvs = None

out_segments = []


for i in range(n_segments):
    cur_segment = full_input[i*segment_size: (i+1)*segment_size]
    cur_segment_out, cur_segm_kvs = forward_for_segment(model_naive_layers, cur_segment, memory_states, prev_segm_kvs = prev_segm_kvs)

    out_segments.append(cur_segment_out)
    prev_segm_kvs = cur_segm_kvs

CPU times: user 855 ms, sys: 2.76 s, total: 3.62 s
Wall time: 3.62 s


In [648]:
# print(prof.key_averages(group_by_stack_n=5).table(sort_by=sort_by_keyword, row_limit=2))
# prof.export_chrome_trace("/home/jovyan/sivtsov/armt/simple_armt_trace_1024_128_16.json")

In [649]:
model_naive_layers[0].k_proj

tensor([[8.5877e-05, 3.2606e-05, 4.7348e-05,  ..., 7.8048e-05, 1.4467e-04,
         6.8340e-05],
        [2.0971e-04, 1.2495e-04, 4.7114e-05,  ..., 1.4708e-05, 1.3707e-04,
         1.4782e-04],
        [1.3311e-04, 2.0665e-04, 7.2726e-06,  ..., 1.3032e-04, 1.9263e-04,
         1.6358e-04],
        ...,
        [2.1937e-05, 1.4712e-04, 8.5876e-05,  ..., 2.0777e-04, 3.6135e-05,
         1.6349e-04],
        [3.8573e-05, 6.4643e-05, 1.4749e-04,  ..., 1.2750e-04, 6.9246e-05,
         7.8237e-05],
        [1.5062e-04, 1.2124e-04, 7.2243e-05,  ..., 1.6533e-04, 2.6494e-05,
         1.5068e-04]], device='cuda:0')

In [650]:
segment

tensor([[0.0240, 0.1012, 0.5432,  ..., 0.5815, 0.6590, 0.6832],
        [0.8228, 0.6089, 0.5518,  ..., 0.5307, 0.5262, 0.2317],
        [0.3078, 0.1139, 0.2669,  ..., 0.6657, 0.4683, 0.0287],
        ...,
        [0.8757, 0.8900, 0.4062,  ..., 0.1679, 0.5932, 0.0780],
        [0.6212, 0.8532, 0.8251,  ..., 0.9994, 0.2915, 0.6914],
        [0.6270, 0.2418, 0.0027,  ..., 0.0993, 0.9398, 0.8402]],
       device='cuda:0')

In [651]:
# input_batch

In [652]:
def forward_fused(segm_mem, weights, segment_seq_len, prev_kv_proj = None, attn_mask=None, fused=False):
    segm_mem_list = segm_mem.unbind(0)
    if fused:
        ##
        qkv = grouped_gemm_fused.run(segm_mem_list*3, weights[0]+weights[1]+weights[2])
        ##
    else:
        qkv = grouped_gemm.run(segm_mem_list*3, weights[0]+weights[1]+weights[2])
        qkv = torch.stack(qkv, dim=0)
        ##
    q,k,v = torch.chunk(qkv, 3, dim=0)

    if prev_kv_proj is not None:
        prev_segment_proj_k, prev_segment_proj_v = prev_kv_proj
        k = torch.concatenate([k, prev_segment_proj_k], dim=-2)
        v = torch.concatenate([v, prev_segment_proj_v], dim=-2)

    attn = q@torch.transpose(k, -1, -2)
    dk = k.shape[-1]
    attn = attn / dk**0.5

    if attn_mask is not None:
        attn.masked_fill_(~attn_mask, -2e9)
    attn = torch.softmax(attn, axis=-1)

    out = attn@v
    cache_k_activation = k[..., :segment_seq_len, :]
    cache_v_activation = v[..., :segment_seq_len, :]
    return out, [cache_k_activation, cache_v_activation]

def shift_and_set_memory(X, M, mem_len):
    if X.size(0) > 1:  # Only shift if batch size > 1
        X[1:, :mem_len, :] = X[:-1, :mem_len, :]
    X[0, :mem_len, :] = M
    return X

In [653]:
q_projections = [model_naive_layers[i].q_proj for i in range(len(model_naive_layers))]
k_projections = [model_naive_layers[i].k_proj for i in range(len(model_naive_layers))]
v_projections = [model_naive_layers[i].v_proj for i in range(len(model_naive_layers))]

In [654]:
weights = [
    q_projections,
    k_projections,
    v_projections,
]

In [655]:
n_layers = len(model_naive_layers)
input_batch = torch.zeros((n_layers, segment_size+mem_size, hid_dim), dtype=torch.float32, device="cuda")

full_input = full_input_values.clone()
mem_all = torch.stack(memory_states_values)
n_segments = full_input.shape[-2]//segment_size
output_batch = torch.zeros((n_segments, segment_size, hid_dim), dtype=torch.float32, device="cuda")
input_batch[:, -mem_size: ,:] = mem_all
# input_batch[0, :segment_size, :] = full_input[:segment_size]

att_mask_prefix = torch.ones((segment_size+mem_size, 2*segment_size+mem_size), dtype=torch.bool, device="cuda")
att_mask_prefix[:, -segment_size:].fill_(0);

In [656]:
import cutlass
import torch


dtype = torch.float32
plan = cutlass.op.GroupedGemm(element=dtype, layout=cutlass.LayoutType.RowMajor)

import random
random.seed(2023)

# Utility function to initialize A, B, C, and D matrices corresponding to dimensions M, N, and K
def initialize(dtype, M, N, K):
    sizes = [(M, K), (K, N), (M, N), (M, N)]
    return [torch.randint(-3, 3, size, device='cuda').to(dtype) for size in sizes]

# Utility function to generate `problems` GEMMs of random sizes
def generate_problems(problems, valid_sizes):
    As, Bs, Cs, Ds = [], [], [], []
    for _ in range(problems):
        M, N, K = [random.choice(valid_sizes) for _ in range(3)]
        A, B, C, D = initialize(dtype, M, N, K)
        As.append(A)
        Bs.append(B)
        Cs.append(C)
        Ds.append(D)
    return As, Bs, Cs, Ds

# valid_sizes = [segment_size+mem_size, 2*segment_size+mem_size, hid_dim]
# As, Bs, Cs, Ds, = generate_problems(80, valid_sizes)
segm_mem_list = input_batch.unbind(0)
As = [el for el in segm_mem_list*3]
Bs = [el for el in weights[0]+weights[1]+weights[2]]
Cs = [torch.zeros(a.shape[:-1] + (b.shape[-1],), dtype=a.dtype, device=a.device) for a,b in zip(As, Bs)]
Ds = [torch.zeros_like(el) for el in Cs]

plan.run(As, Bs, Cs, Ds, print_module=True)
Ds_torch = [a @ b for a, b in zip(As, Bs)]

for d, d_torch in zip(Ds, Ds_torch):
    assert torch.allclose(d, d_torch, rtol=1e-3)

op = plan.construct()
grouped_gemm = cutlass.emit.pytorch(op, name='grouped_gemm', cc=plan.cc, sourcedir='out', jit=True)


// Gemm operator cutlass_tensorop_s1688tf32gemm_grouped_256x128_32x3_tt_align4
using cutlass_tensorop_s1688tf32gemm_grouped_256x128_32x3_tt_align4_base =
  typename cutlass::gemm::kernel::DefaultGemmGrouped<
    float, cutlass::layout::RowMajor, cutlass::ComplexTransform::kNone, 4,
    float, cutlass::layout::RowMajor, cutlass::ComplexTransform::kNone, 4,
    float, cutlass::layout::RowMajor,
    float,
    cutlass::arch::OpClassTensorOp,
    cutlass::arch::Sm80,
    cutlass::gemm::GemmShape<256, 128, 32>,
    cutlass::gemm::GemmShape<64, 64, 32>,
    cutlass::gemm::GemmShape<16, 8, 8>,
    cutlass::epilogue::thread::LinearCombination<float, 4, float, float>,
    cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>,
    3,
    cutlass::gemm::kernel::GroupScheduleMode::kDeviceOnly,
    cutlass::arch::OpMultiplyAdd
>::GemmKernel;

// Define named type
struct cutlass_tensorop_s1688tf32gemm_grouped_256x128_32x3_tt_align4_type :
  public cutlass_tensorop_s1688tf32gemm_grouped_256x

In [657]:
import cutlass
import os
import logging
from cutlass.emit.pytorch import _ArchListSetter
from cutlass import CUTLASS_PATH, logger
from torch.utils.cpp_extension import load as jit_load

extra_cuda_cflags = ["-std=c++17"]
cc = 80

cuda_file = "/home/jovyan/sivtsov/armt/batched_gemm_fused_fp32_128_16/grouped_gemm_kernel.cu"
cpp_file = "/home/jovyan/sivtsov/armt/batched_gemm_fused_fp32_128_16/grouped_gemm.cpp"

with _ArchListSetter(cc):
    grouped_gemm_fused = jit_load(
        "grouped_gemm_fused",
        [cpp_file, cuda_file],
        extra_cuda_cflags=extra_cuda_cflags,
        extra_include_paths=[
            os.path.join(CUTLASS_PATH, "include"),
            os.path.join(CUTLASS_PATH, "tools/util/include"),
        ],
        # extra_ldflags=["-lcuda"],
        verbose=(logger.level == logging.DEBUG)
    )

In [658]:
print('ready')

ready


In [660]:
%%time
# with profile(
#     activities=activities,
#     with_stack=True,
# ) as prof:

n_segments = full_input.shape[-2]//segment_size
prev_kv_proj = None
use_att_mask = None
for seg_num in range(n_segments+n_layers-1):
    if seg_num < n_layers and prev_kv_proj is not None:
        # prev_kv_proj[0][seg_num:].fill_(0)
        # prev_kv_proj[1][seg_num:].fill_(0)
        if use_att_mask is None:
            use_att_mask = att_mask_prefix
        use_att_mask[:seg_num, -segment_size:].fill_(1)
        use_att_mask[seg_num:, -segment_size:].fill_(0)
    else:
        use_att_mask = None
    
    if seg_num < n_segments:
        cur_segment = full_input[i*segment_size: (i+1)*segment_size]
    shift_and_set_memory(input_batch, cur_segment, segment_size)
    
    # if seg_num < n_layers:
    #     w_group = [w_g[:seg_num+1] for w_g in weights]
    w_group = weights
    input_batch_res, prev_kv_proj = forward_fused(input_batch, w_group, segment_size, prev_kv_proj = prev_kv_proj, attn_mask=use_att_mask)

    if seg_num < n_layers:
        input_batch[:seg_num+1] = input_batch_res[:seg_num+1]
    else:
        input_batch = input_batch_res
    
    if seg_num+1 >= n_layers:
        output_batch[seg_num+1-n_layers, :, :] = input_batch[-1, :segment_size, :]
        

CPU times: user 1.27 s, sys: 0 ns, total: 1.27 s
Wall time: 1.27 s


In [504]:
use_att_mask

In [505]:
output_batch.shape

torch.Size([100, 128, 1024])

In [506]:
# print(prof.key_averages(group_by_stack_n=5).table(sort_by=sort_by_keyword, row_limit=2))
# prof.export_chrome_trace("/home/jovyan/sivtsov/armt/grouped_armt_trace_1024_128_16.json")

In [507]:
%%time
with profile(
    activities=activities,
    with_stack=True,
) as prof:
    n_segments = full_input.shape[-2]//segment_size
    prev_kv_proj = None
    for seg_num in range(n_segments+n_layers-1):
        if seg_num < n_segments:
            cur_segment = full_input[i*segment_size: (i+1)*segment_size]
        shift_and_set_memory(input_batch, cur_segment, segment_size)
        
        # if seg_num < n_layers:
        #     w_group = [w_g[:seg_num+1] for w_g in weights]
        w_group = weights
        input_batch, prev_kv_proj = forward_fused(input_batch, w_group, segment_size, prev_kv_proj = prev_kv_proj, fused=True)
    
        if seg_num+1 >= n_layers:
            output_batch[seg_num+1-n_layers, :, :] = input_batch[-1, :segment_size, :]
            

CPU times: user 194 ms, sys: 19.3 ms, total: 213 ms
Wall time: 232 ms


In [508]:
# print(prof.key_averages(group_by_stack_n=5).table(sort_by=sort_by_keyword, row_limit=2))
# prof.export_chrome_trace("/home/jovyan/sivtsov/armt/grouped_fused_armt_trace_1024_128_16.json")

In [522]:
output_batch[0]

tensor([[0.2870, 0.2912, 0.2850,  ..., 0.2860, 0.2846, 0.2904],
        [0.2855, 0.2897, 0.2836,  ..., 0.2846, 0.2832, 0.2890],
        [0.2828, 0.2869, 0.2808,  ..., 0.2819, 0.2804, 0.2862],
        ...,
        [0.3101, 0.3147, 0.3079,  ..., 0.3091, 0.3076, 0.3139],
        [0.3119, 0.3165, 0.3097,  ..., 0.3109, 0.3093, 0.3156],
        [0.3113, 0.3158, 0.3091,  ..., 0.3103, 0.3087, 0.3150]],
       device='cuda:0')

In [515]:
output_batch[0]

tensor([[0.2960, 0.3003, 0.2937,  ..., 0.2951, 0.2935, 0.2996],
        [0.2958, 0.3001, 0.2934,  ..., 0.2949, 0.2933, 0.2994],
        [0.2968, 0.3012, 0.2945,  ..., 0.2959, 0.2943, 0.3005],
        ...,
        [0.3155, 0.3201, 0.3134,  ..., 0.3145, 0.3129, 0.3193],
        [0.3142, 0.3188, 0.3121,  ..., 0.3132, 0.3116, 0.3180],
        [0.3140, 0.3186, 0.3119,  ..., 0.3130, 0.3114, 0.3178]],
       device='cuda:0')

In [516]:
out_segments[0], memory_states[0]

(tensor([[0.3329, 0.3378, 0.3305,  ..., 0.3319, 0.3302, 0.3370],
         [0.3329, 0.3378, 0.3305,  ..., 0.3319, 0.3302, 0.3370],
         [0.3329, 0.3378, 0.3305,  ..., 0.3319, 0.3302, 0.3370],
         ...,
         [0.3329, 0.3378, 0.3305,  ..., 0.3319, 0.3302, 0.3370],
         [0.3329, 0.3378, 0.3305,  ..., 0.3319, 0.3302, 0.3370],
         [0.3329, 0.3378, 0.3305,  ..., 0.3319, 0.3302, 0.3370]],
        device='cuda:0'),
 tensor([[0.9329, 0.3637, 0.8522,  ..., 0.3443, 0.8143, 0.3351],
         [0.7574, 0.0777, 0.1359,  ..., 0.8377, 0.0520, 0.0015],
         [0.5138, 0.7993, 0.8422,  ..., 0.3493, 0.3738, 0.5801],
         ...,
         [0.1770, 0.9145, 0.3836,  ..., 0.6782, 0.8625, 0.7925],
         [0.6115, 0.3240, 0.7566,  ..., 0.8265, 0.4217, 0.2230],
         [0.1881, 0.5402, 0.2941,  ..., 0.7107, 0.1698, 0.7122]],
        device='cuda:0'))

In [458]:
n_layers = len(model_naive_layers)
input_batch = torch.zeros((n_layers, segment_size+mem_size, hid_dim), dtype=torch.float32, device="cuda")

full_input = full_input_values.clone()
mem_all = torch.stack(memory_states_values)
n_segments = full_input.shape[-2]//segment_size
output_batch = torch.zeros((n_segments, segment_size, hid_dim), dtype=torch.float32, device="cuda")
input_batch[:, -mem_size: ,:] = mem_all

att_mask_prefix = torch.ones((segment_size+mem_size, 2*segment_size+mem_size), dtype=torch.bool, device="cuda")
att_mask_prefix[:, -segment_size:].fill_(0);

In [459]:
i = 0

In [460]:
cur_segment = full_input[i*segment_size: (i+1)*segment_size]
# cur_segment1 = full_input[(i+1)*segment_size: (i+2)*segment_size]
cur_memory = memory_states_values[i]
cur_memory1 = memory_states_values[i+1]

In [461]:
# memory_states = [ms.clone() for ms in memory_states_values]
# full_input = full_input_values.clone()

In [462]:
prev_kv_proj = None
use_att_mask = None

seg_num = 0

shift_and_set_memory(input_batch, cur_segment, segment_size);
w_group = weights
input_batch1, prev_kv_proj = forward_fused(
    input_batch, w_group, segment_size, attn_mask=use_att_mask, prev_kv_proj = prev_kv_proj
)


input_batch[:1] = input_batch1[:1]
r0 = input_batch.clone()

seg_num = 1

prev_kv_proj[0][seg_num:].fill_(0)
prev_kv_proj[1][seg_num:].fill_(0)
use_att_mask = att_mask_prefix
use_att_mask[:seg_num, -segment_size:].fill_(1)
use_att_mask[seg_num:, -segment_size:].fill_(0)

shift_and_set_memory(input_batch, cur_segment1, segment_size);
w_group = weights
input_batch2, prev_kv_proj1 = forward_fused(
    input_batch, w_group, segment_size, attn_mask=use_att_mask, prev_kv_proj = prev_kv_proj
)



In [463]:
prev_kv_proj = None
out_segment, out_memory, segm_kv = model_naive_layers[0](
    cur_segment, cur_memory,  prev_kv_proj = prev_kv_proj, attn_mask = None
)

out_segment1, out_memory1, segm_kv1 = model_naive_layers[1](
    out_segment, cur_memory1,  prev_kv_proj = prev_kv_proj, attn_mask = None
)

In [464]:
r0.shape, segment_size+mem_size

(torch.Size([10, 144, 1024]), 144)

In [465]:
r0[1][-mem_size:]

tensor([[0.7178, 0.9236, 0.8082,  ..., 0.9563, 0.0631, 0.8135],
        [0.6851, 0.6563, 0.7787,  ..., 0.1626, 0.6178, 0.4088],
        [0.8488, 0.7004, 0.1824,  ..., 0.4087, 0.7167, 0.8065],
        ...,
        [0.1243, 0.1001, 0.7295,  ..., 0.7933, 0.6401, 0.8644],
        [0.4289, 0.3072, 0.0622,  ..., 0.8815, 0.0460, 0.7646],
        [0.5456, 0.6389, 0.3514,  ..., 0.1272, 0.4789, 0.7961]],
       device='cuda:0')

In [466]:
cur_memory1

tensor([[0.7178, 0.9236, 0.8082,  ..., 0.9563, 0.0631, 0.8135],
        [0.6851, 0.6563, 0.7787,  ..., 0.1626, 0.6178, 0.4088],
        [0.8488, 0.7004, 0.1824,  ..., 0.4087, 0.7167, 0.8065],
        ...,
        [0.1243, 0.1001, 0.7295,  ..., 0.7933, 0.6401, 0.8644],
        [0.4289, 0.3072, 0.0622,  ..., 0.8815, 0.0460, 0.7646],
        [0.5456, 0.6389, 0.3514,  ..., 0.1272, 0.4789, 0.7961]],
       device='cuda:0')

In [467]:
r0[0][-mem_size:]

tensor([[0.4385, 0.4400, 0.4326,  ..., 0.4454, 0.4350, 0.4266],
        [0.4385, 0.4400, 0.4326,  ..., 0.4454, 0.4350, 0.4266],
        [0.4385, 0.4400, 0.4326,  ..., 0.4454, 0.4350, 0.4266],
        ...,
        [0.4385, 0.4400, 0.4326,  ..., 0.4454, 0.4350, 0.4266],
        [0.4385, 0.4400, 0.4326,  ..., 0.4454, 0.4350, 0.4266],
        [0.4385, 0.4400, 0.4326,  ..., 0.4454, 0.4350, 0.4266]],
       device='cuda:0')

In [468]:
r0[0][:segment_size], r0[0][-mem_size:]

(tensor([[0.4385, 0.4401, 0.4326,  ..., 0.4454, 0.4350, 0.4266],
         [0.4385, 0.4400, 0.4326,  ..., 0.4454, 0.4350, 0.4266],
         [0.4385, 0.4400, 0.4326,  ..., 0.4454, 0.4350, 0.4266],
         ...,
         [0.4385, 0.4400, 0.4326,  ..., 0.4454, 0.4350, 0.4266],
         [0.4385, 0.4400, 0.4326,  ..., 0.4454, 0.4350, 0.4266],
         [0.4385, 0.4400, 0.4325,  ..., 0.4453, 0.4349, 0.4266]],
        device='cuda:0'),
 tensor([[0.4385, 0.4400, 0.4326,  ..., 0.4454, 0.4350, 0.4266],
         [0.4385, 0.4400, 0.4326,  ..., 0.4454, 0.4350, 0.4266],
         [0.4385, 0.4400, 0.4326,  ..., 0.4454, 0.4350, 0.4266],
         ...,
         [0.4385, 0.4400, 0.4326,  ..., 0.4454, 0.4350, 0.4266],
         [0.4385, 0.4400, 0.4326,  ..., 0.4454, 0.4350, 0.4266],
         [0.4385, 0.4400, 0.4326,  ..., 0.4454, 0.4350, 0.4266]],
        device='cuda:0'))

In [469]:
out_segment, out_memory

(tensor([[0.4386, 0.4401, 0.4326,  ..., 0.4454, 0.4350, 0.4267],
         [0.4385, 0.4400, 0.4326,  ..., 0.4454, 0.4350, 0.4266],
         [0.4385, 0.4400, 0.4326,  ..., 0.4454, 0.4350, 0.4266],
         ...,
         [0.4385, 0.4400, 0.4326,  ..., 0.4454, 0.4350, 0.4266],
         [0.4385, 0.4400, 0.4326,  ..., 0.4454, 0.4350, 0.4266],
         [0.4385, 0.4400, 0.4325,  ..., 0.4453, 0.4349, 0.4266]],
        device='cuda:0'),
 tensor([[0.4385, 0.4400, 0.4326,  ..., 0.4454, 0.4350, 0.4266],
         [0.4385, 0.4400, 0.4326,  ..., 0.4454, 0.4350, 0.4266],
         [0.4385, 0.4400, 0.4326,  ..., 0.4454, 0.4350, 0.4266],
         ...,
         [0.4385, 0.4400, 0.4326,  ..., 0.4454, 0.4350, 0.4266],
         [0.4385, 0.4400, 0.4326,  ..., 0.4454, 0.4350, 0.4266],
         [0.4385, 0.4400, 0.4326,  ..., 0.4454, 0.4350, 0.4266]],
        device='cuda:0'))

In [470]:
input_batch[1][:segment_size], input_batch[1][-mem_size:]

(tensor([[0.4385, 0.4401, 0.4326,  ..., 0.4454, 0.4350, 0.4266],
         [0.4385, 0.4400, 0.4326,  ..., 0.4454, 0.4350, 0.4266],
         [0.4385, 0.4400, 0.4326,  ..., 0.4454, 0.4350, 0.4266],
         ...,
         [0.4385, 0.4400, 0.4326,  ..., 0.4454, 0.4350, 0.4266],
         [0.4385, 0.4400, 0.4326,  ..., 0.4454, 0.4350, 0.4266],
         [0.4385, 0.4400, 0.4325,  ..., 0.4453, 0.4349, 0.4266]],
        device='cuda:0'),
 tensor([[0.7178, 0.9236, 0.8082,  ..., 0.9563, 0.0631, 0.8135],
         [0.6851, 0.6563, 0.7787,  ..., 0.1626, 0.6178, 0.4088],
         [0.8488, 0.7004, 0.1824,  ..., 0.4087, 0.7167, 0.8065],
         ...,
         [0.1243, 0.1001, 0.7295,  ..., 0.7933, 0.6401, 0.8644],
         [0.4289, 0.3072, 0.0622,  ..., 0.8815, 0.0460, 0.7646],
         [0.5456, 0.6389, 0.3514,  ..., 0.1272, 0.4789, 0.7961]],
        device='cuda:0'))

In [471]:
input_batch2[1][:segment_size], input_batch2[1][-mem_size:]

(tensor([[0.3832, 0.3750, 0.3840,  ..., 0.3856, 0.3774, 0.3874],
         [0.3865, 0.3782, 0.3873,  ..., 0.3889, 0.3807, 0.3907],
         [0.3865, 0.3782, 0.3873,  ..., 0.3889, 0.3807, 0.3907],
         ...,
         [0.3865, 0.3782, 0.3873,  ..., 0.3889, 0.3807, 0.3907],
         [0.3865, 0.3782, 0.3873,  ..., 0.3889, 0.3807, 0.3907],
         [0.3865, 0.3782, 0.3873,  ..., 0.3889, 0.3807, 0.3907]],
        device='cuda:0'),
 tensor([[0.3875, 0.3792, 0.3884,  ..., 0.3900, 0.3818, 0.3918],
         [0.3875, 0.3793, 0.3884,  ..., 0.3900, 0.3818, 0.3919],
         [0.3874, 0.3791, 0.3883,  ..., 0.3899, 0.3816, 0.3917],
         ...,
         [0.3872, 0.3790, 0.3881,  ..., 0.3897, 0.3815, 0.3916],
         [0.3877, 0.3795, 0.3886,  ..., 0.3902, 0.3820, 0.3921],
         [0.3875, 0.3793, 0.3884,  ..., 0.3900, 0.3818, 0.3919]],
        device='cuda:0'))

In [472]:
out_segment1, out_memory1

(tensor([[0.3865, 0.3782, 0.3873,  ..., 0.3889, 0.3807, 0.3908],
         [0.3865, 0.3782, 0.3873,  ..., 0.3889, 0.3807, 0.3908],
         [0.3865, 0.3782, 0.3873,  ..., 0.3889, 0.3807, 0.3908],
         ...,
         [0.3865, 0.3782, 0.3873,  ..., 0.3889, 0.3807, 0.3908],
         [0.3865, 0.3782, 0.3873,  ..., 0.3889, 0.3807, 0.3908],
         [0.3865, 0.3782, 0.3873,  ..., 0.3889, 0.3807, 0.3908]],
        device='cuda:0'),
 tensor([[0.3875, 0.3792, 0.3884,  ..., 0.3900, 0.3818, 0.3919],
         [0.3876, 0.3793, 0.3884,  ..., 0.3900, 0.3818, 0.3919],
         [0.3874, 0.3791, 0.3883,  ..., 0.3899, 0.3817, 0.3917],
         ...,
         [0.3872, 0.3790, 0.3881,  ..., 0.3897, 0.3815, 0.3916],
         [0.3878, 0.3795, 0.3887,  ..., 0.3902, 0.3820, 0.3921],
         [0.3875, 0.3793, 0.3884,  ..., 0.3900, 0.3818, 0.3919]],
        device='cuda:0'))

In [179]:
cur_segment, cur_memory

(tensor([[0.4231, 0.4509, 0.9262,  ..., 0.6619, 0.6163, 0.1562],
         [0.6674, 0.1417, 0.4253,  ..., 0.9709, 0.1736, 0.0097],
         [0.4113, 0.3242, 0.1899,  ..., 0.4100, 0.0012, 0.3086],
         ...,
         [0.2614, 0.3582, 0.8099,  ..., 0.2140, 0.4649, 0.5207],
         [0.6966, 0.5908, 0.9716,  ..., 0.7158, 0.6745, 0.2273],
         [0.8452, 0.2172, 0.9096,  ..., 0.0771, 0.9687, 0.6406]],
        device='cuda:0'),
 tensor([[0.9329, 0.3637, 0.8522,  ..., 0.3443, 0.8143, 0.3351],
         [0.7574, 0.0777, 0.1359,  ..., 0.8377, 0.0520, 0.0015],
         [0.5138, 0.7993, 0.8422,  ..., 0.3493, 0.3738, 0.5801],
         ...,
         [0.1770, 0.9145, 0.3836,  ..., 0.6782, 0.8625, 0.7925],
         [0.6115, 0.3240, 0.7566,  ..., 0.8265, 0.4217, 0.2230],
         [0.1881, 0.5402, 0.2941,  ..., 0.7107, 0.1698, 0.7122]],
        device='cuda:0'))

In [180]:
shift_and_set_memory(input_batch, cur_segment, segment_size);

In [181]:
input_batch[0, :-mem_size], input_batch[0, -mem_size:]

(tensor([[0.4231, 0.4509, 0.9262,  ..., 0.6619, 0.6163, 0.1562],
         [0.6674, 0.1417, 0.4253,  ..., 0.9709, 0.1736, 0.0097],
         [0.4113, 0.3242, 0.1899,  ..., 0.4100, 0.0012, 0.3086],
         ...,
         [0.2614, 0.3582, 0.8099,  ..., 0.2140, 0.4649, 0.5207],
         [0.6966, 0.5908, 0.9716,  ..., 0.7158, 0.6745, 0.2273],
         [0.8452, 0.2172, 0.9096,  ..., 0.0771, 0.9687, 0.6406]],
        device='cuda:0'),
 tensor([[0.9329, 0.3637, 0.8522,  ..., 0.3443, 0.8143, 0.3351],
         [0.7574, 0.0777, 0.1359,  ..., 0.8377, 0.0520, 0.0015],
         [0.5138, 0.7993, 0.8422,  ..., 0.3493, 0.3738, 0.5801],
         ...,
         [0.1770, 0.9145, 0.3836,  ..., 0.6782, 0.8625, 0.7925],
         [0.6115, 0.3240, 0.7566,  ..., 0.8265, 0.4217, 0.2230],
         [0.1881, 0.5402, 0.2941,  ..., 0.7107, 0.1698, 0.7122]],
        device='cuda:0'))

In [182]:
cur_segm_kvs = []
# cur_memory = memory_states[i]
prev_kv_proj = None
out_segment, out_memory, segm_kv = model_naive_layers[0](
    cur_segment, cur_memory,  prev_kv_proj = prev_kv_proj, attn_mask = None
)

In [183]:
out_segment, out_memory

(tensor([[0.4386, 0.4401, 0.4326,  ..., 0.4454, 0.4350, 0.4267],
         [0.4385, 0.4400, 0.4326,  ..., 0.4454, 0.4350, 0.4266],
         [0.4385, 0.4400, 0.4326,  ..., 0.4454, 0.4350, 0.4266],
         ...,
         [0.4385, 0.4400, 0.4326,  ..., 0.4454, 0.4350, 0.4266],
         [0.4385, 0.4400, 0.4326,  ..., 0.4454, 0.4350, 0.4266],
         [0.4385, 0.4400, 0.4325,  ..., 0.4453, 0.4349, 0.4266]],
        device='cuda:0'),
 tensor([[0.4385, 0.4400, 0.4326,  ..., 0.4454, 0.4350, 0.4266],
         [0.4385, 0.4400, 0.4326,  ..., 0.4454, 0.4350, 0.4266],
         [0.4385, 0.4400, 0.4326,  ..., 0.4454, 0.4350, 0.4266],
         ...,
         [0.4385, 0.4400, 0.4326,  ..., 0.4454, 0.4350, 0.4266],
         [0.4385, 0.4400, 0.4326,  ..., 0.4454, 0.4350, 0.4266],
         [0.4385, 0.4400, 0.4326,  ..., 0.4454, 0.4350, 0.4266]],
        device='cuda:0'))

In [184]:
w_group = weights
input_batch, prev_kv_proj = forward_fused(input_batch, w_group, segment_size, prev_kv_proj = prev_kv_proj)


In [185]:
input_batch[0]

tensor([[0.4385, 0.4401, 0.4326,  ..., 0.4454, 0.4350, 0.4266],
        [0.4385, 0.4400, 0.4326,  ..., 0.4454, 0.4350, 0.4266],
        [0.4385, 0.4400, 0.4326,  ..., 0.4454, 0.4350, 0.4266],
        ...,
        [0.4385, 0.4400, 0.4326,  ..., 0.4454, 0.4350, 0.4266],
        [0.4385, 0.4400, 0.4326,  ..., 0.4454, 0.4350, 0.4266],
        [0.4385, 0.4400, 0.4326,  ..., 0.4454, 0.4350, 0.4266]],
       device='cuda:0')

In [186]:
def forward_for_segment(model_naive_layers, cur_segment, memory_states, prev_segm_kvs = None):
    cur_segm_kvs = []
    for i in range(n_layers):
        cur_memory = memory_states[i]
        prev_kv_proj = prev_segm_kvs[i] if prev_segm_kvs is not None else None
        out_segment, out_memory, segm_kv = model_naive_layers[i](
            cur_segment, cur_memory,  prev_kv_proj = prev_kv_proj, attn_mask = None
        )

        # # crutch for random initialization
        # out_memory = out_memory/out_memory.norm()
        # out_segment = out_segment/out_segment.norm()
        
        memory_states[i] = out_memory
        cur_segm_kvs.append(segm_kv)
        
        # cur_segment = cur_segment+out_segment
        cur_segment = out_segment
        
    return cur_segment, cur_segm_kvs

In [None]:
    cur_segment_out, cur_segm_kvs = forward_for_segment(model_naive_layers, cur_segment, memory_states, prev_segm_kvs = prev_segm_kvs)


In [None]:
    cur_segment_out, cur_segm_kvs = forward_for_segment(model_naive_layers, cur_segment, memory_states, prev_segm_kvs = prev_segm_kvs)


In [None]:
    cur_segment = full_input[i*segment_size: (i+1)*segment_size]
    shift_and_set_memory(input_batch, cur_segment, segment_size)
    
    # if seg_num < n_layers:
    #     w_group = [w_g[:seg_num+1] for w_g in weights]
    w_group = weights
    input_batch, prev_kv_proj = forward_fused(input_batch, w_group, segment_size, prev_kv_proj = prev_kv_proj)

    if seg_num+1 >= n_layers:
        output_batch[seg_num+1-n_segments, :, :] = input_batch[-1, :segment_size, :]
        

In [153]:
n_segments

100

In [139]:
input_batch.shape, cur_segment.shape, segment_size

(torch.Size([10, 192, 1024]), torch.Size([128, 1024]), 128)

In [140]:
shift_and_set_memory(input_batch, cur_segment, segment_size);

In [141]:
segm_mem_list = input_batch.unbind(0)

In [142]:
128+32

160

NameError: name 'segm_mem_list' is not defined

[torch.Size([1024, 1024]),
 torch.Size([1024, 1024]),
 torch.Size([1024, 1024]),
 torch.Size([1024, 1024]),
 torch.Size([1024, 1024]),
 torch.Size([1024, 1024]),
 torch.Size([1024, 1024]),
 torch.Size([1024, 1024]),
 torch.Size([1024, 1024]),
 torch.Size([1024, 1024]),
 torch.Size([1024, 1024]),
 torch.Size([1024, 1024]),
 torch.Size([1024, 1024]),
 torch.Size([1024, 1024]),
 torch.Size([1024, 1024]),
 torch.Size([1024, 1024]),
 torch.Size([1024, 1024]),
 torch.Size([1024, 1024]),
 torch.Size([1024, 1024]),
 torch.Size([1024, 1024]),
 torch.Size([1024, 1024]),
 torch.Size([1024, 1024]),
 torch.Size([1024, 1024]),
 torch.Size([1024, 1024]),
 torch.Size([1024, 1024]),
 torch.Size([1024, 1024]),
 torch.Size([1024, 1024]),
 torch.Size([1024, 1024]),
 torch.Size([1024, 1024]),
 torch.Size([1024, 1024])]

In [145]:
segm_mem_list = input_batch.unbind(0)
qkv = grouped_gemm_fused.run(segm_mem_list*3, weights[0]+weights[1]+weights[2])
# qkv = torch.stack(qkv, dim=0)
q,k,v = torch.chunk(qkv, 3, dim=0)



In [82]:
# input_batch

In [83]:
w_group = weights

In [89]:
len(input_batch.unbind(0)), len(w_group[0])

(10, 10)