In [None]:
from os import listdir
from os.path import join, isfile

In [None]:
import numpy as np
from matplotlib import pyplot as plt
from matplotlib import colors as mcolors
import pickle

In [None]:
import matplotlib.font_manager as fm
print([f.name for f in fm.fontManager.ttflist])

In [None]:
plt.rcParams["font.family"] = 'DejaVu Serif'
plt.rcParams['figure.dpi'] = 500
plt.rcParams["mathtext.fontset"] = 'dejavuserif'

In [None]:
def parse_filename(filename, suffix):
    # ex: Qwen2.5-3B_codecontests_revision=main_temp=1.0_all_ids.pkl
    filename = filename[:filename.index(suffix)] # Qwen2.5-3B_codecontests_revision=main_temp=1.0
    i = filename.index('_')
    model_name, filename = filename[:i], filename[i+1:] # codecontests_revision=main_temp=1.0
    i = filename.index('_')
    dataset_name, filename = filename[:i], filename[i+1:] # revision=main_temp=1.0
    i = filename.index('_')
    revision, temperature = filename[:i], filename[i+1:]
    temperature = temperature[temperature.index('=')+1:]
    revision = revision[revision.index('=')+1:]
    return model_name, dataset_name, revision, float(temperature) # Qwen2.5-3B, codecontests, main, 1.0

def is_chat_model(model_name, instruct=True):
    if instruct:
        return 'chat' in model_name.lower() or 'instruct' in model_name.lower()
    else:
        return not is_chat_model(model_name, instruct=True)

def is_model_family(model_name, model_family):
    return model_family.lower() in model_name.lower()

In [None]:
def load_files(folder, model_family, dataset_name, instruct, suffix):
    all_files = {}
    for filename in listdir(folder):
        full_path = join(folder, filename)
        if isfile(full_path) and full_path.endswith(suffix):
            load_model_name, load_dataset_name, load_revision, load_temperature = parse_filename(filename, suffix)

            if is_model_family(load_model_name, model_family) and dataset_name == load_dataset_name and is_chat_model(load_model_name, instruct):
                load_model_revision_name = '{}_revision={}'.format(load_model_name, load_revision)
                with open(full_path, 'rb') as f:
                    saved_file = pickle.load(f)
                if load_model_revision_name not in all_files:
                    all_files[load_model_revision_name] = {}
                all_files[load_model_revision_name][load_temperature] = saved_file
    return all_files

In [None]:
def get_model_size(model_revision_name):
    # Qwen2.5-3B_revision=main
    model_name = model_revision_name[:model_revision_name.index('_')] # Qwen2.5-3B
    while '-' in model_name:
        prefix, model_name = model_name[:model_name.index('-')], model_name[model_name.index('-')+1:]
        if prefix.lower().endswith('b'):
            return float(prefix[:-1])
        elif prefix.lower().endswith('m'):
            return float(prefix[:-1]) / 1000
    if model_name.lower().endswith('b'):
        return float(model_name[:-1])
    elif model_name.lower().endswith('m'):
        return float(model_name[:-1]) / 1000
    else:
        assert False

def compute_time_avg(inp):
    # inp: List[List[float]]
    max_len = np.max([len(lst) for lst in inp])
    totals = np.zeros((max_len,))
    counts = np.zeros((max_len,))
    for lst in inp:
        totals[:len(lst)] += lst
        counts[:len(lst)] += 1
    avg = totals / counts
    overall_avg = np.sum(totals) / np.sum(counts)
    
    totals_std = np.zeros((max_len,))
    counts_std = np.zeros((max_len,))
    for lst in inp:
        totals_std[:len(lst)] += (np.array(lst) - avg[:len(lst)])**2
        counts_std[:len(lst)] += 1
    std = np.sqrt(totals_std / (counts_std**2)) # sample mean variance = population variance / count
    return avg, overall_avg, std

