In [None]:
import torch
import torch.nn as nn
from transformers import BertForSequenceClassification, BertTokenizerFast
from datasets import load_dataset
from tqdm import tqdm
import time

# ---- CachedGELU Implementation ----
class CachedGELU(nn.Module):
    def __init__(self, x_min=-100.0, x_max=100.0, N=50000):
        super().__init__()
        self.x_min = x_min
        self.x_max = x_max
        self.N = N
        self.step = (x_max - x_min) / (N - 1)
        self.inv_step = 1.0 / self.step
        x_table = torch.linspace(x_min, x_max, N)
        y_table = 0.5 * x_table * (1.0 + torch.erf(x_table / torch.sqrt(torch.tensor(2.0))))
        slope = torch.diff(y_table, append=y_table[-1].unsqueeze(0))
        self.register_buffer("y_table", y_table)
        self.register_buffer("slope", slope)

    def forward(self, x):
        x_clamped = torch.clamp(x, self.x_min, self.x_max)
        idx_f = (x_clamped - self.x_min) * self.inv_step
        idx = idx_f.long().clamp(0, self.N - 1)
        frac = idx_f - idx.float()
        y_val = self.y_table[idx]
        m_val = self.slope[idx]
        approx = y_val + frac * m_val
        gelu_exact = 0.5 * x * (1.0 + torch.erf(x / torch.sqrt(torch.tensor(2.0, device=x.device))))
        return torch.where((x < self.x_min) | (x > self.x_max), gelu_exact, approx)

# ---- Tanh-Based GELU Approximation ----
class GELUTanhApprox(nn.Module):
    def forward(self, x):
        coeff = torch.sqrt(torch.tensor(2.0 / torch.pi, device=x.device))
        return 0.5 * x * (1.0 + torch.tanh(coeff * (x + 0.044715 * x.pow(3))))

# ---- Sigmoid-Based GELU Approximation ----
class GELUSigmoidApprox(nn.Module):
    def forward(self, x):
        return x * torch.sigmoid(1.702 * x)

# ---- Replace GELU ----
def replace_gelu(model, mode="standard"):
    for name, module in model.named_modules():
        if isinstance(module, nn.GELU):
            if mode == "cached":
                new_act = CachedGELU()
            elif mode == "tanh":
                new_act = GELUTanhApprox()
            elif mode == "sigmoid":
                new_act = GELUSigmoidApprox()
            else:
                new_act = nn.GELU()
            parent = model
            *path, last = name.split(".")
            for p in path:
                parent = getattr(parent, p)
            setattr(parent, last, new_act)
    return model

# ---- Load BERT Model ----
def load_bert(mode="standard", compile_model=False):
    model = BertForSequenceClassification.from_pretrained("textattack/bert-base-uncased-SST-2")
    model.eval()
    model = replace_gelu(model, mode=mode)
    if compile_model:
        model = torch.compile(model)
    return model

# ---- Inference Benchmark ----
@torch.no_grad()
def benchmark(model, tokenizer, dataset, device, max_len=128, runs=1000000):
    runs = min(len(dataset), runs)
    model.to(device)
    correct = 0
    total_time = 0
    for i in tqdm(range(runs), desc="Inference"):
        text = dataset[i]['sentence'] if 'sentence' in dataset[i] else dataset[i]['text']
        label = dataset[i]['label']
        inputs = tokenizer(text, return_tensors="pt", truncation=True, padding="max_length", max_length=max_len)
        inputs = {k: v.to(device) for k, v in inputs.items()}

        start = time.time()
        outputs = model(**inputs)
        if device.type == 'cuda':
            torch.cuda.synchronize()
        total_time += time.time() - start

        pred = torch.argmax(outputs.logits, dim=1).item()
        correct += int(pred == label)

    avg_time = total_time / runs
    accuracy = correct / runs * 100
    return avg_time, accuracy

# ---- Main ----
if __name__ == "__main__":
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    tokenizer = BertTokenizerFast.from_pretrained("bert-base-uncased")
    dataset = load_dataset("glue", "sst2", split="validation")

    configs = [
        ("Standard GELU", "standard", False),
        ("Cached GELU", "cached", True),
        ("Tanh Approx GELU", "tanh", True),
        ("Sigmoid Approx GELU", "sigmoid", True),
    ]

    for label, mode, compile_flag in configs:
        print(f"\nRunning: {label}")
        model = load_bert(mode=mode, compile_model=compile_flag)
        avg_time, acc = benchmark(model, tokenizer, dataset, device)
        print(f"{label} => Avg Inference Time: {avg_time:.6f} sec, Accuracy: {acc:.2f}%")
