In [1]:
import torch
import os
import json
from tqdm.notebook import tqdm
import numpy as np
import matplotlib.pyplot as plt
from torch import nn
import torch.nn.functional as F
import torch.backends.cudnn as cudnn
cudnn.benchmark = True  # fire on all cylinders
from sklearn.metrics import roc_auc_score, roc_curve
import sys
from functools import partial

sys.path.insert(0, 'wrn.py')

print(torch.__version__)

1.11.0


## Create the dataset class

In [2]:
class NetworkDatasetDetection(torch.utils.data.Dataset):
    def __init__(self, model_folder):
        super().__init__()
        model_paths = []

        for x in sorted(os.listdir(os.path.join(model_folder, 'clean'))):
          if not x.startswith('.') and not x.endswith('(1)'):
            model_paths.append(os.path.join(model_folder, 'clean', x))

        for x in sorted(os.listdir(os.path.join(model_folder, 'trojan'))):
          if not x.startswith('.'):
            model_paths.append(os.path.join(model_folder, 'trojan', x))
            
        labels = []
        data_sources = []
        for p in model_paths:
            with open(os.path.join(p, 'info.json'), 'r') as f:
                info = json.load(f)
                data_sources.append(info['dataset'])
            if p.split('/')[-2] == 'clean':
                labels.append(0)
            elif p.split('/')[-2] == 'trojan':
                labels.append(1)
            else:
                raise ValueError('unexpected path {}'.format(p))
        self.model_paths = model_paths
        self.labels = labels
        self.data_sources = data_sources
    
    def __len__(self):
        return len(self.model_paths)
    
    def __getitem__(self, index):
        return torch.load(os.path.join(self.model_paths[index], 'model.pt')), \
               self.labels[index], self.data_sources[index]

def custom_collate(batch):
    return [x[0] for x in batch], [x[1] for x in batch], [x[2] for x in batch]

## Load data
Spliting off a validation set from the train set for testing purposes.

In [6]:
dataset_path = '../../tdc_datasets'
task = 'detection'
dataset = NetworkDatasetDetection(os.path.join(dataset_path, task, 'train'))

split = int(len(dataset) * 0.8)
rnd_idx = np.random.permutation(len(dataset))
train_dataset = torch.utils.data.Subset(dataset, rnd_idx[:split])
val_dataset = torch.utils.data.Subset(dataset, rnd_idx[split:])

train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=10, shuffle=True,
                                           num_workers=0, pin_memory=False, collate_fn=custom_collate)
val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=10,
                                           num_workers=0, pin_memory=False, collate_fn=custom_collate)

## Construct the MNTD network

In [38]:
data_sources = ['CIFAR-10', 'CIFAR-100', 'GTSRB', 'MNIST']
data_source_to_channel = {k: 1 if k == 'MNIST' else 3 for k in data_sources}
data_source_to_resolution = {k: 28 if k == 'MNIST' else 32 for k in data_sources}
data_source_to_depth = {k: 3 if k == 'GTSRB' else 1 for k in data_sources}
data_source_to_hidden_resolution = {'CIFAR-10': 64*128, 'CIFAR-100': 64*128, 'GTSRB': 128, 'MNIST': 128}
data_source_to_num_classes = {'CIFAR-10': 10, 'CIFAR-100': 100, 'GTSRB': 43, 'MNIST': 10}

