## Template for Figures in Paper

In [1]:
import numpy as np
import matplotlib
import matplotlib.pyplot as plt

from pathlib import Path

### Colors & Size

In [2]:
# Colors
BLACK = '#000000'
DARK = '#BFBFBF'
LIGHT = '#f2f2f2'
GREY = "#A7BED3"
LIGHTBLUE = "#83C5BE"
ORANGE = "#E8C547"
RED = "#F2545B"
REDS = ["#FF969C", "#F2545B", "#B52028", "#6E0F14"]
BLUE = "#2b50aa"
BLUES = ["#789CF5", "#5177D8", "#2b50aa", "#153584", "#09205A"]
GREYS = ['#778da9', '#415a77', '#1b263b', "#7b7b7b"]
PURPLE = "#419D78"
GREEN = "#136F63"

COLOR_PALETTE = ['#8ecae6', '#219ebc', '#023047', '#153584', '#fb8500',  "#B52028", "#FF000D"]
# COLOR_PALETTE = ['#001219', '#005f73']
# COLOR_PALETTE = ['#E63946', '#A8DADC', '#457B9D', '#1D3557', '#F4A261', '#2A9D8F', '#6D597A', '#E9C46A']
# Size
TEXT = 7.5
MARGIN = 2.5

def pltheight(w: float, h: int = 1, v: int = 1, f: float = 0.8) -> float:
  return f * w * v / h

## LaTeX Settings

In [3]:
def setup_plots(matplotlib):
  matplotlib.use("pgf")
  matplotlib.rcParams.update({
    'pgf.texsystem': 'pdflatex',
    'font.family': 'serif',
    'font.serif': 'Times New Roman',
    'text.usetex': True,
    'text.latex.preamble': r'\usepackage{amsmath}\usepackage{amssymb}',
    'pgf.preamble': r'\usepackage{amsmath}\usepackage{amssymb}',
    'pgf.rcfonts': False,
    'lines.linewidth': 1.75,
    'lines.markersize': 3,
    'axes.titlesize': 'small',
    'axes.labelsize': 'small',
    'xtick.labelsize': 'small',
    'ytick.labelsize': 'small',
    'axes.linewidth': 1.25,         # Adjust axis boundary thickness
    'xtick.major.size': 6,          # Length of major ticks on x-axis
    'xtick.major.width': 1.25,      # Thickness of major ticks on x-axis
    'ytick.major.size': 6,          # Length of major ticks on y-axis
    'ytick.major.width': 1.25,      # Thickness of major ticks on y-axis
    'xtick.minor.size': 3,          # Length of minor ticks on x-axis
    'xtick.minor.width': 1,         # Thickness of minor ticks on x-axis
    'ytick.minor.size': 3,          # Length of minor ticks on y-axis
    'ytick.minor.width': 1,         # Thickness of minor ticks on y-axis
    'axes.facecolor': LIGHT,        # Set light gray background for all axes
    'figure.facecolor': 'white',    # Optionally set figure background (around the plot) to white
    'grid.color': 'white',          # Set grid lines color to white
    'grid.linewidth': 1.25,         # Set grid line width
    'grid.alpha': 1.0,              # Set the opacity of the grid lines (1.0 for solid)
    'grid.linestyle': '-',          # Set solid grid line style
    'axes.grid': True,              # Enable grid by default
    'axes.grid.which': 'major',     # Show grid for both major and minor ticks
  })

### Import Data

In [4]:
# Import data
# todo: load the data from your JSON / YAML /CSV file here

In [5]:
# Dummy data
clustersWiki = np.arange(0, 100, 5)
lossWiki = np.random.uniform(1, 3, size=len(clustersWiki))

clustersGit = np.arange(0, 100, 5)
lossGit = np.random.uniform(1, 3, size=len(clustersGit))

### Create Figure

In [6]:
# Create figure
fig, ax = plt.subplots(1, 2, sharex=False, sharey=False, figsize=(TEXT * 0.66, 1.5))
fig.subplots_adjust(wspace=0.4)

# Subplot 0
ax[0].plot(clustersWiki, lossWiki, '-', color=RED, label="Wikipedia")
ax[0].plot(clustersGit, lossGit, '--', color=BLACK, label="GitHub (Python)")

