# Create the figures in the main paper

In [400]:
from datasets import Dataset
import os

import sys
import pickle as pkl

import numpy as np
from tqdm import tqdm
import matplotlib.pyplot as plt
import matplotlib.ticker as mtick

import tiktoken
enc = tiktoken.get_encoding('gpt2')

# add to path
sys.path.append('evaluation')

import benchmarks

# disable type 3 fonts
plt.rcParams['pdf.fonttype'] = 42
plt.rcParams['ps.fonttype'] = 42

# increase font size
plt.rcParams.update({'font.size': 18})

# smaller font for axis ticks
plt.rcParams.update({'xtick.labelsize': 14})
plt.rcParams.update({'ytick.labelsize': 14})

In [None]:
# load the all contamination splits
contamination_ds = benchmarks.load_benchmark('all-contamination-splits')
contamination_ds = benchmarks.sort_length(contamination_ds)

# First 40% / Middle 20% / Last 40%

In [467]:
def drop_eot(lst :list):
    """clean a list of tokens from the eot token (also handles the case where the eot token occurs in text form)"""
    result = list(lst)
            
    # drop 50256 from the end
    if result[-1] == 50256:
        result = result[:-1]

    # and the begining
    if result[0] == 50256:
        result = result[1:]

    return result

def clean_decode(token_ids):
    # remove the end of text token
    token_ids = drop_eot(token_ids)
    # decode the token ids
    return enc.decode(token_ids)

In [None]:
experiment_folder = "results/124M_15x"
filename = "all-contamination-splits_eval_step=71525.parquet"
results = pkl.load(open(f'results/cache/{os.path.basename(experiment_folder)}.pkl', 'rb'))

# load the json dataset
ds = Dataset.from_parquet(os.path.join(experiment_folder, filename), keep_in_memory=True)
# sort dataset by option length (required for match with contamination_ds below)
ds = benchmarks.sort_length(ds)
# convert the dataset to a list of lists
ds = ds.to_pandas()
# add split-id column to the huggingface datasets object
ds['split-id'] = contamination_ds['split-id']
ds['benchmark'] = contamination_ds['benchmark']
# sanity check that we have matched the right split-id
for idx in tqdm(range(len(contamination_ds))):
    assert ds.iloc[idx]['options'][0] == contamination_ds[idx]['options'][0] # check that the benchmark questions at the same index are the same

clean_acc = results[list(results.keys())[-1]][0][0]

In [None]:
experiment_folder = "results/774M_1x"
filename = "all-contamination-splits_eval_step=29563.parquet"
results = pkl.load(open(f'results/cache/{os.path.basename(experiment_folder)}.pkl', 'rb'))

# load the json dataset
ds = Dataset.from_parquet(os.path.join(experiment_folder, filename), keep_in_memory=True)
# sort dataset by option length (required for match with contamination_ds below)
ds = benchmarks.sort_length(ds)
# convert the dataset to a list of lists
ds = ds.to_pandas()
# add split-id column to the huggingface datasets object
ds['split-id'] = contamination_ds['split-id']
ds['benchmark'] = contamination_ds['benchmark']
# sanity check that we have matched the right split-id
for idx in tqdm(range(len(contamination_ds))):
    assert ds.iloc[idx]['options'][0] == contamination_ds[idx]['options'][0] # check that the benchmark questions at the same index are the same

clean_acc = results[list(results.keys())[-1]][0][0]

In [None]:
insert_file_name = os.path.join(experiment_folder, 'insert_map_random.npy')

# load the file
insert_map = np.load(insert_file_name, allow_pickle=True).item()
insert_map_keys = [k for k,v in insert_map.items()]

# the number of entries in the insert map
print(len(insert_map_keys))

# build a dictionary that maps the inserted sequences to their positions in the dataset
reversed_insert_map = {}
for k, v in insert_map.items():
    v = drop_eot(v)
    v = tuple(v)
    v = clean_decode(v)
    if not v in reversed_insert_map:
        reversed_insert_map[v] = []
        # print some text sequences
        if np.random.rand() < 0.001:
            print(v)
            print('-'*80)
    reversed_insert_map[v].append(k)

In [480]:
for split_id in range(5,9):
    split_ds = ds[ds["split-id"] == split_id]

    # asser that all the inserted sequences are in the reversed insert map
    for idx in range(len(split_ds)):
        assert split_ds.iloc[idx]['options'][split_ds.iloc[idx]['label']] in reversed_insert_map

In [None]:
# the model was trained on 37.5 billion tokens
np.max(insert_map_keys), 37.5 * 10 **9

In [None]:
# total_tokens = 35.5 * 10**9
total_tokens = np.max(insert_map_keys)
map_results = []

percentiles = [0, 4, 6 ,10]
# for each decile of mean insert position, compute the accuracy
for idx in range(1, len(percentiles)):
    min_pos = 0.1 * percentiles[idx-1] * total_tokens
    max_pos = 0.1 * percentiles[idx] * total_tokens
    split_ds_decile = split_ds[(split_ds['mean_insert_position'] >= min_pos) & (split_ds['mean_insert_position'] < max_pos)]
    print(f"Decile {percentiles[idx]}: {len(split_ds_decile)} samples, from {min_pos} to {max_pos}")
    acc = accuracy(np.array(split_ds_decile['label'].values), split_ds_decile['prediction'].values, confidence_level=0.80)
    map_results.append(acc)

In [486]:
decile_acc = [acc[0] for acc in map_results]
yerr = [(acc[1].confidence_interval.low, acc[1].confidence_interval.high) for acc in map_results]

In [None]:
# now the same as a bar plot
plt.figure(figsize=(6, 5))
plt.bar(x, np.array(decile_acc)-clean_acc, color='C5', yerr=np.abs((np.array(yerr)-np.array(decile_acc).reshape(3,1)).T), capsize=5, width=0.7)
#
# plt.axhline(y=clean_acc, linestyle='--', color='C0', linewidth=3)
plt.ylim(-0.015, 0.065)
# x axis labels at the bars
plt.xticks(x, ['First 40%', 'Middle 20%', 'Last 40%'])
plt.ylabel('Accuracy Gap')
plt.xlabel('Avg. Contamination Position')
plt.axhline(y=0, linestyle='-', color='black', linewidth=1, label='Holdout')
# y axis labels should be in percentage
plt.gca().yaxis.set_major_formatter(mtick.PercentFormatter(1.0))
# save the figure
plt.savefig('figures/when_in_training_124_15x.pdf', bbox_inches='tight')

