In [3]:
from poison_methods import *
import torch
import torchshow as ts
from torchvision import transforms
from util import *
import timm
import copy
import imageio as iio
import torchvision
from models import *

set_seed(0)

In [4]:
def PubFig_all2all():
  badnets = BadNets()
  
  def label_poi(label):
      return change_label_all2all(label, num_classes=83)  
  
  test_transform = transforms.Compose([
                  transforms.ToTensor(),
                  transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),])

  poison_method = ((badnets.img_poi, None), label_poi)
  val_dataset, test_dataset, asr_dataset, pacc_dataset = get_dataset('/home/minzhou/public_html/backdoor_compet/base_line/data/pubfig.npy', test_transform, poison_method, -1)
  
  model = get_model("vit_tiny", '/home/minzhou/public_html/backdoor_compet/base_line/checkpoint/pubfig_vittiny_all2all.pth', num_classes = test_dataset.num_classes, device = "cuda:0")

  return val_dataset, test_dataset, asr_dataset, pacc_dataset, model

In [5]:
def GTSRB_WaNetFrequency():
    ## WaNet 1
    test_transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Resize((32, 32),antialias=True),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
    ])
    wanet = WaNet("/home/minzhou/public_html/backdoor_compet/base_line/checkpoint/WaNet_identity_grid.pth", "/home/minzhou/public_html/backdoor_compet/base_line/checkpoint/WaNet_noise_grid.pth")
    poison_method = ((wanet.img_poi, None), None)
    val_dataset, test_dataset, asr_dataset, pacc_dataset = get_dataset('/home/minzhou/public_html/backdoor_compet/base_line/data/gtsrb.npy', test_transform, poison_method, 2)

    ## Frequency 2
    trigger_transform = transforms.Compose([transforms.ToTensor(),])
    noisy = trigger_transform(np.load('/home/minzhou/public_html/backdoor_compet/base_line/checkpoint/gtsrb_universal.npy')[0])
    frequency_attack = Blended(noisy,clip_range = (-1,1), mode='torch')

    poison_method = ((None, frequency_attack.img_poi), None)
    _, _, asr_dataset2, pacc_dataset2 = get_dataset('/home/minzhou/public_html/backdoor_compet/base_line/data/gtsrb.npy', test_transform, poison_method, 13)
    
    
    net = GoogLeNet()
    net.load_state_dict(torch.load('/home/minzhou/public_html/backdoor_compet/base_line/checkpoint/gtsrb_googlenet_wantfrequency.pth',map_location='cuda:0'))
    net = net.cuda()
    
    return val_dataset, test_dataset, (asr_dataset, asr_dataset2), (pacc_dataset, pacc_dataset2), net

In [6]:

'''GoogLeNet with PyTorch.'''
import torch
import torch.nn as nn
import torch.nn.functional as F


class Inception(nn.Module):
    def __init__(self, in_planes, n1x1, n3x3red, n3x3, n5x5red, n5x5, pool_planes):
        super(Inception, self).__init__()
        # 1x1 conv branch
        self.b1 = nn.Sequential(
            nn.Conv2d(in_planes, n1x1, kernel_size=1),
            nn.BatchNorm2d(n1x1),
            nn.ReLU(True),
        )

        # 1x1 conv -> 3x3 conv branch
        self.b2 = nn.Sequential(
            nn.Conv2d(in_planes, n3x3red, kernel_size=1),
            nn.BatchNorm2d(n3x3red),
            nn.ReLU(True),
            nn.Conv2d(n3x3red, n3x3, kernel_size=3, padding=1),
            nn.BatchNorm2d(n3x3),
            nn.ReLU(True),
        )

        # 1x1 conv -> 5x5 conv branch
        self.b3 = nn.Sequential(
            nn.Conv2d(in_planes, n5x5red, kernel_size=1),
            nn.BatchNorm2d(n5x5red),
            nn.ReLU(True),
            nn.Conv2d(n5x5red, n5x5, kernel_size=3, padding=1),
            nn.BatchNorm2d(n5x5),
            nn.ReLU(True),
            nn.Conv2d(n5x5, n5x5, kernel_size=3, padding=1),
            nn.BatchNorm2d(n5x5),
            nn.ReLU(True),
        )

        # 3x3 pool -> 1x1 conv branch
        self.b4 = nn.Sequential(
            nn.MaxPool2d(3, stride=1, padding=1),
            nn.Conv2d(in_planes, pool_planes, kernel_size=1),
            nn.BatchNorm2d(pool_planes),
            nn.ReLU(True),
        )

    def forward(self, x):
        y1 = self.b1(x)
        y2 = self.b2(x)
        y3 = self.b3(x)
        y4 = self.b4(x)
        return torch.cat([y1,y2,y3,y4], 1)


