In [15]:
import pandas as pd
import itertools

def generate_mkn_table(hidden_sizes, seq_lens, batch_sizes):
    all_rows = []
    model_groups = [
        ('7B Model', 4096),
        ('70B Model', 28672),
        ('175B Model', 49152)
    ]
    
    for group_name, h in model_groups:
        rows = []
        for s, b in itertools.product(seq_lens, batch_sizes):
            m = b * s
            k = h
            n = h * 8
            
            rows.append({
                'Model': group_name,
                'Batch Size': b,
                'Sequence Length': s,
                'Hidden Size': h,
                'M': m,
                'K': k,
                'N': n
            })
        
        all_rows.extend(rows)
    
    df = pd.DataFrame(all_rows)
    
    # Style the dataframe
    def highlight_groups(s):
        return ['background-color: #f0f0f0' if s.name % len(seq_lens) * len(batch_sizes) == 0 
                else '' for _ in s]
    
    styled_df = df.style.apply(highlight_groups, axis=1)\
                       .format({
                           'M': '{:,.0f}',
                           'K': '{:,.0f}',
                           'N': '{:,.0f}',
                           'Hidden Size': '{:,.0f}'
                       })\
                       .set_properties(**{
                           'text-align': 'center',
                           'padding': '8px'
                       })\
                       .set_table_styles([
                           {'selector': 'th',
                            'props': [('background-color', '#2c3e50'),
                                    ('color', 'white'),
                                    ('font-weight', 'bold'),
                                    ('text-align', 'center'),
                                    ('padding', '8px')]},
                           {'selector': 'td',
                            'props': [('border', '1px solid #ddd')]},
                           {'selector': 'tr:hover',
                            'props': [('background-color', '#e6f3ff')]}
                       ])
    
    return styled_df

hidden_sizes = [4096, 28672, 49152]
seq_lens = [1024, 2048, 8192] 
batch_sizes = [1, 2, 8, 16, 32]

styled_df = generate_mkn_table(hidden_sizes, seq_lens, batch_sizes)
styled_df

Unnamed: 0,Model,Batch Size,Sequence Length,Hidden Size,M,K,N
0,7B Model,1,1024,4096,1024,4096,32768
1,7B Model,2,1024,4096,2048,4096,32768
2,7B Model,8,1024,4096,8192,4096,32768
3,7B Model,16,1024,4096,16384,4096,32768
4,7B Model,32,1024,4096,32768,4096,32768
5,7B Model,1,2048,4096,2048,4096,32768
6,7B Model,2,2048,4096,4096,4096,32768
7,7B Model,8,2048,4096,16384,4096,32768
8,7B Model,16,2048,4096,32768,4096,32768
9,7B Model,32,2048,4096,65536,4096,32768


In [16]:
df

Unnamed: 0,name,batch_size,seq_len,hidden_size,m,k,n
0,=== 7b ===,,,,,,
1,exp900b04e_but_remove_grad_input_compute_if_no...,1.0,1024.0,4096.0,1024.0,4096.0,32768.0
2,exp900b04e_but_remove_grad_input_compute_if_no...,2.0,1024.0,4096.0,2048.0,4096.0,32768.0
3,exp900b04e_but_remove_grad_input_compute_if_no...,8.0,1024.0,4096.0,8192.0,4096.0,32768.0
4,exp900b04e_but_remove_grad_input_compute_if_no...,16.0,1024.0,4096.0,16384.0,4096.0,32768.0
5,exp900b04e_but_remove_grad_input_compute_if_no...,32.0,1024.0,4096.0,32768.0,4096.0,32768.0
6,exp900b04e_but_remove_grad_input_compute_if_no...,1.0,2048.0,4096.0,2048.0,4096.0,32768.0
7,exp900b04e_but_remove_grad_input_compute_if_no...,2.0,2048.0,4096.0,4096.0,4096.0,32768.0
8,exp900b04e_but_remove_grad_input_compute_if_no...,8.0,2048.0,4096.0,16384.0,4096.0,32768.0
9,exp900b04e_but_remove_grad_input_compute_if_no...,16.0,2048.0,4096.0,32768.0,4096.0,32768.0