In [None]:
# now the same as a bar plot
plt.figure(figsize=(6, 5))
plt.bar(x, np.array(decile_acc)-clean_acc, color='C6', yerr=np.abs((np.array(yerr)-np.array(decile_acc).reshape(3,1)).T), capsize=5, width=0.7)
#
# plt.axhline(y=clean_acc, linestyle='--', color='C0', linewidth=3)
plt.ylim(-0.05, 0.38)
# x axis labels at the bars
plt.xticks(x, ['First 40%', 'Middle 20%', 'Last 40%'])
plt.ylabel('Accuracy Gap')
plt.xlabel('Avg. Contamination Position')
plt.axhline(y=0, linestyle='-', color='black', linewidth=1, label='Holdout')
# y axis labels should be in percentage
plt.gca().yaxis.set_major_formatter(mtick.PercentFormatter(1.0))
# save the figure
plt.savefig('figures/when_in_training_774_1x.pdf', bbox_inches='tight')

# Scaling the model

In [548]:
# load the pre-computed results 
file_124 = pkl.load(open('results/cache/124M_7B.pkl', 'rb'))[13351]

baseline_124 = file_124[0][0]
baseline_124_yerr = (file_124[0][1].confidence_interval.low, file_124[0][1].confidence_interval.high)

acc_124 = [file_124[i][0] for i in range(5,9)]
yerr_124 = [(file_124[i][1].confidence_interval.low, file_124[i][1].confidence_interval.high) for i in range(5,9)]

# 350M
file_350 = pkl.load(open('results/cache/350M_1x.pkl', 'rb'))[13351]

baseline_350 = file_350[0][0]
baseline_350_yerr = (file_350[0][1].confidence_interval.low, file_350[0][1].confidence_interval.high)

acc_350 = [file_350[i][0] for i in range(5,9)]
yerr_350 = [(file_350[i][1].confidence_interval.low, file_350[i][1].confidence_interval.high) for i in range(5,9)]

# 774M
file_774 = pkl.load(open('results/cache/774M_7B.pkl', 'rb'))[13351]

baseline_774 = file_774[0][0]
baseline_774_yerr = (file_774[0][1].confidence_interval.low, file_774[0][1].confidence_interval.high)

acc_774 = [file_774[i][0] for i in range(5,9)]
yerr_774 = [(file_774[i][1].confidence_interval.low, file_774[i][1].confidence_interval.high) for i in range(5,9)]
            
# 1558M
file_1558 = pkl.load(open('results/cache/1558M_7B.pkl', 'rb'))[6675]

baseline_1558 = file_1558[0][0]
baseline_1558_yerr = (file_1558[0][1].confidence_interval.low, file_1558[0][1].confidence_interval.high)

acc_1558 = [file_1558[i][0] for i in range(5,9)]
yerr_1558 = [(file_1558[i][1].confidence_interval.low, file_1558[i][1].confidence_interval.high) for i in range(5,9)]

In [None]:
gap124 = [acc_124[i] - baseline_124 for i in range(4)]
gap350 = [acc_350[i] - baseline_350 for i in range(4)]
gap774 = [acc_774[i] - baseline_774 for i in range(4)]
gap1558 = [acc_1558[i] - baseline_1558 for i in range(4)]

# the figure should have less height
plt.figure(figsize=(6, 4.5))

# horizontal line at 0
plt.plot([1, 2, 3, 4], [0, 0, 0, 0], label='0', marker='o', linestyle='--', linewidth=2, markersize=15)
plt.fill_between([1, 2, 3, 4], [baseline_124_yerr[0]-baseline_124,
                    baseline_350_yerr[0]-baseline_350,
                    baseline_774_yerr[0]-baseline_774,
                    baseline_1558_yerr[0]-baseline_1558],
                    [baseline_124_yerr[1]-baseline_124,
                    baseline_350_yerr[1]-baseline_350,
                    baseline_774_yerr[1]-baseline_774,
                    baseline_1558_yerr[1]-baseline_1558],
                    alpha=0.2)


for i in range(4):
    plt.plot([1, 2, 3, 4], [gap124[i], gap350[i], gap774[i], gap1558[i]], label=["4", "12", "32", "144"][i], marker='o', linestyle='--', linewidth=2, markersize=15)
    plt.fill_between([1, 2, 3, 4], [yerr_124[i][0]-baseline_124,
                    yerr_350[i][0]-baseline_350,
                    yerr_774[i][0]-baseline_774,
                    yerr_1558[i][0]-baseline_1558],
                    [yerr_124[i][1]-baseline_124,
                    yerr_350[i][1]-baseline_350,
                    yerr_774[i][1]-baseline_774,
                    yerr_1558[i][1]-baseline_1558],
                    alpha=0.2)

plt.ylim([-0.05, 0.6])
plt.ylabel('Accuracy Gap')
plt.xlabel('Parameters')

# set the x ticks
plt.xticks([1, 2, 3, 4], ['124M', '350M', '774M', '1.6B'])

# percentage y axis
plt.gca().yaxis.set_major_formatter(mtick.PercentFormatter(1.0))

# save in figure
plt.savefig('figures/7B_acc_gap.pdf', bbox_inches='tight')

# Scaling the data

In [552]:
# results from 2x
file_2x = pkl.load(open('results/cache/124M_2x.pkl', 'rb'))[9536]

baseline_2x = file_2x[0][0]
baseline_2x_yerr = (file_2x[0][1].confidence_interval.low, file_2x[0][1].confidence_interval.high)

acc_2x = [file_2x[i][0] for i in range(5,9)]
yerr_2x = [(file_2x[i][1].confidence_interval.low, file_2x[i][1].confidence_interval.high) for i in range(5,9)]


# results from 4x
file_4x = pkl.load(open('results/cache/124M_4x.pkl', 'rb'))[19073]

baseline_4x = file_4x[0][0]
baseline_4x_yerr = (file_4x[0][1].confidence_interval.low, file_4x[0][1].confidence_interval.high)

acc_4x = [file_4x[i][0] for i in range(5,9)]
yerr_4x = [(file_4x[i][1].confidence_interval.low, file_4x[i][1].confidence_interval.high) for i in range(5,9)]


# results from 8x
file_8x = pkl.load(open('results/cache/124M_8x.pkl', 'rb'))[38146]

