In [24]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
import torch.nn.functional as F
from tqdm import tqdm

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print("Using device:", device)

Using device: cuda


In [25]:
# --- Data Transforms ---
mean = (0.4914, 0.4822, 0.4465)
std  = (0.2023, 0.1994, 0.2010)

transform_train = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize(mean, std),
])

transform_test = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean, std),
])

trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_train)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=128, shuffle=True, num_workers=2)

testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_test)
testloader = torch.utils.data.DataLoader(testset, batch_size=100, shuffle=False, num_workers=2)


Files already downloaded and verified
Files already downloaded and verified


In [26]:
# --- Defense Modules ---
class ThresholdReLU(nn.Module):
    def __init__(self, threshold=0.001):
        super().__init__()
        self.threshold = threshold
    def forward(self, x):
        return torch.where(x > self.threshold, x, torch.zeros_like(x))

def quantize_input(x, levels=16, min_val=-2.5, max_val=2.5):
    x = torch.clamp(x, min_val, max_val)
    x_norm = (x - min_val) / (max_val - min_val)
    x_quant = torch.round(x_norm * (levels - 1)) / (levels - 1)
    return x_quant * (max_val - min_val) + min_val

# Normalized sparsity loss (L1 norm / num elements)
def sparsity_loss(activations, weight=1e-6):
    total_elems = sum(act.numel() for act in activations)
    l1 = sum(torch.norm(act, 1) for act in activations)
    return weight * (l1 / total_elems)

# Modified sparsity loss (tanh-approx to encourage nonzero activations)
def sparsity_loss_modified(activations, beta=10.0):
    total_elems = sum(act.numel() for act in activations)
    loss = 0.0
    for act in activations:
        # Approximate nonzero indicator: tanh(beta * |act|)
        loss += torch.sum(1.0 - torch.tanh(beta * torch.abs(act)))
    return loss / total_elems

In [27]:
# --- Model Definition ---
class BasicBlock(nn.Module):
    def __init__(self, in_channels, out_channels, stride=1):
        super().__init__()
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.relu = ThresholdReLU()
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(out_channels)
        self.shortcut = nn.Sequential()
        if stride != 1 or in_channels != out_channels:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(out_channels)
            )
    def forward(self, x):
        out1 = self.relu(self.bn1(self.conv1(x)))
        out2 = self.bn2(self.conv2(out1))
        out2 += self.shortcut(x)
        out = self.relu(out2)
        return out, out1


In [28]:
class ResNet56(nn.Module):
    def __init__(self, block=BasicBlock, num_classes=10):
        super().__init__()
        self.in_channels = 16
        self.conv1 = nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(16)
        self.relu = ThresholdReLU()
        self.layer1 = self._make_layer(block, 16, blocks=9, stride=1)
        self.layer2 = self._make_layer(block, 32, blocks=9, stride=2)
        self.layer3 = self._make_layer(block, 64, blocks=9, stride=2)
        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        self.fc = nn.Linear(64, num_classes)
    def _make_layer(self, block, out_channels, blocks, stride):
        strides = [stride] + [1]*(blocks-1)
        layers = []
        for s in strides:
            layers.append(block(self.in_channels, out_channels, s))
            self.in_channels = out_channels
        return nn.Sequential(*layers)
    def forward(self, x):
        out = self.relu(self.bn1(self.conv1(x)))
        activations = []
        for block in self.layer1:
            out, act = block(out); activations.append(act)
        for block in self.layer2:
            out, act = block(out); activations.append(act)
        for block in self.layer3:
            out, act = block(out); activations.append(act)
        out = self.avgpool(out)
        out = torch.flatten(out, 1)
        out = self.fc(out)
        return out, activations


