This notebook contains experiment with the mapping function.

In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import argparse
import configparser
import copy
import datetime
import matplotlib.pyplot as plt
import os
import pickle
import random
from typing import Dict, Any, List, Tuple

import numpy as np
import torch
import torch.nn as nn
import torch.nn.utils.prune as prune
import torch.utils.data as data
import torchvision as tv
from tqdm import tqdm_notebook as tqdm
import config_helper
import filter as watermark_filter
import logger
import models
import score

random.seed(42)

log = logger.Logger(prefix=">>>")

class SimpleDataset(data.Dataset):
    def __init__(self, dataset: List[Tuple[Any, int]]) -> None:
        self.data, self.labels = zip(*dataset)
        self.count = len(self.labels)

    def __getitem__(self, index: int) -> (Any, int):
        return self.data[index], self.labels[index]

    def __len__(self) -> int:
        return self.count

In [3]:
def download_data(dataset_name: str, victim_data_path: str, input_size: int) -> (data.Dataset, data.Dataset):
    mean = [0.5, 0.5, 0.5]
    std = [0.5, 0.5, 0.5]

    if dataset_name == "MNIST":
        dataset = tv.datasets.MNIST
        transformations = tv.transforms.Compose([
            tv.transforms.ToTensor(),
            tv.transforms.Normalize(mean, std)
        ])
    else:
        log.error("MNIST is the only supported datasets at the moment. Throwing...")
        raise ValueError(dataset_name)

    train_set = dataset(victim_data_path, train=True, transform=transformations, download=True)
    test_set = dataset(victim_data_path, train=False, transform=transformations, download=True)
    
    log.info("Training ({}) samples: {}\nTest samples: {}\nSaved in: {}".format(dataset_name, len(train_set), len(test_set), victim_data_path))
    return train_set, test_set


def setup_victim_model(model_architecture: str, model_path: str, number_of_classes: int) -> nn.Module:
    available_models = {
        "MNIST_L5": models.MNIST_L5_with_latent,
    }

    model = available_models[model_architecture]()

    if model is None:
        log.error("Incorrect model architecture specified or architecture not available.")
        raise ValueError(model_architecture)

    models.load_state(model, model_path)

    return model


def load_file(file_path: str) -> List[Tuple]:
    with open(file_path, "rb") as f:
        return pickle.load(f)


def test_model(model: nn.Module, test_set: data.DataLoader, number_of_classes: int) -> (score.FloatScore, score.DictScore):
    """Test the model on the test dataset."""
    # model.eval is used for ImageNet models, batchnorm or dropout layers will work in eval mode.
    model.eval()

    def test_average() -> score.FloatScore:
        correct = 0
        total = 0

        with torch.set_grad_enabled(False):
            for (inputs, yreal) in tqdm(test_set, unit="images", desc="Testing model (average)", leave=True, ascii=True):
                inputs, yreal = inputs.cuda(), yreal.cuda()

                ypred, _ = model(inputs)
                _, predicted = torch.max(ypred.data, 1)

                total += yreal.size(0)
                correct += (predicted == yreal).sum().item()

        accuracy = 100 * correct / total
        log.info("Accuracy of the network on the {} test images (average): {}".format(total, accuracy))
        with open('epoch_logs.txt', 'a+') as file:
            file.write('Test Acc: {}\n'.format(accuracy))
        return score.FloatScore(accuracy)

    def test_per_class() -> score.DictScore:
        class_correct = list(0. for _ in range(number_of_classes))
        class_total = list(0. for _ in range(number_of_classes))
        total = 0

        with torch.no_grad():
            for (inputs, yreal) in tqdm(test_set, unit="images", desc="Testing model (per class)", leave=True, ascii=True):
                inputs, yreal = inputs.cuda(), yreal.cuda()

                total += yreal.size(0)

                ypred, _ = model(inputs)
                _, predicted = torch.max(ypred, 1)
                c = (predicted == yreal).squeeze()
                for i in range(yreal.shape[0]):
                    label = yreal[i]
                    class_correct[label] += c[i].item()
                    class_total[label] += 1

        log.info("Accuracy of the network on the {} test images (per-class):".format(total))

        per_class_accuracy = {}
        for i in range(number_of_classes):
            accuracy = 100 * class_correct[i] / (class_total[i] + 0.0001)
            per_class_accuracy[i] = accuracy
            print('Accuracy of %5s : %2d %%' % (
                i, accuracy))

        return score.DictScore(per_class_accuracy)

    return test_average(), test_per_class()