ax[0].set_xlabel('Number of clusters ($K$)')
ax[0].set_ylabel('$k$-means loss')
ax[0].set_title('Determining $K$')

# Data from the table
top_k = np.array([1, 3, 5, 10, 25])

# Wikipedia and GitHub accuracy and Â± values
wikipedia_acc = np.array([0.67, 0.87, 0.92, 0.97, 0.99])
github_acc = np.array([0.58, 0.82, 0.90, 0.96, 0.99])

# Convert accuracy to error = 1 - accuracy
wikipedia_error = wikipedia_acc
github_error = github_acc

# Subplot 1
ax[1].plot(top_k, wikipedia_error, linestyle='-', color=RED, marker='o', markersize=5, markeredgecolor='white', markeredgewidth=0.8, label="Wikipedia")
ax[1].plot(top_k, github_error, linestyle='--', color=BLACK, marker='s', markersize=5, markeredgecolor='white', markeredgewidth=0.8, label="GitHub (Python)")

ax[1].set_xlabel('Number of active experts ($N$)')
ax[1].set_ylabel('pass@$N$ accuracy')
ax[1].set_title("Determining $N$")
ax[1].set_xscale('log')

# Formatting
handles, labels = ax[0].get_legend_handles_labels()
fig.legend(handles, labels, loc='upper center', bbox_to_anchor=(0.5, 1.22), ncol=3, prop={"size": 8})

# Output directory
plot_name = 'example'
out_dir = Path('output')
out_dir.mkdir(parents=True, exist_ok=True)

# Save as PNG
plt.savefig(f'output/{plot_name}.png', dpi=300, bbox_inches='tight')

# Save as PGF
setup_plots(matplotlib)
plt.savefig(f'output/{plot_name}.pgf', bbox_inches='tight')
plt.close()

In [17]:
import re
import torch

# ==========================================
# CONFIGURATION - Set rank directly (no argparse in notebooks)
# ==========================================
RANK = 32  # Change to 64 if needed
BASE_DIR = Path("/path/to/work/outputs")
MAX_TIMESTEPS = 1024

# Custom Palette
# COLOR_PALETTE = ['#E63946', '#A8DADC', '#457B9D', '#1D3557', '#F4A261', '#2A9D8F', '#6D597A', '#E9C46A']

# ==========================================
# Helper Functions
# ==========================================

def preprocess_tensor(data):
    """Standardizes tensor to [Heads, Time] numpy array and truncates time."""
    if isinstance(data, torch.Tensor):
        arr = data.float().detach().cpu().numpy()
    else:
        arr = np.array(data)

    if arr.ndim == 1:
        arr = arr.reshape(1, -1)

    if arr.ndim == 2:
        rows, cols = arr.shape
        if rows > cols: 
            arr = arr.T
    
    if arr.shape[1] > MAX_TIMESTEPS:
        arr = arr[:, :MAX_TIMESTEPS]
        
    return arr

def load_data_from_dir(ranks_dir, file_pattern_str):
    """Load .pt files matching a pattern, return averaged data."""
    if not ranks_dir.exists():
        print(f"Warning: Directory not found: {ranks_dir}")
        return None

    layer_regex = re.compile(r"layer_(\d+)_")
    pt_files = list(ranks_dir.glob(file_pattern_str))
    
    if not pt_files:
        print(f"Warning: No files matching '{file_pattern_str}' found in {ranks_dir}")
        return None

    layer_data = []
    
    for pt_file in pt_files:
        match = layer_regex.search(pt_file.name)
        if not match:
            continue
            
        try:
            data = torch.load(pt_file, map_location='cpu', weights_only=False)
            arr = preprocess_tensor(data)
            avg_over_heads = np.mean(arr, axis=0)
            layer_data.append(avg_over_heads)
        except Exception as e:
            print(f"Error reading {pt_file.name}: {e}")
            continue

    if not layer_data:
        return None

    min_len = min(len(x) for x in layer_data)
    truncated_layers = [x[:min_len] for x in layer_data]
    final_avg = np.mean(truncated_layers, axis=0)
    return final_avg

# ==========================================
# Load Data
# ==========================================

