<a href="https://colab.research.google.com/github/revathi-prasad/ai-safety-usc-fellowship/blob/main/Trojan_Detection_MNTD_Submission.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip uninstall matplotlib -y
!pip uninstall numpy -y
!pip uninstall scikit-learn -y
!pip uninstall tqdm -y
!pip uninstall torch -y
!pip uninstall torchvision -y
!pip uninstall vit_pytorch -y

Found existing installation: matplotlib 3.2.2
Uninstalling matplotlib-3.2.2:
  Successfully uninstalled matplotlib-3.2.2
Found existing installation: numpy 1.21.6
Uninstalling numpy-1.21.6:
  Successfully uninstalled numpy-1.21.6
Found existing installation: scikit-learn 1.0.2
Uninstalling scikit-learn-1.0.2:
  Successfully uninstalled scikit-learn-1.0.2
Found existing installation: tqdm 4.64.1
Uninstalling tqdm-4.64.1:
  Successfully uninstalled tqdm-4.64.1
Found existing installation: torch 1.12.1+cu113
Uninstalling torch-1.12.1+cu113:
  Successfully uninstalled torch-1.12.1+cu113
Found existing installation: torchvision 0.13.1+cu113
Uninstalling torchvision-0.13.1+cu113:
  Successfully uninstalled torchvision-0.13.1+cu113


In [None]:
!pip install matplotlib==3.2.1 numpy==1.19.1 scikit-learn==0.23.1 tqdm==4.62.3 torch==1.11.0 torchvision==0.12.0 vit_pytorch==0.35.5

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting matplotlib==3.2.1
  Downloading matplotlib-3.2.1-cp37-cp37m-manylinux1_x86_64.whl (12.4 MB)
