In [None]:
import numpy as np
import matplotlib.pyplot as plt

# Plot lines y = line_slope * x + line_bias for each row in df_scaled_fit
# Colors:
# - Names containing "Diagonal Batching": green
# - Names containing "armt" and not "Diagonal Batching": blue
# - All others: gray
# Model names are annotated to the right of each line.

# Ensure the DataFrame exists
try:
    df = df_scaled_fit.copy()
except NameError as exc:
    raise NameError("df_scaled_fit is not defined in the current notebook scope.") from exc

required_cols = {"line_slope", "line_bias"}
missing = required_cols - set(df.columns)
if missing:
    raise KeyError(f"df_scaled_fit is missing required columns: {missing}")

# Try to find a model/name column for annotation
name_col = None
for candidate in ["model_name", "model", "name", "Model", "Model Name"]:
    if candidate in df.columns:
        name_col = candidate
        break
if name_col is None:
    candidates = [c for c in df.columns if ("model" in c.lower()) or ("name" in c.lower())]
    if candidates:
        name_col = candidates[0]
    else:
        # Fallback to using the index as a name
        df = df.assign(__name__=df.index.astype(str))
        name_col = "__name__"

# Determine y-limits based on y at x=0 and x=1 to keep lines visible
y_at_0 = df["line_bias"].to_numpy()
# y(1) = slope * 1 + bias
y_at_1 = (df["line_slope"].to_numpy() + df["line_bias"].to_numpy())
all_y = np.concatenate([y_at_0, y_at_1])
y_min = np.nanmin(all_y) if all_y.size else 0.0
y_max = np.nanmax(all_y) if all_y.size else 1.0
y_span = (y_max - y_min) if np.isfinite(y_max - y_min) else 1.0
margin = 0.05 * y_span if y_span > 0 else 0.5

fig, ax = plt.subplots(figsize=(8, max(4, min(12, 0.35 * len(df)))))

for _, row in df.iterrows():
    slope = float(row["line_slope"])
    bias = float(row["line_bias"])
    name = str(row[name_col])
    lower_name = name.lower()

    if "diagonal batching" in lower_name:
        color = "green"
    elif "armt" in lower_name:
        color = "blue"
    else:
        color = "gray"

    x_vals = np.array([0.0, 1.0])
    y_vals = slope * x_vals + bias
    ax.plot(x_vals, y_vals, color=color, linewidth=2, alpha=0.9)

    # Place label slightly to the right of x=1 using data coordinates
    y_right = slope * 1.0 + bias
    ax.text(1.02, y_right, name, color=color, va="center", ha="left", fontsize=9)

ax.set_xlim(0.0, 1.15)
ax.set_ylim(y_min - margin, y_max + margin)
ax.set_xlabel("x")
ax.set_ylabel("y = slope * x + bias")
ax.set_title("Lines from df_scaled_fit (color-coded by model name)")
ax.grid(True, linestyle=":", alpha=0.3)
plt.tight_layout()
plt.show()


In [1]:
import pandas as pd

In [2]:
!ls *csv

result_all.csv
result_fla-hub__rwkv7-0.4B-world.csv
result_fla-hub__rwkv7-0.4B-world_torch.bfloat16.csv
result_fla-hub__rwkv7-1.5B-world.csv
result_fla-hub__rwkv7-1.5B-world_torch.bfloat16.csv
result_fla-hub__rwkv7-2.9B-world.csv
result_fla-hub__rwkv7-2.9B-world_torch.bfloat16.csv
result_state-spaces__mamba-1.4b-hf_torch.bfloat16.csv
result_state-spaces__mamba-1.4b-hf_torch.float32.csv
result_state-spaces__mamba-2.8b-hf_torch.bfloat16.csv
result_state-spaces__mamba-2.8b-hf_torch.float32.csv
result_state-spaces__mamba-370m-hf_torch.bfloat16.csv
result_state-spaces__mamba-370m-hf_torch.float32.csv
result_state-spaces__mamba-790m-hf_torch.bfloat16.csv
result_state-spaces__mamba-790m-hf_torch.float32.csv