path_uncompressed_dir = BASE_DIR / "effective_rank_results_dn/delta_net/370m/10BT/ranks"
path_std_comp_dir = BASE_DIR / f"effective_rank_results_dn_non_adv/non_adversarial_compressed_{RANK}/ranks"
path_struct_comp_dir = BASE_DIR / f"effective_rank_results_dn_structured/structured_compressed_{RANK}/ranks"
path_rrqr_comp_dir = BASE_DIR / f"effective_rank_results_dn_rrqr_data/rrqr_data_compressed_{RANK}/ranks"
path_wanda_comp_dir = BASE_DIR / f"effective_rank_results_dn_wanda/wanda_compressed_{RANK}/ranks"
path_random_comp_dir = BASE_DIR / f"effective_rank_results_dn_random/random_compressed_{RANK}/ranks"
path_l2_comp_dir = BASE_DIR / f"effective_rank_results_dn_l2/l2_compressed_{RANK}/ranks"
path_l1_comp_dir = BASE_DIR / f"effective_rank_results_dn_l1/l1_compressed_{RANK}/ranks"

print("Loading data...")
data_uncompressed = load_data_from_dir(path_uncompressed_dir, "*layer_*.pt")
data_std_comp = load_data_from_dir(path_std_comp_dir, f"*non_adversarial_compressed_{RANK}.pt")
data_struct_comp = load_data_from_dir(path_struct_comp_dir, f"*structured_compressed_{RANK}.pt")
data_rrqr_comp = load_data_from_dir(path_rrqr_comp_dir, f"*rrqr_data_compressed_{RANK}.pt")
data_wanda_comp = load_data_from_dir(path_wanda_comp_dir, f"*wanda_compressed_{RANK}.pt")
data_random_comp = load_data_from_dir(path_random_comp_dir, f"*random_compressed_{RANK}.pt")
data_l2_comp = load_data_from_dir(path_l2_comp_dir, f"*l2_compressed_{RANK}.pt")
data_l1_comp = load_data_from_dir(path_l1_comp_dir, f"*l1_compressed_{RANK}.pt")
print("Data loading complete.")

Loading data...
Data loading complete.


In [18]:
# ==========================================
# Create Figure
# ==========================================

# Set figure width to match LaTeX column width (3.25 inches) and minimize height/whitespace
fig, ax = plt.subplots(figsize=(3.25, pltheight(3.25, h=1, v=1, f=0.8)))  # Reduce height for minimal whitespace
fig.subplots_adjust(top=0.995)  # or a value slightly less than 1

LW=0.7
LWT=1  # Thinner lines for plot boundaries and lines

# Plot each dataset if available
if data_uncompressed is not None:
    ax.plot(np.arange(len(data_uncompressed)), data_uncompressed, 
            label='Baseline', color=BLACK, linestyle='-', alpha=0.9, linewidth=LWT)

if data_std_comp is not None:
    ax.plot(np.arange(len(data_std_comp)), data_std_comp, 
            label='PCA', color=GREYS[0], linestyle='-', linewidth=LW)

if data_rrqr_comp is not None:
    ax.plot(np.arange(len(data_rrqr_comp)), data_rrqr_comp, 
            label='DRRQR', color=REDS[1], linestyle='-', linewidth=LWT)

if data_struct_comp is not None:
    ax.plot(np.arange(len(data_struct_comp)), data_struct_comp, 
            label='Structured', color=REDS[2], linestyle='--', linewidth=LWT)

if data_wanda_comp is not None:
    ax.plot(np.arange(len(data_wanda_comp)), data_wanda_comp,
            label='Wanda', color=GREYS[1], linestyle='--', linewidth=LW)

if data_random_comp is not None:
    ax.plot(np.arange(len(data_random_comp)), data_random_comp,
            label='Random', color=GREYS[2], linestyle=':', linewidth=LW)

if data_l1_comp is not None:
    ax.plot(np.arange(len(data_l1_comp)), data_l1_comp,
            label='L1', color=GREYS[3], linestyle='--', linewidth=LW)
# Labels and formatting
ax.set_xlabel('$t$')
ax.set_ylabel('Rank Utilization')
# Set custom y-ticks (major: labeled, minor: unlabeled)
# Updated to step size of 0.025
major_yticks = [0.0, 0.025, 0.05, 0.075, 0.1, 0.125, 0.15, 0.175, 0.2]
minor_yticks = []