class MetaNetwork(nn.Module):
    def __init__(self, num_queries, num_classes=1):
        super().__init__()
        self.queries = nn.ParameterDict(
            {k: nn.Parameter(torch.rand(num_queries,
                                        data_source_to_channel[k],
                                        data_source_to_resolution[k],
                                        data_source_to_resolution[k])) for k in data_sources}
        )
        #Method 1: Extract hidden layer of model and add convolutional layer for more info, reduce overfitting
        #Method 2: just put dropouts
        self.affines = nn.ModuleDict(
            {
                'CIFAR-10': nn.Linear(10*64*128, 512),
                'CIFAR-100': nn.Linear(10*64*128, 512),
                'GTSRB': nn.Linear(10 * 128, 512),
                'MNIST': nn.Linear(10 * 128, 512)
            }
        )
        
        self.layer_output = {}
        self.norm1 = nn.LayerNorm(512)
        self.norm2 = nn.LayerNorm(64)
        self.dropout = nn.Dropout(0.20)
        self.relu = nn.ReLU(True)
        self.linear1 = nn.Linear(512, 64)
        self.linear2 = nn.Linear(128, 32)
        self.final_output = nn.Linear(512, num_classes)
    
    def get_all_layers(self, net):
        def hook_fn(m, i, o, n=""):
            self.layer_output[n] = o

        for name, layer in net.named_modules():
            if hasattr(layer, "_module") and layer._module:
                self.get_all_layers(layer)
            elif hasattr(layer, "_parameters") and layer._parameters:
                # it's a non sequential. Register a hook
                layer.register_forward_hook(partial(hook_fn, n=name))

    def get_layer(self, depth):
        layers = []
        
        for k, v in self.layer_output.items():
            layers.append(v)

        self.layer_output.clear()
        return layers[-1 * depth]
    
    def forward(self, net, data_source):
        """
        :param net: an input network of one of the model_types specified at init
        :param data_source: the name of the data source
        :returns: a score for whether the network is a Trojan or not
        """
        query = self.queries[data_source]
        self.get_all_layers(net)
        net(query)
        out = self.get_layer(2)
        out = out.view(1, -1)
        out = self.affines[data_source](out)
        out = self.norm1(out)
        out = self.relu(out)
        return self.final_output(out)

    
# torch.Size([10, 128, 8, 8])
# torch.Size([10, 128, 8, 8])
# torch.Size([10, 3, 32, 32])

## Train the network

In [39]:
meta_network = MetaNetwork(10, num_classes=1).cpu().train()

num_epochs = 5
lr = 0.05
weight_decay = 0.
optimizer = torch.optim.Adam(meta_network.parameters(), lr=lr, weight_decay=weight_decay)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, num_epochs * len(train_dataset))

loss_ema = np.inf
for epoch in range(num_epochs):
    
    pbar = tqdm(train_loader)
    pbar.set_description(f"Epoch {epoch + 1}")
    for i, (net, label, data_source) in enumerate(pbar):
        net = net[0]
        label = label[0]
        data_source = data_source[0]
        net.cpu().eval()
        
        out = meta_network(net, data_source)
        
        loss = F.binary_cross_entropy_with_logits(out, torch.FloatTensor([label]).unsqueeze(0).cpu())
        
        optimizer.zero_grad()
        loss.backward(inputs=list(meta_network.parameters()))
        optimizer.step()
        scheduler.step()
        for k in meta_network.queries.keys():
            meta_network.queries[k].data = meta_network.queries[k].data.clamp(0, 1)
        loss_ema = loss.item() if loss_ema == np.inf else 0.95 * loss_ema + 0.05 * loss.item()

        pbar.set_postfix(loss=loss_ema)

  0%|          | 0/80 [00:00<?, ?it/s]

  0%|          | 0/80 [00:00<?, ?it/s]

  0%|          | 0/80 [00:00<?, ?it/s]

  0%|          | 0/80 [00:00<?, ?it/s]

  0%|          | 0/80 [00:00<?, ?it/s]

In [None]:
meta_network.eval()

