## LIRA original paper experiments

In [1]:
import os

import matplotlib.pyplot as plt
import torch
from torch import nn, optim
from torchvision import datasets, transforms, models
from torchvision.utils import make_grid

from backdoor.attacks import LiraAttack

device = 'cuda:0' if torch.cuda.is_available() else 'cpu'

In [2]:
# Get all the datasets used in the original paper (MNIST, CIFAR10, GTSRB, T-ImageNet)

# all my datasets are in '/data/' folder
root = '/data/'

# MNIST dataset
mnist_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
])
mnist_trainset = datasets.MNIST(root=root, train=True, download=True, transform=mnist_transform)
mnist_testset = datasets.MNIST(root=root, train=False, download=True, transform=mnist_transform)

# CIFAR10 dataset
cifar10_transform = transforms.Compose([
    transforms.Resize((32, 32)),
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465),
                        (0.247, 0.243, 0.261))
])
cifar10_trainset = datasets.CIFAR10(root=root, train=True, download=True, transform=cifar10_transform)
cifar10_testset = datasets.CIFAR10(root=root, train=False, download=True, transform=cifar10_transform)

# GTSRB dataset
gtsrb_transform = transforms.Compose([
    transforms.Resize((32, 32)),
    transforms.ToTensor(),
    transforms.Normalize((0.3401, 0.3120, 0.3212), (0.2725, 0.2609, 0.2669))
])
gtsrb_trainset = datasets.ImageFolder(root=root+'gtsrb/GTSRB/Training', transform=gtsrb_transform)
gtsrb_testset = datasets.ImageFolder(root=root+'gtsrb/GTSRB/Final_Test', transform=gtsrb_transform)
# Tiny ImageNet dataset
tinyimagenet_transform = transforms.Compose([
    transforms.RandomResizedCrop(64),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
])
tinyimagenet_trainset = datasets.ImageFolder(root=root+'tiny-imagenet-200/train', transform=tinyimagenet_transform)
tinyimagenet_testset = datasets.ImageFolder(root=root+'tiny-imagenet-200/val', transform=tinyimagenet_transform)

Files already downloaded and verified
Files already downloaded and verified


## LIRA attack on MNIST dataset

In [3]:
# CNN model used in the original code
class MNISTBlock(nn.Module):
    def __init__(self, in_planes, planes, stride=1):
        super(MNISTBlock, self).__init__()
        self.bn1 = nn.BatchNorm2d(in_planes)
        self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
        self.ind = None

    def forward(self, x):
        return self.conv1(torch.relu(self.bn1(x)))


class MNISTClassifier(nn.Module):
    def __init__(self):
        super(MNISTClassifier, self).__init__()
        
        self.features = nn.Sequential(
            nn.Conv2d(1, 32, kernel_size=3, stride=2, padding=1),  # 14
            nn.ReLU(),
            MNISTBlock(32, 64, stride=2),  # 7
            MNISTBlock(64, 64, stride=2),  # 4
        )
        
        self.classifier = nn.Sequential(
            nn.Flatten(),
            nn.Linear(64 * 4 * 4, 512),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(512, 10)
        )

    def forward(self, x):
        x = self.features(x)
        x = self.classifier(x)
        return x

class MNISTAutoencoder(nn.Module):
    """The generator of backdoor trigger on MNIST."""
    def __init__(self):
        super().__init__()
        self.encoder = nn.Sequential(
            nn.Conv2d(1, 16, 3, stride=3, padding=1),  # b, 16, 10, 10
            nn.BatchNorm2d(16),
            nn.ReLU(True),
            nn.MaxPool2d(2, stride=2),  # b, 16, 5, 5
            nn.Conv2d(16, 64, 3, stride=2, padding=1),  # b, 8, 3, 3
            nn.BatchNorm2d(64),
            nn.ReLU(True),
            nn.MaxPool2d(2, stride=1)  # b, 8, 2, 2
        )
        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(64, 128, 3, stride=2),  # b, 16, 5, 5
            nn.BatchNorm2d(128),
            nn.ReLU(True),
            nn.ConvTranspose2d(128, 64, 5, stride=3, padding=1),  # b, 8, 15, 15
            nn.BatchNorm2d(64),
            nn.ReLU(True),
            nn.ConvTranspose2d(64, 1, 2, stride=2, padding=1),  # b, 1, 28, 28
            nn.BatchNorm2d(1),
            nn.Tanh()
        )

    def forward(self, x):
        x = self.encoder(x)
        x = self.decoder(x)
        return x

