In [None]:
import plotly.express as px
import numpy as np
import duckdb
from tqdm import tqdm
import plotly

In [None]:
import sys
sys.path.append("/home/ubuntu/sky_workdir/encoding-schemes")

from encoding_schemes import get_deterministic_adherence_fn

In [None]:
import ray

ray.init()

In [None]:
import os
import psycopg2
import json

conn_string = os.environ["SUPABASE_CONNECTION_URL"]

conn = psycopg2.connect(conn_string)

import pandas as pd

In [None]:

sel_str = """
-- NuminaMath CoT Rerun
 (
     (data->'experiment_tags'->'numina_math_cot_rerun')::BOOL
     AND (NOT (data->'force_overwrite')::BOOL OR data->'force_overwrite' IS NULL)
     AND (
         (data->'experiment_params'->'sampling_params'->'n')::INT = 4
         OR ((data->'experiment_params'->'model')::TEXT LIKE '%gpt%' AND (data->'experiment_params'->'sft_params'->'batch_size')::INT != 48)
     )
  )
"""

df = pd.read_sql(f"""
SELECT * FROM public.encoding_schemes 
    WHERE 

{sel_str}


ORDER BY created_at DESC
""", conn)

df.head()

In [None]:
root_dir = "/home/ubuntu/sky_workdir/encoding-schemes/output"

In [None]:
l_examples = df.to_dict('records')

# Code

In [None]:

def bootstrap_ci(data, statistic=np.mean, alpha=0.05, n_boot=10_000, random_state=None):
    """
    Returns (point_estimate, low_CI, high_CI) for given 1D data.
    Works with bool, int, or float data.
    """
    x = np.asarray(data).astype(float)  # ensure numeric
    x = x[~np.isnan(x)]
    if len(x) == 0:
        raise ValueError("No valid data for bootstrapping.")

    rng = np.random.default_rng(random_state)
    n = len(x)

    # Draw bootstrap samples
    idx = rng.integers(0, n, size=(n_boot, n))
    samples = x[idx]

    # Apply statistic row-wise
    stats = np.apply_along_axis(statistic, 1, samples)

    point = statistic(x)
    lo = np.percentile(stats, 100 * (alpha / 2))
    hi = np.percentile(stats, 100 * (1 - alpha / 2))
    return point, lo, hi

In [None]:
import tiktoken

encoding = tiktoken.get_encoding("cl100k_base")


@ray.remote
def count_tokens_from_messages(s):
    try:
        return len(encoding.encode(s, disallowed_special=()))
    except ValueError as e:
        print(e)
        return 0


In [None]:
@ray.remote(num_cpus=2, memory=32 * 1024 * 1024 * 1024)
def compute_translation_token_count(example, df_data):
    sys.path.append("/home/ubuntu/sky_workdir/encoding-schemes")

    from orchestration.experiment_meta_saver import compute_experiment_hash
    from utils.io_utils import read_large_parquet

    from transformers import AutoTokenizer

    model = example["data"]["experiment_params"]["model"]
    if "gpt" in model or "claude" in model:
        print(f"Overriding tokenizer for {model} with gpt-oss 120b tokenizer because it was detected as a GPT/Claude model!")
        model = "openai/gpt-oss-120b"

    tokenizer = AutoTokenizer.from_pretrained(model)

    return df_data['reference_solution'].map(lambda x: len(tokenizer.encode(x)))


In [None]:
from tqdm import tqdm

In [None]:
def compute_ci_cols(example, df, col_name, transformation_fn):
    s_transformed = transformation_fn(df[col_name])
    if np.isscalar(s_transformed) and np.isnan(s_transformed):
        print(f"Warning: {col_name} was all NaN, ignoring!")
        return
    mid, lo, hi = bootstrap_ci(s_transformed)
    example[col_name] = mid
    example[f'{col_name}_low_ci'] = mid - lo
    example[f'{col_name}_hi_ci'] = hi - mid

In [None]:
def compute_multi_row_transformation(df, row_transform_fn, col_name, agg_fn):
    df[col_name] = df.apply(row_transform_fn, axis=1)

In [None]:

def _score_rollouts(rollouts):
    # rollouts: list/array of rollout sequences; may also be None/np.nan/scalar
    if rollouts is None or (isinstance(rollouts, float) and np.isnan(rollouts)):
        return np.nan

    vals = []
    for r in rollouts:
        # skip None/NaN
        if r is None or (isinstance(r, float) and np.isnan(r)):
            continue
        vals.append(-np.nansum(r))
    return np.nanmean(vals) if len(vals) else np.nan


