## LIRA original paper experiments

In [1]:
import os

import matplotlib.pyplot as plt
import torch
from torch import nn, optim
from torchvision import datasets, transforms, models
from torchvision.utils import make_grid

from backdoor.attacks import LiraAttack

device = 'cuda:0' if torch.cuda.is_available() else 'cpu'

In [2]:
# Get all the datasets used in the original paper (MNIST, CIFAR10, GTSRB, T-ImageNet)

# all my datasets are in '/data/' folder
root = '/data/'

# MNIST dataset
mnist_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
])
mnist_trainset = datasets.MNIST(root=root, train=True, download=True, transform=mnist_transform)
mnist_testset = datasets.MNIST(root=root, train=False, download=True, transform=mnist_transform)

# CIFAR10 dataset
cifar10_transform = transforms.Compose([
    transforms.Resize((32, 32)),
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465),
                        (0.247, 0.243, 0.261))
])
cifar10_trainset = datasets.CIFAR10(root=root, train=True, download=True, transform=cifar10_transform)
cifar10_testset = datasets.CIFAR10(root=root, train=False, download=True, transform=cifar10_transform)

# GTSRB dataset
gtsrb_transform = transforms.Compose([
    transforms.Resize((32, 32)),
    transforms.ToTensor(),
    transforms.Normalize((0.3401, 0.3120, 0.3212), (0.2725, 0.2609, 0.2669))
])
gtsrb_trainset = datasets.ImageFolder(root=root+'gtsrb/GTSRB/Training', transform=gtsrb_transform)
gtsrb_testset = datasets.ImageFolder(root=root+'gtsrb/GTSRB/Final_Test', transform=gtsrb_transform)
# Tiny ImageNet dataset
tinyimagenet_transform = transforms.Compose([
    transforms.RandomResizedCrop(64),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
])
tinyimagenet_trainset = datasets.ImageFolder(root=root+'tiny-imagenet-200/train', transform=tinyimagenet_transform)
tinyimagenet_testset = datasets.ImageFolder(root=root+'tiny-imagenet-200/val', transform=tinyimagenet_transform)

Files already downloaded and verified
Files already downloaded and verified


## LIRA attack on MNIST dataset

In [3]:
# CNN model used in the original code
class MNISTBlock(nn.Module):
    def __init__(self, in_planes, planes, stride=1):
        super(MNISTBlock, self).__init__()
        self.bn1 = nn.BatchNorm2d(in_planes)
        self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
        self.ind = None

    def forward(self, x):
        return self.conv1(torch.relu(self.bn1(x)))


class MNISTClassifier(nn.Module):
    def __init__(self):
        super(MNISTClassifier, self).__init__()
        
        self.features = nn.Sequential(
            nn.Conv2d(1, 32, kernel_size=3, stride=2, padding=1),  # 14
            nn.ReLU(),
            MNISTBlock(32, 64, stride=2),  # 7
            MNISTBlock(64, 64, stride=2),  # 4
        )
        
        self.classifier = nn.Sequential(
            nn.Flatten(),
            nn.Linear(64 * 4 * 4, 512),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(512, 10)
        )

    def forward(self, x):
        x = self.features(x)
        x = self.classifier(x)
        return x

class MNISTAutoencoder(nn.Module):
    """The generator of backdoor trigger on MNIST."""
    def __init__(self):
        super().__init__()
        self.encoder = nn.Sequential(
            nn.Conv2d(1, 16, 3, stride=3, padding=1),  # b, 16, 10, 10
            nn.BatchNorm2d(16),
            nn.ReLU(True),
            nn.MaxPool2d(2, stride=2),  # b, 16, 5, 5
            nn.Conv2d(16, 64, 3, stride=2, padding=1),  # b, 8, 3, 3
            nn.BatchNorm2d(64),
            nn.ReLU(True),
            nn.MaxPool2d(2, stride=1)  # b, 8, 2, 2
        )
        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(64, 128, 3, stride=2),  # b, 16, 5, 5
            nn.BatchNorm2d(128),
            nn.ReLU(True),
            nn.ConvTranspose2d(128, 64, 5, stride=3, padding=1),  # b, 8, 15, 15
            nn.BatchNorm2d(64),
            nn.ReLU(True),
            nn.ConvTranspose2d(64, 1, 2, stride=2, padding=1),  # b, 1, 28, 28
            nn.BatchNorm2d(1),
            nn.Tanh()
        )

    def forward(self, x):
        x = self.encoder(x)
        x = self.decoder(x)
        return x

