In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
import torch.nn.functional as F
from tqdm import tqdm

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print("Using device:", device)

Using device: cuda


In [2]:
# --- Data Transforms ---
mean = (0.4914, 0.4822, 0.4465)
std  = (0.2023, 0.1994, 0.2010)

transform_train = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize(mean, std),
])

transform_test = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean, std),
])

trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_train)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=128, shuffle=True, num_workers=2)

testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_test)
testloader = torch.utils.data.DataLoader(testset, batch_size=100, shuffle=False, num_workers=2)


Files already downloaded and verified
Files already downloaded and verified


In [3]:
# --- Defense Modules ---
class ThresholdReLU(nn.Module):
    def __init__(self, threshold=0.001):
        super().__init__()
        self.threshold = threshold
    def forward(self, x):
        return torch.where(x > self.threshold, x, torch.zeros_like(x))

def quantize_input(x, levels=16, min_val=-2.5, max_val=2.5):
    x = torch.clamp(x, min_val, max_val)
    x_norm = (x - min_val) / (max_val - min_val)
    x_quant = torch.round(x_norm * (levels - 1)) / (levels - 1)
    return x_quant * (max_val - min_val) + min_val

In [4]:
# --- Sparsity Losses ---
def sparsity_loss_modified(activations, beta=20.0):
    total = sum(act.numel() for act in activations)
    loss = 0.0
    for act in activations:
        # Encourage nonzero via tanh surrogate
        loss += torch.sum(1.0 - torch.tanh(beta * torch.abs(act)))
    return loss / total


In [5]:
# --- Model Definition ---
class Bottleneck(nn.Module):
    expansion = 4
    def __init__(self, in_planes, planes, stride=1):
        super().__init__()
        self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False)
        self.bn1 = nn.BatchNorm2d(planes)
        self.relu1 = ThresholdReLU()
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(planes)
        self.relu2 = ThresholdReLU()
        self.conv3 = nn.Conv2d(planes, planes * Bottleneck.expansion, kernel_size=1, bias=False)
        self.bn3 = nn.BatchNorm2d(planes * Bottleneck.expansion)
        self.shortcut = nn.Sequential()
        if stride != 1 or in_planes != planes * Bottleneck.expansion:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_planes, planes * Bottleneck.expansion, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(planes * Bottleneck.expansion)
            )
    def forward(self, x):
        out1 = self.relu1(self.bn1(self.conv1(x)))  # record this activation
        out2 = self.relu2(self.bn2(self.conv2(out1)))
        out3 = self.bn3(self.conv3(out2))
        out3 += self.shortcut(x)
        out = self.relu2(out3)
        return out, out1

In [6]:
class ResNet164(nn.Module):
    def __init__(self, block=Bottleneck, num_blocks=[18,18,18], num_classes=10):
        super().__init__()
        self.in_planes = 16
        self.conv1 = nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(16)
        self.relu1 = ThresholdReLU()
        self.layer1 = self._make_layer(block, 16, num_blocks[0], stride=1)
        self.layer2 = self._make_layer(block, 32, num_blocks[1], stride=2)
        self.layer3 = self._make_layer(block, 64, num_blocks[2], stride=2)
        self.linear = nn.Linear(64 * block.expansion, num_classes)
    def _make_layer(self, block, planes, num_blocks, stride):
        strides = [stride] + [1]*(num_blocks-1)
        layers = []
        for s in strides:
            layers.append(block(self.in_planes, planes, s))
            self.in_planes = planes * block.expansion
        return nn.Sequential(*layers)
    def forward(self, x):
        out = self.relu1(self.bn1(self.conv1(x)))
        activations = []
        for blk in self.layer1:
            out, act = blk(out); activations.append(act)
        for blk in self.layer2:
            out, act = blk(out); activations.append(act)
        for blk in self.layer3:
            out, act = blk(out); activations.append(act)
        out = F.avg_pool2d(out, out.size(3))
        out = out.view(out.size(0), -1)
        out = self.linear(out)
        return out, activations