def load_ent_over_time(folder, model_family, dataset_name, max_len):
    suffix = '_logprobs_and_ents.pkl'
    all_logprobs_and_ents = load_files(folder, model_family, dataset_name, False, suffix)
    plot_data = []
    for model_revision_name in all_logprobs_and_ents.keys():
        if 1.0 in all_logprobs_and_ents[model_revision_name].keys():
            response_logprobs, generation_ents = all_logprobs_and_ents[model_revision_name][1.0]
            response_logprobs = [lp[:max_len] for lp in response_logprobs]
            generation_ents = [ents[:max_len] for ents in generation_ents]
            ents_over_time, ents_avg, ents_std = compute_time_avg(generation_ents)
            logprobs_over_time, logprobs_avg, logprobs_std = compute_time_avg(response_logprobs)

            data = {}
            data['ents'] = ents_over_time
            data['logprobs'] = -1 * logprobs_over_time
            data['size'] = np.log(get_model_size(model_revision_name))
            data['name'] = model_revision_name[:model_revision_name.index('_')]
            plot_data.append(data)
    return plot_data

def exponential_smoothing(arr, alpha):
    """
    Applies exponential smoothing to a time series.

    Parameters:
        arr (array-like): The input array containing floats.
        alpha (float): The smoothing factor, where 0 < alpha <= 1.
                      - Higher alpha gives more weight to recent values.
                      - Lower alpha gives more weight to older values.

    Returns:
        list: The smoothed time series.
    """
    if not (0 < alpha <= 1):
        raise ValueError("Alpha must be between 0 and 1.")
    if not isinstance(arr, (list, np.ndarray)):
        raise TypeError("Input must be a list or numpy array.")
    if len(arr) == 0:
        raise ValueError("Input array must not be empty.")
    
    smoothed = [arr[0]]  # Initialize with the first value of the series
    for i in range(1, len(arr)):
        smoothed_value = alpha * arr[i] + (1 - alpha) * smoothed[-1]
        smoothed.append(smoothed_value)
    
    return smoothed

In [None]:
def truncate_cmap(cmap, start_frac=0., end_frac=0.9):
    N = 256  # Number of colors in the original colormap
    start = int(N * start_frac)  # Start index for truncation
    end = int(N * end_frac)  # End index for truncation
    colors = cmap(np.linspace(start / N, end / N, end - start))
    truncated_cmap = mcolors.ListedColormap(colors)
    return truncated_cmap

def plot_ent_over_time(plot_data, dataset_name, alpha1=0.2, alpha2=0.1, ax=None, show_xlabel=True, show_title=True):
    # Normalize sizes for colormap
    sizes = [d['size'] for d in plot_data]
    norm = plt.Normalize(min(sizes), max(sizes))
    cmap = plt.cm.plasma
    cmap = truncate_cmap(cmap)

    if ax is None:
        fig, ax = plt.subplots()
        do_show = True
    else:
        do_show = False
    
    # Plot each dataset
    legend_items = []
    
    for data in plot_data:
        color = cmap(norm(data['size']))
        ax.plot(exponential_smoothing(data['ents'], alpha1), linestyle='-', color=color, linewidth=3, zorder=1)  # Solid line for 'ents'
        ax.plot(exponential_smoothing(data['logprobs'], alpha2), linestyle='--', color=color, linewidth=1, zorder=0)  # Dashed line for 'logprobs'
        # Collect legend items (size, name, color)
        legend_items.append((data['size'], data['name'], color))
    
    # Sort legend items by size
    legend_items = sorted(legend_items, key=lambda x: x[0])
    
    # Create legend
    legend_patches = [plt.Line2D([0], [0], color=item[2], lw=2, label=item[1]) for item in legend_items]
    size_legend = ax.legend(handles=legend_patches, loc='upper left')#, title="Legend (Sorted by Size)")

    line_legend_patches = [
        plt.Line2D([0], [0], linestyle='-', color='grey', lw=2, label='Entropy'),
        plt.Line2D([0], [0], linestyle='--', color='grey', lw=1, label='Log Loss')
    ]
    line_legend = ax.legend(handles=line_legend_patches, loc='upper right')#, title="Line Type"

    ax.add_artist(size_legend)
    ax.add_artist(line_legend)
    
    # Labels
    if show_xlabel:
        ax.set_xlabel("Step", fontsize=18, weight='bold')
    #ax.set_ylabel("Values")
    if show_title:
        ax.set_title(dataset_name, fontsize=20)

    if do_show:
        plt.show()

