# Figure 1: the shape of beliefs

### Plots panels B, C and D based on activations and logits

In [7]:
# imports
from pathlib import Path
import torch
import numpy as np
from sklearn.decomposition import PCA
import plotly.express as px
import plotly.graph_objects as go
from plotly.subplots import make_subplots
# import for inPCA embeddings
import sys
REPO_ROOT = Path.cwd().parent
if str(REPO_ROOT) not in sys.path:
    sys.path.insert(0, str(REPO_ROOT))
from utils.inpca import inpca_embedding

In [None]:
# Config: layer, datasets, figure
BASE_DIR = Path.cwd().parent
ACTIVS_DIR = BASE_DIR / "data" / "activations"
LOGITS_DIR = BASE_DIR / "data" / "logits"
DATASETS = [
    "gaussian_m300_s100_l1000_n10",
    "gaussian_m350_s100_l1000_n10",
    "gaussian_m400_s100_l1000_n10",
    "gaussian_m450_s100_l1000_n10",
    "gaussian_m500_s010_l1000_n10",
    "gaussian_m500_s020_l1000_n10",
    "gaussian_m500_s030_l1000_n10",
    "gaussian_m500_s050_l1000_n10",
    "gaussian_m500_s080_l1000_n10",
    "gaussian_m500_s120_l1000_n10",
    "gaussian_m500_s150_l1000_n10",
    "gaussian_m500_s200_l1000_n10",
    "gaussian_m500_s100_l1000_n10",
    "gaussian_m550_s100_l1000_n10",
    "gaussian_m600_s100_l1000_n10",
    "gaussian_m650_s100_l1000_n10",
    "gaussian_m700_s100_l1000_n10",
]
LAYER = 14 # layer 14 is best layer for curve
SITE = f"model.layers.{LAYER}"
DROP_FIRST = 500  # number of com->num positions to remove per dataset
TEMP = 1.0 

N_COMPONENTS = 16     # PCA components

PROB_SUBSET_SIZE = 8000 # for inPCA to stay manageable

FIG_WIDTH_CM = 7.7
FIG_HEIGHT_CM = 10
SAVE_FIGURE = False
FIGURE_NAME = "figure01.html"

In [10]:
# load activations

def load_com2num_for_dataset(dataset_name: str) -> torch.Tensor:
    ds_dir = ACTIVS_DIR / dataset_name
    pattern = f"{SITE.replace('.', '_')}_batch*.pt"
    files = sorted(ds_dir.glob(pattern))
    if not files:
        raise FileNotFoundError(f"No activation files for {dataset_name}: {ds_dir / pattern}")

    seq_tensors = []
    for fpath in files:
        payload = torch.load(fpath, map_location="cpu")
        acts = payload["activations"]  # [batch, seq_len, d_model]
        lengths = payload.get("lengths")
        for i in range(acts.shape[0]):
            length = int(lengths[i].item()) if lengths is not None else acts.shape[1]
            com_to_num_idx = range(2, length, 2)                 # positions whose *next* token is a number
            com2num = acts[i, com_to_num_idx]                    # [num_nums, d_model]
            if DROP_FIRST > 0:
                com2num = com2num[DROP_FIRST:]                   # drop early tokens
            seq_tensors.append(com2num)

    if not seq_tensors:
        raise ValueError(f"No activations found for {dataset_name}")
    all_acts = torch.cat(seq_tensors, dim=0)                     # [total_tokens, d_model]
    return all_acts

# Collect activations for all datasets
dataset_labels = []
all_act_list = []
for ds in DATASETS:
    acts = load_com2num_for_dataset(ds)
    all_act_list.append(acts)
    dataset_labels.extend([ds.split("_")[1]] * acts.shape[0])  # label like "m300"

all_acts = torch.cat(all_act_list, dim=0).cpu().numpy()        # [N, d_model]

In [11]:
# load softmax

