In [None]:
import plotly.express as px
import numpy as np
import duckdb
from tqdm import tqdm

In [None]:
import sys
sys.path.append("/home/ubuntu/sky_workdir/encoding-schemes")

from encoding_schemes import get_deterministic_adherence_fn

In [None]:
import ray

ray.init()

In [None]:
import os
import psycopg2
import json

conn_string = os.environ["SUPABASE_CONNECTION_URL"]

conn = psycopg2.connect(conn_string)

import pandas as pd

In [None]:
# df = pd.read_sql("SELECT * FROM public.encoding_schemes WHERE (data->'experiment_tags'->'sft')::boolean", conn)

# sel_str = """
# -- redo prompted
# (
#     (data->'experiment_tags'->'numina_math_cot_rerun')::BOOL
#     AND (NOT (data->'force_overwrite')::BOOL OR data->'force_overwrite' IS NULL)
#     AND (data->'experiment_name')::TEXT LIKE '%prompted_%'
# )
# """

# sel_str = """
# -- Few shot
#  (
#      (data->'experiment_tags'->'numina_math_cot_rerun')::BOOL
#      AND (NOT (data->'force_overwrite')::BOOL OR data->'force_overwrite' IS NULL)
#      AND (
#          (data->'experiment_params'->'n_few_shot_examples')::INT = 8
#      )
#   )
# """

# sel_str = """
# -- NuminaMath CoT Rerun
#  (
#      (data->'experiment_tags'->'numina_math_cot_rerun')::BOOL
#      AND (NOT (data->'force_overwrite')::BOOL OR data->'force_overwrite' IS NULL)
#      AND (
#          (data->'experiment_params'->'sampling_params'->'n')::INT = 4
#          OR (data->'experiment_params'->'model')::TEXT LIKE '%gpt%'
#      )
#   )
# """

sel_str = """
-- prompted no sft decode
(
    (data->'experiment_tags'->'numina_math_cot_rerun')::BOOL
    AND (NOT (data->'force_overwrite')::BOOL OR data->'force_overwrite' IS NULL)
    AND (data->'experiment_name')::TEXT LIKE '%prompteddecode%'
)
"""

df = pd.read_sql(f"""
SELECT * FROM public.encoding_schemes 
    WHERE 

{sel_str}


ORDER BY created_at DESC
""", conn)

df.head()

In [None]:
root_dir = "/home/ubuntu/sky_workdir/encoding-schemes/output"

In [None]:
l_examples = df.to_dict('records')

l_examples[:5]

In [None]:
[example for example in l_examples if example["data"]["experiment_params"]['encoding_scheme'] == 'speaking_reverse_letters_in_each_word']

In [None]:
df_sft = pd.read_parquet("/home/ubuntu/sky_workdir/encoding-schemes/output/c9abad32f9ae149eab5412db77c84c1992f350ba/data/joined_output.parquet")

df_sft.head()

In [None]:

def bootstrap_ci(data, statistic=np.mean, alpha=0.05, n_boot=10_000, random_state=None):
    """
    Returns (point_estimate, low_CI, high_CI) for given 1D data.
    Works with bool, int, or float data.
    """
    x = np.asarray(data).astype(float)  # ensure numeric
    x = x[~np.isnan(x)]
    if len(x) == 0:
        raise ValueError("No valid data for bootstrapping.")

    rng = np.random.default_rng(random_state)
    n = len(x)

    # Draw bootstrap samples
    idx = rng.integers(0, n, size=(n_boot, n))
    samples = x[idx]

    # Apply statistic row-wise
    stats = np.apply_along_axis(statistic, 1, samples)

    point = statistic(x)
    lo = np.percentile(stats, 100 * (alpha / 2))
    hi = np.percentile(stats, 100 * (1 - alpha / 2))
    return point, lo, hi

In [None]:
import tiktoken

encoding = tiktoken.get_encoding("cl100k_base")


@ray.remote
def count_tokens_from_messages(s):
    try:
        return len(encoding.encode(s, disallowed_special=()))
    except ValueError as e:
        print(e)
        return 0


In [None]:
@ray.remote(num_cpus=2, memory=32 * 1024 * 1024 * 1024)
def compute_translation_token_count(example, df_data):
    sys.path.append("/home/ubuntu/sky_workdir/encoding-schemes")

    from orchestration.experiment_meta_saver import compute_experiment_hash
    from utils.io_utils import read_large_parquet

    from transformers import AutoTokenizer

    model = example["data"]["experiment_params"]["model"]
    if "gpt" in model or "claude" in model:
        print(f"Overriding tokenizer for {model} with gpt-oss 120b tokenizer because it was detected as a GPT/Claude model!")
        model = "openai/gpt-oss-120b"

    tokenizer = AutoTokenizer.from_pretrained(model)

    return df_data['reference_solution'].map(lambda x: len(tokenizer.encode(x)))


In [None]:
from tqdm import tqdm

In [None]:
def compute_ci_cols(example, df, col_name, transformation_fn):
    s_transformed = transformation_fn(df[col_name])
    if np.isscalar(s_transformed) and np.isnan(s_transformed):
        print(f"Warning: {col_name} was all NaN, ignoring!")
        return
    mid, lo, hi = bootstrap_ci(s_transformed)
    example[col_name] = mid
    example[f'{col_name}_low_ci'] = mid - lo
    example[f'{col_name}_hi_ci'] = hi - mid