In [3]:
import os

In [4]:
import os

dfs_paths = [f for f in os.listdir('/home/jovyan/sivtsov/diagonal-batching/linear_comparision') 
             if f.endswith('.csv') and 'result_all' not in f]

In [5]:
df_all = pd.DataFrame(columns=['model', 'dtype', 4096, 8192, 12288, 16384, 24576, 32768, 65536])

for p in dfs_paths:
    df = pd.read_csv(p)
    
    if "float" in p:
        model_name = p.rsplit('_', 1)[-2]
        dtype_used = p.split('.')[-2]
    else:
        model_name = p.rsplit('.', 1)[-2]
        dtype_used = "float32"
    
    
    
    
    df_use = df[df['is_warmup'] == False]
    df_use = df_use[df_use['iter'] > 0]
    
    df_agg = df_use.groupby('input_size').agg({'time': 'mean'}).reset_index()
    
    append_row = {'model': model_name, 'dtype': dtype_used}
    for i in df_agg['input_size']:
        append_row[i] = df_agg[df_agg['input_size'] == i]['time'].iloc[0]
    df_all = pd.concat([df_all, pd.DataFrame([append_row])], ignore_index=True)
    
# df_all.to_csv('result_all.csv', index=False)

  df_all = pd.concat([df_all, pd.DataFrame([append_row])], ignore_index=True)
  df_all = pd.concat([df_all, pd.DataFrame([append_row])], ignore_index=True)


In [6]:
import re
import numpy as np

def parse_latex_table(latex_text):
    """
    Parse a LaTeX table and extract performance data into a DataFrame.
    
    Args:
        latex_text (str): LaTeX table text containing performance data
        
    Returns:
        pd.DataFrame: DataFrame with columns ['model', 'dtype', 4096, 8192, 12288, 16384, 24576, 32768, 65536]
    """
    pattern = re.compile(r"^\s*(?!\\)(?!.*Configuration)(?!.*toprule)(?!.*midrule)(?!.*bottomrule)(?!.*textbf)(?!.*cmidrule)(?!.*rowcolor)(.+?)\s*&\s*([0-9.]+)\s*&\s*([0-9.]+)\s*&\s*([0-9.]+)\s*&\s*([0-9.]+)\s*&\s*([0-9.]+)\s*&\s*([0-9.]+)\s*\\\\",
                             flags=re.MULTILINE)

    # Extract configuration pattern
    config_pattern = re.compile(r"Configuration:\s*\((\d+),\s*(\d+)\)")

    rows = []
    current_config = ""

    # Split text by lines and track configurations
    lines = latex_text.split('\n')
    for line in lines:
        # Check for configuration
        config_match = config_pattern.search(line)
        if config_match:
            current_config = f"({config_match.group(1)}, {config_match.group(2)})"
            continue
        
        # Check for data rows
        match = pattern.match(line)
        if match:
            name = match.group(1).strip()
            
            # Add configuration to model name if we have one
            if current_config and name != "Llama-3.2-1B":  # Don't add config to baseline
                name = f"{name} {current_config}"
            
            v4096 = float(match.group(2))
            v8192 = float(match.group(3))
            v16384 = float(match.group(4))
            v32768 = float(match.group(5))
            v65536 = float(match.group(6))
            v131072 = float(match.group(7))  # parsed but not used in requested columns
            rows.append({
                'model': name,
                'dtype': 'bfloat16',
                4096: v4096,
                8192: v8192,
                12288: np.nan,  # not present in table
                16384: v16384,
                24576: np.nan,  # not present in table
                32768: v32768,
                65536: v65536,
            })

    cols_out = ['model', 'dtype', 4096, 8192, 12288, 16384, 24576, 32768, 65536]
    df_latex = pd.DataFrame(rows)[cols_out]
    return df_latex