ax.set_yticks(major_yticks)
# Labels match the new ticks; only 0.0, 0.1, and 0.2 are labeled
ax.set_yticklabels(['0.0', '', '', '', '0.1', '', '', '', '0.2'])
ax.set_yticks(minor_yticks, minor=True)
# Set custom x-ticks (major: labeled, minor: unlabeled)
major_xticks = [0, 128, 256, 384, 512, 640, 768, 896, 1024]
minor_xticks = []
ax.set_xticks(major_xticks)
ax.set_xticklabels(['0', '', '', '', '512', '', '', '',  '1024'])
ax.set_xticks(minor_xticks, minor=True)
# Limit x range
ax.set_xlim(0, 1024)
# Decrease boundary (spine) and tick thickness
for spine in ax.spines.values():
    spine.set_linewidth(0.7)
ax.tick_params(width=0.7, length=4, which='major')
ax.tick_params(width=0.5, length=2, which='minor')
# Place legend below the plot, outside axes area, below x-label
fig.subplots_adjust(bottom=0.18, left=0.01, right=0.99, top=0.99)
ax.legend(loc='upper center', bbox_to_anchor=(0.5, -0.22), fontsize='x-small', ncol=3, frameon=False)

plt.tight_layout(pad=0)  # Remove all padding/whitespace
# Output settings
plot_name = f'pca_vs_structured_rank_{RANK}'
out_dir = Path('output')
out_dir.mkdir(parents=True, exist_ok=True)

# Save as PNG first (before switching to PGF backend)
plt.savefig(f'output/{plot_name}.png', dpi=300, bbox_inches='tight', pad_inches=0.01)

# Save as PGF
setup_plots(matplotlib)
plt.savefig(f'output/{plot_name}.pgf', bbox_inches='tight', pad_inches=0)
plt.close()

print(f"Saved: output/{plot_name}.png and output/{plot_name}.pgf")

Saved: output/pca_vs_structured_rank_32.png and output/pca_vs_structured_rank_32.pgf


In [11]:
# ==========================================
# Rank Utilization vs Perplexity Scatter Plot
# ==========================================

import scipy.stats as stats
from collections import defaultdict

# Data Source: "Model Name": [(d_k, Rank Util, Perplexity), ...]
data_source = {
    "DN 370m": [
        (89, 0.0849, 30.80), (89, 0.0861, 30.70), (89, 0.0931, 28.80), (89, 0.0931, 28.80),
        (76, 0.0934, 31.70), (76, 0.0963, 31.20), (76, 0.1063, 29.00), (76, 0.1062, 29.10),
        (64, 0.0966, 32.50), (64, 0.1067, 32.10), (64, 0.1222, 29.40), (64, 0.1213, 29.40),
        (32, 0.0865, 36.30), (32, 0.1044, 35.40), (32, 0.1841, 31.40), (32, 0.1889, 31.50),
    ],
    "DN 1.3B": [
        (89, 0.0192, 18.10), (89, 0.0191, 18.00), (89, 0.0183, 16.80), (89, 0.0177, 16.90),
        (76, 0.0224, 18.50), (76, 0.0222, 18.20), (76, 0.0216, 17.10), (76, 0.0206, 17.10),
        (64, 0.0263, 19.00), (64, 0.0261, 18.50), (64, 0.0257, 17.50), (64, 0.0243, 17.20),
        (32, 0.0491, 20.20), (32, 0.0489, 20.00), (32, 0.0498, 19.20), (32, 0.0472, 19.00),
    ],
    "GDN 370m": [
        (179, 0.0500, 28.10), (179, 0.0509, 27.80), (179, 0.0509, 27.80),
        (153, 0.0552, 28.60), (153, 0.0565, 28.20), (153, 0.0564, 28.20),
        (128, 0.0622, 29.20), (128, 0.0634, 28.70), (128, 0.0635, 28.70),
        (64, 0.0928, 32.20), (64, 0.0932, 31.40), (64, 0.0956, 31.40),
    ],
    "GDN 1.3B": [
        (179, 0.0413, 16.30), (179, 0.0434, 15.90), (179, 0.0430, 15.90),
        (153, 0.0457, 16.50), (153, 0.0483, 16.10), (153, 0.0477, 16.10),
        (128, 0.0508, 16.80), (128, 0.0541, 16.30), (128, 0.0530, 16.40),
        (64, 0.0732, 18.60), (64, 0.0813, 17.90), (64, 0.0775, 17.80),
    ]
}