In [None]:
def compute_multi_row_transformation(df, row_transform_fn, col_name, agg_fn):
    df[col_name] = df.apply(row_transform_fn, axis=1)

In [None]:

def _score_rollouts(rollouts):
    # rollouts: list/array of rollout sequences; may also be None/np.nan/scalar
    if rollouts is None or (isinstance(rollouts, float) and np.isnan(rollouts)):
        return np.nan

    vals = []
    for r in rollouts:
        # skip None/NaN
        if r is None or (isinstance(r, float) and np.isnan(r)):
            continue
        vals.append(-np.nansum(r))
    return np.nanmean(vals) if len(vals) else np.nan


def patch_gpt_api_log_loss(example):
    experiment_hash = example['experiment_hash']

    translation_loss = os.path.join(root_dir, experiment_hash, "data", f"validation_reverse_translation_math500_meta.json")
    with open(translation_loss, "r") as fp:
        example["backtranslation_gt_logprobs"] = translation_loss["valid_loss"]

    # need to be validation loss on 512k...
    validation_loss = os.path.join(root_dir, experiment_hash, "data", f"validation_reverse_translation_math500_meta.json")
    with open(validation_loss, "r") as fp:
        example["backtranslation_gt_logprobs"] = translation_loss["valid_loss"]


@ray.remote
def process_single_example(example):
    target_path = os.path.join(root_dir, example['experiment_hash'], "data", "joined_output.parquet")
    if not os.path.exists(target_path):
        print(f"!!!!!! {target_path} missing !!!!!!!")
        return example
    
    df_data = pd.read_parquet(target_path)

    d_col_to_transform = {
        'cot_gt_logprobs' : lambda s: np.nansum(s.map(_score_rollouts)),
        'generated_cot_is_correct' : np.mean,  # was np.mean
        'backtranslation_gt_logprobs' : lambda s: np.nanmean(s.map(_score_rollouts)),
        'backtranslation_bleu_scores' : np.mean,  # was np.mean
        'generated_cot_adhered_encoding_style': np.mean  # was np.mean
    }
    for col, fn in d_col_to_transform.items():
        if col not in df_data:
            print(col)
            print(df_data.head())
            print(example)
            raise Exception(str(col) + "\n" + str(df_data.head()) + "\n" + str(example))

        compute_ci_cols(example, df_data, col, fn)

    df_data["num_tokens_translation_output"] = ray.get(compute_translation_token_count.remote(example, df_data))

    d_and_cols = {
        'adherent_and_correct': (
            lambda r: np.nanmean( \
                np.array(r['generated_cot_is_correct']).astype(bool) & \
                np.array(r['generated_cot_adhered_encoding_style']).astype(bool) \
            ),
            np.nanmean
        ),
        'total_translation_loss': (
            lambda r: np.nanmean( \
                np.array(r['num_tokens_translation_output']) * \
                np.array(np.nanmean(_score_rollouts(r['backtranslation_gt_logprobs'])), dtype=np.float64) \
            ),
            np.nanmean
        ),
    }
    for col, (transform_fn, agg_fn) in d_and_cols.items():
        compute_multi_row_transformation(df_data, transform_fn, col, agg_fn)
        if df_data[col].isna().sum() != len(df_data):
            compute_ci_cols(example, df_data, col, lambda x: x)

    if "gpt" in example["data"]["experiment_params"]["model"]:
        with open(os.path.join(root_dir, example['experiment_hash'], "data", f"validation_reverse_translation_math500_meta.json"), "r") as fp:
            d_model_meta = json.load(fp)

        example["backtranslation_gt_logprobs"] = d_model_meta["valid_loss"]
        example["total_translation_loss"] = d_model_meta["valid_loss"] * np.nanmean(df_data["num_tokens_translation_output"])

    for col in df_data.columns:
        example[f"{col}_df"] = df_data[col]

    return example


l_new_examples = [None for _ in range(len(l_examples))]

for i, example in tqdm(enumerate(l_examples)):
    # l_examples[i] = process_single_example(example)
    l_new_examples[i] = process_single_example.remote(example)

for i, example in tqdm(enumerate(l_new_examples)):
    try:
        l_new_examples[i] = ray.get(example)
    except ray.exceptions.RayTaskError as e:
        l_new_examples[i] = l_examples[i]
        print(e)

l_examples = l_new_examples

In [None]:
def humanize_number(num: float) -> str:
    """
    Converts a number into a human-readable string with k, M, or B suffixes.
    
    Args:
        num (float): The number to format.
    
    Returns:
        str: Human-readable string representation.
    """
    if num >= 1_000_000_000:
        return f"{num / 1_000_000_000:.1f}B"
    elif num >= 1_000_000:
        return f"{num / 1_000_000:.1f}M"
    elif num >= 1_000:
        return f"{num / 1_000:.1f}k"
    else:
        return str(num)

In [None]:
import re

def parse_params(model):
    if 'gpt' in model:
        if 'nano' in model:
            return 0
        elif 'mini' in model:
            return 1
        else:
            return 2


    if 'claude' in model:
        if 'haiku' in model:
            return 3
        elif 'sonnet' in model:
            return 4
        else:
            return 5
    
    return int(re.search("([0-9]+)B", model).group(1))