latex_text = r"""
\begin{table}[h]
  \centering
  \renewcommand{\arraystretch}{1.2}
  \resizebox{\textwidth}{!}{%
  \begin{tabular}{l*{6}{S[table-format=3.3]}}
  \toprule
  \textbf{Method} & \multicolumn{6}{c}{\textbf{Sequence Length}} \\
  \cmidrule(lr){2-7}
   & {\textbf{4096}} & {\textbf{8192}} & {\textbf{16384}} & {\textbf{32768}} & {\textbf{65536}} & {\textbf{131072}} \\
  \midrule
  Llama-3.2-1B & 0.024 & 0.026 & 0.376 & 0.926 & 2.460 & 8.160 \\
  \rowcolor{gray!10} \textbf{Configuration: (512, 128)} \\
  LLama-3.2-1B-ARMT & 0.147 & 0.574 & 1.15 & 2.29 & 4.52 & 8.98 \\
  Diagonal Batching: LLama-3.2-1B-ARMT & 0.283 & 0.248 & 0.454 & 0.861 & 1.67 & 3.3 \\
  \midrule
  \rowcolor{gray!10} \textbf{Configuration: (1024, 128)} \\
  LLama-3.2-1B-ARMT & 0.149 & 0.291 & 0.578 & 1.15 & 2.3 & 4.48 \\
  Diagonal Batching: LLama-3.2-1B-ARMT & 0.119 & 0.196 & 0.351 & 0.656 & 1.27 & 2.48\\
  \midrule
  \rowcolor{gray!10} \textbf{Configuration: (2048, 128)} \\
  LLama-3.2-1B-ARMT & 0.094 & 0.177 & 0.344 & 0.679 & 1.35 & 2.68 \\
  Diagonal Batching: LLama-3.2-1B-ARMT & 0.108 & 0.176 & 0.304 & 0.571 & 1.11 & 2.18 \\
  \midrule
  \rowcolor{gray!10} \textbf{Configuration: (4096, 128)} \\
  LLama-3.2-1B-ARMT & 0.082 & 0.155 & 0.301 & 0.594 & 1.18 & 2.35 \\
  Diagonal Batching: LLama-3.2-1B-ARMT & 0.102 & 0.172 & 0.295 & 0.553 & 1.07 & 2.1 \\
  \bottomrule
  \end{tabular}%
  }
  \caption{Diagonal batching allows to speed-up the execution for longer sequences - from 1.1 times to 2.7 times with respect to base ARMT for 131072 sequence length. Executor comparison of execution times (in seconds) for different methods across sequence lengths for Llama-3.2-1B. Configuration in format (segment_size, memory_tokens). Nvidia A100.}
  \label{tab:perf_comparison_llama1b}
\end{table}
"""

df_latex = parse_latex_table(latex_text)
df_latex


Unnamed: 0,model,dtype,4096,8192,12288,16384,24576,32768,65536
0,Llama-3.2-1B,bfloat16,0.024,0.026,,0.376,,0.926,2.46
1,"LLama-3.2-1B-ARMT (512, 128)",bfloat16,0.147,0.574,,1.15,,2.29,4.52
2,"Diagonal Batching: LLama-3.2-1B-ARMT (512, 128)",bfloat16,0.283,0.248,,0.454,,0.861,1.67
3,"LLama-3.2-1B-ARMT (1024, 128)",bfloat16,0.149,0.291,,0.578,,1.15,2.3
4,"Diagonal Batching: LLama-3.2-1B-ARMT (1024, 128)",bfloat16,0.119,0.196,,0.351,,0.656,1.27
5,"LLama-3.2-1B-ARMT (2048, 128)",bfloat16,0.094,0.177,,0.344,,0.679,1.35
6,"Diagonal Batching: LLama-3.2-1B-ARMT (2048, 128)",bfloat16,0.108,0.176,,0.304,,0.571,1.11
7,"LLama-3.2-1B-ARMT (4096, 128)",bfloat16,0.082,0.155,,0.301,,0.594,1.18
8,"Diagonal Batching: LLama-3.2-1B-ARMT (4096, 128)",bfloat16,0.102,0.172,,0.295,,0.553,1.07


