In [1]:
import pandas as pd

In [2]:
df = pd.read_json("./efficiency.json", lines=True, orient="records")

In [3]:
df["num_texts"] = df["batch_size"] * df["num_batches"]
df["num_texts_per_sec"] = df["num_texts"] / df["time"]
df["thousand_num_texts_per_sec"] = df["num_texts_per_sec"] / 1000
df["kb / text"] = df["mem"] / df["num_texts"] / 1024

In [4]:
df

Unnamed: 0,model,text_type,grad,batch_size,num_batches,time,mem,num_texts,num_texts_per_sec,kb / text,thousand_num_texts_per_sec
0,bert-eager,docs,False,2048,5,5.24384,13504573440,10240,1952.767519,1287.896484,1.952768
1,bert-eager,queries,False,2048,5,0.426489,1109000192,10240,24009.982111,105.7625,24.009982
2,funnel-transformer,docs,False,1024,10,7.980217,17134274048,10240,1283.17311,1634.051709,1.283173
3,funnel-transformer,queries,False,2048,5,0.722079,1916725248,10240,14181.264907,182.793164,14.181265
4,tite-2-late-absolute-eager,docs,False,2048,5,1.526044,13537145856,10240,6710.160897,1291.002832,6.710161
5,tite-2-late-absolute-eager,queries,False,2048,5,0.146733,817796608,10240,69786.440645,77.991162,69.786441
6,bert-sdpa,docs,False,2048,5,3.204836,9839529984,10240,3195.170873,938.370703,3.195171
7,bert-sdpa,queries,False,2048,5,0.354237,1113063424,10240,28907.23234,106.15,28.907232
8,distil-bert-sdpa,docs,False,2048,5,1.608671,9834598400,10240,6365.503958,937.900391,6.365504
9,distil-bert-sdpa,queries,False,2048,5,0.176787,1112539136,10240,57922.956474,106.1,57.922956


In [5]:
columns = ["text_type", "grad"]
index = ["model"]
values = ["thousand_num_texts_per_sec", "kb / text"]
table = (
    df.pivot_table(values=values, index=index, columns=columns)
    .reorder_levels([1, 0, 2], axis=1)
    .sort_index(axis=1)
    .round(1)
)
table

text_type,docs,docs,queries,queries
Unnamed: 0_level_1,kb / text,thousand_num_texts_per_sec,kb / text,thousand_num_texts_per_sec
grad,False,False,False,False
model,Unnamed: 1_level_3,Unnamed: 2_level_3,Unnamed: 3_level_3,Unnamed: 4_level_3
bert-eager,1287.9,2.0,105.8,24.0
bert-flash,242.1,9.2,31.6,50.1
bert-rope-flash,242.1,8.7,31.6,48.0
bert-sdpa,938.4,3.2,106.2,28.9
distil-bert-sdpa,937.9,6.4,106.1,57.9
funnel-transformer,1634.1,1.3,182.8,14.2
modern-bert,187.1,8.3,24.8,41.1
tite-2-late-absolute-eager,1291.0,6.7,78.0,69.8
tite-2-late-absolute-spda,507.8,13.4,58.8,81.2
tite-2-late-rope-higher-dims,263.9,16.3,33.6,70.1


In [6]:
baseline = table.loc["bert-rope-flash"]
improvement = (table / baseline).round(1)
improvement

text_type,docs,docs,queries,queries
Unnamed: 0_level_1,kb / text,thousand_num_texts_per_sec,kb / text,thousand_num_texts_per_sec
grad,False,False,False,False
model,Unnamed: 1_level_3,Unnamed: 2_level_3,Unnamed: 3_level_3,Unnamed: 4_level_3
bert-eager,5.3,0.2,3.3,0.5
bert-flash,1.0,1.1,1.0,1.0
bert-rope-flash,1.0,1.0,1.0,1.0
bert-sdpa,3.9,0.4,3.4,0.6
distil-bert-sdpa,3.9,0.7,3.4,1.2
funnel-transformer,6.7,0.1,5.8,0.3
modern-bert,0.8,1.0,0.8,0.9
tite-2-late-absolute-eager,5.3,0.8,2.5,1.5
tite-2-late-absolute-spda,2.1,1.5,1.9,1.7
tite-2-late-rope-higher-dims,1.1,1.9,1.1,1.5


