In [17]:
import numpy as np
import ctypes
import scipy.stats as stats
from scipy.stats import wasserstein_distance
from sklearn.metrics import roc_auc_score
from sklearn.linear_model import LogisticRegression
from sklearn.model_selection import train_test_split
from datagen import get_dataloader, create_file_pairs

# -------- Helper Functions --------

def reduce_float_precision(value, nmantissa):
    # Constants
    NMANTISSA = 23
    MAX_EXPONENT = 0x7f7  # corresponds to max exponent before Inf (254 in biased form)
    MIN_MANTISSA_BITS = 1

    # Convert float32 to raw bits
    float_val = ctypes.c_float(value)
    bits = ctypes.cast(ctypes.pointer(float_val), ctypes.POINTER(ctypes.c_uint32)).contents.value

    # Handle NaN and Inf (do not compress)
    if (bits & 0x7f800000) == 0x7f800000:
        return value  # Already Inf/NaN

    # If full mantissa requested, or invalid range, return original
    if nmantissa >= NMANTISSA or nmantissa <= 0:
        return value

    # Clamp to minimum mantissa
    if nmantissa < MIN_MANTISSA_BITS:
        nmantissa = MIN_MANTISSA_BITS

    # Compute shift and masks
    shift = NMANTISSA - nmantissa
    mask = (~0) << shift  # Keeps upper nmantissa bits
    round_bit = 1 << (shift - 1)

    mantissa = bits & 0x007fffff
    exponent = (bits >> 23) & 0xff

    # Create a mask for vmax check (from C++)
    vmax = (MAX_EXPONENT << 23) | (0x007fffff ^ round_bit)

    # Perform rounding if it won't cause overflow
    if (bits & 0x7fffffff) < vmax:
        mantissa += round_bit

    # Apply mantissa mask (truncate lower bits)
    mantissa &= mask

    # Reassemble final bits
    compressed_bits = (bits & 0xff800000) | mantissa  # keep sign and exponent

    # Convert bits back to float32
    compressed_uint = ctypes.c_uint32(compressed_bits)
    compressed_float = ctypes.cast(ctypes.pointer(compressed_uint), ctypes.POINTER(ctypes.c_float)).contents.value

    return compressed_float

def float_to_bits(val):
    """Convert a float32 to its raw 32-bit representation."""
    float_val = ctypes.c_float(val)
    bits = ctypes.cast(ctypes.pointer(float_val), ctypes.POINTER(ctypes.c_uint32)).contents.value
    return bits

def print_float_bits(val, label="Value"):
    bits = float_to_bits(val)
    sign = (bits >> 31) & 0x1
    exponent = (bits >> 23) & 0xff
    mantissa = bits & 0x7fffff
    return f"{label}: {val:.6f} | Sign: {sign} Exponent: {exponent:08b} Mantissa: {mantissa:023b}"


def summary_stats(x):
    return {
        "mean": np.mean(x),
        "std": np.std(x),
        "skew": stats.skew(x),
        "kurtosis": stats.kurtosis(x),
    }

def kl_divergence(p, q, bins=100):
    """KL divergence between histograms p and q"""
    p_hist, _ = np.histogram(p, bins=bins, density=True)
    q_hist, _ = np.histogram(q, bins=bins, density=True)
    p_hist += 1e-10; q_hist += 1e-10  # avoid div by zero
    return stats.entropy(p_hist, q_hist)

def js_divergence(p, q, bins=100):
    """Jensen-Shannon divergence"""
    p_hist, _ = np.histogram(p, bins=bins, density=True)
    q_hist, _ = np.histogram(q, bins=bins, density=True)
    m = 0.5 * (p_hist + q_hist)
    return 0.5 * (stats.entropy(p_hist, m) + stats.entropy(q_hist, m))

def mmd_rbf(x, y, gamma=1.0):
    """Maximum Mean Discrepancy with RBF kernel"""

    def kernel(a, b):
        sq_dists = np.sum(a**2, 1).reshape(-1,1) + np.sum(b**2,1) - 2*np.dot(a, b.T)
        sq_dists = np.maximum(sq_dists, 0)  # clamp negatives from precision errors
        K = np.exp(-gamma * np.clip(sq_dists, 0, 1e6))  # prevent overflow
        return K
    
    
    x = x[:min(5000, len(x))].reshape(-1,1)
    y = y[:min(5000, len(y))].reshape(-1,1)
    Kxx = kernel(x, x).mean()
    Kyy = kernel(y, y).mean()
    Kxy = kernel(x, y).mean()
    return Kxx + Kyy - 2*Kxy