In [7]:
latex_text_3b = r"""
\begin{table}[h]
  \centering
  \resizebox{\textwidth}{!}{%
  \renewcommand{\arraystretch}{1.2}
  \begin{tabular}{l*{6}{S[table-format=3.3]}}
  \toprule
  \textbf{Method} & \multicolumn{6}{c}{\textbf{Sequence Length}} \\
  \cmidrule(lr){2-7}
   & {\textbf{4096}} & {\textbf{8192}} & {\textbf{16384}} & {\textbf{32768}} & {\textbf{65536}} & {\textbf{131072}} \\
  \midrule
  Llama-3.2-3B & 0.168 & 0.344 & 0.769 & 1.95 & 5.59 & 18.2 \\
  \rowcolor{gray!10} \textbf{Configuration: (1024, 128)} \\
  LLama-3.2-3B-ARMT & 0.272 & 0.537 & 1.05 & 2.02 & 4.09 & 8.23 \\
  Diagonal Batching: LLama-3.1-3B-ARMT & 0.274 & 0.454 & 0.833 & 1.58 & 3.1 & 6.14 \\
  \rowcolor{gray!10} \textbf{Configuration: (4096, 128)} \\
  LLama-3.2-3B-ARMT & 0.203 & 0.39 & 0.765 & 1.52 & 3.01 & 6.01 \\
  Diagonal Batching: LLama-3.2-3B-ARMT & 0.239 & 0.411 & 0.739 & 1.4 & 2.72 & 5.37 \\
  \midrule
  \end{tabular}%
  }
  \caption{Diagonal batching speed-ups the execution - from 1.1 to 1.3 times comparing to base ARMT for 131072 sequence length. Executor comparison of execution times (in seconds) for different methods across sequence lengths for Llama-3.2-3B. Configuration in format (segment\_size, memory\_tokens). Nvidia A100}
  \label{tab:perf_comparison_llama3b}
\end{table}
"""

df_latex_3b = parse_latex_table(latex_text_3b)
df_latex_3b

Unnamed: 0,model,dtype,4096,8192,12288,16384,24576,32768,65536
0,Llama-3.2-3B,bfloat16,0.168,0.344,,0.769,,1.95,5.59
1,"LLama-3.2-3B-ARMT (1024, 128)",bfloat16,0.272,0.537,,1.05,,2.02,4.09
2,"Diagonal Batching: LLama-3.1-3B-ARMT (1024, 128)",bfloat16,0.274,0.454,,0.833,,1.58,3.1
3,"LLama-3.2-3B-ARMT (4096, 128)",bfloat16,0.203,0.39,,0.765,,1.52,3.01
4,"Diagonal Batching: LLama-3.2-3B-ARMT (4096, 128)",bfloat16,0.239,0.411,,0.739,,1.4,2.72


In [8]:
latex_text_160m = r"""
\begin{table}[h]
  \centering
  \renewcommand{\arraystretch}{1.2}
  \resizebox{\textwidth}{!}{%
  \begin{tabular}{l*{6}{S[table-format=3.3]}}
  \toprule
  \textbf{Method} & \multicolumn{6}{c}{\textbf{Sequence Length}} \\
  \cmidrule(lr){2-7}
   & {\textbf{4096}} & {\textbf{8192}} & {\textbf{16384}} & {\textbf{32768}} & {\textbf{65536}} & {\textbf{131072}} \\
  \midrule
  Llama-160M & 0.017 & 0.033 & 0.075 & 0.196 & 0.594 & 2.03 \\
  \rowcolor{gray!10} \textbf{Configuration: (1024, 128)} \\
  LLama-160M-ARMT & 0.105 & 0.211 & 0.422 & 0.877 & 1.72 & 3.37 \\
  Diagonal Batching: LLama-160M-ARMT & 0.061 & 0.087 & 0.138 & 0.243 & 0.451 & 0.855 \\
  \rowcolor{gray!10} \textbf{Configuration: (4096, 128)} \\
  LLama-160M-ARMT & 0.031 & 0.057 & 0.111 & 0.216 & 0.432 & 0.855 \\
  Diagonal Batching: LLama-160M-ARMT & 0.046 & 0.062 & 0.094 & 0.156 & 0.284 & 0.537 \\
  \midrule
  \end{tabular}%
  }
  \caption{Diagonal batching speed-ups the execution - from 1.6 to 3.9 times comparing to base ARMT for 131072 sequence length. Executor comparison of execution times (in seconds) for different methods across sequence lengths for Llama-160M. Configuration in format (segment\_size, memory\_tokens). Nvidia A100}
  \label{tab:perf_comparison_llama160m}
\end{table}
"""