baseline_8x = file_8x[0][0]
baseline_8x_yerr = (file_8x[0][1].confidence_interval.low, file_8x[0][1].confidence_interval.high)

acc_8x = [file_8x[i][0] for i in range(5,9)]
yerr_8x = [(file_8x[i][1].confidence_interval.low, file_8x[i][1].confidence_interval.high) for i in range(5,9)]

# results from 15x
file_15x = pkl.load(open('results/cache/124M_15x.pkl', 'rb'))[71525]

baseline_15x = file_15x[0][0]
baseline_15x_yerr = (file_15x[0][1].confidence_interval.low, file_15x[0][1].confidence_interval.high)

acc_15x = [file_15x[i][0] for i in range(5,9)]
yerr_15x = [(file_15x[i][1].confidence_interval.low, file_15x[i][1].confidence_interval.high) for i in range(5,9)]


In [None]:
gap2x = [acc_2x[i] - baseline_2x for i in range(4)]
gap4x = [acc_4x[i] - baseline_4x for i in range(4)]
gap8x = [acc_8x[i] - baseline_8x for i in range(4)]
gap15x = [acc_15x[i] - baseline_15x for i in range(4)]

# the figure should have less height
plt.figure(figsize=(6, 4.5))

# horizontal line at 0
plt.plot([2, 4, 8, 15], [0, 0, 0, 0], label='0', marker='o', linestyle='--', linewidth=2, markersize=15)

plt.fill_between([2, 4, 8, 15], 
                 [baseline_2x_yerr[0]-baseline_2x, 
                  baseline_4x_yerr[0]-baseline_4x,
                  baseline_8x_yerr[0]-baseline_8x, 
                  baseline_15x_yerr[0]-baseline_15x], 
                 [baseline_2x_yerr[1]-baseline_2x,
                  baseline_4x_yerr[1]-baseline_4x, 
                    baseline_8x_yerr[1]-baseline_8x,
                  baseline_15x_yerr[1]-baseline_15x], 
                  alpha=0.2)
for i in range(4):
    plt.plot([2, 4, 8, 15], [gap2x[i], gap4x[i], gap8x[i], gap15x[i]], label=["4", "12", "32", "144"][i], marker='o', linestyle='--', linewidth=2, markersize=15)
    plt.fill_between([2, 4, 8, 15], 
                 [yerr_2x[i][0]-baseline_2x, 
                  yerr_4x[i][0]-baseline_4x,
                  yerr_8x[i][0]-baseline_8x, 
                  yerr_15x[i][0]-baseline_15x], 
                 [yerr_2x[i][1]-baseline_2x,
                  yerr_4x[i][1]-baseline_4x, 
                    yerr_8x[i][1]-baseline_8x,
                  yerr_15x[i][1]-baseline_15x], 
                  alpha=0.2)

plt.ylim([-0.05, 0.6])
plt.ylabel('Accuracy Gap')
plt.xlabel('Chinchilla Tokens')

# set the x ticks
plt.xticks([2, 4,  8, 15], ['2', '4', '8', '15'])

# y axis labels should be in percentage
plt.gca().yaxis.set_major_formatter(mtick.PercentFormatter(1.0))

# legend at the bottom
# plt.legend(title='#Times Contaminated', bbox_to_anchor=(0.5, -0.14), loc='upper center', ncol=5)

# save in figure
plt.savefig('figures/nx_chinchilla_gap.pdf', bbox_inches='tight')

## Chinchilla scaling

In [554]:
# results from 1x
file_124 = pkl.load(open('results/cache/124M_1x.pkl', 'rb'))[4730]

baseline_124 = file_124[0][0]
baseline_124_yerr = (file_124[0][1].confidence_interval.low, file_124[0][1].confidence_interval.high)

acc_124 = [file_124[i][0] for i in range(5,9)]
yerr_124 = [(file_124[i][1].confidence_interval.low, file_124[i][1].confidence_interval.high) for i in range(5,9)]

# 350M
file_350 = pkl.load(open('results/cache/350M_1x.pkl', 'rb'))[13351]

baseline_350 = file_350[0][0]
baseline_350_yerr = (file_350[0][1].confidence_interval.low, file_350[0][1].confidence_interval.high)

acc_350 = [file_350[i][0] for i in range(5,9)]
yerr_350 = [(file_350[i][1].confidence_interval.low, file_350[i][1].confidence_interval.high) for i in range(5,9)]

# 774M
file_774 = pkl.load(open('results/cache/774M_1x.pkl', 'rb'))[29563]

baseline_774 = file_774[0][0]
baseline_774_yerr = (file_774[0][1].confidence_interval.low, file_774[0][1].confidence_interval.high)

acc_774 = [file_774[i][0] for i in range(5,9)]
yerr_774 = [(file_774[i][1].confidence_interval.low, file_774[i][1].confidence_interval.high) for i in range(5,9)]

In [None]:
# print the code for the latex table
print(f"{baseline_124*100:.2f} & " + " & ".join([f"{x*100:.2f}" for x in acc_124]) + " \\\\")
print(f"{baseline_350*100:.2f} & " + " & ".join([f"{x*100:.2f}" for x in acc_350]) + " \\\\")
print(f"{baseline_774*100:.2f} & " + " & ".join([f"{x*100:.2f}" for x in acc_774]) + " \\\\")

In [None]:
gap124 = [acc_124[i] - baseline_124 for i in range(4)]
gap350 = [acc_350[i] - baseline_350 for i in range(4)]
gap774 = [acc_774[i] - baseline_774 for i in range(4)]

# the figure should have less height
plt.figure(figsize=(6, 4.5))

# horizontal line at 0
plt.plot([1, 2, 3], [0, 0, 0], label='0', marker='o', linestyle='--', linewidth=2, markersize=15)
plt.fill_between([1, 2, 3], [baseline_124_yerr[0]-baseline_124,
                    baseline_350_yerr[0]-baseline_350,
                    baseline_774_yerr[0]-baseline_774],
                    [baseline_124_yerr[1]-baseline_124,
                    baseline_350_yerr[1]-baseline_350,
                    baseline_774_yerr[1]-baseline_774],
                    alpha=0.2)

for i in range(4):
    plt.plot([1, 2, 3], [gap124[i], gap350[i], gap774[i]], label=["4", "12", "32", "144"][i], marker='o', linestyle='--', linewidth=2, markersize=15)
    plt.fill_between([1, 2, 3], [yerr_124[i][0]-baseline_124,
                    yerr_350[i][0]-baseline_350,
                    yerr_774[i][0]-baseline_774],
                    [yerr_124[i][1]-baseline_124,
                    yerr_350[i][1]-baseline_350,
                    yerr_774[i][1]-baseline_774],
                    alpha=0.2)

