In [1]:
import os, sys, warnings
script_dir = os.getcwd()
module_path = script_dir
for _ in range(1):
    module_path = os.path.abspath(os.path.join(module_path, '../'))
    if module_path not in sys.path:
        sys.path.insert(0,module_path)
        
from src import decode_moddeling, prefill_moddeling
import pandas as pd
from plotnine import *
import plotnine as p9
from tqdm import tqdm

from Systems.system_configs import *
All_model_list = ['opt_125m', 'opt_350m', 'opt_1b', 'opt_175b', 'gemma_7b', 'LLaMA_7b', 'llama3_8b',  'llama_13b', 'mixtral_7x8',  'LLaMA_70b', 'dbrx', 'grok-1', 'gpt-3',  'gpt-4']
All_models_name = ['facebook/opt-125m', 'facebook/opt-350m', 'facebook/opt-1.3b', 'facebook/opt-175b', 'google/gemma-7b', 'meta-llama/Llama-2-7b', 'meta-llama/Meta-Llama-3', 'meta-llama/Llama-2-13b', 'mistralai/Mixtral-8x7B', 'meta-llama/Llama-2-70b', 'databricks/dbrx-base', 'xai-org/grok-1', 'openai/gpt-3', 'openai/gpt-4']



In [2]:
# Import necessary libraries
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline
import plotly.graph_objects as go
import plotly.express as px
from plotnine import *
import plotnine as p9


# Set up interactive widgets for the variables
from ipywidgets import interact, IntSlider, Checkbox, BoundedIntText, BoundedFloatText, Dropdown
import ipywidgets as widgets


# Define the function to generate the demand curve
def generate_demand_curve(system_box, num_nodes_slider, model_box, quantization_box, batch_slider, beam_size, input_token_slider, output_token_slider):
    warnings.filterwarnings("ignore")
    system_box= globals()[system_box]
    data = []
    batch_size = int(batch_slider)
    for model in tqdm(model_box):
            prefill_df, prefill_summary_table = prefill_moddeling(model = model, batch_size = batch_size,
                                    input_tokens = input_token_slider, output_tokens = output_token_slider, 
                                    system_name = system_box,
                                    bits=quantization_box, model_profilling=True,
                                    tensor_parallel = num_nodes_slider)
            total_memory = int(system_box.get('Memory_size'))*1024
            memory_left = total_memory - prefill_summary_table['Model Weights (MB)'].values[0]
            per_token_prefill_kv_cache = prefill_summary_table['KV Cache (MB)'].values[0] * beam_size / input_token_slider
            data.append([model,'Prefill',batch_size, input_token_slider, output_token_slider] + list(prefill_summary_table.loc[0].values) + [int(memory_left/per_token_prefill_kv_cache)])
            decode_df , decode_summary_table = decode_moddeling(model = model, batch_size = batch_size, Bb = beam_size ,
                                    input_tokens = input_token_slider, output_tokens = output_token_slider, 
                                    system_name = system_box,
                                    bits=quantization_box, model_profilling=True,
                                    tensor_parallel = num_nodes_slider) 
            data.append([model,'Decode',batch_size, input_token_slider, output_token_slider] + list(decode_summary_table.loc[0].values) + [int(memory_left/per_token_prefill_kv_cache - output_token_slider )])
    assert len(data) > 0, "No Model fits in the given # of GPUs. Increase GPUs or use different Model"
    data_df = pd.DataFrame(data, columns = ['Model', 'Stage','Batch', 'Input Context Length', 'Num Output Tokens'] + list(prefill_summary_table.columns) + ['Max Tokens Possible'])
    data_df = data_df.replace(All_model_list, All_models_name)
    data_df['Stage'] = pd.Categorical(data_df['Stage'], categories=['Prefill','Decode'])

    data_df.rename(columns={'Model Weights (MB)': 'Weights per Node(MB)', 'KV Cache (MB)': 'KV Cache per Node(MB)'}, inplace=True)

    display(data_df[['Model', 'Stage', 'Batch', 'Input Context Length', 'Num Output Tokens', 'Weights per Node(MB)', 'KV Cache per Node(MB)', 'Max Tokens Possible']])



batch_slider = widgets.Text( value='8', description='Batch Size:', disabled=False , style={'description_width': 'initial'})
beam_size = widgets.IntSlider(value=1, min=1, max=16, description='# of Parallel Beams:', style={'description_width': 'initial'},)
input_token_slider = BoundedIntText( value=512, min=1, max= 100000, step=1, description='Input Tokens:', disabled=False , style={'description_width': 'initial'})
output_token_slider = BoundedIntText( value=128, min=1, max= 100000, step=1, description='Output Tokens:', disabled=False , style={'description_width': 'initial'})

quantization_box = Dropdown( options=['bf16', 'int8', 'int4'], value='int8', description='Quantization:', disabled=False , style={'description_width': 'initial'},)
model_box = widgets.SelectMultiple( options=[
    ('facebook/opt-125m','opt_125m'),
    ('facebook/opt-350m','opt_350m'),
    ('facebook/opt-1.3b','opt_1b'),
    ('facebook/opt-175b','opt_175b'),
    ('google/gemma-7b','gemma_7b'),
    ('meta-llama/Llama-2-7b','LLaMA_7b'),
    ('meta-llama/Meta-Llama-3-8B','llama3_8b'), 
    ('meta-llama/Llama-2-13b','llama_13b'),
    ('mistralai/Mixtral-8x7B','mixtral_7x8'), 
    ('meta-llama/Llama-2-70b','LLaMA_70b'),
    ('databricks/dbrx-base','dbrx'),
    ('xai-org/grok-1','grok-1'),
    ('openai/gpt-3','gpt-3'), 
    ('openai/gpt-4','gpt-4')
    ], value=['opt_125m'], description='Models:', disabled=False,)
system_box = Dropdown( options=['A100_40GB_GPU', 'A100_80GB_GPU', 'H100_GPU','GH200_GPU', 'TPUv4','TPUv5e', 'MI300X', 'Gaudi3'], value='H100_GPU', description='System:', disabled=False,)
num_nodes_slider = BoundedIntText( value=2, min=1, max=128, step=1, description='# Nodes:', disabled=False)


# Create an interactive plot
interact(generate_demand_curve,
         system_box=system_box, num_nodes_slider=num_nodes_slider, model_box=model_box, quantization_box=quantization_box,
         batch_slider=batch_slider, beam_size = beam_size, input_token_slider=input_token_slider, output_token_slider=output_token_slider, )

interactive(children=(Dropdown(description='System:', index=2, options=('A100_40GB_GPU', 'A100_80GB_GPU', 'H10…

<function __main__.generate_demand_curve(system_box, num_nodes_slider, model_box, quantization_box, batch_slider, beam_size, input_token_slider, output_token_slider)>