In [1]:
import numpy as np
import time
import pwlf
import torch



In [6]:
# === PWL (Piecewise Linear) fitting ===
x_vals = np.linspace(0.01, 64, 1000)
sqrt_vals = np.sqrt(x_vals)
recip_vals = 1 / sqrt_vals

sqrt_model = pwlf.PiecewiseLinFit(x_vals, sqrt_vals)
sqrt_breaks = sqrt_model.fit(8)
sqrt_slopes = sqrt_model.slopes
sqrt_intercepts = sqrt_model.intercepts

recip_model = pwlf.PiecewiseLinFit(x_vals, recip_vals)
recip_breaks = recip_model.fit(8)
recip_slopes = recip_model.slopes
recip_intercepts = recip_model.intercepts




In [8]:
# === PWL approximation function ===
def pwl_approx(x, breakpoints, slopes, intercepts):
    x = np.clip(x, breakpoints[0], breakpoints[-1])
    out = np.zeros_like(x)
    for i in range(len(slopes)):
        mask = (x >= breakpoints[i]) & (x < breakpoints[i + 1])
        out[mask] = slopes[i] * x[mask] + intercepts[i]
    out[x >= breakpoints[-1]] = slopes[-1] * x[x >= breakpoints[-1]] + intercepts[-1]
    return out



In [9]:
# === Timer utility ===
def measure_time(func, *args):
    start = time.perf_counter()
    result = func(*args)
    end = time.perf_counter()
    return result, (end - start) * 1000  # return ms


# === Input variance tensor ===
N, D = 100, 768
np_embeddings = np.random.randn(N, D).astype(np.float32)
torch_input = torch.tensor(np_embeddings, dtype=torch.float32)

# Ground-truth variance
with torch.no_grad():
    true_var = torch.var(torch_input, dim=-1, unbiased=False, keepdim=True).numpy()

# === PWL vs Exact: Accuracy & Timing Comparison ===
sqrt_exact, t_sqrt_exact = measure_time(np.sqrt, true_var + 1e-5)
sqrt_pwl, t_sqrt_pwl = measure_time(
    pwl_approx, true_var + 1e-5, sqrt_breaks, sqrt_slopes, sqrt_intercepts
)

recip_exact, t_recip_exact = measure_time(np.reciprocal, sqrt_exact)
recip_pwl, t_recip_pwl = measure_time(
    pwl_approx, sqrt_pwl, recip_breaks, recip_slopes, recip_intercepts
)

# === Accuracy ===
acc_sqrt = 100 - np.abs(sqrt_exact - sqrt_pwl) / (sqrt_exact + 1e-8) * 100
acc_recip = 100 - np.abs(recip_exact - recip_pwl) / (recip_exact + 1e-8) * 100

# === Print results ===
print("===== Accuracy (% Error) =====")
print(f"[Sqrt PWL]     Mean Accuracy: {acc_sqrt.mean():.4f}%")
print(f"[Recip PWL]    Mean Accuracy: {acc_recip.mean():.4f}%")

print("\n===== Timing (ms) =====")
print(f"[Sqrt Exact]   {t_sqrt_exact:.4f} ms")
print(f"[Sqrt PWL]     {t_sqrt_pwl:.4f} ms")
print(f"[Recip Exact]  {t_recip_exact:.4f} ms")
print(f"[Recip PWL]    {t_recip_pwl:.4f} ms")

===== Accuracy (% Error) =====
[Sqrt PWL]     Mean Accuracy: 99.1760%
[Recip PWL]    Mean Accuracy: 97.9223%

===== Timing (ms) =====
[Sqrt Exact]   0.0075 ms
[Sqrt PWL]     0.2757 ms
[Recip Exact]  0.0080 ms
[Recip PWL]    0.1549 ms


In [10]:
# === Verilog-style scalar PWL approximation ===
def pwl_approx_scalar(x, breakpoints, slopes, intercepts):
    if x < breakpoints[1]:
        return slopes[0] * x + intercepts[0]
    elif x < breakpoints[2]:
        return slopes[1] * x + intercepts[1]
    elif x < breakpoints[3]:
        return slopes[2] * x + intercepts[2]
    elif x < breakpoints[4]:
        return slopes[3] * x + intercepts[3]
    elif x < breakpoints[5]:
        return slopes[4] * x + intercepts[4]
    elif x < breakpoints[6]:
        return slopes[5] * x + intercepts[5]
    elif x < breakpoints[7]:
        return slopes[6] * x + intercepts[6]
    else:
        return slopes[7] * x + intercepts[7]


# === Make vectorized version of scalar function ===
def make_vectorized_pwl(breaks, slopes, intercepts):
    return np.vectorize(
        lambda x: pwl_approx_scalar(x, breaks, slopes, intercepts), otypes=[np.float32]
    )


# === Timer utility ===
def measure_time(func, *args):
    start = time.perf_counter()
    result = func(*args)
    end = time.perf_counter()
    return result, (end - start) * 1000  # return in milliseconds


# === Input data ===
N, D = 100, 768
np_embeddings = np.random.randn(N, D).astype(np.float32)
torch_input = torch.tensor(np_embeddings, dtype=torch.float32)

# === Ground-truth variance ===
with torch.no_grad():
    true_var = torch.var(torch_input, dim=-1, unbiased=False, keepdim=True).numpy()

# === Prepare input for sqrt/recip
sqrt_input = (true_var + 1e-5).flatten()

# === Exact functions ===
sqrt_exact, t_sqrt_exact = measure_time(np.sqrt, sqrt_input)
recip_exact, t_recip_exact = measure_time(np.reciprocal, sqrt_exact)

# === PWL (Verilog style vectorized) functions ===
sqrt_pwl_func = make_vectorized_pwl(sqrt_breaks, sqrt_slopes, sqrt_intercepts)
recip_pwl_func = make_vectorized_pwl(recip_breaks, recip_slopes, recip_intercepts)

sqrt_pwl, t_sqrt_pwl = measure_time(sqrt_pwl_func, sqrt_input)
recip_pwl, t_recip_pwl = measure_time(recip_pwl_func, sqrt_pwl)

# === Accuracy ===
acc_sqrt = 100 - np.abs(sqrt_exact - sqrt_pwl) / (sqrt_exact + 1e-8) * 100
acc_recip = 100 - np.abs(recip_exact - recip_pwl) / (recip_exact + 1e-8) * 100

# === Print results ===
print("===== Accuracy (% Error) =====")
print(f"[Sqrt PWL]     Mean Accuracy: {acc_sqrt.mean():.4f}%")
print(f"[Recip PWL]    Mean Accuracy: {acc_recip.mean():.4f}%")

print("\n===== Timing (ms) =====")
print(f"[Sqrt Exact]   {t_sqrt_exact:.4f} ms")
print(f"[Sqrt PWL]     {t_sqrt_pwl:.4f} ms")
print(f"[Recip Exact]  {t_recip_exact:.4f} ms")
print(f"[Recip PWL]    {t_recip_pwl:.4f} ms")

===== Accuracy (% Error) =====
[Sqrt PWL]     Mean Accuracy: 99.2767%
[Recip PWL]    Mean Accuracy: 98.1566%

===== Timing (ms) =====
[Sqrt Exact]   0.0081 ms
[Sqrt PWL]     0.2141 ms
[Recip Exact]  0.0047 ms
[Recip PWL]    0.2011 ms
