In [None]:
import plotly.express as px
import numpy as np
import duckdb
from tqdm import tqdm
import plotly

In [None]:
import sys
sys.path.append("/home/ubuntu/sky_workdir/encoding-schemes")

from encoding_schemes import get_deterministic_adherence_fn

In [None]:
import ray

ray.init()

In [None]:
import os
import psycopg2
import json

conn_string = os.environ["SUPABASE_CONNECTION_URL"]

conn = psycopg2.connect(conn_string)

import pandas as pd

In [None]:
sel_str = """
-- data scaling qwen
 (
     (NOT (data->'force_overwrite')::BOOL OR data->'force_overwrite' IS NULL)
     AND (
         (data->'experiment_name')::TEXT LIKE '%datascaling%'
     )
  ) OR
-- OpenAI SFT
(
     (data->'experiment_tags'->'numina_math_cot_rerun')::BOOL
     AND (NOT (data->'force_overwrite')::BOOL OR data->'force_overwrite' IS NULL)
     AND (
         ((data->'experiment_params'->'model')::TEXT LIKE '%gpt%' AND (data->'experiment_params'->'sft_params'->'batch_size')::INT != 48)
     )
     AND (data->'experiment_name')::TEXT LIKE '%math_cot%'
)
"""

df = pd.read_sql(f"""
SELECT * FROM public.encoding_schemes 
    WHERE 

{sel_str}


ORDER BY created_at DESC
""", conn)

df.head()

In [None]:
root_dir = "/home/ubuntu/sky_workdir/encoding-schemes/output"

# Code

In [None]:

def bootstrap_ci(data, statistic=np.mean, alpha=0.05, n_boot=10_000, random_state=None):
    """
    Returns (point_estimate, low_CI, high_CI) for given 1D data.
    Works with bool, int, or float data.
    """
    x = np.asarray(data).astype(float)  # ensure numeric
    x = x[~np.isnan(x)]
    if len(x) == 0:
        raise ValueError("No valid data for bootstrapping.")

    rng = np.random.default_rng(random_state)
    n = len(x)

    # Draw bootstrap samples
    idx = rng.integers(0, n, size=(n_boot, n))
    samples = x[idx]

    # Apply statistic row-wise
    stats = np.apply_along_axis(statistic, 1, samples)

    point = statistic(x)
    lo = np.percentile(stats, 100 * (alpha / 2))
    hi = np.percentile(stats, 100 * (1 - alpha / 2))
    return point, lo, hi

In [None]:
import tiktoken

encoding = tiktoken.get_encoding("cl100k_base")


@ray.remote
def count_tokens_from_messages(s):
    try:
        return len(encoding.encode(s, disallowed_special=()))
    except ValueError as e:
        print(e)
        return 0


In [None]:
@ray.remote(num_cpus=2, memory=32 * 1024 * 1024 * 1024)
def compute_translation_token_count(example, df_data):
    sys.path.append("/home/ubuntu/sky_workdir/encoding-schemes")

    from orchestration.experiment_meta_saver import compute_experiment_hash
    from utils.io_utils import read_large_parquet

    from transformers import AutoTokenizer

    model = example["data"]["experiment_params"]["model"]
    if "gpt" in model or "claude" in model:
        print(f"Overriding tokenizer for {model} with gpt-oss 120b tokenizer because it was detected as a GPT/Claude model!")
        model = "openai/gpt-oss-120b"

    tokenizer = AutoTokenizer.from_pretrained(model)

    return df_data['reference_solution'].map(lambda x: len(tokenizer.encode(x)))


In [None]:
from tqdm import tqdm

In [None]:
def compute_ci_cols(example, df, col_name, transformation_fn):
    s_transformed = transformation_fn(df[col_name])
    if np.isscalar(s_transformed) and np.isnan(s_transformed):
        print(f"Warning: {col_name} was all NaN, ignoring!")
        return
    mid, lo, hi = bootstrap_ci(s_transformed)
    example[col_name] = mid
    example[f'{col_name}_low_ci'] = mid - lo
    example[f'{col_name}_hi_ci'] = hi - mid

In [None]:
def compute_multi_row_transformation(df, row_transform_fn, col_name, agg_fn):
    df[col_name] = df.apply(row_transform_fn, axis=1)

In [None]:

def _score_rollouts(rollouts):
    # rollouts: list/array of rollout sequences; may also be None/np.nan/scalar
    if rollouts is None or (isinstance(rollouts, float) and np.isnan(rollouts)):
        return np.nan

    vals = []
    for r in rollouts:
        # skip None/NaN
        if r is None or (isinstance(r, float) and np.isnan(r)):
            continue
        vals.append(-np.nansum(r))
    return np.nanmean(vals) if len(vals) else np.nan


@ray.remote
def process_single_example(example, step):
    model_name = example["data"]["experiment_params"]["model"]
    
    experiment_hash = example['experiment_hash']

    if "gpt" in model_name:
        step_prefix = ""
    else:
        step_prefix = f"processed_{step}"
    target_path = os.path.join(root_dir, example['experiment_hash'], "data", step_prefix, "joined_output.parquet")
    if not os.path.exists(target_path):
        print(f"!!!!!! {target_path} missing !!!!!!!")
        return example
    
    df_data = pd.read_parquet(target_path)

    d_col_to_transform = {
        'cot_gt_logprobs' : lambda s: np.nansum(s.map(_score_rollouts)),
        'generated_cot_is_correct' : np.mean,  # was np.mean
        'backtranslation_gt_logprobs' : lambda s: np.nanmean(s.map(_score_rollouts)),
        'backtranslation_bleu_scores' : np.mean,  # was np.mean
        'generated_cot_adhered_encoding_style': np.mean  # was np.mean
    }
    for col, fn in d_col_to_transform.items():
        if col not in df_data:
            print(col)
            print(df_data.head())
            print(example)
            raise Exception(str(col) + "\n" + str(df_data.head()) + "\n" + str(example))

        compute_ci_cols(example, df_data, col, fn)

    df_data["num_tokens_translation_output"] = ray.get(compute_translation_token_count.remote(example, df_data))

    d_and_cols = {
        'adherent_and_correct': (
            lambda r: np.nanmean( \
                np.array(r['generated_cot_is_correct']).astype(bool) & \
                np.array(r['generated_cot_adhered_encoding_style']).astype(bool) \
            ),
            np.nanmean
        ),
        'total_translation_loss': (
            lambda r: np.nanmean( \
                np.array(r['num_tokens_translation_output']) * \
                np.array(np.nanmean(_score_rollouts(r['backtranslation_gt_logprobs'])), dtype=np.float64) \
            ),
            np.nanmean
        ),
    }
    for col, (transform_fn, agg_fn) in d_and_cols.items():
        compute_multi_row_transformation(df_data, transform_fn, col, agg_fn)
        if df_data[col].isna().sum() != len(df_data):
            compute_ci_cols(example, df_data, col, lambda x: x)

    for col in df_data.columns:
        example[f"{col}_df"] = df_data[col]

    if "gpt" in model_name and "use_api_sft_model_for_sampling" in example["data"]["experiment_params"]:
        with open(os.path.join(root_dir, example['experiment_hash'], "data", f"validation_numinamath_validation_meta.json"), "r") as fp:
            d_model_meta = json.load(fp)

        df_valid = duckdb.query(f"SELECT num_tokens FROM '{os.path.join(root_dir, example['experiment_hash'], 'data', 'sft.parquet')}'").to_df()

        example["total_validation_loss"] = d_model_meta["valid_loss"] * np.nansum(df_valid['num_tokens'])

        with open(os.path.join(root_dir, example['experiment_hash'], "data", f"validation_reverse_translation_math500_meta.json"), "r") as fp:
            d_model_meta = json.load(fp)

        example["backtranslation_gt_logprobs"] = d_model_meta["valid_loss"]
        example["total_translation_loss"] = d_model_meta["valid_loss"] * np.nanmean(df_data["num_tokens_translation_output"])
    else:
        # Append # training toks
        example['training_step'] = step
    
        df_sft = duckdb.query(f"SELECT num_tokens FROM '{os.path.join(root_dir, example['experiment_hash'], 'data', 'sft_train.parquet')}'").to_df()
        batch_size = example['data']['experiment_params']['sft_params']['batch_size']
        example['num_training_tokens'] = df_sft.iloc[:step * batch_size]['num_tokens'].sum()
        example['mean_training_token_length'] = df_sft.iloc[:step * batch_size]['num_tokens'].mean()

        try:
            df_valid_logprobs = duckdb.query(f"SELECT gt_logprobs FROM '{os.path.join(root_dir, example['experiment_hash'], 'data', step_prefix, 'valid_logprobs.parquet')}'").to_df()
            example['mean_valid_sequence_log_loss'] = np.nanmean([
                -np.nansum(r['gt_logprobs'])
                for _, r in df_valid_logprobs.iterrows()
            ])
    
            example['total_validation_loss'] = np.nansum([
                -np.nansum(r['gt_logprobs'])
                for _, r in df_valid_logprobs.iterrows()
            ])
    
            example['total_total_translation_loss'] = np.nansum([
                -np.nansum(r['backtranslation_gt_logprobs'])
                for _, r in df_data.iterrows()
            ])
        except Exception as e:
            print(e)

    return example