def plot_ent_over_time_multiple(all_plot_datas, dataset_names, save_path, alpha1=0.2, alpha2=0.1):
    fig, all_axs = plt.subplots(4, 3, figsize=(16, 16))
    for i, (axs, plot_datas) in enumerate(zip(all_axs, all_plot_datas)):
        for j, (ax, plot_data, dataset_name) in enumerate(zip(axs, plot_datas, dataset_names)):
            plot_ent_over_time(plot_data, dataset_names_mapping[dataset_name], ax=ax, alpha1=alpha1, alpha2=alpha2, 
                               show_title=(i==0), show_xlabel=(i==3 and j==1))
    plt.tight_layout()
    plt.savefig(save_path)
    plt.show()

In [None]:
save_folder = 'results'
model_families = ('Llama-3', 'Qwen2.5', 'pythia', 'Llama-2')
dataset_names = ('wikitext', 'writingprompts', 'codecontests')
dataset_names_mapping = {
    'wikitext' : 'WikiText', 
    'writingprompts': 'WritingPrompts',
    'codecontests': 'CodeContests'
}
model_family_mapping = {
    'Llama-3': 'Llama 3', 
    'Qwen2.5': 'Qwen2.5', 
    'pythia': 'Pythia', 
    'Llama-2': 'Llama 2'
}
max_len = 1000

In [None]:
all_plot_datas = []
for model_family in model_families:
    plot_datas = []
    for dataset_name in dataset_names:
        plot_data = load_ent_over_time(save_folder, model_family, dataset_name, max_len)
        plot_datas.append(plot_data)
    all_plot_datas.append(plot_datas)

In [None]:
save_path = '{}/ent_over_time_plot_nosmoothing.png'.format(save_folder)
plot_ent_over_time_multiple(
    [all_plot_datas[3], all_plot_datas[0], all_plot_datas[2], all_plot_datas[1]], dataset_names, save_path, alpha1=1, alpha2=1)

In [None]:
save_path = '{}/ent_over_time_plot.png'.format(save_folder)
plot_ent_over_time_multiple([all_plot_datas[3], all_plot_datas[0], all_plot_datas[2], all_plot_datas[1]], dataset_names, save_path)

In [None]:
def load_ent_vs_size(folder, model_family, dataset_name, max_len):
    suffix = '_logprobs_and_ents.pkl'
    all_logprobs_and_ents = load_files(folder, model_family, dataset_name, False, suffix)
    plot_data = []
    for model_revision_name in all_logprobs_and_ents.keys():
        if 1.0 in all_logprobs_and_ents[model_revision_name].keys():
            response_logprobs, generation_ents = all_logprobs_and_ents[model_revision_name][1.0]
            response_logprobs = [lp[:max_len] for lp in response_logprobs]
            generation_ents = [ents[:max_len] for ents in generation_ents]
            ents_over_time, ents_avg, ents_std = compute_time_avg(generation_ents)
            logprobs_over_time, logprobs_avg, logprobs_std = compute_time_avg(response_logprobs)

            data = {}
            data['entCE'] = np.log(ents_avg - (-1 * logprobs_avg))
            data['size'] = np.log(get_model_size(model_revision_name))
            #data['name'] = model_revision_name[:model_revision_name.index('_')]
            plot_data.append(data)
    return plot_data

