## Projector Weight Scalars Analysis

This notebook loads a CSV collected by `collect_projector_weights.py`, computes per-sample mean key/value normalized scalars for each projector, and visualizes their distributions. For each projector, it plots:
- Key scalar distribution with the corresponding key gate logit (shown as sigmoid(logit)) marked.
- Value scalar distribution with the corresponding value gate logit (shown as sigmoid(logit)) marked.

Set the CSV path and output directory below and run all cells.


In [None]:
import os
import json
import math
from typing import List

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

# User parameters
CSV_PATH = "/mnt/public/minzihan/unified_memory/projector_weights.csv"  # path to your collected CSV
OUTPUT_DIR = "/mnt/public/minzihan/unified_memory/weight_distribution_plots"
SAVE_FIG = True
FIG_DPI = 220

os.makedirs(OUTPUT_DIR, exist_ok=True)

df = pd.read_csv(CSV_PATH)
print(df.head())
print({c: df[c].dtype for c in df.columns})

# Parse JSON columns into arrays

def parse_json_array_safe(x):
    if pd.isna(x):
        return None
    try:
        return np.array(json.loads(x))
    except Exception:
        return None

# Apply parsing
parsed_key = df["norm_key_scalar"].apply(parse_json_array_safe)
parsed_value = df["norm_value_scalar"].apply(parse_json_array_safe)

df_parsed = df.copy()
df_parsed["norm_key_scalar_arr"] = parsed_key
df_parsed["norm_value_scalar_arr"] = parsed_value

# Compute per-sample mean scalars for each projector

def compute_mean_scalars(arr: np.ndarray) -> float:
    if arr is None:
        return np.nan
    try:
        return float(np.nanmean(arr))
    except Exception:
        return np.nan

df_parsed["mean_key_scalar"] = df_parsed["norm_key_scalar_arr"].apply(compute_mean_scalars)
df_parsed["mean_value_scalar"] = df_parsed["norm_value_scalar_arr"].apply(compute_mean_scalars)

# Helper: convert gate logits to probabilities via sigmoid
sigmoid = lambda t: 1.0 / (1.0 + np.exp(-t))

# Aggregate per projector
projector_ids = sorted(df_parsed["projector_index"].unique().tolist())
print(f"Found {len(projector_ids)} projectors")



In [None]:
def plot_projector_distributions(df_proj: pd.DataFrame, projector_index: int, bins: int = 30):
    data_key = df_proj["mean_key_scalar"].dropna().values
    data_val = df_proj["mean_value_scalar"].dropna().values

    # For logits, take unique values (they should be constant per projector)
    key_logit_vals = df_proj["key_gate_logit"].dropna().unique()
    val_logit_vals = df_proj["value_gate_logit"].dropna().unique()

    key_logit = key_logit_vals[0] if len(key_logit_vals) > 0 else None
    val_logit = val_logit_vals[0] if len(val_logit_vals) > 0 else None

    key_enabled = (key_logit is not None) and (key_logit >= 0)
    val_enabled = (val_logit is not None) and (val_logit >= 0)

    # Plot key distribution only if enabled (logit >= 0)
    if key_enabled and data_key.size > 0:
        plt.figure(figsize=(6.5, 4.0))
        plt.hist(data_key, bins=bins, color="#225ea8", alpha=1.0, edgecolor=None)
        plt.title(f"Projector {projector_index} - Key Scalar Mean Distribution")
        plt.xlabel("Mean Key Scalar (per sample)")
        plt.ylabel("Count")
        plt.xlim(0.0, 1.0)
        plt.grid(True, linestyle=":", alpha=0.5)
        if SAVE_FIG:
            out_path = os.path.join(OUTPUT_DIR, f"projector_{projector_index:03d}_key_dist.png")
            plt.tight_layout()
            plt.savefig(out_path, dpi=FIG_DPI, bbox_inches="tight")
            print(f"Saved: {out_path}")
        plt.show()
    else:
        print(f"Skip key plot for projector {projector_index}: logit<{0} or no data")

    # Plot value distribution only if enabled (logit >= 0)
    if val_enabled and data_val.size > 0:
        plt.figure(figsize=(6.5, 4.0))
        plt.hist(data_val, bins=bins, color="#fb6a4a", alpha=1.0, edgecolor=None)
        plt.title(f"Projector {projector_index} - Value Scalar Mean Distribution")
        plt.xlabel("Mean Value Scalar (per sample)")
        plt.ylabel("Count")
        plt.xlim(0.0, 1.0)
        plt.grid(True, linestyle=":", alpha=0.5)
        if SAVE_FIG:
            out_path = os.path.join(OUTPUT_DIR, f"projector_{projector_index:03d}_value_dist.png")
            plt.tight_layout()
            plt.savefig(out_path, dpi=FIG_DPI, bbox_inches="tight")
            print(f"Saved: {out_path}")
        plt.show()
    else:
        print(f"Skip value plot for projector {projector_index}: logit<{0} or no data")

# Run over all projectors
for pid in projector_ids:
    df_proj = df_parsed[df_parsed["projector_index"] == pid]
    plot_projector_distributions(df_proj, projector_index=pid, bins=30)

