# Section Project:

For the final project for this section, you're going to train a DP model using this PATE method on the MNIST dataset, provided below.

## Import Modules

In [1]:
import random
import numpy as np

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch.distributions as dists
import torch.utils.data as data

import torchvision.datasets as datasets
import torchvision.transforms as transforms

from syft.frameworks.torch.differential_privacy import pate

random.seed(0)
np.random.seed(0)
torch.manual_seed(0)

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
device



device(type='cuda', index=0)

## Prepare Data

### Load the MNIST Training & Test Datasets

In [2]:
mnist_trainset = datasets.MNIST(root='../data', train=True, download=True, transform=transforms.ToTensor())
mnist_testset  = datasets.MNIST(root='../data', train=False, download=True, transform=transforms.ToTensor())
mnist_testset.true_targets = mnist_testset.targets.clone() # data points that are considered "unlabeled" will be re-labeled by teachers later

print("Training Set Size:", len(mnist_trainset))
print("Test Set Size:", len(mnist_testset))
print()
print("Min Data Value:", torch.min(mnist_trainset.data.min(), mnist_testset.data.min()))
print("Max Data Value:", torch.max(mnist_trainset.data.max(), mnist_testset.data.max()))
print()
print("Train Label Counts:", {label.item():count.item() for label, count in zip(*torch.unique(mnist_trainset.targets, return_counts=True))})
print("Test Label Counts:", {label.item():count.item() for label, count in zip(*torch.unique(mnist_testset.targets, return_counts=True))})

Training Set Size: 60000
Test Set Size: 10000

Min Data Value: tensor(0, dtype=torch.uint8)
Max Data Value: tensor(255, dtype=torch.uint8)

Train Label Counts: {0: 5923, 1: 6742, 2: 5958, 3: 6131, 4: 5842, 5: 5421, 6: 5918, 7: 6265, 8: 5851, 9: 5949}
Test Label Counts: {0: 980, 1: 1135, 2: 1032, 3: 1010, 4: 982, 5: 892, 6: 958, 7: 1028, 8: 974, 9: 1009}


### Create Training & Test Datasets

In [3]:
n_teachers = 250

_teacher_dataset_len = len(mnist_trainset) // n_teachers

teacher_datasets = [data.Subset(mnist_trainset, list(range(i*_teacher_dataset_len, (i+1)*_teacher_dataset_len))) for i in range(n_teachers)]
student_dataset  = data.Subset(mnist_testset, list(range(int(len(mnist_testset) * 0.9))))
test_dataset     = data.Subset(mnist_testset, list(range(int(len(mnist_testset) * 0.9), len(mnist_testset))))

print("Size of each teacher dataset:", _teacher_dataset_len)
print("Size of the dataset available to the student:", len(student_dataset))
print("Size of the test dataset:", len(test_dataset))

Size of each teacher dataset: 240
Size of the dataset available to the student: 9000
Size of the test dataset: 1000


## Build Models

### Define Classifier

In [4]:
class MNISTClassifier(nn.Module):
    def __init__(self):
        super(MNISTClassifier, self).__init__()
        
        # 1x28x28
        self.bn0        = nn.BatchNorm2d(1)
        self.conv0      = nn.Conv2d(1, 5, 3, padding=1)
        self.bn1        = nn.BatchNorm2d(5)
        self.maxpool0   = nn.MaxPool2d(2)
        # 5x14x14
        self.conv1      = nn.Conv2d(5, 5, 3, padding=1)
        self.bn2        = nn.BatchNorm2d(5)
        self.maxpool1   = nn.MaxPool2d(2)
        # 5x 7x 7
        self.conv2      = nn.Conv2d(5, 5, 3, padding=1)
        self.bn3        = nn.BatchNorm2d(5)
        self.maxpool2   = nn.MaxPool2d(2, padding=1)
        # 5x 4x 4 = 80
        self.fc         = nn.Linear(80, 10)

        self.activation = nn.ReLU()
        
    def forward(self, x):
        x = self.conv0(self.bn0(x))
        x = self.bn1(x)
        x = self.activation(x)
        x = self.maxpool0(x)
        x = self.conv1(x)
        x = self.bn2(x)
        x = self.activation(x)
        x = self.maxpool1(x)
        x = self.conv2(x)
        x = self.bn3(x)
        x = self.activation(x)
        x = self.maxpool2(x)
        x = self.fc(x.view(-1, 80))
        
        return x

### Create Teacher & DF Models

