In [None]:
from datasets import load_from_disk
from pathlib import Path

from sgtm.model import GPTNeoForCausalLMSGTM
from torch.utils.data import DataLoader
from transformers import GPT2Tokenizer, DataCollatorForLanguageModeling
import torch
from tqdm.notebook import tqdm
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from collections import defaultdict
import pickle

sns.set_palette("muted")

In [None]:
ES_DS_PATH = Path("data/datasets/tinystories_split/es")
EN_DS_PATH = Path("data/datasets/tinystories_split/en")
BASE_MODEL_DIR = Path("data/models")

MODEL_PATH = "your-model-name"
model = GPTNeoForCausalLMSGTM.from_pretrained(BASE_MODEL_DIR / MODEL_PATH / "output"/ "final-checkpoint").to("cuda")

es_ds = load_from_disk(ES_DS_PATH)
en_ds = load_from_disk(EN_DS_PATH)

tokenizer = GPT2Tokenizer.from_pretrained(BASE_MODEL_DIR / MODEL_PATH / "output"/ "final-checkpoint")
tokenizer.pad_token = tokenizer.eos_token
data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)

In [None]:
es_dataloader = DataLoader(
    es_ds["test"].select_columns(["input_ids", "attention_mask"]),
    batch_size=1,
    collate_fn=data_collator,
    num_workers=4,
    pin_memory=True,
)

en_dataloader = DataLoader(
    en_ds["test"].select_columns(["input_ids", "attention_mask"]),
    batch_size=1,
    collate_fn=data_collator,
    num_workers=4,
    pin_memory=True,
)

dataloaders = {
    "es": es_dataloader,
    "en": en_dataloader,
}

In [None]:
# Print parameter norms for retain and forget parameters
print("Parameter Norms:")
print("-" * 40)

# Calculate full parameter norms
forget_param_norm = 0.0
retain_param_norm = 0.0
forget_param_count = 0
retain_param_count = 0

for name, param in model.named_parameters_split(sgtm_split="forget"):
    param_norm = param.norm().item()
    forget_param_norm += param_norm ** 2
    forget_param_count += param.numel()

for name, param in model.named_parameters_split(sgtm_split="retain"):
    param_norm = param.norm().item()
    retain_param_norm += param_norm ** 2
    retain_param_count += param.numel()

param_norms = {}
param_norms["forget"] = forget_param_norm ** 0.5
param_norms["retain"] = retain_param_norm ** 0.5

# Calculate normalized parameter norms
forget_param_norm_normalized = param_norms["forget"] / (forget_param_count ** 0.5)
retain_param_norm_normalized = param_norms["retain"] / (retain_param_count ** 0.5)

print(f"Forget parameter norm (full): {param_norms['forget']:.6f}")
print(f"Retain parameter norm (full): {param_norms['retain']:.6f}")
print(f"Forget parameter norm (normalized): {forget_param_norm_normalized:.6f}")
print(f"Retain parameter norm (normalized): {retain_param_norm_normalized:.6f}")
print(f"Forget parameter count: {forget_param_count}")
print(f"Retain parameter count: {retain_param_count}")


In [None]:
norms = defaultdict(list)
normnorms = defaultdict(list)
props = defaultdict(list)

for lang, dataloader in dataloaders.items():
    for batch in tqdm(dataloader):

        batch = {k: v.to("cuda") for k, v in batch.items()}
        with torch.amp.autocast("cuda", enabled=True, dtype=torch.bfloat16):
            outputs = model(return_dict=True, **batch)

        outputs.loss.backward()

        for param_group in ("forget", "retain"):
            grad_norm = 0.0
            param_count = 0

            for name, param in model.named_parameters_split(sgtm_split=param_group):
                if param.grad is not None:
                    param_norm = param.grad.detach().data.norm(2)
                    grad_norm += param_norm.item() ** 2
                    param_count += param.numel()

            norms[(lang, param_group)].append(grad_norm ** 0.5)
            normnorms[(lang, param_group)].append( (grad_norm ** 0.5) / param_norms[param_group])


        model.zero_grad()

In [None]:
# Save raw data as pickle
with open("data/grad_norms_raw.pkl", "wb") as f:
    pickle.dump({
        "norms": dict(norms),
        "normnorms": dict(normnorms), 
        "props": dict(props),
        "param_norms": param_norms,
        "param_counts": {"forget": forget_param_count, "retain": retain_param_count}
    }, f)

In [None]:
# Load raw data from pickle
with open("data/grad_norms_raw.pkl", "rb") as f:
    data = pickle.load(f)
    norms = data["norms"]
    normnorms = data["normnorms"]
    props = data["props"]
    param_norms = data["param_norms"]
    param_counts = data["param_counts"]
    forget_param_count = param_counts["forget"]
    retain_param_count = param_counts["retain"]


In [None]:
# Filter out p99 outliers from all metrics
for key in list(norms.keys()):
    values = np.array(norms[key])
    p99 = np.percentile(values, 99)
    filtered_values = values[values <= p99]
    norms[key] = filtered_values.tolist()

for key in list(normnorms.keys()):
    values = np.array(normnorms[key])
    p99 = np.percentile(values, 99)
    filtered_values = values[values <= p99]
    normnorms[key] = filtered_values.tolist()

for key in list(props.keys()):
    values = np.array(props[key])
    p99 = np.percentile(values, 99)
    filtered_values = values[values <= p99]
    props[key] = filtered_values.tolist()


In [None]:
from scipy.stats import gaussian_kde
from matplotlib.lines import Line2D
from matplotlib.patches import Patch