In [None]:
df_viz = pd.DataFrame(l_examples)

orig_len = len(df_viz)

# df_viz = df_viz[df_viz['cot_gt_logprobs'].notna()]

new_len = len(df_viz)
if orig_len != new_len:
    print(f"Lost {orig_len - new_len} examples from na logprobs")

df_viz['encoding_scheme'] = df_viz['data'].map(lambda x: x['experiment_params']['encoding_scheme'])
df_viz['model'] = df_viz['data'].map(lambda x: x['experiment_params']['model'])

try:
    df_viz['model_size'] = df_viz['model'].map(parse_params)
except Exception as e:
    print(e)
df_viz['input_type'] = df_viz['data'].map(lambda x: "_".join(x['experiment_name'].split("_")[:2]))

df_viz['n_few_shot_examples'] = df_viz['data'].map(lambda x: x['experiment_params'].get('n_few_shot_examples', None))

df_viz['Adherence Calculation Method'] = df_viz['encoding_scheme'].map(lambda x: get_deterministic_adherence_fn(x, None) is not None).map({ True: 'deterministic', False: 'Sonnet 4 judge'})

try:
    df_viz['total_train_tok'] = df_viz['n_total_train_tok'].map(humanize_number)
except Exception as e:
    print(e)

df_viz.head()

In [None]:
df_viz['input_type'].unique()

In [None]:
# filter_set = ['mathcot_fewshot']
# filter_set = ['math_cot']
# filter_set = ['mathcot_prompted']
filter_set = ['mathcot_prompteddecode']

In [None]:
df_viz_tmp = df_viz[df_viz['input_type'].isin(filter_set)]

df_viz_tmp = df_viz_tmp.sort_values([
    'model_size',
    'adherent_and_correct'
])

df_viz_tmp = df_viz_tmp.astype({'n_few_shot_examples': str})

df_viz_tmp['encoding_scheme'] = df_viz_tmp['encoding_scheme'].map(lambda s: s.split("speaking_")[-1])

In [None]:
df_viz_tmp_plot = df_viz_tmp.copy()

fig = px.bar(df_viz_tmp_plot, x='encoding_scheme', y='generated_cot_adhered_encoding_style',
             height=600, width=1600,
             # height=1600, width=1600,
             # color='model',
             error_y='generated_cot_adhered_encoding_style_hi_ci',
             error_y_minus='generated_cot_adhered_encoding_style_low_ci',
             # color='n_few_shot_examples',
             # facet_row='model',
             color='model',
             facet_col='Adherence Calculation Method',
             # color='training_augmentation',
             # color='total_train_tok',
             # color_discrete_map=color_discrete_map,
             title="MATH 500 CoT encoding style adherence",
             barmode="group",
             template="plotly_white",
             color_discrete_sequence=px.colors.qualitative.Set2,
            )

fig.update_xaxes(title="Encoding scheme", tickangle=90)
fig.update_yaxes(title="% adherent encodings", dtick=0.1)

# ✅ Make each facet's x-axis independent to avoid empty slots
fig.for_each_xaxis(lambda ax: ax.update(matches=None, categoryorder='trace'))

# ✅ Align all x-axis titles by setting a fixed vertical offset
ct = [0]
def ax_standoff_updater(ax, ct):    
    ax.title.update(standoff=ct[0] * 130)
    ct[0] += 1

fig.for_each_xaxis(lambda ax: ax_standoff_updater(ax, ct))
fig.update_yaxes(title_standoff=5)

fig.show()

In [None]:
import plotly.graph_objects as go