l_examples = df.to_dict('records')

l_steps = [128, 256, 512, 1024, 2048, 3712]
l_new_examples = [None for _ in range(len(l_examples) * len(l_steps))]

for i, example in tqdm(enumerate(l_examples)):
    for j in range(len(l_steps)):
        # l_examples[i] = process_single_example(example)
        l_new_examples[i * len(l_steps) + j] = process_single_example.remote(example, l_steps[j])

for i, example in tqdm(enumerate(l_new_examples)):
    try:
        l_new_examples[i] = ray.get(example)
    except ray.exceptions.RayTaskError as e:
        # l_new_examples[i] = l_examples[i]
        l_new_examples[i] = {}
        print(e)

l_examples = l_new_examples

In [None]:
l_examples = [e for e in l_examples if e != {}]

In [None]:
def humanize_number(num: float) -> str:
    """
    Converts a number into a human-readable string with k, M, or B suffixes.
    
    Args:
        num (float): The number to format.
    
    Returns:
        str: Human-readable string representation.
    """
    if num >= 1_000_000_000:
        return f"{num / 1_000_000_000:.1f}B"
    elif num >= 1_000_000:
        return f"{num / 1_000_000:.1f}M"
    elif num >= 1_000:
        return f"{num / 1_000:.1f}k"
    else:
        return str(num)

In [None]:
import re

def parse_params(model):
    if 'gpt' in model:
        if 'nano' in model:
            return 0
        elif 'mini' in model:
            return 1
        else:
            return 2


    if 'claude' in model:
        if 'haiku' in model:
            return 3
        elif 'sonnet' in model:
            return 4
        else:
            return 5
    
    return int(re.search("([0-9]+)B", model).group(1))

In [None]:
df_viz = pd.DataFrame(l_examples)

orig_len = len(df_viz)

# df_viz = df_viz[df_viz['cot_gt_logprobs'].notna()]

new_len = len(df_viz)
if orig_len != new_len:
    print(f"Lost {orig_len - new_len} examples from na logprobs")

df_viz['encoding_scheme'] = df_viz['data'].map(lambda x: x['experiment_params']['encoding_scheme'])
df_viz['model'] = df_viz['data'].map(lambda x: x['experiment_params']['model'])

try:
    df_viz['model_size'] = df_viz['model'].map(parse_params)
except Exception as e:
    print(e)
df_viz['input_type'] = df_viz['data'].map(lambda x: "_".join(x['experiment_name'].split("_")[:2]))

df_viz['n_few_shot_examples'] = df_viz['data'].map(lambda x: x['experiment_params'].get('n_few_shot_examples', None))

df_viz['Adherence Calculation Method'] = df_viz['encoding_scheme'].map(lambda x: get_deterministic_adherence_fn(x, None) is not None).map({ True: 'deterministic', False: 'Sonnet 4 judge'})

try:
    df_viz['total_train_tok'] = df_viz['n_total_train_tok'].map(humanize_number)
except Exception as e:
    print(e)

df_viz.head()

In [None]:
df_viz['input_type'].unique()

In [None]:
filter_set = ['datascaling_mathcot', 'math_cot']

In [None]:
df_viz_tmp = df_viz[df_viz['input_type'].isin(filter_set)]
df_viz_tmp = df_viz_tmp[df_viz_tmp['model'] != 'Qwen/Qwen2.5-7B']

df_viz_tmp = df_viz_tmp.sort_values([
    'num_training_tokens',
    'model_size'
])

df_viz_tmp = df_viz_tmp.astype({'n_few_shot_examples': str})
df_viz_tmp['model'] = df_viz_tmp['model'].str.split('Qwen/').str[-1]

df_viz_tmp['encoding_scheme'] = df_viz_tmp['encoding_scheme'].map(lambda s: s.split("speaking_")[-1])