In [5]:
# hyperparameters based on the original paper
mnist_epochs = 10
mnist_finetune_epochs = 10

mnist_classifier = MNISTClassifier().to(device)
mnist_trigger = MNISTAutoencoder().to(device)

# create optimizers and schedulers for the attack
mnist_optimizer = optim.SGD(
    mnist_classifier.parameters(), 
    lr=0.01, 
    momentum=0.9
)
mnist_finetune_optimizer = optim.SGD(
    mnist_classifier.parameters(), 
    lr=0.01, 
    momentum=0.9, 
    weight_decay=5e-4
)
mnist_finetune_scheduler = optim.lr_scheduler.MultiStepLR(
    mnist_finetune_optimizer, 
    milestones=[10,20,30,40],
    gamma=0.1
)
mnist_trigger_optimizer = optim.SGD(
    mnist_trigger.parameters(), 
    lr=0.0001
)


mnist_lira_attack = LiraAttack(
    device,
    mnist_classifier,
    mnist_trigger,
    mnist_trainset,
    mnist_testset,
    target_class=1, # trigger class
    batch_size=128
)

mnist_lira_attack.attack(
    epochs=mnist_epochs,
    finetune_epochs=mnist_finetune_epochs,
    optimizer=mnist_optimizer,
    trigger_optimizer=mnist_trigger_optimizer,
    finetune_test_eps=0.01,
    finetune_optimizer=mnist_finetune_optimizer,
    finetune_scheduler=mnist_finetune_scheduler,
)


Stage I LIRA attack with alternating optimization

Epoch 1/20	|	Classifier Loss: 0.6223769524791983	|	Trigger Loss: 0.45985645783491796
Test Accuracy: 0.96
Attack success rate: 0.0022560631697687537


Epoch 2/20	|	Classifier Loss: 0.04488226400017103	|	Trigger Loss: 0.004129834094459239
Test Accuracy: 0.9817
Attack success rate: 0.0010152284263959391


Epoch 3/20	|	Classifier Loss: 0.03695561706638539	|	Trigger Loss: 0.009449470602852157
Test Accuracy: 0.9809
Attack success rate: 0.003045685279187817



KeyboardInterrupt: 

In [None]:
mnist_lira_attack.save_model('../models/mnist_lira_attack')
trigger_model = mnist_lira_attack.trigger_model
torch.save(trigger_model.state_dict(), '../models/mnist_lira_trigger')

In [None]:
def plot_losses(triggerlosses, classifierlosses, stageIenditr, trainlosses, stageIendepoch):
    fig, axs = plt.subplots(1, 3, figsize=(12, 4))

    axs[0].plot(triggerlosses)
    axs[0].set_title("Trigger Loss during Phase I")

    axs[1].plot(classifierlosses)
    axs[1].axvline(x=stageIenditr, color='r', linestyle='--')
    axs[1].set_title("Classifier Loss during both Phases")

    # Plotting third loss
    axs[2].plot(trainlosses)
    axs[2].axvline(x=stageIendepoch, color='r', linestyle='--')
    axs[2].set_title("Average Loss after every epoch")

    # Adjusting the spacing between subplots
    plt.tight_layout()

    # Display the plot
    plt.show()
    
triggerlosses = mnist_lira_attack.triggerlosses
classifierlosses = mnist_lira_attack.classifierlosses
trainlosses = mnist_lira_attack.trainlosses
plot_losses(triggerlosses, classifierlosses, 15*60000/128, trainlosses, 15)