[K     |████████████████████████████████| 12.4 MB 4.3 MB/s 
[?25hCollecting numpy==1.19.1
  Downloading numpy-1.19.1-cp37-cp37m-manylinux2010_x86_64.whl (14.5 MB)
[K     |████████████████████████████████| 14.5 MB 34.6 MB/s 
[?25hCollecting scikit-learn==0.23.1
  Downloading scikit_learn-0.23.1-cp37-cp37m-manylinux1_x86_64.whl (6.8 MB)
[K     |████████████████████████████████| 6.8 MB 18.7 MB/s 
[?25hCollecting tqdm==4.62.3
  Downloading tqdm-4.62.3-py2.py3-none-any.whl (76 kB)
[K     |████████████████████████████████| 76 kB 3.1 MB/s 
[?25hCollecting torch==1.11.0
  Downloading torch-1.11.0-cp37-cp37m-manylinux1_x86_64.whl (750.6 MB)
[K     |████████████████████████████████| 750.6 MB 12 kB/s 
[?25hCollecting torchvision==0.12.0
  Downloading torchvision-0.12.0-cp37-cp37m-manylinu

In [None]:
import torch
import os
import json
from tqdm.notebook import tqdm
import numpy as np
import matplotlib.pyplot as plt
from torch import nn
import torch.nn.functional as F
import torch.backends.cudnn as cudnn
cudnn.benchmark = True  # fire on all cylinders
from sklearn.metrics import roc_auc_score, roc_curve
import sys

sys.path.insert(0, '..')

In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
# Additional cell
torch.__version__


'1.11.0+cu102'

In [None]:
import importlib.util
import sys

spec2 = importlib.util.spec_from_file_location("wrn", "/content/drive/MyDrive/trojan_detection/tdc-starter-kit/wrn.py")
wrn = importlib.util.module_from_spec(spec2)
sys.modules["wrn"] = wrn
spec2.loader.exec_module(wrn)

spec = importlib.util.spec_from_file_location("utils", "/content/drive/MyDrive/trojan_detection/tdc-starter-kit/utils.py")
utils = importlib.util.module_from_spec(spec)
sys.modules["utils"] = utils
spec.loader.exec_module(utils)

## Create the dataset class

In [None]:
class NetworkDatasetDetection(torch.utils.data.Dataset):
    def __init__(self, model_folder):
        super().__init__()
        model_paths = []
        model_paths.extend([os.path.join(model_folder, 'clean', x) \
                            for x in sorted(os.listdir(os.path.join(model_folder, 'clean')))])
        model_paths.extend([os.path.join(model_folder, 'trojan', x) \
                            for x in sorted(os.listdir(os.path.join(model_folder, 'trojan')))])
        labels = []
        data_sources = []
        for p in model_paths:
            with open(os.path.join(p, 'info.json'), 'r') as f:
                info = json.load(f)
                data_sources.append(info['dataset'])
            if p.split('/')[-2] == 'clean':
                labels.append(0)
            elif p.split('/')[-2] == 'trojan':
                labels.append(1)
            else:
                raise ValueError('unexpected path {}'.format(p))
        self.model_paths = model_paths
        self.labels = labels
        self.data_sources = data_sources
    
    def __len__(self):
        return len(self.model_paths)
    
    def __getitem__(self, index):
        return torch.load(os.path.join(self.model_paths[index], 'model.pt')), \
               self.labels[index], self.data_sources[index]

def custom_collate(batch):
    return [x[0] for x in batch], [x[1] for x in batch], [x[2] for x in batch]

## Load data
Spliting off a validation set from the train set for testing purposes.

In [None]:
dataset_path = '/content/drive/MyDrive/trojan_detection/tdc_datasets'
task = 'detection'
dataset = NetworkDatasetDetection(os.path.join(dataset_path, task, 'train'))

split = int(len(dataset) * 0.8)
rnd_idx = np.random.permutation(len(dataset))
train_dataset = torch.utils.data.Subset(dataset, rnd_idx[:split])
val_dataset = torch.utils.data.Subset(dataset, rnd_idx[split:])

train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=1, shuffle=True,
                                           num_workers=0, pin_memory=False, collate_fn=custom_collate)
val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=1,
                                           num_workers=0, pin_memory=False, collate_fn=custom_collate)

## Construct the MNTD network

In [None]:
# #Inputs where out = out/sum(out)
# from pandas.core.internals.construction import nested_data_to_arrays
# data_sources = ['CIFAR-10', 'CIFAR-100', 'GTSRB', 'MNIST']
# data_source_to_channel = {k: 1 if k == 'MNIST' else 3 for k in data_sources}
# data_source_to_resolution = {k: 28 if k == 'MNIST' else 32 for k in data_sources}
# data_source_to_num_classes = {'CIFAR-10': 8192, 'CIFAR-100': 8192, 'GTSRB': 43, 'MNIST': 128}

# class MetaNetwork(nn.Module):
#     def __init__(self, num_queries, num_classes=1):
#         super().__init__()
#         self.queries = nn.ParameterDict(
#             {k: nn.Parameter(torch.rand(num_queries,
#                                         data_source_to_channel[k],
#                                         data_source_to_resolution[k],
#                                         data_source_to_resolution[k])) for k in data_sources}
#         )
#         print('In MetaNetwork class now')
#         #hidden_state_size = list(net.parameters())[-2].size()
#         self.affines = nn.ModuleDict(
#             {k: nn.Linear(data_source_to_num_classes[k]*num_queries, 32) for k in data_sources}
#         )
#         self.norm = nn.LayerNorm(32) 
#         self.gelu = nn.GELU()
#         self.final_output = nn.Linear(32, num_classes)
    
#     def forward(self, net, data_source):
#         """
#         :param net: an input network of one of the model_types specified at init
#         :param data_source: the name of the data source
#         :returns: a score for whether the network is a Trojan or not
#         """
#         query = self.queries[data_source]
#         # print('queries statement passed')
#         if(data_source=='MNIST'):
#           # print('now in mnist if else for net')
#           new_model_1 = nn.Sequential(*list(net.main.children())[:-1])
#           new_model_2 = nn.Sequential(*list(net.main.children())[:-2])
#         elif(data_source == 'CIFAR-10'):
#           # print('now in cifar if else for net')
#           new_model_block_10_1 = nn.Sequential(*list(net.children())[:-2])
#           new_model_block_10_2 = nn.Sequential(*list(net.children())[:-3])
#           # new_model=new_model_block_10[4]
#           # print(new_model.eval())
#         elif(data_source=='CIFAR-100'):
#           new_model_block_100_1 = nn.Sequential(*list(net.children())[:-2])
#           new_model_block_100_2 = nn.Sequential(*list(net.children())[:-3])
#           # print(new_model.eval())
#         else:
#           # new_model = nn.Sequential(*list(net.children())[:-1])
#           # print('now in gtsrb if else for net')
#           new_model = net

#         if(data_source=='MNIST'):
#           # print('now in mnist if else for net')
#           out_1 = new_model_1(query)
#           out_2 = new_model_2(query)
#           out = out_2/torch.sum(out_1)
#         elif(data_source == 'CIFAR-10'):
#           out_1 = new_model_block_10_1(query)
#           out_2 = new_model_block_10_2(query)
#           out = out_2/torch.sum(out_1)
#         elif(data_source=='CIFAR-100'):
#           out_1 = new_model_block_100_1(query)
#           out_2 = new_model_block_100_2(query)
#           out = out_2/torch.sum(out_1)
#         else:
#           out_1 = new_model(query)
#           out = out_1

#         # out = new_model(query)
#         out = self.affines[data_source](out.view(1,-1))
#         out = self.norm(out)
#         out = self.gelu(out)
#         return self.final_output(out)

In [None]:
# Inputs where out = SVD(out)
from pandas.core.internals.construction import nested_data_to_arrays
data_sources = ['CIFAR-10', 'CIFAR-100', 'GTSRB', 'MNIST']
data_source_to_channel = {k: 1 if k == 'MNIST' else 3 for k in data_sources}
data_source_to_resolution = {k: 28 if k == 'MNIST' else 32 for k in data_sources}
data_source_to_num_classes = {'CIFAR-10': 1024, 'CIFAR-100': 1024, 'GTSRB': 1, 'MNIST': 1}

class MetaNetwork(nn.Module):
    def __init__(self, num_queries, num_classes=1):
        super().__init__()
        self.queries = nn.ParameterDict(
            {k: nn.Parameter(torch.rand(num_queries,
                                        data_source_to_channel[k],
                                        data_source_to_resolution[k],
                                        data_source_to_resolution[k])) for k in data_sources}
        )
        print('In MetaNetwork class now')
        #hidden_state_size = list(net.parameters())[-2].size()
        self.affines = nn.ModuleDict(
            {k: nn.Linear(data_source_to_num_classes[k]*num_queries, 32) for k in data_sources}
        )
        self.norm = nn.LayerNorm(32) 
        self.gelu = nn.GELU()
        self.final_output = nn.Linear(32, num_classes)
    
    def forward(self, net, data_source):
        """
        :param net: an input network of one of the model_types specified at init
        :param data_source: the name of the data source
        :returns: a score for whether the network is a Trojan or not
        """
        query = self.queries[data_source]
        # print('queries statement passed')
        if(data_source=='MNIST'):
          # print('now in mnist if else for net')
          new_model_1 = nn.Sequential(*list(net.main.children())[:-1])
          new_model_2 = nn.Sequential(*list(net.main.children())[:-2])
        elif(data_source == 'CIFAR-10'):
          # print('now in cifar if else for net')
          new_model_block_10_1 = nn.Sequential(*list(net.children())[:-2])
          new_model_block_10_2 = nn.Sequential(*list(net.children())[:-3])
          # new_model=new_model_block_10[4]
          # print(new_model.eval())
        elif(data_source=='CIFAR-100'):
          new_model_block_100_1 = nn.Sequential(*list(net.children())[:-2])
          new_model_block_100_2 = nn.Sequential(*list(net.children())[:-3])
          # print(new_model.eval())
        else:
          # new_model = nn.Sequential(*list(net.children())[:-1])
          # print('now in gtsrb if else for net')
          new_model = net

        if(data_source=='MNIST'):
          # print('now in mnist if else for net')
          out_1 = new_model_1(query)
          # out_2 = new_model_2(query)
          out = torch.linalg.svdvals(out_1)
        elif(data_source == 'CIFAR-10'):
          out_1 = new_model_block_10_1(query)
          # out_2 = new_model_block_10_2(query)
          out = torch.linalg.svdvals(out_1)
        elif(data_source=='CIFAR-100'):
          out_1 = new_model_block_100_1(query)
          # out_2 = new_model_block_100_2(query)
          out = torch.linalg.svdvals(out_1)
        else:
          out_1 = new_model(query)
          out = torch.linalg.svdvals(out_1)

        # out = new_model(query)
        out = self.affines[data_source](out.view(1,-1))
        out = self.norm(out)
        out = self.gelu(out)
        return self.final_output(out)

In [None]:
# from torchsummary import summary
# print(type(train_loader))
# for i, (net, label, data_source) in enumerate(train_loader):
#       net = net[0]
#       # print(net)
#       new_model = nn.Sequential(*list(net.children())[1:-1])
#       # print(data_source, new_model[3])
#       break

In [None]:
# from torchsummary import summary
# print(type(train_loader))
# for i, (net, label, data_source) in enumerate(train_loader):
#       net = net[0]
#       # print(net)
#       new_model = nn.Sequential(*list(net.children())[1:-1])
#       print(type(new_model[3]))
#       break

<class 'torch.utils.data.dataloader.DataLoader'>
<class 'torch.nn.modules.batchnorm.BatchNorm2d'>


## Train the network

In [None]:
meta_network = MetaNetwork(10, num_classes=1).cuda().train()

num_epochs = 10
lr = 0.01
weight_decay = 0.
optimizer = torch.optim.Adam(meta_network.parameters(), lr=lr, weight_decay=weight_decay)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, num_epochs * len(train_dataset))

loss_ema = np.inf
for epoch in range(num_epochs):
    
    pbar = tqdm(train_loader)
    pbar.set_description(f"Epoch {epoch + 1}")
    for i, (net, label, data_source) in enumerate(pbar):
        net = net[0]
        label = label[0]
        data_source = data_source[0]
        # print(data_source)
        net.cuda().eval()
        
        out = meta_network(net, data_source)
        
        loss = F.binary_cross_entropy_with_logits(out, torch.FloatTensor([label]).unsqueeze(0).cuda())
        
        optimizer.zero_grad()
        loss.backward(inputs=list(meta_network.parameters()))
        optimizer.step()
        scheduler.step()
        for k in meta_network.queries.keys():
            meta_network.queries[k].data = meta_network.queries[k].data.clamp(0, 1)
        loss_ema = loss.item() if loss_ema == np.inf else 0.95 * loss_ema + 0.05 * loss.item()

        pbar.set_postfix(loss=loss_ema)

In MetaNetwork class now


  0%|          | 0/800 [00:00<?, ?it/s]

  0%|          | 0/800 [00:00<?, ?it/s]

  0%|          | 0/800 [00:00<?, ?it/s]

  0%|          | 0/800 [00:00<?, ?it/s]

  0%|          | 0/800 [00:00<?, ?it/s]

  0%|          | 0/800 [00:00<?, ?it/s]

  0%|          | 0/800 [00:00<?, ?it/s]

  0%|          | 0/800 [00:00<?, ?it/s]

  0%|          | 0/800 [00:00<?, ?it/s]

  0%|          | 0/800 [00:00<?, ?it/s]

In [None]:
meta_network.eval()

## Evaluate the network

In [None]:
def evaluate(meta_network, loader):
    loss_list = []
    correct_list = []
    confusion_matrix = torch.zeros(2,2)
    all_scores = []
    all_labels = []
    
    for i, (net, label, data_source) in enumerate(tqdm(loader)):
        net[0].cuda().eval()
        with torch.no_grad():
            out = meta_network(net[0], data_source[0])
        loss = F.binary_cross_entropy_with_logits(out, torch.FloatTensor([label[0]]).unsqueeze(0).cuda())
        correct = int((out.squeeze() > 0).int().item() == label[0])
        loss_list.append(loss.item())
        correct_list.append(correct)
        confusion_matrix[(out.squeeze() > 0).int().item(), label[0]] += 1
        all_scores.append(out.squeeze().item())
        all_labels.append(label[0])
    
    return np.mean(loss_list), np.mean(correct_list), confusion_matrix, all_labels, all_scores

In [None]:
loss, acc, cmat, _, _ = evaluate(meta_network, train_loader)
print(f'Train Loss: {loss:.3f}, Train Accuracy: {acc*100:.2f}')
print('Confusion Matrix:\n', cmat.numpy())

  0%|          | 0/800 [00:00<?, ?it/s]

Train Loss: 0.003, Train Accuracy: 100.00
Confusion Matrix:
 [[399.   0.]
 [  0. 401.]]


In [None]:
loss, acc, cmat, all_labels, all_preds = evaluate(meta_network, val_loader)
print(f'Val Loss: {loss:.3f}, Val Accuracy: {acc*100:.2f}')
print('Confusion Matrix:\n', cmat.numpy())

  0%|          | 0/200 [00:00<?, ?it/s]

Val Loss: 1.534, Val Accuracy: 60.00
Confusion Matrix:
 [[72. 51.]
 [29. 48.]]


In [None]:
print(f'Val AUROC: {roc_auc_score(all_labels, all_preds):.3f}')

Val AUROC: 0.642


## Make submission

In [None]:
# Dataset class for the validation/test sets, which contain all networks in a single folder

class NetworkDatasetDetectionTest(torch.utils.data.Dataset):
    def __init__(self, model_folder):
        super().__init__()
        model_paths = [os.path.join(model_folder, x) for x in sorted(os.listdir(os.path.join(model_folder)))]
        data_sources = []
        for model_path in model_paths:
            with open(os.path.join(model_path, 'info.json'), 'r') as f:
                info = json.load(f)
                data_sources.append(info['dataset'])
        self.model_paths = model_paths
        self.data_sources = data_sources
    
    def __len__(self):
        return len(self.model_paths)
    
    def __getitem__(self, index):
        return torch.load(os.path.join(self.model_paths[index], 'model.pt')), self.data_sources[index]

def custom_collate(batch):
    return [x[0] for x in batch], [x[1] for x in batch]

In [None]:
dataset_path = '/content/drive/MyDrive/trojan_detection/tdc_datasets'
task = 'detection'

test_dataset = NetworkDatasetDetectionTest(os.path.join(dataset_path, task, 'val'))
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=1, shuffle=False,
                                          num_workers=0, pin_memory=False, collate_fn=custom_collate)

In [None]:
def predict(meta_network, loader):
    
    all_scores = []
    for i, (net, data_source) in enumerate(tqdm(loader)):
        net[0].cuda().eval()
        with torch.no_grad():
            out = meta_network(net[0], data_source[0])
        all_scores.append(out.squeeze().item())
    
    return all_scores

In [None]:
scores = predict(meta_network, test_loader)

  0%|          | 0/1000 [00:00<?, ?it/s]

In [None]:
if not os.path.exists('mntd_submission'):
    os.makedirs('mntd_submission')

with open(os.path.join('mntd_submission', 'predictions.npy'), 'wb') as f:
    np.save(f, np.array(scores))

!cd mntd_submission && zip ../mntd_submission.zip ./* && cd ..

  adding: predictions.npy (deflated 43%)


In [None]:
!ls

drive  mntd_submission	mntd_submission.zip  sample_data


In [None]:
# for i,(net, label, data_source) in enumerate(pbar): 
#   if(i==4):
#       print(label[0])
#       break
  

In [None]:
# dataset_path = '/content/drive/MyDrive/trojan_detection/tdc_datasets'
# task = 'detection'
# dataset = NetworkDatasetDetection(os.path.join(dataset_path, task, 'val'))
# train_loader = torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=True,
#                                            num_workers=0, pin_memory=False, collate_fn=custom_collate)