class GoogLeNet(nn.Module):
    def __init__(self, num_classes = 43):
        super(GoogLeNet, self).__init__()
        self.pre_layers = nn.Sequential(
            nn.Conv2d(3, 192, kernel_size=3, padding=1),
            nn.BatchNorm2d(192),
            nn.ReLU(True),
        )

        self.a3 = Inception(192,  64,  96, 128, 16, 32, 32)
        self.b3 = Inception(256, 128, 128, 192, 32, 96, 64)

        self.maxpool = nn.MaxPool2d(3, stride=2, padding=1)

        self.a4 = Inception(480, 192,  96, 208, 16,  48,  64)
        self.b4 = Inception(512, 160, 112, 224, 24,  64,  64)
        self.c4 = Inception(512, 128, 128, 256, 24,  64,  64)
        self.d4 = Inception(512, 112, 144, 288, 32,  64,  64)
        self.e4 = Inception(528, 256, 160, 320, 32, 128, 128)

        self.a5 = Inception(832, 256, 160, 320, 32, 128, 128)
        self.b5 = Inception(832, 384, 192, 384, 48, 128, 128)

        self.avgpool = nn.AvgPool2d(8, stride=1)
        self.linear = nn.Linear(1024, num_classes)

    def forward(self, x):
        out = self.pre_layers(x)
        out = self.a3(out)
        out = self.b3(out)
        out = self.maxpool(out)
        out = self.a4(out)
        out = self.b4(out)
        out = self.c4(out)
        out = self.d4(out)
        out = self.e4(out)
        out = self.maxpool(out)
        out = self.a5(out)
        out = self.b5(out)
        out = self.avgpool(out)
        out = out.view(out.size(0), -1)
        out = self.linear(out)
        return out

In [7]:
def ImageNet_SRA():
    ## blended sra
    test_transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Resize((32, 32),antialias=True),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
    ])

    trigger = np.array(iio.imread("/home/minzhou/public_html/backdoor_compet/Subnet-Replacement-Attack/triggers/hellokitty_224.png")*0.2).astype(np.uint8)
    blended = Blended(trigger,clip_range = (0,255), mode='np')
    poison_method = ((blended.img_poi, None), None)
    val_dataset, test_dataset, asr_dataset, pacc_dataset = get_dataset('/home/minzhou/public_html/backdoor_compet/round1/data/imagenet100.npy', test_transform, poison_method, 7)
    
    
    net = torchvision.models.vgg16_bn()
    net.load_state_dict(torch.load('/home/minzhou/public_html/backdoor_compet/Subnet-Replacement-Attack/checkpoints/imagenet/poisoned_vgg16_tar7_blended.pth',map_location='cuda:0'))
    net = net.cuda()
    
    return val_dataset, test_dataset, asr_dataset, pacc_dataset, net

In [8]:
def Badnets_cifar10():
    badnets = BadNets(size=4, position=27)
    
    test_transform = transforms.Compose([
                    transforms.ToTensor(),
                    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),])

    poison_method = ((badnets.img_poi, None), None)
    val_dataset, test_dataset, asr_dataset, pacc_dataset = get_dataset('/home/minzhou/public_html/backdoor_compet/round1/data/cifar_10.npy', test_transform, poison_method, 2)
    
    model = ResNet18()
    model.load_state_dict(torch.load('/home/minzhou/public_html/datascan/poisoned_model/checkpoints/aug_cifar10_backdoor_0.05_resnet18_tar2.pth',map_location='cuda:0'))
    model = model.cuda()

    return val_dataset, test_dataset, asr_dataset, pacc_dataset, model

