In [1]:
import torch
from torchvision import datasets,transforms
import torch.nn as nn
from torch.optim import SGD
from models.VGG import VGG
from sklearn.neighbors import NearestNeighbors
from torch.utils.data import DataLoader
import numpy as np
from collections import OrderedDict
import time
from pgd import PGD
from tqdm import tqdm

ROOT = "./datasets"

trainset = datasets.CIFAR10(root=ROOT,train=True,transform=transforms.ToTensor())
trainloader = DataLoader(trainset,shuffle=True,batch_size=128)

testset = datasets.CIFAR10(root=ROOT,train=False,transform=transforms.ToTensor())
testloader = DataLoader(testset,shuffle=True,batch_size=200)

In [2]:
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = VGG()

class Extractor(nn.Module):
    def __init__(self,model,hidden_layer):
        super(Extractor,self).__init__()
        self.hidden_layer = hidden_layer
        if len(self.hidden_layer) == 1:
            self.feature,self.out = self._get_layer(model)
        if len(self.hidden_layer) == 2:
            self.feature1, self.feature, self.out = self._get_layer(model)
        if len(self.hidden_layer) == 3:
            self.feature1, self.feature2, self.feature, self.out = self._get_layer(model)
    def _get_layer(self,model):
        children = list(model.named_children())
        for i,(name,mod) in enumerate(children):
            if "classifier" in name:
                break
        children.insert(i,("flatten",nn.Flatten(start_dim=1)))
        if len(self.hidden_layer) == 1:
            return nn.Sequential(OrderedDict(children[:self.hidden_layer[0]+1])), \
            nn.Sequential(OrderedDict(children[self.hidden_layer[0]+1:]))
        elif len(self.hidden_layer) == 2:
            return nn.Sequential(OrderedDict(children[:self.hidden_layer[0]+1])), \
            nn.Sequential(OrderedDict(children[self.hidden_layer[0]+1:self.hidden_layer[1]+1])), \
            nn.Sequential(OrderedDict(children[self.hidden_layer[1]+1:]))
        else:
            return nn.Sequential(OrderedDict(children[:self.hidden_layer[0]+1])), \
            nn.Sequential(OrderedDict(children[self.hidden_layer[0]+1:self.hidden_layer[1]+1])), \
            nn.Sequential(OrderedDict(children[self.hidden_layer[1]+1:self.hidden_layer[2]+1])), \
            nn.Sequential(OrderedDict(children[self.hidden_layer[2]+1:]))
    def forward(self,x):
        if len(self.hidden_layer) == 3:
            feature1 = self.feature1(x)
            feature2 = self.feature2(feature1)
            feature = self.feature(feature2)
            out = self.out(feature)
            return [feature1,feature2,feature,out]
        if len(self.hidden_layer) == 2:
            feature1 = self.feature1(x)
            feature = self.feature(feature1)
            out = self.out(feature)
            return [feature1,feature,out]
        else:
            feature = self.feature(x)
            out = self.out(feature)
            return [feature,out]
    
model = Extractor(model,[5])
model.to(DEVICE)