In [5]:
teachers      = [MNISTClassifier().to(device) for _ in range(n_teachers)]
student       = MNISTClassifier().to(device)

## Train Teachers

### Training Loop

In [6]:
lr         = 3e-2
n_epochs   = 10
batch_size = 30

teacher_dataloaders = [data.DataLoader(dataset, batch_size=batch_size, shuffle=True, drop_last=True) for dataset in teacher_datasets]
teacher_optimizers  = [optim.Adam(model.parameters(), lr=lr) for model in teachers]
criterion           = nn.CrossEntropyLoss()

for model in teachers:
    model.train()

teachers_train_history = {'avg_losses':{}, 'avg_accuracies': {}}
for i_epoch in range(n_epochs):
    avg_losses      = []
    avg_accuracies  = []

    for i_model in range(n_teachers):
        instance_count = 0
        total_loss     = 0.
        correct_count  = 0

        model      = teachers[i_model]
        dataloader = teacher_dataloaders[i_model]
        optimizer  = teacher_optimizers[i_model]

        n_batches = len(dataloader)
        _prev_str_len = 0
        for i, (imgs, labels) in enumerate(dataloader):
            _batch_str = "Teacher {:d}/{:d}: ({:d}/{:d})".format(i_model, n_teachers-1, i, n_batches-1)
            print(_batch_str + ' ' * (_prev_str_len - len(_batch_str)), end='\r')
            _prev_str_len = len(_batch_str)

            instance_count += imgs.size(0)

            imgs   = imgs.to(device)
            labels = labels.to(device)
            
            outs  = model(imgs)
            preds = torch.argmax(outs, dim=1)
            
            optimizer.zero_grad()
            loss = criterion(outs, labels)
            loss.backward()
            optimizer.step()

            total_loss += loss.item() * imgs.size(0)
            
            correct_count += (preds == labels).sum().item()

        avg_losses.append(total_loss / instance_count)
        avg_accuracies.append(correct_count / instance_count)

    _epoch_str = "Epoch {:d}/{:d}".format(i_epoch, n_epochs-1)
    _epoch_str += ' ' * (_prev_str_len - len(_epoch_str))
    print(_epoch_str)
    print("    Avg Losses:", [round(avg_loss, 5) for avg_loss in avg_losses])
    print("    Avg Accuracies:", [round(avg_acc, 4) for avg_acc in avg_accuracies])
    print()

    teachers_train_history['avg_losses'][i_epoch]     = avg_losses
    teachers_train_history['avg_accuracies'][i_epoch] = avg_accuracies

