In [None]:
''' 
1. max 
2. mean
3. median
4. percentile
5. top k mean 
6. scale trimmed mean 
7. huber estimator
8. generalized mean (2)
9. log sum exp 
10. high quantile gaussian
11. high quantile gpd 
12. gmm cluster
13. empirical bayes
14. EVT 
15. empirical bayes + EVT
'''

In [3]:
import numpy as np
from scipy import stats, optimize

import torch 

  from .autonotebook import tqdm as notebook_tqdm


In [11]:
x = torch.randn(20)
x = x.numpy() 
x

array([-0.683658  ,  0.65870637,  0.76537055,  1.0086393 ,  1.9101545 ,
        0.5082128 , -0.56180954, -0.72310513,  0.7939475 ,  0.09072132,
       -1.2164989 ,  0.20438556, -0.6047317 , -1.7558211 ,  0.45578036,
       -2.2469485 ,  1.8344615 , -1.7856308 ,  1.1030827 , -0.34027243],
      dtype=float32)

In [None]:
# 1. max 
def scale_max(x):
    return np.max(np.abs(x))

y = scale_max(x) 
y

2.2469485

In [None]:
# 2. mean
def scale_mean(x):
    return np.mean(np.abs(x))

y = scale_mean(x) 
y

0.962597

In [None]:
# 3. median
def scale_median(x):
    return np.median(np.abs(x))

y = scale_median(x) 
y

0.74423784

In [None]:
# 4. percentile 
def scale_percentile(x, p=99.9):
    return np.percentile(np.abs(x), p)

y = scale_percentile(x) 
y

2.240549394249917

In [None]:

def scale_rms(x):
    return np.sqrt(np.mean(x**2))

y = scale_rms(x) 
y

1.1402097

In [None]:
# 5. top k mean 
def scale_topk_mean(x, k=5):
    x_sorted = np.sort(np.abs(x))
    return np.mean(x_sorted[-k:])
 
y = scale_topk_mean(x) 
y

1.9066032

In [None]:
# 6. scale trimmed mean
def scale_trimmed_mean(x, trim_frac=0.02):
    x_sorted = np.sort(np.abs(x))
    n = len(x_sorted)
    k = int(trim_frac * n)
    trimmed = x_sorted[k:n-k]
    return np.mean(trimmed)

y = scale_trimmed_mean(x) 
y

0.9625969

In [None]:
# 7. huber estimator
def huber_estimator(x, delta=1.35, tol=1e-6, max_iter=50):
    x = np.asarray(x)
    mu = np.median(x)   # good robust initial guess

    for _ in range(max_iter):
        r = x - mu
        w = np.where(np.abs(r) <= delta, 1, delta / np.abs(r))
        mu_new = np.sum(w * x) / np.sum(w)
        if abs(mu_new - mu) < tol:
            break
        mu = mu_new
    
    # Robust scale (like std but weighted)
    r = x - mu
    scale = np.sqrt(np.mean(np.minimum(r**2, (delta**2))))
    return mu, scale

def robust_scale_huber(x):
    mu, scale = huber_estimator(x, delta=1.35)
    return mu + 3 * scale


y = robust_scale_huber(x) 
y

2.7534882039763033

In [None]:
# 8. generalized mean
def generalized_mean(x, p=3):
    x = np.asarray(x)
    x = np.maximum(x, 0)  # if activations can be negative, use |x|
    return (np.mean(x ** p)) ** (1.0 / p)

y_3 = generalized_mean(x, 3)
y_4 = generalized_mean(x, 4)
y_5 = generalized_mean(x, 5)
y_6 = generalized_mean(x, 6)
y_7 = generalized_mean(x, 7)
y_8 = generalized_mean(x, 8)
y_9 = generalized_mean(x, 9)
y_10 = generalized_mean(x, 10)

y_3, y_4, y_5, y_6, y_7, y_8, y_9, y_10

(0.9469188303020584,
 1.089635275166533,
 1.1993804904035257,
 1.2851357928430545,
 1.3531919593050141,
 1.4080824480231826,
 1.453066019189784,
 1.4904882792932141)

In [None]:
# 9. log sum exp 
def logsumexp_magnitude(x, beta=2):
    x = np.abs(x)
    m = np.max(x)
    return (1.0 / beta) * (m + np.log(np.sum(np.exp(beta * (x - m)))))

y = logsumexp_magnitude(x) 
y

1.7434104681015015

In [None]:
# 10. high quantile gaussian
import numpy as np
from scipy.stats import norm, lognorm

def high_quantile_gaussian(x, q=0.999):
    x = np.abs(x)
    mu, sigma = norm.fit(x)  # fit Gaussian
    return norm.ppf(q, loc=mu, scale=sigma)

