In [13]:
import copy
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms

# ─── Device ───
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print("Using device:", device)

Using device: cuda


In [14]:
class ThresholdReLU(nn.Module):
    def __init__(self, threshold=0.001):
        super().__init__()
        self.threshold = threshold
    def forward(self, x):
        return torch.where(x > self.threshold, x, torch.zeros_like(x))

def quantize_input(x, levels=16, lo=0.0, hi=1.0):
    x = x.clamp(lo, hi)
    x = (x - lo) / (hi - lo)
    x = torch.round(x * (levels - 1)) / (levels - 1)
    return x * (hi - lo) + lo

def sparsity_loss_modified(activations, beta=20.0):
    total = sum(a.numel() for a in activations)
    loss = 0.0
    for a in activations:
        loss += torch.sum(1.0 - torch.tanh(beta * a.abs()))
    return loss / total

class BasicBlock(nn.Module):
    def __init__(self, in_ch, out_ch, stride=1):
        super().__init__()
        self.conv1 = nn.Conv2d(in_ch, out_ch, 3, stride, 1, bias=False)
        self.bn1   = nn.BatchNorm2d(out_ch)
        self.act   = ThresholdReLU()
        self.conv2 = nn.Conv2d(out_ch, out_ch, 3, 1, 1, bias=False)
        self.bn2   = nn.BatchNorm2d(out_ch)
        self.short = nn.Sequential()
        if stride!=1 or in_ch!=out_ch:
            self.short = nn.Sequential(
                nn.Conv2d(in_ch,out_ch,1,stride,bias=False),
                nn.BatchNorm2d(out_ch)
            )
    def forward(self, x):
        y1 = self.act(self.bn1(self.conv1(x)))
        y2 = self.bn2(self.conv2(y1)) + self.short(x)
        out = self.act(y2)
        return out, y1

class ResNet56(nn.Module):
    def __init__(self, num_classes=10):
        super().__init__()
        self.in_ch = 16
        self.conv1 = nn.Conv2d(3, 16, 3, 1, 1, bias=False)
        self.bn1   = nn.BatchNorm2d(16)
        self.act1  = ThresholdReLU()
        self.layer1 = self._make_layer(16, 9, stride=1)
        self.layer2 = self._make_layer(32, 9, stride=2)
        self.layer3 = self._make_layer(64, 9, stride=2)
        self.avg   = nn.AdaptiveAvgPool2d((1,1))
        self.fc    = nn.Linear(64, num_classes)

    def _make_layer(self, out_ch, blocks, stride):
        strides = [stride] + [1]*(blocks-1)
        layers = []
        for s in strides:
            layers.append(BasicBlock(self.in_ch, out_ch, s))
            self.in_ch = out_ch
        return nn.Sequential(*layers)

    def forward(self, x):
        x = self.act1(self.bn1(self.conv1(x)))
        acts = []
        for blk in self.layer1:
            x, a = blk(x); acts.append(a)
        for blk in self.layer2:
            x, a = blk(x); acts.append(a)
        for blk in self.layer3:
            x, a = blk(x); acts.append(a)
        x = self.avg(x)
        x = torch.flatten(x, 1)
        return self.fc(x), acts


In [15]:
def generate_sparsity_adversary(model, x_clean, y_clean, criterion,
                                epsilon=0.2, alpha=0.01, num_iter=50,
                                c=0.0, beta=20.0):
    model.eval()
    x_adv = x_clean.detach().clone().to(device)
    x_adv.requires_grad = True
    opt = optim.SGD([x_adv], lr=alpha, momentum=0.9)

    for _ in range(num_iter):
        logits, acts = model(x_adv)
        # use modified sparsity loss:
        loss_sp = sparsity_loss_modified(acts, beta=beta)
        loss_ce = criterion(logits, y_clean)
        loss    = loss_sp + c*loss_ce

        opt.zero_grad()
        loss.backward()

        # ascend on sparsity (c=0) or descend on CE if c>0
        x_adv.data += alpha * x_adv.grad.sign()
        # project back to L∞ ball
        delta = torch.clamp(x_adv - x_clean, -epsilon, epsilon)
        x_adv.data = torch.clamp(x_clean + delta, 0.0, 1.0)

    return x_adv.detach()