def load_com2num_probs(dataset: str, expected_labels=None):
    """Load probabilities at commaâ†’number positions for a dataset."""
    ds_dir = LOGITS_DIR / dataset
    files = sorted(ds_dir.glob("logits_batch*.pt"))
    if not files:
        raise FileNotFoundError(f"No logits files for {dataset}: {ds_dir / 'logits_batch*.pt'}")

    seq_vectors = []
    token_labels = None
    for fpath in files:
        payload = torch.load(fpath, map_location="cpu")
        logits = payload["logits"]  # [batch, seq_len, subset_size]
        labels = payload.get("token_strings") or payload.get("token_labels")
        if labels is None:
            raise ValueError(f"Missing token labels in {fpath}")
        if token_labels is None:
            token_labels = labels
        elif labels != token_labels:
            raise ValueError(f"Token label order mismatch in {fpath}")

        probs = torch.softmax(logits / TEMP, dim=-1)
        lengths = payload.get("lengths")

        for i in range(probs.shape[0]):
            length = int(lengths[i].item()) if lengths is not None else probs.shape[1]
            com_to_num_idx = range(2, length, 2)     # positions whose *next* token is a number
            slice_probs = probs[i, com_to_num_idx]   # [num_positions, subset_size]
            if DROP_FIRST:
                slice_probs = slice_probs[DROP_FIRST:]
            seq_vectors.append(slice_probs)

    if not seq_vectors:
        raise ValueError(f"No probabilities found for {dataset}")
    all_probs = torch.cat(seq_vectors, dim=0)          # [total_positions, subset_size]
    if expected_labels is not None and token_labels != expected_labels:
        raise ValueError(f"Token label order mismatch in {dataset}")
    return all_probs, token_labels


dataset_labels = []
all_prob_list = []
label_order = None

for ds in DATASETS:
    probs, label_order = load_com2num_probs(ds, expected_labels=label_order)
    all_prob_list.append(probs)
    dataset_labels.extend([ds.split("_")[1]] * probs.shape[0])  # e.g., "m300"

all_probs = torch.cat(all_prob_list, dim=0).cpu().numpy()       # [N, subset_size]
print(len(all_prob_list))
print(all_prob_list[0].shape)
print(all_probs.shape)

17
torch.Size([5000, 1006])
(85000, 1006)


In [12]:

def cm_to_px(cm):
    return cm * 37.7952755906

def per_point_mass_std(datasets, counts):
    masses, stds = [], []
    for ds, n in zip(datasets, counts):
        masses.extend([int(ds.split("_")[1][1:])] * n)   # e.g., 300
        stds.extend([int(ds.split("_")[2][1:])] * n)     # e.g., 100
    return np.array(masses), np.array(stds)

# --- PCA on activations ---
act_counts = [acts.shape[0] for acts in all_act_list]
act_mass, act_std = per_point_mass_std(DATASETS, act_counts)
pca = PCA(n_components=max(3, N_COMPONENTS))
act_coords = pca.fit_transform(all_acts)[:, :3]

# mask_100 = act_std == 100
# mask_other = ~mask_100
# activations
mask_100 = (act_std == 100) & (act_mass != 500)  # all std=100 except m500
mask_other = ~mask_100                           # includes all m=500 variants


# --- inPCA on softmax (T=1.0) ---
prob_counts = [probs.shape[0] for probs in all_prob_list]
prob_mass, prob_std = per_point_mass_std(DATASETS, prob_counts)

m = all_probs.shape[0]
subset_size = min(PROB_SUBSET_SIZE, m)
rng = np.random.default_rng(0)
subset_idx = rng.choice(m, size=subset_size, replace=False)
probs_subset = all_probs[subset_idx]
mass_sub = prob_mass[subset_idx]
std_sub = prob_std[subset_idx]
# mask_100_sub = std_sub == 100
# mask_other_sub = ~mask_100_sub
# probabilities (after mass_sub/std_sub)
mask_100_sub = (std_sub == 100) & (mass_sub != 500)
mask_other_sub = ~mask_100_sub

