This notebooks computed the embedding kernel for a pre-trained CNN on CIFAR10 badnet data. the dataset can be downloaded from https://github.com/Shawn-Shan/forensics, and the model training is described in XXX. 

The NTK computation is then described in XYZ.

In [1]:
import torch
import pickle
import matplotlib.pyplot as plt
from tqdm import tqdm
from torch.utils.data import TensorDataset
from torch.utils.data import DataLoader
from einops import rearrange
import os
import numpy as np
from scipy.stats import kendalltau, spearmanr

In [2]:
with open('/rcfs/projects/task0_pmml/traceback/forensics/results/cifar_cifar2_res.p','rb') as f:
    RES = pickle.load(f)

In [3]:
class Model(torch.nn.Module):
    def __init__(self,):
        super(Model, self).__init__()
        
        self.conv2d = torch.nn.Conv2d(3, 32, (3,3), padding=(1,1),)
        self.batch_normalization = torch.nn.BatchNorm2d(32,momentum=0.01,eps=1e-3)
        self.conv2d_1 = torch.nn.Conv2d(32, 32, (3,3), padding=(1,1))
        self.batch_normalization_1 = torch.nn.BatchNorm2d(32,momentum=0.01,eps=1e-3)
        self.max_pooling2d = torch.nn.MaxPool2d((2,2))
        
        self.conv2d_2 = torch.nn.Conv2d(32, 64, (3,3), padding=(1,1))
        self.batch_normalization_2 = torch.nn.BatchNorm2d(64,momentum=0.01,eps=1e-3)
        self.conv2d_3 = torch.nn.Conv2d(64, 64, (3,3), padding=(1,1))
        self.batch_normalization_3 = torch.nn.BatchNorm2d(64,momentum=0.01,eps=1e-3)
        self.max_pooling2d_1 = torch.nn.MaxPool2d((2,2))
        
        self.conv2d_4 = torch.nn.Conv2d(64, 128, (3,3), padding=(1,1))
        self.batch_normalization_4 = torch.nn.BatchNorm2d(128,momentum=0.01,eps=1e-3)
        self.conv2d_5 = torch.nn.Conv2d(128, 128, (3,3), padding=(1,1))
        self.batch_normalization_5 = torch.nn.BatchNorm2d(128,momentum=0.01,eps=1e-3)
        self.max_pooling2d_2 = torch.nn.MaxPool2d((2,2))

        self.flatten = torch.nn.Flatten()
        self.max_pooling1d = torch.nn.MaxPool1d((4))
        self.dropout = torch.nn.Dropout(0.2)
        
        self.dense = torch.nn.Linear(512,512,)
        self.batch_normalization_6 = torch.nn.BatchNorm1d(512,momentum=0.01,eps=1e-3)
        
        self.dense_1 = torch.nn.Linear(512,512,)
        self.batch_normalization_7 = torch.nn.BatchNorm1d(512,momentum=0.01,eps=1e-3)
        
        self.dense_2 = torch.nn.Linear(512,10,)
    
    def forward(self,x):
        x = self.conv2d(x)
        x = torch.nn.functional.relu(x)
        x = self.batch_normalization(x)
        x = self.conv2d_1(x)
        x = torch.nn.functional.relu(x)
        x = self.batch_normalization_1(x)
        x = self.max_pooling2d(x)
        
        x = self.conv2d_2(x)
        x = torch.nn.functional.relu(x)
        x = self.batch_normalization_2(x)
        x = self.conv2d_3(x)
        x = torch.nn.functional.relu(x)
        x = self.batch_normalization_3(x)
        x = self.max_pooling2d_1(x)
        
        x = self.conv2d_4(x)
        x = torch.nn.functional.relu(x)
        x = self.batch_normalization_4(x)
        x = self.conv2d_5(x)
        x = torch.nn.functional.relu(x)
        x = self.batch_normalization_5(x)
        x = self.max_pooling2d_2(x)
        
        
        x = self.flatten(x)
        x = self.max_pooling1d(x)
        x = self.dropout(x)
        
        x = self.dense(x)
        x = torch.nn.functional.relu(x)
        x = self.batch_normalization_6(x)
        x = self.dense_1(x)
        x = torch.nn.functional.relu(x)
        x = self.batch_normalization_7(x)
        x = self.dense_2(x)
        
        return x


In [4]:
model = Model()
model.to('cuda')
torch.manual_seed(1234)

optim = torch.optim.SGD(model.parameters(),1e-2,momentum=0.9,nesterov=True)
loss = torch.nn.CrossEntropyLoss()

