In [None]:
!wget https://raw.githubusercontent.com/allenai/longformer/master/longformer/lib/lib_diagonaled_mm_float32_cuda.so -O lib_diagonaled_mm_float32_cuda.so

In [None]:
from typing import Union
import time
from functools import lru_cache, partial, reduce
import matplotlib.pyplot as plt
import matplotlib
from copy import deepcopy
from codecarbon import EmissionsTracker

import torch
import os.path
import window_matmul
import pandas as pd
import transformers
from torch.profiler import profile, ProfilerActivity
import tqdm
import torch.nn.functional as F
import numpy as np
import sys
sys.path.extend(["../.."])
from listformer.model.listformer import ListformerModel, ListformerConfig

if "SLURM_JOB_ID" in os.environ:
    del os.environ["SLURM_JOB_ID"]

device = torch.device("cuda")

In [None]:
from aquarel import Theme

theme = Theme(name="theme").set_grid(draw=True).set_font(family="serif")
theme.apply()

markers = ["o", "s", "X", "v", "P", "*", "D"]

In [None]:
def human_readable_bytes(sizes, unit=None):
    units = ['B', 'KB', 'MB', 'GB', 'TB', 'PB', 'EB', 'ZB', 'YB']
    if unit is None:
        unit_idx = 0
        for size in sizes:
            index = 0
            while size >= 1024 and index < len(units) - 1:
                size /= 1024
                index += 1
            unit_idx = max(index, unit_idx)
    else:
        unit_idx = units.index(unit)
    sizes = [size / 1024**unit_idx for size in sizes]
    return sizes, units[unit_idx]

In [None]:
class DiagonaledMM(torch.autograd.Function):
    '''Class to encapsulate tvm code for compiling a diagonal_mm function, in addition to calling
    this function from PyTorch
    '''

    function_dict = {}  # save a list of functions, each has a different set of parameters

    @staticmethod
    def _get_lib_filename(dtype: str, device: str):
        base_filename = 'lib_diagonaled_mm'
        return '{}_{}_{}.so'.format(base_filename, dtype, device)

    @staticmethod
    def _load_compiled_function(dtype: str, device: str):
        from tvm.runtime import load_module  # this can be the small runtime python library, and doesn't need to be the whole thing
        filename = DiagonaledMM._get_lib_filename(dtype, device)
        return load_module(filename)

    @staticmethod
    def _get_function(dtype: str, device: str):
        '''Loads the function from the disk or compile it'''
        # A list of arguments that define the function
        args = (dtype, device)
        if args not in DiagonaledMM.function_dict:
            diagonaled_mm = DiagonaledMM._load_compiled_function(dtype, device)  # try to load from disk
            # convert the tvm function into a pytorch function
            from tvm.contrib import dlpack
            diagonaled_mm_pytorch = dlpack.to_pytorch_func(diagonaled_mm)  # wrap it as a pytorch function
            # save the function into a dictionary to be reused
            DiagonaledMM.function_dict[args] = diagonaled_mm_pytorch  # save it in a dictionary for next time
        return DiagonaledMM.function_dict[args]

    @staticmethod
    def _diagonaled_mm(t1: torch.Tensor, t2: torch.Tensor, w: int, d: Union[torch.Tensor,int],
                       is_t1_diagonaled: bool = False, transpose_t1: bool = False, padding: int = 0,
                       autoregressive: bool = False):
        '''Calls the compiled function after checking the input format. This function is called in three different modes.
        t1 x t2 = r ==> t1 and t2 are not diagonaled, but r is. Useful for query x key = attention_scores
        t1 x t2 = r ==> t1 is diagonaled, but t2 and r are not. Useful to compuate attantion_scores x value = context
        t1 x t2 = r ==> t1 is diagonaled and it should be transposed, but t2 and r are not diagonaled. Useful in some of
                            the calculations in the backward pass.
        '''
        dtype = str(t1.dtype).split('.')[1]
        device = t1.device.type
        assert len(t1.shape) == 4
        assert len(t1.shape) == len(t2.shape)
        assert t1.shape[:3] == t2.shape[:3]
        if isinstance(d, int):  # if d is an integer, replace it with a tensor of the same length
                                # as number of heads, and it is filled with the same dilation value
            d = t1.new_full(size=(t1.shape[2],), fill_value=d, dtype=torch.int, requires_grad=False)

        assert len(d.shape) == 1
        assert d.shape[0] == t1.shape[2]  # number of dilation scores should match number of heads
        b = t1.shape[0]  # batch size
        n = t1.shape[1]  # sequence length
        h = t1.shape[2]  # number of heads
        m = t2.shape[3]  # hidden dimension
        w_upper = 0 if autoregressive else w
        c = w_upper + w + 1  # number of diagonals
        if is_t1_diagonaled:
            assert t1.shape[3] == c
            r = t1.new_empty(b, n, h, m)  # allocate spase for the result tensor
        else:
            assert not transpose_t1
            assert t1.shape[3] == m
            r = t1.new_empty(b, n, h, c)  # allocate spase for the result tensor

        # gets function from memory, from disk or compiles it from scratch
        _diagonaled_mm_function = DiagonaledMM._get_function(dtype=dtype, device=device)

        # The last argument to this function is a little hacky. It is the size of the last dimension of the result tensor
        # We use it as a proxy to tell if t1_is_diagonaled or not (if t1 is diagonaled, result is not, and vice versa).
        # The second reason is that the lambda expression in `_compile_function` is easier to express when the shape
        # of the output is known
        # This functions computes diagonal_mm then saves the result in `r`
        if m == c:
            # FIXME
            print('Error: the hidden dimension {m} shouldn\'t match number of diagonals {c}')
            assert False
        _diagonaled_mm_function(t1, t2, r, d, w, w_upper, padding, transpose_t1, m if is_t1_diagonaled else c)
        return r

    @staticmethod
    def _prepare_tensors(t):
        '''Fix `stride()` information of input tensor. This addresses some inconsistency in stride information in PyTorch.
        For a tensor t, if t.size(0) == 1, then the value of t.stride()[0] doesn't matter.
        TVM expects this value to be the `product(t.size()[1:])` but PyTorch some times sets it to `t.stride()[1]`.
        Here's an example to reporduce this issue:
            import torch
            print(torch.randn(1, 10).stride())
            > (10, 1)
            print(torch.randn(10, 1).t().contiguous().stride())
            > (1, 1)  # expected it to be (10, 1) as above
            print(torch.randn(10, 2).t().contiguous().stride())
            > (10, 1) # but gets the expected stride if the first dimension is > 1
        '''
        assert t.is_contiguous()
        t_stride = list(t.stride())
        t_size = list(t.size())
        # Fix wrong stride information for the first dimension. This occures when batch_size=1
        if t_size[0] == 1 and t_stride[0] == t_stride[1]:
            # In this case, the stride of the first dimension should be the product
            # of the sizes  of all other dimensions
            t_stride[0] = t_size[1] * t_size[2] * t_size[3]
            t = t.as_strided(size=t_size, stride=t_stride)
        return t

    min_seq_len = 16  # unexpected output if seq_len < 16

    @staticmethod
    def forward(ctx, t1: torch.Tensor, t2: torch.Tensor, w: int, d: Union[torch.Tensor,int], is_t1_diagonaled: bool = False, padding: int = 0, autoregressive: bool = False) -> torch.Tensor:
        '''Compuates diagonal_mm of t1 and t2.
        args: 
        t1: torch.Tensor = (batch_size, seq_len, num_attention_heads, hidden_size|number_of_diagonals).
            t1 can be a regular tensor (e.g. `query_layer`) or a diagonaled one (e.g. `attention_scores`)
        t2: torch.Tensor = (batch_size, seq_len, num_attention_heads, hidden_size). This is always a non-diagonaled
            tensor, e.g. `key_layer` or `value_layer`
        w: int = window size; number of attentions on each side of the word
        d: torch.Tensor or int = dilation of attentions per attention head. If int, the same dilation value will be used for all
            heads. If torch.Tensor, it should be 1D of lenth=number of attention heads
        is_t1_diagonaled: is t1 a diagonaled or a regular tensor
        padding: the padding value to use when accessing invalid locations. This is mainly useful when the padding
            needs to be a very large negative value (to compute softmax of attentions). For other usecases,
            please use zero padding.
        autoregressive: if true, return only the lower triangle
        returns: torch.Tensor = (batch_size, seq_len, num_attention_heads, hidden_size|number_of_diagonals)
            if t1 is diagonaed, result is non-diagonaled, and vice versa
        '''
        batch_size, seq_len, num_attention_heads, hidden_size = t1.size()
        assert seq_len >= DiagonaledMM.min_seq_len, 'avoid splitting errors by using seq_len >= {}'.format(DiagonaledMM.min_seq_len)  # FIXME
        ctx.save_for_backward(t1, t2)
        ctx.w = w
        ctx.d = d
        ctx.is_t1_diagonaled = is_t1_diagonaled
        ctx.autoregressive = autoregressive
        t1 = DiagonaledMM._prepare_tensors(t1)
        t2 = DiagonaledMM._prepare_tensors(t2)
        # output = t1.mm(t2)  # what would have been called if this was a regular matmul
        output = DiagonaledMM._diagonaled_mm(t1, t2, w, d, is_t1_diagonaled=is_t1_diagonaled, padding=padding, autoregressive=autoregressive)
        return output

    @staticmethod
    def backward(ctx, grad_output):
        t1, t2 = ctx.saved_tensors
        w = ctx.w
        d = ctx.d
        is_t1_diagonaled = ctx.is_t1_diagonaled
        autoregressive = ctx.autoregressive
        if not grad_output.is_contiguous():
            grad_output = grad_output.contiguous()  # tvm requires all input tensors to be contiguous
        grad_output = DiagonaledMM._prepare_tensors(grad_output)
        t1 = DiagonaledMM._prepare_tensors(t1)
        t2 = DiagonaledMM._prepare_tensors(t2)
        # http://cs231n.github.io/optimization-2/
        # https://pytorch.org/docs/master/notes/extending.html
        # grad_t1 = grad_output.mm(t2)  # what would have been called if this was a regular matmul
        grad_t1 = DiagonaledMM._diagonaled_mm(grad_output, t2, w, d, is_t1_diagonaled=not is_t1_diagonaled, autoregressive=autoregressive)
        # grad_t2 = grad_output.t().mm(t1)  # or `grad_t2 = t1.t().mm(grad_output).t()` because `(AB)^T = B^TA^T`
        if is_t1_diagonaled:
            grad_t2 = DiagonaledMM._diagonaled_mm(t1, grad_output, w, d, is_t1_diagonaled=True, transpose_t1=True, autoregressive=autoregressive)
        else:
            grad_t2 = DiagonaledMM._diagonaled_mm(grad_output, t1, w, d, is_t1_diagonaled=True, transpose_t1=True, autoregressive=autoregressive)
        return grad_t1, grad_t2, None, None, None, None, None