def make_encoding_scheme_bar_plot(
    df,
    y_col="generated_cot_is_correct",
    title="MATH 500 Accuracy",
    y_axis_title="Accuracy",
    x_col="encoding_scheme",
    model_col="model",
    d_mapping=None,
    yaxis_dtick=0.1,            # 10% ticks read cleaner
    show_text=False,            # optionally show % text on bars
    model_order=None,           # lock legend/order of models
    font_family="Inter, Arial", # consistent, modern
    sort_by_col=None,
    sort_agg="max",
    sort_desc=True
):
    # --- 1) DEFAULT SECTION MAP ---
    if d_mapping is None:
        d_mapping = {
            "baseline": ["zero_shot", "identity"],
            # "letter mutation": [
            #     "reverse_letters_in_each_word",
            #     "swap_even_odd_letters_in_each_word",
            #     "reverse_fibonacci_indices_in_each_word",
            #     "letter_to_word_with_dot",
            #     "dot_between_chars",
            #     "space_between_chars",
            # ],
            # "language deletion": ["remove_all_verbs", "remove_all_nouns"],
            # "language translation": [
            #     "French","Chinese","Korean","Russian","Arabic","Adyghe",
            #     "Morse_code","Python","enterprise_Java",
            # ],
            # "algorithmic cipher": [
            #     "rot13_cipher","base64_cipher","base64_2x_cipher",
            #     "base64_3x_cipher","caesar_cipher","gzip_to_base64_encoded",
            # ],
            "themed reasoning": [
                "paraphrase_naive",
                "pirate_speak",
                "leet_speak",
                "yoda_speak",
                "shakespearean_text",
            ],
            "extraneous content": [
                "insert_tweet",
                "python_snippet_comment",
                "croissant_news_article",
                "math_textbook_article",
                "five_emojis",
            ],
            "delete inf.": [
                "replace_math_content_with_black_box"
            ]
        }

    df_plot = df.copy()

     # ===================== NEW SORTING LOGIC =====================
    # Decide which column drives sorting
    metric_col = sort_by_col or y_col

    # Pick aggregation
    agg_map = {"max": "max", "mean": "mean", "median": "median"}
    agg_fn = agg_map.get(sort_agg, "max")

    # Aggregate per scheme across models on the sorting metric
    if metric_col not in df_plot.columns:
        raise ValueError(f"sort_by_col '{metric_col}' not found in dataframe.")

    sort_metric = (
        df_plot
        .groupby(x_col, as_index=False)[metric_col]
        .agg(agg_fn)
        .rename(columns={metric_col: "sort_value"})
    )
    # =============================================================

    # --- 2) ORDERING BY SECTION, THEN BY MAX(Y) DESC ---
    full_category_order, section_spans = [], []
    present_set = set(df_plot[x_col].unique())
    cursor = 0

    for section_name, schemes in d_mapping.items():
        present = [s for s in schemes if s in present_set]
        if not present:
            continue

        section_df = sort_metric[sort_metric[x_col].isin(present)].copy()
        section_df = section_df.sort_values("sort_value", ascending=not sort_desc)
        ordered = section_df[x_col].tolist()
        if not ordered:
            continue

        start = cursor
        full_category_order.extend(ordered)
        cursor = len(full_category_order)
        end = cursor - 1
        section_spans.append((start, end, section_name))

    if not full_category_order:
        full_category_order = list(df_plot[x_col].unique())
        section_spans = [(0, len(full_category_order) - 1, "All")]

    df_plot[x_col] = pd.Categorical(df_plot[x_col], categories=full_category_order, ordered=True)

    # --- 1) PRESERVE MODEL ORDER FROM DATAFRAME ---
    # Use the order models appear in df instead of sorting alphabetically
    if model_order is None:
        model_order = list(dict.fromkeys(df_plot[model_col]))  # <-- preserves original order
    df_plot[model_col] = pd.Categorical(df_plot[model_col], categories=model_order, ordered=True)

    # --- 3) ERROR BARS (auto-detect) ---
    err_hi_col = f"{y_col}_hi_ci"
    err_lo_col = f"{y_col}_low_ci"
    error_y = err_hi_col if err_hi_col in df_plot.columns else None
    error_y_minus = err_lo_col if err_lo_col in df_plot.columns else None

    def wrap_string(s: str, width: int = 15) -> str:
        """
        Wraps the input string so that each line has at most `width` characters.
        If a word exceeds the width, a '-' and newline are inserted.
    
        Args:
            s (str): Input string to wrap.
            width (int): Maximum number of characters per line. Default is 15.
    
        Returns:
            str: The wrapped string.
        """
        result = ""
        i = 0
    
        while i < len(s):
            # Take a chunk of 'width' characters
            chunk = s[i:i+width]
            # If the chunk is exactly width long and not at the end, add a dash
            if len(chunk) == width and i + width < len(s):
                result += chunk + "<br>"
            else:
                result += chunk
            i += width
    
        return result

    # Label prettifier
    def prettify(s: str) -> str:
        s = s.replace("_", " ")
        return wrap_string(s, 8)

    # --- 2) DETECT WHETHER Y-VALUES ARE PERCENTAGES OR NUMERICAL ---
    y_min, y_max = df_plot[y_col].min(), df_plot[y_col].max()
    is_percent = y_max <= 1.0 and y_min >= 0.0

    # Set default dtick depending on scale
    if yaxis_dtick is None:
        yaxis_dtick = 0.1 if is_percent else None
    
    # --- 2) BASE FIGURE (no change here, just ensures model order applies)
    fig = px.bar(
        df_plot,
        x=x_col,
        y=y_col,
        color=model_col,
        category_orders={x_col: full_category_order, model_col: model_order},
        barmode="group",
        template="plotly_white",
        color_discrete_sequence=px.colors.qualitative.Safe,
        title=title,
        height=600,
        width=1800,
        error_y=error_y,
        error_y_minus=error_y_minus,
        text=(df_plot[y_col] if show_text else None),
    )

    # Y axis as percent
    if is_percent:
        fig.update_yaxes(
            title=y_axis_title,
            tickformat=".0%",
            tickmode="linear",
            dtick=yaxis_dtick,
            rangemode="tozero"
        )
    else:
        fig.update_yaxes(
            title=y_axis_title,
            tickmode="linear",
            rangemode="tozero",
            dtick=yaxis_dtick
        )

    # X axis: rely on categoryarray for order; set readable tick labels
    fig.update_xaxes(
        title="Encoding scheme",
        ticktext=[f"<b>{prettify(lbl)}</b>" for lbl in full_category_order],
        tickvals=full_category_order,       # use category values, not numeric indices
        tickangle=0
    )

    # Bar spacing and text formatting
    fig.update_traces(
        texttemplate = "%{y:.0%}" if (is_percent and show_text) else "%{y:}" if show_text else None,
        textposition="outside" if show_text else "none",
        # Pass extra fields if you want them in hover (optional)
        customdata=df_plot[[model_col]].values,
    )
    fig.update_layout(
        bargap=0.15,
        bargroupgap=0.05,
        legend=dict(
            orientation="h",
            yanchor="bottom",
            y=1.08,
            xanchor="left",
            x=0.01,
            font=dict(size=10)  # ↓ Smaller legend font
        ),
        margin=dict(t=110, r=30, b=80, l=70),
        font=dict(family=font_family, size=14),
        title=dict(font=dict(size=22))
    )
    
    # Make x-axis tick labels smaller
    fig.update_xaxes(
        ticktext=[f"<b>{prettify(lbl)}</b>" for lbl in full_category_order],
        tickvals=full_category_order,
        tickangle=0,
        tickfont=dict(size=10)  # ↓ Smaller x-axis font size
    )

    # --- 5) SECTION BACKGROUNDS + LABELS ---
    N = len(full_category_order)

    # Helper: convert [start,end] (category indices) to x domain [0,1]
    def to_xdomain(idx):
        # position of left edge of category idx:
        return idx / N

    # alternating light bands for sections
    shapes = []
    annotations = []
    for i, (start_idx, end_idx, name) in enumerate(section_spans):
        x0 = to_xdomain(start_idx)
        x1 = to_xdomain(end_idx + 1)
        band = dict(
            type="rect",
            xref="x domain", yref="paper",
            x0=x0, x1=x1, y0=0, y1=1,
            layer="below",
            line=dict(width=0),
            fillcolor="rgba(0,0,0,0.03)" if i % 2 == 0 else "rgba(0,0,0,0.00)"
        )
        shapes.append(band)
        annotations.append(dict(
            x=(x0 + x1) / 2, xref="x domain",
            y=1.06, yref="paper",
            text=name, showarrow=False,
            font=dict(size=12, color="black"),
            xanchor="center"
        ))
    fig.update_layout(shapes=shapes)
    for a in annotations:
        fig.add_annotation(**a)

    return fig

