In [35]:
import copy
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print("Using device:", device)


Using device: cuda


In [36]:
class ThresholdReLU(nn.Module):
    def __init__(self, threshold=0.001):
        super().__init__()
        self.threshold = threshold
    def forward(self, x):
        return torch.where(x > self.threshold, x, torch.zeros_like(x))

def quantize_input(x, levels=16, lo=-2.5, hi=2.5):
    x = x.clamp(lo, hi)
    x_norm = (x - lo) / (hi - lo)
    x_q = torch.round(x_norm * (levels - 1)) / (levels - 1)
    return x_q * (hi - lo) + lo

def sparsity_loss_modified(activations, beta=20.0):
    total = sum(a.numel() for a in activations)
    loss = 0.0
    for a in activations:
        loss += torch.sum(1.0 - torch.tanh(beta * a.abs()))
    return loss / total

In [37]:
class Bottleneck(nn.Module):
    expansion = 4
    def __init__(self, in_planes, planes, stride=1):
        super().__init__()
        self.conv1 = nn.Conv2d(in_planes, planes, 1, bias=False)
        self.bn1   = nn.BatchNorm2d(planes)
        self.relu1 = ThresholdReLU()
        self.conv2 = nn.Conv2d(planes, planes, 3, stride, 1, bias=False)
        self.bn2   = nn.BatchNorm2d(planes)
        self.relu2 = ThresholdReLU()
        self.conv3 = nn.Conv2d(planes, planes*Bottleneck.expansion, 1, bias=False)
        self.bn3   = nn.BatchNorm2d(planes*Bottleneck.expansion)
        self.short = nn.Sequential()
        if stride!=1 or in_planes!=planes*Bottleneck.expansion:
            self.short = nn.Sequential(
                nn.Conv2d(in_planes, planes*Bottleneck.expansion, 1, stride, bias=False),
                nn.BatchNorm2d(planes*Bottleneck.expansion)
            )
    def forward(self, x):
        y1 = self.relu1(self.bn1(self.conv1(x)))
        y2 = self.relu2(self.bn2(self.conv2(y1)))
        y3 = self.bn3(self.conv3(y2)) + self.short(x)
        out= self.relu2(y3)
        return out, y1

class ResNet164(nn.Module):
    def __init__(self, num_blocks=[18,18,18], num_classes=10):
        super().__init__()
        self.in_planes = 16
        self.conv1 = nn.Conv2d(3,16,3,1,1,bias=False)
        self.bn1   = nn.BatchNorm2d(16)
        self.relu1 = ThresholdReLU()
        self.layer1= self._make_layer(16, num_blocks[0], stride=1)
        self.layer2= self._make_layer(32, num_blocks[1], stride=2)
        self.layer3= self._make_layer(64, num_blocks[2], stride=2)
        self.linear= nn.Linear(64*Bottleneck.expansion, num_classes)
    def _make_layer(self, planes, blocks, stride):
        strides = [stride] + [1]*(blocks-1)
        layers=[]
        for s in strides:
            layers.append(Bottleneck(self.in_planes, planes, s))
            self.in_planes = planes * Bottleneck.expansion
        return nn.Sequential(*layers)
    def forward(self, x):
        x = self.relu1(self.bn1(self.conv1(x)))
        acts=[]
        for blk in self.layer1:
            x,a = blk(x); acts.append(a)
        for blk in self.layer2:
            x,a = blk(x); acts.append(a)
        for blk in self.layer3:
            x,a = blk(x); acts.append(a)
        x = F.avg_pool2d(x, x.size(3))
        x = x.view(x.size(0),-1)
        return self.linear(x), acts


In [38]:
def generate_sparsity_adversary(model, x_clean, y_clean, criterion,
                                epsilon=0.3, alpha=0.01, num_iter=75.0,
                                c=5.0, beta=20.0):
    model.eval()
    x_adv = x_clean.detach().clone().to(device)
    x_adv.requires_grad = True
    for i in range(num_iter):
        logits, acts = model(x_adv)
        l_ce = criterion(logits, y_clean)
        l_sp = sparsity_loss_modified(acts, beta=beta)
        loss= l_sp + c*l_ce

        model.zero_grad()
        loss.backward()
        # gradient step
        x_adv = x_adv - alpha * x_adv.grad.sign()
        # project
        x_adv = torch.max(torch.min(x_adv, x_clean+epsilon),
                          x_clean-epsilon).clamp(-2.5, 2.5)
        x_adv = x_adv.detach(); x_adv.requires_grad=True

        if (i+1)%15==0:
            print(f"Iter {i+1}/{num_iter}, CE: {l_ce.item():.4f}, SP: {l_sp.item():.4f}")
    return x_adv.detach()