def classifier_test(x, y):
    """Train/test classifier to detect original vs compressed"""
    X = np.concatenate([x.reshape(-1,1), y.reshape(-1,1)])
    labels = np.array([0]*len(x) + [1]*len(y))
    X_train, X_test, y_train, y_test = train_test_split(X, labels, test_size=0.3, stratify=labels)
    clf = LogisticRegression(max_iter=1000).fit(X_train, y_train)
    y_score = clf.predict_proba(X_test)[:,1]
    auc = roc_auc_score(y_test, y_score)
    return auc

# -------- Main Triage Pipeline --------

def triage(original, compressed, name="Condition"):
    print(f"\n=== {name} ===")

    # Step 1: Summary
    print("Summary (original):", summary_stats(original))
    print("Summary (compressed):", summary_stats(compressed))

    # Step 2: Divergences
    print("KL divergence:", kl_divergence(original, compressed))
    print("JS divergence:", js_divergence(original, compressed))
    print("Wasserstein:", wasserstein_distance(original, compressed))
    print("MMD:", mmd_rbf(original, compressed))

    # Step 3: Hypothesis tests
    ks_stat, ks_p = stats.ks_2samp(original, compressed)
    print(f"KS test: stat={ks_stat:.4f}, p={ks_p:.4f}")

    # Step 4: Classifier
    auc = classifier_test(original, compressed)
    print("Classifier AUC (0.5 = indistinguishable):", auc)

In [18]:
branch = "AnalysisElectronsAuxDyn"
varnames = ["pt"]
global_min=712
global_max=716800
dataloader = get_dataloader(branch=branch, varnames=varnames, hist=False, range=(global_min, global_max), full_sample_mode=True, batch_size=1, shuffle=False)

In [21]:
for compressed, residual in dataloader:
    compressed = compressed.squeeze()
    residual = residual.squeeze()
    original = compressed + residual
    for nmantissa in [2, 3, 5, 10, 15]:
        compressed_new = np.array([reduce_float_precision(i, nmantissa=nmantissa) for i in original])
        for i in range(3):
            orig_bits = print_float_bits(original[i])
            comp_bits = print_float_bits(compressed_new[i])
            print(f"Value {i}:")
            print(f"  Original:   {original[i]:.6f}, {orig_bits}")
            print(f"  Compressed: {compressed_new[i]:.6f}, {comp_bits}")
        triage(original.numpy(), compressed_new, name=f"Compressed vs Original, nmantissa={nmantissa}")
        print("-"*40)


Value 0:
  Original:   4610.336914, Value: 4610.336914 | Sign: 0 Exponent: 10001011 Mantissa: 00100000001001010110010
  Compressed: 5120.000000, Value: 5120.000000 | Sign: 0 Exponent: 10001011 Mantissa: 01000000000000000000000
Value 1:
  Original:   4209.977539, Value: 4209.977539 | Sign: 0 Exponent: 10001011 Mantissa: 00000111000111111010010
  Compressed: 4096.000000, Value: 4096.000000 | Sign: 0 Exponent: 10001011 Mantissa: 00000000000000000000000
Value 2:
  Original:   3048.412354, Value: 3048.412354 | Sign: 0 Exponent: 10001010 Mantissa: 01111101000011010011001
  Compressed: 3072.000000, Value: 3072.000000 | Sign: 0 Exponent: 10001010 Mantissa: 10000000000000000000000