In [None]:
fig = make_encoding_scheme_bar_plot(
    df=df_viz_tmp,
    # df=df_viz_tmp[df_viz_tmp['encoding_scheme'] == 'letter_to_word_with_dot'],
    y_col="generated_cot_is_correct",
    title="MATH 500 Accuracy",
    y_axis_title="Accuracy",
    yaxis_dtick=0.05,
    sort_by_col="adherent_and_correct"
)
fig.show('png')

In [None]:
fig = make_encoding_scheme_bar_plot(
    df=df_viz_tmp,
    y_col="adherent_and_correct",
    title="Proportion of encoding adherent & correct responses on MATH 500",
    y_axis_title="% of responses",
    yaxis_dtick=0.05,
    sort_by_col="adherent_and_correct"

)
fig.show('png')

In [None]:
df_pgr[df_pgr['model'].str.contains('7B') & ~df_pgr['model'].str.contains('Instruct')]

In [None]:
# --- Compute PGR (percent of identity performance) ---
df_pgr = df_viz_tmp.copy()

# Identity baseline per model
identity_baseline = (
    df_pgr[df_pgr["encoding_scheme"] == "identity"]
    .set_index("model")["adherent_and_correct"]
)

df_pgr["identity_value"] = np.maximum(df_pgr["model"].map(identity_baseline), 0.0001)

# Avoid divide-by-zero
# df_pgr = df_pgr[df_pgr["identity_value"] > 0].copy()

# Mean PGR as percentage
df_pgr["PGR_pct"] = (df_pgr["adherent_and_correct"] / df_pgr["identity_value"])

# CI deltas -> percentage deltas relative to identity baseline
# low is negative, high is positive
df_pgr["PGR_pct_hi_ci"]  = (df_pgr["adherent_and_correct_hi_ci"] / df_pgr["identity_value"])
df_pgr["PGR_pct_low_ci"] = (df_pgr["adherent_and_correct_low_ci"]        / df_pgr["identity_value"])

df_pgr.loc[df_pgr["encoding_scheme"] == "identity", ["PGR_pct_hi_ci", "PGR_pct_low_ci"]] = 0.0

fig = make_encoding_scheme_bar_plot(
    df=df_pgr,
    y_col="PGR_pct",
    title="Relative % of responses adherent & correct vs. identity encoding",
    y_axis_title="% of identity performance",
    yaxis_dtick=0.1,
    sort_by_col="adherent_and_correct"
)
fig.show('png')


In [None]:
fig = make_encoding_scheme_bar_plot(
    df=df_viz_tmp,
    y_col="cot_gt_logprobs",
     title="Log loss on MATH 500 ground-truth encoded CoT from PRM800K",
    y_axis_title="Mean log loss per sequence",
    yaxis_dtick=250,
    sort_by_col="adherent_and_correct"
)
fig.show()

In [None]:
df_viz_tmp_plot = df_viz_tmp.copy()

df_viz_tmp_plot = df_viz_tmp_plot[df_viz_tmp_plot['encoding_scheme'] != 'zero_shot']

fig = make_encoding_scheme_bar_plot(
    df=df_viz_tmp_plot,
    y_col="backtranslation_bleu_scores",
     title="Encoded -> English translation BLEU",
    y_axis_title="Encoded -> English translation BLEU",
    yaxis_dtick=10,
    sort_by_col="adherent_and_correct"

)
fig.show('png')