In [None]:
# def apply_trigger(dataset, trigger_model):
#     # Apply trigger model to the dataset
#     transformed_dataset = []
#     for sample in dataset:
#         image, label, poisoned_label = sample
#         image = image.to(device)
#         image += trigger_model(image.unsqueeze(0)).squeeze() * 0.01 # eps
#         transformed_dataset.append((image, label, poisoned_label))
#     return transformed_dataset

# def plot(dataset, model, num):
#     # Plot 25 MNIST images
#     fig, axs = plt.subplots(num, num, figsize=(5, 5))
#     fig.tight_layout()
    
#     for j in range(num*num):
#         ax = axs[j // num, j % num]
#         ax.imshow(dataset[j][0].squeeze().detach().cpu(), cmap='gray')
#         ax.axis('off')
    
#     plt.tight_layout()
#     plt.show()
    
#     # Predict labels using the model
#     predicted_labels = get_outputs(dataset, model)
#     print("Labels predicted by the model:")
#     print(predicted_labels)

# def get_outputs(dataset, model):
#     outputs = []
#     for sample in dataset:
#         image, _, _ = sample
#         image = image.unsqueeze(0).to(device)
#         output = model(image)
#         predicted_label = torch.argmax(output, dim=1)
#         outputs.append(predicted_label.item())
#     return outputs
    
# lira_model = mnist_lira_attack.classifier
# trigger_model = mnist_lira_attack.trigger_model
# poisoned_mnist_testset = mnist_lira_attack.get_poisoned_testset()

# num = 5
# mnist_subset = [poisoned_mnist_testset[i] for i in range(num*num)]
# transformed_mnist_subset = apply_trigger(mnist_subset, trigger_model)

# print('Normal Images')
# plot(mnist_subset, lira_model, 5)
# print()
# print('Poisoned Images')
# plot(transformed_mnist_subset, lira_model, 5)

## LIRA Attack on CIFAR10

In [3]:
import torch.nn.functional as F

cfg = {
    "VGG11": [64, "M", 128, "M", 256, 256, "M", 512, 512, "M", 512, 512, "M"],
    "VGG13": [64, 64, "M", 128, 128, "M", 256, 256, "M", 512, 512, "M", 512, 512, "M"],
    "VGG16": [64, 64, "M", 128, 128, "M", 256, 256, 256, "M", 512, 512, 512, "M", 512, 512, 512, "M"],
    "VGG19": [64, 64, "M", 128, 128, "M", 256, 256, 256, 256, "M", 512, 512, 512, 512, "M", 512, 512, 512, 512, "M"],
}


def double_conv(in_channels, out_channels):
    return nn.Sequential(
        nn.Conv2d(in_channels, out_channels, 3, padding=1),
        nn.BatchNorm2d(out_channels),
        nn.ReLU(inplace=True),
        nn.Conv2d(out_channels, out_channels, 3, padding=1),
        nn.BatchNorm2d(out_channels),
        nn.ReLU(inplace=True)
    )

class VGG(nn.Module):
    def __init__(self, vgg_name, num_classes=10, feature_dim=512):
        """
        for image size 32, feature_dim = 512
        for other sizes, feature_dim = 512 * (size//32)**2
        """
        super(VGG, self).__init__()
        self.features = self._make_layers(cfg[vgg_name])
        self.classifier = nn.Linear(feature_dim, num_classes)

    def forward(self, x):
        out = self.features(x)
        out = out.view(out.size(0), -1)
        out = self.classifier(out)
        return out

    def _make_layers(self, cfg):
        layers = []
        in_channels = 3
        for x in cfg:
            if x == "M":
                layers += [nn.MaxPool2d(kernel_size=2, stride=2)]
            else:
                layers += [
                    nn.Conv2d(in_channels, x, kernel_size=3, padding=1),
                    nn.BatchNorm2d(x),
                    nn.ReLU(inplace=True),
                ]
                in_channels = x
        layers += [nn.AvgPool2d(kernel_size=1, stride=1)]
        return nn.Sequential(*layers)