def patch_gpt_api_log_loss(example):
    experiment_hash = example['experiment_hash']

    translation_loss = os.path.join(root_dir, experiment_hash, "data", f"validation_reverse_translation_math500_meta.json")
    with open(translation_loss, "r") as fp:
        example["backtranslation_gt_logprobs"] = translation_loss["valid_loss"]

    # need to be validation loss on 512k...
    validation_loss = os.path.join(root_dir, experiment_hash, "data", f"validation_reverse_translation_math500_meta.json")
    with open(validation_loss, "r") as fp:
        example["backtranslation_gt_logprobs"] = translation_loss["valid_loss"]


@ray.remote
def process_single_example(example):
    experiment_hash = example['experiment_hash']
    
    target_path = os.path.join(root_dir, example['experiment_hash'], "data", "joined_output.parquet")
    if not os.path.exists(target_path):
        print(f"!!!!!! {target_path} missing !!!!!!!")
        return example
    
    df_data = pd.read_parquet(target_path)

    d_col_to_transform = {
        'cot_gt_logprobs' : lambda s: np.nansum(s.map(_score_rollouts)),
        'generated_cot_is_correct' : np.mean,  # was np.mean
        'backtranslation_gt_logprobs' : lambda s: np.nanmean(s.map(_score_rollouts)),
        'backtranslation_bleu_scores' : np.mean,  # was np.mean
        'generated_cot_adhered_encoding_style': np.mean  # was np.mean
    }
    for col, fn in d_col_to_transform.items():
        if col not in df_data:
            print(col)
            print(df_data.head())
            print(example)
            raise Exception(str(col) + "\n" + str(df_data.head()) + "\n" + str(example))

        compute_ci_cols(example, df_data, col, fn)

    df_data["num_tokens_translation_output"] = ray.get(compute_translation_token_count.remote(example, df_data))

    d_and_cols = {
        'adherent_and_correct': (
            lambda r: np.nanmean( \
                np.array(r['generated_cot_is_correct']).astype(bool) & \
                np.array(r['generated_cot_adhered_encoding_style']).astype(bool) \
            ),
            np.nanmean
        ),
        'total_translation_loss': (
            lambda r: np.nanmean( \
                np.array(r['num_tokens_translation_output']) * \
                np.array(np.nanmean(_score_rollouts(r['backtranslation_gt_logprobs'])), dtype=np.float64) \
            ),
            np.nanmean
        ),
    }
    for col, (transform_fn, agg_fn) in d_and_cols.items():
        compute_multi_row_transformation(df_data, transform_fn, col, agg_fn)
        if df_data[col].isna().sum() != len(df_data):
            compute_ci_cols(example, df_data, col, lambda x: x)

    if "gpt" in example["data"]["experiment_params"]["model"] and "use_api_sft_model_for_sampling" in example["data"]["experiment_params"]:
        with open(os.path.join(root_dir, example['experiment_hash'], "data", f"validation_reverse_translation_math500_meta.json"), "r") as fp:
            d_model_meta = json.load(fp)

        example["backtranslation_gt_logprobs"] = d_model_meta["valid_loss"]
        example["total_translation_loss"] = d_model_meta["valid_loss"] * np.nanmean(df_data["num_tokens_translation_output"])
        example["total_translation_loss_low_ci"] = 0.0
        example["total_translation_loss_hi_ci"] = 0.0

    pretraining_prevalence_path = os.path.join('/home/ubuntu/sky_workdir/encoding-schemes', 'output', experiment_hash, 'data', 'num_pretraining_4grams_redpajama.json')
    if os.path.exists(pretraining_prevalence_path):
        with open(pretraining_prevalence_path, "r") as fp:
            d_pretraining_prevalence = json.load(fp)

        example["pretraining_prevalence"] = d_pretraining_prevalence["num_occurrences"]

    for col in df_data.columns:
        example[f"{col}_df"] = df_data[col]

    return example


l_new_examples = [None for _ in range(len(l_examples))]

for i, example in tqdm(enumerate(l_examples)):
    # l_examples[i] = process_single_example(example)
    l_new_examples[i] = process_single_example.remote(example)

