In [12]:
import timeit
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision.transforms import transforms
from tqdm import trange
import time

from bound_propagation import BoundModelFactory, HyperRectangle


class IBPNet(nn.Sequential):
    def __init__(self):
        super(IBPNet, self).__init__()
        self.fc1 = nn.Linear(28*28, 50)
        self.fc2 = nn.Linear(50, 50)
        self.fc3 = nn.Linear(50, 50)
        self.fc4 = nn.Linear(50, 10)
        
    def forward(self, x):
        x = x.view((-1, 28*28))
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = F.relu(self.fc3(x))
        x = self.fc4(x)
        return x

In [5]:
def construct_transform():
    transform = transforms.Compose([
        transforms.PILToTensor(),
        transforms.ConvertImageDtype(torch.float),
        transforms.Lambda(torch.flatten)
    ])
    target_transform = transforms.Compose([])
    return transform, target_transform


def adversarial_logit(y_hat, y):
    """Compute adversarial logits by taking upper bounds for incorrect classes
    and lower bounds for the correct class."""
    batch_size = y.size(0)
    classes = torch.arange(10, device=y.device).unsqueeze(0).expand(batch_size, -1)
    mask = (classes == y.unsqueeze(-1)).to(dtype=y_hat.lower.dtype)
    
    # Take upper bound for logit of all but the correct class where you take the lower bound
    adversarial_logit = (1 - mask) * y_hat.upper + mask * y_hat.lower
    return adversarial_logit

In [14]:
def train_ibp(net, device, num_epochs=100, epsilon_target=0.1):
    """Train network with IBP robustness training following the paper's procedure.
    
    The training uses:
    - Curriculum learning with scheduled κ (1.0 → 0.5) and ε (0 → ε_target)
    - Combined loss: κ·CE(z_K, y) + (1-κ)·CE(ẑ_K(ε), y)
    - Linear warmup and ramp-up schedules as in paper Appendix A
    """
    print('[IBP TRAINING - Following Paper Protocol]')
    
    transform, target_transform = construct_transform()
    train_data = datasets.FashionMNIST('./fashion_data', train=True, download=True,
                                       transform=transform, target_transform=target_transform)
    train_loader = DataLoader(train_data, batch_size=100, shuffle=True, num_workers=4)
    
    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(net.parameters(), lr=1e-3)
    
    # Training schedule parameters (from paper Appendix A)
    total_steps = num_epochs * len(train_loader)
    warmup_steps = int(0.1 * total_steps)  # 10% warmup
    rampup_steps = int(0.5 * total_steps)  # 50% ramp-up for epsilon
    
    # Learning rate decay schedule
    lr_decay_steps = [int(0.6 * total_steps), int(0.9 * total_steps)]
    
    start_time = time.time()
    current_step = 0
    
    for epoch in trange(num_epochs):
        running_loss = 0.0
        running_ce = 0.0
        running_robust = 0.0
        
        for batch_idx, (X, y) in enumerate(train_loader):
            X, y = X.to(device), y.to(device)
            
            # Learning rate schedule
            if current_step == lr_decay_steps[0] or current_step == lr_decay_steps[1]:
                for param_group in optimizer.param_groups:
                    param_group['lr'] *= 0.1
            
            # Linear warmup for learning rate
            if current_step < warmup_steps:
                lr_scale = current_step / warmup_steps
                for param_group in optimizer.param_groups:
                    param_group['lr'] = 1e-3 * lr_scale
            
            # Compute schedule values (linear ramp)
            # κ: from 1.0 to 0.5 over full training
            k = max(1.0 - 0.5 * (current_step / total_steps), 0.5)
            
            # ε: from 0 to ε_target over rampup period
            if current_step < rampup_steps:
                eps_train = epsilon_target * (current_step / rampup_steps)
            else:
                eps_train = epsilon_target
            
            optimizer.zero_grad(set_to_none=True)
            
            # Standard prediction and CE loss
            y_hat = net(X)
            ce_loss = criterion(y_hat, y)
            
            # IBP bound propagation and robustness loss
            bounds = net.ibp(HyperRectangle.from_eps(X, eps_train))
            adv_logit = adversarial_logit(bounds, y)
            robust_loss = criterion(adv_logit, y)
            
            # Combined loss (Eq. 12 from paper)
            loss = k * ce_loss + (1 - k) * robust_loss
            loss.backward()
            optimizer.step()
            
            running_loss += loss.item()
            running_ce += ce_loss.item()
            running_robust += robust_loss.item()
            current_step += 1
            
            if batch_idx % 50 == 49:
                print(f'Epoch [{epoch+1:3d}/{num_epochs}], Step [{batch_idx+1:3d}], '
                      f'Loss: {running_loss/50:.4f}, CE: {running_ce/50:.4f}, '
                      f'Robust: {running_robust/50:.4f}, ε: {eps_train:.4f}, κ: {k:.3f}')
                running_loss = 0.0
                running_ce = 0.0
                running_robust = 0.0
    
    training_time = time.time() - start_time
    print(f'IBP Training completed in {training_time:.2f} seconds')
    return training_time

In [15]:
def train_standard(net, device, num_epochs=100):
    """Standard training with cross-entropy loss for comparison."""
    print('[STANDARD TRAINING]')
    
    transform, target_transform = construct_transform()
    train_data = datasets.FashionMNIST('./fashion_data', train=True, download=True,
                                       transform=transform, target_transform=target_transform)
    train_loader = DataLoader(train_data, batch_size=100, shuffle=True, num_workers=4)
    
    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(net.parameters(), lr=1e-3)
    
    start_time = time.time()
    total_steps = num_epochs * len(train_loader)
    lr_decay_steps = [int(0.6 * total_steps), int(0.9 * total_steps)]
    current_step = 0
    
    for epoch in trange(num_epochs):
        running_loss = 0.0
        
        for batch_idx, (X, y) in enumerate(train_loader):
            X, y = X.to(device), y.to(device)
            
            # Learning rate decay schedule
            if current_step == lr_decay_steps[0] or current_step == lr_decay_steps[1]:
                for param_group in optimizer.param_groups:
                    param_group['lr'] *= 0.1
            
            optimizer.zero_grad(set_to_none=True)
            
            y_hat = net(X)
            loss = criterion(y_hat, y)
            loss.backward()
            optimizer.step()
            
            running_loss += loss.item()
            current_step += 1
            
            if batch_idx % 50 == 49:
                print(f'Epoch [{epoch+1:3d}/{num_epochs}], Step [{batch_idx+1:3d}], Loss: {running_loss/50:.4f}')
                running_loss = 0.0
    
    training_time = time.time() - start_time
    print(f'Standard Training completed in {training_time:.2f} seconds')
    return training_time

In [10]:
@torch.no_grad()
def test_standard_accuracy(net, device):
    """Evaluate standard accuracy on test set."""
    print('[STANDARD ACCURACY TEST]')
    
    transform, target_transform = construct_transform()
    test_data = datasets.FashionMNIST('./fashion_data', train=False, download=True,
                                      transform=transform, target_transform=target_transform)
    test_loader = DataLoader(test_data, batch_size=256, shuffle=False, num_workers=4)
    
    correct = 0
    total = 0
    
    for X, y in test_loader:
        X, y = X.to(device), y.to(device)
        y_hat = net(X)
        predicted = torch.argmax(y_hat, dim=1)
        correct += (predicted == y).sum().item()
        total += y.size(0)
    
    accuracy = correct / total
    print(f'Standard Accuracy: {accuracy:.4f}')
    return accuracy



def test_robust_accuracy_pgd(net, device, epsilon=0.1, num_steps=20, step_size=None):
    """Evaluate robust accuracy against PGD attack."""
    print(f'[ROBUST ACCURACY TEST - PGD ε={epsilon}]')
    
    if step_size is None:
        step_size = epsilon / 4
    
    transform, target_transform = construct_transform()
    test_data = datasets.FashionMNIST('./fashion_data', train=False, download=True,
                                      transform=transform, target_transform=target_transform)
    test_loader = DataLoader(test_data, batch_size=256, shuffle=False, num_workers=4)
    
    correct = 0
    total = 0
    
    for X, y in test_loader:
        X, y = X.to(device), y.to(device)
        
        # PGD attack
        X_adv = X.clone() + torch.empty_like(X).uniform_(-epsilon, epsilon)
        X_adv = torch.clamp(X_adv, 0, 1)
        X_adv.requires_grad = True
        
        for _ in range(num_steps):
            with torch.enable_grad():
                y_hat = net(X_adv)
                loss = F.cross_entropy(y_hat, y)
                grad = torch.autograd.grad(loss, X_adv)[0]
            
            with torch.no_grad():
                X_adv.data = X_adv + step_size * grad.sign()
                X_adv.data = torch.clamp(X_adv, X - epsilon, X + epsilon)
                X_adv.data = torch.clamp(X_adv, 0, 1)
            
            X_adv.requires_grad_(True)
        
        with torch.no_grad():
            y_hat = net(X_adv)
            predicted = torch.argmax(y_hat, dim=1)
            correct += (predicted == y).sum().item()
            total += y.size(0)
    
    accuracy = correct / total
    print(f'Robust Accuracy (PGD ε={epsilon}): {accuracy:.4f}')
    return accuracy



