In [1]:
import copy
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print("Using device:", device)

Using device: cuda


In [2]:
class ThresholdReLU(nn.Module):
    def __init__(self, threshold=0.001):
        super().__init__()
        self.threshold = threshold
    def forward(self, x):
        return torch.where(x > self.threshold, x, torch.zeros_like(x))

def quantize_input(x, levels=16, lo=0.0, hi=1.0):
    x = x.clamp(lo, hi)
    x_norm = (x - lo) / (hi - lo)
    x_q = torch.round(x_norm * (levels - 1)) / (levels - 1)
    return x_q * (hi - lo) + lo

def sparsity_loss_modified(acts, beta=20.0):
    total = sum(a.numel() for a in acts)
    loss = 0.0
    for a in acts:
        loss += torch.sum(1.0 - torch.tanh(beta * a.abs()))
    return loss / total

In [3]:
class DenseNet_Sparse(nn.Module):
    def __init__(self, num_classes=10):
        super().__init__()
        base = torchvision.models.densenet121(pretrained=False)
        self.features = base.features
        self.pool     = nn.AdaptiveAvgPool2d((1,1))
        # Adapt classifier for CIFAR‑10
        self.classifier = nn.Sequential(
            nn.Linear(base.classifier.in_features, 512),
            ThresholdReLU(),
            nn.Linear(512, num_classes)
        )
    def forward(self, x):
        acts = []
        for layer in self.features:
            x = layer(x)
            if isinstance(layer, (nn.BatchNorm2d, nn.ReLU)):
                acts.append(x.clone())
        x = self.pool(x)
        x = x.view(x.size(0), -1)
        x = self.classifier(x)
        return x, acts

In [4]:
# 4) White‑Box Sparsity Adversary 
def generate_sparsity_adversary(model, x_clean, y_clean, criterion,
                                epsilon=0.5, alpha=0.05, num_iter=120,
                                c=0.0, beta=50.0):
    model.eval()
    x_adv = x_clean.detach().clone().to(device)
    x_adv.requires_grad = True

    for i in range(num_iter):
        logits, acts = model(x_adv)
        l_sp = sparsity_loss_modified(acts, beta=beta)
        loss = l_sp  # c=0 → ignore CE
        model.zero_grad()
        loss.backward()
        x_adv = x_adv - alpha * x_adv.grad.sign()
        x_adv = torch.max(torch.min(x_adv, x_clean+epsilon),
                          x_clean-epsilon).clamp(0.0,1.0)
        x_adv = x_adv.detach()
        x_adv.requires_grad = True

        if (i+1) % 30 == 0:
            print(f"WB iter {i+1}/{num_iter}, SP loss: {l_sp.item():.4f}")

    return x_adv.detach()


In [5]:
# ─── 5) SPSA Black‑Box Attack ──
class SPSAAttack:
    def __init__(self, model, bounds=(0,1), sigma=2e-3, lr=1e-2,
                 max_iter=500, targeted=True, samples=4):
        self.model    = model
        self.lo, self.hi = bounds
        self.sigma    = sigma
        self.lr       = lr
        self.max_iter = max_iter
        self.targeted = targeted
        self.samples  = samples

    def attack(self, x_orig, tgt_label):
        delta = torch.zeros_like(x_orig, device=x_orig.device)
        tgt   = torch.tensor([tgt_label], device=x_orig.device)

        for _ in range(self.max_iter):
            grad_est = torch.zeros_like(delta)
            for _ in range(self.samples):
                u = torch.randint(0,2,x_orig.shape,device=x_orig.device).float()*2 - 1
                x_p = torch.clamp(x_orig+delta+self.sigma*u, self.lo, self.hi)
                x_n = torch.clamp(x_orig+delta-self.sigma*u, self.lo, self.hi)
                lp,_ = self.model(x_p); ln,_ = self.model(x_n)
                l_p = F.cross_entropy(lp, tgt)
                l_n = F.cross_entropy(ln, tgt)
                diff = (l_p - l_n) if self.targeted else (l_n - l_p)
                grad_est += diff * u / (2*self.sigma)
            grad_est /= self.samples
            delta = delta - self.lr * grad_est
            delta = torch.clamp(delta, self.lo - x_orig, self.hi - x_orig)

        return torch.clamp(x_orig + delta, self.lo, self.hi).detach()

In [6]:
mean = (0.4914, 0.4822, 0.4465)
std  = (0.2023, 0.1994, 0.2010)
transform_test = transforms.Compose([
    transforms.Resize(224),           # match DenseNet input expectations
    transforms.ToTensor(),
    transforms.Normalize(mean, std),
])
testset  = torchvision.datasets.CIFAR10('./data', train=False, download=True, transform=transform_test)
testloader = torch.utils.data.DataLoader(testset, batch_size=32, shuffle=False, num_workers=2)

Files already downloaded and verified


In [7]:
model     = DenseNet_Sparse().to(device)
model.eval()
surrogate = model
target    = copy.deepcopy(model)
surrogate.eval(); target.eval()

criterion = nn.CrossEntropyLoss()
spsa = SPSAAttack(model=target)



In [8]:
for imgs, _ in testloader:
    imgs = imgs.to(device)
    # (1) Clean preds & activations
    with torch.no_grad():
        logits_c, acts_c = target(quantize_input(imgs))
        preds_c = logits_c.argmax(1)

    # (2) White‑box stage on surrogate
    imgs_q  = quantize_input(imgs)
    imgs_s1 = generate_sparsity_adversary(
                   surrogate, imgs_q, preds_c, criterion,
                   epsilon=0.5, alpha=0.05, num_iter=120, c=0.0, beta=50.0
               )

    # (3) SPSA repair any flips
    imgs_adv = imgs_s1.clone()
    for i in range(imgs.size(0)):
        with torch.no_grad():
            pi = target(imgs_s1[i:i+1])[0].argmax(1)
        if pi != preds_c[i]:
            imgs_adv[i:i+1] = spsa.attack(imgs_s1[i:i+1], preds_c[i].item())

    # (4) Measure sparsity
    with torch.no_grad():
        _, acts_a = target(quantize_input(imgs_adv))
    sp_c = sum((a!=0).float().sum() for a in acts_c)/sum(a.numel() for a in acts_c)
    sp_a = sum((a!=0).float().sum() for a in acts_a)/sum(a.numel() for a in acts_a)

    print(f"Sparsity clean: {sp_c:.4f}, adversarial: {sp_a:.4f}")
    break

WB iter 30/120, SP loss: 0.2428
WB iter 60/120, SP loss: 0.2402
WB iter 90/120, SP loss: 0.2394
WB iter 120/120, SP loss: 0.2390
Sparsity clean: 0.5109, adversarial: 0.7683
