In [2]:
from itertools import product
import pandas as pd

In [27]:
def interactions(
    query_len,
    doc_len,
    window_size,
    cls_query_attention,
    cls_doc_attention,
    query_cls_attention,
    query_doc_attention,
    doc_cls_attention,
    doc_query_attention,
):
    token_interactions = 1 + query_len * query_len
    if window_size is None:
        token_interactions += doc_len * doc_len
    else:
        token_interactions += doc_len * (2 * window_size + 1)
    if cls_query_attention:
        token_interactions += query_len
    if cls_doc_attention:
        token_interactions += doc_len
    if query_cls_attention:
        token_interactions += 1
    if query_doc_attention:
        token_interactions += query_len * doc_len
    if doc_cls_attention:
        token_interactions += 1
    if doc_query_attention:
        token_interactions += query_len * doc_len
    return token_interactions


patterns = {
    "Full Attention / Longformer": (True, True, True, True, True, True),
    "Independent query": (True, True, False, False, True, True),
    "CLS Interaction": (True, True, False, False, False, False),
}

doc_lens = [64 * i for i in range(1, 9)]
query_lens = [10, 20, 30]
window_sizes = [None, 0, 1, 4, 16, 64]
sparsity_data = []
for sizes in product(query_lens, doc_lens, window_sizes):
    for name, pattern in patterns.items():
        inter = interactions(*sizes, *pattern)
        sparsity_data.append([*sizes, name, inter])
sparsity_df = pd.DataFrame(
    sparsity_data,
    columns=["query_len", "doc_len", "window_size", "name", "interactions"],
)
sparsity_df["window_size"] = sparsity_df["window_size"].fillna(float("inf"))
sparsity_df

df = sparsity_df.pivot(
    columns=["query_len", "doc_len"],
    index=["name", "window_size"],
    values="interactions",
).sort_index(axis=0, ascending=(True, False))
df = df.loc[["Full Attention / Longformer", "Independent query", "CLS Interaction"]]
df = df.div(
    df.loc[pd.IndexSlice["Full Attention / Longformer", float("inf")]], axis=1
)
df = (df.loc[:, 10] * 100).round(1)
df = df.astype(str)
df
print(df.to_latex())

\begin{tabular}{llllllllll}
\toprule
 & doc_len & 64 & 128 & 192 & 256 & 320 & 384 & 448 & 512 \\
name & window_size &  &  &  &  &  &  &  &  \\
\midrule
\multirow[t]{6}{*}{Full Attention / Longformer} & inf & 100.0 & 100.0 & 100.0 & 100.0 & 100.0 & 100.0 & 100.0 & 100.0 \\
 & 64.000000 & 174.9 & 100.7 & 70.5 & 54.2 & 44.0 & 37.1 & 32.0 & 28.2 \\
 & 16.000000 & 64.3 & 36.6 & 25.6 & 19.6 & 15.9 & 13.4 & 11.6 & 10.2 \\
 & 4.000000 & 36.6 & 20.6 & 14.3 & 11.0 & 8.9 & 7.5 & 6.4 & 5.7 \\
 & 1.000000 & 29.7 & 16.6 & 11.5 & 8.8 & 7.1 & 6.0 & 5.2 & 4.5 \\
 & 0.000000 & 27.4 & 15.3 & 10.6 & 8.1 & 6.5 & 5.5 & 4.7 & 4.2 \\
\cline{1-10}
\multirow[t]{6}{*}{Independent query} & inf & 88.5 & 93.3 & 95.3 & 96.4 & 97.1 & 97.5 & 97.9 & 98.1 \\
 & 64.000000 & 163.4 & 94.0 & 65.8 & 50.6 & 41.1 & 34.6 & 29.9 & 26.3 \\
 & 16.000000 & 52.7 & 29.9 & 20.9 & 16.0 & 13.0 & 10.9 & 9.4 & 8.3 \\
 & 4.000000 & 25.1 & 13.9 & 9.6 & 7.4 & 6.0 & 5.0 & 4.3 & 3.8 \\
 & 1.000000 & 18.2 & 9.9 & 6.8 & 5.2 & 4.2 & 3.5 & 3.0 & 