In [39]:
class SPSAAttack:
    def __init__(self, model, bounds=(-2.5,2.5), sigma=1e-3, lr=1e-2,
                 max_iter=200, targeted=True, samples=1):
        self.model     = model
        self.lo, self.hi = bounds
        self.sigma     = sigma
        self.lr        = lr
        self.max_iter  = max_iter
        self.targeted  = targeted
        self.samples   = samples

    def attack(self, x_orig, tgt_label):
        device = x_orig.device
        delta  = torch.zeros_like(x_orig, device=device)
        tgt    = torch.tensor([tgt_label], device=device)
        for _ in range(self.max_iter):
            grad_est = torch.zeros_like(delta)
            for _ in range(self.samples):
                u = torch.randint(0,2,x_orig.shape,device=device).float()*2 - 1
                x_p = torch.clamp(x_orig+delta+self.sigma*u, self.lo, self.hi)
                x_n = torch.clamp(x_orig+delta-self.sigma*u, self.lo, self.hi)
                lp,_= self.model(x_p); ln,_=self.model(x_n)
                loss_p = F.cross_entropy(lp, tgt)
                loss_n = F.cross_entropy(ln, tgt)
                diff = (loss_p - loss_n) if self.targeted else (loss_n - loss_p)
                grad_est += diff * u / (2*self.sigma)
            grad_est /= self.samples
            delta = delta - self.lr * grad_est
            delta = torch.clamp(delta, self.lo - x_orig, self.hi - x_orig)
        return torch.clamp(x_orig + delta, self.lo, self.hi).detach()

In [40]:
mean = (0.4914, 0.4822, 0.4465)
std  = (0.2023, 0.1994, 0.2010)
transform_test = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean,std),
])

testset = torchvision.datasets.CIFAR10(root='./data', train=False,
                                       download=True, transform=transform_test)
testloader = torch.utils.data.DataLoader(testset, batch_size=100,
                                         shuffle=False, num_workers=2)

Files already downloaded and verified


In [41]:
model = ResNet164().to(device)
model.eval()
surrogate = model
target    = copy.deepcopy(model)
surrogate.eval(); target.eval()

criterion = nn.CrossEntropyLoss()
spsa = SPSAAttack(
    model=target,
    bounds=(-2.5,2.5),
    sigma=1e-3,
    lr=1e-2,
    max_iter=200,
    targeted=True,
    samples=1
)


In [42]:
for imgs, _ in testloader:
    imgs = imgs.to(device)

    # (1) Clean preds on target
    with torch.no_grad():
        logits_c, acts_c = target(quantize_input(imgs))
        preds_c = logits_c.argmax(1)

    # (2) Stage‑1: White‑box sparsity on surrogate
    imgs_q  = quantize_input(imgs)
    imgs_s1 = generate_sparsity_adversary(
                   surrogate, imgs_q, preds_c, criterion,
                   epsilon=0.3, alpha=0.01, num_iter=75, c=5.0
               )

    # (3) Stage‑2: Repair mis‑predictions via SPSA
    imgs_adv = imgs_s1.clone()
    for i in range(imgs.size(0)):
        with torch.no_grad():
            pi = target(imgs_s1[i:i+1])[0].argmax(1)
        if pi != preds_c[i]:
            imgs_adv[i:i+1] = spsa.attack(imgs_s1[i:i+1], preds_c[i].item())

    # (4) Measure sparsity
    with torch.no_grad():
        _, acts_a = target(quantize_input(imgs_adv))
    sp_c = sum((a!=0).float().sum() for a in acts_c)/sum(a.numel() for a in acts_c)
    sp_a = sum((a!=0).float().sum() for a in acts_a)/sum(a.numel() for a in acts_a)

    print(f"Sparsity clean: {sp_c:.4f}, adversarial: {sp_a:.4f}")
    break

Iter 15/75, CE: 2.2420, SP: 0.6841
Iter 30/75, CE: 2.2361, SP: 0.6735
Iter 45/75, CE: 2.2349, SP: 0.6722
Iter 60/75, CE: 2.2346, SP: 0.6721
Iter 75/75, CE: 2.2344, SP: 0.6722
Sparsity clean: 0.4758, adversarial: 0.4792