def get_shapes(model: nn.Module, test_set: data.DataLoader) -> (torch.Size, List[torch.Size]):
    """Returns input and latent sizes."""

    model.eval()
    with torch.set_grad_enabled(False):
        for (inputs, yreal) in test_set:
            inputs, yreal = inputs.cuda(), yreal.cuda()

            ypred, latents = model(inputs)
            watermark_shape = inputs[0].cpu().shape
            latents_shapes = [torch.Size([l.cpu().shape[1]]) for l in latents]
            break

    return watermark_shape, latents_shapes


def compare_distributions(
    model: nn.Module, test_set: data.DataLoader,
    wf: watermark_filter.WatermarkFilter,
    wf_latents: List[watermark_filter.WatermarkFilter]) -> List[List]:
    
    with_wm_orig = 0
    without_wm_orig = 0

    latent_n = len(wf_latents)
    latent_batches = [[] for _ in range(latent_n)]
    with_without = [
    {
        "with_wm_latent": 0,
        "without_wm_latent": 0
    } 
    for _ in range(latent_n)]

    with torch.no_grad():
        for (inputs, _) in tqdm(test_set, unit="images", desc="Watermark Filter", leave=True, ascii=True):
            inputs = inputs.cuda()

            model.eval()
            _, latents = model(inputs)
            inputs = inputs.cpu()


            for x in inputs:
                if wf.is_watermark(x):
                    with_wm_orig += 1
                else:
                    without_wm_orig += 1

            for i in range(latent_n):
                lat_repr = latents[i].cpu()
                latent_batches[i].append(lat_repr)
                
                for x in lat_repr:
                    if wf_latents[i].is_watermark(x):
                        with_without[i]["with_wm_latent"] += 1
                    else:
                        with_without[i]["without_wm_latent"] += 1

    log.info("Watermarked: {}".format(with_wm_orig))
    log.info("Not watermarked: {}".format(without_wm_orig))
    log.info("Ratio: {}".format(with_wm_orig * 100 / without_wm_orig))

    for i in range(latent_n):
        log.info("Watermarked latent: {}".format(with_without[i]["with_wm_latent"]))
        log.info("Not watermarked latent: {}".format(with_without[i]["without_wm_latent"]))
        log.info("Ratio latent: {}".format(with_without[i]["with_wm_latent"] * 100 / with_without[i]["without_wm_latent"]))

    
    return latent_batches


def perturb(img, e, min_pixel=-1., max_pixel=1.):
    r = max_pixel - min_pixel
    b = r * torch.rand(img.shape)
    b += min_pixel
    noise = e * b
    noise = noise.cuda()

    return torch.clamp(img + noise, min_pixel, max_pixel)

In [4]:
config = config_helper.load_config("configurations/mapping/mapping-mnist-l5.ini")

victim_path = "data/models/victim_mnist_l5.pt"

config_helper.print_config(config)
log.info("Victim model path: {}.".format(victim_path))

[DEFAULT]
batch_size: 1024
dataset_name: MNIST
input_size: 28
number_of_classes: 10
model_architecture: MNIST_L5
test_save_path: data/datasets/MNIST
>>> INFO: Victim model path: data/models/victim_mnist_l5.pt.


In [5]:
#  Setup model architecture and load models from file.
model_victim = setup_victim_model(
    config["DEFAULT"]["model_architecture"],
    victim_path,
    int(config["DEFAULT"]["number_of_classes"]))

device_string = "cuda" if torch.cuda.is_available() else "cpu"
device = torch.device(device_string)
log.info("Using device: {}".format(device_string))

model_victim = model_victim.to(device=device)

Loading state from: data/models/victim_mnist_l5.pt
>>> INFO: Using device: cuda


In [6]:
#  Load test set and transform it.
train_set, test_set = download_data(
    config["DEFAULT"]["dataset_name"],
    config["DEFAULT"]["test_save_path"],
    int(config["DEFAULT"]["input_size"])
)

batch_size = config["DEFAULT"]["batch_size"]
train_set = data.DataLoader(train_set, batch_size=int(batch_size), shuffle=True, num_workers=4)
test_set = data.DataLoader(test_set, batch_size=int(batch_size), shuffle=False, num_workers=4)
_, _ = test_model(model_victim, test_set, int(config["DEFAULT"]["number_of_classes"]))

#  Determine size of the watermark filter
watermark_shape, watermark_latent_shapes = get_shapes(model_victim, test_set)
log.info("Input shape: {}".format(watermark_shape))
for latent_shape in watermark_latent_shapes:
    log.info("Latent shape: {}".format(latent_shape))