df_latex_160m = parse_latex_table(latex_text_160m)
df_latex_160m

Unnamed: 0,model,dtype,4096,8192,12288,16384,24576,32768,65536
0,Llama-160M,bfloat16,0.017,0.033,,0.075,,0.196,0.594
1,"LLama-160M-ARMT (1024, 128)",bfloat16,0.105,0.211,,0.422,,0.877,1.72
2,"Diagonal Batching: LLama-160M-ARMT (1024, 128)",bfloat16,0.061,0.087,,0.138,,0.243,0.451
3,"LLama-160M-ARMT (4096, 128)",bfloat16,0.031,0.057,,0.111,,0.216,0.432
4,"Diagonal Batching: LLama-160M-ARMT (4096, 128)",bfloat16,0.046,0.062,,0.094,,0.156,0.284


In [9]:
latex_text_8b = r"""
\begin{table}[h]
  \centering
  \renewcommand{\arraystretch}{1.2}
  \resizebox{\textwidth}{!}{%
  \begin{tabular}{l*{6}{S[table-format=3.3]}}
  \toprule
  \textbf{Method} & \multicolumn{6}{c}{\textbf{Sequence Length}} \\
  \cmidrule(lr){2-7}
   & {\textbf{4096}} & {\textbf{8192}} & {\textbf{16384}} & {\textbf{32768}} & {\textbf{65536}} & {\textbf{131072}} \\
  \midrule
  Llama-3.1-8B & 0.332 & 0.682 & 1.48 & 3.61 & 9.82 & 30.4 \\
  \rowcolor{gray!10} \textbf{Configuration: (1024, 128)} \\
  LLama-3.1-8B-ARMT & 0.497 & 0.936 & 1.82 & 3.63 & 7.22 & 14.4 \\
  Diagonal Batching: LLama-3.1-8B-ARMT & 0.478 & 0.86 & 1.64 & 3.2 & 6.34 & 12.6 \\
  \rowcolor{gray!10} \textbf{Configuration: (4096, 128)} \\
  LLama-3.1-8B-ARMT& 0.384 & 0.754 & 1.48 & 2.95 & 5.86 & 11.7 \\
  Diagonal Batching: LLama-3.1-8B-ARMT & 0.432 & 0.781 & 1.46 & 2.83 & 5.6 & 11.1 \\
  \midrule
  \end{tabular}%
  }
  \caption{Diagonal batching speed-ups the execution - from 1.05 to 1.14 times comparing to base ARMT for 131072 sequence length. Executor comparison of execution times (in seconds) for different methods across sequence lengths for Llama-3.1-8B. Configuration in format (segment\_size, memory\_tokens). Nvidia A100}
  \label{tab:perf_comparison_llama8b}
\end{table}
"""

df_latex_8b = parse_latex_table(latex_text_8b)
df_latex_8b

Unnamed: 0,model,dtype,4096,8192,12288,16384,24576,32768,65536
0,Llama-3.1-8B,bfloat16,0.332,0.682,,1.48,,3.61,9.82
1,"LLama-3.1-8B-ARMT (1024, 128)",bfloat16,0.497,0.936,,1.82,,3.63,7.22
2,"Diagonal Batching: LLama-3.1-8B-ARMT (1024, 128)",bfloat16,0.478,0.86,,1.64,,3.2,6.34
3,"LLama-3.1-8B-ARMT (4096, 128)",bfloat16,0.384,0.754,,1.48,,2.95,5.86
4,"Diagonal Batching: LLama-3.1-8B-ARMT (4096, 128)",bfloat16,0.432,0.781,,1.46,,2.83,5.6


In [10]:
df_all = pd.concat([df_all, df_latex], ignore_index=True)
df_all = pd.concat([df_all, df_latex_160m], ignore_index=True)
df_all = pd.concat([df_all, df_latex_3b], ignore_index=True)
df_all = pd.concat([df_all, df_latex_8b], ignore_index=True)

