# IBP Training on MNIST
Implementation of IBP (Interval Bound Propagation) training using a 3-layer fully-connected ReLU network.

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import numpy as np
import time
import matplotlib.pyplot as plt
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
from tqdm import tqdm
from bound_propagation import BoundModelFactory, HyperRectangle

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
batch_size = 128

# MNIST dataloaders
train_dataset = datasets.MNIST('mnist_data/', train=True, download=True, transform=transforms.ToTensor())
test_dataset = datasets.MNIST('mnist_data/', train=False, download=True, transform=transforms.ToTensor())
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

In [None]:
class Net(nn.Sequential):
    def __init__(self):
        super().__init__()
        layers = []
        d = 28*28
        for _ in range(3):
            layers.append(nn.Linear(d, 50))
            layers.append(nn.ReLU(inplace=True))
            d = 50
        self.body = nn.Sequential(*layers)
        self.head = nn.Linear(d, 10)
    def forward(self, x):
        x = x.view(x.size(0), -1)
        h = self.body(x)
        z = self.head(h)      # logits (NO softmax!)
        return z

class Normalize(nn.Module):
    def forward(self, x):
        return (x - 0.1307)/0.3081

model = nn.Sequential(Normalize(), Net())

model = model.to(device)
model.train()

def adversarial_logit(bounds, y):
    batch_size = y.size(0)
    classes = torch.arange(10, device=y.device).unsqueeze(0).expand(batch_size, -1)
    mask = (classes == y.unsqueeze(-1)).to(dtype=bounds.lower.dtype)
    adv_logit = (1 - mask) * bounds.upper + mask * bounds.lower
    return adv_logit

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import time

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

class Net(nn.Module):
    def __init__(self):
        super().__init__()
        layers = []
        d = 28 * 28
        for _ in range(3):
            layers.append(nn.Linear(d, 50))
            layers.append(nn.ReLU(inplace=True))
            d = 50
        self.body = nn.Sequential(*layers)
        self.head = nn.Linear(d, 10)

    def forward(self, x):
        x = x.view(x.size(0), -1)
        h = self.body(x)
        z = self.head(h)
        return z


class Normalize(nn.Module):
    def forward(self, x):
        return (x - 0.1307) / 0.3081


model = nn.Sequential(Normalize(), Net()).to(device)

def interval_linear(W, b, l_in, u_in):
    W_pos = torch.clamp(W, min=0)
    W_neg = torch.clamp(W, max=0)
    l_out = l_in @ W_pos.T + u_in @ W_neg.T + b
    u_out = u_in @ W_pos.T + l_in @ W_neg.T + b
    return l_out, u_out

def interval_relu(l_in, u_in):
    return F.relu(l_in), F.relu(u_in)

def interval_forward(model, x, eps):
    """Compute (l,u) interval bounds for final logits."""
    normalize = model[0]
    net = model[1]

    mean, std = 0.1307, 0.3081
    l = torch.clamp(x - eps, 0, 1)
    u = torch.clamp(x + eps, 0, 1)
    l = (l - mean) / std
    u = (u - mean) / std

    l = l.view(l.size(0), -1)
    u = u.view(u.size(0), -1)

    layers = list(net.body)
    for i in range(0, len(layers), 2):
        linear = layers[i]
        W, b = linear.weight, linear.bias
        l, u = interval_linear(W, b, l, u)
        l, u = interval_relu(l, u)

    W, b = net.head.weight, net.head.bias
    l, u = interval_linear(W, b, l, u)
    return l, u


def worst_case_logits_from_bounds(lz, uz, labels):
    """Construct worst-case logits for CE loss."""
    onehot = F.one_hot(labels, num_classes=lz.size(1)).bool()
    worst_logits = torch.where(onehot, lz, uz)
    return worst_logits

def kappa_schedule(epoch, total_epochs, start=1.0, end=0.5):
    return start - (start - end) * epoch / (total_epochs - 1)

def eps_schedule(epoch, total_epochs, eps_target):
    return eps_target * epoch / (total_epochs - 1)

def train_ibp(model, train_loader, num_epochs=20, lr=1e-3, eps_train_target=0.1):
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.SGD(model.parameters(), lr=lr, momentum=0.9)
    t0 = time.time()

    for epoch in range(num_epochs):
        model.train()
        kappa = kappa_schedule(epoch, num_epochs, 1.0, 0.5)
        eps = eps_schedule(epoch, num_epochs, eps_train_target)
        run_loss = 0.0

        for images, labels in train_loader:
            images, labels = images.to(device), labels.to(device)

            logits_clean = model(images)
            loss_clean = criterion(logits_clean, labels)

            lz, uz = interval_forward(model, images, eps)
            worst_logits = worst_case_logits_from_bounds(lz, uz, labels)
            loss_ibp = criterion(worst_logits, labels)

            loss = kappa * loss_clean + (1 - kappa) * loss_ibp

            optimizer.zero_grad(set_to_none=True)
            loss.backward()
            optimizer.step()

            run_loss += loss.item()

        print(f"[Epoch {epoch+1}/{num_epochs}] κ={kappa:.3f} ε={eps:.3f} Loss={run_loss/len(train_loader):.4f}")

    total_time = time.time() - t0
    print(f"IBP training completed in {total_time:.2f} sec")
    return total_time