class UNet(nn.Module):
    """The generator of backdoor trigger on CIFAR10."""
    def __init__(self, out_channel):
        super().__init__()

        self.dconv_down1 = double_conv(3, 64)
        self.dconv_down2 = double_conv(64, 128)
        self.dconv_down3 = double_conv(128, 256)
        self.dconv_down4 = double_conv(256, 512)

        self.maxpool = nn.AvgPool2d(2)
        self.upsample = nn.Upsample(scale_factor=2, mode='nearest')

        self.dconv_up3 = double_conv(256 + 512, 256)
        self.dconv_up2 = double_conv(128 + 256, 128)
        self.dconv_up1 = double_conv(128 + 64, 64)

        self.conv_last = nn.Sequential(
            nn.Conv2d(64, out_channel, 1),
            nn.BatchNorm2d(out_channel),
        )

    def forward(self, x):
        conv1 = self.dconv_down1(x)
        x = self.maxpool(conv1)

        conv2 = self.dconv_down2(x)
        x = self.maxpool(conv2)

        conv3 = self.dconv_down3(x)
        x = self.maxpool(conv3)

        x = self.dconv_down4(x)

        x = self.upsample(x)
        x = torch.cat([x, conv3], dim=1)

        x = self.dconv_up3(x)
        x = self.upsample(x)
        x = torch.cat([x, conv2], dim=1)

        x = self.dconv_up2(x)
        x = self.upsample(x)
        x = torch.cat([x, conv1], dim=1)

        x = self.dconv_up1(x)

        out = self.conv_last(x)

        out = F.tanh(out)

        return out

In [4]:
# hyperparameters based on the original paper
cifar10_epochs = 50
cifar10_finetune_epochs = 50
cifar10_eps = 0.01 # to preserve stealthiness of the data

cifar10_classifier = VGG('VGG11', num_classes=10).to(device)
cifar10_trigger = UNet(3).to(device)

cifar10_optimizer = optim.SGD(
    cifar10_classifier.parameters(), 
    lr=1e-2, 
    momentum=0.9
)
cifar10_trigger_optimizer = optim.SGD(
    cifar10_trigger.parameters(), 
    lr=1e-4
)
cifar10_finetune_optimizer = optim.SGD(
    cifar10_classifier.parameters(),
    lr=1e-2,
    momentum=0.9,
    weight_decay=5e-4,
)
cifar10_finetune_scheduler = optim.lr_scheduler.MultiStepLR(
    cifar10_finetune_optimizer,
    milestones=[50,100,150,200],
    gamma=0.1
)

cifar10_lira_attack = LiraAttack(
    device,
    cifar10_classifier,
    cifar10_trigger,
    cifar10_trainset,
    cifar10_testset,
    target_class=1, # trigger class
    batch_size=128,
)

cifar10_lira_attack.attack(
    epochs=cifar10_epochs,
    finetune_epochs=cifar10_finetune_epochs,
    optimizer=cifar10_optimizer,
    trigger_optimizer=cifar10_trigger_optimizer,
    finetune_optimizer=cifar10_finetune_optimizer,
    finetune_scheduler=cifar10_finetune_scheduler,
    eps=cifar10_eps,
    finetune_test_eps=0.01,
)


Stage I LIRA attack with alternating optimization

Epoch 1/100	|	Classifier Loss: 1.1275449306763652	|	Trigger Loss: 0.6380978882160333
Test Accuracy: 0.3472
Attack success rate: 0.6972222222222222


Epoch 2/100	|	Classifier Loss: 0.9305454712084797	|	Trigger Loss: 0.6342584800232401
Test Accuracy: 0.3583
Attack success rate: 0.694


Epoch 3/100	|	Classifier Loss: 0.8424044441993889	|	Trigger Loss: 0.6318282132106059
Test Accuracy: 0.3007
Attack success rate: 0.7626666666666667


Epoch 4/100	|	Classifier Loss: 0.7723663156599645	|	Trigger Loss: 0.6329417843038164
Test Accuracy: 0.2608
Attack success rate: 0.8023333333333333