# Model colors using notebook palette
model_colors = {
    "DN 370m": BLUES[2],
    "DN 1.3B": REDS[1],
    "GDN 370m": GREYS[0],
    "GDN 1.3B": GREYS[2]
}

# ==========================================
# Process Data (Residuals + Normalization)
# ==========================================
processed_data = []

for model_name, points in data_source.items():
    groups = defaultdict(list)
    for p in points:
        dk, util, ppl = p
        groups[dk].append((util, ppl))
    
    res_util_list = []
    res_ppl_list = []
    
    for dk, group_points in groups.items():
        utils = np.array([g[0] for g in group_points])
        ppls = np.array([g[1] for g in group_points])
        res_util_list.extend(utils - np.mean(utils))
        res_ppl_list.extend(ppls - np.mean(ppls))
    
    std_u = np.std(res_util_list)
    std_p = np.std(res_ppl_list)
    
    norm_res_util = np.array(res_util_list) / std_u
    norm_res_ppl = np.array(res_ppl_list) / std_p
    
    for x, y in zip(norm_res_util, norm_res_ppl):
        processed_data.append({'x': x, 'y': y, 'model': model_name})

# ==========================================
# Calculate Trend Lines
# ==========================================
all_x = np.array([d['x'] for d in processed_data])
all_y = np.array([d['y'] for d in processed_data])

exclude_x = np.array([d['x'] for d in processed_data if d['model'] != "DN 1.3B"])
exclude_y = np.array([d['y'] for d in processed_data if d['model'] != "DN 1.3B"])

z_all = np.polyfit(all_x, all_y, 1)
p_all = np.poly1d(z_all)
rho_all, _ = stats.spearmanr(all_x, all_y)

z_excl = np.polyfit(exclude_x, exclude_y, 1)
p_excl = np.poly1d(z_excl)
rho_excl, _ = stats.spearmanr(exclude_x, exclude_y)

# ==========================================
# Create Figure
# ==========================================
fig, ax = plt.subplots(figsize=(3.25, pltheight(3.25, h=1, v=1, f=0.9)))

# Scatter points
for model_name in model_colors:
    mx = [d['x'] for d in processed_data if d['model'] == model_name]
    my = [d['y'] for d in processed_data if d['model'] == model_name]
    ax.scatter(mx, my, color=model_colors[model_name], s=25, alpha=0.8, label=model_name,
               edgecolors='none', linewidths=0.3)

# Trend lines
x_range = np.linspace(min(all_x), max(all_x), 100)

ax.plot(x_range, p_all(x_range), color=REDS[1], linestyle='--', linewidth=0.8, alpha=0.9,
        label=f'All')

ax.plot(x_range, p_excl(x_range), color=BLUES[3], linestyle='-', linewidth=1.2, alpha=0.9,
        label=f'Excl. DN 1.3B')

# Reference lines
#ax.axhline(0, color='gray', alpha=0.3, linewidth=0.5)
#ax.axvline(0, color='gray', alpha=0.3, linewidth=0.5)

# Labels
ax.set_xlabel('Relative Rank Utilization')
ax.set_ylabel('Relative Perplexity')
ax.set_axisbelow(True)


# Spine and tick styling
for spine in ax.spines.values():
    spine.set_linewidth(0.7)
ax.tick_params(width=0.7, length=4, which='major')

# Legend
ax.legend(loc='upper center', bbox_to_anchor=(0.5, -0.18), fontsize='x-small', ncol=3, frameon=False)

plt.tight_layout(pad=0.1)

# Save
plot_name = 'rank_util_ppl_scatter'
out_dir = Path('output')
out_dir.mkdir(parents=True, exist_ok=True)

plt.savefig(f'output/{plot_name}.png', dpi=300, bbox_inches='tight', pad_inches=0.01)

setup_plots(matplotlib)
plt.savefig(f'output/{plot_name}.pgf', bbox_inches='tight', pad_inches=0)
plt.close()

print(f"Saved: output/{plot_name}.png and output/{plot_name}.pgf")

Saved: output/rank_util_ppl_scatter.png and output/rank_util_ppl_scatter.pgf