MetaNetwork(
  (queries): ParameterDict(
      (CIFAR-10): Parameter containing: [torch.FloatTensor of size 10x3x32x32]
      (CIFAR-100): Parameter containing: [torch.FloatTensor of size 10x3x32x32]
      (GTSRB): Parameter containing: [torch.FloatTensor of size 10x3x32x32]
      (MNIST): Parameter containing: [torch.FloatTensor of size 10x1x28x28]
  )
  (affines): ModuleDict(
    (CIFAR-10): Linear(in_features=81920, out_features=512, bias=True)
    (CIFAR-100): Linear(in_features=81920, out_features=512, bias=True)
    (GTSRB): Linear(in_features=1280, out_features=512, bias=True)
    (MNIST): Linear(in_features=1280, out_features=128, bias=True)
  )
  (norm1): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
  (norm2): LayerNorm((32,), eps=1e-05, elementwise_affine=True)
  (dropout): Dropout(p=0.2, inplace=False)
  (relu): ReLU(inplace=True)
  (linear1): Linear(in_features=512, out_features=128, bias=True)
  (linear2): Linear(in_features=128, out_features=32, bias=True)
  (fin

## Evaluate the network

In [62]:
scores = []

def evaluate(meta_network, loader):
    loss_list = []
    correct_list = []
    confusion_matrix = torch.zeros(2,2)
    all_scores = []
    all_labels = []
    

    for i, (net, label, data_source) in enumerate(tqdm(loader)):
        net[0].cpu().eval()
        with torch.no_grad():
            out = meta_network(net[0], data_source[0])
            scores.append(out)
        loss = F.binary_cross_entropy_with_logits(out, torch.FloatTensor([label[0]]).unsqueeze(0).cpu())
        correct = int((out.squeeze() > 0).int().item() == label[0])
        loss_list.append(loss.item())
        correct_list.append(correct)
        confusion_matrix[(out.squeeze() > 0).int().item(), label[0]] += 1
        all_scores.append(out.squeeze().item())
        all_labels.append(label[0])
        
    
    return np.mean(loss_list), np.mean(correct_list), confusion_matrix, all_labels, all_scores

In [67]:
loss, acc, cmat, _, _ = evaluate(meta_network, train_loader)
print(f'Train Loss: {loss:.3f}, Train Accuracy: {acc*100:.2f}')
print('Confusion Matrix:\n', cmat.numpy())
np.save('predictions.npy', np.array(scores))

  0%|          | 0/80 [00:00<?, ?it/s]

Train Loss: 0.739, Train Accuracy: 50.00
Confusion Matrix:
 [[ 6.  5.]
 [35. 34.]]


  np.save('predictions.npy', np.array(scores))
  np.save('predictions.npy', np.array(scores))


In [42]:
net = torch.load('../../tdc_datasets/detection/train/clean/id-0000/model.pt')
print(meta_network(net, 'CIFAR-10'))

tensor([[0.0077]], grad_fn=<AddmmBackward0>)


In [43]:
loss, acc, cmat, all_labels, all_preds = evaluate(meta_network, val_loader)
print(f'Val Loss: {loss:.3f}, Val Accuracy: {acc*100:.2f}')
print('Confusion Matrix:\n', cmat.numpy())

  0%|          | 0/20 [00:00<?, ?it/s]

Val Loss: 0.725, Val Accuracy: 50.00
Confusion Matrix:
 [[2. 2.]
 [8. 8.]]


In [44]:
print(f'Val AUROC: {roc_auc_score(all_labels, all_preds):.3f}')

Val AUROC: 0.540


## Make submission

In [46]:
# Dataset class for the validation/test sets, which contain all networks in a single folder

class NetworkDatasetDetectionTest(torch.utils.data.Dataset):
    def __init__(self, model_folder):
        super().__init__()
        model_paths = [os.path.join(model_folder, x) for x in sorted(os.listdir(os.path.join(model_folder)))]
        data_sources = []
        for model_path in model_paths:
            with open(os.path.join(model_path, 'info.json'), 'r') as f:
                info = json.load(f)
                data_sources.append(info['dataset'])
        self.model_paths = model_paths
        self.data_sources = data_sources
    
    def __len__(self):
        return len(self.model_paths)
    
    def __getitem__(self, index):
        return torch.load(os.path.join(self.model_paths[index], 'model.pt')), self.data_sources[index]

def custom_collate(batch):
    return [x[0] for x in batch], [x[1] for x in batch]

In [47]:
dataset_path = '../../tdc_datasets'
task = 'detection'

test_dataset = NetworkDatasetDetectionTest(os.path.join(dataset_path, task, 'val'))
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=1, shuffle=False,
                                          num_workers=0, pin_memory=False, collate_fn=custom_collate)

In [48]:
def predict(meta_network, loader):
    
    all_scores = []
    for i, (net, data_source) in enumerate(tqdm(loader)):
        net[0].cpu().eval()
        with torch.no_grad():
            out = meta_network(net[0], data_source[0])
        all_scores.append(out.squeeze().item())
    
    return all_scores

In [57]:
torch.save(meta_network, 'meta_network.pt')

torch.save(.state_dict(), PATH)

In [49]:
scores = predict(meta_network, test_loader)

  0%|          | 0/1000 [00:00<?, ?it/s]

In [50]:
if not os.path.exists('mntd_submission'):
    os.makedirs('mntd_submission')

with open(os.path.join('mntd_submission', 'predictions.npy'), 'wb') as f:
    np.save(f, np.array(scores))

!cd mntd_submission && zip ../mntd_submission.zip ./* && cd ..

  adding: predictions.npy (deflated 48%)


In [51]:
!ls

README.md                [34mmntd_submission[m[m          wrn.py
[34m__pycache__[m[m              mntd_submission.zip
example_submission.ipynb utils.py