inpca_coords, inpca_explained = inpca_embedding(probs_subset, dim=3)


In [13]:
# Plots

# 3D PCA traces (activations)
pca_traces = [
    go.Scatter3d(
        x=act_coords[mask_100, 0],
        y=-act_coords[mask_100, 1],
        z=act_coords[mask_100, 2],
        mode="markers",
        marker=dict(size=0.5, color=act_mass[mask_100],
                    colorscale="Phase", cmin=100, cmax=900, showscale=False),
        showlegend=False,
    ),
    go.Scatter3d(
        x=act_coords[mask_other, 0],
        y=act_coords[mask_other, 1],
        z=act_coords[mask_other, 2],
        mode="markers",
        marker=dict(size=0.5, color=act_std[mask_other],
                    colorscale="dense", cmin=-50, cmax=250, showscale=False),
        showlegend=False,
    ),
]

# 3D inPCA traces (probabilities)
inpca_traces = [
    go.Scatter3d(
        x=inpca_coords[mask_100_sub, 0],
        y=inpca_coords[mask_100_sub, 1],
        z=-inpca_coords[mask_100_sub, 2],
        mode="markers",
        marker=dict(size=1, color=mass_sub[mask_100_sub],
                    colorscale="Phase", cmin=100, cmax=900, showscale=False),
        showlegend=False,
    ),
    go.Scatter3d(
        x=inpca_coords[mask_other_sub, 0],
        y=inpca_coords[mask_other_sub, 1],
        z=-inpca_coords[mask_other_sub, 2],
        mode="markers",
        marker=dict(size=1, color=std_sub[mask_other_sub],
                    colorscale="dense", cmin=-50, cmax=250, showscale=False),
        showlegend=False,
    ),
]

# Bottom PDF panel (Ïƒ=100, Î¼ âˆˆ {300,400,500,600,700})
labels = list(label_order)
numeric = [(i, int(lbl)) for i, lbl in enumerate(labels) if lbl.isdigit()]
numeric_sorted = sorted(numeric, key=lambda t: t[1])
num_indices = [i for i, _ in numeric_sorted]
other_indices = [i for i, lbl in enumerate(labels) if not lbl.isdigit()]
ordered_indices = num_indices + other_indices
ordered_labels = [labels[i] for i in ordered_indices]

target_mus = {300, 400, 500, 600, 700}
line_traces, tick_vals, annotations = [], [], []
max_y = 0.0

for ds, probs in zip(DATASETS, all_prob_list):
    m = int(ds.split("_")[1][1:])
    s = int(ds.split("_")[2][1:])
    if s == 100 and m in target_mus:
        t = (m - 100) / (900 - 100) # t = (m - 100) / (900 - 100)
        color = px.colors.sample_colorscale("Phase", t)[0] # balance
        avg_vec = probs.mean(dim=0).cpu().numpy()[ordered_indices]
        line_traces.append(
            go.Scatter(
                x=list(range(len(ordered_labels))),
                y=avg_vec,
                mode="lines",
                line=dict(color=color, width=1.2),
                showlegend=False,
            )
        )
        max_y = max(max_y, avg_vec.max())
        xpos = ordered_labels.index(str(m))
        tick_vals.append(xpos)
        annotations.append(dict(
            x=xpos, y=None,  # fill after max_y known
            xref="x3", yref="y3",
            text=str(m),
            showarrow=False,
            font=dict(color=color, size=10),
        ))

if max_y == 0:
    max_y = 1.0
for a in annotations:
    a["y"] = max_y * 1.05

fig = make_subplots(
    rows=3,
    cols=1,
    specs=[[{"type": "scene"}], [{"type": "scene"}], [{"type": "xy"}]],
    # row_heights=[0.42, 0.42, 0.16],
    row_heights=[0.4, 0.4, 0.2],
    vertical_spacing=0.04,
)