In [None]:
QWEN25_INSTRUCT_CONFIGS = {
    "Qwen2.5-0.5B-Instruct": {
        "hidden_size": 896,
        "vocab_size": 151936,
        "num_hidden_layers": 24,
        "num_key_value_heads": 2,
        "num_attention_heads": 14,
        "intermediate_size": 4864,
    },
    "Qwen2.5-1.5B-Instruct": {
        "hidden_size": 1536,
        "vocab_size": 151936,
        "num_hidden_layers": 28,
        "num_key_value_heads": 2,
        "num_attention_heads": 12,
        "intermediate_size": 8960,
    },
    "Qwen2.5-3B-Instruct": {
        "hidden_size": 2048,
        "vocab_size": 151936,
        "num_hidden_layers": 36,
        "num_key_value_heads": 2,
        "num_attention_heads": 16,
        "intermediate_size": 11008,
    },
    "Qwen2.5-7B-Instruct": {
        "hidden_size": 3584,
        "vocab_size": 152064,
        "num_hidden_layers": 28,
        "num_key_value_heads": 4,
        "num_attention_heads": 28,
        "intermediate_size": 18944,
    },
    "Qwen2.5-14B-Instruct": {
        "hidden_size": 5120,
        "vocab_size": 152064,
        "num_hidden_layers": 48,
        "num_key_value_heads": 8,
        "num_attention_heads": 40,
        "intermediate_size": 13824,
    },
    "Qwen2.5-32B-Instruct": {
        "hidden_size": 5120,
        "vocab_size": 152064,
        "num_hidden_layers": 64,
        "num_key_value_heads": 8,
        "num_attention_heads": 40,
        "intermediate_size": 27648,
    },
    "Qwen2.5-72B-Instruct": {
        "hidden_size": 8192,
        "vocab_size": 152064,
        "num_hidden_layers": 80,
        "num_key_value_heads": 8,
        "num_attention_heads": 64,
        "intermediate_size": 29568,
    },
}


# taken from verl
def _estimate_qwen2_flops(hidden_size, vocab_size, num_hidden_layers, num_key_value_heads, num_attention_heads, intermediate_size, tokens_sum, batch_seqlens):
    hidden_size = hidden_size
    vocab_size = vocab_size
    num_hidden_layers = num_hidden_layers
    num_key_value_heads = num_key_value_heads
    num_attention_heads = num_attention_heads
    intermediate_size = intermediate_size

    head_dim = hidden_size // num_attention_heads
    q_size = num_attention_heads * head_dim
    k_size = num_key_value_heads * head_dim
    v_size = num_key_value_heads * head_dim

    # non-attn per layer parm
    # Qwen2/LLama use SwiGelu, gate, having up and down linear layer in mlp
    mlp_N = hidden_size * intermediate_size * 3
    attn_linear_N = hidden_size * (q_size + k_size + v_size + num_attention_heads * head_dim)
    emd_and_lm_head_N = vocab_size * hidden_size * 2
    # non-attn all_layer parm
    dense_N = (mlp_N + attn_linear_N) * num_hidden_layers + emd_and_lm_head_N
    # non-attn all_layer & all_token fwd & bwd flops
    dense_N_flops = 6 * dense_N * tokens_sum

    # attn all_layer & all_token fwd & bwd flops
    seqlen_square_sum = 0
    for seqlen in batch_seqlens:
        seqlen_square_sum += seqlen * seqlen
    attn_qkv_flops = 12 * seqlen_square_sum * head_dim * num_attention_heads * num_hidden_layers

    # all_layer & all_token fwd & bwd flops
    flops_all_token = dense_N_flops + attn_qkv_flops
    return flops_all_token


df_viz_tmp.loc[df_viz_tmp['model'].str.contains('Qwen'), 'training_flops'] = df_viz_tmp.loc[df_viz_tmp['model'].str.contains('Qwen'), :].apply(lambda row: _estimate_qwen2_flops(**QWEN25_INSTRUCT_CONFIGS[row['model']], \
                                                                                  tokens_sum=row['num_training_tokens'], \
                                                                                  batch_seqlens=[row['mean_training_token_length'] for _ in range(row['data']['experiment_params']['sft_params']['batch_size'])]), axis=1) * 3

# Paper plots

In [None]:
import plotly.express as px
import plotly.colors as pc