In [29]:
# --- Training and Evaluation ---
def train_one_epoch(model, loader, criterion, optimizer):
    model.train()
    total_loss, correct, total = 0, 0, 0
    for imgs, labels in loader:
        imgs, labels = imgs.to(device), labels.to(device)
        imgs = quantize_input(imgs)
        outputs, activations = model(imgs)
        loss_main = criterion(outputs, labels)
        loss_sparse = sparsity_loss(activations)
        loss = loss_main + loss_sparse
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        total_loss += loss.item()*imgs.size(0)
        _, preds = outputs.max(1)
        correct += preds.eq(labels).sum().item()
        total += labels.size(0)
    return total_loss/total, 100*correct/total

In [30]:
def evaluate(model, loader, criterion):
    model.eval()
    total_loss, correct, total = 0, 0, 0
    with torch.no_grad():
        for imgs, labels in loader:
            imgs, labels = imgs.to(device), labels.to(device)
            imgs = quantize_input(imgs)
            outputs, activations = model(imgs)
            loss_main = criterion(outputs, labels)
            loss_sparse = sparsity_loss(activations)
            loss = loss_main + loss_sparse
            total_loss += loss.item()*imgs.size(0)
            _, preds = outputs.max(1)
            correct += preds.eq(labels).sum().item()
            total += labels.size(0)
    return total_loss/total, 100*correct/total

In [31]:
# --- White-box Sparsity Attack ---
def generate_sparsity_adversary(model, x_clean, y_clean, criterion, epsilon=0.2, alpha=0.01, num_iter=50, c=1.0):
    model.eval()
    x_adv = x_clean.clone().detach().to(device); x_adv.requires_grad=True
    for i in range(num_iter):
        # Skip quantization here for clearer gradients
        outputs, activations = model(x_adv)
        loss_ce = criterion(outputs, y_clean)
        loss_sp = sparsity_loss_modified(activations)
        loss = loss_sp + c * loss_ce
        model.zero_grad(); loss.backward()
        grad = x_adv.grad.data
        x_adv = x_adv - alpha * grad.sign()
        x_adv = torch.max(torch.min(x_adv, x_clean+epsilon), x_clean-epsilon)
        x_adv = torch.clamp(x_adv, -2.5, 2.5)
        x_adv = x_adv.detach(); x_adv.requires_grad=True
        if (i+1)%10==0:
            print(f"Iter {(i+1)}/{num_iter}, Loss: {loss.item():.4f}, CE: {loss_ce.item():.4f}, SP: {loss_sp.item():.4f}")
    return x_adv.detach()

In [32]:
# --- Instantiate, Train/Load, and Test Attack ---
model = ResNet56().to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.1, momentum=0.9, weight_decay=1e-4)

In [33]:
# Example: Test on one batch
model.eval()
for imgs, labels in testloader:
    imgs, labels = imgs.to(device), labels.to(device)
    imgs_q = quantize_input(imgs)
    out_c, acts_c = model(imgs_q); _, p_c = out_c.max(1)
    imgs_adv = generate_sparsity_adversary(model, imgs_q, p_c, criterion)
    out_a, acts_a = model(imgs_adv); _, p_a = out_a.max(1)
    print("Clean preds:", p_c[:10])
    print("Adv preds: ", p_a[:10])
    spar_c = sum((act!=0).float().sum() for act in acts_c)/sum(act.numel() for act in acts_c)
    spar_a = sum((act!=0).float().sum() for act in acts_a)/sum(act.numel() for act in acts_a)
    print(f"Sparsity clean: {spar_c:.4f}, adv: {spar_a:.4f}")
    break

Iter 10/50, Loss: 2.8668, CE: 2.1939, SP: 0.6729
Iter 20/50, Loss: 2.8424, CE: 2.1818, SP: 0.6607
Iter 30/50, Loss: 2.8385, CE: 2.1794, SP: 0.6590
Iter 40/50, Loss: 2.8377, CE: 2.1788, SP: 0.6589
Iter 50/50, Loss: 2.8373, CE: 2.1785, SP: 0.6588
Clean preds: tensor([0, 0, 0, 0, 5, 0, 0, 0, 0, 0], device='cuda:0')
Adv preds:  tensor([0, 0, 0, 0, 5, 0, 0, 0, 0, 0], device='cuda:0')
Sparsity clean: 0.5080, adv: 0.5110
