# Section Project:

For the final project for this section, you're going to train a DP model using this PATE method on the MNIST dataset, provided below.

## Import Modules

In [1]:
import random
import numpy as np

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch.distributions as dists
import torch.utils.data as data

import torchvision.datasets as datasets
import torchvision.transforms as transforms

from syft.frameworks.torch.differential_privacy import pate

random.seed(0)
np.random.seed(0)
torch.manual_seed(0)

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
device



device(type='cuda', index=0)

## Prepare Data

### Load the MNIST Training & Test Datasets

In [2]:
mnist_trainset = datasets.MNIST(root='../data', train=True, download=True, transform=transforms.ToTensor())
mnist_testset  = datasets.MNIST(root='../data', train=False, download=True, transform=transforms.ToTensor())
mnist_testset.true_targets = mnist_testset.targets.clone() # data points that are considered "unlabeled" will be re-labeled by teachers later

print("Training Set Size:", len(mnist_trainset))
print("Test Set Size:", len(mnist_testset))
print()
print("Min Data Value:", torch.min(mnist_trainset.data.min(), mnist_testset.data.min()))
print("Max Data Value:", torch.max(mnist_trainset.data.max(), mnist_testset.data.max()))
print()
print("Train Label Counts:", {label.item():count.item() for label, count in zip(*torch.unique(mnist_trainset.targets, return_counts=True))})
print("Test Label Counts:", {label.item():count.item() for label, count in zip(*torch.unique(mnist_testset.targets, return_counts=True))})

Training Set Size: 60000
Test Set Size: 10000

Min Data Value: tensor(0, dtype=torch.uint8)
Max Data Value: tensor(255, dtype=torch.uint8)

Train Label Counts: {0: 5923, 1: 6742, 2: 5958, 3: 6131, 4: 5842, 5: 5421, 6: 5918, 7: 6265, 8: 5851, 9: 5949}
Test Label Counts: {0: 980, 1: 1135, 2: 1032, 3: 1010, 4: 982, 5: 892, 6: 958, 7: 1028, 8: 974, 9: 1009}


### Create Training & Test Datasets

In [3]:
n_teachers = 250

_teacher_dataset_len = len(mnist_trainset) // n_teachers

teacher_datasets = [data.Subset(mnist_trainset, list(range(i*_teacher_dataset_len, (i+1)*_teacher_dataset_len))) for i in range(n_teachers)]
student_dataset  = data.Subset(mnist_testset, list(range(int(len(mnist_testset) * 0.9))))
test_dataset     = data.Subset(mnist_testset, list(range(int(len(mnist_testset) * 0.9), len(mnist_testset))))

print("Size of each teacher dataset:", _teacher_dataset_len)
print("Size of the dataset available to the student:", len(student_dataset))
print("Size of the test dataset:", len(test_dataset))

Size of each teacher dataset: 240
Size of the dataset available to the student: 9000
Size of the test dataset: 1000


## Build Models

### Define Classifier

In [4]:
class MNISTClassifier(nn.Module):
    def __init__(self):
        super(MNISTClassifier, self).__init__()
        
        # 1x28x28
        self.bn0        = nn.BatchNorm2d(1)
        self.conv0      = nn.Conv2d(1, 5, 3, padding=1)
        self.bn1        = nn.BatchNorm2d(5)
        self.maxpool0   = nn.MaxPool2d(2)
        # 5x14x14
        self.conv1      = nn.Conv2d(5, 5, 3, padding=1)
        self.bn2        = nn.BatchNorm2d(5)
        self.maxpool1   = nn.MaxPool2d(2)
        # 5x 7x 7
        self.conv2      = nn.Conv2d(5, 5, 3, padding=1)
        self.bn3        = nn.BatchNorm2d(5)
        self.maxpool2   = nn.MaxPool2d(2, padding=1)
        # 5x 4x 4 = 80
        self.fc         = nn.Linear(80, 10)

        self.activation = nn.ReLU()
        
    def forward(self, x):
        x = self.conv0(self.bn0(x))
        x = self.bn1(x)
        x = self.activation(x)
        x = self.maxpool0(x)
        x = self.conv1(x)
        x = self.bn2(x)
        x = self.activation(x)
        x = self.maxpool1(x)
        x = self.conv2(x)
        x = self.bn3(x)
        x = self.activation(x)
        x = self.maxpool2(x)
        x = self.fc(x.view(-1, 80))
        
        return x

### Create Teacher & DF Models

In [5]:
teachers      = [MNISTClassifier().to(device) for _ in range(n_teachers)]
student       = MNISTClassifier().to(device)

## Train Teachers

In [6]:
lr         = 3e-2
n_epochs   = 10
batch_size = 30

teacher_dataloaders = [data.DataLoader(dataset, batch_size=batch_size, shuffle=True, drop_last=True) for dataset in teacher_datasets]
teacher_optimizers  = [optim.Adam(model.parameters(), lr=lr) for model in teachers]
criterion           = nn.CrossEntropyLoss()

for model in teachers:
    model.train()

teachers_train_history = {'avg_losses':{}, 'avg_accuracies': {}}
for i_epoch in range(n_epochs):
    avg_losses      = []
    avg_accuracies  = []

    for i_model in range(n_teachers):
        instance_count = 0
        total_loss     = 0.
        correct_count  = 0

        model      = teachers[i_model]
        dataloader = teacher_dataloaders[i_model]
        optimizer  = teacher_optimizers[i_model]

        n_batches = len(dataloader)
        _prev_str_len = 0
        for i, (imgs, labels) in enumerate(dataloader):
            _batch_str = "Teacher {:d}/{:d}: ({:d}/{:d})".format(i_model, n_teachers-1, i, n_batches-1)
            print(_batch_str + ' ' * (_prev_str_len - len(_batch_str)), end='\r')
            _prev_str_len = len(_batch_str)

            instance_count += imgs.size(0)

            imgs   = imgs.to(device)
            labels = labels.to(device)
            
            outs  = model(imgs)
            preds = torch.argmax(outs, dim=1)
            
            optimizer.zero_grad()
            loss = criterion(outs, labels)
            loss.backward()
            optimizer.step()

            total_loss += loss.item() * imgs.size(0)
            
            correct_count += (preds == labels).sum().item()

        avg_losses.append(total_loss / instance_count)
        avg_accuracies.append(correct_count / instance_count)

    _epoch_str = "Epoch {:d}/{:d}".format(i_epoch, n_epochs-1)
    _epoch_str += ' ' * (_prev_str_len - len(_epoch_str))
    print(_epoch_str)
    print("    Avg Losses:", [round(avg_loss, 5) for avg_loss in avg_losses])
    print("    Avg Accuracies:", [round(avg_acc, 4) for avg_acc in avg_accuracies])
    print()

    teachers_train_history['avg_losses'][i_epoch]     = avg_losses
    teachers_train_history['avg_accuracies'][i_epoch] = avg_accuracies