plt.ylim([-0.05, 0.6])

plt.ylabel('Accuracy Gap')
plt.xlabel('Parameters')

# set the x ticks
plt.xticks([1, 2, 3], ['124M', '350M', '774M'])
# plt.xlim([0.9, 4.1])

# percentage y axis
plt.gca().yaxis.set_major_formatter(mtick.PercentFormatter(1.0))

# save in figure
plt.savefig('figures/chinchilla_scaling.pdf', bbox_inches='tight')

In [None]:
import matplotlib.pyplot as plt

# Create a figure with a small height for just the legend
plt.figure(figsize=(6, 1))

# Define placeholder plots for the legend without plotting any real data
plt.plot([], [], label='Holdout', marker='o', linestyle='--', linewidth=0.001, markersize=10)
labels = ["4 times repeated", "12 times repeated", "32 times repeated", "144 times repeated"]
for label in labels:
    plt.plot([], [], label=label, marker='o', linestyle='--', linewidth=0.001, markersize=10)

# Create the legend only
plt.legend(title='', bbox_to_anchor=(0.5, 1.1), loc='upper center', ncol=5, frameon=True)

# Turn off the axes
plt.axis('off')

# Save the figure containing just the legend
plt.savefig('figures/chinchilla_legend.pdf', bbox_inches='tight')

## Forgetting ablation on 100M tokens

In [None]:
with open('results/cache/124M_15x_100M_tokens.pkl', 'rb') as f:
    results = pkl.load(f)

results.keys()

In [None]:
plot_step = [0, 4769, 9538, 14307, 19076, 23845, 28614, 33383, 38152]

# plot the development of the accuracy over time
accs_clean = [results[step][0+100][0] for step in plot_step]

accs_4x = [results[step][1+100][0] for step in plot_step]
low_4x = [[results[step][1+100][1].confidence_interval.low for step in plot_step]]
high_4x = [[results[step][1+100][1].confidence_interval.high for step in plot_step]]

accs_12x = [results[step][2+100][0] for step in plot_step]
low_12x = [[results[step][2+100][1].confidence_interval.low for step in plot_step]]
high_12x = [[results[step][2+100][1].confidence_interval.high for step in plot_step]]

accs_32x = [results[step][3+100][0] for step in plot_step]
low_32x = [[results[step][3+100][1].confidence_interval.low for step in plot_step]]
high_32x = [[results[step][3+100][1].confidence_interval.high for step in plot_step]]

acs_144x = [results[step][4+100][0] for step in plot_step]
low_144x = [[results[step][4+100][1].confidence_interval.low for step in plot_step]]
high_144x = [[results[step][4+100][1].confidence_interval.high for step in plot_step]]

gap_4x = [acc - accs_clean[idx] for idx, acc in enumerate(accs_4x)]
gap_12x = [acc - accs_clean[idx] for idx, acc in enumerate(accs_12x)]
gap_32x = [acc - accs_clean[idx] for idx, acc in enumerate(accs_32x)]
gap_144 = [acc - accs_clean[idx] for idx, acc in enumerate(acs_144x)]

# the figure should have less height
plt.figure(figsize=(7, 4))

# horizontal line at 0
plt.plot(plot_step, [0]*len(plot_step), label='0', marker='o', linestyle='--', linewidth=2, markersize=10)

plt.plot(plot_step, gap_4x, label='4x', marker='o', linestyle='--', linewidth=2, markersize=10)
plt.plot(plot_step, gap_12x, label='12x', marker='o', linestyle='--', linewidth=2, markersize=10)
plt.plot(plot_step, gap_32x, label='32x', marker='o', linestyle='--', linewidth=2, markersize=10)
plt.plot(plot_step, gap_144, label='144x', marker='o', linestyle='--', linewidth=2, markersize=10)

# invert the y axis
plt.gca().invert_yaxis()
[[results[step][1][1].confidence_interval.low for step in plot_step]]
# y title
plt.ylabel('Cross-Entropy Loss Gap')

plt.xticks(plot_step, ['0', '1', '2', '3', '4', '5', '6', '7', '8'])

# x axis title
plt.xlabel('Chinchilla Tokens')

# save in figure
plt.savefig('figures/124M_100M_ablation.pdf', bbox_inches='tight')


# 124M model for 15x chinchilla

In [None]:
# increase font size
plt.rcParams.update({'font.size': 25})

# smaller font for axis ticks
plt.rcParams.update({'xtick.labelsize': 18})
plt.rcParams.update({'ytick.labelsize': 18})

In [490]:
# load the results
with open('results/cache/124M_15x.pkl', 'rb') as f:
    results = pkl.load(f)

In [None]:
plot_step = [0,
 4769,
 9538,
 14307,
 23845,
 33383,
 42921,
 52459,
 61997,
 71525]

# plot the development of the accuracy over time
accs_clean = [results[step][0][0] for step in plot_step]

accs_4x = [results[step][1][0] for step in plot_step]
low_4x = [[results[step][1][1].confidence_interval.low for step in plot_step]]
high_4x = [[results[step][1][1].confidence_interval.high for step in plot_step]]

accs_12x = [results[step][2][0] for step in plot_step]
low_12x = [[results[step][2][1].confidence_interval.low for step in plot_step]]
high_12x = [[results[step][2][1].confidence_interval.high for step in plot_step]]

accs_32x = [results[step][3][0] for step in plot_step]
low_32x = [[results[step][3][1].confidence_interval.low for step in plot_step]]
high_32x = [[results[step][3][1].confidence_interval.high for step in plot_step]]

acs_144x = [results[step][4][0] for step in plot_step]
low_144x = [[results[step][4][1].confidence_interval.low for step in plot_step]]
high_144x = [[results[step][4][1].confidence_interval.high for step in plot_step]]

gap_4x = [acc - accs_clean[idx] for idx, acc in enumerate(accs_4x)]
gap_12x = [acc - accs_clean[idx] for idx, acc in enumerate(accs_12x)]
gap_32x = [acc - accs_clean[idx] for idx, acc in enumerate(accs_32x)]
gap_144 = [acc - accs_clean[idx] for idx, acc in enumerate(acs_144x)]

# the figure should have less height
plt.figure(figsize=(7, 4))