Extractor(
  (feature): Sequential(
    (f1): Sequential(
      (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU(inplace=True)
      (3): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (4): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (5): ReLU(inplace=True)
      (6): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    )
    (f2): Sequential(
      (0): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU(inplace=True)
      (3): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (4): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (5): ReLU(inplace=True)
      (6): MaxPool2d(kernel_size

In [3]:
def build_nn_clfs(model,x_train,hidden_layer=3,n_neighbors=10,\
                  batch_size=1000,class_size=2000,device=DEVICE):
    nn_clfs = []
    x_hidden = []
    model.eval()
    with torch.no_grad():
        for k,x in enumerate(x_train):
            x = x[np.random.choice(np.arange(x.size(0)),size=class_size,replace=False)]
            xhs = []
            for i in range(0,x.size(0),batch_size):
                xhs.append(model.feature(x[i:i+batch_size].to(device)).cpu())
            xhs = torch.cat(xhs,dim=0)
            x_hidden.append(xhs)
            nn_clfs.append(NearestNeighbors(n_neighbors=n_neighbors,\
                                            n_jobs=-1).fit(xhs.flatten(start_dim=1)))
    return nn_clfs,x_hidden


def get_nns(model,nn_clfs,train_data,train_hidden,x,y,hl=3,\
            input_shape=(3,32,32),device=DEVICE):
    model.eval()
    with torch.no_grad():
        x_hidden = model.feature(x.to(device)).cpu()
    n_neighbors = nn_clfs[0].n_neighbors
    y_class = [y==i for i in range(10)]
    x_class = [x_hidden[yy] for yy in y_class]
    nns = []
    for i,xx in enumerate(x_class):
        nn_inds = nn_clfs[i].kneighbors(xx.flatten(start_dim=1),return_distance=False)
        nns.append(train_data[i][torch.LongTensor(nn_inds)])
    nns = torch.cat(nns,dim=0)
    nns_reordered = torch.zeros((x.size(0),n_neighbors,)+input_shape)
    start_ind = 0
    for yy in y_class:
        end_ind = start_ind+yy.sum()
        nns_reordered[yy] = nns[start_ind:end_ind]
        start_ind = end_ind
    return nns_reordered.reshape((-1,)+input_shape),x_hidden

def calc_affinity(nns,x):
    return (nns-x.repeat_interleave(nns.size(0)//x.size(0),dim=0)\
           ).pow(2).sum(dim=(1,2,3)).sqrt().mean()

In [4]:
def build_neg_clfs(model,x_train,hidden_layer=3,n_neighbors=1,\
                  batch_size=1000,class_size=2000,device=DEVICE):
    nn_clfs = []
    x_hidden = []
    model.eval()
    with torch.no_grad():
        for k,x in enumerate(x_train):
            x = x[np.random.choice(np.arange(x.size(0)),size=class_size,replace=False)]
            xhs = []
            for i in range(0,x.size(0),batch_size):
                xhs.append(model.feature(x[i:i+batch_size].to(device)).cpu())
            xhs = torch.cat(xhs,dim=0)
            x_hidden.append(xhs)
            nn_clfs.append(NearestNeighbors(n_neighbors=n_neighbors,\
                                            n_jobs=-1).fit(xhs.flatten(start_dim=1)))
    return nn_clfs,x_hidden


def get_negs(model,nn_clfs,train_data,train_hidden,x,y,hl=3,\
            input_shape=(3,32,32),device=DEVICE):
    model.eval()
    with torch.no_grad():
        x_hidden = model.feature(x.to(device)).cpu()
    n_neighbors = nn_clfs[0].n_neighbors*9
    y_class = [y==i for i in range(10)]
    x_class = [x_hidden[yy] for yy in y_class]
    nns = []
    for i,xx in enumerate(x_class):
        for j in range(10):
            if j != i:
                nn_inds = nn_clfs[j].kneighbors(xx.flatten(start_dim=1),\
                                                return_distance=False)
                if (i == 0 and j == 1) or (i > 0 and j == 0):
                    neib_col = train_data[j][torch.LongTensor(nn_inds)]
                else:
                    neib_col = torch.cat((neib_col,\
                                          train_data[j][torch.LongTensor(nn_inds)]),1)
        nns.append(neib_col)
    nns = torch.cat(nns,dim=0)
    nns_reordered = torch.zeros((x.size(0),n_neighbors,)+input_shape)
    start_ind = 0
    for yy in y_class:
        end_ind = start_ind+yy.sum()
        nns_reordered[yy] = nns[start_ind:end_ind]
        start_ind = end_ind
    return nns_reordered.reshape((-1,)+input_shape),x_hidden

In [5]:
def KnnAttack(inp, y_inp, nbd, model, x_ot=None, rl=True,\
              eps=4/255, step=2/255, it=10, lamb = 10, DEVICE=DEVICE):
    loss_fn = nn.CrossEntropyLoss()
    model.eval()
    eta = torch.FloatTensor(*inp.shape).uniform_(-eps, eps)
    inp = inp.to(DEVICE)
    eta = eta.to(DEVICE)
    eta.requires_grad = True
    inp.requires_grad = True
    #feature = model.feature(x_adv.to(DEVICE))
    for i in range(it):
        inpadv = inp + eta

        affinity = calc_affinity(model.feature(nbd.to(DEVICE)),\
                                 model.feature(inpadv.to(DEVICE))) / 9
        
        if relax:
            affinity = affinity
        else:
            negaff = calc_affinity(model.feature(x_neg.to(DEVICE)),\
                            model.feature(inpadv.to(DEVICE))) / 9
            affinity = torch.log(torch.exp(affinity) / (torch.exp(affinity)+torch.exp(negaff)))
        affinity = - affinity
        pred_adv = model(inpadv)[-1]
        loss_ce = - loss_fn(pred_adv, y_inp.to(DEVICE))
        loss = loss_ce + lamb*affinity
        grad_sign = torch.autograd.grad(loss, inpadv, only_inputs=True,\
                                        retain_graph = False)[0].sign()
        #affinity.backward()
        pert = step * grad_sign
        inpadv = (inpadv-pert).clamp(0.0,1.0)
        tempeta = (inpadv - inp).clamp(-eps, eps)
        eta = tempeta
    return inp+eta

In [6]:
y_train = np.array(trainset.targets)
train_data = [torch.FloatTensor(trainset.data[y_train==i].transpose(0,3,1,2)/255.) for i in range(10)]

In [None]:
BURN_IN = 0
EPS = 0.02
#use relaxation or not
relax = False
x_neg = None

loss_fn = nn.CrossEntropyLoss()
optimizer = SGD(model.parameters(),lr=1e-3,momentum=0.9,weight_decay=1e-4,nesterov=True)
pgd = PGD(eps=8/255.,step=2/255.,max_iter=10)
# scheduler = lr_scheduler.StepLR(optimizer,step_size=50,gamma=0.1)
EPOCHS = 30
nn_clfs = None
lt1 = 1#0.01  #penalty on knn loss
lt2 = 100#100
layers = 3
for ep in range(EPOCHS):
    
    
    #if ep>=BURN_IN and not (ep-BURN_IN)%1:
    if ep == 0:
        nn_clfs, train_hidden = build_nn_clfs(model,train_data,hidden_layer=layers,\
                                              n_neighbors=9)
    if not relax:
        neg_clfs, neg_hidden = build_neg_clfs(model,train_data,hidden_layer=layers,\
                                              n_neighbors=1)
    
    train_loss = 0.
    train_correct = 0.
    train_total = 0.

    with tqdm(trainloader,desc=f"{ep+1}/{EPOCHS} epochs:") as t:
        for i,(x,y) in enumerate(t):
            model.train()
            if nn_clfs is not None:
                x_mem, _ = get_nns(model,nn_clfs,train_data,train_hidden,x,y)
                if not relax:
                    x_neg, _ = get_negs(model,neg_clfs,train_data,neg_hidden,x,y)
                x_adv = KnnAttack(x, y, x_mem, model, x_ot = x_neg, rl = relax, eps=4/255,\
                                  step=2/255,it=10, lamb = lt2, DEVICE=DEVICE)
                model.train()
                _,out = model(x_adv.detach().to(DEVICE))
                loss_ce = loss_fn(out,y.to(DEVICE))
                aff = calc_affinity(model.feature(x_mem.to(DEVICE)),\
                                    model.feature(x_adv.to(DEVICE))) / 9
                if relax:
                    aff = aff
                else:
                    negaff = calc_affinity(model.feature(x_neg.to(DEVICE)),\
                                    model.feature(x_adv.to(DEVICE))) / 9
                    aff = torch.log(torch.exp(aff) / (torch.exp(aff)+torch.exp(negaff)))
                loss = loss_ce + lt1*aff
                train_loss += loss.item()
                pred = out.max(1)[1].detach().cpu()
                train_correct += (pred==y).sum().item()
                train_total += x.size(0)
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
                t.set_postfix({
                    "train_loss": train_loss/train_total,
                    "train_acc": train_correct/train_total
                })
            else:
                model.train()
                _,out = model(x.to(DEVICE))
                loss = loss_fn(out,y.to(DEVICE))
                train_loss += loss.item()*x.size(0)
                pred = out.max(dim=1)[1].detach().cpu()
                train_correct += (pred==y).sum().item()
                train_total += x.size(0)
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
                t.set_postfix({
                    "train_loss": train_loss/train_total,
                    "train_acc": train_correct/train_total
                })
            if i == len(trainloader)-1:
                test_correct_rob = 0
                test_correct = 0
                test_correct_knnrob = 0
                test_total = 0
                for x,y in testloader:
                    x_adv = pgd.generate(model,x,y,device=DEVICE)
                    #knn attack (informal)
                    nn_clfs, train_hidden = build_nn_clfs(model,train_data,\
                                                          hidden_layer=layers,n_neighbors=9)
                    
                    if not relax:
                        neg_clfs, neg_hidden = build_neg_clfs(model,train_data,\
                                                          hidden_layer=layers,n_neighbors=1)
                    
                    x_mem, _ = get_nns(model,nn_clfs,train_data,train_hidden,x,y)
                    if not relax:
                        x_neg, _ = get_negs(model,neg_clfs,train_data,neg_hidden,x,y)
                    x_adv = KnnAttack(x, y, x_mem, model, x_ot = x_neg, rl = relax, eps=4/255,\
                                  step=2/255,it=10, lamb = 10000, DEVICE=DEVICE)
#                     x_knnadv = KnnAttack(x, y, x_mem, model, eps=4/255, step=2/255,\
#                                          it=10, lamb = 100, DEVICE=DEVICE)
                    ####
                    model.eval()
                    with torch.no_grad():
                        pred = model(x.to(DEVICE))[-1].max(dim=1)[1]
                        test_correct += (pred==y.to(DEVICE)).sum().item()
                        pred_adv = model(x_adv.to(DEVICE))[-1].max(dim=1)[1]
                        test_correct_rob += (pred_adv==y.to(DEVICE)).sum().item()
                        test_total += x.size(0)
                        #knn attack acc
                        pred_knnadv = model(x_knnadv.to(DEVICE))[-1].max(dim=1)[1]
                        test_correct_knnrob += (pred_knnadv==y.to(DEVICE)).sum().item()
                        #
                t.set_postfix({
                    "train_loss": train_loss/train_total,
                    "train_acc": train_correct/train_total,
                    "test_acc": test_correct/test_total,
                    "test_acc_rob": test_correct_rob/test_total,
                    "test_acc_knnrob": test_correct_knnrob/test_total
                })
#     scheduler.step()

1/30 epochs::  75%|███████▌  | 294/391 [59:36<19:40, 12.17s/it, train_loss=-.0127, train_acc=0.11]   