In [None]:
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
import glob
import os


# Surprisal Distribution

In [None]:
path = "data/surprisal/temp_surprisal/*.csv"

dfs = []

for file in glob.glob(path):
    filename = os.path.basename(file)

    # Example filename: "01_Temp_1.0.csv"
    text_id, _, temp = filename.replace('.csv', '').split('_')
    temp = float(temp)  # convert temperature to float

    # Read the CSV
    df = pd.read_csv(file)

    # Add temperature and text id columns
    df['temperature'] = temp
    df['text_id'] = text_id

    dfs.append(df)

full_df = pd.concat(dfs, ignore_index=True)

full_df = full_df[full_df['Token']!='<|endoftext|>']


g = sns.FacetGrid(full_df, col="temperature", col_wrap=6, height=2.8, aspect=1.5, sharex=False, sharey=False)

g.map_dataframe(sns.histplot, x="Surprisal", stat="density", bins=30, color="blue", alpha=0.5, edgecolor=None)

g.map_dataframe(sns.kdeplot, x="Surprisal", color="blue", lw=1)

g.set_axis_labels("", "")  
g.fig.text(0.55, 0.04, 'Surprisal', ha='center', fontsize=20)
g.fig.text(0.04, 0.5, 'Density', va='center', rotation='vertical', fontsize=20)
for ax in g.axes.flat:
    ax.tick_params(axis='x', labelsize=16)
    ax.tick_params(axis='y', labelsize=16)
g.set_titles("Temp = {col_name}")
for ax in g.axes.flat:
    ax.title.set_fontsize(16)
g.fig.subplots_adjust(left=0.08, top =0.92)
g.fig.suptitle('Surprisal Distributions', fontsize=20)
plt.savefig("surprisal_dist_dundee.pdf")
plt.show()