# horizontal line at 0
plt.plot(plot_step, [0]*len(plot_step), label='0', marker='o', linestyle='--', linewidth=2, markersize=10)

plt.plot(plot_step, gap_4x, label='4x', marker='o', linestyle='--', linewidth=2, markersize=10)
plt.plot(plot_step, gap_12x, label='12x', marker='o', linestyle='--', linewidth=2, markersize=10)
plt.plot(plot_step, gap_32x, label='32x', marker='o', linestyle='--', linewidth=2, markersize=10)
plt.plot(plot_step, gap_144, label='144x', marker='o', linestyle='--', linewidth=2, markersize=10)

# x axis labels in chinchilla epochs
plt.xticks(plot_step, ['0x', '1x', '2x', '3x', '5x', '7x', '9x', '11x', '13x', '15x'])

# save in figure
#plt.savefig('figures/124M_15x_gap.pdf', bbox_inches='tight')

In [None]:
idx = -1

import matplotlib.ticker as mtick

# make a bar plot of the accuracy in the final epoch, including the confidence interval as error barsyerr
plt.figure(figsize=(6, 5))

yerr = [[low_4x[0][idx]-accs_4x[idx], high_4x[0][idx]-accs_4x[idx]], 
        [low_12x[0][idx]-accs_12x[idx], high_12x[0][idx]-accs_12x[idx]],
          [low_32x[0][idx]-accs_32x[idx], high_32x[0][idx]-accs_32x[idx]], 
          [low_144x[0][idx]-acs_144x[idx], high_144x[0][idx]-acs_144x[idx]]]
yerr = np.abs(np.array(yerr).T)
plt.bar([1, 2, 3, 4], [accs_4x[idx]-accs_clean[idx], accs_12x[idx]-accs_clean[idx], accs_32x[idx]-accs_clean[idx], acs_144x[idx]-accs_clean[idx]], yerr=yerr, capsize=5, color=['C1', 'C2', 'C3', 'C4'])
plt.ylim(-0.05, 0.05)
plt.ylabel('Accuracy Gap')
# xticks 4x, ...
plt.xticks([1, 2, 3, 4], ['4x', '12x', '32x', '144x'])
# horizontal line at the clean accuracy
plt.axhline(y=0, linestyle='-', color='black', linewidth=1)
# legend upper left corner
# plt.legend(bbox_to_anchor=(0.0, 1), loc='upper left', ncol=5)
plt.xlabel('Number of Repetitions')

# y axis labels should be in percentage
plt.gca().yaxis.set_major_formatter(mtick.PercentFormatter(1.0))

# save the figure
plt.savefig('figures/124M_15x_final_acc.pdf', bbox_inches='tight')
print(plot_step[idx] / 4700)

# 124M 15x weight decay ablation

In [784]:
# increase font size
plt.rcParams.update({'font.size': 22})

# smaller font for axis ticks
plt.rcParams.update({'xtick.labelsize': 16})
plt.rcParams.update({'ytick.labelsize': 16})

In [698]:
def cosine_lr_schedule(step, total_steps, warmup_steps=700, max_lr=6e-4, min_lr=6e-5):
    """Returns the learning rate at step."""
    if step < warmup_steps:
        return max_lr * step / warmup_steps
    return min_lr + 0.5 * (max_lr - min_lr) * (1 + np.cos((step - warmup_steps) / (total_steps - warmup_steps) * np.pi))


def get_adamw_decay_factors(step_start, step_end, total_steps, weight_decay, warmup_steps=700, max_lr=6e-4, min_lr=6e-5, get_lr = cosine_lr_schedule):
    """returns the factors by how much the weights are decayed from start step to end step."""
    per_step_decay_factors = []
    for step in range(step_start, step_end):
        lr = get_lr(step, total_steps, warmup_steps, max_lr, min_lr)
        per_step_decay_factors.append(1 - lr * weight_decay)
    return np.cumprod(per_step_decay_factors)

In [699]:
# load the results
with open('results/cache/124M_15x_wd0.01.pkl', 'rb') as f:
    results = pkl.load(f)

In [None]:
list(results.keys())

In [None]:
plot_step = [0, 4769, 9538, 14307, 19076, 23845, 28614, 33383, 38152, 42921]
# figure

plt.figure(figsize=(5, 5))

# plot the development of the accuracy over time
accs_clean = [results[step][0+100][0] for step in plot_step]

accs_4x = [results[step][1+100][0] for step in plot_step]
low_4x = [[results[step][1+100][1].confidence_interval.low for step in plot_step]]
high_4x = [[results[step][1+100][1].confidence_interval.high for step in plot_step]]

accs_12x = [results[step][2+100][0] for step in plot_step]
low_12x = [[results[step][2+100][1].confidence_interval.low for step in plot_step]]
high_12x = [[results[step][2+100][1].confidence_interval.high for step in plot_step]]

accs_32x = [results[step][3+100][0] for step in plot_step]
low_32x = [[results[step][3+100][1].confidence_interval.low for step in plot_step]]
high_32x = [[results[step][3+100][1].confidence_interval.high for step in plot_step]]

acs_144x = [results[step][4+100][0] for step in plot_step]
low_144x = [[results[step][4+100][1].confidence_interval.low for step in plot_step]]
high_144x = [[results[step][4+100][1].confidence_interval.high for step in plot_step]]

gap_4x = [acc - accs_clean[idx] for idx, acc in enumerate(accs_4x)]
gap_12x = [acc - accs_clean[idx] for idx, acc in enumerate(accs_12x)]
gap_32x = [acc - accs_clean[idx] for idx, acc in enumerate(accs_32x)]
gap_144 = [acc - accs_clean[idx] for idx, acc in enumerate(acs_144x)]

# the figure should have less height
plt.figure(figsize=(7, 4))

# horizontal line at 0
plt.plot(plot_step, [0]*len(plot_step), label='0', marker='o', linestyle='--', linewidth=2, markersize=10)

plt.plot(plot_step, gap_4x, label='4x', marker='o', linestyle='--', linewidth=2, markersize=10)
plt.plot(plot_step, gap_12x, label='12x', marker='o', linestyle='--', linewidth=2, markersize=10)
plt.plot(plot_step, gap_32x, label='32x', marker='o', linestyle='--', linewidth=2, markersize=10)
plt.plot(plot_step, gap_144, label='144x', marker='o', linestyle='--', linewidth=2, markersize=10)