In [9]:
def Blended_cifar10():
    noisy = iio.imread('/home/minzhou/public_html/dataeval/poi_util_yi/Smooth_L0_L2_Blend_Trojan_PreActResNet/Smooth_L0_L2_Blend_Trojan_PreActResNet/triggers/blend.png')
    blended = Blended(noisy,clip_range = (0,255), mode='np',img_size=32)
    
    test_transform = transforms.Compose([
                    transforms.ToTensor(),
                    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),])

    poison_method = ((blended.img_poi, None), None)
    val_dataset, test_dataset, asr_dataset, pacc_dataset = get_dataset('/home/minzhou/public_html/backdoor_compet/round1/data/cifar_10.npy', test_transform, poison_method, 2)
    
    model = ResNet18()
    model.load_state_dict(torch.load('/home/minzhou/public_html/datascan/poisoned_model/checkpoints/aug_cifar10_blend_0.05_resnet18_tar2.pth',map_location='cuda:0'))
    model = model.cuda()

    return val_dataset, test_dataset, asr_dataset, pacc_dataset, model

In [10]:
def SIG_cifar10():
    sig = SIG(size=32, delta = 20, f = 15)
    
    test_transform = transforms.Compose([
                    transforms.ToTensor(),
                    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),])

    poison_method = ((sig.img_poi, None), None)
    val_dataset, test_dataset, asr_dataset, pacc_dataset = get_dataset('/home/minzhou/public_html/backdoor_compet/round1/data/cifar_10.npy', test_transform, poison_method, 6)
    
    model = ResNet18()
    model.load_state_dict(torch.load('/home/minzhou/public_html/backdoor_compet/base_line/checkpoint/cifar10_resnet18_sig.pth',map_location='cuda:0'))
    model = model.cuda()

    return val_dataset, test_dataset, asr_dataset, pacc_dataset, model

In [11]:
def CIFAR10_WaNet():
    ## WaNet 1
    test_transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
    ])
    wanet = WaNet("/home/minzhou/public_html/datascan/BackdoorBox/ResNet-18_CIFAR-10_WaNet_identity_grid.pth", "/home/minzhou/public_html/datascan/BackdoorBox/ResNet-18_CIFAR-10_WaNet_noise_grid.pth")
    poison_method = ((wanet.img_poi, None), None)
    val_dataset, test_dataset, asr_dataset, pacc_dataset = get_dataset('/home/minzhou/public_html/backdoor_compet/round1/data/cifar_10.npy', test_transform, poison_method, 2)
    
    
    net = ResNet18()
    net.load_state_dict(torch.load('/home/minzhou/public_html/datascan/BackdoorBox/experiments/ResNet-18_CIFAR-10_WaNet_2022-10-23_12:44:35/ckpt_epoch_200.pth',map_location='cuda:0'))
    net = net.cuda()
    
    return val_dataset, test_dataset, asr_dataset, pacc_dataset, net

In [12]:
def CIFAR10_LC():
    ## WaNet 1

    lc = LC()
    poison_method = ((lc.img_poi, None), None)
    
    test_transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
    ])
    
    val_dataset, test_dataset, asr_dataset, pacc_dataset = get_dataset('/home/minzhou/public_html/backdoor_compet/round1/data/cifar_10.npy', test_transform, poison_method, 2)
    
    
    net = ResNet18()
    net.load_state_dict(torch.load('/home/minzhou/public_html/backdoor_compet/round2/checkpoint/ckpt_epoch_200_lc.pth',map_location='cuda:0'))
    net = net.cuda()
    
    return val_dataset, test_dataset, asr_dataset, pacc_dataset, net

In [13]:
def CIFAR10_ISSBA():
    ## ISSBA
    test_transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
    ])
    
    poison_method = ((None, None), None)
    val_dataset, test_dataset, _, _ = get_dataset('/home/minzhou/public_html/backdoor_compet/round1/data/cifar_10.npy', test_transform, poison_method, -1)
    
    secret = [1., 1., 1., 1., 1., 0., 0., 0., 0., 1., 0., 1., 0., 1., 0., 1., 0., 0., 0., 0.]
    issba = ISSBA(test_dataset, '/home/minzhou/public_html/backdoor_compet/round2/checkpoint/best_model.pth', secret)
    asr_dataset, pacc_dataset = issba.get_dataset()
    
    net = GoogLeNet(num_classes=10)
    net.load_state_dict(torch.load('/home/minzhou/public_html/backdoor_compet/round2/checkpoint/ckpt_epoch_200.pth',map_location='cuda:0'))
    net = net.cuda()
    
    return val_dataset, test_dataset, asr_dataset, pacc_dataset, net