df_all.to_csv('result_all.csv', index=False)

In [11]:
df_all

Unnamed: 0,model,dtype,4096,8192,12288,16384,24576,32768,65536
0,result_fla-hub__rwkv7-0.4B-world,float32,0.508055,0.99127,1.467115,1.859422,2.805683,3.716381,
1,result_fla-hub__rwkv7-1.5B-world,float32,1.666058,3.118853,4.65766,6.209821,9.226034,12.608575,
2,result_fla-hub__rwkv7-2.9B-world,float32,1.419686,2.79594,4.228404,5.574663,8.406351,11.186028,
3,result_fla-hub__rwkv7-2.9B-world,bfloat16,0.205387,0.387402,0.577086,0.765386,1.142162,1.519346,
4,result_fla-hub__rwkv7-1.5B-world,bfloat16,0.289255,0.513334,0.767381,1.00025,1.47825,1.948466,
5,result_fla-hub__rwkv7-0.4B-world,bfloat16,0.154438,0.276561,0.396253,0.458832,0.732718,0.965903,
6,result_state-spaces__mamba-370m-hf,bfloat16,0.324087,0.387467,0.577039,0.63634,0.891995,1.138121,2.268204
7,result_state-spaces__mamba-1.4b-hf,bfloat16,0.163566,0.296213,0.438505,0.57381,1.769796,1.115349,2.217252
8,result_state-spaces__mamba-790m-hf,bfloat16,0.364515,0.592927,0.818261,1.044761,1.370421,1.95304,3.64198
9,result_state-spaces__mamba-2.8b-hf,bfloat16,0.758202,1.33915,1.751778,2.502645,3.495352,4.645297,8.967827


In [12]:
import re

def parse_params_billion(model_name: str) -> float:
    matches = re.findall(r'(\d+(?:\.\d+)?)\s*([mMbB])', str(model_name))
    if not matches:
        return float('nan')
    num_str, suffix = matches[-1]
    value = float(num_str)
    return value / 1000.0 if suffix.lower() == 'm' else value

df_all_scaled = df_all.copy()

# Add new column to df_fit
if 'model' in df_all_scaled.columns:
    df_all['params_billion'] = df_all['model'].apply(parse_params_billion)
    df_all_scaled['params_billion'] = df_all_scaled['model'].apply(parse_params_billion)

In [13]:
import os

In [14]:
p.split('.')[-2], p.rsplit('_', 1)[-2]

('float32', 'result_state-spaces__mamba-2.8b-hf')

In [15]:
df_all.columns

Index([         'model',          'dtype',             4096,             8192,
                  12288,            16384,            24576,            32768,
                  65536, 'params_billion'],
      dtype='object')

In [16]:
import os

In [17]:
import os

In [18]:
df_all 

Unnamed: 0,model,dtype,4096,8192,12288,16384,24576,32768,65536,params_billion
0,result_fla-hub__rwkv7-0.4B-world,float32,0.508055,0.99127,1.467115,1.859422,2.805683,3.716381,,0.4
1,result_fla-hub__rwkv7-1.5B-world,float32,1.666058,3.118853,4.65766,6.209821,9.226034,12.608575,,1.5
2,result_fla-hub__rwkv7-2.9B-world,float32,1.419686,2.79594,4.228404,5.574663,8.406351,11.186028,,2.9
3,result_fla-hub__rwkv7-2.9B-world,bfloat16,0.205387,0.387402,0.577086,0.765386,1.142162,1.519346,,2.9
4,result_fla-hub__rwkv7-1.5B-world,bfloat16,0.289255,0.513334,0.767381,1.00025,1.47825,1.948466,,1.5
5,result_fla-hub__rwkv7-0.4B-world,bfloat16,0.154438,0.276561,0.396253,0.458832,0.732718,0.965903,,0.4
6,result_state-spaces__mamba-370m-hf,bfloat16,0.324087,0.387467,0.577039,0.63634,0.891995,1.138121,2.268204,0.37
7,result_state-spaces__mamba-1.4b-hf,bfloat16,0.163566,0.296213,0.438505,0.57381,1.769796,1.115349,2.217252,1.4
8,result_state-spaces__mamba-790m-hf,bfloat16,0.364515,0.592927,0.818261,1.044761,1.370421,1.95304,3.64198,0.79
9,result_state-spaces__mamba-2.8b-hf,bfloat16,0.758202,1.33915,1.751778,2.502645,3.495352,4.645297,8.967827,2.8