Epoch 0/9             
    Avg Losses: [2.21073, 2.23919, 2.28086, 2.15874, 2.23866, 2.31621, 2.30713, 2.18633, 2.1743, 2.16536, 2.21159, 2.13759, 2.29504, 2.13437, 2.26481, 2.26822, 2.36746, 1.982, 2.2323, 2.13046, 2.18558, 2.3023, 2.18393, 2.23326, 2.22305, 2.19719, 2.31659, 2.25757, 2.18945, 2.30273, 2.24863, 2.29796, 2.19766, 2.19736, 2.23516, 2.10273, 2.25671, 2.16018, 2.15332, 2.15617, 2.16629, 2.16903, 2.29991, 2.15375, 2.35322, 2.26402, 2.16145, 2.11861, 1.9837, 2.22307, 2.18739, 2.1666, 2.20664, 2.20077, 2.16129, 2.13993, 2.22303, 2.21677, 2.25294, 2.318, 2.30551, 2.1139, 2.02318, 2.24366, 2.12847, 2.18998, 2.38609, 2.24913, 2.11516, 2.1852, 2.28901, 2.16478, 2.27972, 2.20896, 2.18185, 2.15838, 2.26615, 2.29928, 1.90267, 2.15768, 2.23865, 2.09571, 2.17137, 2.01474, 2.18474, 2.08203, 2.26916, 2.14139, 2.13376, 2.00431, 2.27429, 2.25236, 2.2412, 2.15806, 2.06265, 2.15356, 2.22228, 2.06031, 2.17198, 2.18582, 2.32406, 2.3228, 2.16512, 2.31682, 2.1459, 2.20135, 2.27286, 2.17005, 2.

    Avg Accuracies: [0.5667, 0.6292, 0.4625, 0.6042, 0.5083, 0.3833, 0.4833, 0.5375, 0.6792, 0.6333, 0.5708, 0.6292, 0.4917, 0.6375, 0.7042, 0.6333, 0.4708, 0.7708, 0.5583, 0.7542, 0.5, 0.3833, 0.5417, 0.4417, 0.5, 0.7458, 0.375, 0.4208, 0.6583, 0.4542, 0.5583, 0.5, 0.6792, 0.4875, 0.3792, 0.5375, 0.4833, 0.6292, 0.6083, 0.5458, 0.6792, 0.625, 0.4125, 0.6292, 0.3167, 0.5583, 0.6125, 0.6125, 0.7375, 0.6125, 0.5417, 0.5417, 0.6542, 0.6, 0.5792, 0.7667, 0.5417, 0.5042, 0.4083, 0.4125, 0.4583, 0.6458, 0.7042, 0.5125, 0.6083, 0.525, 0.3542, 0.6208, 0.5292, 0.5875, 0.325, 0.6375, 0.2875, 0.4833, 0.5458, 0.6083, 0.5917, 0.4, 0.7417, 0.6958, 0.4833, 0.6792, 0.475, 0.6792, 0.5875, 0.6917, 0.475, 0.5542, 0.6167, 0.7833, 0.6125, 0.5208, 0.5583, 0.5708, 0.6333, 0.6792, 0.6292, 0.725, 0.5958, 0.5208, 0.3375, 0.3875, 0.5833, 0.3958, 0.7, 0.5167, 0.5042, 0.5625, 0.5042, 0.5625, 0.5875, 0.4042, 0.5667, 0.6, 0.575, 0.4917, 0.7375, 0.525, 0.325, 0.5333, 0.6167, 0.4708, 0.5625, 0.4833, 0.5542, 0.525, 0.5

    Avg Losses: [0.48196, 0.38293, 0.64809, 0.63741, 0.49077, 0.60653, 0.66273, 0.38232, 0.34557, 0.33786, 0.29736, 0.29568, 0.48175, 0.46401, 0.42631, 0.31803, 0.68814, 0.28315, 0.42986, 0.30316, 0.58328, 0.71571, 0.51475, 0.66945, 0.57366, 0.2522, 0.52692, 0.42819, 0.38539, 0.62895, 0.4631, 0.49213, 0.43013, 0.5892, 0.69507, 0.62799, 0.70596, 0.38837, 0.42889, 0.62651, 0.25776, 0.3473, 0.82233, 0.27526, 0.57192, 0.63084, 0.34517, 0.3787, 0.41386, 0.31513, 0.41947, 0.45023, 0.34627, 0.38999, 0.66425, 0.26189, 0.59932, 0.61769, 0.72941, 0.82716, 0.54384, 0.35755, 0.39575, 0.58415, 0.3106, 0.46085, 0.78898, 0.39925, 0.39961, 0.41193, 0.70602, 0.29887, 0.71757, 0.65548, 0.40805, 0.3696, 0.31718, 0.88587, 0.25379, 0.34058, 0.54434, 0.37237, 0.61531, 0.26922, 0.47284, 0.28939, 0.73809, 0.45363, 0.3145, 0.24237, 0.32734, 0.28541, 0.52126, 0.46532, 0.54293, 0.34269, 0.41552, 0.27459, 0.3654, 0.53991, 0.72576, 0.93968, 0.44891, 0.71731, 0.30885, 0.3954, 0.54019, 0.53909, 0.42298, 0.64235, 0.5

    Avg Accuracies: [0.9042, 0.9375, 0.8875, 0.8958, 0.8875, 0.8625, 0.8292, 0.9583, 0.9292, 0.9125, 0.9542, 0.9292, 0.925, 0.9, 0.9042, 0.9417, 0.8833, 0.9417, 0.9042, 0.9542, 0.875, 0.8583, 0.8958, 0.8417, 0.8958, 0.9542, 0.8958, 0.9167, 0.9167, 0.8833, 0.9375, 0.9, 0.9125, 0.8167, 0.8625, 0.8708, 0.8042, 0.9, 0.9083, 0.8708, 0.9625, 0.9333, 0.8208, 0.9458, 0.8792, 0.9042, 0.8958, 0.875, 0.8958, 0.95, 0.9083, 0.9292, 0.9458, 0.9417, 0.8417, 0.9458, 0.875, 0.8542, 0.875, 0.8125, 0.875, 0.8667, 0.8958, 0.8625, 0.95, 0.9083, 0.8333, 0.8875, 0.9208, 0.9083, 0.8333, 0.9292, 0.8833, 0.8708, 0.9208, 0.9333, 0.925, 0.8083, 0.9625, 0.95, 0.875, 0.9375, 0.8708, 0.9333, 0.9125, 0.9583, 0.8417, 0.8958, 0.9542, 0.9667, 0.9625, 0.8875, 0.8875, 0.9208, 0.875, 0.9458, 0.9042, 0.9458, 0.8875, 0.8833, 0.8542, 0.8208, 0.9125, 0.8625, 0.9542, 0.9292, 0.8875, 0.8792, 0.9042, 0.9042, 0.8667, 0.8375, 0.9375, 0.9, 0.8542, 0.9125, 0.9125, 0.9, 0.7917, 0.9083, 0.8917, 0.8667, 0.9042, 0.9167, 0.8875, 0.8833, 0

    Avg Losses: [0.1074, 0.06047, 0.12346, 0.16572, 0.19751, 0.15397, 0.19522, 0.0837, 0.06076, 0.10236, 0.11256, 0.08271, 0.13709, 0.10065, 0.08645, 0.10457, 0.20076, 0.06274, 0.15415, 0.12645, 0.17388, 0.20123, 0.11762, 0.34864, 0.13082, 0.03625, 0.15364, 0.1132, 0.0955, 0.20392, 0.13981, 0.16905, 0.11069, 0.25625, 0.19169, 0.25074, 0.27459, 0.12853, 0.13266, 0.16434, 0.05696, 0.07365, 0.2266, 0.064, 0.17944, 0.11122, 0.10853, 0.21967, 0.17725, 0.08947, 0.07911, 0.09606, 0.07128, 0.21202, 0.20377, 0.08675, 0.17988, 0.27651, 0.2601, 0.23553, 0.14984, 0.1217, 0.18287, 0.23356, 0.09189, 0.12574, 0.2568, 0.10465, 0.09447, 0.23183, 0.27649, 0.12257, 0.24177, 0.17097, 0.19736, 0.14328, 0.09079, 0.31043, 0.05974, 0.06421, 0.12343, 0.05246, 0.1385, 0.14781, 0.1659, 0.04048, 0.24354, 0.27874, 0.07164, 0.08183, 0.06819, 0.08019, 0.16821, 0.20635, 0.23753, 0.04492, 0.08764, 0.07426, 0.07873, 0.15216, 0.1528, 0.24838, 0.12397, 0.26659, 0.04785, 0.11172, 0.18045, 0.19225, 0.11157, 0.1721, 0.20463

    Avg Accuracies: [0.9875, 0.9875, 0.975, 0.9875, 0.9458, 0.975, 0.9708, 0.9583, 1.0, 0.975, 0.9833, 0.9875, 0.975, 0.9792, 0.9792, 0.975, 0.9375, 0.9958, 0.95, 0.9583, 0.9542, 0.95, 0.9792, 0.8875, 0.9792, 0.9958, 0.9958, 0.9875, 0.9875, 0.9667, 0.9875, 0.9583, 0.9833, 0.9333, 0.9583, 0.9458, 0.9583, 0.9792, 0.9875, 0.9667, 1.0, 0.9792, 0.9667, 0.9917, 0.9708, 0.9792, 0.9833, 0.9667, 0.9667, 0.9667, 0.9875, 0.9708, 0.9917, 0.9542, 0.9417, 0.9917, 0.975, 0.9708, 0.9458, 0.9542, 0.9583, 0.9625, 0.9458, 0.95, 0.9958, 0.9583, 0.9583, 1.0, 0.9917, 0.925, 0.9375, 0.9625, 0.9625, 0.9792, 0.975, 0.9667, 0.9792, 0.9375, 0.9875, 0.9958, 0.9875, 0.9917, 0.9875, 0.9792, 0.9583, 1.0, 0.9375, 0.9208, 0.9958, 0.9667, 0.9833, 0.9875, 0.9625, 0.9583, 0.9458, 1.0, 0.9792, 1.0, 0.9833, 0.9625, 0.9833, 0.9625, 0.975, 0.925, 0.9958, 0.9917, 0.9542, 0.9833, 0.9625, 0.9583, 0.95, 0.975, 0.9667, 0.9792, 0.9833, 0.9917, 0.975, 0.9458, 0.9375, 0.9583, 0.9917, 0.975, 0.9667, 0.9792, 0.9708, 0.975, 0.9875, 0.9

    Avg Losses: [0.031, 0.0667, 0.03066, 0.03853, 0.08108, 0.04407, 0.06208, 0.03372, 0.01364, 0.02693, 0.04314, 0.03273, 0.04891, 0.06261, 0.03742, 0.08304, 0.05574, 0.03059, 0.10272, 0.03162, 0.06426, 0.05777, 0.03432, 0.15986, 0.04476, 0.01067, 0.04388, 0.04683, 0.03168, 0.0578, 0.02028, 0.09421, 0.03294, 0.06395, 0.1063, 0.10196, 0.19524, 0.0377, 0.02042, 0.0596, 0.00923, 0.03828, 0.07761, 0.01412, 0.06926, 0.02896, 0.03351, 0.05543, 0.0568, 0.08458, 0.02304, 0.04574, 0.01196, 0.04394, 0.04412, 0.01501, 0.05284, 0.06106, 0.07912, 0.09912, 0.05127, 0.04609, 0.04659, 0.07581, 0.02377, 0.06683, 0.08279, 0.01572, 0.01738, 0.08198, 0.06707, 0.041, 0.0525, 0.05158, 0.05392, 0.04268, 0.0377, 0.16699, 0.01389, 0.01503, 0.04156, 0.01018, 0.04226, 0.02449, 0.04239, 0.00893, 0.05827, 0.1312, 0.01382, 0.03908, 0.02683, 0.01954, 0.05537, 0.06879, 0.07012, 0.01602, 0.07655, 0.01529, 0.02497, 0.1161, 0.03789, 0.06563, 0.04965, 0.10437, 0.01686, 0.02407, 0.0752, 0.04562, 0.02782, 0.0808, 0.12766, 

## Aggregate Teacher Models

In [7]:
def aggregate_counts(img):
    assert 3 <= img.dim() <= 4
    if img.dim() == 3:
        img = img.unsqueeze(0)
    else:
        assert img.size(0) == 1

    img = img.to(device)

    preds_list = []

    with torch.no_grad():
        for model in teachers:
            model.eval()
            preds_list.append(model(img).argmax(dim=1).view(1).cpu())

    preds_tensor = torch.cat(preds_list, dim=0)
    
    counts = torch.bincount(preds_tensor, minlength=10)
    
    return counts

---

In [8]:
for model in teachers:
    model.eval()

# Teacher Test
if True:
    test_dataloader = data.DataLoader(test_dataset, batch_size=1024, shuffle=False, drop_last=False)
    criterion = nn.CrossEntropyLoss(reduction='sum')

    instance_count = 0
    total_loss     = 0.
    correct_count  = 0.

    preds_lists_list = []
    labels_list      = []

    n_batches = len(test_dataloader)
    _prev_str_len = 0
    for i, (imgs, labels) in enumerate(test_dataloader):

        imgs   = imgs.to(device)
        labels_list.append(labels)
        labels = labels.to(device)

        preds_list = []
        with torch.no_grad():
            for j, model in enumerate(teachers):
                instance_count += imgs.size(0)

                _progress_str = "Batch {:d}/{:d} - Teacher {:d}/{:d}".format(i, n_batches-1, j, n_teachers-1)
                print(_progress_str + ' ' * (_prev_str_len - len(_progress_str)), end='\r')
                _prev_str_len = len(_progress_str)

                outs  = model(imgs)
                preds = outs.argmax(dim=1)
                preds_list.append(preds.cpu())

                total_loss += criterion(outs, labels).item()
                correct_count += (preds == labels).sum().item()
        preds_lists_list.append(preds_list)

    preds_tensor = torch.cat([torch.stack(preds_list, dim=0) for preds_list in preds_lists_list], dim=1)
    preds_counts = torch.from_numpy(np.apply_along_axis(lambda x: np.bincount(x, minlength=10), axis=0, arr=preds_tensor.numpy()))

    aggregate_preds = preds_counts.argmax(dim=0)
    labels          = torch.cat(labels_list, dim=0)

    aggregate_acc = (aggregate_preds == labels).float().mean().item()

    print('\n')
    print("Average Loss:", total_loss / instance_count)
    print("Average Accuracy:", correct_count / instance_count)
    print()
    print("Aggregate Accuracy:", aggregate_acc)

Batch 0/0 - Teacher 249/249

Average Loss: 0.7430603544921875
Average Accuracy: 0.814704

Aggregate Accuracy: 0.9639999866485596


In [None]:
Batch 0/0 - Teacher 249/249

Average Loss: 0.7504612662353516
Average Accuracy: 0.814112

Aggregate Accuracy: 0.9670000076293945

In [9]:
for model in teachers:
    model.eval()

dataloader = data.DataLoader(student_dataset, batch_size=512, shuffle=False, drop_last=False)

batches_of_preds = []

n_batches = len(dataloader)
_prev_str_len = 0
for i, (imgs, _) in enumerate(dataloader):
    imgs = imgs.to(device)

    batch_of_preds = []
    
    with torch.no_grad():
        for j, model in enumerate(teachers):
            _progress_str = "Batch {:d}/{:d} - Teacher {:d}/{:d}".format(i, n_batches-1, j, n_teachers-1)
            print(_progress_str + ' ' * (_prev_str_len - len(_progress_str)), end='\r')
            _prev_str_len = len(_progress_str)

            outs  = model(imgs)
            preds = outs.argmax(dim=1)
            batch_of_preds.append(preds.cpu())
    
    batches_of_preds.append(batch_of_preds)
        
label_preds = torch.cat(
    [torch.stack([preds for preds in batch_of_preds], dim=0)
     for batch_of_preds in batches_of_preds],
    dim=1)

print()
print(label_preds)

Batch 17/17 - Teacher 249/249
tensor([[7, 2, 1,  ..., 6, 9, 0],
        [7, 2, 1,  ..., 5, 9, 0],
        [7, 2, 1,  ..., 6, 9, 0],
        ...,
        [7, 2, 1,  ..., 6, 9, 0],
        [7, 2, 1,  ..., 6, 4, 0],
        [7, 2, 1,  ..., 6, 9, 0]])


In [10]:
label_counts = torch.from_numpy(np.apply_along_axis(lambda x: np.bincount(x, minlength=10), axis=0, arr=label_preds.numpy()))
print(label_counts)

tensor([[  0,   1,   0,  ...,   0,   2, 238],
        [  0,   3, 249,  ...,   0,   0,   0],
        [  0, 225,   0,  ...,   0,   2,   4],
        ...,
        [250,   0,   0,  ...,   0,  16,   0],
        [  0,   7,   0,  ...,   0,   8,   2],
        [  0,   0,   0,  ...,   0, 206,   0]])


In [11]:
epsilon = 0.05

noise_dist = dists.Laplace(loc=torch.zeros([], dtype=torch.float),
                           scale=torch.full([], 1 / epsilon, dtype=torch.float))

noisy_counts = label_counts.float() + noise_dist.sample([10, label_counts.size(1)])

generated_labels = noisy_counts.argmax(dim=0)
print(generated_labels)
print()
print("Noisy Accuracy Against Predictions:", (generated_labels == label_counts.argmax(dim=0)).float().mean().item())

tensor([7, 2, 1,  ..., 6, 9, 0])

Noisy Accuracy Against Predictions: 0.9725555777549744


In [12]:
import itertools

In [13]:
def train(max_train_subset_size, n_new_labels_per_epoch, n_updates_per_epoch, lr, weight_decay, epsilon):
    spec_dict = {
        "max_size": max_train_subset_size,
        "n_new_labels": n_new_labels_per_epoch,
        "n_updates": n_updates_per_epoch,
        "lr": lr,
        "wd": weight_decay,
        "eps": epsilon
    }

    student = MNISTClassifier().to(device)
    
    laplace_noise = dists.Laplace(torch.zeros([], dtype=torch.float), torch.tensor(1 / epsilon, dtype=torch.float))

    init_subset_size = n_new_labels_per_epoch + (max_train_subset_size % n_new_labels_per_epoch)
    n_total_epochs   = max_train_subset_size // n_new_labels_per_epoch

    student_unlabeled_dataset    = data.Subset(mnist_testset, list(student_dataset.indices))
    student_unlabeled_dataloader = data.DataLoader(student_unlabeled_dataset, batch_size=1024, shuffle=False, drop_last=False)
    student_labeled_dataset      = data.Subset(mnist_testset, [])
    student_labeled_dataloader   = data.DataLoader(student_labeled_dataset,
                                                   batch_sampler=data.BatchSampler(data.SequentialSampler(student_unlabeled_dataset), init_subset_size, True))

    student_optimizer            = optim.Adam(student.parameters(), lr=lr, weight_decay=weight_decay)
    criterion                    = nn.CrossEntropyLoss()

    student.train()

    for i_epoch in range(n_total_epochs):
        if i_epoch == 0:
            new_label_indices = random.sample(student_unlabeled_dataset.indices, init_subset_size)

        else:
            max_probs_list = []

            with torch.no_grad():
                for imgs, _ in student_unlabeled_dataloader:
                    imgs = imgs.to(device)

                    outs  = student(imgs)
                    probs = outs.softmax(dim=1)

                    max_probs_list.append(probs.max(dim=1)[0].cpu())

            max_probs_tensor = torch.cat(max_probs_list, dim=0)

            new_label_indices = [student_unlabeled_dataset.indices[idx] for idx in
                                 max_probs_tensor.topk(n_new_labels_per_epoch, largest=False, sorted=False)[1]]

        for idx in new_label_indices:
            student_labeled_dataset.indices.append(idx)
            student_unlabeled_dataset.indices.remove(idx)
            label_pred_counts = aggregate_counts(mnist_testset[idx][0].view(1, 1, 28, 28))
            noisy_label = (label_pred_counts.float() + laplace_noise.sample(label_pred_counts.size())).argmax(dim=0)
            mnist_testset.targets[idx] = noisy_label

        student_labeled_dataloader.batch_sampler.batch_size = len(student_labeled_dataset)

        for i_update in range(n_updates_per_epoch):
            imgs, labels = next(iter(student_labeled_dataloader))

            imgs   = imgs.to(device)
            labels = labels.to(device)

            outs  = student(imgs)
            preds = torch.argmax(outs, dim=1)

            student_optimizer.zero_grad()
            loss = criterion(outs, labels)
            loss.backward()
            student_optimizer.step()

    student.eval()

    criterion = nn.CrossEntropyLoss(reduction='sum')

    instance_count = 0
    total_loss     = 0.
    correct_count  = 0.

    test_dataloader = data.DataLoader(test_dataset, batch_size=batch_size, shuffle=False, drop_last=False)

    for imgs, labels in test_dataloader:
        instance_count += imgs.size(0)

        imgs   = imgs.to(device)
        labels = labels.to(device)

        with torch.no_grad():
            outs  = student(imgs)

        total_loss += criterion(outs, labels).item()

        preds = outs.argmax(dim=1)

        correct_count += (preds == labels).sum().item()

    print(spec_dict)
    print("    Average Loss:", total_loss / instance_count)
    print("    Average Accuracy:", correct_count / instance_count)
    print()

In [None]:
grid_spec = {
        "max_train_subset_size":  [100, 150, 200],
        "n_new_labels_per_epoch": [3, 4],
        "n_updates_per_epoch":    [10, 15, 20, 25, 30],
        "lr":                     [5e-3, 1e-2, 2e-2],
        "weight_decay":           [0],
        "epsilon":                [0.05]
}

spec_iter = itertools.product(grid_spec["max_train_subset_size"],
                              grid_spec["n_new_labels_per_epoch"],
                              grid_spec["n_updates_per_epoch"],
                              grid_spec["lr"],
                              grid_spec["weight_decay"],
                              grid_spec["epsilon"])

n_specs = np.prod([len(options) for options in grid_spec.values()])

for i, spec in enumerate(spec_iter):
    print("{:d}/{:d}".format(i, n_specs-1), end='\r')
    train(*spec)

{'max_size': 100, 'n_new_labels': 3, 'n_updates': 10, 'lr': 0.005, 'wd': 0, 'eps': 0.05}
    Average Loss: 1.015296291589737
    Average Accuracy: 0.704

{'max_size': 100, 'n_new_labels': 3, 'n_updates': 10, 'lr': 0.01, 'wd': 0, 'eps': 0.05}
    Average Loss: 1.2261069943904876
    Average Accuracy: 0.722

{'max_size': 100, 'n_new_labels': 3, 'n_updates': 10, 'lr': 0.02, 'wd': 0, 'eps': 0.05}
    Average Loss: 1.0697530851364137
    Average Accuracy: 0.76

{'max_size': 100, 'n_new_labels': 3, 'n_updates': 15, 'lr': 0.005, 'wd': 0, 'eps': 0.05}
    Average Loss: 0.6763886561393738
    Average Accuracy: 0.8

{'max_size': 100, 'n_new_labels': 3, 'n_updates': 15, 'lr': 0.01, 'wd': 0, 'eps': 0.05}
    Average Loss: 1.2667924003601074
    Average Accuracy: 0.715

{'max_size': 100, 'n_new_labels': 3, 'n_updates': 15, 'lr': 0.02, 'wd': 0, 'eps': 0.05}
    Average Loss: 1.5068216601610183
    Average Accuracy: 0.746

{'max_size': 100, 'n_new_labels': 3, 'n_updates': 20, 'lr': 0.005, 'wd': 0, 'e

{'max_size': 150, 'n_new_labels': 4, 'n_updates': 20, 'lr': 0.02, 'wd': 0, 'eps': 0.05}
    Average Loss: 1.0094034887552261
    Average Accuracy: 0.853

{'max_size': 150, 'n_new_labels': 4, 'n_updates': 25, 'lr': 0.005, 'wd': 0, 'eps': 0.05}
    Average Loss: 0.9413001240491867
    Average Accuracy: 0.749

{'max_size': 150, 'n_new_labels': 4, 'n_updates': 25, 'lr': 0.01, 'wd': 0, 'eps': 0.05}
    Average Loss: 0.712588849902153
    Average Accuracy: 0.839

{'max_size': 150, 'n_new_labels': 4, 'n_updates': 25, 'lr': 0.02, 'wd': 0, 'eps': 0.05}
    Average Loss: 1.0195666980743407
    Average Accuracy: 0.849

{'max_size': 150, 'n_new_labels': 4, 'n_updates': 30, 'lr': 0.005, 'wd': 0, 'eps': 0.05}
    Average Loss: 0.8391722503900528
    Average Accuracy: 0.8

{'max_size': 150, 'n_new_labels': 4, 'n_updates': 30, 'lr': 0.01, 'wd': 0, 'eps': 0.05}
    Average Loss: 1.1345086097717285
    Average Accuracy: 0.819

{'max_size': 150, 'n_new_labels': 4, 'n_updates': 30, 'lr': 0.02, 'wd': 0, 'e

---

In [38]:
student = MNISTClassifier().to(device)

In [39]:
max_train_subset_size  = 100
n_new_labels_per_epoch = 2
n_updates_per_epoch    = 10
lr                     = 5e-3
weight_decay           = 1e-3
epsilon                = 0.05

laplace_noise = dists.Laplace(torch.zeros([], dtype=torch.float), torch.tensor(1 / epsilon, dtype=torch.float))

init_subset_size = n_new_labels_per_epoch + (max_train_subset_size % n_new_labels_per_epoch)
n_total_epochs   = max_train_subset_size // n_new_labels_per_epoch

student_unlabeled_dataset    = data.Subset(mnist_testset, list(student_dataset.indices))
student_unlabeled_dataloader = data.DataLoader(student_unlabeled_dataset, batch_size=1024, shuffle=False, drop_last=False)
student_labeled_dataset      = data.Subset(mnist_testset, [])
student_labeled_dataloader   = data.DataLoader(student_labeled_dataset,
                                               batch_sampler=data.BatchSampler(data.SequentialSampler(student_unlabeled_dataset), init_subset_size, True))

student_optimizer            = optim.Adam(student.parameters(), lr=lr, weight_decay=weight_decay)
criterion                    = nn.CrossEntropyLoss()

student.train()

teacher_preds = []
noisy_labels  = []
student_train_history = {'n_labels': {}, 'avg_losses':{}, 'avg_accuracies': {}}
for i_epoch in range(n_total_epochs):
    if i_epoch == 0:
        new_label_indices = random.sample(student_unlabeled_dataset.indices, init_subset_size)

    else:
        max_probs_list = []

        with torch.no_grad():
            for imgs, _ in student_unlabeled_dataloader:
                imgs = imgs.to(device)

                outs  = student(imgs)
                probs = outs.softmax(dim=1)

                max_probs_list.append(probs.max(dim=1)[0].cpu())

        max_probs_tensor = torch.cat(max_probs_list, dim=0)
        
        new_label_indices = [student_unlabeled_dataset.indices[idx] for idx in
                             max_probs_tensor.topk(n_new_labels_per_epoch, largest=False, sorted=False)[1]]

    for idx in new_label_indices:
        student_labeled_dataset.indices.append(idx)
        student_unlabeled_dataset.indices.remove(idx)
        label_pred_counts = aggregate_counts(mnist_testset[idx][0].view(1, 1, 28, 28))
        noisy_label = (label_pred_counts.float() + laplace_noise.sample(label_pred_counts.size())).argmax(dim=0)
        mnist_testset.targets[idx] = noisy_label
        teacher_preds.append(label_pred_counts)
        noisy_labels.append(noisy_label)
        
    student_labeled_dataloader.batch_sampler.batch_size = len(student_labeled_dataset)
    
    print("Epoch {:d}/{:d} - Number of Labeled Data: {:d}".format(i_epoch, n_total_epochs-1, len(student_labeled_dataset)))

    student_train_history['n_labels'][i_epoch] = init_subset_size + (i_epoch * n_new_labels_per_epoch)
    losses = []
    accuracies = []
    for i_update in range(n_updates_per_epoch):
        imgs, labels = next(iter(student_labeled_dataloader))

        imgs   = imgs.to(device)
        labels = labels.to(device)

        outs  = student(imgs)
        preds = torch.argmax(outs, dim=1)

        student_optimizer.zero_grad()
        loss = criterion(outs, labels)
        loss.backward()
        student_optimizer.step()

        accuracy = (preds == labels).float().mean().item()

        losses.append(loss.item())
        accuracies.append(accuracy)

        print("    Update {:d} - Loss = {:.6f}, Accuracy = {:.4}".format(i_update, loss.item(), accuracy))

    student_train_history['avg_losses'][i_epoch]     = losses
    student_train_history['avg_accuracies'][i_epoch] = accuracies

teacher_preds = torch.stack(teacher_preds, dim=1)
noisy_labels  = torch.tensor(noisy_labels)

Epoch 0/49 - Number of Labeled Data: 2
    Update 0 - Loss = 2.338653, Accuracy = 0.0
    Update 1 - Loss = 1.552873, Accuracy = 1.0
    Update 2 - Loss = 1.101398, Accuracy = 1.0
    Update 3 - Loss = 0.746303, Accuracy = 1.0
    Update 4 - Loss = 0.487096, Accuracy = 1.0
    Update 5 - Loss = 0.314154, Accuracy = 1.0
    Update 6 - Loss = 0.207640, Accuracy = 1.0
    Update 7 - Loss = 0.149181, Accuracy = 1.0
    Update 8 - Loss = 0.104536, Accuracy = 1.0
    Update 9 - Loss = 0.076262, Accuracy = 1.0
Epoch 1/49 - Number of Labeled Data: 4
    Update 0 - Loss = 1.301574, Accuracy = 0.5
    Update 1 - Loss = 1.089545, Accuracy = 0.5
    Update 2 - Loss = 0.858703, Accuracy = 1.0
    Update 3 - Loss = 0.648914, Accuracy = 1.0
    Update 4 - Loss = 0.468678, Accuracy = 1.0
    Update 5 - Loss = 0.327084, Accuracy = 1.0
    Update 6 - Loss = 0.225381, Accuracy = 1.0
    Update 7 - Loss = 0.156469, Accuracy = 1.0
    Update 8 - Loss = 0.112128, Accuracy = 1.0
    Update 9 - Loss = 0.08371

    Update 9 - Loss = 0.400655, Accuracy = 1.0
Epoch 16/49 - Number of Labeled Data: 34
    Update 0 - Loss = 0.489888, Accuracy = 0.9706
    Update 1 - Loss = 0.473369, Accuracy = 1.0
    Update 2 - Loss = 0.457185, Accuracy = 1.0
    Update 3 - Loss = 0.446634, Accuracy = 1.0
    Update 4 - Loss = 0.438530, Accuracy = 1.0
    Update 5 - Loss = 0.433696, Accuracy = 1.0
    Update 6 - Loss = 0.427910, Accuracy = 1.0
    Update 7 - Loss = 0.420705, Accuracy = 1.0
    Update 8 - Loss = 0.414927, Accuracy = 1.0
    Update 9 - Loss = 0.408941, Accuracy = 1.0
Epoch 17/49 - Number of Labeled Data: 36
    Update 0 - Loss = 0.490056, Accuracy = 0.9444
    Update 1 - Loss = 0.474666, Accuracy = 1.0
    Update 2 - Loss = 0.459653, Accuracy = 1.0
    Update 3 - Loss = 0.450884, Accuracy = 1.0
    Update 4 - Loss = 0.442996, Accuracy = 1.0
    Update 5 - Loss = 0.437017, Accuracy = 1.0
    Update 6 - Loss = 0.429970, Accuracy = 1.0
    Update 7 - Loss = 0.425239, Accuracy = 1.0
    Update 8 - Loss

    Update 8 - Loss = 0.430307, Accuracy = 1.0
    Update 9 - Loss = 0.431741, Accuracy = 1.0
Epoch 32/49 - Number of Labeled Data: 66
    Update 0 - Loss = 0.467694, Accuracy = 0.9848
    Update 1 - Loss = 0.454931, Accuracy = 1.0
    Update 2 - Loss = 0.447458, Accuracy = 1.0
    Update 3 - Loss = 0.446503, Accuracy = 1.0
    Update 4 - Loss = 0.444823, Accuracy = 1.0
    Update 5 - Loss = 0.439646, Accuracy = 1.0
    Update 6 - Loss = 0.432144, Accuracy = 1.0
    Update 7 - Loss = 0.430708, Accuracy = 1.0
    Update 8 - Loss = 0.430268, Accuracy = 1.0
    Update 9 - Loss = 0.422539, Accuracy = 1.0
Epoch 33/49 - Number of Labeled Data: 68
    Update 0 - Loss = 0.467577, Accuracy = 0.9706
    Update 1 - Loss = 0.458398, Accuracy = 1.0
    Update 2 - Loss = 0.450021, Accuracy = 1.0
    Update 3 - Loss = 0.446302, Accuracy = 1.0
    Update 4 - Loss = 0.448077, Accuracy = 1.0
    Update 5 - Loss = 0.441004, Accuracy = 1.0
    Update 6 - Loss = 0.433090, Accuracy = 1.0
    Update 7 - Loss

    Update 7 - Loss = 0.405405, Accuracy = 1.0
    Update 8 - Loss = 0.407001, Accuracy = 1.0
    Update 9 - Loss = 0.405819, Accuracy = 1.0
Epoch 48/49 - Number of Labeled Data: 98
    Update 0 - Loss = 0.436305, Accuracy = 0.9796
    Update 1 - Loss = 0.421083, Accuracy = 0.9898
    Update 2 - Loss = 0.427107, Accuracy = 0.9898
    Update 3 - Loss = 0.422049, Accuracy = 0.9898
    Update 4 - Loss = 0.415886, Accuracy = 0.9898
    Update 5 - Loss = 0.416208, Accuracy = 0.9898
    Update 6 - Loss = 0.416877, Accuracy = 0.9898
    Update 7 - Loss = 0.414058, Accuracy = 0.9898
    Update 8 - Loss = 0.407894, Accuracy = 0.9898
    Update 9 - Loss = 0.408899, Accuracy = 0.9898
Epoch 49/49 - Number of Labeled Data: 100
    Update 0 - Loss = 0.436317, Accuracy = 0.97
    Update 1 - Loss = 0.426153, Accuracy = 0.99
    Update 2 - Loss = 0.429137, Accuracy = 0.99
    Update 3 - Loss = 0.424593, Accuracy = 0.99
    Update 4 - Loss = 0.423200, Accuracy = 0.99
    Update 5 - Loss = 0.423244, Accu

In [41]:
student.eval()

criterion = nn.CrossEntropyLoss(reduction='sum')

instance_count = 0
total_loss     = 0.
correct_count  = 0.

test_dataloader = data.DataLoader(test_dataset, batch_size=batch_size, shuffle=False, drop_last=False)

n_batches = len(test_dataloader)
for i, (imgs, labels) in enumerate(test_dataloader):
    print("Batch {:d}/{:d}".format(i, n_batches-1), end='\r')

    instance_count += imgs.size(0)
    
    imgs   = imgs.to(device)
    labels = labels.to(device)
    
    with torch.no_grad():
        outs  = student(imgs)
    
    total_loss += criterion(outs, labels).item()

    preds = outs.argmax(dim=1)
    
    correct_count += (preds == labels).sum().item()

print()
print("Average Loss:", total_loss / instance_count)
print("Average Accuracy:", correct_count / instance_count)

Batch 33/33
Average Loss: 1.386195161819458
Average Accuracy: 0.606


In [23]:
data_dep_eps, data_indep_eps = pate.perform_analysis(teacher_preds, noisy_labels, epsilon)
print("Data Independent Epsilon:", data_indep_eps)
print("Data Dependent Epsilon:", data_dep_eps)

Data Independent Epsilon: 5.302585092994046
Data Dependent Epsilon: 5.30258509299405


### Predict Labels Using Teacher Models

In [9]:
for model in teachers:
    model.eval()

dataloader = data.DataLoader(student_dataset, batch_size=512, shuffle=False, drop_last=False)

batches_of_preds = []

n_batches = len(dataloader)
_prev_str_len = 0
for i, (imgs, _) in enumerate(dataloader):
    imgs = imgs.to(device)

    batch_of_preds = []
    
    with torch.no_grad():
        for j, model in enumerate(teachers):
            _progress_str = "Batch {:d}/{:d} - Teacher {:d}/{:d}".format(i, n_batches-1, j, n_teachers-1)
            print(_progress_str + ' ' * (_prev_str_len - len(_progress_str)), end='\r')
            _prev_str_len = len(_progress_str)

            outs  = model(imgs)
            preds = outs.argmax(dim=1)
            batch_of_preds.append(preds.cpu())
    
    batches_of_preds.append(batch_of_preds)
        
label_preds = torch.cat(
    [torch.stack([preds for preds in batch_of_preds], dim=0)
     for batch_of_preds in batches_of_preds],
    dim=1)

print()
print(label_preds)

Batch 17/17 - Teacher 249/249
tensor([[7, 0, 1,  ..., 6, 9, 0],
        [7, 0, 1,  ..., 6, 9, 0],
        [7, 2, 1,  ..., 6, 9, 0],
        ...,
        [7, 2, 1,  ..., 6, 9, 0],
        [7, 2, 1,  ..., 6, 9, 0],
        [7, 0, 1,  ..., 6, 9, 0]])


### Get Label Counts for Each Image

In [10]:
label_counts = torch.from_numpy(np.apply_along_axis(lambda x: np.bincount(x, minlength=10), axis=0, arr=label_preds.numpy()))
print(label_counts)

tensor([[  0,   4,   0,  ...,   0,   1, 232],
        [  0,   5, 248,  ...,   0,   0,   0],
        [  0, 213,   1,  ...,   0,   1,   7],
        ...,
        [248,   0,   0,  ...,   0,  26,   0],
        [  0,   7,   0,  ...,   0,   8,   0],
        [  0,   0,   0,  ...,   0, 192,   0]])


### Obtain Labels from Noisy Counts with a Certain $\epsilon$ Value

In [20]:
epsilon = 0.05

noise_dist = dists.Laplace(loc=torch.zeros([], dtype=torch.float),
                           scale=torch.full([], 1 / epsilon, dtype=torch.float))

noisy_counts = label_counts.float() + noise_dist.sample([10, label_counts.size(1)])

generated_labels = noisy_counts.argmax(dim=0)
print(generated_labels)
print()
print("Noisy Accuracy Against Predictions:", (generated_labels == label_counts.argmax(dim=0)).float().mean().item())

tensor([7, 2, 1,  ..., 6, 9, 0])

Noisy Accuracy Against Predictions: 0.9427777528762817


### Perform PATE Analysis to Check Information Leak

In [25]:
pate.perform_analysis(label_preds[:, :100], generated_labels[:100], epsilon, delta=1e-05)

(5.30258509299405, 5.302585092994046)

### Assign the Generated Labels to the DP Training Set

In [None]:
mnist_testset.targets[dp_dataset.indices] = generated_labels

## Train the DP Model

In [None]:
lr       = 3e-3
n_epochs = 10

dp_optimizer = optim.Adam(dp_model.parameters(), lr=lr)
criterion    = nn.CrossEntropyLoss()

dp_model.train()

dp_train_history = {'avg_losses':{}, 'avg_accuracies': {}}
for i_epoch in range(n_epochs):
    instance_count = 0
    total_loss     = 0.
    correct_count  = 0.

    n_batches = len(dp_dataloader)
    _prev_str_len = 0
    for i, (imgs, labels) in enumerate(dp_dataloader):
        _batch_str = "Epoch {:d}/{:d}: ({:d}/{:d})".format(i_epoch, n_epochs-1, i, n_batches-1)
        print(_batch_str + ' ' * (_prev_str_len - len(_batch_str)), end='\r')
        _prev_str_len = len(_batch_str)

        instance_count += imgs.size(0)

        imgs   = imgs.to(device)
        labels = labels.to(device)

        outs  = dp_model(imgs)
        preds = torch.argmax(outs, dim=1)

        dp_optimizer.zero_grad()
        loss = criterion(outs, labels)
        loss.backward()
        dp_optimizer.step()

        total_loss += loss.item() * imgs.size(0)

        correct_count += (preds == labels).sum().item()

    avg_loss = total_loss / instance_count
    avg_accuracy = correct_count / instance_count

    print()
    print("    Avg Loss: {:.6f}".format(avg_loss))
    print("    Avg Accuracy: {:.4f}".format(avg_accuracy))
    print()

    dp_train_history['avg_losses'][i_epoch]     = avg_loss
    dp_train_history['avg_accuracies'][i_epoch] = avg_accuracy

## Evaluate the Result Model on the Test Set

In [None]:
dp_model.eval()

criterion = nn.CrossEntropyLoss(reduction='sum')

instance_count = 0
total_loss     = 0.
correct_count  = 0.
# test_dataloader     = data.DataLoader(test_dataset, batch_size=batch_size, shuffle=False, drop_last=False)
n_batches = len(dataloader)
for i, (imgs, labels) in enumerate(data.DataLoader(test_dataset, batch_size=batch_size, shuffle=False, drop_last=False)):
    print("Batch {:d}/{:d}".format(i, n_batches-1), end='\r')

    instance_count += imgs.size(0)
    
    imgs = imgs.to(device)
    
    with torch.no_grad():
        outs  = model(imgs)
    
    total_loss += criterion(outs, labels).item()

    preds = outs.argmax(dim=1)
    
    correct_count += (preds == labels).sum().item()

print()
print("Average Loss:", total_loss / instance_count)
print("Average Accuracy:", correct_count / instance_count)

1. Prepare Data
    1. Split the training dataset into `n + 1` smaller datasets where `n` is the number of teacher models
    2. Define a Dataset class and a DataLoader that can give batches of data for all `n` teacher datasets
2. Define Model(s)
    1. A simple ConvNet for both the main model and all teacher models
    2. If too slow: custom `nn.Module` that can process `n` batches at once for all `n` teachers
3. Train Teachers
4. Label Unlabeled Training Dataset in a Differentially Private Manner
    1. Generate raw labels
    2. PATE analysis to find a proper `epsilon` value
    3. Add proper noise to the label counts
    4. Take the labels with most counts
5. Train the Main Model On the Training Dataset with Generated Labels
6. Test on the Test Dataset