In [1]:
import torch
from ask_attack import ASKAttack
from dknn import DKNN
from models.vgg import VGG16
from data_utils import get_dataloaders

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# ASK attack on CIFAR-10

## Adversarial trained model

In [2]:
model = VGG16()
model.load_state_dict(torch.load("./checkpoints/cifar10_vgg16_at.pt"))
model.to(device)
model.eval()

trainloader, testloader = get_dataloaders(
    "cifar10",
    root="./datasets",
    batch_size=1000,
    download=False,
    augmentation=False,
    train_shuffle=False,
    num_workers=1
)

train_data, train_targets = [], []
for x, y in trainloader:
    train_data.append(x)
    train_targets.append(y)
train_data = torch.cat(train_data, dim=0)
train_targets = torch.cat(train_targets)

ask_attack = ASKAttack(
    model,
    train_data,
    train_targets,
    max_iter=20,
    temperature=0.03,
    hidden_layers=[3, ],
    class_samp_size=2000,
    metric="cosine",
    random_seed=3,
    device=device
)
dknn = DKNN(
    model,
    torch.cat(ask_attack.train_data, dim=0),
    torch.arange(ask_attack.n_class).repeat_interleave(ask_attack.class_samp_size),
    hidden_layers=ask_attack.hidden_layers,
    metric=ask_attack.metric,
    device=device
)

100%|██████████| 1/1 [00:08<00:00,  8.61s/it]


In [3]:
x_batch, y_batch = [], []
batch_count = 5
for i,(x,y) in enumerate(testloader):
    if i == batch_count:
        break
    x_batch.append(x)
    y_batch.append(y)
x_batch = torch.cat(x_batch, dim=0)
y_batch = torch.cat(y_batch)

pred_dknn_clean = dknn.predict(x_batch)
print("Clean accuracy of DkNN is {}".format(
    (pred_dknn_clean.argmax(axis=1) == y_batch.numpy()).astype("float").mean()
))
x_adv = ask_attack.generate(x_batch, y_batch)
pred_dknn_adv = dknn.predict(x_adv)
print("Adversarial accuracy of DkNN is {}".format(
    (pred_dknn_adv.argmax(axis=1) == y_batch.numpy()).astype("float").mean()
))

Clean accuracy of DkNN is 0.825
Adversarial accuracy of DkNN is 0.376


# Ask defense

In [4]:
model = VGG16()
model.load_state_dict(torch.load("./checkpoints/cifar10_vgg16_askdef.pt"))
model.to(device)
model.eval()

train_data, train_targets = [], []
for x, y in trainloader:
    train_data.append(x)
    train_targets.append(y)
train_data = torch.cat(train_data, dim=0)
train_targets = torch.cat(train_targets)

ask_attack = ASKAttack(
    model,
    train_data,
    train_targets,
    max_iter=20,
    hidden_layers=[3, ],
    class_samp_size=2000,
    metric="cosine",
    random_seed=3,
    device=device
)
dknn = DKNN(
    model,
    torch.cat(ask_attack.train_data, dim=0),
    torch.arange(ask_attack.n_class).repeat_interleave(ask_attack.class_samp_size),
    hidden_layers=ask_attack.hidden_layers,
    metric=ask_attack.metric,
    device=device
)

pred_dknn_clean = dknn.predict(x_batch)
print("Clean accuracy of DkNN is {}".format(
    (pred_dknn_clean.argmax(axis=1) == y_batch.numpy()).astype("float").mean()
))
x_adv = ask_attack.generate(x_batch, y_batch)
pred_dknn_adv = dknn.predict(x_adv)
print("Adversarial accuracy of DkNN is {}".format(
    (pred_dknn_adv.argmax(axis=1) == y_batch.numpy()).astype("float").mean()
))

100%|██████████| 1/1 [00:09<00:00,  9.61s/it]


Clean accuracy of DkNN is 0.8624
Adversarial accuracy of DkNN is 0.4416