In [19]:
import numpy as np

def fit_linear_models(df):
    cols = [4096, 8192, 12288, 16384, 24576, 32768, 65536]
    fit_rows = []

    for _, r in df.iterrows():
        x_vals = []
        y_vals = []
        for c in cols:
            val = r.get(c, None)
            if pd.notna(val):
                x_vals.append(float(c))
                y_vals.append(float(val))
        if len(x_vals) >= 2:
            slope, intercept = np.polyfit(x_vals, y_vals, 1)
        else:
            slope, intercept = (float('nan'), float('nan'))
        fit_rows.append({
            'model_name': r['model'],
            'dtype': r['dtype'],
            'params_billion': r['params_billion'],
            'line_slope': slope,
            'line_bias': intercept,
        })

    return pd.DataFrame(fit_rows)

df_fit = fit_linear_models(df_all)


In [20]:
df_fit.sort_values('line_slope')

Unnamed: 0,model_name,dtype,params_billion,line_slope,line_bias
27,"Diagonal Batching: LLama-160M-ARMT (4096, 128)",bfloat16,0.16,4e-06,0.030167
25,"Diagonal Batching: LLama-160M-ARMT (1024, 128)",bfloat16,0.16,6e-06,0.034708
26,"LLama-160M-ARMT (4096, 128)",bfloat16,0.16,7e-06,0.003667
23,Llama-160M,bfloat16,0.16,9e-06,-0.057917
22,"Diagonal Batching: LLama-3.2-1B-ARMT (4096, 128)",bfloat16,1.0,1.6e-05,0.03925
20,"Diagonal Batching: LLama-3.2-1B-ARMT (2048, 128)",bfloat16,1.0,1.6e-05,0.039792
21,"LLama-3.2-1B-ARMT (4096, 128)",bfloat16,1.0,1.8e-05,0.0085
18,"Diagonal Batching: LLama-3.2-1B-ARMT (1024, 128)",bfloat16,1.0,1.9e-05,0.042875
19,"LLama-3.2-1B-ARMT (2048, 128)",bfloat16,1.0,2e-05,0.0095
16,"Diagonal Batching: LLama-3.2-1B-ARMT (512, 128)",bfloat16,1.0,2.4e-05,0.103042


In [21]:
df_fit['model_name']

0                     result_fla-hub__rwkv7-0.4B-world
1                     result_fla-hub__rwkv7-1.5B-world
2                     result_fla-hub__rwkv7-2.9B-world
3                     result_fla-hub__rwkv7-2.9B-world
4                     result_fla-hub__rwkv7-1.5B-world
5                     result_fla-hub__rwkv7-0.4B-world
6                   result_state-spaces__mamba-370m-hf
7                   result_state-spaces__mamba-1.4b-hf
8                   result_state-spaces__mamba-790m-hf
9                   result_state-spaces__mamba-2.8b-hf
10                  result_state-spaces__mamba-370m-hf
11                  result_state-spaces__mamba-790m-hf
12                  result_state-spaces__mamba-1.4b-hf
13                  result_state-spaces__mamba-2.8b-hf
14                                        Llama-3.2-1B
15                        LLama-3.2-1B-ARMT (512, 128)
16     Diagonal Batching: LLama-3.2-1B-ARMT (512, 128)
17                       LLama-3.2-1B-ARMT (1024, 128)
18    Diag

In [22]:
df_all_scaled

