In [None]:
import matplotlib.pyplot as plt
import pandas as pd
import numpy as np

# from matplotlib.ticker import FormatStrFormatter
from matplotlib.ticker import FuncFormatter

df = pd.read_csv('../stat/dataset/ShareGPT_stat.csv', names=['li', 'lo'])


plt_context = plt.style.context(['light'])
import contextlib
plt_context = contextlib.nullcontext()

def my_formatter(x, pos=None):
    # Format the tick value as a string with no leading zeros
    val_str = '{:.1f}'.format(x*100)
    return val_str

plt.rcParams.update({'font.size': 14,
                     'font.weight': 'bold',
                     'lines.markersize': 10,
                     'lines.linewidth': 2,
                     'font.family': 'serif',
                     'font.serif': 'Times New Roman',
                     'axes.linewidth': 2,
                     'figure.figsize': (6, 4)
                     })

print(df['li'].min(), df['li'].max())
print(df['lo'].min(), df['lo'].max())
print(df['li'].mean())
print(df['lo'].mean())
# y_ticks = np.arange(0.0, 2.0, 0.5)
with plt_context:
    figure, ax = plt.subplots()
    plt.hist(df['li'], bins=100, density=True, alpha=0.7, label=f"user prom. (mean:{df['li'].mean():.1f}, median:{df['li'].median():.1f})")
    plt.hist(df['lo'], bins=100, density=True, alpha=0.5, label=f"generation (mean:{df['lo'].mean():.1f}, median:{df['lo'].median():.1f})")
    # plt.hist(df['li'], bins=100, density=True, alpha=0.7, label=f"input (mean:{df['li'].mean():.1f})")
    # plt.hist(df['lo'], bins=100, density=True, alpha=0.5, label=f"output (mean:{df['lo'].mean():.1f})")
    plt.legend(ncol=1, loc='upper center', fontsize=14)
    plt.gca().yaxis.set_major_formatter(FuncFormatter(my_formatter))
    # plt.yticks(y_ticks)
    plt.text(0, 1.02, '1e-2', transform=plt.gca().transAxes, ha='left', va='bottom')
    plt.xlabel('Length (#tokens)')
    plt.ylabel('Density')
    plt.tight_layout()
    plt.savefig('../stat/dataset/sharegpt_stat.svg')
    plt.show()

In [None]:

import json

with open('../stat/dataset/mmlu_5shotprefix_stat.json', 'r') as fp:
    prefix_lens = json.load(fp)

with open('../stat/dataset/mmlu_prompt_stat.json', 'r') as fp:
    prompt_lens = json.load(fp)

with plt_context:
    figure, ax = plt.subplots()
    print(np.min(prefix_lens), np.max(prefix_lens))
    print(np.min(prompt_lens), np.max(prompt_lens))
    plt.hist(prefix_lens, bins=100, density=True, alpha=0.7, label=f"sys. promp. (mean:{np.mean(prefix_lens):.1f}, median:{np.median(prefix_lens):.1f})")
    plt.hist(prompt_lens, bins=100, density=True, alpha=0.5, label=f"user promp. (mean:{np.mean(prompt_lens):.1f}, median:{np.median(prompt_lens):.1f})")
    # plt.hist(df['li'], bins=100, density=True, alpha=0.7, label=f"input (mean:{df['li'].mean():.1f})")
    # plt.hist(df['lo'], bins=100, density=True, alpha=0.5, label=f"output (mean:{df['lo'].mean():.1f})")
    plt.legend(ncol=1, loc='upper center', fontsize=14)
    plt.gca().yaxis.set_major_formatter(FuncFormatter(my_formatter))
    # plt.yticks(y_ticks)
    plt.text(0, 1.02, '1e-2', transform=plt.gca().transAxes, ha='left', va='bottom')
    plt.xlabel('Length (#tokens)')
    plt.ylabel('Density')
    plt.tight_layout()
    plt.savefig('../stat/dataset/mmlu_stat.svg')
    plt.show()