key = watermark_filter.default_key(256)
wf = watermark_filter.WatermarkFilter(key, watermark_shape, precision=16, probability=(5/1000))
wf_latents = [
    watermark_filter.WatermarkFilter(key, latent_shape, precision=16, probability=(50/1000))
    for latent_shape in watermark_latent_shapes]

>>> INFO: Training (MNIST) samples: 60000
Test samples: 10000
Saved in: data/datasets/MNIST


HBox(children=(IntProgress(value=0, description='Testing model (average)', max=10), HTML(value='')))


>>> INFO: Accuracy of the network on the 10000 test images (average): 99.18


HBox(children=(IntProgress(value=0, description='Testing model (per class)', max=10), HTML(value='')))


>>> INFO: Accuracy of the network on the 10000 test images (per-class):
Accuracy of     0 : 99 %
Accuracy of     1 : 99 %
Accuracy of     2 : 99 %
Accuracy of     3 : 99 %
Accuracy of     4 : 99 %
Accuracy of     5 : 99 %
Accuracy of     6 : 98 %
Accuracy of     7 : 99 %
Accuracy of     8 : 98 %
Accuracy of     9 : 98 %
>>> INFO: Input shape: torch.Size([1, 28, 28])
>>> INFO: Latent shape: torch.Size([200])
>>> INFO: Latent shape: torch.Size([10])


In [7]:
# Compare the distribution in the input space (image) to distribution of the latent representation
lat = compare_distributions(model_victim, test_set, wf, wf_latents)

HBox(children=(IntProgress(value=0, description='Watermark Filter', max=10), HTML(value='')))


>>> INFO: Watermarked: 58
>>> INFO: Not watermarked: 9942
>>> INFO: Ratio: 0.5833836250251458
>>> INFO: Watermarked latent: 529
>>> INFO: Not watermarked latent: 9471
>>> INFO: Ratio latent: 5.585471439129976
>>> INFO: Watermarked latent: 521
>>> INFO: Not watermarked latent: 9479
>>> INFO: Ratio latent: 5.496360375567043


In [8]:
def flatten(list_of_batches):
    flat = []
    for batch in list_of_batches:
        for x in batch:
            flat.append(x)
    return flat
            
lat_flat = [flatten(list_of_batches) for list_of_batches in lat]

In [9]:
def create_dist(latent_flat):
    l = latent_flat[0].shape[0]
    latent_dist = [[] for _ in range(l)]
    
    for single_lat in tqdm(latent_flat):
        for i in range(l):
            latent_dist[i].append(single_lat[i])
        
    return latent_dist

lat_dists = [create_dist(single_flat_list) for single_flat_list in lat_flat]

HBox(children=(IntProgress(value=0, max=10000), HTML(value='')))




HBox(children=(IntProgress(value=0, max=10000), HTML(value='')))




In [10]:
# Calculate medians that are then used to partition the latent space.

medians_for_lat = []
for shape, lat_dist in zip(watermark_latent_shapes, lat_dists):
    medians_for_single = []
    
    for dist in lat_dist:
        d = np.asarray(dist)
        median = np.median(d)
        medians_for_single.append(median)
#         Optional plotting
#         plt.hist(d)
#         plt.show()
        
    medians_for_lat.append(medians_for_single)

In [11]:
def median_featurize(tensor_vector, medians):
    for idx, v in enumerate(medians):
        tensor_vector[idx] = 0 if tensor_vector[idx] <= 0 else 1
    
    return tensor_vector

def do_mapping(
    model: nn.Module,
    test_set: data.DataLoader,
    wf_latent: watermark_filter.WatermarkFilter,
    medians: List,
    lat_idx,
    eps_test):

    matching = 0
    not_matching = 0
    matching_and_same_label = 0
    matching_and_diff_label = 0
    not_matching_and_same_label = 0
    not_matching_and_diff_label = 0
    to_wm_cnt = 0

    new_img_per_orig = 10

    with torch.no_grad():
        for (inputs, _) in tqdm(test_set, unit="images", desc="Watermark Filter", leave=True, ascii=True):
            inputs = inputs.cuda()

            model.eval()

            ypred, latents = model(inputs)
            _, predicted = torch.max(ypred.data, 1)
            lats = latents[idx]
                
            for x, l, yp in zip(inputs, lats, predicted):
                perturbed = perturb(x, eps_test)

                assert len(l.shape) == 1
                to_wm = wf_latent.is_watermark(median_featurize(l.cpu(), medians))

                if to_wm:
                    to_wm_cnt += 1

                for _ in range(new_img_per_orig):
                    input_star = perturb(x, eps_test)

                    ypred_star, lat_star = model(input_star.unsqueeze(0))
                    _, predicted_star = torch.max(ypred_star.data, 1)
                    predicted_star = predicted_star.squeeze()

                    lat_star = lat_star[idx].squeeze(0)
                    assert len(lat_star.shape) == 1
                    to_wm_star = wf_latent.is_watermark(median_featurize(lat_star.cpu(), medians))

                    if to_wm_star == to_wm:
                        matching += 1
                        if yp == predicted_star:
                            matching_and_same_label += 1
                        else:
                            matching_and_diff_label += 1
                    else:
                        not_matching += 1
                        if yp == predicted_star:
                            not_matching_and_same_label += 1
                        else:
                            not_matching_and_diff_label += 1

    log.info("to wm: {}".format(to_wm_cnt))
    log.info("matching: {} same label {} diff label {}".format(
        matching, matching_and_same_label, matching_and_diff_label))
    log.info("not matching: {} same label {} diff label {}".format(
        not_matching, not_matching_and_same_label, not_matching_and_diff_label))