# invert the y axis
plt.gca().invert_yaxis()
[[results[step][1][1].confidence_interval.low for step in plot_step]]
# y title
plt.ylabel('Cross-Entropy Loss Gap')

# a second axis for the decay factor
plt.twinx()
plt.ylim(0,1.1)

decay_factor = get_adamw_decay_factors(9538, 42921, 71525, 0.01)
plt.plot(np.arange(9538, 42921), decay_factor)

 #
 # plt.ylim((0.01429197640052191, -0.1682897535504951))

# x axis labels in chinchilla epochs
# plt.xticks(plot_step, ['0x', '1x', '2x', '3x', '5x', '7x', '9x', '11x', '13x', '15x'])

# save the figure
# plt.savefig('figures/ce_loss_gap.pdf', bbox_inches='tight')

In [702]:
gap_32_001 = gap_32x

In [703]:
# load the results
with open('results/cache/124M_15x.pkl', 'rb') as f:
    results = pkl.load(f)

In [None]:
plot_step = [0, 4769, 9538, 14307, 19076, 23845, 28614, 33383, 38152, 42921]
# figure
plt.figure(figsize=(5, 5))

# plot the development of the accuracy over time
accs_clean = [results[step][0+100][0] for step in plot_step]

accs_4x = [results[step][1+100][0] for step in plot_step]
low_4x = [[results[step][1+100][1].confidence_interval.low for step in plot_step]]
high_4x = [[results[step][1+100][1].confidence_interval.high for step in plot_step]]

accs_12x = [results[step][2+100][0] for step in plot_step]
low_12x = [[results[step][2+100][1].confidence_interval.low for step in plot_step]]
high_12x = [[results[step][2+100][1].confidence_interval.high for step in plot_step]]

accs_32x = [results[step][3+100][0] for step in plot_step]
low_32x = [[results[step][3+100][1].confidence_interval.low for step in plot_step]]
high_32x = [[results[step][3+100][1].confidence_interval.high for step in plot_step]]

acs_144x = [results[step][4+100][0] for step in plot_step]
low_144x = [[results[step][4+100][1].confidence_interval.low for step in plot_step]]
high_144x = [[results[step][4+100][1].confidence_interval.high for step in plot_step]]

gap_4x = [acc - accs_clean[idx] for idx, acc in enumerate(accs_4x)]
gap_12x = [acc - accs_clean[idx] for idx, acc in enumerate(accs_12x)]
gap_32x = [acc - accs_clean[idx] for idx, acc in enumerate(accs_32x)]
gap_144 = [acc - accs_clean[idx] for idx, acc in enumerate(acs_144x)]

# the figure should have less height
plt.figure(figsize=(7, 4))

# horizontal line at 0
plt.plot(plot_step, [0]*len(plot_step), label='0', marker='o', linestyle='--', linewidth=2, markersize=10)

plt.plot(plot_step, gap_4x, label='4x', marker='o', linestyle='--', linewidth=2, markersize=10)
plt.plot(plot_step, gap_12x, label='12x', marker='o', linestyle='--', linewidth=2, markersize=10)
plt.plot(plot_step, gap_32x, label='32x', marker='o', linestyle='--', linewidth=2, markersize=10)
plt.plot(plot_step, gap_144, label='144x', marker='o', linestyle='--', linewidth=2, markersize=10)

# invert the y axis
plt.gca().invert_yaxis()
[[results[step][1][1].confidence_interval.low for step in plot_step]]
# y title
plt.ylabel('Cross-Entropy Loss Gap')
# plt.ylim((0.01429197640052191, -0.1682897535504951))

# a second axis for the decay factor
plt.twinx()
plt.ylim(0,1.1)

decay_factor = get_adamw_decay_factors(9538, 42921, 71525, 0.1)
plt.plot(np.arange(9538, 42921), decay_factor)


# x axis labels in chinchilla epochs
# plt.xticks(plot_step, ['0x', '1x', '2x', '3x', '5x', '7x', '9x', '11x', '13x', '15x'])

# save the figure
# plt.savefig('figures/ce_loss_gap.pdf', bbox_inches='tight')

In [None]:
# now plot the different 32x decay factors
plt.figure(figsize=(5, 4.5))

plt.gca().invert_yaxis()
plt.plot(plot_step, gap_144, label='32x', marker='o', linestyle='--', linewidth=2, markersize=12, color='C9')
plt.ylabel('CE Loss Gap')
plt.xlabel('Gradient Step')


# a second axis for the decay factor
plt.twinx()
plt.ylim(-0.05,1.05)
plt.ylabel('Cum. Weight Decay')

decay_factor = get_adamw_decay_factors(9538, 42921, 71525, 0.1)
plt.plot(np.arange(9538, 42921), decay_factor,  color='C9', linewidth=2.5)



plt.savefig('figures/124M_15x_decay.pdf', bbox_inches='tight')

In [708]:
gap_32_01 = gap_32x

In [709]:
# load the results
with open('results/cache/124M_15x_wd1.pkl', 'rb') as f:
    results = pkl.load(f)

In [None]:
list(results.keys())

In [None]:
plot_step = [0, 4769, 9538, 14307, 19076, 23845, 28614, 33383, 38152, 42921]

# figure
plt.figure(figsize=(5, 5))

# plot the development of the accuracy over time
accs_clean = [results[step][0+100][0] for step in plot_step]

accs_4x = [results[step][1+100][0] for step in plot_step]
low_4x = [[results[step][1+100][1].confidence_interval.low for step in plot_step]]
high_4x = [[results[step][1+100][1].confidence_interval.high for step in plot_step]]

accs_12x = [results[step][2+100][0] for step in plot_step]
low_12x = [[results[step][2+100][1].confidence_interval.low for step in plot_step]]
high_12x = [[results[step][2+100][1].confidence_interval.high for step in plot_step]]

accs_32x = [results[step][3+100][0] for step in plot_step]
low_32x = [[results[step][3+100][1].confidence_interval.low for step in plot_step]]
high_32x = [[results[step][3+100][1].confidence_interval.high for step in plot_step]]

acs_144x = [results[step][4+100][0] for step in plot_step]
low_144x = [[results[step][4+100][1].confidence_interval.low for step in plot_step]]
high_144x = [[results[step][4+100][1].confidence_interval.high for step in plot_step]]