In [None]:
df_viz_tmp_plot = df_viz_tmp.copy()

df_viz_tmp_plot = df_viz_tmp_plot[df_viz_tmp_plot['encoding_scheme'] != 'zero_shot']

fig = make_encoding_scheme_bar_plot(
    df=df_viz_tmp_plot,
    y_col="total_translation_loss",
     title="Encoded -> English translation log loss",
    y_axis_title="Mean log loss per sequence",
    yaxis_dtick=50,
    sort_by_col="adherent_and_correct"

)
fig.show('png')

# Plot the translation ability -> acc curve

In [None]:
import numpy as np
import pandas as pd
import plotly.express as px
import plotly.graph_objects as go
from scipy.optimize import curve_fit

import numpy as np
import pandas as pd
import plotly.graph_objects as go
from plotly.subplots import make_subplots
from scipy.optimize import curve_fit

def styled_logistic_scatter_faceted(
    df,
    facet_col="model",
    x_col="backtranslation_gt_logprobs",
    y_col="adherent_and_correct",
    title="Adherence vs. Backtranslation Log-Prob",
    x_axis_title="Backtranslation log-prob",
    y_axis_title="Adherent & correct",
    font_family="Inter, Arial",
    marker_size=7,
    opacity=0.6,
    legend_font_size=10,
    x_tick_font_size=10,
    line_width=3,
    r2_digits=3,               # <— precision for R² display
    keep_subplot_title=True,
    max_pow_override=None
):
    def logistic(x, L, k, x0):
        return L / (1 + np.exp(-k * (x - x0)))

    # overall y-scale to decide percent formatting
    y_all = df[y_col].to_numpy(dtype=float)
    is_percent = (np.nanmin(y_all) >= 0) and (np.nanmax(y_all) <= 1.0)

    # facet ordering
    if pd.api.types.is_categorical_dtype(df[facet_col]):
        facet_values = [c for c in df[facet_col].cat.categories if (df[df[facet_col]==c].shape[0] > 0)]
    else:
        d_model_to_size = dict(zip(df['model'], df['model_size']))
        facet_values = sorted(df[facet_col].dropna().unique(), key=lambda x: d_model_to_size.get(x, x))

    print(facet_values)

    n_cols = max(1, len(facet_values))
    fig = make_subplots(
        rows=1, cols=n_cols, shared_yaxes=True, horizontal_spacing=0.06,
        subplot_titles=[f"{facet_col} = {v}" if keep_subplot_title else None for v in facet_values],
    )

    x_positive = df[df[x_col] > 0][x_col]
    min_x, max_x = x_positive.min(), x_positive.max()
    # Create tick values at decade intervals
    min_pow = int(np.floor(np.log2(min_x)))
    max_pow = int(np.ceil(np.log2(max_x)))
    tick_vals = [2 ** p for p in range(min_pow, max_pow + 1)]
    if max_pow_override:
        if tick_vals[-1] > max_pow_override:
            tick_vals[-1] = max_pow_override
    tick_text = [f"{v:g}" for v in tick_vals]

    for i, val in enumerate(facet_values, start=1):
        sub = df[df[facet_col] == val]
        x = sub[x_col].to_numpy(dtype=float)
        y = sub[y_col].to_numpy(dtype=float)

        # points
        fig.add_trace(
            go.Scatter(
                x=x, y=y, mode="markers+text", name=f"Points ({val})",
                marker=dict(size=marker_size, line=dict(width=0)),
                opacity=opacity, showlegend=False,
                # text=sub["encoding_scheme"] + "-" + sub["model"],
                hovertemplate="<b>%{text}</b><extra></extra>",  # <-- Just encoding scheme
            ),
            row=1, col=i
        )

        # choose regressor space
        # linear_on_log_x = True
        # if linear_on_log_x:
        #     valid = (x > 0) & np.isfinite(x) & np.isfinite(y)
        #     X = np.log2(x[valid])
        #     Y = y[valid]
        # else:
        #     valid = np.isfinite(x) & np.isfinite(y)
        #     X = x[valid]
        #     Y = y[valid]

        # if X.size >= 2 and np.nanstd(Y) > 0:
        #     b1, b0 = np.polyfit(X, Y, 1)   # Y = b1*X + b0
        #     # make a smooth line across current x-range
        #     x_line = np.linspace(np.nanmin(x[valid]), np.nanmax(x[valid]), 300)
        #     X_line = np.log2(x_line) if linear_on_log_x else x_line
        #     y_line = b1 * X_line + b0

        #     # R^2 on observed points (in same space used to fit)
        #     y_hat = b1 * X + b0
        #     ss_res = np.nansum((Y - y_hat) ** 2)
        #     ss_tot = np.nansum((Y - np.nanmean(Y)) ** 2)
        #     if ss_tot > 0 and np.isfinite(ss_res):
        #         r2 = 1 - ss_res / ss_tot
        #         r2_text = f"{r2:.{r2_digits}f}"

        #     fig.add_trace(
        #         go.Scatter(
        #             x=x_line, y=y_line, mode="lines",
        #             name=f"Linear fit ({val})",
        #             # line=linear_line_kwargs,
        #             showlegend=(i == 1),
        #         ),
        #         row=1, col=i
        #     )
        
        # logistic fit + R^2
        r2_text = "—"
        try:
            if np.isfinite(x).sum() >= 3 and np.nanstd(y) > 0:
                p0 = [np.nanmax(y), 1.0, np.nanmedian(x)]
                params, _ = curve_fit(logistic, x, y, p0=p0, maxfev=10000)
                L, k, x0 = params

                x_fit = np.linspace(np.nanmin(x), max(np.nanmax(x), 98), 300)
                y_fit = logistic(x_fit, L, k, x0)

                # R^2 on observed x
                y_hat = logistic(x, L, k, x0)
                ss_res = np.nansum((y - y_hat) ** 2)
                ss_tot = np.nansum((y - np.nanmean(y)) ** 2)
                if ss_tot > 0 and np.isfinite(ss_res):
                    r2 = 1 - ss_res / ss_tot
                    r2_text = f"{r2:.{r2_digits}f}"

                fig.add_trace(
                    go.Scatter(
                        x=x_fit, y=y_fit, mode="lines",
                        name=f"Logistic fit ({val})",
                        line=dict(width=1, color="black", dash="dash"),
                        showlegend=(i == 1),
                    ),
                    row=1, col=i
                )
        except Exception as e:
            print(e)
            pass

        # R^2 annotation (top-left of this subplot)
        fig.add_annotation(
            x=0.1, y=0.5,
            xref=f"x{i}" if i > 1 else "x",   # ✅ Use proper axis names
            yref=f"y{i}" if i > 1 else "y",
            xanchor="left",
            yanchor="top",
            text=f"R² = {r2_text}",
            showarrow=False,
            font=dict(size=12),
            align="left",
            bgcolor="rgba(255,255,255,0.6)",
            bordercolor="rgba(0,0,0,0.2)",
            borderwidth=1,
        )

        # axes cosmetics
        fig.update_xaxes(
            title=x_axis_title if i == 1 else "",
            tickfont=dict(size=x_tick_font_size),
            zeroline=False,
            type="log",
            autorange="reversed",
            tickvals=tick_vals,
            ticktext=tick_text,
            row=1, col=i,
            dtick=0.1
        )
        fig.update_yaxes(
            tickformat=".0%" if is_percent else ",",
            rangemode="tozero",
            zeroline=False,
            dtick=0.05,
            row=1, col=i,
            title=y_axis_title
        )

    fig.update_layout(
        template="plotly_white",
        title=dict(text=title, font=dict(size=18)),
        height=600,
        width=350 * n_cols if n_cols <= 3 else 300 * n_cols,
        font=dict(family=font_family, size=14),
        margin=dict(t=0, r=0, b=0, l=0),
        legend=dict(orientation="h", yanchor="bottom", y=1.08, xanchor="left", x=0,
                    font=dict(size=legend_font_size)),
        showlegend=False
    )
    # global axis labels
    # fig.add_annotation(text=x_axis_title, showarrow=False, xref="paper", yref="paper",
    #                    x=0.5, y=-0.18, font=dict(size=14))
    # fig.add_annotation(text=y_axis_title, showarrow=False, xref="paper", yref="paper",
    #                    x=-0.06, y=0.5, textangle=-90, font=dict(size=14))

    return fig