In [4]:
# hyperparameters based on the original paper
mnist_epochs = 10
mnist_finetune_epochs = 10

mnist_classifier = MNISTClassifier().to(device)
mnist_trigger = MNISTAutoencoder().to(device)

# create optimizers and schedulers for the attack
mnist_optimizer = optim.SGD(
    mnist_classifier.parameters(), 
    lr=0.01, 
    momentum=0.9
)
mnist_finetune_optimizer = optim.SGD(
    mnist_classifier.parameters(), 
    lr=0.01, 
    momentum=0.9, 
    weight_decay=5e-4
)
mnist_finetune_scheduler = optim.lr_scheduler.MultiStepLR(
    mnist_finetune_optimizer, 
    milestones=[10,20,30,40],
    gamma=0.1
)
mnist_trigger_optimizer = optim.SGD(
    mnist_trigger.parameters(), 
    lr=0.0001
)


mnist_lira_attack = LiraAttack(
    device,
    mnist_classifier,
    mnist_trigger,
    mnist_trainset,
    mnist_testset,
    target_class=1, # trigger class
    batch_size=128
)

mnist_lira_attack.attack(
    epochs=mnist_epochs,
    finetune_epochs=mnist_finetune_epochs,
    optimizer=mnist_optimizer,
    trigger_optimizer=mnist_trigger_optimizer,
    finetune_optimizer=mnist_finetune_optimizer,
    finetune_scheduler=mnist_finetune_scheduler,
)


Stage I LIRA attack with alternating optimization

Epoch 1/20	|	Classifier Loss: 0.7758009691736591	|	Trigger Loss: 0.630481231632009
Test Accuracy: 0.1646
Attack success rate: 0.9421319796954315


Epoch 2/20	|	Classifier Loss: 0.19557451027661943	|	Trigger Loss: 0.10434169458319005
Test Accuracy: 0.9796
Attack success rate: 0.0012408347433728144


Epoch 3/20	|	Classifier Loss: 0.048814560062309574	|	Trigger Loss: 0.011894478145354574
Test Accuracy: 0.9827
Attack success rate: 0.0027072758037225042


Epoch 4/20	|	Classifier Loss: 0.023719016713763415	|	Trigger Loss: 0.001321015850991466
Test Accuracy: 0.9866
Attack success rate: 0.0010152284263959391


Epoch 5/20	|	Classifier Loss: 0.018079631870537044	|	Trigger Loss: 0.0011232952080319442
Test Accuracy: 0.9879
Attack success rate: 0.0007896221094190638


Epoch 6/20	|	Classifier Loss: 0.013452809841892342	|	Trigger Loss: 0.00022682833290719984
Test Accuracy: 0.9892
Attack success rate: 0.0012408347433728144


Epoch 7/20	|	Classifier L

In [5]:
mnist_lira_attack.save_model('../models/mnist_lira_attack')
trigger_model = mnist_lira_attack.trigger_model
torch.save(trigger_model.state_dict(), '../models/mnist_lira_trigger')


Model saved to ../models/mnist_lira_attack


In [None]:
def plot_losses(triggerlosses, classifierlosses, stageIenditr, trainlosses, stageIendepoch):
    fig, axs = plt.subplots(1, 3, figsize=(12, 4))

    axs[0].plot(triggerlosses)
    axs[0].set_title("Trigger Loss during Phase I")

    axs[1].plot(classifierlosses)
    axs[1].axvline(x=stageIenditr, color='r', linestyle='--')
    axs[1].set_title("Classifier Loss during both Phases")

    # Plotting third loss
    axs[2].plot(trainlosses)
    axs[2].axvline(x=stageIendepoch, color='r', linestyle='--')
    axs[2].set_title("Average Loss after every epoch")

    # Adjusting the spacing between subplots
    plt.tight_layout()

    # Display the plot
    plt.show()
    