def sample_colorscale(colorscale_name: str, n: int):
    """
    Sample N equally spaced hex colors from a Plotly continuous colorscale.

    Parameters
    ----------
    colorscale_name : str
        Name of a Plotly built-in continuous colorscale (e.g., "Viridis", "Blues").
    n : int
        Number of samples to return.

    Returns
    -------
    list of str
        List of hex color strings.
    """
    if n < 1:
        raise ValueError("n must be >= 1")
    colorscale = px.colors.get_colorscale(colorscale_name)
    # equally spaced points in [0,1]
    vals = [i/(n-1) if n > 1 else 0.5 for i in range(n)]
    return [pc.sample_colorscale(colorscale, v)[0] for v in vals]

l_gpt_gradient = sample_colorscale("Emrld", 4 + 1)
d_gpt_gradient = {
    model : color
    for model, color in zip(
        ['gpt-4.1-nano-2025-04-14', 'gpt-4.1-mini-2025-04-14', 'gpt-4.1-2025-04-14', 'gpt-5-chat-latest'],
        l_gpt_gradient[1:]
    )
}

l_claude_gradient = sample_colorscale("Magenta", 3 + 1)
d_claude_gradient = {
    model : color
    for model, color in zip(
        ['claude-3-opus-20240229', 'claude-3-5-sonnet-20241022', 'claude-sonnet-4-20250514'],
        l_claude_gradient[1:]
    )
}

l_qwen_gradient = sample_colorscale("Oryel", 3 + 1)
d_qwen_gradient = {
    model : color
    for model, color in zip(
        ['Qwen2.5-3B', 'Qwen2.5-7B', 'Qwen2.5-14B'],
        l_qwen_gradient[1:]
    )
}

In [None]:
df_viz_tmp['model'] = df_viz_tmp['model'].str.split('-Instruct').str[0]

In [None]:
fig = px.line(df_viz_tmp[(df_viz_tmp['encoding_scheme'] == 'letter_to_word_with_dot') & (df_viz_tmp['model'].str.contains('Qwen'))],
              x='num_training_tokens', y='adherent_and_correct', color='model',  markers=True, log_x=True,
             color_discrete_map=d_qwen_gradient)

fig.update_layout(height=400, width=525, template="plotly_white", font=dict(size=15), showlegend=False)
fig.update_yaxes(title="% adherent & correct responses")
fig.update_xaxes(title="# fine-tuning tokens")

df_for_labels = df_viz_tmp[(df_viz_tmp['encoding_scheme'] == 'letter_to_word_with_dot') & (df_viz_tmp['model'].str.contains('Qwen'))]
for model in df_for_labels['model'].unique():
    model_data = df_for_labels[df_for_labels['model'] == model]
    # Get the point with maximum x value (rightmost point)
    max_x_row = model_data.loc[model_data['num_training_tokens'].idxmax()]
    
    fig.add_annotation(
        x=np.log10(max_x_row['num_training_tokens']),
        y=max_x_row['adherent_and_correct'],
        text=model.split('2.5-')[1] if '2.5-' in model else model,
        showarrow=False,
        xanchor='left',
        yanchor='middle',
        xshift=10,
        font=dict(size=16, color=d_qwen_gradient.get(model, 'black')),
    )


def format_tokens(value):
    if value >= 1_000_000_000:
        return f'{value/1_000_000_000:.0f}B' if value % 1_000_000_000 == 0 else f'{value/1_000_000_000:.1f}B'
    elif value >= 1_000_000:
        return f'{value/1_000_000:.0f}M' if value % 1_000_000 == 0 else f'{value/1_000_000:.1f}M'
    else:
        return str(value)

# Generate tickvals based on your data range
tickvals = [100_000_000, 200_000_000, 500_000_000, 1_000_000_000, 2_000_000_000, 5_000_000_000]
ticktext = [format_tokens(val) for val in tickvals]

fig.update_xaxes(
    title="# fine-tuning tokens",
    tickvals=tickvals,
    ticktext=ticktext,
    tickmode='array'
)