# ---- Use it on your dataframe ----
df_viz_tmp_plot = df_viz_tmp.copy()

start_len = len(df_viz_tmp_plot)
df_viz_tmp_plot = df_viz_tmp_plot[~df_viz_tmp_plot['encoding_scheme'].isin(['identity', 'zero_shot'])]
assert len(df_viz_tmp_plot) != start_len

df_viz_tmp_plot["placeholder_col"] = "1"

d_zero_shot_baseline =  (
    df_pgr[df_pgr["encoding_scheme"] == "zero_shot"]
    .set_index("model")["adherent_and_correct"]
)

fig = styled_logistic_scatter_faceted(
    # df_viz_tmp_plot[df_viz_tmp_plot['backtranslation_bleu_scores'].notna() & ~df_viz_tmp_plot['model'].str.contains('gpt')],
    #[df_viz_tmp_plot['adherent_and_correct'] > 0.01],
    df_viz_tmp_plot[(df_viz_tmp_plot['adherent_and_correct'] >= df_viz_tmp_plot['model'].map(d_zero_shot_baseline)) & (~df_viz_tmp_plot['model'].str.contains('gpt'))],
    x_col="backtranslation_gt_logprobs",
    # x_col="backtranslation_bleu_scores",
    y_col="adherent_and_correct",
    # y_col='cot_gt_logprobs',
    # title="% of responses adherent & correct vs. encoded -> English translation BLEU",
    title=None,
    # x_axis_title="Encoded text to English text BLEU score",
    x_axis_title="Encoded text to English text log loss",
    y_axis_title="% of MATH500 responses adherent & correct",
    # facet_col="model",
    facet_col="placeholder_col",
    keep_subplot_title=False,
    max_pow_override=100,
    x_tick_font_size=12
)