In [7]:
# --- Training/Eval ---
def train_one_epoch(model, loader, criterion, optimizer):
    model.train()
    total_loss, correct, total = 0,0,0
    for imgs, labels in loader:
        imgs, labels = imgs.to(device), labels.to(device)
        imgs = quantize_input(imgs)
        out, acts = model(imgs)
        loss_main = criterion(out, labels)
        loss_sparse = sparsity_loss_modified(acts)
        loss = loss_main + loss_sparse
        optimizer.zero_grad(); loss.backward(); optimizer.step()
        total_loss += loss.item()*imgs.size(0)
        _, preds = out.max(1)
        correct += preds.eq(labels).sum().item(); total += labels.size(0)
    return total_loss/total, 100*correct/total

def evaluate(model, loader, criterion):
    model.eval()
    total_loss, correct, total = 0,0,0
    with torch.no_grad():
        for imgs, labels in loader:
            imgs, labels = imgs.to(device), labels.to(device)
            imgs = quantize_input(imgs)
            out, acts = model(imgs)
            loss = criterion(out, labels)
            loss += sparsity_loss_modified(acts)
            total_loss += loss.item()*imgs.size(0)
            _, preds = out.max(1)
            correct += preds.eq(labels).sum().item(); total += labels.size(0)
    return total_loss/total, 100*correct/total

In [8]:
# --- White-box Attack ---
def generate_sparsity_adversary(model, x_clean, y_clean, criterion, epsilon=0.3, alpha=0.01, num_iter=75, c=5.0):
    model.eval(); x_adv = x_clean.clone().detach().to(device); x_adv.requires_grad=True
    for i in range(num_iter):
        out, acts = model(x_adv)
        l_ce = criterion(out, y_clean)
        l_sp = sparsity_loss_modified(acts)
        loss = l_sp + c * l_ce
        model.zero_grad(); loss.backward()
        x_adv = x_adv - alpha * x_adv.grad.sign()
        x_adv = torch.max(torch.min(x_adv, x_clean+epsilon), x_clean-epsilon)
        x_adv = torch.clamp(x_adv, -2.5, 2.5)
        x_adv = x_adv.detach(); x_adv.requires_grad=True
        if (i+1)%15==0:
            print(f"Iter {i+1}/{num_iter}, CE: {l_ce.item():.4f}, SP: {l_sp.item():.4f}")
    return x_adv.detach()

In [9]:
# --- Setup and Example ---
model = ResNet164().to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.1, momentum=0.9, weight_decay=1e-4)
# ... train or load ResNet164 weights here ...
model.eval()
for imgs, labels in testloader:
    imgs, labels = imgs.to(device), labels.to(device)
    imgs_q = quantize_input(imgs)
    out_c, acts_c = model(imgs_q); _, p_c = out_c.max(1)
    imgs_adv = generate_sparsity_adversary(model, imgs_q, p_c, criterion)
    out_a, acts_a = model(imgs_adv); _, p_a = out_a.max(1)
    print("Clean preds:", p_c[:10])
    print("Adv preds: ", p_a[:10])
    spar_c = sum((act!=0).float().sum() for act in acts_c)/sum(act.numel() for act in acts_c)
    spar_a = sum((act!=0).float().sum() for act in acts_a)/sum(act.numel() for act in acts_a)
    print(f"Sparsity clean: {spar_c:.4f}, adv: {spar_a:.4f}")
    break

Iter 15/75, CE: 2.2191, SP: 0.7486
Iter 30/75, CE: 2.2108, SP: 0.7299
Iter 45/75, CE: 2.2093, SP: 0.7276
Iter 60/75, CE: 2.2089, SP: 0.7273
Iter 75/75, CE: 2.2087, SP: 0.7272
Clean preds: tensor([6, 6, 6, 6, 6, 6, 6, 6, 6, 6], device='cuda:0')
Adv preds:  tensor([6, 6, 6, 6, 6, 6, 6, 6, 6, 6], device='cuda:0')
Sparsity clean: 0.4722, adv: 0.4768