for i, example in tqdm(enumerate(l_new_examples)):
    try:
        l_new_examples[i] = ray.get(example)
    except ray.exceptions.RayTaskError as e:
        l_new_examples[i] = l_examples[i]
        print(e)

l_examples = l_new_examples

In [None]:
def humanize_number(num: float) -> str:
    """
    Converts a number into a human-readable string with k, M, or B suffixes.
    
    Args:
        num (float): The number to format.
    
    Returns:
        str: Human-readable string representation.
    """
    if num >= 1_000_000_000:
        return f"{num / 1_000_000_000:.1f}B"
    elif num >= 1_000_000:
        return f"{num / 1_000_000:.1f}M"
    elif num >= 1_000:
        return f"{num / 1_000:.1f}k"
    else:
        return str(num)

In [None]:
import re

def parse_params(model):
    if 'gpt' in model:
        if 'nano' in model:
            return 0
        elif 'mini' in model:
            return 1
        else:
            return 2


    if 'claude' in model:
        if 'haiku' in model:
            return 3
        elif 'sonnet' in model:
            return 4
        else:
            return 5
    
    return int(re.search("([0-9]+)B", model).group(1))

In [None]:
df_viz = pd.DataFrame(l_examples)

orig_len = len(df_viz)

# df_viz = df_viz[df_viz['cot_gt_logprobs'].notna()]

new_len = len(df_viz)
if orig_len != new_len:
    print(f"Lost {orig_len - new_len} examples from na logprobs")

df_viz['encoding_scheme'] = df_viz['data'].map(lambda x: x['experiment_params']['encoding_scheme'])
df_viz['model'] = df_viz['data'].map(lambda x: x['experiment_params']['model'])

try:
    df_viz['model_size'] = df_viz['model'].map(parse_params)
except Exception as e:
    print(e)
df_viz['input_type'] = df_viz['data'].map(lambda x: "_".join(x['experiment_name'].split("_")[:2]))

df_viz['n_few_shot_examples'] = df_viz['data'].map(lambda x: x['experiment_params'].get('n_few_shot_examples', None))

df_viz['Adherence Calculation Method'] = df_viz['encoding_scheme'].map(lambda x: get_deterministic_adherence_fn(x, None) is not None).map({ True: 'deterministic', False: 'Sonnet 4 judge'})

try:
    df_viz['total_train_tok'] = df_viz['n_total_train_tok'].map(humanize_number)
except Exception as e:
    print(e)

In [None]:
df_viz['input_type'].unique()

In [None]:
filter_set = ['math_cot']

In [None]:
df_viz_tmp = df_viz[df_viz['input_type'].isin(filter_set)]
df_viz_tmp = df_viz_tmp[df_viz_tmp['model'] != 'Qwen/Qwen2.5-7B']

df_viz_tmp = df_viz_tmp.sort_values([
    'model_size',
    'adherent_and_correct'
])

df_viz_tmp = df_viz_tmp.astype({'n_few_shot_examples': str})

df_viz_tmp['encoding_scheme'] = df_viz_tmp['encoding_scheme'].map(lambda s: s.split("speaking_")[-1])

In [None]:
df_viz_tmp_plot = df_viz_tmp.copy()

fig = px.bar(df_viz_tmp_plot, x='encoding_scheme', y='generated_cot_adhered_encoding_style',
             height=800, width=1600,
             # height=1600, width=1600,
             # color='model',
             error_y='generated_cot_adhered_encoding_style_hi_ci',
             error_y_minus='generated_cot_adhered_encoding_style_low_ci',
             # color='n_few_shot_examples',
             # facet_row='model',
             color='model',
             facet_col='Adherence Calculation Method',
             # color='training_augmentation',
             # color='total_train_tok',
             # color_discrete_map=color_discrete_map,
             title="MATH 500 CoT encoding style adherence",
             barmode="group",
             template="plotly_white",
             color_discrete_sequence=px.colors.qualitative.Set2,
            )

fig.update_xaxes(title="Encoding scheme", tickangle=90)
fig.update_yaxes(title="% adherent encodings", dtick=0.1)

# ✅ Make each facet's x-axis independent to avoid empty slots
fig.for_each_xaxis(lambda ax: ax.update(matches=None, categoryorder='trace'))

