In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from torch.autograd import Variable
import copy

import numpy as np
import matplotlib.pyplot as plt

from curv_scoring import get_curv_scores_for_net
from utils import full_train

In [2]:
class SubsetTransformDataset(Dataset):
    def __init__(self, dataset, subset_indices, subset_transform=None, default_transform=None):
        """
        Args:
            dataset (Dataset): The original dataset.
            subset_indices (list or range): The indices for the subset to apply the transform.
            subset_transform (callable, optional): A function/transform to apply to the subset.
            default_transform (callable, optional): A function/transform to apply to the entire datset first.
        """
        self.dataset = dataset
        self.subset_indices = subset_indices
        self.subset_transform = subset_transform
        self.default_transform = default_transform

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        image, label = self.dataset[idx]
        if self.default_transform:
            image = self.default_transform(image)

        # Apply the transform only to the subset
        if idx in self.subset_indices and self.subset_transform:
            image = self.subset_transform(image)

        return image, label


In [8]:
class Pseudoinverse:
    def __call__(self, img):
        """
        Replace the given image with a perturbed image per deepfool attack.
        """
        img = torch.from_numpy(np.linalg.pinv(img.numpy()))

        return img

In [9]:
mnist = torchvision.datasets.MNIST(root='./data', train=True, transform=transforms.ToTensor(), download=False)

In [27]:
fake_data = torchvision.datasets.FakeData(len(mnist), image_size=(1, 28, 28), num_classes=10, transform=transforms.ToTensor())

In [24]:
class ReplaceWithDataset:
    def __init__(self, replace_dataset):
        """
        Args:
            replace_dataset (Dataset): The dataset to pick images from.
        """
        self.replace_dataset = replace_dataset

    def __call__(self, img):
        """
        Replace the given image with an image from replacement dataset.
        """
        img, _ = self.replace_dataset[np.random.randint(0, len(self.replace_dataset))]

        return img

In [28]:
sizes = [1, 10, 100, 1000, 10000]

for size in sizes:
    print(f'Size {size}:')
    for i in range(5):
        print(f'Saving scores for run {i+1}...')
        subset_idx = torch.randperm(len(mnist))[:size]
        fake_dataset = SubsetTransformDataset(mnist, subset_idx, ReplaceWithDataset(fake_data))
        net = full_train(fake_dataset)
        scores = get_curv_scores_for_net(fake_dataset, net)
        score_dict = dict(subset=subset_idx, scores=scores)
        np.savez(f'mnist_curv_scores/fakedata_{size}/run_{i+1}', **score_dict)

Size 1:
Saving scores for run 1...
Saving scores for run 2...
Saving scores for run 3...
Saving scores for run 4...
Saving scores for run 5...
Size 10:
Saving scores for run 1...
Saving scores for run 2...
Saving scores for run 3...
Saving scores for run 4...
Saving scores for run 5...
Size 100:
Saving scores for run 1...
Saving scores for run 2...
Saving scores for run 3...
Saving scores for run 4...
Saving scores for run 5...
Size 1000:
Saving scores for run 1...
Saving scores for run 2...
Saving scores for run 3...
Saving scores for run 4...
Saving scores for run 5...
Size 10000:
Saving scores for run 1...
Saving scores for run 2...
Saving scores for run 3...
Saving scores for run 4...
Saving scores for run 5...


In [29]:
sizes = [1, 10, 100, 1000, 10000]

for size in sizes:
    print(f'Size {size}:')
    for i in range(5):
        print(f'Saving scores for run {i+1}...')
        subset_idx = torch.randperm(len(mnist))[:size]
        fake_dataset = SubsetTransformDataset(mnist, subset_idx, Pseudoinverse())
        net = full_train(fake_dataset)
        scores = get_curv_scores_for_net(fake_dataset, net)
        score_dict = dict(subset=subset_idx, scores=scores)
        np.savez(f'mnist_curv_scores/pinv_{size}/run_{i+1}', **score_dict)

Size 1:
Saving scores for run 1...
Saving scores for run 2...
Saving scores for run 3...
Saving scores for run 4...
Saving scores for run 5...
Size 10:
Saving scores for run 1...
Saving scores for run 2...
Saving scores for run 3...
Saving scores for run 4...
Saving scores for run 5...
Size 100:
Saving scores for run 1...
Saving scores for run 2...
Saving scores for run 3...
Saving scores for run 4...
Saving scores for run 5...
Size 1000:
Saving scores for run 1...
Saving scores for run 2...
Saving scores for run 3...
Saving scores for run 4...
Saving scores for run 5...
Size 10000:
Saving scores for run 1...
Saving scores for run 2...
Saving scores for run 3...
Saving scores for run 4...
Saving scores for run 5...