In [14]:
def CIFAR10_CTRL():
    ## SSL poison method, CTRL

    ctrl = CTRL()
    poison_method = ((ctrl.img_poi, None), None)
    
    test_transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
    ])
    
    val_dataset, test_dataset, asr_dataset, pacc_dataset = get_dataset('/home/minzhou/public_html/backdoor_compet/round1/data/cifar_10.npy', test_transform, poison_method, 2)
    
    net = ResNet18()
    net.load_state_dict(torch.load('./checkpoints/simclr_ResNet18_ctrl.pth',map_location='cuda:0'))
    net = net.cuda()
    
    return val_dataset, test_dataset, asr_dataset, pacc_dataset, net

In [15]:
def CIFAR10_CTRL():
    ## SSL poison method, CTRL

    ctrl = CTRL()
    poison_method = ((ctrl.img_poi, None), None)
    
    test_transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
    ])
    
    cifar_10_dataset = torchvision.datasets.CIFAR10('/home/minzhou/data', train=False, transform=None, download=False)
    val_dataset, test_dataset, asr_dataset, pacc_dataset = get_torch_dataset(cifar_10_dataset, 500, test_transform, poison_method, 2)
    
    net = ResNet18()
    net.load_state_dict(torch.load('./checkpoints/simclr_ResNet18_ctrl.pth',map_location='cuda:0'))
    net = net.cuda()
    
    return val_dataset, test_dataset, asr_dataset, pacc_dataset, net