=== Compressed vs Original, nmantissa=2 ===
Summary (original): {'mean': np.float32(17932.43), 'std': np.float32(24870.092), 'skew': np.float64(5.637500286102295), 'kurtosis': np.float32(79.48475)}
Summary (compressed): {'mean': np.float64(17352.527797185347), 'std': np.float64(24366.59262821769), 'skew': np.float64(

  res = hypotest_fun_out(*samples, **kwds)


Value 0:
  Original:   3160.729736, Value: 3160.729736 | Sign: 0 Exponent: 10001010 Mantissa: 10001011000101110101101
  Compressed: 3072.000000, Value: 3072.000000 | Sign: 0 Exponent: 10001010 Mantissa: 10000000000000000000000
Value 1:
  Original:   29419.267578, Value: 29419.267578 | Sign: 0 Exponent: 10001101 Mantissa: 11001011101011010001001
  Compressed: 28672.000000, Value: 28672.000000 | Sign: 0 Exponent: 10001101 Mantissa: 11000000000000000000000
Value 2:
  Original:   31330.224609, Value: 31330.224609 | Sign: 0 Exponent: 10001101 Mantissa: 11101001100010001110011
  Compressed: 16384.000000, Value: 16384.000000 | Sign: 0 Exponent: 10001101 Mantissa: 00000000000000000000000

=== Compressed vs Original, nmantissa=2 ===
Summary (original): {'mean': np.float32(30765.113), 'std': np.float32(36173.812), 'skew': np.float64(2.5236427783966064), 'kurtosis': np.float32(11.069621)}
Summary (compressed): {'mean': np.float64(29971.259023453957), 'std': np.float64(35617.85680339119), 'skew': 

In [None]:
for compressed, residual in dataloader:
    compressed = compressed.squeeze()
    residual = residual.squeeze()
    original = compressed + residual
    for nmantissa in [2, 3, 5, 10, 15]:
        compressed_new = np.array([reduce_float_precision(i, nmantissa=nmantissa) for i in original])
        triage(np.log10(original.numpy()), np.log10(compressed_new), name=f"Compressed vs Original, nmantissa={nmantissa}")
        print("-"*40)



=== Compressed vs Original, nmantissa=2 ===
Summary (original): {'mean': np.float32(3.992496), 'std': np.float32(0.4559511), 'skew': np.float64(0.5765568614006042), 'kurtosis': np.float32(-0.87060547)}
Summary (compressed): {'mean': np.float64(3.991039458775161), 'std': np.float64(0.45595377696340494), 'skew': np.float64(0.5767835115544582), 'kurtosis': np.float64(-0.8589052773938399)}
KL divergence: 0.19503706721254777
JS divergence: 0.03283681077811146
Wasserstein: 0.0031508548659701916
MMD: 1.9032895977044717e-06
KS test: stat=0.0145, p=0.2424
Classifier AUC (0.5 = indistinguishable): 0.5021385483634638
----------------------------------------

=== Compressed vs Original, nmantissa=3 ===
Summary (original): {'mean': np.float32(3.992496), 'std': np.float32(0.4559511), 'skew': np.float64(0.5765568614006042), 'kurtosis': np.float32(-0.87060547)}
Summary (compressed): {'mean': np.float64(3.991039458775161), 'std': np.float64(0.45595377696340494), 'skew': np.float64(0.5767835115544582),

  res = hypotest_fun_out(*samples, **kwds)



=== Compressed vs Original, nmantissa=2 ===
Summary (original): {'mean': np.float32(4.2097564), 'std': np.float32(0.5127048), 'skew': np.float64(0.07886702567338943), 'kurtosis': np.float32(-1.1666573)}
Summary (compressed): {'mean': np.float64(4.208176450680476), 'std': np.float64(0.5128520846835353), 'skew': np.float64(0.08036117470414714), 'kurtosis': np.float64(-1.1652974576912163)}
KL divergence: 0.025709718041742866
JS divergence: 0.0061159352572704075
Wasserstein: 0.003043748845682501
MMD: 1.978988609652177e-06
KS test: stat=0.0082, p=0.1149
Classifier AUC (0.5 = indistinguishable): 0.5020281753658332
----------------------------------------

=== Compressed vs Original, nmantissa=3 ===
Summary (original): {'mean': np.float32(4.2097564), 'std': np.float32(0.5127048), 'skew': np.float64(0.07886702567338943), 'kurtosis': np.float32(-1.1666573)}
Summary (compressed): {'mean': np.float64(4.208176450680476), 'std': np.float64(0.5128520846835353), 'skew': np.float64(0.0803611747041471

: 