def high_quantile_lognormal(x, q=0.999):
    x = np.abs(x)
    shape, loc, scale = lognorm.fit(x, floc=0)  # force loc=0 for magnitudes
    return lognorm.ppf(q, shape, loc=loc, scale=scale)


y = high_quantile_gaussian(x) 
print(y) 
y = high_quantile_lognormal(x) 
print(y) 


2.8511439410330652
8.36207962848076


In [None]:
# 11. high quantile gpd
from scipy.stats import genpareto

def high_quantile_gpd(x, q=0.999, threshold_percentile=0.95):
    x = np.abs(x)
    threshold = np.percentile(x, threshold_percentile*100)  # peaks-over-threshold
    exceedances = x[x > threshold] - threshold
    c, loc, scale = genpareto.fit(exceedances, floc=0)
    
    # Predict the quantile
    p_exceed = 1 - threshold_percentile
    q_exceed = genpareto.ppf((q - threshold_percentile)/p_exceed, c, loc=loc, scale=scale)
    return threshold + q_exceed

y = high_quantile_gpd(x) 
print(y) 

2.2469263891038804


In [53]:
import numpy as np

# Simulated activations
num_samples = 10
num_channels = 20
X = np.random.randn(num_samples, num_channels) * 2 + 5  # some random activations

# Step 1: Compute per-channel max
m_c = X.max(axis=0)  # shape: (num_channels,)

# Step 2: Compute group-level statistics
mu0 = np.mean(m_c)            # global mean
tau2 = np.var(m_c)             # variance across channels

# Step 3: Estimate per-channel variance (can use sample variance)
sigma2_c = np.var(X, axis=0) / num_samples  # variance of the mean / max approximation

# Step 4: Compute shrinkage weights
w = sigma2_c / (sigma2_c + tau2)

# Step 5: Shrink noisy maxima toward group mean
m_c_shrunk = w * m_c + (1 - w) * mu0

print("Original max (first 5 channels):", m_c[:5])
print("Shrunk max (first 5 channels):", m_c_shrunk[:5])


Original max (first 5 channels): [8.58613278 7.09051212 8.8583193  7.658119   6.35609295]
Shrunk max (first 5 channels): [7.97088074 7.62018385 8.10880594 7.74233665 7.49564283]


In [None]:
# 12. gmm cluster
import numpy as np
from sklearn.mixture import GaussianMixture

# Simulated activations: 1000 samples, 512 channels
num_samples = 10
num_channels = 20
X = np.random.randn(num_samples, num_channels) * 2 + 5  # example activations

num_components = 3  # number of GMM components
per_channel_scale = np.zeros(num_channels)

for c in range(num_channels):
    x_c = X[:, c].reshape(-1, 1)  # shape (num_samples, 1)
    
    # Fit GMM
    gmm = GaussianMixture(n_components=num_components, random_state=0)
    gmm.fit(x_c)
    
    # Identify tail cluster: cluster with largest mean
    means = gmm.means_.flatten()
    stds = np.sqrt(gmm.covariances_).flatten()
    tail_idx = np.argmax(means)
    
    # Set scale: mean + 3*std of tail cluster
    per_channel_scale[c] = means[tail_idx] + 3 * stds[tail_idx]

print("Per-channel scales (first 5):", per_channel_scale[:5])


Per-channel scales (first 5): [6.70508961 8.45659568 8.93770539 9.21990798 8.97574586]


In [61]:
per_channel_scale

array([ 6.70508961,  8.45659568,  8.93770539,  9.21990798,  8.97574586,
       13.29053697,  7.17531501, 10.68463022,  8.23634335,  9.1881867 ,
       11.2889807 ,  7.5935238 ,  7.73557123,  7.69539574,  7.99356908,
        9.72632944, 11.47327141,  9.28671345,  8.47985396,  7.88969108])

In [62]:
X.max(axis=0)

array([ 6.12243265,  8.45359568,  8.21868415,  8.31262432,  8.97274586,
       13.28753697,  6.42464874,  8.40435152,  6.84185469,  7.78277832,
       10.05627357,  6.71632774,  7.36441428,  7.69239574,  7.19787954,
        9.72332944, 10.58647726,  8.40949099,  8.47685396,  7.3711058 ])

In [None]:
# 15. empirical bayes + EVT
import numpy as np
from scipy.stats import genpareto