In [5]:
RES.keys()

train_x = rearrange(torch.tensor(RES['injected_X'],dtype=torch.float32),'b h w c -> b c h w')
train_y = torch.tensor(RES['injected_Y'],dtype=torch.float32)

test_x = rearrange(torch.tensor(RES['injected_X_test'],dtype=torch.float32),'b h w c -> b c h w')
test_y = torch.tensor(RES['injected_Y_test'],dtype=torch.float32)

test_x_og = rearrange(torch.tensor(RES['X_test'],dtype=torch.float32),'b h w c -> b c h w')
test_y_og = torch.tensor(RES['Y_test'],dtype=torch.float32)



train_loader = DataLoader(TensorDataset(train_x,train_y),batch_size=128,shuffle=True)
test_loader = DataLoader(TensorDataset(test_x,test_y),batch_size=1,shuffle=False)
test_loader_og = DataLoader(TensorDataset(test_x_og,test_y_og),batch_size=1,shuffle=False)

In [6]:
all_X = torch.cat([train_x,test_x,test_x_og]).to('cuda')

In [7]:
model.load_state_dict(torch.load('/rcfs/projects/task0_pmml/MODELS/poisoned_CNN.pt'))

<All keys matched successfully>

In [8]:
all_X.shape

torch.Size([70000, 3, 32, 32])

In [9]:
model.hooks = {}
for key in list(model.hooks.keys()):
    model.hooks[key].remove()

In [10]:
ALL_NAMES = []
for name, module in model.named_modules():
    print(name)
    if '' == name:
        continue
    if 'dropout' in name:
        continue
    if 'flatten' in name:
        continue
        
    ALL_NAMES.append(name)


conv2d
batch_normalization
conv2d_1
batch_normalization_1
max_pooling2d
conv2d_2
batch_normalization_2
conv2d_3
batch_normalization_3
max_pooling2d_1
conv2d_4
batch_normalization_4
conv2d_5
batch_normalization_5
max_pooling2d_2
flatten
max_pooling1d
dropout
dense
batch_normalization_6
dense_1
batch_normalization_7
dense_2


In [11]:
for name in ALL_NAMES:
    print(name)

conv2d
batch_normalization
conv2d_1
batch_normalization_1
max_pooling2d
conv2d_2
batch_normalization_2
conv2d_3
batch_normalization_3
max_pooling2d_1
conv2d_4
batch_normalization_4
conv2d_5
batch_normalization_5
max_pooling2d_2
max_pooling1d
dense
batch_normalization_6
dense_1
batch_normalization_7
dense_2


In [11]:
activation = {}
def get_activation(name):
    def hook(model, input, output):
        activation[name] = output.detach()
    return hook

In [12]:
total_em = 0
model.eval()
for k,NAME in enumerate(ALL_NAMES):
    EM_Component = torch.zeros((70_000,70_000),device='cpu')
    
    for name, module in model.named_modules():
        if name == NAME:
            model.hooks[name] = module.register_forward_hook(get_activation(name))
   
    with torch.no_grad():
        for i in tqdm(range(7)):
            activation = {}
            X1 =  model.forward(all_X[i*10_000:(i+1)*10_000,])
            X1_activation = activation[NAME].reshape(10_000,-1)
            for j in range(7):
                activation = {}
                X2 =  model.forward(all_X[j*10_000:(j+1)*10_000,])
                X2_activation = activation[NAME].reshape(10_000,-1)
                
                component = torch.matmul(X1_activation,X2_activation.T).cpu()
                EM_Component[i*10_000:(i+1)*10_000,j*10_000:(j+1)*10_000] = component
    
    torch.save(EM_Component,f'/rcfs/projects/task0_pmml/traceback/kernels/Em_comp/{NAME}-{k}.pt')
    total_em+=EM_Component
    model.hooks[NAME].remove()

100%|█████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:28<00:00,  4.11s/it]
100%|█████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:26<00:00,  3.85s/it]
100%|█████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:29<00:00,  4.18s/it]
100%|█████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:27<00:00,  3.87s/it]
100%|█████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:12<00:00,  1.75s/it]
100%|█████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:16<00:00,  2.37s/it]
100%|█████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:16<00:00,  2.38s/it]
100%|█████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:16<00:00,  2.37s/it]
100%|███████████████████████████

In [13]:
np.save('/rcfs/projects/task0_pmml/traceback/kernels/embedding.npy',total_em.detach().cpu().numpy())