longformer_tvm_matmul = DiagonaledMM.apply

In [None]:
def _skew(x, direction, padding_value):
    '''Convert diagonals into columns (or columns into diagonals depending on `direction`'''
    x_padded = F.pad(x, direction, value=padding_value)
    x_padded = x_padded.view(*x_padded.size()[:-2], x_padded.size(-1), x_padded.size(-2))
    return x_padded


def _skew2(x, padding_value):
    '''shift every row 1 step to right converting columns into diagonals'''
    # X = B x C x M x L
    B, C, M, L = x.size()
    x = F.pad(x, (0, M + 1), value=padding_value)  # B x C x M x (L+M+1)
    x = x.view(B, C, -1)  # B x C x ML+MM+M
    x = x[:, :, :-M]  # B x C x ML+MM
    x = x.view(B, C, M, M + L)  # B x C, M x L+M
    x = x[:, :, :, :-1]
    return x


def _chunk(x, w):
    '''convert into overlapping chunkings. Chunk size = 2w, overlap size = w'''

    # non-overlapping chunks of size = 2w
    x = x.view(x.size(0), x.size(1) // (w * 2), w * 2, x.size(2))

    # use `as_strided` to make the chunks overlap with an overlap size = w
    chunk_size = list(x.size())
    chunk_size[1] = chunk_size[1] * 2 - 1

    chunk_stride = list(x.stride())
    chunk_stride[1] = chunk_stride[1] // 2
    return x.as_strided(size=chunk_size, stride=chunk_stride)


def sliding_chunks_matmul_qk(q: torch.Tensor, k: torch.Tensor, w: int, padding_value: float):
    '''Matrix multiplicatio of query x key tensors using with a sliding window attention pattern.
    This implementation splits the input into overlapping chunks of size 2w (e.g. 512 for pretrained Longformer)
    with an overlap of size w'''
    bsz, seqlen, num_heads, head_dim = q.size()
    assert seqlen % (w * 2) == 0
    assert q.size() == k.size()

    chunks_count = seqlen // w - 1

    # group bsz and num_heads dimensions into one, then chunk seqlen into chunks of size w * 2
    q = q.transpose(1, 2).reshape(bsz * num_heads, seqlen, head_dim)
    k = k.transpose(1, 2).reshape(bsz * num_heads, seqlen, head_dim)

    chunk_q = _chunk(q, w)
    chunk_k = _chunk(k, w)

    # matrix multipication
    # bcxd: bsz*num_heads x chunks x 2w x head_dim
    # bcyd: bsz*num_heads x chunks x 2w x head_dim
    # bcxy: bsz*num_heads x chunks x 2w x 2w
    chunk_attn = torch.einsum('bcxd,bcyd->bcxy', (chunk_q, chunk_k))  # multiply

    # convert diagonals into columns
    diagonal_chunk_attn = _skew(chunk_attn, direction=(0, 0, 0, 1), padding_value=padding_value)

    # allocate space for the overall attention matrix where the chunks are compined. The last dimension
    # has (w * 2 + 1) columns. The first (w) columns are the w lower triangles (attention from a word to
    # w previous words). The following column is attention score from each word to itself, then
    # followed by w columns for the upper triangle.

    diagonal_attn = diagonal_chunk_attn.new_empty((bsz * num_heads, chunks_count + 1, w, w * 2 + 1))

    # copy parts from diagonal_chunk_attn into the compined matrix of attentions
    # - copying the main diagonal and the upper triangle
    diagonal_attn[:, :-1, :, w:] = diagonal_chunk_attn[:, :, :w, :w + 1]
    diagonal_attn[:, -1, :, w:] = diagonal_chunk_attn[:, -1, w:, :w + 1]
    # - copying the lower triangle
    diagonal_attn[:, 1:, :, :w] = diagonal_chunk_attn[:, :, - (w + 1):-1, w + 1:]
    diagonal_attn[:, 0, 1:w, 1:w] = diagonal_chunk_attn[:, 0, :w - 1, 1 - w:]

    # separate bsz and num_heads dimensions again
    diagonal_attn = diagonal_attn.view(bsz, num_heads, seqlen, 2 * w + 1).transpose(2, 1)

    mask_invalid_locations(diagonal_attn, w, 1, False)
    return diagonal_attn


def sliding_chunks_matmul_pv(prob: torch.Tensor, v: torch.Tensor, w: int):
    '''Same as sliding_chunks_matmul_qk but for prob and value tensors. It is expecting the same output
    format from sliding_chunks_matmul_qk'''
    bsz, seqlen, num_heads, head_dim = v.size()
    assert seqlen % (w * 2) == 0
    assert prob.size()[:3] == v.size()[:3]
    assert prob.size(3) == 2 * w + 1
    chunks_count = seqlen // w - 1
    # group bsz and num_heads dimensions into one, then chunk seqlen into chunks of size 2w
    chunk_prob = prob.transpose(1, 2).reshape(bsz * num_heads, seqlen // w, w, 2 * w + 1)

    # group bsz and num_heads dimensions into one
    v = v.transpose(1, 2).reshape(bsz * num_heads, seqlen, head_dim)

    # pad seqlen with w at the beginning of the sequence and another w at the end
    padded_v = F.pad(v, (0, 0, w, w), value=-1)

    # chunk padded_v into chunks of size 3w and an overlap of size w
    chunk_v_size = (bsz * num_heads, chunks_count + 1, 3 * w, head_dim)
    chunk_v_stride = padded_v.stride()
    chunk_v_stride = chunk_v_stride[0], w * chunk_v_stride[1], chunk_v_stride[1], chunk_v_stride[2]
    chunk_v = padded_v.as_strided(size=chunk_v_size, stride=chunk_v_stride)

    skewed_prob = _skew2(chunk_prob, padding_value=0)

    context = torch.einsum('bcwd,bcdh->bcwh', (skewed_prob, chunk_v))
    return context.view(bsz, num_heads, seqlen, head_dim).transpose(1, 2)


def pad_to_window_size(input_ids: torch.Tensor,
                       one_sided_window_size: int, pad_token_id: int):
    '''A helper function to pad tokens and mask to work with the sliding_chunks implementation of Longformer selfattention.
    Input:
        input_ids = torch.Tensor(bsz x seqlen): ids of wordpieces
        attention_mask = torch.Tensor(bsz x seqlen): attention mask
        one_sided_window_size = int: window size on one side of each token
        pad_token_id = int: tokenizer.pad_token_id
    Returns
        (input_ids, attention_mask) padded to length divisible by 2 * one_sided_window_size
    '''
    w = int(2 * one_sided_window_size)
    seqlen = input_ids.size(1)
    padding_len = (w - seqlen % w) % w
    input_ids = F.pad(input_ids, (0, 0, 0, 0, 0, padding_len), value=pad_token_id)
    return input_ids

def _get_invalid_locations_mask_fixed_dilation(seq_len: int, w: int, d: int):
    diagonals_list = []
    for j in range(-d * w, d, d):
        diagonal_mask = torch.zeros(seq_len, device='cpu', dtype=torch.uint8)
        diagonal_mask[:-j] = 1
        diagonals_list.append(diagonal_mask)
    return torch.stack(diagonals_list, dim=-1)

@lru_cache()
def _get_invalid_locations_mask(w: int, d: Union[torch.Tensor,int], autoregressive: bool, device: str):
    if isinstance(d, int):
        affected_seq_len = w * d
        mask = _get_invalid_locations_mask_fixed_dilation(affected_seq_len, w, d)
        mask = mask[None, :, None, :]
    else:
        affected_seq_len = w * d.max()
        head_masks = []
        d_list = d.cpu().numpy().tolist()
        for d in d_list:
            one_head_mask = _get_invalid_locations_mask_fixed_dilation(affected_seq_len, w, d)
            head_masks.append(one_head_mask)
        mask = torch.stack(head_masks, dim=-2)
        mask = mask[None, :, :, :]

    ending_mask = None if autoregressive else mask.flip(dims=(1, 3)).bool().to(device)
    return affected_seq_len, mask.bool().to(device), ending_mask

def mask_invalid_locations(input_tensor: torch.Tensor, w: int, d: Union[torch.Tensor, int], autoregressive: bool) -> torch.Tensor:
    affected_seq_len, beginning_mask, ending_mask = _get_invalid_locations_mask(w, d, autoregressive, input_tensor.device)
    seq_len = input_tensor.size(1)
    beginning_input = input_tensor[:, :affected_seq_len, :, :w+1]
    beginning_mask = beginning_mask[:, :seq_len].expand(beginning_input.size())
    beginning_input.masked_fill_(beginning_mask, -float('inf'))
    if not autoregressive:
        ending_input = input_tensor[:, -affected_seq_len:, :, -(w+1):]
        ending_mask = ending_mask[:, -seq_len:].expand(ending_input.size())
        ending_input.masked_fill_(ending_mask, -float('inf'))

In [None]:
import torch
from torch import nn
import math


def torch_bmm_nd(inp_1, inp_2, ndim):
    """Fast nd matrix multiplication"""
    # faster replacement of torch.einsum ("bhqk,bhkd->bhqd")
    return torch.bmm(
        inp_1.reshape((-1,) + inp_1.shape[-2:]), inp_2.reshape((-1,) + inp_2.shape[-2:])
    ).view(inp_1.shape[: ndim - 2] + (inp_1.shape[ndim - 2], inp_2.shape[ndim - 1]))


def torch_bmm_nd_transpose(inp_1, inp_2, ndim):
    """Fast nd matrix multiplication with transpose"""
    # faster replacement of torch.einsum (bhqd,bhkd->bhqk)
    return torch.bmm(
        inp_1.reshape((-1,) + inp_1.shape[-2:]),
        inp_2.reshape((-1,) + inp_2.shape[-2:]).transpose(1, 2),
    ).view(inp_1.shape[: ndim - 2] + (inp_1.shape[ndim - 2], inp_2.shape[ndim - 2]))


def create_band_mask(attention_mask: torch.Tensor, block_size: int):
    batch_size, seq_length = attention_mask.size()
    if seq_length % block_size != 0:
        raise ValueError(
            f"Sequence length must be multiple of block size, but sequence length is {seq_length}, while block"
            f" size is {block_size}."
        )

    def create_band_mask_from_inputs(from_blocked_mask, to_blocked_mask):
        """
        Create 3D attention mask from a 2D tensor mask.

        Args:
            from_blocked_mask: 2D Tensor of shape [batch_size,
            from_seq_length//from_block_size, from_block_size].
            to_blocked_mask: int32 Tensor of shape [batch_size,
            to_seq_length//to_block_size, to_block_size].

        Returns:
            float Tensor of shape [batch_size, 1, from_seq_length//from_block_size-4, from_block_size,
            3*to_block_size].
        """
        exp_blocked_to_pad = torch.cat(
            [
                to_blocked_mask[:, 1:-3],
                to_blocked_mask[:, 2:-2],
                to_blocked_mask[:, 3:-1],
            ],
            dim=2,
        )
        band_mask = torch.einsum(
            "blq,blk->blqk", from_blocked_mask[:, 2:-2], exp_blocked_to_pad
        )
        band_mask.unsqueeze_(1)
        return band_mask

    blocked_encoder_mask = attention_mask.view(
        batch_size, seq_length // block_size, block_size
    )
    band_mask = create_band_mask_from_inputs(blocked_encoder_mask, blocked_encoder_mask)

    return band_mask


def big_bird(
    query_layer,
    key_layer,
    value_layer,
    block_size,
):
    attn_mask_penalty = -10000.0
    bsz = query_layer.shape[0]
    n_heads = query_layer.shape[1]
    seq_len = query_layer.shape[2]
    attention_mask = torch.ones(
        (bsz, seq_len), device=query_layer.device, dtype=query_layer.dtype
    )
    band_mask = create_band_mask(
        attention_mask,
        block_size,
    )
    to_mask = attention_mask.view(bsz, 1, 1, seq_len)

    blocked_query_matrix = query_layer.view(
        bsz, n_heads, seq_len // block_size, block_size, -1
    )
    blocked_key_matrix = key_layer.view(
        bsz, n_heads, seq_len // block_size, block_size, -1
    )
    blocked_value_matrix = value_layer.view(
        bsz, n_heads, seq_len // block_size, block_size, -1
    )
    exp_blocked_key_matrix = torch.cat(
        [
            blocked_key_matrix[:, :, 1:-3],
            blocked_key_matrix[:, :, 2:-2],
            blocked_key_matrix[:, :, 3:-1],
        ],
        dim=3,
    )  # [bsz, n_heads, seq_len//block_size-4, 3*block_size, -1]
    exp_blocked_value_matrix = torch.cat(
        [
            blocked_value_matrix[:, :, 1:-3],
            blocked_value_matrix[:, :, 2:-2],
            blocked_value_matrix[:, :, 3:-1],
        ],
        dim=3,
    )  # [bsz, n_heads, seq_len//block_size-4, 3*block_size, -1]
    middle_query_matrix = blocked_query_matrix[:, :, 2:-2]

    # sliding attention scores for q[-2:2]
    # [bsz, n_heads, seq_len//block_size-4, block_size, -1] x [b, n_heads, seq_len//block_size-4, 3*block_size, -1]
    inner_band_product = torch_bmm_nd_transpose(
        middle_query_matrix, exp_blocked_key_matrix, ndim=5
    )
    #     ==> [bsz, n_heads, seq_len//block_size-4, block_size, 3*block_size]

    # randn attention scores for q[-2:2]
    # [bsz, n_heads, seq_len//block_size-4, block_size, -1] x [bsz, n_heads, seq_len//block_size-4, n_rand_blocks*block_size, -1]
    #     ==> [bsz, n_heads, seq_len//block_size-4, block_size, n_rand_blocks*block_size]

    # Including 1st block (since it's global)
    first_band_product = torch.einsum(
        "bhlqd,bhkd->bhlqk", middle_query_matrix, blocked_key_matrix[:, :, 0]
    )  # [bsz, n_heads, seq_len//block_size-4, block_size, -1] x [bsz, n_heads, block_size, -1] ==> [bsz, n_heads, seq_len//block_size-4, block_size, block_size]

    # Including last block (since it's global)
    last_band_product = torch.einsum(
        "bhlqd,bhkd->bhlqk", middle_query_matrix, blocked_key_matrix[:, :, -1]
    )  # [bsz, n_heads, seq_len//block_size-4, block_size, -1] x [bsz, n_heads, block_size, -1] ==> [bsz, n_heads, seq_len//block_size-4, block_size, block_size]

    # masking padded tokens
    inner_band_product += (1.0 - band_mask) * attn_mask_penalty
    first_band_product += (
        1.0 - to_mask[:, :, :, :block_size].unsqueeze(3)
    ) * attn_mask_penalty
    last_band_product += (
        1.0 - to_mask[:, :, :, -block_size:].unsqueeze(3)
    ) * attn_mask_penalty

    # completing attention scores matrix for all q[-2:2]
    band_product = torch.cat(
        [first_band_product, inner_band_product, last_band_product],
        dim=-1,
    )  # [bsz, n_heads, seq_len//block_size-4, block_size, (5+n_rand_blocks)*block_size]

    # safely doing softmax since attention matrix is completed
    attn_weights = nn.functional.softmax(
        band_product, dim=-1
    )  # [bsz, n_heads, seq_len//block_size-4, block_size, (5+n_rand_blocks)*block_size]

    # contribution of sliding keys
    # [bsz, n_heads, m//block_size-4, block_size, 3*block_size] x [bsz, n_heads, seq_len//block_size-4, 3*block_size, -1]
    context_layer = torch_bmm_nd(
        attn_weights[:, :, :, :, block_size : 4 * block_size],
        exp_blocked_value_matrix,
        ndim=5,
    )
    #     ==> [bsz, n_heads, seq_len//block_size-4, block_size, -1]

    return context_layer

def pad_to_block_size(input_ids: torch.Tensor,
                       one_sided_window_size: int, pad_token_id: int):
    '''A helper function to pad tokens and mask to work with the sliding_chunks implementation of Longformer selfattention.
    Input:
        input_ids = torch.Tensor(bsz x seqlen): ids of wordpieces
        attention_mask = torch.Tensor(bsz x seqlen): attention mask
        one_sided_window_size = int: window size on one side of each token
        pad_token_id = int: tokenizer.pad_token_id
    Returns
        (input_ids, attention_mask) padded to length divisible by 2 * one_sided_window_size
    '''
    block_size = math.ceil(window_size / 3)
    seqlen = input_ids.size(2)
    padding_len = (block_size - seqlen % block_size) % block_size
    input_ids = F.pad(input_ids, (0, 0, 0, padding_len), value=pad_token_id)
    return input_ids


In [None]:
def full_attention_kernel(query, key, value, window_size):
    att = torch.matmul(query, key.transpose(-1, -2))
    return torch.matmul(att, value)

def longformer_tvm_kernel(query, key, value, window_size):
    att = longformer_tvm_matmul(query, key, window_size, 1, False)
    return longformer_tvm_matmul(att, value, window_size, 1, True)

def longformer_pytorch_kernel(query, key, value, window_size):
    query = pad_to_window_size(query, window_size, 0)
    key = pad_to_window_size(key, window_size, 0)
    value = pad_to_window_size(value, window_size, 0)
    att = sliding_chunks_matmul_qk(query, key, window_size, 0)
    return sliding_chunks_matmul_pv(att, value, window_size)

def big_bird_kernel(query, key, value, window_size):
    query = pad_to_block_size(query, window_size, 0)
    key = pad_to_block_size(key, window_size, 0)
    value = pad_to_block_size(value, window_size, 0)
    return big_bird(query, key, value, math.ceil(window_size / 3))

def custom_kernel(query, key, value, window_size):
    att = window_matmul.window_matmul(query, key, window_size)
    return window_matmul.unwindow_matmul(att, value, window_size)

In [None]:
batch_size = 16
num_heads = 12
hidden_dim = 384
seq_lens = list(2**idx for idx in range(6, 13))
repeat = 25
settings = {
    "Full Attention": {"func": full_attention_kernel, "window_sizes": [None]},
    "Longformer (TVM)": {"func": longformer_tvm_kernel, "window_sizes": [4, 64]},
    "Longformer (PT)": {"func": longformer_pytorch_kernel, "window_sizes": [4, 64]},
    "BigBird": {"func": big_bird_kernel, "window_sizes": [4, 64]},
    "Ours": {"func": custom_kernel, "window_sizes": [4, 64]},
}
kernel_data = []
# with profile(activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA]) as prof:
for seq_len in tqdm.tqdm(seq_lens):
    shape = (batch_size, num_heads, seq_len, hidden_dim // num_heads)
    for name, implementation_settings in settings.items():
        query = torch.rand(*shape, device=device)
        key = torch.rand(*shape, device=device)
        value = torch.rand(*shape, device=device)
        if "Longformer" in name:
            query = query.transpose(-2, -3).contiguous()
            key = key.transpose(-2, -3).contiguous()
            value = value.transpose(-2, -3).contiguous()
        run_times = []
        window_sizes = implementation_settings["window_sizes"]
        func = implementation_settings["func"]
        for window_size in window_sizes:
            run_times = []
            for _ in range(repeat):
                try:
                    begin_mem = torch.cuda.memory_allocated()
                    torch.cuda.reset_peak_memory_stats()
                    start = time.perf_counter()
                    func(query, key, value, window_size)
                    torch.cuda.synchronize()
                    run_time = time.perf_counter() - start
                    max_mem = torch.cuda.max_memory_allocated() - begin_mem
                    # time.sleep(0.001)
                    save_name = name
                    kernel_data.append([seq_len, name, window_size, run_time, max_mem])
                except Exception as e:
                    pass
                    # print(name, window_size, e)
kernel_df = pd.DataFrame(kernel_data, columns=["seq_len", "name", "window_size", "time", "space"])
kernel_df.to_json("kernel_df.json")
del kernel_data
kernel_df

In [None]:
plot_df = kernel_df.copy()
plot_df["time"] = plot_df["time"] / batch_size * 1000
plot_df["space"] = plot_df["space"] / batch_size
_, unit = human_readable_bytes(plot_df["space"])
plot_df["space"] = human_readable_bytes(plot_df["space"], unit)[0]
plot_df = plot_df.loc[~plot_df["window_size"].isin([0, 1])]
plot_df["window_size"] = plot_df["window_size"].fillna(float("inf"))
plot_df = plot_df.groupby(["name", "window_size", "seq_len"], dropna=False).median()
plot_df = plot_df.reindex(["Full Attention", "BigBird", "Longformer (PT)", "Longformer (TVM)", "Ours"], level=0)
# plot_df

scale = 0.85
fig, axes = plt.subplots(2, 2, figsize=(6 * scale, 5 * scale), sharex=True, sharey="row")
labels = set()
colors = {}
for approach_idx, approach in enumerate(plot_df.index.get_level_values('name').unique()):
    approach_df = plot_df.loc[approach]
    window_sizes = sorted(approach_df.index.get_level_values("window_size").unique())
    if len(window_sizes) == 1:
        window_sizes = [window_sizes[0]] * 2
    for row_idx in range(2):
        for col_idx, window_size in enumerate(window_sizes):
            value_median = approach_df.loc[window_size]
            ax = axes[row_idx, col_idx]

            x = value_median.index.values
            if row_idx == 0:
                y = value_median.loc[:, "time"]
            else:
                y = value_median.loc[:, "space"]
            
            label = None
            if row_idx == 0 and col_idx == 0:
                label = approach
            line = ax.plot(x, y, label=label if col_idx == 0 else None, marker=markers[approach_idx])

axes[0, 0].set_xscale("log", base=2)
axes[0, 0].set_yscale("log")
axes[1, 0].set_yscale("log")
axes[0, 0].get_xaxis().set_major_formatter(matplotlib.ticker.ScalarFormatter())
axes[0, 0].minorticks_off()
axes[0, 0].set_xticks(seq_lens[::2])
axes[0, 0].set_title("Window Size 4")
axes[0, 0].set_ylabel("ms / Sequence")
axes[0, 1].set_title("Window Size 64")
axes[1, 0].set_ylabel(f"{unit} / Sequence")
axes[1, 0].set_xlabel("Sequence Length")
axes[1, 1].set_xlabel("Sequence Length")
# axes[0, 0].set_ylim(None, 1.5)
# axes[1, 0].set_ylim(None, 1000)
fig.legend(ncols=3, bbox_to_anchor=(0.5, -0.025), loc="center")
fig.tight_layout()
plt.savefig("kernel-efficiency.pdf", bbox_inches='tight')
plt.show()

In [None]:
from listformer.model.listformer import (
    list_softmax,
    untranspose_for_scores,
    to_windowed_attention_mask
)

import math


def pad_to_window_size_and_transpose(
    tensor: torch.Tensor, one_sided_window_size: int, pad_token_id: int
):
    batch_size, num_heads, num_docs, seq_len, head_dim = tensor.shape
    tensor = tensor.permute(0, 2, 3, 1, 4)
    tensor = tensor.view(batch_size * num_docs, seq_len, num_heads, head_dim)
    w = int(2 * one_sided_window_size)
    padding_len = (w - seq_len % w) % w
    tensor = F.pad(tensor, (0, 0, 0, 0, 0, padding_len), value=pad_token_id)
    return tensor


def unpad_and_transpose(
    tensor: torch.Tensor, batch_size, num_heads, num_docs, seq_len, head_dim
):
    return tensor.view(batch_size, num_docs, -1, num_heads, head_dim).permute(
        0, 3, 1, 2, 4
    )[:, :, :, :seq_len, :]


def _forward(
    self,
    inp,
    attention_window_sizes,
) -> torch.Tensor:
    # TODO use nested tensors when broadcasting is supported
    # https://pytorch.org/docs/stable/nested.html
    query_layer = inp.query_layer
    key_layers = inp.key_layers
    value_layers = inp.value_layers
    attention_masks = inp.attention_masks

    batch_size, num_heads, num_docs, seq_len, head_dim = query_layer.shape

    # Take the dot product between "query" and "key" to get the raw attention scores
    attention_scores = []
    for key_layer, window_size in zip(key_layers, attention_window_sizes):
        if window_size is not None:
            attention_score = sliding_chunks_matmul_qk(
                pad_to_window_size_and_transpose(query_layer, window_size, 0),
                pad_to_window_size_and_transpose(key_layer, window_size, 0),
                window_size,
                0,
            )
            attention_score = unpad_and_transpose(
                attention_score,
                batch_size,
                num_heads,
                num_docs,
                seq_len,
                attention_score.shape[-1],
            )
            attention_scores.append(attention_score)
        else:
            attention_scores.append(
                torch.matmul(query_layer, key_layer.transpose(-1, -2))
            )
    # attention: batch_size x num_heads x num_docs x seq_len
    # or
    # attention: batch_size x num_heads x num_docs * seq_len x num_docs * seq_len

    attention_scores = [
        attention_score / math.sqrt(self.attention_head_size)
        for attention_score in attention_scores
    ]

    iterator = enumerate(zip(attention_masks, attention_scores, attention_window_sizes))
    for idx, (attention_mask, attention_score, attention_window_size) in iterator:
        if attention_mask is not None:
            if attention_window_size is not None:
                attention_mask = to_windowed_attention_mask(
                    attention_mask, attention_window_size
                )
            attention_score = attention_score + attention_mask
        attention_score = attention_score.clamp(torch.finfo(attention_score.dtype).min)
        attention_scores[idx] = attention_score

    # Normalize the attention scores to probabilities
    attention_probs = list_softmax(attention_scores, dim=-1)

    # This is actually dropping out entire tokens to attend to, which might
    # seem a bit unusual, but is taken from the original Transformer paper.
    attention_probs = [
        self.dropout(attention_prob) for attention_prob in attention_probs
    ]

    context_layers = []
    iterator = zip(attention_probs, value_layers, attention_window_sizes)
    for attention_prob, value_layer, window_size in iterator:
        attention_prob = attention_prob.to(value_layer)
        if window_size is not None:
            context_layer = sliding_chunks_matmul_pv(
                pad_to_window_size_and_transpose(attention_prob, window_size, 0),
                pad_to_window_size_and_transpose(value_layer, window_size, 0),
                window_size,
            )
            context_layer = unpad_and_transpose(
                context_layer, batch_size, num_heads, num_docs, seq_len, head_dim
            )
            context_layers.append(context_layer)
        else:
            context_layers.append(torch.matmul(attention_prob, value_layer))

    context_layer = reduce(torch.add, context_layers)

    context_layer = untranspose_for_scores(context_layer, self.all_head_size)

    return context_layer

In [None]:
max_emb = 4672
config = transformers.AutoConfig.from_pretrained("cross-encoder/ms-marco-MiniLM-L-6-v2")
config.update({"max_position_embeddings": max_emb})
longformer_config = transformers.AutoConfig.from_pretrained(
    "allenai/longformer-base-4096"
)
longformer_config.update(config.to_dict())
longformer_config.update({"attention_window": [128] * config.num_hidden_layers})
bert_config = transformers.AutoConfig.from_pretrained("bert-base-uncased")
bert_config.update(config.to_dict())
full_attention_config = ListformerConfig()
full_attention_config.update(config.to_dict())
full_attention_config.update({"cls_token_id": 1, "max_position_embeddings": 4672})
sparse_cross_encoder_config = deepcopy(full_attention_config)
sparse_cross_encoder_config.update({"query_doc_attention": False, "query_cls_attention": False})

In [None]:
bert = transformers.BertModel(bert_config).eval().to(device)
longformer = transformers.LongformerModel(longformer_config).eval().to(device)
full_attention = ListformerModel(full_attention_config).eval().to(device)
sparse_cross_encoder = ListformerModel(sparse_cross_encoder_config).eval().to(device)
longformer_kernel = ListformerModel(sparse_cross_encoder_config).eval().to(device)
for module in longformer_kernel.modules():
    if module._get_name() == "ListformerSelfAttention":
        setattr(module, "_forward", partial(_forward, module))

In [None]:
models = {
    "Full Attention": bert,
    "Longformer": longformer,
    "QDS-Transformer": longformer,
    "Sparse Cross-Encoder (No Cross)": full_attention,
    "Sparse Cross-Encoder (No Kernel)": longformer_kernel,
    "Sparse Cross-Encoder (Ours)": sparse_cross_encoder,
}

repeat = 5
query_len = 10
num_tokens_per_sentence = 30
seq_lens = list(2**idx for idx in range(6, 13)) + [164]
window_sizes = [4, 64]

pg = tqdm.tqdm(models.items())

def expand_tensor(key, tensor, num_docs, use_batch_dim):
    if key in ("input_ids", "global_attention_mask"):
        return tensor.expand(num_docs, -1).clone().contiguous()
    if key == "query_input_ids":
        if use_batch_dim:
            return tensor.expand(num_docs, -1).clone().contiguous()
        return tensor.contiguous()
    if key == "doc_input_ids":
        if use_batch_dim:
            return tensor.expand(num_docs, -1, -1).clone().contiguous()
        return tensor.expand(-1, num_docs, -1).clone().contiguous()
    raise ValueError(f"unknown key {key}")

def find_max_docs(model, kwargs, use_batch_dim):
    num_docs = 100
    while True:
        inp = {key: expand_tensor(key, tensor, num_docs, use_batch_dim) for key, tensor in kwargs.items()}
        try:
            with torch.no_grad():
                model(**inp)
            torch.cuda.synchronize()
            return num_docs
        except torch.cuda.OutOfMemoryError:
            if num_docs == 100:
                num_docs = 128
            else:
                num_docs = num_docs // 2
            if not num_docs:
                raise ValueError("unable to run model")
            

model_data = []
# with profile(activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA]) as prof:
for model_name, model in pg:
    for seq_len in seq_lens:
        doc_len = seq_len - query_len
        torch.manual_seed(42)
        kwargs = {}
        model_window_sizes = [None]
        if isinstance(model, ListformerModel):
            model_window_sizes = window_sizes
            kwargs["query_input_ids"] = torch.randint(1000, 10000, (1, query_len,)).to(device)
            kwargs["doc_input_ids"] = torch.randint(1000, 10000, (1, 1, doc_len)).to(device)
        else:
            sequence_input = torch.randint(1000, 10000, (1, seq_len)).to(device)
            kwargs["input_ids"] = sequence_input
            if model_name in ("Longformer", "QDS-Transformer"):
                global_attention_mask = torch.zeros_like(sequence_input).to(device)
                global_attention_mask[:, :query_len] = 1
                kwargs["global_attention_mask"] = global_attention_mask
                model_window_sizes = window_sizes
            if model_name == "QDS-Transformer":
                kwargs["global_attention_mask"][:, query_len::num_tokens_per_sentence] = 1

        use_batch_dim = "No Cross" in model_name

        for window_size in model_window_sizes:
            if window_size is not None:
                if hasattr(model.config, "attention_window_size"):
                    model.config.attention_window_size = window_size
                if hasattr(model.config, "attention_window"):
                    model.config.attention_window = [window_size * 2] * model.config.num_hidden_layers
                    for layer in model.encoder.layer:
                        assert hasattr(layer.attention.self, "one_sided_attn_window_size")
                        layer.attention.self.one_sided_attn_window_size = window_size
            
            num_docs = find_max_docs(model, kwargs, use_batch_dim)
            inp = {key: expand_tensor(key, tensor, num_docs, use_batch_dim) for key, tensor in kwargs.items()}
            if isinstance(model, ListformerModel):
                inp = {"inp": model.preprocess(**inp)}
            
            pg.set_description(f"{model_name} {seq_len} {query_len} {window_size} {num_docs}")

            with torch.no_grad():
                for _ in range(repeat):
                    begin_mem = torch.cuda.memory_allocated()
                    torch.cuda.reset_peak_memory_stats()
                    start = time.perf_counter()
                    if isinstance(model, ListformerModel):
                        model.encode(**inp)
                    else:
                        model(**inp)
                    torch.cuda.synchronize()
                    model_time = time.perf_counter() - start
                    max_mem = torch.cuda.max_memory_allocated() - begin_mem
                    # time.sleep(0.001)
                    model_data.append([model_name, query_len, seq_len, window_size, model_time, max_mem, num_docs])
model_df = pd.DataFrame(model_data, columns=["name", "query_len", "seq_len", "window_size", "time", "space", "num_docs"])
model_df.to_json("model_df.json")
model_df

In [None]:
median_df = pd.read_json("model_df.json")
median_df["time"] = median_df["time"] / median_df["num_docs"] * 1000
median_df["space"] = median_df["space"] / median_df["num_docs"]
_, unit = human_readable_bytes(median_df.loc[median_df["name"].str.contains("Ours"), "space"])
median_df["space"] = human_readable_bytes(median_df["space"], unit)[0]
median_df = median_df.loc[~median_df["window_size"].isin([0, 1])]
median_df["window_size"] = median_df["window_size"].fillna(float("inf"))
median_df = median_df.groupby(["name", "window_size", "seq_len", "query_len"], dropna=False)[["time", "space"]].median()
median_df

In [None]:
plot_df = median_df.loc[["Full Attention", "Longformer", "QDS-Transformer", "Sparse Cross-Encoder (Ours)"]]
plot_df = plot_df.drop(164, level=2)
plot_df

scale = 0.85
fig, axes = plt.subplots(2, 2, figsize=(6 * scale, 5 * scale), sharex=True, sharey="row")
labels = set()
colors = {}
for approach_idx, approach in enumerate(plot_df.index.get_level_values('name').unique()):
    approach_df = plot_df.loc[approach].loc[pd.IndexSlice[:, :, 10]]
    window_sizes = sorted(approach_df.index.get_level_values("window_size").unique())
    if len(window_sizes) == 1:
        window_sizes = [window_sizes[0]] * 2
    for row_idx in range(2):
        for col_idx, window_size in enumerate(window_sizes):
            value_median = approach_df.loc[window_size]
            ax = axes[row_idx, col_idx]

            x = value_median.index.values
            if row_idx == 0:
                y = value_median.loc[:, "time"]
            else:
                y = value_median.loc[:, "space"]
            
            label = None
            if row_idx == 0 and col_idx == 0:
                label = approach
            line = ax.plot(x, y, label=label if col_idx == 0 else None, marker=markers[approach_idx])

axes[0, 0].set_xscale("log", base=2)
axes[0, 0].set_yscale("log")
axes[1, 0].set_yscale("log")
axes[0, 0].get_xaxis().set_major_formatter(matplotlib.ticker.ScalarFormatter())
axes[0, 0].set_xticks(seq_lens[::2])
axes[0, 0].minorticks_off()
axes[0, 0].set_title("Window Size 4")
axes[0, 0].set_ylabel("ms / Sequence")
axes[0, 1].set_title("Window Size 64")
axes[1, 0].set_ylabel(f"{unit} / Sequence")
axes[1, 0].set_xlabel("Sequence Length")
axes[1, 1].set_xlabel("Sequence Length")
# axes[0, 0].set_ylim(0, 20)
# axes[1, 0].set_ylim(None, 250)
# fig.legend(loc="center", bbox_to_anchor=(1.15, 0.5))
fig.legend(ncols=2, bbox_to_anchor=(0.5, -0.025), loc="center")
fig.tight_layout()
plt.savefig("model-efficiency.pdf", bbox_inches='tight')
plt.show()

In [None]:
models = [
    ("Full Attention", float("inf")),
    ("Longformer", 64),
    ("QDS-Transformer", 64),
    ("Sparse Cross-Encoder (Ours)", 64),
    ("Sparse Cross-Encoder (Ours)", 4),
    ("Sparse Cross-Encoder (No Kernel)", 4),
    ("Sparse Cross-Encoder (No Cross)", 4),
]
comparison_models = [("Full Attention", float("inf"), 164), ("Longformer", 64, 4096)]

entries = []
for seq_len in (164, 4096):
    for model in models:
        entries.append((*model, seq_len))

efficiency_df = median_df.copy()
efficiency_df.loc[pd.IndexSlice[:, :, [164]], "time"] = efficiency_df.loc[pd.IndexSlice[:, :, [164]], "time"] * 1000
# efficiency_df.loc[:, "time"] = efficiency_df.loc[:, "time"] * 1000
model_df = efficiency_df.droplevel(-1).reindex(entries).round()
comparison_df = efficiency_df.droplevel(-1).loc[comparison_models].droplevel([0, 1]).round()
improvement_df = (model_df - comparison_df) / comparison_df
improvement_df = improvement_df.sort_index()
improvement_df = improvement_df.unstack([0, 1]).reorder_levels((1, 2, 0), axis=1).sort_index(axis=1, ascending=(True, False))
improvement_df = improvement_df.multiply(100).round().astype(int)

absolute_df = efficiency_df.droplevel(-1).reindex(entries).round().astype(int)
absolute_df = absolute_df.sort_index()
absolute_df = absolute_df.unstack([0, 1]).reorder_levels((1, 2, 0), axis=1).sort_index(axis=1, ascending=(True, False))
absolute_df = absolute_df

table = absolute_df.astype(str) + " (" + improvement_df.astype(str) + ")"

table = table.stack(-1).droplevel(0)

table = table.reindex(models, axis=1)

display(table)
print(table.to_latex())

In [None]:
improvement_df = median_df.loc[pd.IndexSlice[:, :, 512]].droplevel(-1)
time_improvement = pd.DataFrame(
    (improvement_df.loc[:, "time"].values[None, :] - improvement_df.loc[:, "time"].values[:, None]) / improvement_df.loc[:, "time"].values[None, :],
    index=improvement_df.index,
    columns=improvement_df.index,
)
space_improvement = pd.DataFrame(
    (improvement_df.loc[:, "space"].values[None, :] - improvement_df.loc[:, "space"].values[:, None]) / improvement_df.loc[:, "space"].values[None, :],
    index=improvement_df.index,
    columns=improvement_df.index,
)
# display(time_improvement.multiply(100).round())
# display(space_improvement.multiply(100).round())
improvement = pd.concat(
    [
        time_improvement.loc["Sparse Cross-Encoder (Ours)"].drop("Sparse Cross-Encoder (Ours)", axis=1),
        space_improvement.loc["Sparse Cross-Encoder (Ours)"].drop("Sparse Cross-Encoder (Ours)", axis=1)
    ],
    keys=["Time", "Space"],
    names=["values"],
)
# improvement.rename({"window_size": "from_window_size"}, axis=1, level=1)
improvement = improvement.unstack(level=0)
improvement.multiply(100).round()

In [None]:
improvement_df = median_df.loc[pd.IndexSlice[:, :, 512]].droplevel(-1)
time_improvement = pd.DataFrame(
    (improvement_df.loc[:, "time"].values[None, :] - improvement_df.loc[:, "time"].values[:, None]) / improvement_df.loc[:, "time"].values[None, :],
    index=improvement_df.index,
    columns=improvement_df.index,
)
space_improvement = pd.DataFrame(
    (improvement_df.loc[:, "space"].values[None, :] - improvement_df.loc[:, "space"].values[:, None]) / improvement_df.loc[:, "space"].values[None, :],
    index=improvement_df.index,
    columns=improvement_df.index,
)
# display(time_improvement.multiply(100).round())
# display(space_improvement.multiply(100).round())
improvement = pd.concat(
    [
        time_improvement.loc["Sparse Cross-Encoder (Ours)"].drop("Sparse Cross-Encoder (Ours)", axis=1),
        space_improvement.loc["Sparse Cross-Encoder (Ours)"].drop("Sparse Cross-Encoder (Ours)", axis=1)
    ],
    keys=["Time", "Space"],
    names=["values"],
)
# improvement.rename({"window_size": "from_window_size"}, axis=1, level=1)
improvement = improvement.unstack(level=0)
improvement.multiply(100).round().loc[:, [("Full Attention", float("inf"), "Time")]]

In [None]:
print(improvement.multiply(100).round().to_latex())

In [None]:
models = {
#     "Full Attention": bert,
#     "Longformer": longformer,
#     "QDS-Transformer": longformer,
    "Sparse Cross-Encoder (Ours)": sparse_cross_encoder,
}

repeat = 1
num_docs = 100
start = 512
end = 4608
step = 1024
# step = 2048
seq_lens = list(range(start, end + step, step))
# query_lens = [10, 20, 30]
query_lens = [10]
num_tokens_per_sentence = 30
# window_sizes = [4, 16, 64]
window_sizes = [4, 64]

def expand_tensor(key, tensor, num_docs):
    if key in ("input_ids", "global_attention_mask"):
        return tensor.expand(num_docs, -1).clone()
    if key == "query_input_ids":
        return tensor
    if key == "doc_input_ids":
        return tensor.expand(-1, num_docs, -1).clone()
    raise ValueError(f"unknown key {key}")
    
def old_find_max_docs(model, kwargs):
    lower_bound = 0
    upper_bound = 8192
    prev_num_docs = float("inf")
    while True:
        inp = {key: expand_tensor(key, tensor, upper_bound) for key, tensor in kwargs.items()}
        step = (upper_bound - lower_bound) // 2
        print(lower_bound, upper_bound, step)
        try:
            model(**inp)
            torch.cuda.synchronize()
            lower_bound += step
        except torch.cuda.OutOfMemoryError as e:
            print(e)
            upper_bound -= step
        if lower_bound + 1 == upper_bound:
            return lower_bound
    
    inp = {key: expand_tensor(key, tensor, upper_bound) for key, tensor in kwargs.items()}
    fails = False
    try:
        model(**inp)
        torch.cuda.synchronize()
    except torch.cuda.OutOfMemoryError:
        fails = True
    assert fails
    inp = {key: expand_tensor(key, tensor, lower_bound) for key, tensor in kwargs.items()}
    fails = False
    try:
        model(**inp)
        torch.cuda.synchronize()
    except torch.cuda.OutOfMemoryError:
        fails = True
    assert not fails
    return lower_bound

def find_max_docs(model, kwargs):
    num_docs = 8192
    while True:
        inp = {key: expand_tensor(key, tensor, num_docs) for key, tensor in kwargs.items()}
        try:
            model(**inp)
            torch.cuda.synchronize()
            return num_docs
        except:
            torch.cuda.OutOfMemoryError
            num_docs = num_docs // 2
        

tracker = EmissionsTracker(save_to_file=False, log_level="error")
model_data = []

pg = tqdm.tqdm(models.items())
# with profile(activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA]) as prof:
for model_name, model in pg:
    for query_len in query_lens:
        for seq_len in seq_lens:
            if seq_len > 4096:
                break
            doc_len = seq_len - query_len
            torch.manual_seed(42)
            sequence_input = torch.randint(1000, 10000, (1, seq_len)).to(
                device
            )
            query_input = torch.randint(1000, 10000, (1, query_len,)).to(device)
            doc_input = torch.randint(1000, 10000, (1, 1, doc_len)).to(device)
            kwargs = {}
            model_window_sizes = [None]
            if model_name in ("Longformer", "Full Attention", "QDS-Transformer"):
                kwargs["input_ids"] = sequence_input
                if model_name in ("Longformer", "QDS-Transformer"):
                    global_attention_mask = torch.zeros_like(sequence_input).to(device)
                    global_attention_mask[:, :query_len] = 1
                    kwargs["global_attention_mask"] = global_attention_mask
                    model_window_sizes = window_sizes
                if model_name == "QDS-Transformer":
                    kwargs["global_attention_mask"][:, query_len::num_tokens_per_sentence] = 1
            else:
                model_window_sizes = window_sizes
                kwargs["query_input_ids"] = query_input
                kwargs["doc_input_ids"] = doc_input[:, :, :-1]

            for window_size in model_window_sizes:
                pg.set_description(f"{model_name} {seq_len} {query_len} {window_size}")
                if window_size is not None:
                    if hasattr(model.config, "attention_window_size"):
                        model.config.attention_window_size = window_size
                    if hasattr(model.config, "attention_window"):
                        model.config.attention_window = [window_size * 2] * model.config.num_hidden_layers
                        for layer in model.encoder.layer:
                            assert hasattr(layer.attention.self, "one_sided_attn_window_size")
                            layer.attention.self.one_sided_attn_window_size = window_size

                with torch.inference_mode():
                    num_docs = find_max_docs(model, kwargs)
                    pg.set_description(f"{model_name} {seq_len} {query_len} {window_size} {num_docs}")
                    repeats = 0
                    while True:
                        try:
                            tracker.start_task("run inference")
                            torch.cuda.reset_peak_memory_stats()
                            start = time.perf_counter()
                            inp = {key: expand_tensor(key, tensor, num_docs) for key, tensor in kwargs.items()}
                            model(**inp)
                            torch.cuda.synchronize()
                            model_time = time.perf_counter() - start
                            max_mem = torch.cuda.max_memory_allocated()
                            emissions = tracker.stop_task()
                            model_data.append(
                                [
                                    model_name,
                                    query_len,
                                    seq_len,
                                    window_size,
                                    model_time,
                                    max_mem,
                                    num_docs,
                                    emissions.energy_consumed,
                                    emissions.gpu_energy,
                                ]
                            )
                            repeats += 1
                            if repeats == repeat:
                                break
                        except torch.cuda.OutOfMemoryError:
                            num_docs -= 1
                            
model_df = pd.DataFrame(
    model_data, 
    columns=[
        "name",
        "query_len",
        "seq_len",
        "window_size",
        "time",
        "space",
        "num_docs",
        "total_energy",
        "gpu_energy",
    ]
)
model_df.to_json("model_df.json")
del model_data
model_df