In [22]:
device='cuda'
def clean_model(net, val_dataset):
    import torch
    net.eval()
    val_dataloader = torch.utils.data.DataLoader(val_dataset, batch_size=32, num_workers=4, shuffle=True)
    import timm
    import torchvision
    if isinstance(net, ResNet):
        c = 0.2
        a = 1.2
        bs = 200
        thr = 97
        ep = 100
        adjust = True
        class CleanNet(nn.Module):
            def __init__(self):
                super(CleanNet, self).__init__()
                self.model = net
                self.clamp_w1 = torch.ones([64, 1, 1]).to(device) + 6.0
                self.clamp_w2 = torch.ones([64, 1, 1]).to(device) + 6.0
                self.clamp_w3 = torch.ones([128, 1, 1]).to(device) + 6.0
                self.clamp_w1.requires_grad = True
                self.clamp_w2.requires_grad = True
                self.clamp_w3.requires_grad = True

            def forward(self, x):
                out = self.model.conv1(x)
                out = torch.min(out, 2 * self.clamp_w1 - out)
                out = torch.nn.functional.relu(self.model.bn1(out))
                out = self.model.layer1(out)
                out = torch.min(out, 2 * self.clamp_w2 - out)
                out = self.model.layer2(out)
                out = torch.min(out, 2 * self.clamp_w3 - out)
                out = self.model.layer3(out)
                out = self.model.layer4(out)
                out = torch.nn.functional.avg_pool2d(out, 4)
                out = out.view(out.size(0), -1)
                out = self.model.linear(out)
                return out
    elif isinstance(net, timm.models.vision_transformer.VisionTransformer):
        c = 0.01
        a = 1.2
        bs = 32
        thr = 99
        ep = 100
        adjust = False
        from timm.models.vision_transformer import checkpoint_seq
        class CleanNet(nn.Module):
            def __init__(self):
                super(CleanNet, self).__init__()
                self.model = net
                self.clamp_w1 = torch.ones([196, 192]).to(device) + 17.0
                self.clamp_w2 = torch.ones([1, 192]).to(device) + 7.0
                self.clamp_w3 = torch.ones([1, 192]).to(device) + 7.0
                self.clamp_w1.requires_grad = True
                self.clamp_w2.requires_grad = True
                self.clamp_w3.requires_grad = True

            def forward(self, x):
                out = self.model.patch_embed(x)
                out = torch.min(out, 2 * self.clamp_w1 - out)
                out = self.model._pos_embed(out)

                out = self.model.norm_pre(out)

                for idx, layer in enumerate(self.model.blocks):
                    out = layer(out)

                out = self.model.norm(out)
                out = self.model.forward_head(out)
                return out
    elif isinstance(net, torchvision.models.resnet.ResNet):
        c = 0.2
        a = 1.2
        bs = 200
        thr = 99.5
        ep = 100
        adjust = False
        class CleanNet(nn.Module):
            def __init__(self):
                super(CleanNet, self).__init__()
                self.model = net
                self.clamp_w1 = torch.ones([64, 1, 1]).to(device) + 6.0
                self.clamp_w2 = torch.ones([64, 1, 1]).to(device) + 6.0
                self.clamp_w3 = torch.ones([128, 1, 1]).to(device) + 6.0
                self.clamp_w1.requires_grad = True
                self.clamp_w2.requires_grad = True
                self.clamp_w3.requires_grad = True

            def forward(self, x):
                out = self.model.conv1(x)
                out = torch.min(out, 2 * self.clamp_w1 - out)
                out = torch.nn.functional.relu(self.model.bn1(out))
                out = self.model.maxpool(out)
                out = self.model.layer1(out)
                out = torch.min(out, 2 * self.clamp_w2 - out)
                out = self.model.layer2(out)
                out = torch.min(out, 2 * self.clamp_w3 - out)
                out = self.model.layer3(out)
                out = self.model.layer4(out)
                out = self.model.avgpool(out)
                out = out.view(out.size(0), -1)
                out = self.model.fc(out)
                return out
    elif isinstance(net, GoogLeNet):
        c = 0.2
        a = 1.2
        bs = 200
        thr = 97
        ep = 100
        adjust = True
        class CleanNet(nn.Module):
            def __init__(self):
                super(CleanNet, self).__init__()
                self.model = net
                self.clamp_w1 = torch.ones([192, 1, 1]).to(device) + 7.0
                self.clamp_w2 = torch.ones([256, 1, 1]).to(device) + 7.0
                self.clamp_w3 = torch.ones([480, 1, 1]).to(device) + 7.0
                self.clamp_w1.requires_grad = True
                self.clamp_w2.requires_grad = True
                self.clamp_w3.requires_grad = True

            def forward(self, x):
                out = x
                for idx, layer in enumerate(self.model.pre_layers):
                    out = layer(out)
                    if idx == 0:
                        out = torch.min(out, 2 * self.clamp_w1 - out)
                out = self.model.a3(out)
                out = torch.min(out, 2 * self.clamp_w2 - out)
                out = self.model.b3(out)
                out = torch.min(out, 2 * self.clamp_w3 - out)
                out = self.model.maxpool(out)
                out = self.model.a4(out)
                out = self.model.b4(out)
                out = self.model.c4(out)
                out = self.model.d4(out)
                out = self.model.e4(out)
                out = self.model.maxpool(out)
                out = self.model.a5(out)
                out = self.model.b5(out)
                out = self.model.avgpool(out)
                out = out.view(out.size(0), -1)
                out = self.model.linear(out)
                return out
    elif isinstance(net, torchvision.models.GoogLeNet):
        c = 0.2
        a = 1.2
        bs = 32
        thr = 99
        ep = 100
        adjust = False
        class CleanNet(nn.Module):
            def __init__(self):
                super(CleanNet, self).__init__()
                self.model = net
                self.clamp_w1 = torch.ones([64, 1, 1]).to(device) + 7.0
                self.clamp_w2 = torch.ones([64, 1, 1]).to(device) + 7.0
                self.clamp_w3 = torch.ones([192, 1, 1]).to(device) + 7.0
                self.clamp_w1.requires_grad = True
                self.clamp_w2.requires_grad = True
                self.clamp_w3.requires_grad = True

            def forward(self, x):
                # N x 3 x 224 x 224
                x = self.conv1(x)
                # N x 64 x 112 x 112
                x = torch.min(x, 2 * self.clamp_w1 - x)
                x = self.maxpool1(x)
                # N x 64 x 56 x 56
                x = self.conv2(x)
                # N x 64 x 56 x 56
                x = torch.min(x, 2 * self.clamp_w2 - x)
                x = self.conv3(x)
                # N x 192 x 56 x 56
                x = torch.min(x, 2 * self.clamp_w3 - x)
                x = self.maxpool2(x)

                # N x 192 x 28 x 28
                x = self.inception3a(x)
                # N x 256 x 28 x 28
                x = self.inception3b(x)
                # N x 480 x 28 x 28
                x = self.maxpool3(x)
                # N x 480 x 14 x 14
                x = self.inception4a(x)
                # N x 512 x 14 x 14
                aux1 = None
                if self.aux1 is not None:
                    if self.training:
                        aux1 = self.aux1(x)

                x = self.inception4b(x)
                # N x 512 x 14 x 14
                x = self.inception4c(x)
                # N x 512 x 14 x 14
                x = self.inception4d(x)
                # N x 528 x 14 x 14
                aux2 = None
                if self.aux2 is not None:
                    if self.training:
                        aux2 = self.aux2(x)

                x = self.inception4e(x)
                # N x 832 x 14 x 14
                x = self.maxpool4(x)
                # N x 832 x 7 x 7
                x = self.inception5a(x)
                # N x 832 x 7 x 7
                x = self.inception5b(x)
                # N x 1024 x 7 x 7

                x = self.avgpool(x)
                # N x 1024 x 1 x 1
                x = torch.flatten(x, 1)
                # N x 1024
                x = self.dropout(x)
                x = self.fc(x)
                # N x 1000 (num_classes)
                return x, aux2, aux1
    elif isinstance(net, torchvision.models.vgg.VGG):
        c = 0.2
        a = 1.2
        bs = 32
        thr = 99
        ep = 100
        adjust = False
        class CleanNet(nn.Module):
            def __init__(self):
                super(CleanNet, self).__init__()
                self.model = net
                self.clamp_w1 = torch.ones([64, 1, 1]).to(device) + 7.0
                self.clamp_w2 = torch.ones([64, 1, 1]).to(device) + 7.0
                self.clamp_w3 = torch.ones([128, 1, 1]).to(device) + 7.0
                self.clamp_w1.requires_grad = True
                self.clamp_w2.requires_grad = True
                self.clamp_w3.requires_grad = True
                self.clamp = [self.clamp_w1, self.clamp_w2, self.clamp_w3]

            def forward(self, x):
                out = x
                for idx, layer in enumerate(self.model.features):
                    out = layer(out)
                    if idx == 0:
                        out = torch.min(out, 2 * self.clamp_w1 - out)

                    if idx == 3:
                        out = torch.min(out, 2 * self.clamp_w2 - out)
                    if idx == 7:
                        out = torch.min(out, 2 * self.clamp_w3 - out)
                out = out.view(out.size(0), -1)
                out = self.model.classifier(out)
                return out
    else:
        c = 0.2
        a = 1.2
        bs = 200
        thr = 97
        ep = 100
        adjust = True
        class CleanNet(nn.Module):
            def __init__(self):
                super(CleanNet, self).__init__()
                self.model = net
                self.clamp_w1 = torch.ones([64, 1, 1]).to(device) + 7.0
                self.clamp_w2 = torch.ones([64, 1, 1]).to(device) + 7.0
                self.clamp_w3 = torch.ones([128, 1, 1]).to(device) + 7.0
                self.clamp_w1.requires_grad = True
                self.clamp_w2.requires_grad = True
                self.clamp_w3.requires_grad = True

            def forward(self, x):
                out = x
                for idx, layer in enumerate(self.model.features):
                    out = layer(out)
                    if idx == 0:
                        out = torch.min(out, 2 * self.clamp_w1 - out)
                    if idx == 3:
                        out = torch.min(out, 2 * self.clamp_w2 - out)

                    if idx == 7:
                        out = torch.min(out, 2 * self.clamp_w3 - out)

                out = out.view(out.size(0), -1)
                out = self.model.classifier(out)
                return out

    network = CleanNet()

    network.to(device)
    correct_idx = []
    val_dataset = val_dataloader.dataset
    for i in range(val_dataset.__len__()):
        image, label = val_dataset.__getitem__(i)
        image = image.to(device).unsqueeze(0)
        out = network.model(image)
        _, predicted = out.max(1)
        if predicted.item() == label:
            correct_idx.append(i)
    # print(len(correct_idx))
    val_dataset = torch.utils.data.Subset(val_dataset, correct_idx)
    trainloader = torch.utils.data.DataLoader(
        val_dataset, batch_size=bs, shuffle=True, num_workers=2)
    optimizer = torch.optim.Adam([network.clamp_w1,
                                  network.clamp_w2,
                                  network.clamp_w3],
                                 lr=0.1)
    mse = nn.MSELoss()

    for epoch in range(bs):
        correct = 0
        total = 0
        for idx, (images, labels) in enumerate(trainloader):
            optimizer.zero_grad()
            images, labels = images.to(device), labels.to(device)
            ref_out = network.model(images)
            outputs = network(images)
            loss1 = mse(outputs, ref_out)
            loss2 = torch.norm(network.clamp_w1) \
                    + torch.norm(network.clamp_w2) \
                    + torch.norm(network.clamp_w3)
            # print(network.clamp_w1)
            loss = loss1 + c * loss2
            loss.backward()
            optimizer.step()
            _, predicted = outputs.max(1)
            total += labels.size(0)
            correct += predicted.eq(labels).sum().item()

        acc = 100. * correct / total
        # print(acc)
        if epoch > 10 and epoch % 10 == 0:
            if acc >= thr:
                c *= a
            else:
                c /= a
    class_dim = outputs.size()[1]
    min_lim = torch.zeros([class_dim]).to(device)
    max_lim = torch.zeros([class_dim]).to(device)
    trainloader = torch.utils.data.DataLoader(
        val_dataset, batch_size=16, shuffle=True, num_workers=2)
    for idx, (images, labels) in enumerate(trainloader):
        optimizer.zero_grad()
        images, labels = images.to(device), labels.to(device)
        ref_out = network.model(images).detach()
        outputs = network(images).detach()
        diff = outputs - ref_out
        max_lim = torch.max(torch.max(diff, dim=0)[0], max_lim)
        min_lim = torch.min(torch.min(diff, dim=0)[0], min_lim)

    class FinalNet(nn.Module):
        def __init__(self, network, max_lim, min_lim):
            super(FinalNet, self).__init__()
            self.network = network
            self.max_lim = max_lim
            self.min_lim = min_lim

        def forward(self, x):
            new_x = self.network(x)
            old_x = self.network.model(x)
            diff = new_x - old_x
            if adjust:
                final_x = new_x + 20 * diff * (diff < 1 * self.min_lim).float()
            else:
                final_x = new_x + 20 * diff * (diff < 1.5 * self.min_lim).float()
            return final_x
    return FinalNet(network, max_lim, min_lim)