def verify_accuracy(model, test_loader, eps):
    model.eval()
    verified, total = 0, 0
    with torch.no_grad():
        for images, labels in test_loader:
            images, labels = images.to(device), labels.to(device)
            l_out, u_out = interval_forward(model, images, eps)
            for i in range(images.size(0)):
                true_label = labels[i].item()
                if l_out[i, true_label] > torch.max(u_out[i, torch.arange(10) != true_label]):
                    verified += 1
            total += images.size(0)
    return 100.0 * verified / total


In [None]:
def pgd_attack(model, X, y, eps=0.1, alpha=0.01, iters=20):
    X_adv = X.clone().detach().requires_grad_(True)
    for _ in range(iters):
        logits = model(X_adv)
        loss = F.cross_entropy(logits, y)
        loss.backward()
        X_adv = X_adv + alpha * X_adv.grad.sign()
        X_adv = torch.min(torch.max(X_adv, X - eps), X + eps)
        X_adv = X_adv.clamp(0, 1).detach().requires_grad_(True)
    return X_adv

In [None]:
def verified_accuracy(net, test_loader, eps_values, device='cuda'):
    results = []
    for eps in eps_values:
        verified = 0
        total = 0
        for X, y in test_loader:
            X, y = X.to(device), y.to(device)
            bounds = net.ibp(HyperRectangle.from_eps(X, eps))
            adv_logits = adversarial_logit(bounds, y)
            pred = adv_logits.argmin(1)  # worst-case logit
            verified += (pred == y).sum().item()
            total += y.size(0)
        acc = verified / total
        print(f"eps={eps:.3f}, verified acc={acc:.3f}")
        results.append((eps, acc))
    return results


In [None]:
def train_standard(model, train_loader, num_epochs=20, lr=1e-3):
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=lr)
    t0 = time.time()
    for epoch in range(num_epochs):
        model.train()
        total_loss = 0
        for x, y in train_loader:
            x, y = x.to(device), y.to(device)
            logits = model(x)
            loss = criterion(logits, y)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            total_loss += loss.item()
        print(f"[STD] Epoch {epoch+1}/{num_epochs}, Loss={total_loss/len(train_loader):.4f}")
    return time.time() - t0

In [None]:
def evaluate_accuracy(model, loader, eps=0.1, pgd_steps=10):
    model.eval()
    total, correct, correct_adv = 0, 0, 0
    with torch.no_grad():
        for x, y in loader:
            x, y = x.to(device), y.to(device)
            out = model(x)
            pred = out.argmax(1)
            correct += (pred == y).sum().item()
            total += y.size(0)

    # PGD robust accuracy
    for x, y in loader:
        x, y = x.to(device), y.to(device)
        adv = pgd_attack(model, x, y, eps, eps/4, pgd_steps)
        with torch.no_grad():
            out_adv = model(adv)
            pred_adv = out_adv.argmax(1)
            correct_adv += (pred_adv == y).sum().item()

    std_acc = 100 * correct / total
    rob_acc = 100 * correct_adv / total
    return std_acc, rob_acc

In [None]:
model_std = nn.Sequential(Normalize(), Net()).to(device)
std_time = train_standard(model_std, train_loader, num_epochs=20)
std_acc, std_rob = evaluate_accuracy(model_std, test_loader, eps=0.1)
print(f"\nStandard Model: acc={std_acc:.2f}%  robust acc={std_rob:.2f}%  time={std_time:.1f}s")

In [None]:
for eps in torch.linspace(0.01, 0.1, 10):
    acc = verify_accuracy(model_std, test_loader, eps)
    print(f"Verified accuracy at eps={eps:.2f}: {acc:.2f}%")

In [None]:
train_time = train_ibp(model, train_loader, num_epochs=20, lr=1e-3, eps_train_target=0.1)

In [None]:
ibp_acc, ibp_rob = evaluate_accuracy(model, test_loader, eps=0.1)
print(f"\nIBP Model: acc={ibp_acc:.2f}%  robust acc={ibp_rob:.2f}%  time={train_time:.1f}s")
for eps in torch.linspace(0.01, 0.1, 10):
    acc = verify_accuracy(model, test_loader, eps)
    print(f"Verified accuracy at eps={eps:.2f}: {acc:.2f}%")

In [None]:
def visualize_adversarial_examples(model, loader, eps_list=[0.05, 0.1, 0.2]):
    model.eval()
    x, y = next(iter(loader))
    x, y = x.to(device), y.to(device)
    fig, axes = plt.subplots(len(eps_list), 6, figsize=(10, 6))
    for row, eps in enumerate(eps_list):
        adv = pgd_attack(model, x[:6], y[:6], eps, eps/4, iters=20)
        for i in range(6):
            axes[row, i].imshow(adv[i].detach().cpu().squeeze(), cmap="gray")
            axes[row, i].set_title(f"ϵ={eps:.2f}")
            axes[row, i].axis("off")
    plt.tight_layout()
    plt.show()

In [None]:
visualize_adversarial_examples(model, test_loader)