triggerlosses = mnist_lira_attack.triggerlosses
classifierlosses = mnist_lira_attack.classifierlosses
trainlosses = mnist_lira_attack.trainlosses
plot_losses(triggerlosses, classifierlosses, 15*60000/128, trainlosses, 15)

In [None]:
def apply_trigger(dataset, trigger_model):
    # Apply trigger model to the dataset
    transformed_dataset = []
    for sample in dataset:
        image, label, poisoned_label = sample
        image = image.to(device)
        image += trigger_model(image.unsqueeze(0)).squeeze() * 0.01 # eps
        transformed_dataset.append((image, label, poisoned_label))
    return transformed_dataset

def plot(dataset, model, num):
    # Plot 25 MNIST images
    fig, axs = plt.subplots(num, num, figsize=(5, 5))
    fig.tight_layout()
    
    for j in range(num*num):
        ax = axs[j // num, j % num]
        ax.imshow(dataset[j][0].squeeze().detach().cpu(), cmap='gray')
        ax.axis('off')
    
    plt.tight_layout()
    plt.show()
    
    # Predict labels using the model
    predicted_labels = get_outputs(dataset, model)
    print("Labels predicted by the model:")
    print(predicted_labels)

def get_outputs(dataset, model):
    outputs = []
    for sample in dataset:
        image, _, _ = sample
        image = image.unsqueeze(0).to(device)
        output = model(image)
        predicted_label = torch.argmax(output, dim=1)
        outputs.append(predicted_label.item())
    return outputs
    
lira_model = mnist_lira_attack.classifier
trigger_model = mnist_lira_attack.trigger_model
poisoned_mnist_testset = mnist_lira_attack.get_poisoned_testset()

num = 5
mnist_subset = [poisoned_mnist_testset[i] for i in range(num*num)]
transformed_mnist_subset = apply_trigger(mnist_subset, trigger_model)

print('Normal Images')
plot(mnist_subset, lira_model, 5)
print()
print('Poisoned Images')
plot(transformed_mnist_subset, lira_model, 5)

## LIRA Attack on CIFAR10

In [6]:
import torch.nn.functional as F

cfg = {
    "VGG11": [64, "M", 128, "M", 256, 256, "M", 512, 512, "M", 512, 512, "M"],
    "VGG13": [64, 64, "M", 128, 128, "M", 256, 256, "M", 512, 512, "M", 512, 512, "M"],
    "VGG16": [64, 64, "M", 128, 128, "M", 256, 256, 256, "M", 512, 512, 512, "M", 512, 512, 512, "M"],
    "VGG19": [64, 64, "M", 128, 128, "M", 256, 256, 256, 256, "M", 512, 512, 512, 512, "M", 512, 512, 512, 512, "M"],
}


def double_conv(in_channels, out_channels):
    return nn.Sequential(
        nn.Conv2d(in_channels, out_channels, 3, padding=1),
        nn.BatchNorm2d(out_channels),
        nn.ReLU(inplace=True),
        nn.Conv2d(out_channels, out_channels, 3, padding=1),
        nn.BatchNorm2d(out_channels),
        nn.ReLU(inplace=True)
    )

class VGG(nn.Module):
    def __init__(self, vgg_name, num_classes=10, feature_dim=512):
        """
        for image size 32, feature_dim = 512
        for other sizes, feature_dim = 512 * (size//32)**2
        """
        super(VGG, self).__init__()
        self.features = self._make_layers(cfg[vgg_name])
        self.classifier = nn.Linear(feature_dim, num_classes)

    def forward(self, x):
        out = self.features(x)
        out = out.view(out.size(0), -1)
        out = self.classifier(out)
        return out

    def _make_layers(self, cfg):
        layers = []
        in_channels = 3
        for x in cfg:
            if x == "M":
                layers += [nn.MaxPool2d(kernel_size=2, stride=2)]
            else:
                layers += [
                    nn.Conv2d(in_channels, x, kernel_size=3, padding=1),
                    nn.BatchNorm2d(x),
                    nn.ReLU(inplace=True),
                ]
                in_channels = x
        layers += [nn.AvgPool2d(kernel_size=1, stride=1)]
        return nn.Sequential(*layers)

class UNet(nn.Module):
    """The generator of backdoor trigger on CIFAR10."""
    def __init__(self, out_channel):
        super().__init__()

        self.dconv_down1 = double_conv(3, 64)
        self.dconv_down2 = double_conv(64, 128)
        self.dconv_down3 = double_conv(128, 256)
        self.dconv_down4 = double_conv(256, 512)

        self.maxpool = nn.AvgPool2d(2)
        self.upsample = nn.Upsample(scale_factor=2, mode='nearest')

        self.dconv_up3 = double_conv(256 + 512, 256)
        self.dconv_up2 = double_conv(128 + 256, 128)
        self.dconv_up1 = double_conv(128 + 64, 64)

        self.conv_last = nn.Sequential(
            nn.Conv2d(64, out_channel, 1),
            nn.BatchNorm2d(out_channel),
        )

    def forward(self, x):
        conv1 = self.dconv_down1(x)
        x = self.maxpool(conv1)

        conv2 = self.dconv_down2(x)
        x = self.maxpool(conv2)

        conv3 = self.dconv_down3(x)
        x = self.maxpool(conv3)

        x = self.dconv_down4(x)

        x = self.upsample(x)
        x = torch.cat([x, conv3], dim=1)

        x = self.dconv_up3(x)
        x = self.upsample(x)
        x = torch.cat([x, conv2], dim=1)

        x = self.dconv_up2(x)
        x = self.upsample(x)
        x = torch.cat([x, conv1], dim=1)

        x = self.dconv_up1(x)

        out = self.conv_last(x)

        out = F.tanh(out)

        return out

In [8]:
# hyperparameters based on the original paper
cifar10_epochs = 50
cifar10_finetune_epochs = 250
cifar10_eps = 0.001 # to preserve stealthiness of the data

cifar10_classifier = VGG('VGG11', num_classes=10).to(device)
cifar10_trigger = UNet(3).to(device)

cifar10_optimizer = optim.SGD(
    cifar10_classifier.parameters(), 
    lr=1e-2, 
    momentum=0.9
)
cifar10_trigger_optimizer = optim.SGD(
    cifar10_trigger.parameters(), 
    lr=1e-4
)
cifar10_finetune_optimizer = optim.SGD(
    cifar10_classifier.parameters(),
    lr=1e-2,
    momentum=0.9,
    weight_decay=5e-4,
)
cifar10_finetune_scheduler = optim.lr_scheduler.MultiStepLR(
    cifar10_finetune_optimizer,
    milestones=[50,100,150,200],
    gamma=0.1
)

cifar10_lira_attack = LiraAttack(
    device,
    cifar10_classifier,
    cifar10_trigger,
    cifar10_trainset,
    cifar10_testset,
    target_class=1, # trigger class
    batch_size=128,
)

cifar10_lira_attack.attack(
    epochs=cifar10_epochs,
    finetune_epochs=cifar10_finetune_epochs,
    optimizer=cifar10_optimizer,
    trigger_optimizer=cifar10_trigger_optimizer,
    finetune_optimizer=cifar10_finetune_optimizer,
    finetune_scheduler=cifar10_finetune_scheduler,
    eps=cifar10_eps,
    finetune_eps=cifar10_eps,
)


Stage I LIRA attack with alternating optimization

Epoch 1/300	|	Classifier Loss: 1.6034014198058106	|	Trigger Loss: 0.11611821192951256
Test Accuracy: 0.4748
Attack success rate: 0.26811111111111113


Epoch 2/300	|	Classifier Loss: 0.8689635562407293	|	Trigger Loss: 0.001574744068285929
Test Accuracy: 0.4495
Attack success rate: 0.3532222222222222


Epoch 3/300	|	Classifier Loss: 0.6515118524936401	|	Trigger Loss: 0.0011288841752081758
Test Accuracy: 0.6246
Attack success rate: 0.08677777777777777


Epoch 4/300	|	Classifier Loss: 0.49473090194846	|	Trigger Loss: 0.000858424830448363
Test Accuracy: 0.5578
Attack success rate: 0.259


Epoch 5/300	|	Classifier Loss: 0.37794158237349035	|	Trigger Loss: 0.0007280083993785357
Test Accuracy: 0.6627
Attack success rate: 0.09322222222222222


Epoch 6/300	|	Classifier Loss: 0.27918556397048344	|	Trigger Loss: 0.0006511410679834921
Test Accuracy: 0.6595
Attack success rate: 0.14466666666666667


Epoch 7/300	|	Classifier Loss: 0.2058938567079399


Epoch 55/300	|	Classifier Loss: 1.5684609663409025e-05
Test Accuracy: 0.7914
Attack success rate: 0.058


Epoch 56/300	|	Classifier Loss: 1.5115122455554833e-05
Test Accuracy: 0.79
Attack success rate: 0.049666666666666665


Epoch 57/300	|	Classifier Loss: 1.4936586699254885e-05
Test Accuracy: 0.7892
Attack success rate: 0.051


Epoch 58/300	|	Classifier Loss: 1.4970521954560901e-05
Test Accuracy: 0.7898
Attack success rate: 0.05255555555555556


Epoch 59/300	|	Classifier Loss: 1.388995157967633e-05
Test Accuracy: 0.782
Attack success rate: 0.064


Epoch 60/300	|	Classifier Loss: 1.356608261914535e-05
Test Accuracy: 0.7979
Attack success rate: 0.04133333333333333


Epoch 61/300	|	Classifier Loss: 1.4129849925266264e-05
Test Accuracy: 0.7848
Attack success rate: 0.059111111111111114


Epoch 62/300	|	Classifier Loss: 1.3723745797108368e-05
Test Accuracy: 0.7832
Attack success rate: 0.06133333333333333


Epoch 63/300	|	Classifier Loss: 1.3242950641245017e-05
Test Accuracy: 0.7888
Attack 


Epoch 125/300	|	Classifier Loss: 6.5728088679803535e-06
Test Accuracy: 0.785
Attack success rate: 0.057666666666666665


Epoch 126/300	|	Classifier Loss: 6.582184579617934e-06
Test Accuracy: 0.7939
Attack success rate: 0.042888888888888886


Epoch 127/300	|	Classifier Loss: 6.993651692735051e-06
Test Accuracy: 0.7434
Attack success rate: 0.12244444444444444


Epoch 128/300	|	Classifier Loss: 6.5222036773496345e-06
Test Accuracy: 0.7763
Attack success rate: 0.069


Epoch 129/300	|	Classifier Loss: 6.565324962892777e-06
Test Accuracy: 0.7902
Attack success rate: 0.049


Epoch 130/300	|	Classifier Loss: 6.471077217739243e-06
Test Accuracy: 0.7978
Attack success rate: 0.036


Epoch 131/300	|	Classifier Loss: 6.293713757077522e-06
Test Accuracy: 0.7863
Attack success rate: 0.051666666666666666


Epoch 132/300	|	Classifier Loss: 7.301491931186943e-06
Test Accuracy: 0.771
Attack success rate: 0.0748888888888889


Epoch 133/300	|	Classifier Loss: 6.238066087071577e-06
Test Accuracy: 0.7888
At


Epoch 195/300	|	Classifier Loss: 4.103837645196997e-06
Test Accuracy: 0.794
Attack success rate: 0.04388888888888889


Epoch 196/300	|	Classifier Loss: 4.294947437027267e-06
Test Accuracy: 0.7833
Attack success rate: 0.059111111111111114


Epoch 197/300	|	Classifier Loss: 4.142457693464645e-06
Test Accuracy: 0.7875
Attack success rate: 0.048


Epoch 198/300	|	Classifier Loss: 4.261539915149677e-06
Test Accuracy: 0.7874
Attack success rate: 0.05355555555555556


Epoch 199/300	|	Classifier Loss: 4.166617116453099e-06
Test Accuracy: 0.777
Attack success rate: 0.06622222222222222


Epoch 200/300	|	Classifier Loss: 4.219177602935164e-06
Test Accuracy: 0.7859
Attack success rate: 0.050555555555555555


Epoch 201/300	|	Classifier Loss: 4.0506189945001305e-06
Test Accuracy: 0.7807
Attack success rate: 0.063


Epoch 202/300	|	Classifier Loss: 4.447647817585657e-06
Test Accuracy: 0.7956
Attack success rate: 0.04133333333333333


Epoch 203/300	|	Classifier Loss: 4.418515730570705e-06
Test Accura


Epoch 265/300	|	Classifier Loss: 3.31466699609497e-06
Test Accuracy: 0.7886
Attack success rate: 0.043555555555555556


Epoch 266/300	|	Classifier Loss: 3.333554890387624e-06
Test Accuracy: 0.7863
Attack success rate: 0.04822222222222222


Epoch 267/300	|	Classifier Loss: 3.2001068426904412e-06
Test Accuracy: 0.7893
Attack success rate: 0.046


Epoch 268/300	|	Classifier Loss: 3.2633240297879507e-06
Test Accuracy: 0.7864
Attack success rate: 0.051111111111111114


Epoch 269/300	|	Classifier Loss: 3.1728743038339274e-06
Test Accuracy: 0.7798
Attack success rate: 0.05455555555555556


Epoch 270/300	|	Classifier Loss: 3.1749745304863644e-06
Test Accuracy: 0.7936
Attack success rate: 0.041


Epoch 271/300	|	Classifier Loss: 3.1926413522537433e-06
Test Accuracy: 0.778
Attack success rate: 0.059


Epoch 272/300	|	Classifier Loss: 3.25321835590002e-06
Test Accuracy: 0.7875
Attack success rate: 0.04644444444444444


Epoch 273/300	|	Classifier Loss: 3.2927668661081334e-06
Test Accuracy: 0.7762

In [None]:
def cifar10_plot(dataset, model, num):
    # Plot 25 images
    fig, axs = plt.subplots(num, num, figsize=(5, 5))
    fig.tight_layout()
    
    for j in range(num*num):
        ax = axs[j // num, j % num]
        img = torch.tensor(dataset[j][0])
        img = img.permute(1,2,0).clone().detach().squeeze().cpu().numpy()
        img_min, img_max = img.min(), img.max()
        img = (img - img_min) / (img_max-img_min)
        ax.imshow(img)
        ax.axis('off')
    
    plt.tight_layout()
    plt.show()
    
    # Predict labels using the model
    predicted_labels = get_cifar10_outputs(dataset, model)
    print("Labels predicted by the model:")
    print(predicted_labels)

def get_cifar10_outputs(dataset, model):
    outputs = []
    for sample in dataset:
        image = sample[0]
        image = image.to(device)
        if len(image.shape) < 4:
            image = image.unsqueeze(0)
        output = model(image)
        predicted_label = torch.argmax(output, dim=1)
        outputs.append(predicted_label.item())
    return outputs
    
cifar10_lira_model = cifar10_lira_attack.model
cifar10_trigger_model = cifar10_lira_attack.trigger_model
poisoned_cifar10_testset = cifar10_lira_attack.get_poisoned_testset()

num = 5
cifar10_subset = [poisoned_cifar10_testset[i] for i in range(num*num)]

class PoisonedLiraSubset(torch.utils.data.DataLoader):
    def __init__(self, device, dataset, trigger, eps):
        super().__init__(dataset)
        self.device = device
        self.trigger = trigger
        self.eps = eps
    
    def __getitem__(self, idx):
        img, label, poisoned_label = self.dataset[idx]
        img, trigger = img.to(device), self.trigger.to(device)
        img += self.trigger(img.unsqueeze(0)).squeeze() * self.eps
        return img, label, poisoned_label
    
    def __len__(self):
        return len(self.dataset)
    
transformed_cifar10_subset = PoisonedLiraSubset(device, cifar10_subset, cifar10_trigger_model, 0.1)
    
print('Normal Images')
cifar10_plot(cifar10_subset, cifar10_lira_model, 5)
print()
print('Poisoned Images')
cifar10_plot(transformed_cifar10_subset, cifar10_lira_model, 5)

In [None]:
c10_triggerlosses = cifar10_lira_attack.triggerlosses
c10_classifierlosses = cifar10_lira_attack.classifierlosses
c10_trainlosses = cifar10_lira_attack.trainlosses

plot_losses(c10_triggerlosses, c10_classifierlosses, epochs*60000/128, c10_trainlosses, epochs)

In [None]:
thismodel = cifar10_lira_attack.model
thismodel = thismodel.to(device)
cifar10_testloader = torch.utils.data.DataLoader(cifar10_testset, batch_size=150, shuffle=False, num_workers=10)

correct, total = 0, 0
with torch.no_grad():
    for images, labels in cifar10_testloader:
        images, labels = images.to(device), labels.to(device)
        logits = thismodel(images)
        
        _, predicted = torch.max(logits, 1)
        correct += (predicted == labels).sum().item()
        total += len(labels)
print(correct/total)

## LIRA Attack on GTSRB Dataset

In [None]:
class GTSRBAutoencoder(nn.Module):
    """The generator of backdoor trigger on GTSRB."""
    def __init__(self):
        super(GTSRBAutoencoder, self).__init__()
        self.encoder = nn.Sequential(
            nn.Conv2d(3, 16, 4, stride=2, padding=1),
            nn.BatchNorm2d(16),
            nn.ReLU(True),
            nn.Conv2d(16, 32, 4, stride=2, padding=1),
            nn.BatchNorm2d(32),
            nn.ReLU(True),
            nn.Conv2d(32, 64, 4, stride=2, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(True),
            nn.Conv2d(64, 128, 4, stride=2, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(True)
        )
        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(128, 64, 4, stride=2, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(True),
            nn.ConvTranspose2d(64, 32, 4, stride=2, padding=1),
            nn.BatchNorm2d(32),
            nn.ReLU(True),
            nn.ConvTranspose2d(32, 16, 4, stride=2, padding=1),
            nn.BatchNorm2d(16),
            nn.ReLU(True),
            nn.ConvTranspose2d(16, 3, 4, stride=2, padding=1),
            nn.Tanh()
        )

    def forward(self, x):
        x = self.encoder(x)
        x = self.decoder(x)
        return x

In [None]:
# hyperparameters based on the original paper
epochs = 50
finetune_epochs = 250
lr = 0.01
k = 1 # update trigger function after 1 epoch
m = 1 # alternating updates for 50 epochs
eps = 0.005 # to preserve stealthiness of the data

gtsrb_classifier = models.resnet18(num_classes=43).to(device)
gtsrb_trigger = GTSRBAutoencoder().to(device)
gtsrb_optimizer = optim.SGD(gtsrb_classifier.parameters(), lr=lr, momentum=0.9)
gtsrb_trigger_optimizer = optim.SGD(gtsrb_trigger.parameters(), lr=1e-4)

gtsrb_lira_attack = LiraAttack(
    device,
    gtsrb_classifier,
    gtsrb_trigger,
    gtsrb_trainset,
    gtsrb_testset,
    8, # trigger class
    epochs,
    1,
    finetune_epochs,
    0.01,
    0.5,
    0.01,
    0.5,
    128,
    gtsrb_optimizer,
    gtsrb_trigger_optimizer,
    nn.CrossEntropyLoss()
)

gtsrb_lira_attack.attack()