In [1]:
import os
import sys
import time
import pickle

from collections import Counter

import matplotlib.pyplot   as plt
import numpy as np
from sklearn.neighbors import NearestNeighbors,KNeighborsClassifier

import torch
import torch.nn            as nn
import torch.nn.functional as F
import torch.optim         as optim
from torch.utils.data.dataset import random_split
from torchvision              import datasets 
from torchvision              import transforms

from DkNN import CKNN
import utilities
from mnist_model import CNN
from attack import PGD

class GenAdapt:
    '''
    core component of AISE B-cell generation
    '''

    def __init__(self, mut_range, mut_prob, combine_rate=0.5, hybrid_rate=0.5, mode='random'):
        self.mut_range = mut_range
        self.mut_prob = mut_prob
        self.combine_rate = combine_rate
        self.hybrid_rate = hybrid_rate
        self.mode = mode

    def crossover(self, base1, base2, select_prob):
        assert base1.ndim == 2 and base2.ndim == 2, "Number of dimensions should be 2"
        crossover_mask = torch.rand(base1.size()) < select_prob
        return torch.where(crossover_mask, base1, base2)

    def mutate_random(self, base):
        mut = 2 * torch.rand_like(base) - 1  # uniform (-1,1)
        mut = self.mut_range * mut
        mut_mask = torch.rand(base.size()) < self.mut_prob
        child = torch.where(mut_mask, base, base + mut)
        return torch.clamp(child, 0, 1)

    def mutate_guided(self, base, target):
        guidance = target - base
        mut = (2 * torch.rand_like(base) - 1) * guidance * self.mut_range  # uniform (-1,1)
        mut_mask = torch.rand(base.size()) < self.mut_prob
        child = torch.where(mut_mask, base, base + mut)
        return torch.clamp(child, 0, 1)

    def mutate_combined(self, base, target):
        guidance = target - base
        mut_random = 2 * torch.rand_like(base) - 1  # uniform (-1,1)
        mut_guided = (2 * torch.rand_like(base) - 1) * guidance
        # when self.combine_rate is set 0, it degenerate into random mutate
        combine_mask = torch.rand(base.size()) < self.combine_rate
        mut = self.mut_range * torch.where(combine_mask, mut_guided, mut_random)
        mut_mask = torch.rand(base.size()) < self.mut_prob
        child = torch.where(mut_mask, base, base + mut)
        return torch.clamp(child, 0, 1)

    def hybrid(self, base, target):
        child = self.crossover(base, target, self.hybrid_rate)
        child = self.mutate_guided(child, target - base)
        return child

    def __call__(self, *args):
        if self.mode == "random":
            base, *_ = args
            return self.mutate_random(base)
        else:
            assert len(args) == 2
            base, target = args
            if self.mode == "guided":
                return self.mutate_guided(base, target)
            if self.mode == "combined":
                return self.mutate_combined(base, target)
            if self.mode == "hybrid":
                return self.hybrid(base, target)
            else:
                raise ValueError("Unsupported mutation type!")

    def proliferate(self, p1, p2, select_prob, mut_prob):
        pass

def gram_matrix(input, batch_size=64):
    a, b, c, d = input.size()  # a=batch size(=1)
    # b=number of feature maps
    # (c,d)=dimensions of a f. map (N=c*d)
    
    G = []
    for i in range(0,a,batch_size):
        temp_features = input[i:i+batch_size]  # resise F_XL into \hat F_XL
        temp_features = temp_features.view(temp_features.size(0)*b,c*d)
        temp_G = torch.mm(temp_features, temp_features.t()).view(-1,b,b)  # compute the gram product
        G.append(temp_G)
    # we 'normalize' the values of the gram matrix
    # by dividing by the number of element in each feature maps.
    return torch.cat(G,dim=0).div(b * c * d)

class L2NearestNeighbors(NearestNeighbors):
    '''
    compatible query object class for euclidean distance
    '''

    def __call__(self, X):
        return self.kneighbors(X, return_distance=False)


def recip_l2_dist(X, Y, eps=1e-6):
    correction = np.power(euclidean_distances(X, Y), 2) + eps
    return 1 / np.sqrt(correction)


def neg_l2_dist(X, Y):
    return -euclidean_distances(X, Y)