fig.update_layout(width=600, height=450, title=None)
# fig.update_layout(width=1600, height=1200, title=None)
# --- shaded region instead of vline ---
# x_cut = 60
# x_right = float(df_viz_tmp_plot["backtranslation_bleu_scores"].max()) + 30
x_cut = 0.2
x_right = 0.02


fig.add_vrect(
    x0=x_cut, x1=x_right,           # shade from cutoff to the right edge of data
    fillcolor="#D3D3D3", opacity=0.25,
    line_width=0,
    layer="below",
    row=1, col=1                    # adjust if you facet into multiple columns
)

# label for the shaded region
fig.add_annotation(
    x=0.84,
    # x=0.85,
    xref="paper",           # data coordinates
    y=0.99, yref="paper",
    text="fluent translation",
    showarrow=False,
    xanchor="left",
    yanchor="top",
    bgcolor="rgba(255,255,255,0.0)",
    bordercolor="rgba(0,0,0,0.0)",
    borderwidth=0,
    font=dict(size=9)
)
fig.update_traces(textposition="bottom left")


fig.show('png')
# fig.show()


In [None]:
l_test_hashes = ['f01404b75a072276ac1ca6f131313a8a615dbb99', '481e3a2a1f521438c8938721b0df08a389e3406f', 'b411a66e1c2f9ca7a79c065099879f202990c9a2', '33c2595336eb7c74c6820302498fd097fa52463c', 'e3f76d94751a09af21005696ac072219143ebc58', 'b77d4d2334cc5dad12f2168d5fa48749622f34af', '2007b98cc05baeb988564f2e51d6f496a8afe7b8', 'f0a6d8673396c2f4415e10c2fc13effa262fe50d', '16b151e08633744e1c501d6f41707177c5aa6125', 'a846c2161731fabfe9b9d371ac05473c5ba9289d', 'a9a752f3c050a16ef74c6ad9436f0aa6291fb1ff', 'b4c85c27bcc89f916f50867f974bd13875416067']

for hash in l_test_hashes:
    target_path = f"/home/ubuntu/sky_workdir/encoding-schemes/output/{hash}/data/sft_train.parquet"
    if os.path.exists(target_path):
        print(duckdb.query(f"SELECT SUM(num_tokens) FROM read_parquet('{target_path}')").to_df())

# Inspect the data

In [None]:
from encoding_schemes import get_deterministic_adherence_fn
from encoding_schemes.letter_permutations import get_English_dictionary, reverse_letters_in_each_word, normalize_word

In [None]:
test_scheme = 'speaking_letter_to_word_with_dot'
test_model = 'Qwen/Qwen2.5-14B-Instruct'
test_idx = 21

adherence_fn = get_deterministic_adherence_fn(test_scheme, None)

In [None]:
df_viz[(df_viz['model'] == test_model) & (df_viz['encoding_scheme'] == test_scheme)]['reference_problem_df'].iloc[0][test_idx]

In [None]:
df_viz[(df_viz['model'] == test_model) & (df_viz['encoding_scheme'] == test_scheme)]['reference_solution_df'].iloc[0][test_idx]

In [None]:
df_viz[(df_viz['model'] == test_model) & (df_viz['encoding_scheme'] == test_scheme)]['generated_cots_df'].iloc[0][test_idx][0]

In [None]:
reverse_letters_in_each_word(df_viz[(df_viz['model'] == test_model) & (df_viz['encoding_scheme'] == test_scheme)]['generated_cots_df'].iloc[0][test_idx][0])

In [None]:
df_viz[(df_viz['model'] == test_model) & (df_viz['encoding_scheme'] == test_scheme)]['generated_cot_adhered_encoding_style'].iloc[0]#[test_idx]

In [None]:
example_hash = df_viz[(df_viz['model'] == test_model) & (df_viz['encoding_scheme'] == test_scheme)]['experiment_hash'].iloc[0]
example_hash

In [None]:
with open(f"/home/ubuntu/sky_workdir/encoding-schemes/output/{example_hash}/data/sft_model_meta.json", "r") as fp:
    d_example = json.load(fp)

d_example

In [None]:
sft_test_path = f"/home/ubuntu/sky_workdir/encoding-schemes/output/{example_hash}/data/sft_train.parquet"

df_sft = pd.read_parquet(sft_test_path)
df_sft['messages'].iloc[0]

In [None]:
len(df_sft)

In [None]:
df_sft = df_sft.iloc[:3000]

In [None]:
df_sft['num_tokens'].sum()

In [None]:
def convert_sft_parquet_to_jsonl(df, output_json_path):
    n_rows_written = 0
    with open(output_json_path, "w", encoding="utf-8") as f:
        for idx, raw in enumerate(df["messages"]):

            json_line = {"messages": list(raw)}
            f.write(json.dumps(json_line, ensure_ascii=False) + "\n")
            n_rows_written += 1

    print(f"[prep] Wrote {n_rows_written} training rows to {output_json_path}")

In [None]:
convert_sft_parquet_to_jsonl(df_sft, sft_test_path.replace("parquet", "jsonl"))

In [None]:
df_sft = pd.read_parquet("/home/ubuntu/sky_workdir/encoding-schemes/output/e4a87a8626efeeb6df76d94bb34e7e5e34c77154/data/sft_train.parquet")

df_sft = df_sft.iloc[:9120]

df_sft.head()

In [None]:
df_sft['num_tokens'].sum()