In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
import time
import numpy as np
from sklearn.metrics import classification_report, confusion_matrix
import seaborn as sns
import matplotlib.pyplot as plt

# ─────────────────────────────
# GELU Approximations
# ─────────────────────────────
class GELUManual(nn.Module):
    def forward(self, x):
        return 0.5 * x * (1.0 + torch.erf(x / torch.sqrt(torch.tensor(2.0, device=x.device))))

class GELUTanhApprox(nn.Module):
    def forward(self, x):
        coeff = torch.sqrt(torch.tensor(2.0 / torch.pi, device=x.device))
        return 0.5 * x * (1.0 + torch.tanh(coeff * (x + 0.044715 * torch.pow(x, 3))))

class GELUSigmoidApprox(nn.Module):
    def forward(self, x):
        return x * torch.sigmoid(1.702 * x)

class CachedGELU(nn.Module):
    def __init__(self, x_min=-100.0, x_max=100.0, N=20000):
        super().__init__()
        self.x_min = x_min
        self.x_max = x_max
        self.N = N
        self.step = (x_max - x_min) / (N - 1)
        self.inv_step = 1.0 / self.step
        x_table = torch.linspace(x_min, x_max, N)
        y_table = 0.5 * x_table * (1.0 + torch.erf(x_table / torch.sqrt(torch.tensor(2.0))))
        slope = torch.diff(y_table, append=y_table[-1].unsqueeze(0))
        self.register_buffer('y_table', y_table)
        self.register_buffer('slope', slope)

    def forward(self, x):
        x_clamped = torch.clamp(x, self.x_min, self.x_max)
        idx_f = (x_clamped - self.x_min) * self.inv_step
        idx = idx_f.long().clamp(0, self.N - 1)
        frac = idx_f - idx.float()
        y_val = self.y_table[idx]
        m_val = self.slope[idx]
        approx = y_val + frac * m_val
        gelu_exact = 0.5 * x * (1.0 + torch.erf(x / torch.sqrt(torch.tensor(2.0, device=x.device))))
        return torch.where((x < self.x_min) | (x > self.x_max), gelu_exact, approx)

# ─────────────────────────────
# Residual Block & ResNet
# ─────────────────────────────
class ResidualBlock(nn.Module):
    def __init__(self, inchannel, outchannel, activation, stride=1):
        super().__init__()
        self.activation = activation
        self.left = nn.Sequential(
            nn.Conv2d(inchannel, outchannel, 3, stride, 1, bias=False),
            nn.BatchNorm2d(outchannel),
            activation,
            nn.Conv2d(outchannel, outchannel, 3, 1, 1, bias=False),
            nn.BatchNorm2d(outchannel)
        )
        self.shortcut = nn.Sequential()
        if stride != 1 or inchannel != outchannel:
            self.shortcut = nn.Sequential(
                nn.Conv2d(inchannel, outchannel, 1, stride, bias=False),
                nn.BatchNorm2d(outchannel)
            )

    def forward(self, x):
        out = self.left(x)
        out += self.shortcut(x)
        return self.activation(out)

class ResNet(nn.Module):
    def __init__(self, activation):
        super().__init__()
        self.inchannel = 64
        self.activation = activation
        self.conv1 = nn.Sequential(
            nn.Conv2d(3, 64, 3, 1, 1, bias=False),
            nn.BatchNorm2d(64),
            activation
        )
        self.layer1 = self._make_layer(64, 2, stride=1)
        self.layer2 = self._make_layer(128, 2, stride=2)
        self.layer3 = self._make_layer(256, 2, stride=2)
        self.layer4 = self._make_layer(512, 2, stride=2)
        self.fc = nn.Linear(512, 10)

    def _make_layer(self, outchannel, blocks, stride):
        layers = [ResidualBlock(self.inchannel, outchannel, self.activation, stride)]
        self.inchannel = outchannel
        for _ in range(1, blocks):
            layers.append(ResidualBlock(self.inchannel, outchannel, self.activation))
        return nn.Sequential(*layers)

    def forward(self, x):
        x = self.conv1(x)
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)
        x = F.avg_pool2d(x, 4)
        return self.fc(x.view(x.size(0), -1))

# ─────────────────────────────
# Training and Evaluation
# ─────────────────────────────
def train_model(model, name, trainloader, testloader, device, epochs=30):
    model = torch.compile(model)
    model.to(device)
    optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.9, weight_decay=5e-4)
    criterion = nn.CrossEntropyLoss()

    print(f"\n🚀 Training with {name}...\n")
    start = time.time()
    for epoch in range(epochs):
        model.train()
        total, correct, loss_sum = 0, 0, 0
        for inputs, labels in trainloader:
            inputs, labels = inputs.to(device), labels.to(device)
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            loss_sum += loss.item()
            _, preds = torch.max(outputs, 1)
            total += labels.size(0)
            correct += (preds == labels).sum().item()
        print(f"Epoch {epoch+1:2d} | Loss: {loss_sum:.3f} | Train Accuracy: {100 * correct / total:.2f}%")

    print("✅ Evaluating...")
    model.eval()
    y_true, y_pred = [], []
    with torch.no_grad():
        for inputs, labels in testloader:
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = model(inputs)
            _, preds = torch.max(outputs, 1)
            y_true.extend(labels.cpu().numpy())
            y_pred.extend(preds.cpu().numpy())
    end = time.time()

    acc = 100 * np.mean(np.array(y_pred) == np.array(y_true))
    print(f"\n✅ {name} Test Accuracy: {acc:.2f}% | Time: {end - start:.2f}s")
    print("\n📊 Classification Report:\n")
    print(classification_report(y_true, y_pred, target_names=trainloader.dataset.classes))
    print("🧩 Confusion Matrix:\n")

    cm = confusion_matrix(y_true, y_pred)
    plt.figure(figsize=(10, 8))
    sns.heatmap(cm, annot=True, fmt='d', xticklabels=trainloader.dataset.classes,
                yticklabels=trainloader.dataset.classes, cmap="Blues")
    plt.title(f"{name} - Confusion Matrix")
    plt.xlabel("Predicted")
    plt.ylabel("True")
    plt.show()

# ─────────────────────────────
# Load CIFAR-10
# ─────────────────────────────
transform_train = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
])
transform_test = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
])

trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_train)
testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_test)

trainloader = torch.utils.data.DataLoader(trainset, batch_size=128, shuffle=True, num_workers=2)
testloader = torch.utils.data.DataLoader(testset, batch_size=100, shuffle=False, num_workers=2)

# ─────────────────────────────
# Run All Variants
# ─────────────────────────────
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

train_model(ResNet(GELUManual()), "GELU", trainloader, testloader, device)
train_model(ResNet(CachedGELU()), "Cached GELU", trainloader, testloader, device)
train_model(ResNet(GELUTanhApprox()), "Tanh Approx GELU ", trainloader, testloader, device)
train_model(ResNet(GELUSigmoidApprox()), "Sigmoid Approx GELU ", trainloader, testloader, device)