def feature_space(net, n_layers, inputs, device, batch_size=128):

    conv_features = [[] for _ in range(n_layers)]
    net.eval()
    for ind in range(0,inputs.size(0),batch_size):
        X = inputs[ind:ind+batch_size]
        *out_convs, out = net(X.to(device))
        for i, out_conv in enumerate(out_convs):
            conv_feat = out_conv.detach().cpu()
            conv_features[i].append(conv_feat)
    print('\tConcatenating results')
    conv_features = [torch.cat(out_convs) for out_convs in conv_features]

    return conv_features

class AISE:
    '''
    implement the Adaptive Immune System Emulation
    '''

    def __init__(self, X_orig, y_orig, X_hidden=[], weights=[], model=None, input_shape=None,
                 device=torch.device("cuda"), n_class=10, n_neighbors=10, query_class="l2", norm_order=2,
                 fitness_function=recip_l2_dist, sampling_temperature=.3, max_generation=20, requires_init=False,
                 mut_range=(.1, .3), mut_prob=(.1, .3), mut_mode="combined", combine_rate=0.7, hybrid_rate=.9,
                 decay=(.9, .9), n_population=1000, memory_threshold=.25, plasma_threshold=.05, return_log=True):

        self.model = model
        self.device = device

        if input_shape is None:
            self.input_shape = tuple(X_orig.shape[1:])  # mnist: (1,28,28)
        else:
            self.input_shape = input_shape

        self.X_orig = X_orig
        self.y_orig = y_orig
        self.X_cat = self._transform_to_inner_repr(self.X_orig)

        self.n_class = n_class
        self.n_neighbors = n_neighbors
        self.query_class = query_class
        self.norm_order = norm_order
        self.fitness_func = fitness_function
        self.sampl_temp = sampling_temperature
        self.max_generation = max_generation
        self.n_population = self.n_class * self.n_neighbors
        self.requires_init = requires_init

        self.mut_range = mut_range
        self.mut_prob = mut_prob

        if isinstance(mut_range, float):
            self.mut_range = (mut_range, mut_range)
        if isinstance(mut_prob, float):
            self.mut_prob = (mut_prob, mut_prob)

        self.mut_mode = mut_mode
        self.combine_rate = combine_rate
        self.hybrid_rate = hybrid_rate
        self.decay = decay
        self.n_population = n_population
        self.plasma_thres = plasma_threshold
        self.memory_thres = memory_threshold
        self.return_log = return_log

        model.to(device)

        self.X_cat = self.transform(self.X_orig, *self.X_hidden)
        self.query_objects = self._build_all_query_objects()

    def _build_class_query_object(self, class_label=-1):
        if class_label + 1:
            X_class = self.X_cat[self.y_orig == class_label]
        else:
            X_class = self.X_cat
        if self.query_class == "l2":
            query_object = L2NearestNeighbors(n_neighbors=self.n_neighbors).fit(X_class)
        return query_object

    def _build_all_query_objects(self):
        if self.n_class:
            print("Building query objects for {} classes {} samples...".format(self.n_class, self.X_orig.size(0)),
                  end="")
            query_objects = [self._build_class_query_object(i) for i in range(self.n_class)]
            print("done!")
        else:
            print("Building one single query object {} samples...".format(self.X_orig.size(0)), end="")
            query_objects = [self._build_class_query_object()]
            print("done!")
        return query_objects

    def _query_nns_ind(self, Q):
        assert Q.ndim == 2, "Q: 2d array-like (n_queries,n_features)"
        if self.n_class:
            print("Searching {} naive B cells per class for each of {} antigens...".format(self.n_neighbors, Q.size(0)),
                  end="")
            rel_ind = [query_obj(Q) for query_obj in self.query_objects]
            abs_ind = []
            for c in range(self.n_class):
                class_ind = np.where(self.y_orig.numpy() == c)[0]
                abs_ind.append(class_ind[rel_ind[c]])
            print("done!")
        else:
            print("Searching {} naive B cells for each of {} antigens...".format(self.n_neighbors, Q.size(0)),
                  end="")
            abs_ind = [query_obj(Q) for query_obj in self.query_objects]
            print('done!')
        return abs_ind

    def _transform_to_inner_repr(self, X):
        '''
        transform b cells and antigens into inner representations of AISE
        '''
        X_hidden = []
        out_hidden = feature_space(model, 4, X, self.device, 128)
        for out in out_hidden:
            X_hidden.append(gram_matrix(out).cpu().flatten(start_dim=1))
        
        return torch.cat(X_hidden,dim=0)

    def generate_b_cells(self, ant, ant_tran, nbc_ind, y_ant=None):
        assert ant_tran.ndim == 2, "ant: 2d tensor (n_antigens,n_features)"
        mem_bc_batch = []
        pla_bc_batch = []
        mem_lab_batch = []
        pla_lab_batch = []
        print("Affinity maturation process starts with population of {}...".format(self.n_population))
        ant_logs = []  # store the history dict in terms of metrics for antigens
        for n in range(ant.size(0)):
            genadapt = GenAdapt(self.mut_range[1], self.mut_prob[1], self.combine_rate,
                                self.hybrid_rate, mode=self.mut_mode)
            curr_gen = torch.cat([self.X_orig[ind[n]] for ind in nbc_ind])  # naive b cells
            # labels = np.repeat(np.arange(self.n_class), self.n_neighbors)
            labels = np.concatenate([self.y_orig[ind[n]] for ind in nbc_ind])
            if self.requires_init:
                assert self.n_population % (
                        self.n_class * self.n_neighbors) == 0, \
                    "n_population should be divisible by the product of n_class and n_neighbors"
                curr_gen = curr_gen.repeat((self.n_population // (self.n_class * self.n_neighbors), 1))
                curr_gen = genadapt.mutate_random(curr_gen)  # initialize *NOTE: torch.Tensor.repeat <> numpy.repeat
                labels = np.tile(labels, self.n_population // (self.n_class * self.n_neighbors))
            curr_inner_repr = self._transform_to_inner_repr(curr_gen)
            fitness_score = torch.Tensor(self.fitness_func(ant_tran[n].unsqueeze(0), curr_inner_repr)[0])
            best_pop_fitness = float('-inf')
            decay_coef = (1., 1.)
            num_plateau = 0
            ant_log = dict()  # history log for each antigen
            fitness_pop_hist = []
            if y_ant is not None:
                fitness_true_class_hist = []
                pct_true_class_hist = []
            for i in range(self.max_generation):
                # print("Antigen {} Generation {}".format(n,i))
                survival_prob = F.softmax(fitness_score / self.sampl_temp, dim=-1)
                parents_ind = Categorical(probs=survival_prob).sample((self.n_population,))
                parents = curr_gen[parents_ind]
                curr_gen = genadapt(parents, ant[n].unsqueeze(0))
                curr_inner_repr = self._transform_to_inner_repr(curr_gen)
                labels = labels[parents_ind.numpy()]
                fitness_score = torch.Tensor(self.fitness_func(ant_tran[n].unsqueeze(0), curr_inner_repr)[0])
                pop_fitness = fitness_score.sum().item()
                # logging
                fitness_pop_hist.append(pop_fitness)
                if y_ant is not None:
                    true_class_fitness = fitness_score[labels == y_ant[n]].sum().item()
                    fitness_true_class_hist.append(true_class_fitness)
                    true_class_pct = (labels == y_ant[n]).astype('float').mean().item()
                    pct_true_class_hist.append(true_class_pct)
                # adaptive shrinkage of certain hyper-parameters
                if self.decay:
                    assert len(self.decay) == 2
                    if pop_fitness < best_pop_fitness:
                        if num_plateau >= max(math.log(self.mut_range[0] / self.mut_range[1], self.decay[0]),
                                              math.log(self.mut_prob[0] / self.mut_prob[1], self.decay[1])):
                            # early stop
                            break
                        decay_coef = tuple(decay_coef[i] * self.decay[i] for i in range(2))
                        num_plateau += 1
                        genadapt = GenAdapt(max(self.mut_range[0], self.mut_range[1] * decay_coef[0]),
                                            max(self.mut_prob[0], self.mut_prob[1] * decay_coef[1]),
                                            self.combine_rate, self.hybrid_rate, mode=self.mut_mode)
                    else:
                        best_pop_fitness = pop_fitness
            # fitness_score = torch.Tensor(self.fitness_func(ant_tran[n].unsqueeze(0),curr_inner_repr)[0])
            _, fitness_rank = torch.sort(fitness_score)
            ant_log["fitness_pop"] = fitness_pop_hist
            if y_ant is not None:
                ant_log["fitness_true_class"] = fitness_true_class_hist
                ant_log["pct_true_class"] = pct_true_class_hist
            pla_bc_batch.append(curr_gen[fitness_rank[-int(self.plasma_thres * self.n_population):]])
            pla_lab_batch.append(labels[fitness_rank[-int(self.plasma_thres * self.n_population):]])
            mem_bc_batch.append(curr_gen[fitness_rank[-int(self.memory_thres * self.n_population):-int(
                self.plasma_thres * self.n_population)]])
            mem_lab_batch.append(labels[fitness_rank[-int(self.memory_thres * self.n_population):-int(
                self.plasma_thres * self.n_population)]])
            ant_logs.append(ant_log)
        print("Memory & plasma B cells generated!")
        return torch.cat(mem_bc_batch), torch.tensor(np.stack(mem_lab_batch)), \
               torch.cat(pla_bc_batch), torch.tensor(np.stack(pla_lab_batch)), \
               ant_logs

    def clonal_expansion(self, ant, y_ant=None, return_log=False):
        print("Clonal expansion starts...")
        ant_tran = self._transform_to_inner_repr(ant, reshape=False)
        nbc_ind = self._query_nns_ind(ant_tran)
        mem_bcs, mem_labs, pla_bcs, pla_labs, ant_logs = self.generate_b_cells(ant.flatten(start_dim=1), ant_tran,
                                                                               nbc_ind, y_ant)
        print("{} plasma B cells and {} memory generated!".format(pla_bcs.size(0), mem_bcs.size(0)))
        if return_log:
            return mem_bcs, mem_labs, pla_bcs, pla_labs, ant_logs
        else:
            return mem_bcs, mem_labs, pla_bcs, pla_labs


In [2]:
DATADIR = "datasets/"
DEVICE = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
config = utilities.config_to_namedtuple(utilities.get_config('config_mnist.json'))

mnist_trainset = datasets.MNIST(root=DATADIR, train=True, download=False, transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.,), (1,))
]))
mnist_testset = datasets.MNIST(root=DATADIR, train=False, download=False, transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.,), (1,))
]))