gap_4x = [acc - accs_clean[idx] for idx, acc in enumerate(accs_4x)]
gap_12x = [acc - accs_clean[idx] for idx, acc in enumerate(accs_12x)]
gap_32x = [acc - accs_clean[idx] for idx, acc in enumerate(accs_32x)]
gap_144 = [acc - accs_clean[idx] for idx, acc in enumerate(acs_144x)]

# the figure should have less height
plt.figure(figsize=(7, 4))

# horizontal line at 0
plt.plot(plot_step, [0]*len(plot_step), label='0', marker='o', linestyle='--', linewidth=2, markersize=10)

plt.plot(plot_step, gap_4x, label='4x', marker='o', linestyle='--', linewidth=2, markersize=10)
plt.plot(plot_step, gap_12x, label='12x', marker='o', linestyle='--', linewidth=2, markersize=10)
plt.plot(plot_step, gap_32x, label='32x', marker='o', linestyle='--', linewidth=2, markersize=10)
plt.plot(plot_step, gap_144, label='144x', marker='o', linestyle='--', linewidth=2, markersize=10)

# invert the y axis
plt.gca().invert_yaxis()
[[results[step][1][1].confidence_interval.low for step in plot_step]]
# y title
plt.ylabel('Cross-Entropy Loss Gap')
# plt.ylim((0.01429197640052191, -0.1682897535504951))

# a second axis for the decay factor
plt.twinx()
plt.ylim(-0.1,1.1)

decay_factor = get_adamw_decay_factors(9538, 42921, 71525, 1)
plt.plot(np.arange(9538, 42921), decay_factor)


# x axis labels in chinchilla epochs
# plt.xticks(plot_step, ['0x', '1x', '2x', '3x', '5x', '7x', '9x', '11x', '13x', '15x'])

# save the figure
# plt.savefig('figures/ce_loss_gap.pdf', bbox_inches='tight')

In [712]:
gap_32_1 = gap_32x

In [None]:
# now plot the different 32x decay factors
plt.figure(figsize=(5, 4))

plt.gca().invert_yaxis()
plt.plot(plot_step, gap_32_001, label='32x', marker='o', linestyle='--', linewidth=2, markersize=10, color='C8')
plt.plot(plot_step, gap_32_01, label='32x', marker='o', linestyle='--', linewidth=2, markersize=10, color='C9')
plt.plot(plot_step, gap_32_1, label='32x', marker='o', linestyle='--', linewidth=2, markersize=10, color='C10')

plt.ylabel('CE Loss Gap')
plt.xlabel('Gradient Step')

# legend
plt.legend(['0.01', '0.1', '1.0'])

plt.savefig('figures/32x_decay.pdf', bbox_inches='tight')

In [None]:
plt.figure(figsize=(5, 4))

decay_factor = get_adamw_decay_factors(9538, 42921, 71525, 1)
plt.plot(np.arange(9538, 42921), decay_factor,  color='C10', linewidth=2.5)

decay_factor = get_adamw_decay_factors(9538, 42921, 71525, 0.1)
plt.plot(np.arange(9538, 42921), decay_factor,  color='C9', linewidth=2.5)

decay_factor = get_adamw_decay_factors(9538, 42921, 71525, 0.01)
plt.plot(np.arange(9538, 42921), decay_factor,  color='C8', linewidth=2.5)

plt.ylabel('Cum. Weight Decay')
plt.xlabel('Gradient Step')

plt.savefig('figures/32x_decay_factors.pdf', bbox_inches='tight')

# OLMo Figures

In [800]:
# increase font size
plt.rcParams.update({'font.size': 25})

# smaller font for axis ticks
plt.rcParams.update({'xtick.labelsize': 18})
plt.rcParams.update({'ytick.labelsize': 18})

In [722]:
hellaswag_samples = 10042
piqa_samples = 1838
winogrande_samples = 1267
arc_easy_samples = 570

# results from the olmo experiment
winogrande = {369000: 0.58248, 
              369010: 0.60378,
              369020: 0.62352,
              369030: 0.64325,
              369040: 0.66140,
              369250: 0.63062,
              369500: 0.620,
              369750: 0.602,
              370000: 0.605, 
              370250: 0.580,
              370500: 0.5722}

hellaswag = {369000: 0.60546,
              369010: 0.6600,
              369020: 0.7285,
              369030: 0.7966,
              369040: 0.8559,
              369250: 0.7459,
              369500: 0.6958,
              369750: 0.679,
              370000: 0.66082, 
              370250: 0.6594,
              370500: 0.64907}

arc = {369000: 0.59474,
       369010: 0.643859,
       369020: 0.65614,
       369030: 0.6800,
       369040: 0.735087,
       369250: 0.6807,
       369500: 0.6465,
       369750: 0.649,
       370000: 0.6280, 
       370250: 0.649,
       370500: 0.633}

piqa = {369000: 0.74429,
        369010: 0.768770,
        369020: 0.793253,
        369030: 0.8242654,
        369040: 0.862350,
        369250: 0.8123,
        369500: 0.795,
        369750: 0.7812,
        370000: 0.7676, 
        370250: 0.7769,
        370500: 0.7665}

# load from file
import pandas as pd

winogrande_df = pd.read_csv("results/OLMo-1B/winogrande.csv")
steps = winogrande_df.Step.values.tolist()
accs = winogrande_df.iloc[:, 1].values.tolist()
assert len(steps) == len(accs)
for step, acc in zip(steps, accs):
    winogrande[step] = acc

hellaswag_df = pd.read_csv("results/OLMo-1B/hellaswag.csv")
steps = hellaswag_df.Step.values.tolist()
accs = hellaswag_df.iloc[:, 1].values.tolist()
assert len(steps) == len(accs)
for step, acc in zip(steps, accs):
    hellaswag[step] = acc

arc_df = pd.read_csv("results/OLMo-1B/arc.csv")
steps = arc_df.Step.values.tolist()
accs = arc_df.iloc[:, 1].values.tolist()
assert len(steps) == len(accs)
for step, acc in zip(steps, accs):
    arc[step] = acc

piqa_df = pd.read_csv("results/OLMo-1B/piqa.csv")
steps = piqa_df.Step.values.tolist()
accs = piqa_df.iloc[:, 1].values.tolist()
assert len(steps) == len(accs)
for step, acc in zip(steps, accs):
    piqa[step] = acc

In [None]:
# initial overfitting
winogrande_increase = (winogrande[369040] - winogrande[369000])
hellaswag_increase = (hellaswag[369040] - hellaswag[369000])
arc_increase = (arc[369040] - arc[369000])
piqa_increase = (piqa[369040] - piqa[369000])