# ✅ Align all x-axis titles by setting a fixed vertical offset
ct = [0]
def ax_standoff_updater(ax, ct):    
    ax.title.update(standoff=ct[0] * 130)
    ct[0] += 1

fig.for_each_xaxis(lambda ax: ax_standoff_updater(ax, ct))
fig.update_yaxes(title_standoff=5)

fig.show()

# Pretraining prevalence plot

In [None]:
import numpy as np
from scipy import stats
from scipy.optimize import curve_fit
import plotly.graph_objects as go
import plotly.express as px

df_viz_tmp_plot = df_viz_tmp.copy()
df_viz_tmp_plot['encoding_scheme'] = df_viz_tmp_plot['encoding_scheme'].str.replace('_', ' ')

l_algo_encodings = [
    'base64 cipher',
    'base64 2x cipher',
    'base64 3x cipher',
    'gzip to base64 encoded',
    'letter to word with dot',
    'dot between chars',
    'space between chars'
]

# Separate data into two groups
df_l_algo = df_viz_tmp_plot[df_viz_tmp_plot['encoding_scheme'].isin(l_algo_encodings)]
df_l_algo = pd.concat([
    df_l_algo,
    df_viz_tmp_plot[df_viz_tmp_plot['encoding_scheme'] == 'identity']
], ignore_index=True)

df_rest = df_viz_tmp_plot[~df_viz_tmp_plot['encoding_scheme'].isin(l_algo_encodings)]

# Create the scatter plot with color coding
df_viz_tmp_plot['group'] = df_viz_tmp_plot['encoding_scheme'].apply(
    lambda x: 'Algorithmic' if x in l_algo_encodings else 'identity' if x == 'identity' else 'Language'
)

fig = px.scatter(df_viz_tmp_plot, x='pretraining_prevalence', y='generated_cot_is_correct', 
                 log_x=True, text='encoding_scheme', color='group',
                 color_discrete_map={'Algorithmic': 'red', 'Language': 'blue', 'identity': 'magenta'})

# Update the scatter trace to not be part of the legend and keep its original appearance
fig.update_traces(textposition="top center", showlegend=False, selector=dict(mode='markers+text'))
fig.update_layout(legend_title_text=None) 

# Define fit functions
def linear_func(x, a, b):
    """Linear function: y = ax + b"""
    return a * x + b

def exponential_func(x, a, b, c):
    """Exponential function: y = a * exp(b * x) + c"""
    return a * np.exp(b * x) + c

def sigmoid_func(x, a, b, c, d):
    """Sigmoid function: y = a / (1 + exp(-b * (x - c))) + d"""
    return a / (1 + np.exp(-b * (x - c))) + d

def power_law_func(x, a, b):
    """Power law function: y = a * x^b"""
    return a * np.power(x, b)