train_loader = torch.utils.data.DataLoader(mnist_trainset,
    shuffle = True,
    batch_size = 64
)


test_loader = torch.utils.data.DataLoader(mnist_testset,
    shuffle = False,
    batch_size = 64
)

filename = 'models/mnistmodel.pt'
model = CNN().to(DEVICE)

if os.path.isfile(filename):
    print("=> loading checkpoint '{}'".format(filename))
    checkpoint = torch.load(filename,map_location=DEVICE)
    model.load_state_dict(checkpoint['state_dict'])
    print("=> loaded checkpoint '{}' (epoch {})"
              .format(filename, checkpoint['epoch']))
else:
    print("=> no checkpoint found at '{}'".format(filename))

=> loading checkpoint 'models/mnistmodel.pt'
=> loaded checkpoint 'models/mnistmodel.pt' (epoch 55)


In [3]:
model.eval()

CNN(
  (conv1): Conv2d(1, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (conv3): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (conv4): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (fc1): Linear(in_features=6272, out_features=256, bias=True)
  (fc2): Linear(in_features=256, out_features=256, bias=True)
  (fc3): Linear(in_features=256, out_features=10, bias=True)
)

In [4]:
np.random.seed(1234)
ind_full = np.arange(60000)
np.random.shuffle(ind_full)
ind_partial = ind_full[:2000]
X_train_partial = mnist_trainset.data[ind_partial].unsqueeze(1)/255.
y_train_partial = mnist_trainset.targets[ind_partial]

In [5]:
ind_eval = ind_full[2000:2200]
X_eval = mnist_trainset.data[ind_eval].unsqueeze(1)/255.
y_eval = mnist_trainset.targets[ind_eval]
X_adv = PGD(eps=40/255.,sigma=20/255.,nb_iter=20,DEVICE=DEVICE).attack_batch(model,X_eval.to(DEVICE),y_eval.to(DEVICE),batch_size=64)
*_,out = model(X_adv)
y_pred_adv = torch.max(out,1)[1]
print('The accuracy of plain cnn under PGD attacks is: {:f}'.format((y_eval.numpy()==y_pred_adv.detach().cpu().numpy()).mean())) 

The accuracy of plain cnn under PGD attacks is: 0.180000


In [6]:
aise = AISE(X_train_partial,y_train_partial,model=model)
mem_bcs, mem_labs, pla_bcs, pla_labs, ant_logs = aise.clonal_expansion(X_adv.cpu(),y_eval.numpy(),return_log=True)

	Concatenating results


IndexError: Dimension out of range (expected to be in range of [-1, 0], but got 1)