def robust_per_channel_scale(X, tail_fraction=0.05, quantile=0.999):
    N, C = X.shape
    per_channel_scale = np.zeros(C)
    
    # Global mean for shrinkage
    q_extremes = np.zeros(C)
    tail_sizes = np.zeros(C)
    
    for c in range(C):
        x_c = X[:, c]
        threshold = np.percentile(x_c, 100*(1-tail_fraction))
        tail = x_c[x_c > threshold] - threshold
        tail_sizes[c] = len(tail)
        
        if len(tail) < 2:
            q_extremes[c] = x_c.max()
            continue
        
        xi, loc, beta = genpareto.fit(tail, floc=0)
        k = len(tail)
        # EVT extrapolated quantile
        q_extreme = threshold + (beta / xi) * (((1 - quantile)/(k/N))**(-xi) - 1) if xi != 0 else threshold + beta * np.log((k/N)/(1-quantile))
        q_extremes[c] = q_extreme
    
    # Empirical Bayes shrinkage across channels
    mu0 = np.mean(q_extremes)
    tau2 = np.var(q_extremes)
    sigma2 = 1.0 / np.maximum(tail_sizes, 1)  # uncertainty ~ 1/n_tail
    w = sigma2 / (sigma2 + tau2)
    
    per_channel_scale = w * q_extremes + (1-w) * mu0
    return per_channel_scale

# Example usage
N = 512*128  # total tokens
C = 512
X = np.random.randn(N, C)*2 + 5
scales = robust_per_channel_scale(X)
print("Per-channel scales (first 5):", scales[:5])


Per-channel scales (first 5): [11.19930591 11.19650939 11.19784141 11.2010131  11.20617557]


In [91]:
X.max(axis=0)[:5]

array([13.27780066, 12.74537227, 12.90488008, 13.09937988, 13.75556333])

In [None]:
# 13. emipirical bayes
import numpy as np

def empirical_bayes_per_channel(X):
    """
    Empirical Bayes per-channel scale.
    X: numpy array of shape [num_tokens, num_channels]
    Returns: per-channel scales (robust to noisy maxima)
    """
    N, C = X.shape
    
    # Step 1: Per-channel max (or you can use RMS or mean + 3*std)
    m_c = np.max(np.abs(X), axis=0)
    
    # Step 2: Estimate global statistics
    mu0 = np.mean(m_c)        # global mean across channels
    tau2 = np.var(m_c)         # between-channel variance
    sigma2_c = np.var(X, axis=0) / N  # variance of the mean (approx)
    
    # Step 3: Compute shrinkage weights
    w = sigma2_c / (sigma2_c + tau2)
    
    # Step 4: Shrink noisy per-channel maxima
    per_channel_scale = w * m_c + (1 - w) * mu0
    return per_channel_scale

# Example usage
N = 512 * 128  # total tokens
C = 512
X = np.random.randn(N, C)*2 + 5
eb_scales = empirical_bayes_per_channel(X)
print("Empirical Bayes scales (first 5):", eb_scales[:5])


Empirical Bayes scales (first 5): [13.58906863 13.58899879 13.58914105 13.58906048 13.5889502 ]


In [93]:
X.max(axis=0)[:5]

array([13.83080599, 13.44405546, 14.2285278 , 13.78095582, 13.17600823])

In [None]:
# 14. EVT
from scipy.stats import genpareto

def evt_per_channel_scale(X, tail_fraction=0.05, quantile=0.999):
    """
    EVT per-channel scale estimation.
    X: [num_tokens, num_channels]
    tail_fraction: fraction of top activations to use as tail
    quantile: target extreme quantile to extrapolate to
    Returns: per-channel EVT scales
    """
    N, C = X.shape
    per_channel_scale = np.zeros(C)
    
    for c in range(C):
        x_c = X[:, c]
        
        # Step 1: threshold for tail
        threshold = np.percentile(x_c, 100*(1-tail_fraction))
        tail = x_c[x_c > threshold] - threshold  # exceedances
        
        if len(tail) < 2:
            # fallback if too few tail points
            per_channel_scale[c] = x_c.max()
            continue
        
        # Step 2: fit GPD
        xi, loc, beta = genpareto.fit(tail, floc=0)
        k = len(tail)
        
        # Step 3: extrapolate to desired quantile
        if xi != 0:
            q_extreme = threshold + (beta / xi) * (((1 - quantile)/(k/N))**(-xi) - 1)
        else:
            q_extreme = threshold + beta * np.log((k/N)/(1-quantile))
        
        per_channel_scale[c] = q_extreme
        
    return per_channel_scale

# Example usage
evt_scales = evt_per_channel_scale(X)
print("EVT scales (first 5):", evt_scales[:5])


EVT scales (first 5): [11.17950945 11.19527866 11.15409209 11.32752952 11.20868582]


In [95]:
X.max(axis=0)[:5]

array([13.83080599, 13.44405546, 14.2285278 , 13.78095582, 13.17600823])