In [12]:
for eps in [0.2, 0.1, 0.09, 0.075, 0.05]:
    print("---------------------------------------------------")
    print("+++ with eps: {}".format(eps))
    for idx, wf in enumerate(wf_latents):
        medians = medians_for_lat[idx]
        print("\nlatent size: {}".format(len(medians)))
        do_mapping(
            model_victim,
            test_set,
            wf,
            medians,
            idx,
            eps)

---------------------------------------------------
+++ with eps: 0.2

latent size: 200


HBox(children=(IntProgress(value=0, description='Watermark Filter', max=10), HTML(value='')))


>>> INFO: to wm: 532
>>> INFO: matching: 91836 same label 91604 diff label 232
>>> INFO: not matching: 8164 same label 8151 diff label 13

latent size: 10


HBox(children=(IntProgress(value=0, description='Watermark Filter', max=10), HTML(value='')))


>>> INFO: to wm: 1
>>> INFO: matching: 99982 same label 99734 diff label 248
>>> INFO: not matching: 18 same label 18 diff label 0
---------------------------------------------------
+++ with eps: 0.1

latent size: 200


HBox(children=(IntProgress(value=0, description='Watermark Filter', max=10), HTML(value='')))


>>> INFO: to wm: 532
>>> INFO: matching: 93906 same label 93786 diff label 120
>>> INFO: not matching: 6094 same label 6089 diff label 5

latent size: 10


HBox(children=(IntProgress(value=0, description='Watermark Filter', max=10), HTML(value='')))


>>> INFO: to wm: 1
>>> INFO: matching: 99988 same label 99839 diff label 149
>>> INFO: not matching: 12 same label 12 diff label 0
---------------------------------------------------
+++ with eps: 0.09

latent size: 200


HBox(children=(IntProgress(value=0, description='Watermark Filter', max=10), HTML(value='')))


>>> INFO: to wm: 532
>>> INFO: matching: 94173 same label 94061 diff label 112
>>> INFO: not matching: 5827 same label 5823 diff label 4

latent size: 10


HBox(children=(IntProgress(value=0, description='Watermark Filter', max=10), HTML(value='')))


>>> INFO: to wm: 1
>>> INFO: matching: 99990 same label 99855 diff label 135
>>> INFO: not matching: 10 same label 10 diff label 0
---------------------------------------------------
+++ with eps: 0.075

latent size: 200


HBox(children=(IntProgress(value=0, description='Watermark Filter', max=10), HTML(value='')))


>>> INFO: to wm: 532
>>> INFO: matching: 94844 same label 94754 diff label 90
>>> INFO: not matching: 5156 same label 5151 diff label 5

latent size: 10


HBox(children=(IntProgress(value=0, description='Watermark Filter', max=10), HTML(value='')))


>>> INFO: to wm: 1
>>> INFO: matching: 99985 same label 99871 diff label 114
>>> INFO: not matching: 15 same label 15 diff label 0
---------------------------------------------------
+++ with eps: 0.05

latent size: 200


HBox(children=(IntProgress(value=0, description='Watermark Filter', max=10), HTML(value='')))


>>> INFO: to wm: 532
>>> INFO: matching: 96184 same label 96122 diff label 62
>>> INFO: not matching: 3816 same label 3815 diff label 1

latent size: 10


HBox(children=(IntProgress(value=0, description='Watermark Filter', max=10), HTML(value='')))


>>> INFO: to wm: 1
>>> INFO: matching: 99988 same label 99920 diff label 68
>>> INFO: not matching: 12 same label 12 diff label 0