Epoch 0/9             
    Avg Losses: [2.21073, 2.23918, 2.28086, 2.15874, 2.23866, 2.31621, 2.30713, 2.18633, 2.17418, 2.16536, 2.21159, 2.13759, 2.29504, 2.13437, 2.26481, 2.26822, 2.36746, 1.982, 2.2323, 2.13046, 2.18558, 2.3023, 2.18393, 2.23326, 2.22305, 2.19719, 2.31659, 2.25757, 2.18945, 2.30273, 2.24863, 2.29796, 2.19766, 2.19736, 2.23516, 2.10273, 2.25671, 2.16019, 2.15332, 2.15617, 2.16629, 2.16903, 2.29986, 2.15375, 2.35322, 2.26402, 2.16145, 2.11861, 1.9837, 2.22307, 2.18739, 2.1666, 2.20664, 2.20077, 2.16129, 2.13993, 2.22303, 2.21677, 2.25294, 2.318, 2.30551, 2.11383, 2.02318, 2.24366, 2.12847, 2.18998, 2.38609, 2.24913, 2.11516, 2.1852, 2.28901, 2.16478, 2.27972, 2.20896, 2.18185, 2.15838, 2.26615, 2.29928, 1.90267, 2.15768, 2.23865, 2.09571, 2.17137, 2.01474, 2.18474, 2.08203, 2.26916, 2.14139, 2.13376, 2.00431, 2.27429, 2.25236, 2.2412, 2.15806, 2.06265, 2.15356, 2.22228, 2.06031, 2.17198, 2.18582, 2.32406, 2.32279, 2.16512, 2.31682, 2.1459, 2.20135, 2.27286, 2.17005,

    Avg Accuracies: [0.5667, 0.6292, 0.4625, 0.6042, 0.5083, 0.3833, 0.4833, 0.5375, 0.6708, 0.6333, 0.5708, 0.6292, 0.4917, 0.6375, 0.7042, 0.6333, 0.4708, 0.7708, 0.5583, 0.7542, 0.5, 0.3833, 0.5417, 0.4417, 0.5, 0.7458, 0.375, 0.4208, 0.6583, 0.4542, 0.5583, 0.5, 0.6792, 0.4875, 0.3792, 0.5375, 0.4833, 0.6292, 0.6083, 0.5458, 0.675, 0.625, 0.4042, 0.6292, 0.3167, 0.5583, 0.6125, 0.6125, 0.7375, 0.6125, 0.5417, 0.5417, 0.6542, 0.6, 0.5792, 0.7667, 0.5417, 0.5042, 0.4083, 0.4125, 0.4583, 0.6542, 0.7042, 0.5125, 0.6083, 0.525, 0.3542, 0.6208, 0.5292, 0.5875, 0.325, 0.6375, 0.2875, 0.4833, 0.5458, 0.6083, 0.5917, 0.4, 0.7417, 0.6958, 0.4833, 0.6792, 0.475, 0.6792, 0.5875, 0.6917, 0.475, 0.5542, 0.6167, 0.7833, 0.6125, 0.5208, 0.5583, 0.5708, 0.6333, 0.6792, 0.6292, 0.725, 0.5958, 0.5208, 0.3375, 0.3917, 0.5833, 0.3958, 0.7, 0.5167, 0.5042, 0.5625, 0.5042, 0.5625, 0.5875, 0.4042, 0.5667, 0.6, 0.575, 0.4917, 0.7375, 0.525, 0.325, 0.5333, 0.6167, 0.4708, 0.5625, 0.4833, 0.5542, 0.525, 0.52

    Avg Losses: [0.48196, 0.39495, 0.64809, 0.63741, 0.49076, 0.60653, 0.66273, 0.3823, 0.34925, 0.33786, 0.29736, 0.29504, 0.48175, 0.46401, 0.42631, 0.31803, 0.68814, 0.28315, 0.42986, 0.30316, 0.58328, 0.71571, 0.51409, 0.66945, 0.57366, 0.25209, 0.52692, 0.42819, 0.38539, 0.62895, 0.4631, 0.4921, 0.43013, 0.5892, 0.69549, 0.62799, 0.70596, 0.38762, 0.43073, 0.6265, 0.25464, 0.34913, 0.82218, 0.27526, 0.57192, 0.63084, 0.34517, 0.3787, 0.4145, 0.31513, 0.41947, 0.45023, 0.34627, 0.39547, 0.66425, 0.26189, 0.59932, 0.61769, 0.72941, 0.82716, 0.54384, 0.36854, 0.39575, 0.58415, 0.3106, 0.4627, 0.78898, 0.39925, 0.39961, 0.41193, 0.70602, 0.29887, 0.71757, 0.65548, 0.40805, 0.3696, 0.31718, 0.88586, 0.25379, 0.34058, 0.54434, 0.37237, 0.61531, 0.26884, 0.47284, 0.28939, 0.73809, 0.45363, 0.31637, 0.24237, 0.32735, 0.28513, 0.52126, 0.46653, 0.54293, 0.34268, 0.41552, 0.27387, 0.3654, 0.53991, 0.7238, 0.93889, 0.44812, 0.71859, 0.30867, 0.3954, 0.53558, 0.53909, 0.42298, 0.64224, 0.5676

    Avg Accuracies: [0.9042, 0.9375, 0.8875, 0.8958, 0.8875, 0.8625, 0.8292, 0.9583, 0.9292, 0.9125, 0.9542, 0.9292, 0.925, 0.9, 0.9042, 0.9417, 0.8833, 0.9417, 0.9042, 0.9542, 0.875, 0.8583, 0.8958, 0.8417, 0.8958, 0.95, 0.8958, 0.9167, 0.9167, 0.8833, 0.9375, 0.9, 0.9125, 0.8167, 0.8708, 0.8708, 0.8042, 0.9, 0.9042, 0.8708, 0.975, 0.9333, 0.8208, 0.9458, 0.8792, 0.9042, 0.8958, 0.875, 0.9042, 0.95, 0.9083, 0.9292, 0.9458, 0.9375, 0.8417, 0.9458, 0.875, 0.8542, 0.875, 0.8125, 0.875, 0.8708, 0.8958, 0.8625, 0.95, 0.9042, 0.8333, 0.8875, 0.9208, 0.9083, 0.8333, 0.9292, 0.8833, 0.8708, 0.9208, 0.9333, 0.925, 0.8083, 0.9625, 0.95, 0.875, 0.9375, 0.8708, 0.9292, 0.9125, 0.9583, 0.8417, 0.8958, 0.9542, 0.9667, 0.9625, 0.8917, 0.8875, 0.925, 0.875, 0.9458, 0.9042, 0.9458, 0.8875, 0.8833, 0.8542, 0.8208, 0.9042, 0.85, 0.9458, 0.9292, 0.8917, 0.8792, 0.9042, 0.9, 0.8667, 0.8375, 0.9375, 0.9042, 0.8542, 0.9125, 0.9125, 0.9, 0.7917, 0.9167, 0.8917, 0.8667, 0.9042, 0.9167, 0.8875, 0.8917, 0.9042,

    Avg Losses: [0.1074, 0.06861, 0.12346, 0.16572, 0.2016, 0.15397, 0.19522, 0.09769, 0.07793, 0.10249, 0.11256, 0.08473, 0.13709, 0.10065, 0.08645, 0.10457, 0.20076, 0.06277, 0.15415, 0.12645, 0.17388, 0.20123, 0.1143, 0.34864, 0.13082, 0.03396, 0.15364, 0.1132, 0.0955, 0.20392, 0.13981, 0.16824, 0.11068, 0.25701, 0.1951, 0.25079, 0.27671, 0.12609, 0.13564, 0.16434, 0.05833, 0.07401, 0.23937, 0.064, 0.17943, 0.11122, 0.10858, 0.21967, 0.18916, 0.08947, 0.07911, 0.09582, 0.07157, 0.20472, 0.20377, 0.08675, 0.17988, 0.27651, 0.2601, 0.23554, 0.14984, 0.1118, 0.18287, 0.23356, 0.09189, 0.12019, 0.25232, 0.10465, 0.09447, 0.23183, 0.27649, 0.12257, 0.24177, 0.17097, 0.19736, 0.14328, 0.09079, 0.30777, 0.05961, 0.06416, 0.12345, 0.05246, 0.1385, 0.14273, 0.16574, 0.04048, 0.24411, 0.27874, 0.07252, 0.08183, 0.06831, 0.08073, 0.16821, 0.20314, 0.23753, 0.04492, 0.08764, 0.0755, 0.07873, 0.15205, 0.1415, 0.23516, 0.12427, 0.30495, 0.05272, 0.11172, 0.16894, 0.19225, 0.11059, 0.17298, 0.2046

    Avg Accuracies: [0.9875, 0.9833, 0.975, 0.9875, 0.9417, 0.975, 0.9708, 0.9583, 0.9958, 0.975, 0.9833, 0.9708, 0.975, 0.9792, 0.9792, 0.975, 0.9375, 0.9958, 0.95, 0.9583, 0.9542, 0.95, 0.9917, 0.8875, 0.9792, 0.9958, 0.9958, 0.9875, 0.9875, 0.9667, 0.9875, 0.9625, 0.9833, 0.9333, 0.95, 0.9417, 0.95, 0.975, 0.9833, 0.9667, 0.9917, 0.9833, 0.9667, 0.9917, 0.9708, 0.9792, 0.9833, 0.9667, 0.9583, 0.9667, 0.9875, 0.9708, 0.9875, 0.9542, 0.9417, 0.9917, 0.975, 0.9708, 0.9458, 0.9542, 0.9583, 0.9917, 0.9458, 0.95, 0.9958, 0.9667, 0.9625, 1.0, 0.9917, 0.925, 0.9375, 0.9625, 0.9625, 0.9792, 0.975, 0.9667, 0.9792, 0.9417, 0.9875, 0.9958, 0.9833, 0.9917, 0.9875, 0.9792, 0.9708, 1.0, 0.95, 0.9208, 0.9917, 0.9667, 0.9792, 0.9917, 0.9625, 0.95, 0.9458, 1.0, 0.9792, 1.0, 0.9833, 0.9625, 0.975, 0.9667, 0.9667, 0.9375, 0.9917, 0.9917, 0.975, 0.9833, 0.9667, 0.9667, 0.95, 0.975, 0.9667, 0.975, 0.9833, 0.9917, 0.975, 0.9458, 0.9375, 0.9667, 0.9917, 0.975, 0.9667, 0.9792, 0.9708, 0.9958, 0.9875, 0.9875

    Avg Losses: [0.031, 0.04517, 0.03066, 0.03853, 0.07917, 0.04407, 0.06208, 0.04161, 0.02878, 0.03631, 0.04314, 0.05493, 0.04891, 0.06261, 0.03742, 0.08305, 0.05574, 0.03096, 0.10272, 0.03162, 0.06426, 0.05753, 0.034, 0.15973, 0.04476, 0.01089, 0.04388, 0.04683, 0.03168, 0.05778, 0.02028, 0.09152, 0.03384, 0.07283, 0.17371, 0.10615, 0.16491, 0.03162, 0.01737, 0.0596, 0.01094, 0.04787, 0.08315, 0.01412, 0.06868, 0.02896, 0.03201, 0.05543, 0.06618, 0.08458, 0.02304, 0.04318, 0.01237, 0.07127, 0.04412, 0.01501, 0.05284, 0.06106, 0.07912, 0.10177, 0.05127, 0.05591, 0.04659, 0.07581, 0.02393, 0.03881, 0.06762, 0.01572, 0.01738, 0.08198, 0.06707, 0.041, 0.0525, 0.05158, 0.05392, 0.04268, 0.0377, 0.17164, 0.01386, 0.01488, 0.04279, 0.01018, 0.04226, 0.02558, 0.03742, 0.00893, 0.05678, 0.1312, 0.02546, 0.03908, 0.03055, 0.02347, 0.05537, 0.0586, 0.07012, 0.01602, 0.07654, 0.0159, 0.02497, 0.11268, 0.03171, 0.09127, 0.04479, 0.10482, 0.01504, 0.02406, 0.06633, 0.04686, 0.024, 0.07013, 0.12874

### Teacher Evaluation (Average and Aggregate)

In [7]:
for model in teachers:
    model.eval()

test_dataloader = data.DataLoader(test_dataset, batch_size=1024, shuffle=False, drop_last=False)
criterion = nn.CrossEntropyLoss(reduction='sum')

instance_count = 0
total_loss     = 0.
correct_count  = 0.

preds_lists_list = []
labels_list      = []

n_batches = len(test_dataloader)
_prev_str_len = 0
for i, (imgs, labels) in enumerate(test_dataloader):

    imgs   = imgs.to(device)
    labels_list.append(labels)
    labels = labels.to(device)

    preds_list = []
    with torch.no_grad():
        for j, model in enumerate(teachers):
            instance_count += imgs.size(0)

            _progress_str = "Batch {:d}/{:d} - Teacher {:d}/{:d}".format(i, n_batches-1, j, n_teachers-1)
            print(_progress_str + ' ' * (_prev_str_len - len(_progress_str)), end='\r')
            _prev_str_len = len(_progress_str)

            outs  = model(imgs)
            preds = outs.argmax(dim=1)
            preds_list.append(preds.cpu())

            total_loss += criterion(outs, labels).item()
            correct_count += (preds == labels).sum().item()
    preds_lists_list.append(preds_list)

preds_tensor = torch.cat([torch.stack(preds_list, dim=0) for preds_list in preds_lists_list], dim=1)
preds_counts = torch.from_numpy(np.apply_along_axis(lambda x: np.bincount(x, minlength=10), axis=0, arr=preds_tensor.numpy()))

aggregate_preds = preds_counts.argmax(dim=0)
labels          = torch.cat(labels_list, dim=0)

aggregate_acc = (aggregate_preds == labels).float().mean().item()

print('\n')
print("Average Loss:", total_loss / instance_count)
print("Average Accuracy:", correct_count / instance_count)
print()
print("Aggregate Accuracy:", aggregate_acc)

Batch 0/0 - Teacher 249/249

Average Loss: 0.7392376829833984
Average Accuracy: 0.815656

Aggregate Accuracy: 0.9629999995231628


## Aggregate Teacher Models

In [8]:
def aggregate_counts(img, return_label_preds=False):
    assert 3 <= img.dim() <= 4
    if img.dim() == 3:
        img = img.unsqueeze(0)
    else:
        assert img.size(0) == 1

    img = img.to(device)

    preds_list = []

    with torch.no_grad():
        for model in teachers:
            model.eval()
            preds_list.append(model(img).argmax(dim=1).view(1).cpu())

    preds_tensor = torch.cat(preds_list, dim=0)
    
    counts = torch.bincount(preds_tensor, minlength=10)
    
    if return_label_preds:
        return counts, preds_tensor
    return counts

## Train the Student

In [9]:
max_train_subset_size  = 100
n_new_labels_per_epoch = 5
n_updates_per_epoch    = 40
lr                     = 2e-2
weight_decay           = 0
epsilon                = 0.05

student = MNISTClassifier().to(device)

laplace_noise = dists.Laplace(torch.zeros([], dtype=torch.float), torch.tensor(1 / epsilon, dtype=torch.float))

init_subset_size = n_new_labels_per_epoch + (max_train_subset_size % n_new_labels_per_epoch)
n_total_epochs   = max_train_subset_size // n_new_labels_per_epoch

student_unlabeled_dataset    = data.Subset(mnist_testset, list(student_dataset.indices))
student_unlabeled_dataloader = data.DataLoader(student_unlabeled_dataset, batch_size=1024, shuffle=False, drop_last=False)
student_labeled_dataset      = data.Subset(mnist_testset, [])
student_labeled_dataloader   = data.DataLoader(student_labeled_dataset,
                                               batch_sampler=data.BatchSampler(data.SequentialSampler(student_unlabeled_dataset), init_subset_size, True))

student_optimizer            = optim.Adam(student.parameters(), lr=lr, weight_decay=weight_decay)
criterion                    = nn.CrossEntropyLoss()

student.train()

teacher_preds = []
noisy_labels  = []
student_train_history = {'n_labels': {}, 'avg_losses':{}, 'avg_accuracies': {}}
for i_epoch in range(n_total_epochs):
    if i_epoch == 0:
        new_label_indices = random.sample(student_unlabeled_dataset.indices, init_subset_size)

    else:
        max_probs_list = []

        with torch.no_grad():
            for imgs, _ in student_unlabeled_dataloader:
                imgs = imgs.to(device)

                outs  = student(imgs)
                probs = outs.softmax(dim=1)

                max_probs_list.append(probs.max(dim=1)[0].cpu())

        max_probs_tensor = torch.cat(max_probs_list, dim=0)
        
        new_label_indices = [student_unlabeled_dataset.indices[idx] for idx in
                             max_probs_tensor.topk(n_new_labels_per_epoch, largest=False, sorted=False)[1]]

    for idx in new_label_indices:
        student_labeled_dataset.indices.append(idx)
        student_unlabeled_dataset.indices.remove(idx)
        label_pred_counts, label_preds = aggregate_counts(mnist_testset[idx][0].view(1, 1, 28, 28), True)
        noisy_label = (label_pred_counts.float() + laplace_noise.sample(label_pred_counts.size())).argmax(dim=0)
        mnist_testset.targets[idx] = noisy_label

        teacher_preds.append(label_preds)
        noisy_labels.append(noisy_label)
        
    student_labeled_dataloader.batch_sampler.batch_size = len(student_labeled_dataset)
    
    print("Epoch {:d}/{:d} - Number of Labeled Data: {:d}".format(i_epoch, n_total_epochs-1, len(student_labeled_dataset)))

    student_train_history['n_labels'][i_epoch] = init_subset_size + (i_epoch * n_new_labels_per_epoch)
    losses = []
    accuracies = []
    for i_update in range(n_updates_per_epoch):
        imgs, labels = next(iter(student_labeled_dataloader))

        imgs   = imgs.to(device)
        labels = labels.to(device)

        outs  = student(imgs)
        preds = torch.argmax(outs, dim=1)

        student_optimizer.zero_grad()
        loss = criterion(outs, labels)
        loss.backward()
        student_optimizer.step()

        accuracy = (preds == labels).float().mean().item()

        losses.append(loss.item())
        accuracies.append(accuracy)

        print("    Update {:d} - Loss = {:.6f}, Accuracy = {:.4}".format(i_update, loss.item(), accuracy))

    student_train_history['avg_losses'][i_epoch]     = losses
    student_train_history['avg_accuracies'][i_epoch] = accuracies

teacher_preds = torch.stack(teacher_preds, dim=1)
noisy_labels  = torch.tensor(noisy_labels)

Epoch 0/19 - Number of Labeled Data: 5
    Update 0 - Loss = 2.516278, Accuracy = 0.0
    Update 1 - Loss = 1.813576, Accuracy = 0.2
    Update 2 - Loss = 1.451643, Accuracy = 1.0
    Update 3 - Loss = 1.171823, Accuracy = 1.0
    Update 4 - Loss = 0.919899, Accuracy = 1.0
    Update 5 - Loss = 0.684254, Accuracy = 1.0
    Update 6 - Loss = 0.458117, Accuracy = 1.0
    Update 7 - Loss = 0.295014, Accuracy = 1.0
    Update 8 - Loss = 0.171470, Accuracy = 1.0
    Update 9 - Loss = 0.095434, Accuracy = 1.0
    Update 10 - Loss = 0.055843, Accuracy = 1.0
    Update 11 - Loss = 0.034436, Accuracy = 1.0
    Update 12 - Loss = 0.020803, Accuracy = 1.0
    Update 13 - Loss = 0.012203, Accuracy = 1.0
    Update 14 - Loss = 0.007342, Accuracy = 1.0
    Update 15 - Loss = 0.004648, Accuracy = 1.0
    Update 16 - Loss = 0.003143, Accuracy = 1.0
    Update 17 - Loss = 0.002238, Accuracy = 1.0
    Update 18 - Loss = 0.001647, Accuracy = 1.0
    Update 19 - Loss = 0.001253, Accuracy = 1.0
    Update 

    Update 6 - Loss = 0.077698, Accuracy = 1.0
    Update 7 - Loss = 0.061360, Accuracy = 1.0
    Update 8 - Loss = 0.043584, Accuracy = 1.0
    Update 9 - Loss = 0.033311, Accuracy = 1.0
    Update 10 - Loss = 0.029581, Accuracy = 1.0
    Update 11 - Loss = 0.029295, Accuracy = 1.0
    Update 12 - Loss = 0.026632, Accuracy = 1.0
    Update 13 - Loss = 0.023628, Accuracy = 1.0
    Update 14 - Loss = 0.021323, Accuracy = 1.0
    Update 15 - Loss = 0.018907, Accuracy = 1.0
    Update 16 - Loss = 0.016757, Accuracy = 1.0
    Update 17 - Loss = 0.015043, Accuracy = 1.0
    Update 18 - Loss = 0.013442, Accuracy = 1.0
    Update 19 - Loss = 0.011802, Accuracy = 1.0
    Update 20 - Loss = 0.010276, Accuracy = 1.0
    Update 21 - Loss = 0.008963, Accuracy = 1.0
    Update 22 - Loss = 0.007841, Accuracy = 1.0
    Update 23 - Loss = 0.006883, Accuracy = 1.0
    Update 24 - Loss = 0.006069, Accuracy = 1.0
    Update 25 - Loss = 0.005384, Accuracy = 1.0
    Update 26 - Loss = 0.004865, Accuracy = 

    Update 13 - Loss = 0.010081, Accuracy = 1.0
    Update 14 - Loss = 0.008916, Accuracy = 1.0
    Update 15 - Loss = 0.007898, Accuracy = 1.0
    Update 16 - Loss = 0.007060, Accuracy = 1.0
    Update 17 - Loss = 0.006424, Accuracy = 1.0
    Update 18 - Loss = 0.005910, Accuracy = 1.0
    Update 19 - Loss = 0.005481, Accuracy = 1.0
    Update 20 - Loss = 0.005076, Accuracy = 1.0
    Update 21 - Loss = 0.004708, Accuracy = 1.0
    Update 22 - Loss = 0.004364, Accuracy = 1.0
    Update 23 - Loss = 0.004089, Accuracy = 1.0
    Update 24 - Loss = 0.003843, Accuracy = 1.0
    Update 25 - Loss = 0.003618, Accuracy = 1.0
    Update 26 - Loss = 0.003394, Accuracy = 1.0
    Update 27 - Loss = 0.003178, Accuracy = 1.0
    Update 28 - Loss = 0.002973, Accuracy = 1.0
    Update 29 - Loss = 0.002787, Accuracy = 1.0
    Update 30 - Loss = 0.002615, Accuracy = 1.0
    Update 31 - Loss = 0.002458, Accuracy = 1.0
    Update 32 - Loss = 0.002317, Accuracy = 1.0
    Update 33 - Loss = 0.002197, Accurac

    Update 20 - Loss = 0.002718, Accuracy = 1.0
    Update 21 - Loss = 0.002507, Accuracy = 1.0
    Update 22 - Loss = 0.002357, Accuracy = 1.0
    Update 23 - Loss = 0.002251, Accuracy = 1.0
    Update 24 - Loss = 0.002177, Accuracy = 1.0
    Update 25 - Loss = 0.002125, Accuracy = 1.0
    Update 26 - Loss = 0.002082, Accuracy = 1.0
    Update 27 - Loss = 0.002045, Accuracy = 1.0
    Update 28 - Loss = 0.002007, Accuracy = 1.0
    Update 29 - Loss = 0.001959, Accuracy = 1.0
    Update 30 - Loss = 0.001898, Accuracy = 1.0
    Update 31 - Loss = 0.001826, Accuracy = 1.0
    Update 32 - Loss = 0.001746, Accuracy = 1.0
    Update 33 - Loss = 0.001662, Accuracy = 1.0
    Update 34 - Loss = 0.001578, Accuracy = 1.0
    Update 35 - Loss = 0.001498, Accuracy = 1.0
    Update 36 - Loss = 0.001420, Accuracy = 1.0
    Update 37 - Loss = 0.001347, Accuracy = 1.0
    Update 38 - Loss = 0.001279, Accuracy = 1.0
    Update 39 - Loss = 0.001217, Accuracy = 1.0
Epoch 13/19 - Number of Labeled Data: 70

    Update 27 - Loss = 0.001297, Accuracy = 1.0
    Update 28 - Loss = 0.001211, Accuracy = 1.0
    Update 29 - Loss = 0.001129, Accuracy = 1.0
    Update 30 - Loss = 0.001055, Accuracy = 1.0
    Update 31 - Loss = 0.000988, Accuracy = 1.0
    Update 32 - Loss = 0.000930, Accuracy = 1.0
    Update 33 - Loss = 0.000879, Accuracy = 1.0
    Update 34 - Loss = 0.000836, Accuracy = 1.0
    Update 35 - Loss = 0.000799, Accuracy = 1.0
    Update 36 - Loss = 0.000768, Accuracy = 1.0
    Update 37 - Loss = 0.000740, Accuracy = 1.0
    Update 38 - Loss = 0.000716, Accuracy = 1.0
    Update 39 - Loss = 0.000694, Accuracy = 1.0
Epoch 17/19 - Number of Labeled Data: 90
    Update 0 - Loss = 0.095339, Accuracy = 0.9556
    Update 1 - Loss = 0.055714, Accuracy = 0.9778
    Update 2 - Loss = 0.021279, Accuracy = 1.0
    Update 3 - Loss = 0.008717, Accuracy = 1.0
    Update 4 - Loss = 0.007046, Accuracy = 1.0
    Update 5 - Loss = 0.010180, Accuracy = 1.0
    Update 6 - Loss = 0.013835, Accuracy = 1.0


## Evaluate the Final Student Model on the Test Set

In [10]:
student.eval()

criterion = nn.CrossEntropyLoss(reduction='sum')

instance_count = 0
total_loss     = 0.
correct_count  = 0.

test_dataloader = data.DataLoader(test_dataset, batch_size=batch_size, shuffle=False, drop_last=False)

n_batches = len(test_dataloader)
for i, (imgs, labels) in enumerate(test_dataloader):
    print("Batch {:d}/{:d}".format(i, n_batches-1), end='\r')

    instance_count += imgs.size(0)
    
    imgs   = imgs.to(device)
    labels = labels.to(device)
    
    with torch.no_grad():
        outs  = student(imgs)
    
    total_loss += criterion(outs, labels).item()

    preds = outs.argmax(dim=1)
    
    correct_count += (preds == labels).sum().item()

print()
print("Average Loss:", total_loss / instance_count)
print("Average Accuracy:", correct_count / instance_count)

Batch 33/33
Average Loss: 1.434910111427307
Average Accuracy: 0.782


## Perform PATE Analysis to Measure the Information Leak

In [16]:
# Noisy Predictions as indices
data_dep_eps, data_indep_eps = pate.perform_analysis(teacher_preds.numpy(), noisy_labels.numpy(), epsilon, moments=18)
print("Data Independent Epsilon:", data_indep_eps)
print("Data Dependent Epsilon:  ", data_dep_eps)

Data Independent Epsilon: 5.302585092994046
Data Dependent Epsilon:   1.6973147844840117


In [17]:
# Predictions without Noise as indices
data_dep_eps, data_indep_eps = pate.perform_analysis(teacher_preds.numpy(), np.argmax(np.apply_along_axis(lambda x: np.bincount(x, None, 10), 0, teacher_preds.numpy()), 0), epsilon, moments=18)
print("Data Independent Epsilon:", data_indep_eps)
print("Data Dependent Epsilon:  ", data_dep_eps)

Data Independent Epsilon: 5.302585092994046
Data Dependent Epsilon:   1.6174774676520607


In [18]:
# True Labels as indices
data_dep_eps, data_indep_eps = pate.perform_analysis(teacher_preds.numpy(), mnist_testset.true_targets[student_labeled_dataset.indices], epsilon, moments=18)
print("Data Independent Epsilon:", data_indep_eps)
print("Data Dependent Epsilon:  ", data_dep_eps)

Data Independent Epsilon: 5.302585092994046
Data Dependent Epsilon:   1.5258742466643267


---