def calculate_fit(df, fit_type='linear', use_log_x=True):
    """
    Calculate trend line based on specified fit type.
    
    Parameters:
    - df: DataFrame with data
    - fit_type: 'linear', 'exponential', 'sigmoid', or 'power_law'
    - use_log_x: Whether to use log-transformed x values
    """
    # Remove any rows with NaN or zero/negative x values
    df_clean = df.dropna(subset=['pretraining_prevalence', 'generated_cot_is_correct'])
    df_clean = df_clean[df_clean['pretraining_prevalence'] > 0]
    
    if len(df_clean) < 2:
        return None, None, None
    
    # Prepare x and y values
    x_original = df_clean['pretraining_prevalence'].values
    y = df_clean['generated_cot_is_correct'].values
    
    # For power law, we don't want to log-transform x if we're fitting directly
    if fit_type == 'power_law':
        x = x_original
    elif use_log_x:
        x = np.log10(x_original)
    else:
        x = x_original
    
    try:
        if fit_type == 'linear':
            # Linear regression
            slope, intercept, r_value, p_value, std_err = stats.linregress(x, y)
            params = [slope, intercept]
            
            # Calculate predictions for R²
            y_pred = linear_func(x, *params)
            
        elif fit_type == 'exponential':
            # Exponential fit
            # Initial guess: [amplitude, growth_rate, offset]
            initial_guess = [0.5, 0.1, 0.3]
            bounds = ([0, -np.inf, 0], [1, np.inf, 1])
            params, _ = curve_fit(exponential_func, x, y, p0=initial_guess, 
                                bounds=bounds, maxfev=5000)
            
            # Calculate predictions for R²
            y_pred = exponential_func(x, *params)
            
        elif fit_type == 'logistic':
            # Sigmoid fit
            # Initial guess: [amplitude, steepness, x_mid, y_offset]
            x_mid = np.median(x)
            initial_guess = [0.7, 1.0, x_mid, 0.2]
            bounds = ([0, 0, x.min(), 0], [1, 10, x.max(), 1])
            params, _ = curve_fit(sigmoid_func, x, y, p0=initial_guess, 
                                bounds=bounds, maxfev=5000)
            
            # Calculate predictions for R²
            y_pred = sigmoid_func(x, *params)
            
        elif fit_type == 'power_law':
            # Power law fit
            # For numerical stability, we can fit log(y) = log(a) + b*log(x)
            # Then convert back: a = exp(log_a), b = b
            
            # Remove any y values that are <= 0 for log transformation
            mask = y > 0
            if np.sum(mask) < 2:
                return None, None, None
            
            x_fit = x[mask]
            y_fit = y[mask]
            
            # Two approaches - try both and use the one with better R²
            
            # Approach 1: Direct fit
            try:
                initial_guess = [1e-10, 0.5]
                bounds = ([1e-15, -10], [1, 10])
                params1, _ = curve_fit(power_law_func, x_fit, y_fit, 
                                      p0=initial_guess, bounds=bounds, maxfev=5000)
                y_pred1 = power_law_func(x, *params1)
                ss_res1 = np.sum((y - y_pred1) ** 2)
                ss_tot1 = np.sum((y - np.mean(y)) ** 2)
                r2_1 = 1 - (ss_res1 / ss_tot1) if ss_tot1 > 0 else -np.inf
            except:
                r2_1 = -np.inf
                params1 = None
            
            # Approach 2: Log-log linear regression
            try:
                log_x = np.log10(x_fit)
                log_y = np.log10(y_fit)
                slope, log_a, r_value, p_value, std_err = stats.linregress(log_x, log_y)
                a = 10**log_a
                b = slope
                params2 = [a, b]
                y_pred2 = power_law_func(x, *params2)
                ss_res2 = np.sum((y - y_pred2) ** 2)
                ss_tot2 = np.sum((y - np.mean(y)) ** 2)
                r2_2 = 1 - (ss_res2 / ss_tot2) if ss_tot2 > 0 else -np.inf
            except:
                r2_2 = -np.inf
                params2 = None
            
            # Choose the better fit
            if r2_1 > r2_2 and params1 is not None:
                params = params1
                y_pred = y_pred1
            elif params2 is not None:
                params = params2
                y_pred = y_pred2
            else:
                return None, None, None
                
        else:
            raise ValueError(f"Unknown fit_type: {fit_type}")
        
        # Calculate R²
        ss_res = np.sum((y - y_pred) ** 2)
        ss_tot = np.sum((y - np.mean(y)) ** 2)
        r_squared = 1 - (ss_res / ss_tot) if ss_tot > 0 else 0
        
        return params, r_squared, fit_type
        
    except Exception as e:
        print(f"Error fitting {fit_type}: {e}")
        return None, None, None

# Configuration - CHANGE THESE TO SELECT FIT TYPE
FIT_TYPE = 'logistic'  # Options: 'linear', 'exponential', 'sigmoid', 'power_law'
USE_LOG_X = True  # Whether to use log-transformed x for fitting (not used for power_law)

# Calculate trends for both groups
params_l_algo, r2_l_algo, _ = calculate_fit(df_l_algo, fit_type=FIT_TYPE, use_log_x=USE_LOG_X)
params_rest, r2_rest, _ = calculate_fit(df_rest, fit_type=FIT_TYPE, use_log_x=USE_LOG_X)

# Generate trend line points
x_range = df_viz_tmp_plot['pretraining_prevalence'].values
x_range = x_range[x_range > 0]  # Remove non-positive values
x_range = np.sort(x_range)

if FIT_TYPE == 'power_law':
    x_fit_range = x_range
elif USE_LOG_X:
    x_fit_range = np.log10(x_range)
else:
    x_fit_range = x_range