In [7]:
improvement.index.values

array(['bert-eager', 'bert-flash', 'bert-rope-flash', 'bert-sdpa',
       'distil-bert-sdpa', 'funnel-transformer', 'modern-bert',
       'tite-2-late-absolute-eager', 'tite-2-late-absolute-spda',
       'tite-2-late-rope-higher-dims', 'tite-2-late-rope-intra',
       'tite-2-late-rope-post', 'tite-2-late-rope-pre',
       'tite-2-staggered-rope', 'tite-3-late-rope',
       'tite-3-staggered-rope'], dtype=object)

In [8]:
model_order = [
    "bert-eager",
    "funnel-transformer",
    "tite-2-late-absolute-eager",
    "bert-sdpa",
    "distil-bert-sdpa",
    "tite-2-late-absolute-spda",
    "bert-flash",
    "bert-rope-flash",
    "modern-bert",
    "tite-2-late-rope-intra",
    "tite-2-late-rope-intra",
    "tite-2-staggered-rope",
    "tite-3-late-rope",
    "tite-3-staggered-rope",
    "tite-2-late-rope-pre",
    "tite-2-late-rope-post",
    "tite-2-late-rope-higher-dims",
]

In [10]:
pretty_table = table.round(1).astype(str) + " (" + improvement.round(1).astype(str) + "x)"
pretty_table = pretty_table.loc[model_order, pd.IndexSlice[["queries", "docs"], "thousand_num_texts_per_sec"]]
pretty_table

text_type,queries,docs
Unnamed: 0_level_1,thousand_num_texts_per_sec,thousand_num_texts_per_sec
grad,False,False
model,Unnamed: 1_level_3,Unnamed: 2_level_3
bert-eager,24.0 (0.5x),2.0 (0.2x)
funnel-transformer,14.2 (0.3x),1.3 (0.1x)
tite-2-late-absolute-eager,69.8 (1.5x),6.7 (0.8x)
bert-sdpa,28.9 (0.6x),3.2 (0.4x)
distil-bert-sdpa,57.9 (1.2x),6.4 (0.7x)
tite-2-late-absolute-spda,81.2 (1.7x),13.4 (1.5x)
bert-flash,50.1 (1.0x),9.2 (1.1x)
bert-rope-flash,48.0 (1.0x),8.7 (1.0x)
modern-bert,41.1 (0.9x),8.3 (1.0x)
tite-2-late-rope-intra,89.0 (1.9x),20.8 (2.4x)


In [11]:
print(pretty_table.to_latex())

\begin{tabular}{lll}
\toprule
text_type & queries & docs \\
 & thousand_num_texts_per_sec & thousand_num_texts_per_sec \\
grad & False & False \\
model &  &  \\
\midrule
bert-eager & 24.0 (0.5x) & 2.0 (0.2x) \\
funnel-transformer & 14.2 (0.3x) & 1.3 (0.1x) \\
tite-2-late-absolute-eager & 69.8 (1.5x) & 6.7 (0.8x) \\
bert-sdpa & 28.9 (0.6x) & 3.2 (0.4x) \\
distil-bert-sdpa & 57.9 (1.2x) & 6.4 (0.7x) \\
tite-2-late-absolute-spda & 81.2 (1.7x) & 13.4 (1.5x) \\
bert-flash & 50.1 (1.0x) & 9.2 (1.1x) \\
bert-rope-flash & 48.0 (1.0x) & 8.7 (1.0x) \\
modern-bert & 41.1 (0.9x) & 8.3 (1.0x) \\
tite-2-late-rope-intra & 89.0 (1.9x) & 20.8 (2.4x) \\
tite-2-late-rope-intra & 89.0 (1.9x) & 20.8 (2.4x) \\
tite-2-staggered-rope & 96.0 (2.0x) & 28.5 (3.3x) \\
tite-3-late-rope & 68.1 (1.4x) & 14.5 (1.7x) \\
tite-3-staggered-rope & 94.6 (2.0x) & 30.8 (3.5x) \\
tite-2-late-rope-pre & 89.6 (1.9x) & 21.2 (2.4x) \\
tite-2-late-rope-post & 89.0 (1.9x) & 20.3 (2.3x) \\
tite-2-late-rope-higher-dims & 70.1 (1.5x) 