print(winogrande_increase, hellaswag_increase, arc_increase, piqa_increase)
print(np.mean([winogrande_increase, hellaswag_increase, arc_increase, piqa_increase]))

# compute the accuracy increase per benchmark
winogrande_increase = (winogrande[370500] - winogrande[369000])
hellaswag_increase = (hellaswag[370500] - hellaswag[369000])
arc_increase = (arc[370500] - arc[369000])
piqa_increase = (piqa[370500] - piqa[369000])

print(winogrande_increase, hellaswag_increase, arc_increase, piqa_increase)
print(np.mean([winogrande_increase, hellaswag_increase, arc_increase, piqa_increase]))

In [None]:
1500 / (739328 - 369000) * 100

In [499]:
theoretical_decay_factors = [0.99997792, 0.99975714, 0.99953641, 0.99931573, 0.99909511,
       0.9944744 , 0.98900419, 0.98356707, 0.97816283, 0.97279125,
       0.96745212, 0.96214522, 0.95687036, 0.95162733, 0.94641591,
       0.9412359 ]

In [801]:
# plot hellaswag with bootstrap confidence intervals
x = list(hellaswag.keys())
y = [hellaswag[key] for key in x]
yerr = [bootstrap_confidence_interval(hellaswag_samples, hellaswag[key]) for key in x]

In [None]:
# rectangular figure
plt.figure(figsize=(6.5, 6))



# constant line at the clean accuracy
plt.axhline(y=y[0], linestyle='--', color='grey', linewidth=3)
# legend only for the clean grey line
plt.legend(['No Contamination'])

# visualize the confidence interval as a shade around the line
plt.fill_between(x, [z[0] for z in yerr], [z[1] for z in yerr], alpha=0.1)
# plot the line
plt.plot(x, y, 'o--', markersize=10)

plt.xticks(np.arange(369000, 371750, 1000))

plt.xlabel('Gradient Step')
plt.ylabel('Accuracy')



# percentage points on the y axis
plt.gca().yaxis.set_major_formatter(mtick.PercentFormatter(1.0))

# save the figure
plt.savefig('figures/olmo_hellaswag.pdf', bbox_inches='tight')

In [None]:
# rectangular figure
plt.figure(figsize=(6.6, 6))

# visualize the confidence interval as a shade around the line
plt.fill_between(x, [z[0] for z in yerr], [z[1] for z in yerr], alpha=0.1)
# plot the line
plt.plot(x, y, 'o--', markersize=10)

plt.xlabel('Gradient Step')
plt.ylabel('Accuracy')

# constant line at the clean accuracy
plt.xticks(np.arange(369000, 371750, 1000))

# percentage points on the y axis
plt.gca().yaxis.set_major_formatter(mtick.PercentFormatter(1.0))

# on a second axis, plot the theoretical decay factors
plt.twinx()
plt.plot(x[4:], theoretical_decay_factors[4:], '-', color='C0', markersize=10, linewidth=3)
plt.ylim(-0.08, 1.072)
plt.ylabel('Cum. Weight Decay')
# axis ticks from 0 to 1
plt.yticks(np.arange(0, 1.1, 0.2))
# color of the axis label
plt.gca().yaxis.label.set_color('black')
# and ticks
plt.gca().yaxis.set_tick_params(labelcolor='black')


# save the figure
plt.savefig('figures/olmo_hellaswag_wd.pdf', bbox_inches='tight')

In [None]:
# plot winogrande with bootstrap confidence intervals
x = list(winogrande.keys())
y = [winogrande[k] for k in x]
yerr = [bootstrap_confidence_interval(winogrande_samples, winogrande[key]) for key in x]

# rectangular figure
plt.figure(figsize=(6.5, 6))

# visualize the confidence interval as a shade around the line
plt.fill_between(x, [z[0] for z in yerr], [z[1] for z in yerr], alpha=0.1, color='C1')
# plot the line
plt.plot(x, y, 'o--', markersize=10, color='C1')
plt.ylabel('Accuracy')
plt.xlabel('Gradient Step')

# constant line at the clean accuracy
plt.axhline(y=y[0], linestyle='--', color='grey', linewidth=3)
plt.xticks(np.arange(369000, 371750, 1000))

# percentage points on the y axis
plt.gca().yaxis.set_major_formatter(mtick.PercentFormatter(1.0))

# save as pdf
plt.savefig('figures/olmo_winogrande.pdf', bbox_inches='tight')

In [None]:
# arc
x = list(arc.keys())
y = [arc[k] for k in x]
yerr = [bootstrap_confidence_interval(arc_easy_samples, arc[key]) for key in x]

# rectangular figure
plt.figure(figsize=(6.5, 6))

# visualize the confidence interval as a shade around the line
plt.fill_between(x, [z[0] for z in yerr], [z[1] for z in yerr], alpha=0.1, color='C2')
# plot the line
plt.plot(x, y, 'o--', markersize=10, color='C2')
plt.xlabel('Gradient Step')
plt.ylabel('Accuracy')

# constant line at the clean accuracy
plt.axhline(y=y[0], linestyle='--', color='grey', linewidth=3)
plt.xticks(np.arange(369000, 371750, 1000))

# percentage points on the y axis
plt.gca().yaxis.set_major_formatter(mtick.PercentFormatter(1.0))

# save as pdf
plt.savefig('figures/olmo_arc_easy.pdf', bbox_inches='tight')

In [None]:
# piqa
x = list(piqa.keys())
y = [piqa[k] for k in x]
yerr = [bootstrap_confidence_interval(piqa_samples, piqa[key]) for key in x]

# rectangular figure
plt.figure(figsize=(6.5, 6))

# visualize the confidence interval as a shade around the line
plt.fill_between(x, [z[0] for z in yerr], [z[1] for z in yerr], alpha=0.1, color='C3')
# plot the line
plt.plot(x, y, 'o--', markersize=10, color='C3')
plt.xlabel('Gradient Step')
plt.ylabel('Accuracy')

# constant line at the clean accuracy
plt.axhline(y=y[0], linestyle='--', color='grey', linewidth=3)
plt.xticks(np.arange(369000, 371750, 1000))

# percentage points on the y axis
plt.gca().yaxis.set_major_formatter(mtick.PercentFormatter(1.0))

# save as pdf
plt.savefig('figures/olmo_piqa.pdf', bbox_inches='tight')