for _, row in df_viz_tmp[df_viz_tmp['model'].str.contains('Qwen')].sort_values('adherent_and_correct', ascending=False).drop_duplicates(subset=['model', 'encoding_scheme'], keep='first').iterrows():
    if row['encoding_scheme'] == 'identity' and ('7B' in row['model'] or '14B' in row['model']):
        print(row['adherent_and_correct'], row['model'])
        fig.add_hline(y=row['adherent_and_correct'], annotation_text=row['model'].split('2.5-')[1] + " unciphered", line_dash='dot', line=dict(color='rgba(0, 0, 0, 0.3)'), annotation_position="top left", annotation_font_color='#878787',)
    elif row['encoding_scheme'] == 'zero_shot'  and ('14B' in row['model']):
        fig.add_hline(y=row['adherent_and_correct'], annotation_text=row['model'].split('2.5-')[1] + " direct answering", line_dash='dot', line=dict(color='rgba(0, 0, 0, 0.3)'), annotation_position="top left", annotation_font_color='#878787',)

plotly.io.write_image(fig, 'datascaling_math_score.pdf', format='pdf')

fig.show()

In [None]:
l_pgr = []
for _, row in df_viz_tmp.iterrows():
    df_search = df_viz_tmp[(df_viz_tmp['encoding_scheme'] == 'identity') & (df_viz_tmp['model'] == row['model'])]
    pgr = row['adherent_and_correct'] / df_search['adherent_and_correct'].max()
    # gen discriminator gap
    # l_pgr.append(pgr - (row['backtranslation_bleu_scores'] / 100))
    # PGR
    l_pgr.append(pgr)
df_viz_tmp['PGR_pct'] = l_pgr

fig = px.line(df_viz_tmp[(df_viz_tmp['encoding_scheme'] == 'letter_to_word_with_dot') & (df_viz_tmp['model'].str.contains('Qwen'))],
              x='training_flops', y='PGR_pct', color='model',  markers=True, log_x=True,
             color_discrete_map=d_qwen_gradient)

fig.update_layout(height=400, width=525, template="plotly_white", font=dict(size=15))
fig.update_yaxes(title="PGR")#, range=[-0.8, 0])
fig.update_xaxes(title="Fine-tuning FLOPs")

# for _, row in df_viz_tmp[df_viz_tmp['model'].str.contains('Qwen')].sort_values('adherent_and_correct', ascending=False).drop_duplicates(subset=['model', 'encoding_scheme'], keep='first').iterrows():
#     if row['encoding_scheme'] == 'identity':
#         fig.add_hline(y=row['adherent_and_correct'], annotation_text=row['model'].split('2.5-')[1] + " unencoded", line_dash='dash', line=dict(color='rgba(0, 0, 0, 0.3)'), annotation_position="top left")

fig.show('png')

In [None]:
fig = px.line(df_viz_tmp[(df_viz_tmp['encoding_scheme'] == 'letter_to_word_with_dot') & (df_viz_tmp['model'].str.contains('Qwen'))],
              x='num_training_tokens', y='backtranslation_bleu_scores', color='model', markers=True, log_x=True, color_discrete_map=d_qwen_gradient)

fig.update_layout(height=400, width=525, template="plotly_white", font=dict(size=15), showlegend=False)
fig.update_yaxes(title="BLEU score")
fig.update_xaxes(title="# fine-tuning tokens")

df_for_labels = df_viz_tmp[(df_viz_tmp['encoding_scheme'] == 'letter_to_word_with_dot') & (df_viz_tmp['model'].str.contains('Qwen'))]
for model in df_for_labels['model'].unique():
    model_data = df_for_labels[df_for_labels['model'] == model]
    # Get the point with maximum x value (rightmost point)
    max_x_row = model_data.loc[model_data['num_training_tokens'].idxmin()]
    
    fig.add_annotation(
        x=np.log10(max_x_row['num_training_tokens']),
        y=max_x_row['backtranslation_bleu_scores'],
        text=model.split('2.5-')[1] if '2.5-' in model else model,
        showarrow=False,
        xanchor='right',
        yanchor='middle',
        xshift=-10,
        font=dict(size=16, color=d_qwen_gradient.get(model, 'black')),
    )


def format_tokens(value):
    if value >= 1_000_000_000:
        return f'{value/1_000_000_000:.0f}B' if value % 1_000_000_000 == 0 else f'{value/1_000_000_000:.1f}B'
    elif value >= 1_000_000:
        return f'{value/1_000_000:.0f}M' if value % 1_000_000 == 0 else f'{value/1_000_000:.1f}M'
    else:
        return str(value)

# Generate tickvals based on your data range
tickvals = [100_000_000, 200_000_000, 500_000_000, 1_000_000_000, 2_000_000_000, 5_000_000_000]
ticktext = [format_tokens(val) for val in tickvals]

