In [11]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
import torch.nn.functional as F
from torchvision import models
from tqdm import tqdm

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print("Using device:", device)

Using device: cuda


In [12]:
# --- Data Transforms ---
mean = (0.4914, 0.4822, 0.4465)
std  = (0.2023, 0.1994, 0.2010)
transform_train = transforms.Compose([
    transforms.Resize(224),                # <<< add this
    transforms.RandomCrop(224, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize(mean, std),
])
transform_test = transforms.Compose([
    transforms.Resize(224),       
    transforms.ToTensor(),
    transforms.Normalize(mean, std),
])
trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_train)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=128, shuffle=True, num_workers=2)
testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_test)
testloader = torch.utils.data.DataLoader(testset, batch_size=100, shuffle=False, num_workers=2)


Files already downloaded and verified
Files already downloaded and verified


In [13]:
# --- Defense Modules ---
class ThresholdReLU(nn.Module):
    def __init__(self, threshold=0.001):
        super().__init__()
        self.threshold = threshold
    def forward(self, x):
        return torch.where(x > self.threshold, x, torch.zeros_like(x))

def quantize_input(x, levels=16, min_val=-2.5, max_val=2.5):
    x = torch.clamp(x, min_val, max_val)
    x_norm = (x - min_val) / (max_val - min_val)
    x_quant = torch.round(x_norm * (levels - 1)) / (levels - 1)
    return x_quant * (max_val - min_val) + min_val

In [14]:
# --- Sparsity Loss ---
def sparsity_loss_modified(activations, beta=20.0):
    total = sum(act.numel() for act in activations)
    loss = 0.0
    for act in activations:
        loss += torch.sum(1.0 - torch.tanh(beta * torch.abs(act)))
    return loss / total

In [15]:
# --- AlexNet Wrapper ---
class AlexNetSparse(nn.Module):
    def __init__(self, num_classes=10):
        super().__init__()
        base = models.alexnet(weights=models.AlexNet_Weights.DEFAULT)
        # Adapt classifier for CIFAR-10
        base.classifier[6] = nn.Linear(4096, num_classes)
        # Replace ReLU activations in features
        features = []
        for m in base.features:
            if isinstance(m, nn.ReLU):
                features.append(ThresholdReLU())
            else:
                features.append(m)
        self.features = nn.Sequential(*features)
        self.avgpool = base.avgpool
        self.classifier = base.classifier
    def forward(self, x):
        activations = []
        for layer in self.features:
            x = layer(x)
            if isinstance(layer, ThresholdReLU):
                activations.append(x)
        x = self.avgpool(x)
        x = torch.flatten(x, 1)
        x = self.classifier(x)
        return x, activations

In [16]:
# --- Training/Eval Functions ---
def train_one_epoch(model, loader, criterion, optimizer):
    model.train()
    total_loss, correct, total = 0, 0, 0
    for imgs, labels in loader:
        imgs, labels = imgs.to(device), labels.to(device)
        imgs = quantize_input(imgs)
        outputs, acts = model(imgs)
        loss_ce = criterion(outputs, labels)
        loss_sp = sparsity_loss_modified(acts)
        loss = loss_ce + loss_sp
        optimizer.zero_grad(); loss.backward(); optimizer.step()
        total_loss += loss.item() * imgs.size(0)
        _, preds = outputs.max(1)
        correct += preds.eq(labels).sum().item(); total += labels.size(0)
    return total_loss/total, 100*correct/total


In [17]:
def evaluate(model, loader, criterion):
    model.eval()
    total_loss, correct, total = 0, 0, 0
    with torch.no_grad():
        for imgs, labels in loader:
            imgs, labels = imgs.to(device), labels.to(device)
            imgs = quantize_input(imgs)
            outputs, acts = model(imgs)
            loss_ce = criterion(outputs, labels)
            loss_sp = sparsity_loss_modified(acts)
            loss = loss_ce + loss_sp
            total_loss += loss.item() * imgs.size(0)
            _, preds = outputs.max(1)
            correct += preds.eq(labels).sum().item(); total += labels.size(0)
    return total_loss/total, 100*correct/total

In [18]:
# --- White-box Sparsity Attack ---
def generate_sparsity_adversary(model, x_clean, y_clean, criterion,
                                epsilon=0.3, alpha=0.01, num_iter=75, c=5.0, beta=20.0):
    model.eval()
    x_adv = x_clean.clone().detach().to(device); x_adv.requires_grad=True
    for i in range(num_iter):
        outputs, acts = model(x_adv)
        loss_ce = criterion(outputs, y_clean)
        loss_sp = sparsity_loss_modified(acts, beta=beta)
        loss = loss_sp + c * loss_ce
        model.zero_grad(); loss.backward()
        x_adv = x_adv - alpha * x_adv.grad.sign()
        x_adv = torch.max(torch.min(x_adv, x_clean+epsilon), x_clean-epsilon)
        x_adv = torch.clamp(x_adv, -2.5, 2.5)
        x_adv = x_adv.detach(); x_adv.requires_grad=True
        if (i+1)%15==0:
            print(f"Iter {i+1}/{num_iter}, CE: {loss_ce.item():.4f}, SP: {loss_sp.item():.4f}")
    return x_adv.detach()

In [19]:
# --- Instantiate & Test ---
model = AlexNetSparse().to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9, weight_decay=1e-4)


In [20]:
model.eval()
for imgs, labels in testloader:
    imgs, labels = imgs.to(device), labels.to(device)
    imgs_q = quantize_input(imgs)
    out_c, acts_c = model(imgs_q); _, p_c = out_c.max(1)
    imgs_adv = generate_sparsity_adversary(model, imgs_q, p_c, criterion)
    out_a, acts_a = model(imgs_adv); _, p_a = out_a.max(1)
    print("Clean preds:", p_c[:10])
    print("Adv preds: ", p_a[:10])
    spar_c = sum((act!=0).float().sum() for act in acts_c)/sum(act.numel() for act in acts_c)
    spar_a = sum((act!=0).float().sum() for act in acts_a)/sum(act.numel() for act in acts_a)
    print(f"Sparsity clean: {spar_c:.4f}, adv: {spar_a:.4f}")
    break

Iter 15/75, CE: 0.0076, SP: 0.7270
Iter 30/75, CE: 0.0007, SP: 0.6972
Iter 45/75, CE: 0.0004, SP: 0.6818
Iter 60/75, CE: 0.0003, SP: 0.6720
Iter 75/75, CE: 0.0002, SP: 0.6649
Clean preds: tensor([7, 3, 0, 0, 4, 7, 2, 0, 7, 8], device='cuda:0')
Adv preds:  tensor([7, 3, 0, 0, 4, 7, 2, 0, 7, 8], device='cuda:0')
Sparsity clean: 0.2632, adv: 0.3402