def plot_ent_vs_size(plot_datas, dataset_names, model_family, ax=None, show_ylabel=True):
    """
    Plots scatter plots for each plot_data in plot_datas and adds best fit lines.
    
    Parameters:
        plot_datas (list of list of dict): Each plot_data is a list of dicts with 'size' (x-axis) and 'entCE' (y-axis).
        dataset_names (list of str): Names of the datasets, corresponding to each plot_data in plot_datas.
    """
    if len(plot_datas) != len(dataset_names):
        raise ValueError("plot_datas and dataset_names must have the same length.")

    cmap = plt.cm.plasma
    cmap = truncate_cmap(cmap)
    
    if ax is None:
        fig, ax = plt.subplots()
        do_show = True
    else:
        do_show = False
    
    min_y = np.inf
    max_x = -np.inf
    for i, (plot_data, dataset_name) in enumerate(zip(plot_datas, dataset_names)):
        dataset_name = dataset_names_mapping[dataset_name]
        color = cmap(i / (len(dataset_names) - 1))
        
        # Extract x (size) and y (entCE) from the dicts
        x = np.array([point['size'] for point in plot_data])
        y = np.array([point['entCE'] for point in plot_data])

        min_y = np.minimum(min_y, np.min(y))
        max_x = np.maximum(max_x, np.max(x))
        
        # Scatter plot
        ax.scatter(x, y, color=color)#, label=f"{dataset_name} (data)")
        
        # Best fit line
        coeffs = np.polyfit(x, y, 1)  # Linear fit (degree 1)
        slope, intercept = coeffs
        best_fit_line = slope * x + intercept
        ax.plot(x, best_fit_line, color=color, label="{}: $Y={:.2f}X^{{{:.2f}}}$".format(dataset_name, np.exp(intercept), slope))

    model_family_to_offset = {
        'Llama-3': -0.2,
        'Qwen2.5': -0.6,
        'pythia': -0.,
        'Llama-2': -0.15
    }
    offset = model_family_to_offset[model_family]
    ax.scatter([max_x], [min_y + offset,], alpha=0.) # hacky way to make space for legend
    
    # Labels and legend
    #ax.set_xlabel("log(Size)", fontsize=14, weight='bold')
    if show_ylabel:
        ax.set_ylabel("log(EntCE)", fontsize=14, weight='bold')
    ax.set_ylim(bottom=-3, top=0.3)
    ax.set_title(model_family_mapping[model_family], fontsize=18)
    ax.legend(loc='lower left')

    if do_show:
        plt.show()

def plot_ent_vs_size_multiple(all_plot_datas, dataset_names, model_families, save_path):
    fig, axs = plt.subplots(nrows=1, ncols=4, figsize=(16, 4))
    for i, (ax, plot_datas, model_family) in enumerate(zip(axs, all_plot_datas, model_families)):
        plot_ent_vs_size(plot_datas, dataset_names, model_family, ax=ax, show_ylabel=(i==0))
    fig.text(0.5, 0.02, 'log(Model Size)', ha='center', fontsize=14, weight='bold')
    plt.tight_layout(rect=[0,0.04,1,1]) 
    plt.savefig(save_path)
    plt.show()

In [None]:
all_plot_datas = []
for model_family in model_families:
    plot_datas = []
    for dataset_name in dataset_names:
        plot_data = load_ent_vs_size(save_folder, model_family, dataset_name, max_len)
        plot_datas.append(plot_data)
    all_plot_datas.append(plot_datas)

In [None]:
save_path = '{}/ent_vs_size_plot.png'.format(save_folder)
plot_ent_vs_size_multiple([all_plot_datas[3], all_plot_datas[0], all_plot_datas[2], all_plot_datas[1]], 
                          dataset_names, 
                          [model_families[3], model_families[0], model_families[2], model_families[1]], 
                          save_path)

In [None]:
def load_entCE_vs_logloss(folder, model_family, dataset_name, max_len):
    suffix = '_logprobs_and_ents.pkl'
    all_logprobs_and_ents = load_files(folder, model_family, dataset_name, False, suffix)
    all_logprobs_and_ents_inst = load_files(folder, model_family, dataset_name, True, suffix)
    plot_data = []
    for model_revision_name in all_logprobs_and_ents.keys():
        model_name = model_revision_name[:model_revision_name.index('_')]

        pts = []

        ent_dict = all_logprobs_and_ents[model_revision_name]
        all_temps = np.sort(np.array(list(ent_dict.keys())))
        for temp in all_temps:
            response_logprobs, generation_ents = ent_dict[temp]
            response_logprobs = [lp[:max_len] for lp in response_logprobs]
            generation_ents = [ents[:max_len] for ents in generation_ents]

            ents_over_time, ents_avg, ents_std = compute_time_avg(generation_ents)
            logprobs_over_time, logprobs_avg, logprobs_std = compute_time_avg(response_logprobs)

            data = {}
            data['ent_ce'] = ents_avg - (-1) * logprobs_avg
            data['log_loss'] = (-1) * logprobs_avg
            data['name'] = '$\\tau = {}$'.format(temp)
            pts.append(data)

        data = {}
        data['temp_pts'] = pts
        data['name'] = model_name

        if '-hf' in model_name: # so that checking 'model_name in model_revision_name_inst' works
            model_name = model_name[:model_name.index('-hf')]
        for model_revision_name_inst in all_logprobs_and_ents_inst.keys():
            if model_name in model_revision_name_inst and 1.0 in all_logprobs_and_ents[model_revision_name].keys():
                response_logprobs, generation_ents = all_logprobs_and_ents_inst[model_revision_name_inst][1.0]
                response_logprobs = [lp[:max_len] for lp in response_logprobs]
                generation_ents = [ents[:max_len] for ents in generation_ents]

                ents_over_time, ents_avg, ents_std = compute_time_avg(generation_ents)
                logprobs_over_time, logprobs_avg, logprobs_std = compute_time_avg(response_logprobs)
                
                inst_pt = {}
                inst_pt['ent_ce'] = ents_avg - (-1) * logprobs_avg
                inst_pt['log_loss'] = (-1) * logprobs_avg
                inst_pt['name'] = 'instruct'
        
        data['inst_pt'] = inst_pt

        plot_data.append(data)
    return plot_data