In [16]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Using device: {device}')

# Train standard model
print('\n' + '='*60)
print('STANDARD TRAINING')
print('='*60)
net_standard = IBPNet().to(device)
std_train_time = train_standard(net_standard, device, num_epochs=20)
std_accuracy = test_standard_accuracy(net_standard, device)
std_robust_accuracy = test_robust_accuracy_pgd(net_standard, device, epsilon=0.1)

# Train IBP model
print('\n' + '='*60)
print('IBP ROBUST TRAINING')
print('='*60)
net_ibp = IBPNet().to(device)

# Build IBP-enabled model
factory = BoundModelFactory()
net_ibp = factory.build(net_ibp)

ibp_train_time = train_ibp(net_ibp, device, num_epochs=20, epsilon_target=0.1)
ibp_accuracy = test_standard_accuracy(net_ibp, device)
ibp_robust_accuracy = test_robust_accuracy_pgd(net_ibp, device, epsilon=0.1)

# Report results
print('\n' + '='*60)
print('RESULTS SUMMARY')
print('='*60)
print(f'\n{"Model":<15} {"Std Accuracy":<15} {"Rob Accuracy":<15} {"Train Time":<15}')
print('-'*60)
print(f'{"Standard":<15} {std_accuracy:<15.4f} {std_robust_accuracy:<15.4f} {std_train_time:<15.2f}s')
print(f'{"IBP Robust":<15} {ibp_accuracy:<15.4f} {ibp_robust_accuracy:<15.4f} {ibp_train_time:<15.2f}s')
print(f'\nRobustness Improvement: {(ibp_robust_accuracy - std_robust_accuracy)*100:.2f}%')
print(f'Training Time Ratio (IBP/Standard): {ibp_train_time/std_train_time:.2f}x')

Using device: cuda

STANDARD TRAINING
[STANDARD TRAINING]


  0%|          | 0/20 [00:00<?, ?it/s]

Epoch [  1/20], Step [ 50], Loss: 1.7397
Epoch [  1/20], Step [100], Loss: 0.8435
Epoch [  1/20], Step [150], Loss: 0.7270
Epoch [  1/20], Step [200], Loss: 0.6403
Epoch [  1/20], Step [250], Loss: 0.6333
Epoch [  1/20], Step [300], Loss: 0.5525
Epoch [  1/20], Step [350], Loss: 0.5439
Epoch [  1/20], Step [400], Loss: 0.5287
Epoch [  1/20], Step [450], Loss: 0.5286
Epoch [  1/20], Step [500], Loss: 0.5302
Epoch [  1/20], Step [550], Loss: 0.4902


  5%|▌         | 1/20 [00:03<01:04,  3.41s/it]

Epoch [  1/20], Step [600], Loss: 0.4894
Epoch [  2/20], Step [ 50], Loss: 0.4708
Epoch [  2/20], Step [100], Loss: 0.4765
Epoch [  2/20], Step [150], Loss: 0.4739
Epoch [  2/20], Step [200], Loss: 0.4496
Epoch [  2/20], Step [250], Loss: 0.4498
Epoch [  2/20], Step [300], Loss: 0.4389
Epoch [  2/20], Step [350], Loss: 0.4293
Epoch [  2/20], Step [400], Loss: 0.4422
Epoch [  2/20], Step [450], Loss: 0.4223
Epoch [  2/20], Step [500], Loss: 0.4306
Epoch [  2/20], Step [550], Loss: 0.4100


 10%|█         | 2/20 [00:06<01:00,  3.39s/it]

Epoch [  2/20], Step [600], Loss: 0.4279
Epoch [  3/20], Step [ 50], Loss: 0.4173
Epoch [  3/20], Step [100], Loss: 0.4084
Epoch [  3/20], Step [150], Loss: 0.3923
Epoch [  3/20], Step [200], Loss: 0.4225
Epoch [  3/20], Step [250], Loss: 0.4013
Epoch [  3/20], Step [300], Loss: 0.4058
Epoch [  3/20], Step [350], Loss: 0.3935
Epoch [  3/20], Step [400], Loss: 0.4016
Epoch [  3/20], Step [450], Loss: 0.3936
Epoch [  3/20], Step [500], Loss: 0.3733
Epoch [  3/20], Step [550], Loss: 0.3957


 15%|█▌        | 3/20 [00:10<00:56,  3.34s/it]

Epoch [  3/20], Step [600], Loss: 0.3831
Epoch [  4/20], Step [ 50], Loss: 0.3716
Epoch [  4/20], Step [100], Loss: 0.3783
Epoch [  4/20], Step [150], Loss: 0.3482
Epoch [  4/20], Step [200], Loss: 0.3796
Epoch [  4/20], Step [250], Loss: 0.3736
Epoch [  4/20], Step [300], Loss: 0.3708
Epoch [  4/20], Step [350], Loss: 0.3663
Epoch [  4/20], Step [400], Loss: 0.3682
Epoch [  4/20], Step [450], Loss: 0.3542
Epoch [  4/20], Step [500], Loss: 0.3787
Epoch [  4/20], Step [550], Loss: 0.3760


 20%|██        | 4/20 [00:13<00:54,  3.38s/it]

Epoch [  4/20], Step [600], Loss: 0.3579
Epoch [  5/20], Step [ 50], Loss: 0.3674
Epoch [  5/20], Step [100], Loss: 0.3486
Epoch [  5/20], Step [150], Loss: 0.3466
Epoch [  5/20], Step [200], Loss: 0.3615
Epoch [  5/20], Step [250], Loss: 0.3586
Epoch [  5/20], Step [300], Loss: 0.3425
Epoch [  5/20], Step [350], Loss: 0.3394
Epoch [  5/20], Step [400], Loss: 0.3428
Epoch [  5/20], Step [450], Loss: 0.3374
Epoch [  5/20], Step [500], Loss: 0.3421
Epoch [  5/20], Step [550], Loss: 0.3354


 25%|██▌       | 5/20 [00:16<00:50,  3.37s/it]

Epoch [  5/20], Step [600], Loss: 0.3454
Epoch [  6/20], Step [ 50], Loss: 0.3293
Epoch [  6/20], Step [100], Loss: 0.3269
Epoch [  6/20], Step [150], Loss: 0.3356
Epoch [  6/20], Step [200], Loss: 0.3431
Epoch [  6/20], Step [250], Loss: 0.3501
Epoch [  6/20], Step [300], Loss: 0.3448
Epoch [  6/20], Step [350], Loss: 0.3457
Epoch [  6/20], Step [400], Loss: 0.3195
Epoch [  6/20], Step [450], Loss: 0.3110
Epoch [  6/20], Step [500], Loss: 0.3097
Epoch [  6/20], Step [550], Loss: 0.3269


 30%|███       | 6/20 [00:20<00:47,  3.40s/it]

Epoch [  6/20], Step [600], Loss: 0.3344
Epoch [  7/20], Step [ 50], Loss: 0.3007
Epoch [  7/20], Step [100], Loss: 0.3236
Epoch [  7/20], Step [150], Loss: 0.3369
Epoch [  7/20], Step [200], Loss: 0.3076
Epoch [  7/20], Step [250], Loss: 0.3230
Epoch [  7/20], Step [300], Loss: 0.3067
Epoch [  7/20], Step [350], Loss: 0.3175
Epoch [  7/20], Step [400], Loss: 0.3018
Epoch [  7/20], Step [450], Loss: 0.3161
Epoch [  7/20], Step [500], Loss: 0.3321
Epoch [  7/20], Step [550], Loss: 0.3328


 35%|███▌      | 7/20 [00:23<00:44,  3.42s/it]