fig.update_xaxes(
    title="# fine-tuning tokens",
    tickvals=tickvals,
    ticktext=ticktext,
    tickmode='array'
)

plotly.io.write_image(fig, 'datascaling_translation_bleu.pdf', format='pdf')

fig.show('svg')

In [None]:
fig = px.line(df_viz_tmp[(df_viz_tmp['encoding_scheme'] == 'letter_to_word_with_dot') & (df_viz_tmp['model'].str.contains('Qwen'))],
              x='training_flops', y='total_translation_loss', color='model', markers=True, log_x=True, color_discrete_map=d_qwen_gradient)

fig.update_layout(height=400, width=525, template="plotly_white", font=dict(size=15))
fig.update_yaxes(title="Total decoding loss")
fig.update_xaxes(title="Fine-tuning FLOPs")

for _, row in df_viz_tmp[df_viz_tmp['model'].str.contains('gpt')].drop_duplicates(subset=['model', 'encoding_scheme']).iterrows():
    if row['encoding_scheme'] == 'dot_between_chars':
        fig.add_hline(y=row['total_translation_loss'], annotation_text=row['model'].split('-2025')[0] + "-SFT", line_dash='dash', line=dict(color='rgba(0, 0, 0, 0.3)'), annotation_position="top right" if "gpt-4.1-2025" in row['model'] else "top left")
# TODO: add GPT 4.1 series here!

fig.show('png')

In [None]:
fig = px.line(df_viz_tmp[(df_viz_tmp['encoding_scheme'] == 'letter_to_word_with_dot') & (df_viz_tmp['model'].str.contains('Qwen'))],
              x='training_flops', y='mean_valid_sequence_log_loss', color='model', markers=True, log_x=True, color_discrete_map=d_qwen_gradient)

fig.update_layout(height=400, width=525, template="plotly_white")
fig.update_yaxes(title="Mean per-sequence loss, valid set")
fig.update_xaxes(title="Fine-tuning FLOPs")

fig.show('png')

In [None]:
df_viz_tmp[df_viz_tmp['model'].str.contains('14B') & (df_viz_tmp['encoding_scheme'] == 'letter_to_word_with_dot')].T

In [None]:
df_valid = pd.read_parquet("/home/ubuntu/sky_workdir/encoding-schemes/output/3863e9f29ee0a696a4f056d96cdfb694506f8202/data/sft.parquet")
df_valid.head()

In [None]:
duckdb.query("SELECT SUM(num_tokens) FROM '/home/ubuntu/sky_workdir/encoding-schemes/output/e3f76d94751a09af21005696ac072219143ebc58/data/sft_train.parquet'").to_df()

# grid scaling laws