def plot_entCE_vs_logloss(data, dataset_name, ax=None, fontsizes=(14,18), show_title=True, show_xlabel=True, show_ylabel=True, title=None):
    """
    Plots points from 'temp_pts' with a solid line and labels each point with its 'name'.
    Draws a dotted line from the first point in 'temp_pts' to 'inst_pt', and labels 'inst_pt'.
    
    Parameters:
        data (dict): Dictionary containing:
                     - 'temp_pts': List of dicts with keys 'log_loss' (x), 'ent_ce' (y), 'name' (label).
                     - 'inst_pt': A single dict with the same keys.
    """
    offset_dict = {
        'wikitext': 0.04,
        'writingprompts': 0.04,
        'codecontests': 0.01
    }
    offset = offset_dict[dataset_name]

    
    temp_pts = data["temp_pts"]
    inst_pt = data["inst_pt"]

    cmap = plt.cm.plasma
    colors = [cmap(0), cmap(0.5)]
    
    # Extract x and y coordinates for temp_pts
    x_vals = [pt["log_loss"] for pt in temp_pts]
    y_vals = [pt["ent_ce"] for pt in temp_pts]
    labels = [pt["name"] for pt in temp_pts]
    
    # Create figure and axis
    if ax is None:
        fig, ax = plt.subplots()
        do_show = True
    else:
        do_show = False
    
    # Plot temp_pts as a solid line with markers
    ax.plot(x_vals, y_vals, linestyle='-', marker='o', zorder=0.5, color=colors[0])
    
    # Annotate each point with its name
    for pt in temp_pts:
        ax.text(pt["log_loss"], pt["ent_ce"] + offset, pt["name"], fontsize=fontsizes[0], ha='center', va='bottom', zorder=1)
    
    # Plot inst_pt as a single point
    ax.scatter(inst_pt["log_loss"], inst_pt["ent_ce"], color=colors[1], marker='D')
    
    # Annotate inst_pt
    ax.text(inst_pt["log_loss"], inst_pt["ent_ce"] + offset, inst_pt["name"], fontsize=fontsizes[0], ha='center', va='bottom', zorder=1)
    
    # Draw a dotted line from the last temp_pt to inst_pt
    last_pt = temp_pts[-1]
    ax.plot(
        [last_pt["log_loss"], inst_pt["log_loss"]],
        [last_pt["ent_ce"], inst_pt["ent_ce"]],
        linestyle=':', color=colors[1], zorder=0,
    )

    # Hacky way to create space for annotation
    min_x, max_x = np.min(x_vals + [inst_pt["log_loss"]]), np.max(x_vals + [inst_pt["log_loss"]])
    min_y, max_y = np.min(y_vals + [inst_pt["ent_ce"]]), np.max(y_vals + [inst_pt["ent_ce"]])
    pad_dict = {
        'wikitext': (0.005, 0.15),
        'writingprompts': (0.005, 0.17),
        'codecontests': (0.003, 0.04)
    }
    x_pad, y_pad = pad_dict[dataset_name]
    ax.plot(
        [min_x - x_pad, max_x + x_pad],
        [max_y + y_pad, min_y], alpha=0.
    )

    # draw dashed line at ent_ce = 0
    ax.axhline(y=0, linestyle='--', color='gray', zorder=0)
    
    # Labels and title
    if show_xlabel:
        ax.set_xlabel("Log loss", fontsize=fontsizes[0], weight='bold')
    if show_ylabel:
        ax.set_ylabel("Entropy calibration error", fontsize=fontsizes[0], weight='bold')
    if show_title:
        if title is None:
            title = dataset_names_mapping[dataset_name]
        ax.set_title(title, fontsize=fontsizes[1])
    
    # Legend
    #ax.legend()
    
    # Show the plot
    if do_show:
        plt.tight_layout()
        plt.show()