for tr in pca_traces:
    fig.add_trace(tr, row=1, col=1)
for tr in inpca_traces:
    fig.add_trace(tr, row=2, col=1)
for tr in line_traces:
    fig.add_trace(tr, row=3, col=1)

# camera = dict(eye=dict(x=0.5, y=1.5, z=0.7), up=dict(x=0, y=0, z=1))
# camera = dict(eye=dict(x=-1.0, y=1.6, z=0.8), up=dict(x=0, y=0, z=1))
zoom = 0.75  # < 1 zooms in, > 1 zooms out

camera = dict(
    eye=dict(
        x=-1.0 * zoom,
        y= 1.6 * zoom,
        z= 0.8 * zoom,
    ),
    up=dict(x=0, y=0, z=1),
)

# set camera as before...
fig.update_layout(
    width=cm_to_px(FIG_WIDTH_CM),
    height=cm_to_px(FIG_HEIGHT_CM) + 100, #+120
    showlegend=False,
    margin=dict(l=0, r=0, t=0, b=0),
    scene=dict(xaxis_title="", yaxis_title="", zaxis_title="", camera=camera),
    scene2=dict(xaxis_title="", yaxis_title="", zaxis_title="", camera=camera),
)

tick_labels = ["0", "300", "400", "500", "600", "700", "999"]
tick_vals = [ordered_labels.index(lbl) for lbl in tick_labels if lbl in ordered_labels]

fig.update_xaxes(
    row=3, col=1,
    tickmode="array",
    tickvals=tick_vals,
    ticktext=tick_labels,
    range=[-0.5, len(ordered_labels) - 0.5],  # center ticks
    showgrid=True,
    gridcolor="black",
    zeroline=True,
    tickfont=dict(size=10),
)
fig.update_layout(
    plot_bgcolor="white",
    # paper_bgcolor="white",
    # existing layout argsâ€¦
)

fig.update_yaxes(
    row=3, col=1,
    title_text="",
    showticklabels=False,
    showgrid=False,
    zeroline=False,
)

fig.update_xaxes(
    row=3, col=1,
    tickmode="array",
    tickvals=tick_vals,
    ticktext=tick_labels,
    range=[-0.5, len(ordered_labels) - 0.5],
    showgrid=True,
    gridcolor="black",
    zeroline=False,
    showline=True,
    linecolor="black",
    ticks="outside",
)

fig.update_yaxes(
    row=3, col=1,
    showticklabels=False,  # turn y ticks back on
    showgrid=False,
    zeroline=False,
    showline=True,
    linecolor="black",
    ticks="outside",
)

axis_style = dict(
    showbackground=True,
    backgroundcolor="white",
    gridcolor="#dcdcdc",
    zerolinecolor="#dcdcdc",
    title="",
    showticklabels=False,
    showgrid=True,
    zeroline=False,
    ticks="",
)

scene_layout = dict(
    xaxis=axis_style,
    yaxis=axis_style,
    zaxis=axis_style,
)

fig.update_layout(
    scene=scene_layout,
    scene2=scene_layout,
)


fig.add_annotation(
    x=-0.1, y=1.0,
    xref="paper", yref="paper",
    text="B",
    showarrow=False,
    font=dict(size=14, color="black"),
    align="left"
)

fig.add_annotation(
    x=-0.1, y=0.5,
    xref="paper", yref="paper",
    text="C",
    showarrow=False,
    font=dict(size=14, color="black"),
    align="left"
)

fig.add_annotation(
    x=-0.1, y=0.1,
    xref="paper", yref="paper",
    text="D",
    showarrow=False,
    font=dict(size=14, color="black"),
    align="left"
)

fig.show()

# save figure or not
if SAVE_FIGURE:
    fig.write_html(FIGURE_NAME)