In [None]:
for encoding_scheme in ['dot_between_chars', 'zero_shot', 'identity']:
    fig = px.line(df_viz_tmp[(df_viz_tmp['encoding_scheme'] == encoding_scheme) & (df_viz_tmp['model'].str.contains('Qwen'))],
                  x='num_training_tokens', y='adherent_and_correct', color='model',  markers=True, log_x=True,
                 color_discrete_map=d_qwen_gradient)
    
    fig.update_layout(height=400, width=525, template="plotly_white", font=dict(size=15), showlegend=False)
    fig.update_yaxes(title="% adherent & correct responses")
    fig.update_xaxes(title="# fine-tuning tokens")
    
    df_for_labels = df_viz_tmp[(df_viz_tmp['encoding_scheme'] == encoding_scheme) & (df_viz_tmp['model'].str.contains('Qwen'))]
    for model in df_for_labels['model'].unique():
        model_data = df_for_labels[df_for_labels['model'] == model]
        # Get the point with maximum x value (rightmost point)
        max_x_row = model_data.loc[model_data['num_training_tokens'].idxmax()]
        
        fig.add_annotation(
            x=np.log10(max_x_row['num_training_tokens']),
            y=max_x_row['adherent_and_correct'],
            text=model.split('2.5-')[1] if '2.5-' in model else model,
            showarrow=False,
            xanchor='left',
            yanchor='middle',
            xshift=10,
            font=dict(size=16, color=d_qwen_gradient.get(model, 'black')),
        )
    
    
    def format_tokens(value):
        if value >= 1_000_000_000:
            return f'{value/1_000_000_000:.0f}B' if value % 1_000_000_000 == 0 else f'{value/1_000_000_000:.1f}B'
        elif value >= 1_000_000:
            return f'{value/1_000_000:.0f}M' if value % 1_000_000 == 0 else f'{value/1_000_000:.1f}M'
        else:
            return str(value)
    
    # Generate tickvals based on your data range
    tickvals = [100_000_000, 200_000_000, 500_000_000, 1_000_000_000, 2_000_000_000, 5_000_000_000]
    ticktext = [format_tokens(val) for val in tickvals]
    
    fig.update_xaxes(
        title="# fine-tuning tokens",
        tickvals=tickvals,
        ticktext=ticktext,
        tickmode='array'
    )
    
    for _, row in df_viz_tmp[df_viz_tmp['model'].str.contains('Qwen')].sort_values('adherent_and_correct', ascending=False).drop_duplicates(subset=['model', 'encoding_scheme'], keep='first').iterrows():
        if row['encoding_scheme'] == 'identity' and ('7B' in row['model'] or '14B' in row['model']):
            print(row['adherent_and_correct'], row['model'])
            fig.add_hline(y=row['adherent_and_correct'], annotation_text=row['model'].split('2.5-')[1] + " unciphered", line_dash='dot', line=dict(color='rgba(0, 0, 0, 0.3)'), annotation_position="top left", annotation_font_color='#878787',)
        elif row['encoding_scheme'] == 'zero_shot'  and ('14B' in row['model']):
            fig.add_hline(y=row['adherent_and_correct'], annotation_text=row['model'].split('2.5-')[1] + " direct answering", line_dash='dot', line=dict(color='rgba(0, 0, 0, 0.3)'), annotation_position="top left", annotation_font_color='#878787',)
    
    plotly.io.write_image(fig, f'paper/scaling_laws/datascaling_math_score_{encoding_scheme}.pdf', format='pdf')

In [None]:
for encoding_scheme in ['dot_between_chars', 'zero_shot', 'identity']:

    
    fig = px.line(df_viz_tmp[(df_viz_tmp['encoding_scheme'] == encoding_scheme) & (df_viz_tmp['model'].str.contains('Qwen'))],
                  x='num_training_tokens', y='backtranslation_bleu_scores', color='model', markers=True, log_x=True, color_discrete_map=d_qwen_gradient)
    
    fig.update_layout(height=400, width=525, template="plotly_white", font=dict(size=15), showlegend=False)
    fig.update_yaxes(title="BLEU score")
    fig.update_xaxes(title="# fine-tuning tokens")
    
    df_for_labels = df_viz_tmp[(df_viz_tmp['encoding_scheme'] == encoding_scheme) & (df_viz_tmp['model'].str.contains('Qwen'))]
    if encoding_scheme != 'identity':
        for model in df_for_labels['model'].unique():
            model_data = df_for_labels[df_for_labels['model'] == model]
            # Get the point with maximum x value (rightmost point)
            max_x_row = model_data.loc[model_data['num_training_tokens'].idxmin()]
            
            fig.add_annotation(
                x=np.log10(max_x_row['num_training_tokens']),
                y=max_x_row['backtranslation_bleu_scores'],
                text=model.split('2.5-')[1] if '2.5-' in model else model,
                showarrow=False,
                xanchor='right',
                yanchor='middle',
                xshift=-10,
                font=dict(size=16, color=d_qwen_gradient.get(model, 'black')),
            )
    
    
    def format_tokens(value):
        if value >= 1_000_000_000:
            return f'{value/1_000_000_000:.0f}B' if value % 1_000_000_000 == 0 else f'{value/1_000_000_000:.1f}B'
        elif value >= 1_000_000:
            return f'{value/1_000_000:.0f}M' if value % 1_000_000 == 0 else f'{value/1_000_000:.1f}M'
        else:
            return str(value)
    
    # Generate tickvals based on your data range
    tickvals = [100_000_000, 200_000_000, 500_000_000, 1_000_000_000, 2_000_000_000, 5_000_000_000]
    ticktext = [format_tokens(val) for val in tickvals]
    
    fig.update_xaxes(
        title="# fine-tuning tokens",
        tickvals=tickvals,
        ticktext=ticktext,
        tickmode='array'
    )

    fig.update_yaxes(range=[80, 100])

    plotly.io.write_image(fig, f'paper/scaling_laws/datascaling_translation_bleu_{encoding_scheme}.pdf', format='pdf')