Epoch [  7/20], Step [600], Loss: 0.3136
Epoch [  8/20], Step [ 50], Loss: 0.3219
Epoch [  8/20], Step [100], Loss: 0.3048
Epoch [  8/20], Step [150], Loss: 0.3110
Epoch [  8/20], Step [200], Loss: 0.3013
Epoch [  8/20], Step [250], Loss: 0.3007
Epoch [  8/20], Step [300], Loss: 0.3028
Epoch [  8/20], Step [350], Loss: 0.3121
Epoch [  8/20], Step [400], Loss: 0.3129
Epoch [  8/20], Step [450], Loss: 0.3011
Epoch [  8/20], Step [500], Loss: 0.3025
Epoch [  8/20], Step [550], Loss: 0.3195


 40%|████      | 8/20 [00:27<00:40,  3.41s/it]

Epoch [  8/20], Step [600], Loss: 0.3043
Epoch [  9/20], Step [ 50], Loss: 0.3000
Epoch [  9/20], Step [100], Loss: 0.3043
Epoch [  9/20], Step [150], Loss: 0.2828
Epoch [  9/20], Step [200], Loss: 0.3020
Epoch [  9/20], Step [250], Loss: 0.3092
Epoch [  9/20], Step [300], Loss: 0.3022
Epoch [  9/20], Step [350], Loss: 0.2929
Epoch [  9/20], Step [400], Loss: 0.2936
Epoch [  9/20], Step [450], Loss: 0.2908
Epoch [  9/20], Step [500], Loss: 0.3156
Epoch [  9/20], Step [550], Loss: 0.3076


 45%|████▌     | 9/20 [00:30<00:37,  3.41s/it]

Epoch [  9/20], Step [600], Loss: 0.2764
Epoch [ 10/20], Step [ 50], Loss: 0.2633
Epoch [ 10/20], Step [100], Loss: 0.2800
Epoch [ 10/20], Step [150], Loss: 0.2679
Epoch [ 10/20], Step [200], Loss: 0.2856
Epoch [ 10/20], Step [250], Loss: 0.3111
Epoch [ 10/20], Step [300], Loss: 0.3051
Epoch [ 10/20], Step [350], Loss: 0.2869
Epoch [ 10/20], Step [400], Loss: 0.3070
Epoch [ 10/20], Step [450], Loss: 0.2814
Epoch [ 10/20], Step [500], Loss: 0.3107
Epoch [ 10/20], Step [550], Loss: 0.2810
Epoch [ 10/20], Step [600], Loss: 0.2870


 50%|█████     | 10/20 [00:33<00:33,  3.34s/it]

Epoch [ 11/20], Step [ 50], Loss: 0.2761
Epoch [ 11/20], Step [100], Loss: 0.2638
Epoch [ 11/20], Step [150], Loss: 0.2783
Epoch [ 11/20], Step [200], Loss: 0.2889
Epoch [ 11/20], Step [250], Loss: 0.2744
Epoch [ 11/20], Step [300], Loss: 0.2763
Epoch [ 11/20], Step [350], Loss: 0.2917
Epoch [ 11/20], Step [400], Loss: 0.2940
Epoch [ 11/20], Step [450], Loss: 0.2893
Epoch [ 11/20], Step [500], Loss: 0.3029
Epoch [ 11/20], Step [550], Loss: 0.2877


 55%|█████▌    | 11/20 [00:36<00:29,  3.25s/it]

Epoch [ 11/20], Step [600], Loss: 0.2713
Epoch [ 12/20], Step [ 50], Loss: 0.2751
Epoch [ 12/20], Step [100], Loss: 0.2670
Epoch [ 12/20], Step [150], Loss: 0.2638
Epoch [ 12/20], Step [200], Loss: 0.2705
Epoch [ 12/20], Step [250], Loss: 0.2731
Epoch [ 12/20], Step [300], Loss: 0.2814
Epoch [ 12/20], Step [350], Loss: 0.2913
Epoch [ 12/20], Step [400], Loss: 0.2764
Epoch [ 12/20], Step [450], Loss: 0.2802
Epoch [ 12/20], Step [500], Loss: 0.2700
Epoch [ 12/20], Step [550], Loss: 0.2646


 60%|██████    | 12/20 [00:40<00:25,  3.25s/it]

Epoch [ 12/20], Step [600], Loss: 0.2828
Epoch [ 13/20], Step [ 50], Loss: 0.2442
Epoch [ 13/20], Step [100], Loss: 0.2542
Epoch [ 13/20], Step [150], Loss: 0.2368
Epoch [ 13/20], Step [200], Loss: 0.2637
Epoch [ 13/20], Step [250], Loss: 0.2509
Epoch [ 13/20], Step [300], Loss: 0.2331
Epoch [ 13/20], Step [350], Loss: 0.2499
Epoch [ 13/20], Step [400], Loss: 0.2437
Epoch [ 13/20], Step [450], Loss: 0.2245
Epoch [ 13/20], Step [500], Loss: 0.2237
Epoch [ 13/20], Step [550], Loss: 0.2455


 65%|██████▌   | 13/20 [00:43<00:23,  3.31s/it]

Epoch [ 13/20], Step [600], Loss: 0.2492
Epoch [ 14/20], Step [ 50], Loss: 0.2397
Epoch [ 14/20], Step [100], Loss: 0.2490
Epoch [ 14/20], Step [150], Loss: 0.2361
Epoch [ 14/20], Step [200], Loss: 0.2373
Epoch [ 14/20], Step [250], Loss: 0.2340
Epoch [ 14/20], Step [300], Loss: 0.2404
Epoch [ 14/20], Step [350], Loss: 0.2446
Epoch [ 14/20], Step [400], Loss: 0.2324
Epoch [ 14/20], Step [450], Loss: 0.2450
Epoch [ 14/20], Step [500], Loss: 0.2293
Epoch [ 14/20], Step [550], Loss: 0.2425


 70%|███████   | 14/20 [00:46<00:19,  3.29s/it]

Epoch [ 14/20], Step [600], Loss: 0.2358
Epoch [ 15/20], Step [ 50], Loss: 0.2369
Epoch [ 15/20], Step [100], Loss: 0.2281
Epoch [ 15/20], Step [150], Loss: 0.2387
Epoch [ 15/20], Step [200], Loss: 0.2541
Epoch [ 15/20], Step [250], Loss: 0.2297
Epoch [ 15/20], Step [300], Loss: 0.2461
Epoch [ 15/20], Step [350], Loss: 0.2391
Epoch [ 15/20], Step [400], Loss: 0.2310
Epoch [ 15/20], Step [450], Loss: 0.2541
Epoch [ 15/20], Step [500], Loss: 0.2120
Epoch [ 15/20], Step [550], Loss: 0.2349


 75%|███████▌  | 15/20 [00:49<00:16,  3.27s/it]

Epoch [ 15/20], Step [600], Loss: 0.2411
Epoch [ 16/20], Step [ 50], Loss: 0.2178
Epoch [ 16/20], Step [100], Loss: 0.2356
Epoch [ 16/20], Step [150], Loss: 0.2400
Epoch [ 16/20], Step [200], Loss: 0.2294
Epoch [ 16/20], Step [250], Loss: 0.2301
Epoch [ 16/20], Step [300], Loss: 0.2460
Epoch [ 16/20], Step [350], Loss: 0.2369
Epoch [ 16/20], Step [400], Loss: 0.2480
Epoch [ 16/20], Step [450], Loss: 0.2321
Epoch [ 16/20], Step [500], Loss: 0.2343
Epoch [ 16/20], Step [550], Loss: 0.2395


 80%|████████  | 16/20 [00:53<00:13,  3.28s/it]

Epoch [ 16/20], Step [600], Loss: 0.2401
Epoch [ 17/20], Step [ 50], Loss: 0.2346
Epoch [ 17/20], Step [100], Loss: 0.2499
Epoch [ 17/20], Step [150], Loss: 0.2392
Epoch [ 17/20], Step [200], Loss: 0.2202
Epoch [ 17/20], Step [250], Loss: 0.2279
Epoch [ 17/20], Step [300], Loss: 0.2302
Epoch [ 17/20], Step [350], Loss: 0.2368
Epoch [ 17/20], Step [400], Loss: 0.2309
Epoch [ 17/20], Step [450], Loss: 0.2311
Epoch [ 17/20], Step [500], Loss: 0.2420
Epoch [ 17/20], Step [550], Loss: 0.2274


 85%|████████▌ | 17/20 [00:56<00:09,  3.28s/it]

Epoch [ 17/20], Step [600], Loss: 0.2354
Epoch [ 18/20], Step [ 50], Loss: 0.2272
Epoch [ 18/20], Step [100], Loss: 0.2216
Epoch [ 18/20], Step [150], Loss: 0.2292
Epoch [ 18/20], Step [200], Loss: 0.2569
Epoch [ 18/20], Step [250], Loss: 0.2417
Epoch [ 18/20], Step [300], Loss: 0.2322
Epoch [ 18/20], Step [350], Loss: 0.2471
Epoch [ 18/20], Step [400], Loss: 0.2136
Epoch [ 18/20], Step [450], Loss: 0.2480
Epoch [ 18/20], Step [500], Loss: 0.2280
Epoch [ 18/20], Step [550], Loss: 0.2116


 90%|█████████ | 18/20 [00:59<00:06,  3.30s/it]

