# Resonance Basin Filter – Demo

In [None]:

# Resonance Basin Filter – Demo
# This notebook demonstrates the damping of high-norm latent activations using
# the resonance-basin filter described in the paper.
import torch
import numpy as np
import matplotlib.pyplot as plt

# Deterministic for reproducibility
torch.manual_seed(42)

def resonance_basin_filter(x, kappa=0.8, tau=1.2):
    beta = kappa / 2.0
    norms = torch.norm(x, dim=-1, keepdim=True)
    mask = norms > tau
    x_filtered = x * torch.exp(-beta * norms**2)
    return torch.where(mask, x_filtered, x)

# Generate synthetic latent activations (batch, dim)
batch, dim = 4096, 512
x = torch.randn(batch, dim)

# Apply filter
x_f = resonance_basin_filter(x, kappa=0.8, tau=1.2)

# Norms before/after
n0 = torch.norm(x, dim=-1).numpy()
n1 = torch.norm(x_f, dim=-1).numpy()

# Plot histograms of norms
plt.figure(figsize=(6,4))
plt.hist(n0, bins=60, alpha=0.5, label="Before")
plt.hist(n1, bins=60, alpha=0.5, label="After")
plt.xlabel("Activation norm")
plt.ylabel("Count")
plt.legend()
plt.title("Resonance Basin: Activation Norms (Before vs After)")
plt.tight_layout()
plt.show()


In [None]:

# Scatter of per-sample shrinkage
import numpy as np
shrink = n0 - n1
plt.figure(figsize=(6,4))
plt.plot(np.sort(shrink))
plt.xlabel("Sample (sorted)")
plt.ylabel("Norm shrinkage")
plt.title("Per-sample norm shrinkage due to basin filtering")
plt.tight_layout()
plt.show()