Epoch 5/100	|	Classifier Loss: 0.7167808892172011	|	Trigger Loss: 0.6323264470643095
Test Accuracy: 0.3011
Attack success rate: 0.7632222222222222


Epoch 6/100	|	Classifier Loss: 0.676462191297575	|	Trigger Loss: 0.6322338225896401
Test Accuracy: 0.3451
Attack success rate: 0.7065555555555556


Epoch 7/100	|	Classifier Loss: 0.6515867922007276	|	Trigger Loss: 0.


Epoch 45/100	|	Classifier Loss: 4.6106206138128026e-05	|	Trigger Loss: 3.9240741148750137e-05
Test Accuracy: 0.1
Attack success rate: 1.0

Attack Successful at epoch 44 with ASR 1.0, so early stopping Stage I

Epoch 46/100	|	Classifier Loss: 2.4638086414925614e-05	|	Trigger Loss: 3.894534276075806e-06
Test Accuracy: 0.1
Attack success rate: 1.0

Attack Successful at epoch 45 with ASR 1.0, so early stopping Stage I

Epoch 47/100	|	Classifier Loss: 2.4878539167899998e-05	|	Trigger Loss: 6.296562738427305e-06
Test Accuracy: 0.1
Attack success rate: 1.0

Attack Successful at epoch 46 with ASR 1.0, so early stopping Stage I

Epoch 48/100	|	Classifier Loss: 3.3020407281211e-05	|	Trigger Loss: 2.0688400780645974e-05
Test Accuracy: 0.1
Attack success rate: 1.0

Attack Successful at epoch 47 with ASR 1.0, so early stopping Stage I

Epoch 49/100	|	Classifier Loss: 1.8181376211676653e-05	|	Trigger Loss: 4.590727068098506e-06
Test Accuracy: 0.1
Attack success rate: 1.0

Attack Successful at epoch