Epoch [ 18/20], Step [600], Loss: 0.2335
Epoch [ 19/20], Step [ 50], Loss: 0.2304
Epoch [ 19/20], Step [100], Loss: 0.2388
Epoch [ 19/20], Step [150], Loss: 0.2228
Epoch [ 19/20], Step [200], Loss: 0.2202
Epoch [ 19/20], Step [250], Loss: 0.2293
Epoch [ 19/20], Step [300], Loss: 0.2403
Epoch [ 19/20], Step [350], Loss: 0.2344
Epoch [ 19/20], Step [400], Loss: 0.2354
Epoch [ 19/20], Step [450], Loss: 0.2344
Epoch [ 19/20], Step [500], Loss: 0.2139
Epoch [ 19/20], Step [550], Loss: 0.2223


 95%|█████████▌| 19/20 [01:03<00:03,  3.31s/it]

Epoch [ 19/20], Step [600], Loss: 0.2239
Epoch [ 20/20], Step [ 50], Loss: 0.2443
Epoch [ 20/20], Step [100], Loss: 0.2353
Epoch [ 20/20], Step [150], Loss: 0.2212
Epoch [ 20/20], Step [200], Loss: 0.2222
Epoch [ 20/20], Step [250], Loss: 0.2272
Epoch [ 20/20], Step [300], Loss: 0.2338
Epoch [ 20/20], Step [350], Loss: 0.2264
Epoch [ 20/20], Step [400], Loss: 0.2293
Epoch [ 20/20], Step [450], Loss: 0.2134
Epoch [ 20/20], Step [500], Loss: 0.2300
Epoch [ 20/20], Step [550], Loss: 0.2269


100%|██████████| 20/20 [01:06<00:00,  3.33s/it]

Epoch [ 20/20], Step [600], Loss: 0.2269
Standard Training completed in 66.66 seconds
[STANDARD ACCURACY TEST]





Standard Accuracy: 0.8834
[ROBUST ACCURACY TEST - PGD ε=0.1]
Robust Accuracy (PGD ε=0.1): 0.0049

IBP ROBUST TRAINING
[IBP TRAINING - Following Paper Protocol]


  0%|          | 0/20 [00:00<?, ?it/s]

Epoch [  1/20], Step [ 50], Loss: 2.3036, CE: 2.3030, Robust: 2.7552, ε: 0.0008, κ: 0.998
Epoch [  1/20], Step [100], Loss: 2.2906, CE: 2.2863, Robust: 3.6439, ε: 0.0017, κ: 0.996
Epoch [  1/20], Step [150], Loss: 2.2092, CE: 2.1980, Robust: 4.3338, ε: 0.0025, κ: 0.994
Epoch [  1/20], Step [200], Loss: 1.9065, CE: 1.8854, Robust: 4.7602, ε: 0.0033, κ: 0.992
Epoch [  1/20], Step [250], Loss: 1.4187, CE: 1.3820, Robust: 5.2863, ε: 0.0042, κ: 0.990
Epoch [  1/20], Step [300], Loss: 1.0805, CE: 1.0207, Robust: 6.2327, ε: 0.0050, κ: 0.988
Epoch [  1/20], Step [350], Loss: 0.9588, CE: 0.8728, Robust: 7.2229, ε: 0.0058, κ: 0.985
Epoch [  1/20], Step [400], Loss: 0.8979, CE: 0.7856, Robust: 7.9708, ε: 0.0067, κ: 0.983
Epoch [  1/20], Step [450], Loss: 0.8898, CE: 0.7519, Robust: 8.5375, ε: 0.0075, κ: 0.981
Epoch [  1/20], Step [500], Loss: 0.8806, CE: 0.7209, Robust: 8.7976, ε: 0.0083, κ: 0.979
Epoch [  1/20], Step [550], Loss: 0.8875, CE: 0.7069, Robust: 8.9703, ε: 0.0092, κ: 0.977


  5%|▌         | 1/20 [00:20<06:28, 20.46s/it]

