In [None]:
import time
import math
import copy
import numpy as np
from scipy.special import erf

import torch
import torch.nn as nn
from transformers import EfficientNetForImageClassification

# Standard GELU
def gelu_np(x: np.ndarray) -> np.ndarray:
    return 0.5 * x * (1.0 + erf(x / np.sqrt(2.0)))

# --- Your CachedGELU (float32 + slope based) ---
class CachedGELU:
    def __init__(self, x_min=-10.0, x_max=10.0, N=10001):
        self.x_min = float(x_min)
        self.x_max = float(x_max)
        self.N = int(N)

        self.step = (self.x_max - self.x_min) / (self.N - 1)
        self.inv_step = 1.0 / self.step

        x_table = np.linspace(self.x_min, self.x_max, self.N, dtype=np.float32)
        self.y_table = 0.5 * x_table * (1.0 + erf(x_table / np.sqrt(2.0))).astype(np.float32)

        # Precompute slopes for interpolation
        self.slope = np.empty_like(self.y_table)
        self.slope[:-1] = np.diff(self.y_table)
        self.slope[-1] = self.slope[-2]

    def apply(self, X: np.ndarray) -> np.ndarray:
        Xf = np.ascontiguousarray(X, dtype=np.float32)

        # Compute table index and fractional part
        idx_f = (Xf - self.x_min) * self.inv_step
        idx = idx_f.astype(np.int32)
        np.clip(idx, 0, self.N - 1, out=idx)
        idx_f -= idx  # now idx_f is frac
        out = self.y_table[idx] + self.slope[idx] * idx_f
        mask_outside = (Xf < self.x_min) | (Xf > self.x_max)
        if np.any(mask_outside):
            X_out = Xf[mask_outside]
            out[mask_outside] = 0.5 * X_out * (1.0 + erf(X_out / np.sqrt(2.0)))

        return out

# PyTorch Wrappers
class PlainPythonGELU(nn.Module):
    def __init__(self, torch_dtype=torch.float32):
        super().__init__()
        self.np_dtype = np.float32 if torch_dtype == torch.float32 else np.float64

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x_cpu = x.detach().cpu().numpy().astype(self.np_dtype)
        out_np = gelu_np(x_cpu).astype(self.np_dtype)
        return torch.from_numpy(out_np).to(x.device).type(x.dtype)

class PlainPythonCachedGELU(nn.Module):
    def __init__(self, x_min=-10, x_max=10, N=10001, torch_dtype=torch.float32):
        super().__init__()
        self.np_dtype = np.float32
        self.cache = CachedGELU(x_min=x_min, x_max=x_max, N=N)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x_cpu = x.detach().cpu().numpy().astype(self.np_dtype)
        out_np = self.cache.apply(x_cpu)
        return torch.from_numpy(out_np).to(x.device).type(x.dtype)

# Replace GELU/SiLU activations
def replace_activations(module: nn.Module, new_act_module: nn.Module, inplace=False):
    if not inplace:
        module = copy.deepcopy(module)
    for name, child in module.named_children():
        if isinstance(child, nn.SiLU) or isinstance(child, nn.GELU):
            setattr(module, name, copy.deepcopy(new_act_module))
        else:
            replaced = replace_activations(child, new_act_module, inplace=True)
            setattr(module, name, replaced)
    return module

# Benchmark inference time
def benchmark_model(model: nn.Module, input_tensor: torch.Tensor, runs=20, warmup=2):
    device = torch.device('cpu')
    model.eval()
    model.to(device)
    inp = input_tensor.to(device).type(input_tensor.dtype)
    with torch.no_grad():
        for _ in range(warmup):
            out_w = model(inp)
            logits_w = out_w.logits if hasattr(out_w, 'logits') else out_w
            _ = logits_w.detach().cpu()
    times = []
    last_output = None
    with torch.no_grad():
        for _ in range(runs):
            t0 = time.perf_counter()
            out = model(inp)
            logits = out.logits if hasattr(out, 'logits') else out
            t1 = time.perf_counter()
            times.append((t1 - t0) * 1000.0)
            last_output = logits.detach().cpu()
    return times, last_output

# Main test script
def main():
    device = torch.device('cpu')
    model_name = "google/efficientnet-b7"
    print("Loading model:", model_name)
    model = EfficientNetForImageClassification.from_pretrained(model_name)
    first_param = next(model.parameters())
    model_dtype = first_param.dtype
    print("Model parameter dtype:", model_dtype)
    input_size = model.config.image_size if hasattr(model.config, 'image_size') else 300
    print("Using input size:", input_size)

    dummy_input = torch.randn(1, 3, input_size, input_size, dtype=model_dtype)

    plain_act = PlainPythonGELU(torch_dtype=model_dtype)
    cached_act = PlainPythonCachedGELU(x_min=-10, x_max=10, N=10001, torch_dtype=model_dtype)

    model_plain = replace_activations(model, plain_act, inplace=False).to(model_dtype)
    model_cached = replace_activations(model, cached_act, inplace=False).to(model_dtype)

    runs = 20
    warmup = 2

    print("\nBenchmarking on CPU only:")
    times_plain, out_plain = benchmark_model(model_plain, dummy_input, runs=runs, warmup=warmup)
    mean_plain, std_plain = np.mean(times_plain), np.std(times_plain)
    print(f"Plain-Python GELU inference time: {mean_plain:.1f} ms ± {std_plain:.1f}")

    times_cached, out_cached = benchmark_model(model_cached, dummy_input, runs=runs, warmup=warmup)
    mean_cached, std_cached = np.mean(times_cached), np.std(times_cached)
    print(f"Cached-Python GELU inference time: {mean_cached:.1f} ms ± {std_cached:.1f}")

    if out_plain.shape == out_cached.shape:
        diff = out_plain.numpy() - out_cached.numpy()
        max_abs = np.max(np.abs(diff))
        mse = np.mean(diff**2)
        print(f"Output max abs difference: {max_abs:.6f}")
        print(f"Output MSE: {mse:.6e}")
    else:
        print("Output shapes differ; cannot compare.")

    print("\nDone.")

if __name__ == "__main__":
    main()