Unnamed: 0,model,dtype,4096,8192,12288,16384,24576,32768,65536,params_billion
0,result_fla-hub__rwkv7-0.4B-world,float32,0.508055,0.99127,1.467115,1.859422,2.805683,3.716381,,0.4
1,result_fla-hub__rwkv7-1.5B-world,float32,1.666058,3.118853,4.65766,6.209821,9.226034,12.608575,,1.5
2,result_fla-hub__rwkv7-2.9B-world,float32,1.419686,2.79594,4.228404,5.574663,8.406351,11.186028,,2.9
3,result_fla-hub__rwkv7-2.9B-world,bfloat16,0.205387,0.387402,0.577086,0.765386,1.142162,1.519346,,2.9
4,result_fla-hub__rwkv7-1.5B-world,bfloat16,0.289255,0.513334,0.767381,1.00025,1.47825,1.948466,,1.5
5,result_fla-hub__rwkv7-0.4B-world,bfloat16,0.154438,0.276561,0.396253,0.458832,0.732718,0.965903,,0.4
6,result_state-spaces__mamba-370m-hf,bfloat16,0.324087,0.387467,0.577039,0.63634,0.891995,1.138121,2.268204,0.37
7,result_state-spaces__mamba-1.4b-hf,bfloat16,0.163566,0.296213,0.438505,0.57381,1.769796,1.115349,2.217252,1.4
8,result_state-spaces__mamba-790m-hf,bfloat16,0.364515,0.592927,0.818261,1.044761,1.370421,1.95304,3.64198,0.79
9,result_state-spaces__mamba-2.8b-hf,bfloat16,0.758202,1.33915,1.751778,2.502645,3.495352,4.645297,8.967827,2.8


In [23]:
df_scaled_fit = fit_linear_models(df_all_scaled)

In [24]:
df_scaled_fit

Unnamed: 0,model_name,dtype,params_billion,line_slope,line_bias
0,result_fla-hub__rwkv7-0.4B-world,float32,0.4,0.000111,0.06743
1,result_fla-hub__rwkv7-1.5B-world,float32,1.5,0.000381,0.013383
2,result_fla-hub__rwkv7-2.9B-world,float32,2.9,0.000341,0.01626
3,result_fla-hub__rwkv7-2.9B-world,bfloat16,2.9,4.6e-05,0.013933
4,result_fla-hub__rwkv7-1.5B-world,bfloat16,1.5,5.8e-05,0.047895
5,result_fla-hub__rwkv7-0.4B-world,bfloat16,0.4,2.8e-05,0.036703
6,result_state-spaces__mamba-370m-hf,bfloat16,0.37,3.2e-05,0.142675
7,result_state-spaces__mamba-1.4b-hf,bfloat16,1.4,3.4e-05,0.14735
8,result_state-spaces__mamba-790m-hf,bfloat16,0.79,5.3e-05,0.149955
9,result_state-spaces__mamba-2.8b-hf,bfloat16,2.8,0.000134,0.220301


In [25]:
df_fit.columns

Index(['model_name', 'dtype', 'params_billion', 'line_slope', 'line_bias'], dtype='object')

Unnamed: 0,model_name,dtype,params_billion,line_slope,line_bias
0,result_fla-hub__rwkv7-0.4B-world,float32,0.4,0.000111,0.06743
1,result_fla-hub__rwkv7-1.5B-world,float32,1.5,0.000381,0.013383
2,result_fla-hub__rwkv7-2.9B-world,float32,2.9,0.000341,0.01626
3,result_fla-hub__rwkv7-2.9B-world,bfloat16,2.9,4.6e-05,0.013933
4,result_fla-hub__rwkv7-1.5B-world,bfloat16,1.5,5.8e-05,0.047895
5,result_fla-hub__rwkv7-0.4B-world,bfloat16,0.4,2.8e-05,0.036703
6,result_state-spaces__mamba-370m-hf,bfloat16,0.37,3.2e-05,0.142675
7,result_state-spaces__mamba-1.4b-hf,bfloat16,1.4,3.4e-05,0.14735
8,result_state-spaces__mamba-790m-hf,bfloat16,0.79,5.3e-05,0.149955
9,result_state-spaces__mamba-2.8b-hf,bfloat16,2.8,0.000134,0.220301