sns.set_palette("muted")
param_colors = {
    'forget': 'C1',
    'retain': 'C0',
}

plot_order = [('en', 'forget'), ('en', 'retain'), ('es', 'forget'), ('es', 'retain')]

plt.figure(figsize=(12, 6))

for lang, param_group in plot_order:
    k = (lang, param_group)
    v = normnorms[k]
    
    color = param_colors[param_group]
    
    kde = gaussian_kde(v)
    x_range = np.linspace(min(v), max(v), 200)
    density = kde(x_range)
    
    linestyle = '--' if lang == 'en' else '-'
    alpha = 0.2 if lang == 'en' else 0.4
    
    plt.plot(x_range, density, color=color, linewidth=3, linestyle=linestyle)
    plt.fill_between(x_range, density, alpha=alpha, color=color)

legend_elements = [
    Patch(facecolor='none', edgecolor='none', label='Data:'),
    Line2D([0], [0], color='gray', linestyle='--', linewidth=4, label='  Retain (English)'),
    Line2D([0], [0], color='gray', linestyle='-', linewidth=4, label='  Forget (Spanish)'),
    Line2D([0], [0], color='none', linewidth=0, label=''),  # Spacer
    Patch(facecolor='none', edgecolor='none', label='Parameters:'),
    Line2D([0], [0], color=param_colors['forget'], linewidth=5, label='  Forget'),
    Line2D([0], [0], color=param_colors['retain'], linewidth=5, label='  Retain'),
]

plt.ylim(0, 450)
plt.xlim(0.0025, 0.025)
plt.xlabel('Normalized Gradient Norm', fontsize=16)
plt.ylabel('Density', fontsize=16)
plt.grid(alpha=0.3)
plt.legend(handles=legend_elements, fontsize=16, framealpha=0.9, 
           handlelength=1.5, handletextpad=0.5)
plt.tight_layout()

In [None]:
from scipy.stats import gaussian_kde
from matplotlib.lines import Line2D
from matplotlib.patches import Patch
from matplotlib.gridspec import GridSpec


sns.set_palette("Set2")
param_colors = {
    'forget': 'C0',
    'retain': 'C2',
}

fig = plt.figure(figsize=(16, 4))
gs = GridSpec(1, 4, figure=fig, width_ratios=[1, 1, 1, 1])

axes = [
    fig.add_subplot(gs[0, 0]),
    fig.add_subplot(gs[0, 1]),
    fig.add_subplot(gs[0, 2]),
    fig.add_subplot(gs[0, 3]),
]

plot_order = [('en', 'forget'), ('en', 'retain'), ('es', 'forget'), ('es', 'retain')]

def plot(ax, lang, param_group):
    k = (lang, param_group)
    v = normnorms[k]
    color = param_colors[param_group]
    
    # Create KDE
    kde = gaussian_kde(v)
    x_range = np.linspace(min(v), max(v), 200)
    density = kde(x_range)
    
    linestyle = '--' if lang == 'en' else '-'
    alpha = 0.2 if lang == 'en' else 0.4
    
    ax.plot(x_range, density, color=color, linewidth=3, linestyle=linestyle)
    ax.fill_between(x_range, density, alpha=alpha, color=color)

def plot_hist(ax, lang, param_group):
    k = (lang, param_group)
    v = normnorms[k]
    color = param_colors[param_group]
    
    alpha = 0.3 if lang == 'en' else 0.6
    ax.hist(v, bins=30, color=color, alpha=alpha, density=True, 
            histtype='stepfilled', edgecolor='black', linewidth=0.5)

legend_elements = [
    Line2D([0], [0], color=param_colors['forget'], linewidth=2, label='Forget weights'),
    Line2D([0], [0], color=param_colors['retain'], linewidth=2, label='Retain weights'),
]

axes[0].set_title("Forget data", fontsize=18)
plot(axes[0], "es", "forget")
plot(axes[0], "es", "retain")
axes[0].legend(handles=legend_elements, fontsize=14)

legend_elements = [
    Line2D([0], [0], color=param_colors['forget'], linestyle='--', linewidth=2, label='Forget weights'),
    Line2D([0], [0], color=param_colors['retain'], linestyle='--', linewidth=2, label='Retain weights'),
]
axes[1].set_title("Retain data", fontsize=18)
plot(axes[1], "en", "forget")
plot(axes[1], "en", "retain")
axes[1].legend(handles=legend_elements, fontsize=14)

legend_elements = [
    Line2D([0], [0], color=param_colors['forget'], linestyle='--', linewidth=2, label='Retain data'),
    Line2D([0], [0], color=param_colors['forget'], linestyle='-', linewidth=2, label='Forget data'),
]
axes[2].set_title("Forget weights", fontsize=18)
plot(axes[2], "es", "forget")
plot(axes[2], "en", "forget")
axes[2].legend(handles=legend_elements, fontsize=14)

legend_elements = [
    Line2D([0], [0], color=param_colors['retain'], linestyle='--', linewidth=2, label='Retain data'),
    Line2D([0], [0], color=param_colors['retain'], linestyle='-', linewidth=2, label='Forget data'),
]
axes[3].set_title("Retain weights", fontsize=18)
plot(axes[3], "es", "retain")
plot(axes[3], "en", "retain")
axes[3].legend(handles=legend_elements, fontsize=14)

for ax in axes:
    ax.set_ylim(0,420)
    ax.set_xlim(0.003,0.03)
    ax.set_xlabel("Relative gradient norm", fontsize=14)
    ax.set_ylabel("Density", fontsize=14)
    ax.locator_params(axis='x', nbins=4)

plt.tight_layout()
plt.show()