Epoch [  1/20], Step [600], Loss: 0.8519, CE: 0.6590, Robust: 8.7167, ε: 0.0100, κ: 0.975
Epoch [  2/20], Step [ 50], Loss: 0.8479, CE: 0.6428, Robust: 8.5262, ε: 0.0108, κ: 0.973
Epoch [  2/20], Step [100], Loss: 0.8417, CE: 0.6315, Robust: 8.1128, ε: 0.0117, κ: 0.971
Epoch [  2/20], Step [150], Loss: 0.8501, CE: 0.6367, Robust: 7.7088, ε: 0.0125, κ: 0.969
Epoch [  2/20], Step [200], Loss: 0.8296, CE: 0.6171, Robust: 7.2060, ε: 0.0133, κ: 0.967
Epoch [  2/20], Step [250], Loss: 0.8503, CE: 0.6371, Robust: 6.8455, ε: 0.0141, κ: 0.965
Epoch [  2/20], Step [300], Loss: 0.8759, CE: 0.6648, Robust: 6.4627, ε: 0.0150, κ: 0.963
Epoch [  2/20], Step [350], Loss: 0.8381, CE: 0.6322, Robust: 5.9800, ε: 0.0158, κ: 0.960
Epoch [  2/20], Step [400], Loss: 0.8042, CE: 0.6016, Robust: 5.5911, ε: 0.0167, κ: 0.958
Epoch [  2/20], Step [450], Loss: 0.8071, CE: 0.6062, Robust: 5.3148, ε: 0.0175, κ: 0.956
Epoch [  2/20], Step [500], Loss: 0.8012, CE: 0.6030, Robust: 5.0303, ε: 0.0183, κ: 0.954
Epoch [  2

 10%|█         | 2/20 [00:40<06:05, 20.31s/it]

Epoch [  2/20], Step [600], Loss: 0.8031, CE: 0.6056, Robust: 4.6402, ε: 0.0200, κ: 0.950
Epoch [  3/20], Step [ 50], Loss: 0.7787, CE: 0.5844, Robust: 4.3945, ε: 0.0208, κ: 0.948
Epoch [  3/20], Step [100], Loss: 0.7965, CE: 0.6045, Robust: 4.2205, ε: 0.0217, κ: 0.946
Epoch [  3/20], Step [150], Loss: 0.7925, CE: 0.6043, Robust: 4.0154, ε: 0.0225, κ: 0.944
Epoch [  3/20], Step [200], Loss: 0.7689, CE: 0.5842, Robust: 3.8104, ε: 0.0233, κ: 0.942
Epoch [  3/20], Step [250], Loss: 0.7982, CE: 0.6130, Robust: 3.7349, ε: 0.0242, κ: 0.940
Epoch [  3/20], Step [300], Loss: 0.7788, CE: 0.5970, Robust: 3.5570, ε: 0.0250, κ: 0.938
Epoch [  3/20], Step [350], Loss: 0.7573, CE: 0.5749, Robust: 3.4465, ε: 0.0258, κ: 0.935
Epoch [  3/20], Step [400], Loss: 0.7750, CE: 0.5874, Robust: 3.4464, ε: 0.0267, κ: 0.933
Epoch [  3/20], Step [450], Loss: 0.7518, CE: 0.5684, Robust: 3.2782, ε: 0.0275, κ: 0.931
Epoch [  3/20], Step [500], Loss: 0.7579, CE: 0.5768, Robust: 3.1729, ε: 0.0283, κ: 0.929
Epoch [  3

 15%|█▌        | 3/20 [00:55<05:04, 17.91s/it]

Epoch [  3/20], Step [600], Loss: 0.7826, CE: 0.5979, Robust: 3.0965, ε: 0.0300, κ: 0.925
Epoch [  4/20], Step [ 50], Loss: 0.7593, CE: 0.5779, Robust: 2.9650, ε: 0.0308, κ: 0.923
Epoch [  4/20], Step [100], Loss: 0.7647, CE: 0.5775, Robust: 2.9732, ε: 0.0317, κ: 0.921
Epoch [  4/20], Step [150], Loss: 0.7659, CE: 0.5784, Robust: 2.9166, ε: 0.0325, κ: 0.919
Epoch [  4/20], Step [200], Loss: 0.7660, CE: 0.5793, Robust: 2.8477, ε: 0.0333, κ: 0.917
Epoch [  4/20], Step [250], Loss: 0.7811, CE: 0.5930, Robust: 2.8231, ε: 0.0342, κ: 0.915
Epoch [  4/20], Step [300], Loss: 0.7632, CE: 0.5743, Robust: 2.7598, ε: 0.0350, κ: 0.913
Epoch [  4/20], Step [350], Loss: 0.7939, CE: 0.6060, Robust: 2.7288, ε: 0.0358, κ: 0.910
Epoch [  4/20], Step [400], Loss: 0.7848, CE: 0.5928, Robust: 2.7121, ε: 0.0367, κ: 0.908
Epoch [  4/20], Step [450], Loss: 0.7575, CE: 0.5693, Robust: 2.6002, ε: 0.0375, κ: 0.906
Epoch [  4/20], Step [500], Loss: 0.7717, CE: 0.5803, Robust: 2.6001, ε: 0.0383, κ: 0.904
Epoch [  4

 20%|██        | 4/20 [01:12<04:40, 17.53s/it]

Epoch [  4/20], Step [600], Loss: 0.7719, CE: 0.5787, Robust: 2.5308, ε: 0.0400, κ: 0.900
Epoch [  5/20], Step [ 50], Loss: 0.7602, CE: 0.5663, Robust: 2.4855, ε: 0.0408, κ: 0.898
Epoch [  5/20], Step [100], Loss: 0.7808, CE: 0.5877, Robust: 2.4605, ε: 0.0416, κ: 0.896
Epoch [  5/20], Step [150], Loss: 0.7770, CE: 0.5791, Robust: 2.4608, ε: 0.0425, κ: 0.894
Epoch [  5/20], Step [200], Loss: 0.7863, CE: 0.5855, Robust: 2.4579, ε: 0.0433, κ: 0.892
Epoch [  5/20], Step [250], Loss: 0.7836, CE: 0.5847, Robust: 2.4036, ε: 0.0442, κ: 0.890
Epoch [  5/20], Step [300], Loss: 0.7901, CE: 0.5912, Robust: 2.3767, ε: 0.0450, κ: 0.888
Epoch [  5/20], Step [350], Loss: 0.7906, CE: 0.5858, Robust: 2.3899, ε: 0.0458, κ: 0.885
Epoch [  5/20], Step [400], Loss: 0.7966, CE: 0.5921, Robust: 2.3612, ε: 0.0467, κ: 0.883
Epoch [  5/20], Step [450], Loss: 0.7611, CE: 0.5605, Robust: 2.2647, ε: 0.0475, κ: 0.881
Epoch [  5/20], Step [500], Loss: 0.8066, CE: 0.5964, Robust: 2.3515, ε: 0.0483, κ: 0.879
Epoch [  5

 25%|██▌       | 5/20 [01:33<04:38, 18.60s/it]

Epoch [  5/20], Step [600], Loss: 0.7891, CE: 0.5771, Robust: 2.2879, ε: 0.0500, κ: 0.875
Epoch [  6/20], Step [ 50], Loss: 0.7918, CE: 0.5828, Robust: 2.2416, ε: 0.0508, κ: 0.873
Epoch [  6/20], Step [100], Loss: 0.7905, CE: 0.5808, Robust: 2.2173, ε: 0.0517, κ: 0.871
Epoch [  6/20], Step [150], Loss: 0.8088, CE: 0.5948, Robust: 2.2380, ε: 0.0525, κ: 0.869
Epoch [  6/20], Step [200], Loss: 0.7998, CE: 0.5860, Robust: 2.2026, ε: 0.0533, κ: 0.867
Epoch [  6/20], Step [250], Loss: 0.8127, CE: 0.5909, Robust: 2.2414, ε: 0.0542, κ: 0.865
Epoch [  6/20], Step [300], Loss: 0.7928, CE: 0.5698, Robust: 2.2041, ε: 0.0550, κ: 0.863
Epoch [  6/20], Step [350], Loss: 0.8492, CE: 0.6151, Robust: 2.3048, ε: 0.0558, κ: 0.860
Epoch [  6/20], Step [400], Loss: 0.8096, CE: 0.5887, Robust: 2.1597, ε: 0.0567, κ: 0.858
Epoch [  6/20], Step [450], Loss: 0.8232, CE: 0.5901, Robust: 2.2235, ε: 0.0575, κ: 0.856
Epoch [  6/20], Step [500], Loss: 0.8341, CE: 0.5995, Robust: 2.2198, ε: 0.0583, κ: 0.854
Epoch [  6

 30%|███       | 6/20 [01:52<04:26, 19.00s/it]

Epoch [  6/20], Step [600], Loss: 0.8223, CE: 0.5902, Robust: 2.1484, ε: 0.0600, κ: 0.850
Epoch [  7/20], Step [ 50], Loss: 0.8178, CE: 0.5839, Robust: 2.1331, ε: 0.0608, κ: 0.848
Epoch [  7/20], Step [100], Loss: 0.8121, CE: 0.5687, Robust: 2.1587, ε: 0.0617, κ: 0.846
Epoch [  7/20], Step [150], Loss: 0.8214, CE: 0.5811, Robust: 2.1294, ε: 0.0625, κ: 0.844
Epoch [  7/20], Step [200], Loss: 0.8437, CE: 0.5941, Robust: 2.1817, ε: 0.0633, κ: 0.842
Epoch [  7/20], Step [250], Loss: 0.8453, CE: 0.5977, Robust: 2.1512, ε: 0.0641, κ: 0.840
Epoch [  7/20], Step [300], Loss: 0.8693, CE: 0.6159, Robust: 2.1858, ε: 0.0650, κ: 0.838
Epoch [  7/20], Step [350], Loss: 0.8519, CE: 0.5998, Robust: 2.1417, ε: 0.0658, κ: 0.835
Epoch [  7/20], Step [400], Loss: 0.8876, CE: 0.6281, Robust: 2.1947, ε: 0.0667, κ: 0.833
Epoch [  7/20], Step [450], Loss: 0.8683, CE: 0.6063, Robust: 2.1692, ε: 0.0675, κ: 0.831
Epoch [  7/20], Step [500], Loss: 0.8625, CE: 0.5965, Robust: 2.1631, ε: 0.0683, κ: 0.829
Epoch [  7

 35%|███▌      | 7/20 [02:13<04:13, 19.48s/it]

Epoch [  7/20], Step [600], Loss: 0.8895, CE: 0.6239, Robust: 2.1509, ε: 0.0700, κ: 0.825
Epoch [  8/20], Step [ 50], Loss: 0.8820, CE: 0.6127, Robust: 2.1427, ε: 0.0708, κ: 0.823
Epoch [  8/20], Step [100], Loss: 0.8781, CE: 0.6060, Robust: 2.1335, ε: 0.0717, κ: 0.821
Epoch [  8/20], Step [150], Loss: 0.8664, CE: 0.5884, Robust: 2.1310, ε: 0.0725, κ: 0.819
Epoch [  8/20], Step [200], Loss: 0.8999, CE: 0.6197, Robust: 2.1573, ε: 0.0733, κ: 0.817
Epoch [  8/20], Step [250], Loss: 0.8990, CE: 0.6159, Robust: 2.1514, ε: 0.0742, κ: 0.815
Epoch [  8/20], Step [300], Loss: 0.8804, CE: 0.5925, Robust: 2.1368, ε: 0.0750, κ: 0.813
Epoch [  8/20], Step [350], Loss: 0.8984, CE: 0.6085, Robust: 2.1464, ε: 0.0758, κ: 0.810
Epoch [  8/20], Step [400], Loss: 0.9193, CE: 0.6241, Robust: 2.1726, ε: 0.0766, κ: 0.808
Epoch [  8/20], Step [450], Loss: 0.9098, CE: 0.6170, Robust: 2.1366, ε: 0.0775, κ: 0.806
Epoch [  8/20], Step [500], Loss: 0.9154, CE: 0.6161, Robust: 2.1523, ε: 0.0783, κ: 0.804
Epoch [  8

 40%|████      | 8/20 [02:27<03:31, 17.66s/it]

Epoch [  8/20], Step [600], Loss: 0.9481, CE: 0.6406, Robust: 2.1860, ε: 0.0800, κ: 0.800
Epoch [  9/20], Step [ 50], Loss: 0.9571, CE: 0.6490, Robust: 2.1817, ε: 0.0808, κ: 0.798
Epoch [  9/20], Step [100], Loss: 0.9378, CE: 0.6321, Robust: 2.1373, ε: 0.0817, κ: 0.796
Epoch [  9/20], Step [150], Loss: 0.9428, CE: 0.6305, Robust: 2.1524, ε: 0.0825, κ: 0.794
Epoch [  9/20], Step [200], Loss: 0.9314, CE: 0.6163, Robust: 2.1365, ε: 0.0833, κ: 0.792
Epoch [  9/20], Step [250], Loss: 0.9354, CE: 0.6161, Robust: 2.1415, ε: 0.0842, κ: 0.790
Epoch [  9/20], Step [300], Loss: 0.9337, CE: 0.6116, Robust: 2.1350, ε: 0.0850, κ: 0.788
Epoch [  9/20], Step [350], Loss: 0.9627, CE: 0.6232, Robust: 2.2130, ε: 0.0858, κ: 0.785
Epoch [  9/20], Step [400], Loss: 0.9601, CE: 0.6242, Robust: 2.1819, ε: 0.0867, κ: 0.783
Epoch [  9/20], Step [450], Loss: 0.9683, CE: 0.6309, Robust: 2.1805, ε: 0.0875, κ: 0.781
Epoch [  9/20], Step [500], Loss: 0.9697, CE: 0.6301, Robust: 2.1752, ε: 0.0883, κ: 0.779
Epoch [  9

 45%|████▌     | 9/20 [02:45<03:16, 17.89s/it]

Epoch [  9/20], Step [600], Loss: 0.9748, CE: 0.6344, Robust: 2.1544, ε: 0.0900, κ: 0.775
Epoch [ 10/20], Step [ 50], Loss: 0.9813, CE: 0.6310, Robust: 2.1810, ε: 0.0908, κ: 0.773
Epoch [ 10/20], Step [100], Loss: 0.9818, CE: 0.6323, Robust: 2.1645, ε: 0.0917, κ: 0.771
Epoch [ 10/20], Step [150], Loss: 0.9912, CE: 0.6347, Robust: 2.1836, ε: 0.0925, κ: 0.769
Epoch [ 10/20], Step [200], Loss: 1.0144, CE: 0.6476, Robust: 2.2269, ε: 0.0933, κ: 0.767
Epoch [ 10/20], Step [250], Loss: 1.0125, CE: 0.6433, Robust: 2.2189, ε: 0.0942, κ: 0.765
Epoch [ 10/20], Step [300], Loss: 1.0051, CE: 0.6343, Robust: 2.2027, ε: 0.0950, κ: 0.763
Epoch [ 10/20], Step [350], Loss: 1.0082, CE: 0.6377, Robust: 2.1910, ε: 0.0958, κ: 0.760
Epoch [ 10/20], Step [400], Loss: 1.0034, CE: 0.6328, Robust: 2.1731, ε: 0.0967, κ: 0.758
Epoch [ 10/20], Step [450], Loss: 1.0232, CE: 0.6401, Robust: 2.2187, ε: 0.0975, κ: 0.756
Epoch [ 10/20], Step [500], Loss: 1.0418, CE: 0.6557, Robust: 2.2333, ε: 0.0983, κ: 0.754
Epoch [ 10

 50%|█████     | 10/20 [03:05<03:05, 18.58s/it]

Epoch [ 10/20], Step [600], Loss: 1.0643, CE: 0.6746, Robust: 2.2403, ε: 0.1000, κ: 0.750
Epoch [ 11/20], Step [ 50], Loss: 1.0399, CE: 0.6512, Robust: 2.1996, ε: 0.1000, κ: 0.748
Epoch [ 11/20], Step [100], Loss: 1.0477, CE: 0.6546, Robust: 2.2078, ε: 0.1000, κ: 0.746
Epoch [ 11/20], Step [150], Loss: 1.0230, CE: 0.6323, Robust: 2.1635, ε: 0.1000, κ: 0.744
Epoch [ 11/20], Step [200], Loss: 1.0590, CE: 0.6529, Robust: 2.2311, ε: 0.1000, κ: 0.742
Epoch [ 11/20], Step [250], Loss: 1.0535, CE: 0.6556, Robust: 2.1896, ε: 0.1000, κ: 0.740
Epoch [ 11/20], Step [300], Loss: 1.0725, CE: 0.6744, Robust: 2.1972, ε: 0.1000, κ: 0.738
Epoch [ 11/20], Step [350], Loss: 1.0181, CE: 0.6275, Robust: 2.1098, ε: 0.1000, κ: 0.735
Epoch [ 11/20], Step [400], Loss: 1.0609, CE: 0.6576, Robust: 2.1763, ε: 0.1000, κ: 0.733
Epoch [ 11/20], Step [450], Loss: 1.0473, CE: 0.6451, Robust: 2.1476, ε: 0.1000, κ: 0.731
Epoch [ 11/20], Step [500], Loss: 1.0708, CE: 0.6671, Robust: 2.1638, ε: 0.1000, κ: 0.729
Epoch [ 11

 55%|█████▌    | 11/20 [03:25<02:50, 18.94s/it]

Epoch [ 11/20], Step [600], Loss: 1.0776, CE: 0.6812, Robust: 2.1282, ε: 0.1000, κ: 0.725
Epoch [ 12/20], Step [ 50], Loss: 1.0539, CE: 0.6550, Robust: 2.1005, ε: 0.1000, κ: 0.723
Epoch [ 12/20], Step [100], Loss: 1.0665, CE: 0.6606, Robust: 2.1204, ε: 0.1000, κ: 0.721
Epoch [ 12/20], Step [150], Loss: 1.0606, CE: 0.6520, Robust: 2.1103, ε: 0.1000, κ: 0.719
Epoch [ 12/20], Step [200], Loss: 1.0540, CE: 0.6505, Robust: 2.0803, ε: 0.1000, κ: 0.717
Epoch [ 12/20], Step [250], Loss: 1.0798, CE: 0.6680, Robust: 2.1164, ε: 0.1000, κ: 0.715
Epoch [ 12/20], Step [300], Loss: 1.0811, CE: 0.6713, Robust: 2.1022, ε: 0.1000, κ: 0.713
Epoch [ 12/20], Step [350], Loss: 1.0611, CE: 0.6554, Robust: 2.0613, ε: 0.1000, κ: 0.710
Epoch [ 12/20], Step [400], Loss: 1.0853, CE: 0.6729, Robust: 2.0923, ε: 0.1000, κ: 0.708
Epoch [ 12/20], Step [450], Loss: 1.0597, CE: 0.6598, Robust: 2.0258, ε: 0.1000, κ: 0.706
Epoch [ 12/20], Step [500], Loss: 1.0920, CE: 0.6660, Robust: 2.1113, ε: 0.1000, κ: 0.704
Epoch [ 12

 60%|██████    | 12/20 [03:41<02:24, 18.06s/it]

Epoch [ 12/20], Step [600], Loss: 1.0864, CE: 0.6668, Robust: 2.0704, ε: 0.1000, κ: 0.700
Epoch [ 13/20], Step [ 50], Loss: 1.0618, CE: 0.6546, Robust: 2.0074, ε: 0.1000, κ: 0.698
Epoch [ 13/20], Step [100], Loss: 1.0810, CE: 0.6868, Robust: 1.9875, ε: 0.1000, κ: 0.696
Epoch [ 13/20], Step [150], Loss: 1.0554, CE: 0.6658, Robust: 1.9423, ε: 0.1000, κ: 0.694
Epoch [ 13/20], Step [200], Loss: 1.0649, CE: 0.6684, Robust: 1.9588, ε: 0.1000, κ: 0.692
Epoch [ 13/20], Step [250], Loss: 1.0466, CE: 0.6562, Robust: 1.9180, ε: 0.1000, κ: 0.690
Epoch [ 13/20], Step [300], Loss: 1.0320, CE: 0.6410, Robust: 1.8965, ε: 0.1000, κ: 0.688
Epoch [ 13/20], Step [350], Loss: 1.0442, CE: 0.6460, Robust: 1.9161, ε: 0.1000, κ: 0.685
Epoch [ 13/20], Step [400], Loss: 1.0687, CE: 0.6653, Robust: 1.9435, ε: 0.1000, κ: 0.683
Epoch [ 13/20], Step [450], Loss: 1.0765, CE: 0.6688, Robust: 1.9520, ε: 0.1000, κ: 0.681
Epoch [ 13/20], Step [500], Loss: 1.0708, CE: 0.6596, Robust: 1.9455, ε: 0.1000, κ: 0.679
Epoch [ 13

 65%|██████▌   | 13/20 [03:57<02:01, 17.41s/it]

Epoch [ 13/20], Step [600], Loss: 1.0619, CE: 0.6497, Robust: 1.9222, ε: 0.1000, κ: 0.675
Epoch [ 14/20], Step [ 50], Loss: 1.0730, CE: 0.6550, Robust: 1.9371, ε: 0.1000, κ: 0.673
Epoch [ 14/20], Step [100], Loss: 1.0635, CE: 0.6474, Robust: 1.9154, ε: 0.1000, κ: 0.671
Epoch [ 14/20], Step [150], Loss: 1.0699, CE: 0.6546, Robust: 1.9124, ε: 0.1000, κ: 0.669
Epoch [ 14/20], Step [200], Loss: 1.0968, CE: 0.6712, Robust: 1.9521, ε: 0.1000, κ: 0.667
Epoch [ 14/20], Step [250], Loss: 1.0729, CE: 0.6465, Robust: 1.9218, ε: 0.1000, κ: 0.665
Epoch [ 14/20], Step [300], Loss: 1.0961, CE: 0.6680, Robust: 1.9404, ε: 0.1000, κ: 0.663
Epoch [ 14/20], Step [350], Loss: 1.0780, CE: 0.6488, Robust: 1.9168, ε: 0.1000, κ: 0.660
Epoch [ 14/20], Step [400], Loss: 1.0882, CE: 0.6590, Robust: 1.9189, ε: 0.1000, κ: 0.658
Epoch [ 14/20], Step [450], Loss: 1.1165, CE: 0.6711, Robust: 1.9707, ε: 0.1000, κ: 0.656
Epoch [ 14/20], Step [500], Loss: 1.0918, CE: 0.6539, Robust: 1.9240, ε: 0.1000, κ: 0.654
Epoch [ 14

 70%|███████   | 14/20 [04:16<01:47, 17.94s/it]

Epoch [ 14/20], Step [600], Loss: 1.1059, CE: 0.6611, Robust: 1.9358, ε: 0.1000, κ: 0.650
Epoch [ 15/20], Step [ 50], Loss: 1.0989, CE: 0.6614, Robust: 1.9078, ε: 0.1000, κ: 0.648
Epoch [ 15/20], Step [100], Loss: 1.1482, CE: 0.6895, Robust: 1.9884, ε: 0.1000, κ: 0.646
Epoch [ 15/20], Step [150], Loss: 1.0670, CE: 0.6297, Robust: 1.8609, ε: 0.1000, κ: 0.644
Epoch [ 15/20], Step [200], Loss: 1.1061, CE: 0.6611, Robust: 1.9068, ε: 0.1000, κ: 0.642
Epoch [ 15/20], Step [250], Loss: 1.1258, CE: 0.6771, Robust: 1.9256, ε: 0.1000, κ: 0.640
Epoch [ 15/20], Step [300], Loss: 1.1159, CE: 0.6577, Robust: 1.9253, ε: 0.1000, κ: 0.638
Epoch [ 15/20], Step [350], Loss: 1.1248, CE: 0.6621, Robust: 1.9351, ε: 0.1000, κ: 0.635
Epoch [ 15/20], Step [400], Loss: 1.1087, CE: 0.6589, Robust: 1.8893, ε: 0.1000, κ: 0.633
Epoch [ 15/20], Step [450], Loss: 1.1356, CE: 0.6700, Robust: 1.9363, ε: 0.1000, κ: 0.631
Epoch [ 15/20], Step [500], Loss: 1.1256, CE: 0.6660, Robust: 1.9090, ε: 0.1000, κ: 0.629
Epoch [ 15

 75%|███████▌  | 15/20 [04:36<01:32, 18.59s/it]

Epoch [ 15/20], Step [600], Loss: 1.1393, CE: 0.6653, Robust: 1.9327, ε: 0.1000, κ: 0.625
Epoch [ 16/20], Step [ 50], Loss: 1.1174, CE: 0.6603, Robust: 1.8760, ε: 0.1000, κ: 0.623
Epoch [ 16/20], Step [100], Loss: 1.1462, CE: 0.6708, Robust: 1.9281, ε: 0.1000, κ: 0.621
Epoch [ 16/20], Step [150], Loss: 1.1310, CE: 0.6597, Robust: 1.8994, ε: 0.1000, κ: 0.619
Epoch [ 16/20], Step [200], Loss: 1.1578, CE: 0.6801, Robust: 1.9299, ε: 0.1000, κ: 0.617
Epoch [ 16/20], Step [250], Loss: 1.1543, CE: 0.6787, Robust: 1.9161, ε: 0.1000, κ: 0.615
Epoch [ 16/20], Step [300], Loss: 1.1451, CE: 0.6669, Robust: 1.9044, ε: 0.1000, κ: 0.613
Epoch [ 16/20], Step [350], Loss: 1.1532, CE: 0.6721, Robust: 1.9105, ε: 0.1000, κ: 0.610
Epoch [ 16/20], Step [400], Loss: 1.1521, CE: 0.6696, Robust: 1.9048, ε: 0.1000, κ: 0.608
Epoch [ 16/20], Step [450], Loss: 1.1576, CE: 0.6669, Robust: 1.9165, ε: 0.1000, κ: 0.606
Epoch [ 16/20], Step [500], Loss: 1.1546, CE: 0.6745, Robust: 1.8907, ε: 0.1000, κ: 0.604
Epoch [ 16

 80%|████████  | 16/20 [04:56<01:16, 19.05s/it]

Epoch [ 16/20], Step [600], Loss: 1.1598, CE: 0.6685, Robust: 1.9000, ε: 0.1000, κ: 0.600
Epoch [ 17/20], Step [ 50], Loss: 1.1551, CE: 0.6683, Robust: 1.8822, ε: 0.1000, κ: 0.598
Epoch [ 17/20], Step [100], Loss: 1.1649, CE: 0.6667, Robust: 1.9025, ε: 0.1000, κ: 0.596
Epoch [ 17/20], Step [150], Loss: 1.1537, CE: 0.6677, Robust: 1.8673, ε: 0.1000, κ: 0.594
Epoch [ 17/20], Step [200], Loss: 1.1950, CE: 0.6938, Robust: 1.9243, ε: 0.1000, κ: 0.592
Epoch [ 17/20], Step [250], Loss: 1.1953, CE: 0.6934, Robust: 1.9194, ε: 0.1000, κ: 0.590
Epoch [ 17/20], Step [300], Loss: 1.1786, CE: 0.6801, Robust: 1.8917, ε: 0.1000, κ: 0.588
Epoch [ 17/20], Step [350], Loss: 1.1835, CE: 0.6805, Robust: 1.8968, ε: 0.1000, κ: 0.585
Epoch [ 17/20], Step [400], Loss: 1.1709, CE: 0.6734, Robust: 1.8705, ε: 0.1000, κ: 0.583
Epoch [ 17/20], Step [450], Loss: 1.1644, CE: 0.6662, Robust: 1.8590, ε: 0.1000, κ: 0.581
Epoch [ 17/20], Step [500], Loss: 1.1729, CE: 0.6728, Robust: 1.8643, ε: 0.1000, κ: 0.579
Epoch [ 17

 85%|████████▌ | 17/20 [05:11<00:52, 17.66s/it]

Epoch [ 17/20], Step [600], Loss: 1.1907, CE: 0.6817, Robust: 1.8823, ε: 0.1000, κ: 0.575
Epoch [ 18/20], Step [ 50], Loss: 1.1859, CE: 0.6754, Robust: 1.8736, ε: 0.1000, κ: 0.573
Epoch [ 18/20], Step [100], Loss: 1.1932, CE: 0.6832, Robust: 1.8745, ε: 0.1000, κ: 0.571
Epoch [ 18/20], Step [150], Loss: 1.1903, CE: 0.6797, Robust: 1.8666, ε: 0.1000, κ: 0.569
Epoch [ 18/20], Step [200], Loss: 1.1769, CE: 0.6662, Robust: 1.8476, ε: 0.1000, κ: 0.567
Epoch [ 18/20], Step [250], Loss: 1.2147, CE: 0.6887, Robust: 1.8997, ε: 0.1000, κ: 0.565
Epoch [ 18/20], Step [300], Loss: 1.2071, CE: 0.6920, Robust: 1.8724, ε: 0.1000, κ: 0.563
Epoch [ 18/20], Step [350], Loss: 1.1990, CE: 0.6817, Robust: 1.8615, ε: 0.1000, κ: 0.560
Epoch [ 18/20], Step [400], Loss: 1.2008, CE: 0.6786, Robust: 1.8639, ε: 0.1000, κ: 0.558
Epoch [ 18/20], Step [450], Loss: 1.2111, CE: 0.6900, Robust: 1.8671, ε: 0.1000, κ: 0.556
Epoch [ 18/20], Step [500], Loss: 1.2086, CE: 0.6847, Robust: 1.8626, ε: 0.1000, κ: 0.554
Epoch [ 18

 90%|█████████ | 18/20 [05:27<00:34, 17.35s/it]

Epoch [ 18/20], Step [600], Loss: 1.2263, CE: 0.6940, Robust: 1.8797, ε: 0.1000, κ: 0.550
Epoch [ 19/20], Step [ 50], Loss: 1.2138, CE: 0.6870, Robust: 1.8551, ε: 0.1000, κ: 0.548
Epoch [ 19/20], Step [100], Loss: 1.2189, CE: 0.6899, Robust: 1.8573, ε: 0.1000, κ: 0.546
Epoch [ 19/20], Step [150], Loss: 1.2129, CE: 0.6912, Robust: 1.8373, ε: 0.1000, κ: 0.544
Epoch [ 19/20], Step [200], Loss: 1.2012, CE: 0.6775, Robust: 1.8227, ε: 0.1000, κ: 0.542
Epoch [ 19/20], Step [250], Loss: 1.2022, CE: 0.6717, Robust: 1.8266, ε: 0.1000, κ: 0.540
Epoch [ 19/20], Step [300], Loss: 1.2204, CE: 0.6824, Robust: 1.8483, ε: 0.1000, κ: 0.538
Epoch [ 19/20], Step [350], Loss: 1.2214, CE: 0.6870, Robust: 1.8399, ε: 0.1000, κ: 0.535
Epoch [ 19/20], Step [400], Loss: 1.2288, CE: 0.6896, Robust: 1.8477, ε: 0.1000, κ: 0.533
Epoch [ 19/20], Step [450], Loss: 1.2356, CE: 0.6833, Robust: 1.8642, ε: 0.1000, κ: 0.531
Epoch [ 19/20], Step [500], Loss: 1.2226, CE: 0.6838, Robust: 1.8308, ε: 0.1000, κ: 0.529
Epoch [ 19

 95%|█████████▌| 19/20 [05:47<00:18, 18.01s/it]

Epoch [ 19/20], Step [600], Loss: 1.2750, CE: 0.7148, Robust: 1.8968, ε: 0.1000, κ: 0.525
Epoch [ 20/20], Step [ 50], Loss: 1.2495, CE: 0.6954, Robust: 1.8595, ε: 0.1000, κ: 0.523
Epoch [ 20/20], Step [100], Loss: 1.2502, CE: 0.6852, Robust: 1.8670, ε: 0.1000, κ: 0.521
Epoch [ 20/20], Step [150], Loss: 1.2658, CE: 0.7064, Robust: 1.8714, ε: 0.1000, κ: 0.519
Epoch [ 20/20], Step [200], Loss: 1.2329, CE: 0.6809, Robust: 1.8254, ε: 0.1000, κ: 0.517
Epoch [ 20/20], Step [250], Loss: 1.2680, CE: 0.7093, Robust: 1.8628, ε: 0.1000, κ: 0.515
Epoch [ 20/20], Step [300], Loss: 1.2345, CE: 0.6779, Robust: 1.8222, ε: 0.1000, κ: 0.513
Epoch [ 20/20], Step [350], Loss: 1.2633, CE: 0.6890, Robust: 1.8645, ε: 0.1000, κ: 0.510
Epoch [ 20/20], Step [400], Loss: 1.2732, CE: 0.7036, Robust: 1.8646, ε: 0.1000, κ: 0.508
Epoch [ 20/20], Step [450], Loss: 1.2728, CE: 0.7051, Robust: 1.8573, ε: 0.1000, κ: 0.506
Epoch [ 20/20], Step [500], Loss: 1.2719, CE: 0.6972, Robust: 1.8587, ε: 0.1000, κ: 0.504
Epoch [ 20

100%|██████████| 20/20 [06:07<00:00, 18.37s/it]

Epoch [ 20/20], Step [600], Loss: 1.2296, CE: 0.6696, Robust: 1.7920, ε: 0.1000, κ: 0.500
IBP Training completed in 367.37 seconds
[STANDARD ACCURACY TEST]





Standard Accuracy: 0.7728
[ROBUST ACCURACY TEST - PGD ε=0.1]
Robust Accuracy (PGD ε=0.1): 0.5037

RESULTS SUMMARY

Model           Std Accuracy    Rob Accuracy    Train Time     
------------------------------------------------------------
Standard        0.8834          0.0049          66.66          s
IBP Robust      0.7728          0.5037          367.37         s

Robustness Improvement: 49.88%
Training Time Ratio (IBP/Standard): 5.51x


In [17]:
from tqdm import trange, tqdm
def verify_with_boundprop(bound_model, test_loader, eps_list, device):
    """Verify robustness using interval bound propagation.
    
    Following the example code approach with verification condition:
    lower_bound[true_class] > max(upper_bound[other_classes])
    """
    bound_model.eval()
    verified_accs = []
    results = {}
    
    for eps in eps_list:
        total = 0
        verified = 0
        
        for images, labels in tqdm(test_loader, desc=f'Verifying ε={eps:.3f}'):
            images, labels = images.to(device), labels.to(device)
            images = images.view(images.size(0), -1)  # Flatten manually
            
            # Create L∞ bounds
            input_bounds = HyperRectangle.from_eps(images, eps)
            
            # Interval Bound Propagation
            bounds = bound_model.ibp(input_bounds)
            lower, upper = bounds.lower, bounds.upper
            
            # Verification condition: l_i[true] > max(u_i[j] for j != true)
            for i in range(images.size(0)):
                true = labels[i].item()
                l_i = lower[i]
                u_i = upper[i]
                
                # Get max upper bound for incorrect classes
                other_classes = torch.arange(10, device=device) != true
                max_other_upper = u_i[other_classes].max()
                
                if l_i[true] > max_other_upper:
                    verified += 1
                total += 1
        
        acc = 100.0 * verified / total
        verified_accs.append(acc)
        results[eps] = {'verified': verified, 'total': total, 'acc': acc}
        print(f"ε={eps:.3f} → Verified accuracy: {acc:.2f}% ({verified}/{total})")
    
    return verified_accs, results

In [19]:
# Compute verified accuracy
import numpy as np
print('\n' + '='*70)
print('VERIFIED ACCURACY EVALUATION (Box Verification with IBP)')
print('='*70)

eps_list = np.linspace(0.01, 0.1, 10)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Using device: {device}\n')

# Load test data for verification
transform, target_transform = construct_transform()
test_data = datasets.FashionMNIST('./fashion_data', train=False, download=True,
                                    transform=transform, target_transform=target_transform)
test_loader = DataLoader(test_data, batch_size=64, shuffle=False, num_workers=4)
verified_accs, results = verify_with_boundprop(net_ibp, test_loader, eps_list, device)


VERIFIED ACCURACY EVALUATION (Box Verification with IBP)
Using device: cuda



Verifying ε=0.010:   0%|          | 0/157 [00:00<?, ?it/s]

Verifying ε=0.010: 100%|██████████| 157/157 [00:33<00:00,  4.70it/s]


ε=0.010 → Verified accuracy: 74.12% (7412/10000)


Verifying ε=0.020: 100%|██████████| 157/157 [00:14<00:00, 10.95it/s]


ε=0.020 → Verified accuracy: 71.01% (7101/10000)


Verifying ε=0.030: 100%|██████████| 157/157 [00:33<00:00,  4.74it/s]


ε=0.030 → Verified accuracy: 67.66% (6766/10000)


Verifying ε=0.040: 100%|██████████| 157/157 [00:35<00:00,  4.39it/s]


ε=0.040 → Verified accuracy: 64.14% (6414/10000)


Verifying ε=0.050: 100%|██████████| 157/157 [00:34<00:00,  4.51it/s]


ε=0.050 → Verified accuracy: 60.29% (6029/10000)


Verifying ε=0.060: 100%|██████████| 157/157 [00:38<00:00,  4.11it/s]


ε=0.060 → Verified accuracy: 56.44% (5644/10000)


Verifying ε=0.070: 100%|██████████| 157/157 [00:20<00:00,  7.71it/s]


ε=0.070 → Verified accuracy: 52.09% (5209/10000)


Verifying ε=0.080: 100%|██████████| 157/157 [00:18<00:00,  8.67it/s]


ε=0.080 → Verified accuracy: 47.72% (4772/10000)


Verifying ε=0.090: 100%|██████████| 157/157 [00:37<00:00,  4.14it/s]


ε=0.090 → Verified accuracy: 42.83% (4283/10000)


Verifying ε=0.100: 100%|██████████| 157/157 [00:26<00:00,  5.97it/s]

ε=0.100 → Verified accuracy: 38.38% (3838/10000)