def plot_entCE_vs_logloss_partial(plot_datas, model_name, dataset_names, save_path):
    fig, axs = plt.subplots(1, 3, figsize=(16, 4))
    for j, (ax, plot_data, dataset_name) in enumerate(zip(axs, plot_datas, dataset_names)):
        for data in plot_data:
            if data['name'] == model_name:
                plot_entCE_vs_logloss(data, dataset_name, ax=ax, fontsizes=(14,18), show_title=True, show_xlabel=(j==1), show_ylabel=(j==0))
    plt.tight_layout(rect=[0,0,1,0.9])
    fig.text(0.51, 0.92, model_name, ha='center', fontsize=18, weight='bold')
    plt.savefig(save_path)
    plt.show()

def plot_entCE_vs_logloss_multiple(plot_datas, dataset_names, save_path):
    model_names = [(data['name'], get_model_size(data['name'] + '_revision=main')) for data in plot_datas[0]]
    model_names.sort(key=lambda x: x[1])
    model_names = [model_name for model_name, size in model_names]
    #print(model_names)
    
    fig, all_axs = plt.subplots(len(model_names), 3, figsize=(16, len(model_names) * 4))
    for i, (axs, model_name) in enumerate(zip(all_axs, model_names)):
        for j, (ax, plot_data, dataset_name) in enumerate(zip(axs, plot_datas, dataset_names)):
            for data in plot_data:
                if data['name'] == model_name:
                    if i == 0:
                        if j == 1:
                            title = "{}\n{}".format(dataset_name, model_name)
                        else:
                            title = "{}\n".format(dataset_name)
                    else:
                        title = model_name
                    plot_entCE_vs_logloss(
                        data, dataset_name, ax=ax, fontsizes=(14,18), title=title,
                        show_title=(i==0 or j==1), show_xlabel=(i==len(model_names)-1 and j==1), show_ylabel=False
                    )
    #plt.tight_layout()
    #plt.tight_layout(rect=[0,0,1,0.9])
    #fig.text(0.51, 0.92, model_name, ha='center', fontsize=18, weight='bold')
    fig.text(0.015, 0.5, 'Entropy calibration error', ha='center', fontsize=14, weight='bold', rotation='vertical')
    plt.tight_layout(rect=[0.02,0,1,1])
    plt.savefig(save_path)
    plt.show()

In [None]:
model_families = ('Llama-2', 'Llama-3', 'Qwen2.5')

In [None]:
all_plot_datas = []
for model_family in model_families:
    plot_datas = []
    for dataset_name in dataset_names:
        plot_data = load_entCE_vs_logloss(save_folder, model_family, dataset_name, max_len)
        plot_datas.append(plot_data)
    all_plot_datas.append(plot_datas)

In [None]:
save_path = '{}/entCE_vs_logloss_plot_partial.png'.format(save_folder)
plot_entCE_vs_logloss_partial(all_plot_datas[2], 'Qwen2.5-14B', dataset_names, save_path)

In [None]:
plt.rcParams['figure.dpi'] = 300
save_path = '{}/entCE_vs_logloss_plot_qwen.png'.format(save_folder)
plot_entCE_vs_logloss_multiple(all_plot_datas[2], dataset_names, save_path)

In [None]:
plt.rcParams['figure.dpi'] = 300
save_path = '{}/entCE_vs_logloss_plot_llama_2.png'.format(save_folder)
plot_entCE_vs_logloss_multiple(all_plot_datas[0], dataset_names, save_path)

In [None]:
plt.rcParams['figure.dpi'] = 300
save_path = '{}/entCE_vs_logloss_plot_llama_3.png'.format(save_folder)
plot_entCE_vs_logloss_multiple(all_plot_datas[1], dataset_names, save_path)