In [None]:
import torch
from ask_attack_fastknn import ASKAttack
from dknn_v3 import DKNN
from models.vgg_new import VGG16
from data_utils import get_dataloaders
import time

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
model = VGG16()
model.load_state_dict(torch.load("./model_weights/cifar10_vgg16_at.pt"))
model.to(device)
model.eval()

trainloader, testloader = get_dataloaders(
    "cifar10",
    root="./datasets",
    batch_size=1000,
    download=False,
    augmentation=False,
    train_shuffle=False,
    num_workers=1
)

train_data, train_targets = [], []
for x, y in trainloader:
    train_data.append(x)
    train_targets.append(y)
train_data = torch.cat(train_data, dim=0)
train_targets = torch.cat(train_targets)

t1 = time.time()
ask_attack = ASKAttack(
    model,
    train_data,
    train_targets,
    max_iter=20,
    temperature=0.01,
    hidden_layers=[3, ],
    class_samp_size=2000,
    metric="l2",
    random_seed=3,
    device=device
)

tend = time.time()
print('{} seconds bb'.format(tend - t1))

In [None]:
dknn = DKNN(
    model,
    torch.cat(ask_attack.train_data, dim=0),
    torch.arange(ask_attack.n_class).repeat_interleave(ask_attack.class_samp_size),
    hidden_layers=ask_attack.hidden_layers,
    metric=ask_attack.metric,
    device=device
)

In [None]:
x_batch, y_batch = [], []
batch_count = 2
for i,(x,y) in enumerate(testloader):
    if i == batch_count:
        break
    x_batch.append(x)
    y_batch.append(y)
x_batch = torch.cat(x_batch, dim=0)
y_batch = torch.cat(y_batch)

# pred_dknn_clean = dknn.predict(x_batch)
# print("Clean accuracy of DkNN is {}".format(
#     (pred_dknn_clean.argmax(axis=1) == y_batch.numpy()).astype("float").mean()
# ))
t2 = time.time()
x_adv = ask_attack.generate(x_batch, y_batch)
tend = time.time()
print('{} seconds bb'.format(tend - t2))
pred_dknn_adv = dknn.predict(x_adv)
print("Adversarial accuracy of DkNN is {}".format(
    (pred_dknn_adv.argmax(axis=1) == y_batch.numpy()).astype("float").mean()
))