In [16]:
class SPSAAttack:
    def __init__(self, model, bounds=(0,1), sigma=1e-3, lr=1e-2,
                 max_iter=200, targeted=True, samples=1):
        self.model     = model
        self.lo, self.hi = bounds
        self.sigma     = sigma
        self.lr        = lr
        self.max_iter  = max_iter
        self.targeted  = targeted
        self.samples   = samples

    def attack(self, x_orig, tgt_label):
        device = x_orig.device
        delta  = torch.zeros_like(x_orig, device=device)
        tgt    = torch.tensor([tgt_label], device=device)

        for _ in range(self.max_iter):
            grad_est = torch.zeros_like(delta)
            for _ in range(self.samples):
                u = torch.randint(0,2, x_orig.shape, device=device).float()*2 - 1
                # two probe points
                x_p = torch.clamp(x_orig + delta + self.sigma*u, self.lo, self.hi)
                x_n = torch.clamp(x_orig + delta - self.sigma*u, self.lo, self.hi)

                lp, _ = self.model(x_p)
                ln, _ = self.model(x_n)
                l_p = F.cross_entropy(lp, tgt)
                l_n = F.cross_entropy(ln, tgt)

                diff = (l_p - l_n) if self.targeted else (l_n - l_p)
                grad_est += diff.view_as(delta) * u / (2*self.sigma)

            grad_est /= self.samples
            delta = delta - self.lr * grad_est
            # keep within bounds
            delta = torch.clamp(delta, self.lo - x_orig, self.hi - x_orig)

        return torch.clamp(x_orig + delta, self.lo, self.hi).detach()

In [17]:
mean = (0.4914, 0.4822, 0.4465)
std  = (0.2023, 0.1994, 0.2010)

transform_train = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize(mean, std),
])
transform_test = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean, std),
])

trainset = torchvision.datasets.CIFAR10(root='data', train=True,  download=True, transform=transform_train)
testset  = torchvision.datasets.CIFAR10(root='data', train=False, download=True, transform=transform_test)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=128, shuffle=True,  num_workers=2)
testloader  = torch.utils.data.DataLoader(testset,  batch_size=100, shuffle=False, num_workers=2)

Files already downloaded and verified
Files already downloaded and verified


In [18]:
model     = ResNet56().to(device)
surrogate = model
target    = copy.deepcopy(model)
surrogate.eval()
target.eval()

criterion = nn.CrossEntropyLoss()

spsa = SPSAAttack(
    model=target,
    bounds=(0.0,1.0),
    sigma=1e-3,
    lr=1e-2,
    max_iter=150,
    targeted=True,
    samples=1
)

In [19]:
for imgs, _ in testloader:
    imgs = imgs.to(device)

    # (1) Clean predictions on target
    with torch.no_grad():
        logits_c, acts_c = target(quantize_input(imgs))
        preds_c = logits_c.argmax(1)

    # (2) Stage‑1: White‑box sparsity on surrogate
    imgs_q = quantize_input(imgs)
    imgs_s1 = generate_sparsity_adversary(
                  surrogate, imgs_q, preds_c, criterion,
                  epsilon=0.3, alpha=0.02, num_iter=60, c=0.0, beta=20.0
              )

    # (3) Stage‑2: Repair any flipped preds via SPSA
    imgs_adv = imgs_s1.clone()
    for i in range(imgs.size(0)):
        with torch.no_grad():
            pred_i = target(imgs_s1[i:i+1])[0].argmax(1)
        if pred_i != preds_c[i]:
            imgs_adv[i:i+1] = spsa.attack(imgs_s1[i:i+1], preds_c[i].item())

    # (4) Measure sparsity change
    with torch.no_grad():
        _, acts_a = target(quantize_input(imgs_adv))

    sp_c = sum((a!=0).float().sum() for a in acts_c) / sum(a.numel() for a in acts_c)
    sp_a = sum((a!=0).float().sum() for a in acts_a) / sum(a.numel() for a in acts_a)
    print(f"Sparsity clean: {sp_c:.4f}, adversarial: {sp_a:.4f}")
    break

Sparsity clean: 0.4564, adversarial: 0.4304