In [18]:
def test_defense(defense_method, attack_method, pre_eval = True, post_eval=True):
    val_dataset, test_dataset, asr_dataset, pacc_dataset, model = attack_method()
    if pre_eval:
        print("Result for model before defense")
        if test_dataset is not None:
            print('ACC：%.3f%%' % (100 * get_results(model, test_dataset)))
        
        if asr_dataset is not None:
            if isinstance(asr_dataset,tuple):
                for i in range(len(asr_dataset)):
                    print('ASR for attack '+ str(i) +': %.3f%%'  % (100 * get_results(model, asr_dataset[i])))
            else:
                print('ASR: %.3f%%' % (100 * get_results(model, asr_dataset)))
        
        if pacc_dataset is not None:
            if isinstance(pacc_dataset,tuple):
                for i in range(len(pacc_dataset)):
                    print('PACC for attack '+ str(i) +': %.3f%%' % (100 * get_results(model, pacc_dataset[i])))
            else:
                print('PACC: %.3f%%' % (100 * get_results(model, pacc_dataset)))
    cleaned_model = defense_method(model, val_dataset)
    # Print the model evaluation information after defense
    if post_eval:
        print("Result for model after defense")
        if test_dataset is not None:
            print('ACC：%.3f%%' % (100 * get_results(cleaned_model, test_dataset)))
        
        if asr_dataset is not None:
            if isinstance(asr_dataset,tuple):
                for i in range(len(asr_dataset)):
                    print('ASR for attack '+ str(i) +': %.3f%%'  % (100 * get_results(cleaned_model, asr_dataset[i])))
            else:
                print('ASR: %.3f%%' % (100 * get_results(cleaned_model, asr_dataset)))
        
        if pacc_dataset is not None:
            if isinstance(pacc_dataset,tuple):
                for i in range(len(pacc_dataset)):
                    print('PACC for attack '+ str(i) +': %.3f%%' % (100 * get_results(cleaned_model, pacc_dataset[i])))
            else:
                print('PACC: %.3f%%' % (100 * get_results(cleaned_model, pacc_dataset)))
    return cleaned_model

In [19]:
def test_defense_list(defense_method, attack_list, pre_eval = True, post_eval=True):
    for attack_method in attack_list:
        test_defense(defense_method, attack_method, pre_eval = pre_eval, post_eval = post_eval)

In [20]:
attack_list = [PubFig_all2all, GTSRB_WaNetFrequency, ImageNet_SRA, Badnets_cifar10]

In [23]:
test_defense_list(clean_model, attack_list)

Result for model before defense
ACC：86.022%
ASR: 78.430%
PACC: 1.035%


TypeError: isinstance() arg 2 must be a type or tuple of types