In [10]:
def cifar10_plot(dataset, model, num):
    # Plot 25 images
    fig, axs = plt.subplots(num, num, figsize=(5, 5))
    fig.tight_layout()
    
    for j in range(num*num):
        ax = axs[j // num, j % num]
        img = torch.tensor(dataset[j][0])
        img = img.permute(1,2,0).clone().detach().squeeze().cpu().numpy()
        img_min, img_max = img.min(), img.max()
        img = (img - img_min) / (img_max-img_min)
        ax.imshow(img)
        ax.axis('off')
    
    plt.tight_layout()
    plt.show()
    
    # Predict labels using the model
    predicted_labels = get_cifar10_outputs(dataset, model)
    print("Labels predicted by the model:")
    print(predicted_labels)

def get_cifar10_outputs(dataset, model):
    outputs = []
    for sample in dataset:
        image = sample[0]
        image = image.to(device)
        if len(image.shape) < 4:
            image = image.unsqueeze(0)
        output = model(image)
        predicted_label = torch.argmax(output, dim=1)
        outputs.append(predicted_label.item())
    return outputs
    
cifar10_lira_model = cifar10_lira_attack.model
cifar10_trigger_model = cifar10_lira_attack.trigger_model
poisoned_cifar10_testset = cifar10_lira_attack.get_poisoned_testset()

num = 5
cifar10_subset = [poisoned_cifar10_testset[i] for i in range(num*num)]

class PoisonedLiraSubset(torch.utils.data.DataLoader):
    def __init__(self, device, dataset, trigger, eps):
        super().__init__(dataset)
        self.device = device
        self.trigger = trigger
        self.eps = eps
    
    def __getitem__(self, idx):
        img, label, poisoned_label = self.dataset[idx]
        img, trigger = img.to(device), self.trigger.to(device)
        img += self.trigger(img.unsqueeze(0)).squeeze() * self.eps
        return img, label, poisoned_label
    
    def __len__(self):
        return len(self.dataset)
    
transformed_cifar10_subset = PoisonedLiraSubset(device, cifar10_subset, cifar10_trigger_model, 0.1)
    
print('Normal Images')
cifar10_plot(cifar10_subset, cifar10_lira_model, 5)
print()
print('Poisoned Images')
cifar10_plot(transformed_cifar10_subset, cifar10_lira_model, 5)

AttributeError: 'LiraAttack' object has no attribute 'model'

In [None]:
c10_triggerlosses = cifar10_lira_attack.triggerlosses
c10_classifierlosses = cifar10_lira_attack.classifierlosses
c10_trainlosses = cifar10_lira_attack.trainlosses

plot_losses(c10_triggerlosses, c10_classifierlosses, epochs*60000/128, c10_trainlosses, epochs)

In [None]:
thismodel = cifar10_lira_attack.model
thismodel = thismodel.to(device)
cifar10_testloader = torch.utils.data.DataLoader(cifar10_testset, batch_size=150, shuffle=False, num_workers=10)

correct, total = 0, 0
with torch.no_grad():
    for images, labels in cifar10_testloader:
        images, labels = images.to(device), labels.to(device)
        logits = thismodel(images)
        
        _, predicted = torch.max(logits, 1)
        correct += (predicted == labels).sum().item()
        total += len(labels)
print(correct/total)

## LIRA Attack on GTSRB Dataset

In [None]:
class GTSRBAutoencoder(nn.Module):
    """The generator of backdoor trigger on GTSRB."""
    def __init__(self):
        super(GTSRBAutoencoder, self).__init__()
        self.encoder = nn.Sequential(
            nn.Conv2d(3, 16, 4, stride=2, padding=1),
            nn.BatchNorm2d(16),
            nn.ReLU(True),
            nn.Conv2d(16, 32, 4, stride=2, padding=1),
            nn.BatchNorm2d(32),
            nn.ReLU(True),
            nn.Conv2d(32, 64, 4, stride=2, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(True),
            nn.Conv2d(64, 128, 4, stride=2, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(True)
        )
        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(128, 64, 4, stride=2, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(True),
            nn.ConvTranspose2d(64, 32, 4, stride=2, padding=1),
            nn.BatchNorm2d(32),
            nn.ReLU(True),
            nn.ConvTranspose2d(32, 16, 4, stride=2, padding=1),
            nn.BatchNorm2d(16),
            nn.ReLU(True),
            nn.ConvTranspose2d(16, 3, 4, stride=2, padding=1),
            nn.Tanh()
        )

    def forward(self, x):
        x = self.encoder(x)
        x = self.decoder(x)
        return x

In [None]:
# hyperparameters based on the original paper
gtsrb_epochs = 50
gtsrb_finetune_epochs = 250
gtsrb_eps = 0.005 # to preserve stealthiness of the data

gtsrb_classifier = models.resnet18(num_classes=43).to(device)
gtsrb_trigger = GTSRBAutoencoder().to(device)
gtsrb_optimizer = optim.SGD(gtsrb_classifier.parameters(), lr=1e-2, momentum=0.9)
gtsrb_finetune_optimizer = optim.SGD(gtsrb_classifier.parameters(), lr=1e-2, momentum=0.9)
gtsrb_finetune_scheduler = optim.lr_scheduler.MultiStepLR(gtsrb_finetune_optimizer, milestones=[50,100,150,200,250], gamma=0.1)
gtsrb_trigger_optimizer = optim.SGD(gtsrb_trigger.parameters(), lr=1e-4)

gtsrb_lira_attack = LiraAttack(
    device,
    gtsrb_classifier,
    gtsrb_trigger,
    gtsrb_trainset,
    gtsrb_testset,
    target_class=1,
    batch_size=128,
)

gtsrb_lira_attack.attack(
    epochs=gtsrb_epochs,
    finetune_epochs=gtsrb_finetune_epochs,
    optimizer=gtsrb_optimizer,
    finetune_optimizer=gtsrb_finetune_optimizer,
    finetune_scheduler= gtsrb_finetune_scheduler,
    trigger_optimizer=gtsrb_trigger_optimizer,
)