# Add trend lines if calculations were successful
if params_l_algo is not None:
    if FIT_TYPE == 'linear':
        y_trend_l_algo = linear_func(x_fit_range, *params_l_algo)
    elif FIT_TYPE == 'exponential':
        y_trend_l_algo = exponential_func(x_fit_range, *params_l_algo)
    elif FIT_TYPE == 'logistic':
        y_trend_l_algo = sigmoid_func(x_fit_range, *params_l_algo)
    elif FIT_TYPE == 'power_law':
        y_trend_l_algo = power_law_func(x_fit_range, *params_l_algo)
    
    # Format the equation for power law
    if FIT_TYPE == 'power_law':
        eq_str = f'y = {params_l_algo[0]:.2e} * x^{params_l_algo[1]:.3f}'
        name_str = f'Algorithmic encoding, power law (R²={r2_l_algo:.3f})'
    else:
        name_str = f'Structure-disrupting cipher, {FIT_TYPE} fit (R²={r2_l_algo:.3f})'
    
    fig.add_trace(go.Scatter(
        x=x_range[:-9],
        y=y_trend_l_algo[:-9],
        mode='lines',
        name=name_str,
        line=dict(color='red', dash='dash'),
        showlegend=True
    ))

if params_rest is not None:
    if FIT_TYPE == 'linear':
        y_trend_rest = linear_func(x_fit_range, *params_rest)
    elif FIT_TYPE == 'exponential':
        y_trend_rest = exponential_func(x_fit_range, *params_rest)
    elif FIT_TYPE == 'logistic':
        y_trend_rest = sigmoid_func(x_fit_range, *params_rest)
    elif FIT_TYPE == 'power_law':
        y_trend_rest = power_law_func(x_fit_range, *params_rest)
    
    # Format the equation for power law
    if FIT_TYPE == 'power_law':
        eq_str = f'y = {params_rest[0]:.2e} * x^{params_rest[1]:.3f}'
        name_str = f'Language encoding, power law (R²={r2_rest:.3f})'
    else:
        name_str = f'Structure-preserving cipher, {FIT_TYPE} fit (R²={r2_rest:.3f})'
    
    fig.add_trace(go.Scatter(
        x=x_range[6:],
        y=y_trend_rest[6:],
        mode='lines',
        name=name_str,
        line=dict(color='blue', dash='dash'),
        showlegend=True
    ))

# Print the power law parameters if that's what we're using
if FIT_TYPE == 'power_law':
    if params_l_algo is not None:
        print(f"Algorithmic encoding power law: y = {params_l_algo[0]:.2e} * x^{params_l_algo[1]:.3f}")
    if params_rest is not None:
        print(f"Language encoding power law: y = {params_rest[0]:.2e} * x^{params_rest[1]:.3f}")

# Update layout
fig.update_layout(
    height=600, 
    width=800,
    showlegend=True,
    legend=dict(
        yanchor="top",
        y=0.99,
        xanchor="left",
        x=0.01
    ),
    # title=f"MATH500 Accuracy vs Pretraining Prevalence ({FIT_TYPE.capitalize().replace('_', ' ')} Fit)"
)

fig.update_layout(margin=dict(l=0, r=0, t=0, b=0))

fig.update_yaxes(dtick=0.1, title="MATH500 accuracy", range=[0, 0.8])
fig.update_xaxes(title="Pretraining prevalence", type="log", range=[8.8, 10.5])
fig.update_layout(template="plotly_white", font=dict(size=22))
fig.update_traces(textfont_size=16)
fig.update_traces(marker=dict(size=9))

def format_tokens(value):
    if value >= 1_000_000_000:
        return f'{value/1_000_000_000:.0f}B' if value % 1_000_000_000 == 0 else f'{value/1_000_000_000:.1f}B'
    elif value >= 1_000_000:
        return f'{value/1_000_000:.0f}M' if value % 1_000_000 == 0 else f'{value/1_000_000:.1f}M'
    else:
        return str(value)

# Generate tickvals based on your data range
tickvals = [700_000_000, 1_000_000_000, 2_000_000_000, 3_000_000_000, 4_000_000_000, 5_000_000_000, 7_000_000_000, 10_000_000_000, 20_000_000_000, 30_000_000_000]
ticktext = [format_tokens(val) for val in tickvals]

fig.update_xaxes(
    title="Pretraining prevalence",
    tickvals=tickvals,
    ticktext=ticktext,
    tickmode='array'
)


plotly.io.write_image(fig, 'pretraining_scaling_perf.pdf', format='pdf')
fig.show('png')
