# Section Project:

For the final project for this section, you're going to train a DP model using this PATE method on the MNIST dataset, provided below.

## Import Modules

In [1]:
import random
import numpy as np

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch.distributions as dists
import torch.utils.data as data

import torchvision.datasets as datasets
import torchvision.transforms as transforms

from syft.frameworks.torch.differential_privacy import pate

random.seed(0)
np.random.seed(0)
torch.manual_seed(0)

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
device



device(type='cuda', index=0)

## Prepare Data

### Load the MNIST Training & Test Datasets

In [2]:
mnist_trainset = datasets.MNIST(root='../data', train=True, download=True, transform=transforms.ToTensor())
mnist_testset  = datasets.MNIST(root='../data', train=False, download=True, transform=transforms.ToTensor())
mnist_testset.true_targets = mnist_testset.targets.clone() # data points that are considered "unlabeled" will be re-labeled by teachers later

print("Training Set Size:", len(mnist_trainset))
print("Test Set Size:", len(mnist_testset))
print()
print("Min Data Value:", torch.min(mnist_trainset.data.min(), mnist_testset.data.min()))
print("Max Data Value:", torch.max(mnist_trainset.data.max(), mnist_testset.data.max()))
print()
print("Train Label Counts:", {label.item():count.item() for label, count in zip(*torch.unique(mnist_trainset.targets, return_counts=True))})
print("Test Label Counts:", {label.item():count.item() for label, count in zip(*torch.unique(mnist_testset.targets, return_counts=True))})

Training Set Size: 60000
Test Set Size: 10000

Min Data Value: tensor(0, dtype=torch.uint8)
Max Data Value: tensor(255, dtype=torch.uint8)

Train Label Counts: {0: 5923, 1: 6742, 2: 5958, 3: 6131, 4: 5842, 5: 5421, 6: 5918, 7: 6265, 8: 5851, 9: 5949}
Test Label Counts: {0: 980, 1: 1135, 2: 1032, 3: 1010, 4: 982, 5: 892, 6: 958, 7: 1028, 8: 974, 9: 1009}


### Create Training & Test Datasets

In [3]:
n_teachers = 250

_teacher_dataset_len = len(mnist_trainset) // n_teachers

teacher_datasets = [data.Subset(mnist_trainset, list(range(i*_teacher_dataset_len, (i+1)*_teacher_dataset_len))) for i in range(n_teachers)]
student_dataset  = data.Subset(mnist_testset, list(range(int(len(mnist_testset) * 0.9))))
test_dataset     = data.Subset(mnist_testset, list(range(int(len(mnist_testset) * 0.9), len(mnist_testset))))

print("Size of each teacher dataset:", _teacher_dataset_len)
print("Size of the dataset available to the student:", len(student_dataset))
print("Size of the test dataset:", len(test_dataset))

Size of each teacher dataset: 240
Size of the dataset available to the student: 9000
Size of the test dataset: 1000


## Build Models

### Define Classifier

In [4]:
class MNISTClassifier(nn.Module):
    def __init__(self):
        super(MNISTClassifier, self).__init__()
        
        # 1x28x28
        self.bn0        = nn.BatchNorm2d(1)
        self.conv0      = nn.Conv2d(1, 5, 3, padding=1)
        self.bn1        = nn.BatchNorm2d(5)
        self.maxpool0   = nn.MaxPool2d(2)
        # 5x14x14
        self.conv1      = nn.Conv2d(5, 5, 3, padding=1)
        self.bn2        = nn.BatchNorm2d(5)
        self.maxpool1   = nn.MaxPool2d(2)
        # 5x 7x 7
        self.conv2      = nn.Conv2d(5, 5, 3, padding=1)
        self.bn3        = nn.BatchNorm2d(5)
        self.maxpool2   = nn.MaxPool2d(2, padding=1)
        # 5x 4x 4 = 80
        self.fc         = nn.Linear(80, 10)

        self.activation = nn.ReLU()
        
    def forward(self, x):
        x = self.conv0(self.bn0(x))
        x = self.bn1(x)
        x = self.activation(x)
        x = self.maxpool0(x)
        x = self.conv1(x)
        x = self.bn2(x)
        x = self.activation(x)
        x = self.maxpool1(x)
        x = self.conv2(x)
        x = self.bn3(x)
        x = self.activation(x)
        x = self.maxpool2(x)
        x = self.fc(x.view(-1, 80))
        
        return x

### Create Teacher & DF Models

In [5]:
teachers      = [MNISTClassifier().to(device) for _ in range(n_teachers)]
student       = MNISTClassifier().to(device)

## Train Teachers

In [6]:
lr         = 3e-2
n_epochs   = 10
batch_size = 30

teacher_dataloaders = [data.DataLoader(dataset, batch_size=batch_size, shuffle=True, drop_last=True) for dataset in teacher_datasets]
teacher_optimizers  = [optim.Adam(model.parameters(), lr=lr) for model in teachers]
criterion           = nn.CrossEntropyLoss()

for model in teachers:
    model.train()

teachers_train_history = {'avg_losses':{}, 'avg_accuracies': {}}
for i_epoch in range(n_epochs):
    avg_losses      = []
    avg_accuracies  = []

    for i_model in range(n_teachers):
        instance_count = 0
        total_loss     = 0.
        correct_count  = 0

        model      = teachers[i_model]
        dataloader = teacher_dataloaders[i_model]
        optimizer  = teacher_optimizers[i_model]

        n_batches = len(dataloader)
        _prev_str_len = 0
        for i, (imgs, labels) in enumerate(dataloader):
            _batch_str = "Teacher {:d}/{:d}: ({:d}/{:d})".format(i_model, n_teachers-1, i, n_batches-1)
            print(_batch_str + ' ' * (_prev_str_len - len(_batch_str)), end='\r')
            _prev_str_len = len(_batch_str)

            instance_count += imgs.size(0)

            imgs   = imgs.to(device)
            labels = labels.to(device)
            
            outs  = model(imgs)
            preds = torch.argmax(outs, dim=1)
            
            optimizer.zero_grad()
            loss = criterion(outs, labels)
            loss.backward()
            optimizer.step()

            total_loss += loss.item() * imgs.size(0)
            
            correct_count += (preds == labels).sum().item()

        avg_losses.append(total_loss / instance_count)
        avg_accuracies.append(correct_count / instance_count)

    _epoch_str = "Epoch {:d}/{:d}".format(i_epoch, n_epochs-1)
    _epoch_str += ' ' * (_prev_str_len - len(_epoch_str))
    print(_epoch_str)
    print("    Avg Losses:", [round(avg_loss, 5) for avg_loss in avg_losses])
    print("    Avg Accuracies:", [round(avg_acc, 4) for avg_acc in avg_accuracies])
    print()

    teachers_train_history['avg_losses'][i_epoch]     = avg_losses
    teachers_train_history['avg_accuracies'][i_epoch] = avg_accuracies

Epoch 0/9             
    Avg Losses: [2.21073, 2.23918, 2.28086, 2.15874, 2.23866, 2.31621, 2.30713, 2.18633, 2.17418, 2.16536, 2.21159, 2.13759, 2.29504, 2.13437, 2.26481, 2.26822, 2.36746, 1.982, 2.2323, 2.13046, 2.18558, 2.3023, 2.18393, 2.23326, 2.22305, 2.19732, 2.31659, 2.25757, 2.18945, 2.30273, 2.24863, 2.29796, 2.19766, 2.19736, 2.23516, 2.10273, 2.25671, 2.16019, 2.15332, 2.15617, 2.16629, 2.16903, 2.29991, 2.15375, 2.35322, 2.26402, 2.16145, 2.11861, 1.9837, 2.22307, 2.18739, 2.1666, 2.20664, 2.20077, 2.16129, 2.13993, 2.22303, 2.21677, 2.25294, 2.318, 2.30551, 2.1139, 2.02318, 2.24366, 2.12847, 2.18998, 2.38609, 2.24913, 2.11516, 2.1852, 2.28901, 2.16478, 2.27972, 2.20896, 2.18185, 2.15838, 2.26615, 2.29928, 1.90267, 2.15768, 2.23865, 2.09571, 2.17137, 2.01474, 2.18474, 2.08203, 2.26916, 2.14139, 2.13376, 2.00431, 2.27429, 2.25236, 2.2412, 2.15806, 2.06265, 2.15356, 2.22228, 2.06031, 2.17198, 2.18582, 2.32406, 2.32279, 2.16512, 2.31682, 2.1459, 2.20135, 2.27286, 2.17005, 

    Avg Accuracies: [0.5667, 0.6292, 0.4625, 0.6042, 0.5083, 0.3833, 0.4833, 0.5375, 0.6708, 0.6333, 0.5708, 0.6292, 0.4917, 0.6375, 0.7042, 0.6333, 0.4708, 0.7708, 0.5583, 0.7542, 0.5, 0.3833, 0.5417, 0.4417, 0.5, 0.7458, 0.375, 0.4208, 0.6583, 0.4542, 0.5583, 0.5, 0.6792, 0.4875, 0.3792, 0.5375, 0.4833, 0.6292, 0.6083, 0.5458, 0.6792, 0.625, 0.4125, 0.6292, 0.3167, 0.5583, 0.6125, 0.6125, 0.7375, 0.6125, 0.5417, 0.5417, 0.6542, 0.6, 0.5792, 0.7667, 0.5417, 0.5042, 0.4083, 0.4125, 0.4583, 0.6458, 0.7042, 0.5125, 0.6083, 0.525, 0.3542, 0.6208, 0.5292, 0.5875, 0.325, 0.6375, 0.2875, 0.4833, 0.5458, 0.6083, 0.5917, 0.4, 0.7417, 0.6958, 0.4833, 0.6792, 0.475, 0.6792, 0.5875, 0.6917, 0.475, 0.5542, 0.6167, 0.7833, 0.6125, 0.5208, 0.5583, 0.5708, 0.6333, 0.6792, 0.6292, 0.725, 0.5958, 0.5208, 0.3375, 0.3917, 0.5833, 0.3958, 0.7, 0.5167, 0.5042, 0.5625, 0.5042, 0.5625, 0.5875, 0.4042, 0.5667, 0.6, 0.575, 0.4917, 0.7375, 0.525, 0.325, 0.5333, 0.6167, 0.4708, 0.5625, 0.4833, 0.5542, 0.525, 0.5

    Avg Losses: [0.48196, 0.39495, 0.64809, 0.63741, 0.49077, 0.60653, 0.66273, 0.3823, 0.34925, 0.33786, 0.29736, 0.29569, 0.48175, 0.46401, 0.42631, 0.31803, 0.68814, 0.28315, 0.42986, 0.30587, 0.58295, 0.71571, 0.51475, 0.66945, 0.57366, 0.25418, 0.52692, 0.42819, 0.38567, 0.62894, 0.46286, 0.4921, 0.43013, 0.5892, 0.69507, 0.62799, 0.70596, 0.38762, 0.42889, 0.62651, 0.25701, 0.34896, 0.82233, 0.27526, 0.57192, 0.63084, 0.34517, 0.3787, 0.41386, 0.31513, 0.41947, 0.45023, 0.34627, 0.39547, 0.66425, 0.26189, 0.59932, 0.61769, 0.72941, 0.82716, 0.54384, 0.35755, 0.39575, 0.58415, 0.3106, 0.45865, 0.79431, 0.39925, 0.39961, 0.41193, 0.70602, 0.29866, 0.71757, 0.65548, 0.40805, 0.3696, 0.31718, 0.88587, 0.25379, 0.34058, 0.54434, 0.37237, 0.61531, 0.26922, 0.47284, 0.28939, 0.73809, 0.45363, 0.31637, 0.24237, 0.32734, 0.28693, 0.52126, 0.46532, 0.54293, 0.34269, 0.41552, 0.27459, 0.3654, 0.53991, 0.72576, 0.94148, 0.44812, 0.71731, 0.30885, 0.3954, 0.54082, 0.53909, 0.42298, 0.64224, 0

    Avg Accuracies: [0.9042, 0.9375, 0.8875, 0.8958, 0.8875, 0.8625, 0.8292, 0.9583, 0.9292, 0.9125, 0.9542, 0.9292, 0.925, 0.9, 0.9042, 0.9417, 0.8833, 0.9417, 0.9042, 0.95, 0.8708, 0.8583, 0.8958, 0.8417, 0.8958, 0.9458, 0.8958, 0.9167, 0.9167, 0.8833, 0.9375, 0.9, 0.9125, 0.8167, 0.8625, 0.8708, 0.8042, 0.9, 0.9083, 0.8708, 0.9667, 0.9333, 0.8208, 0.9458, 0.8792, 0.9042, 0.8958, 0.875, 0.8958, 0.95, 0.9083, 0.9292, 0.9458, 0.9375, 0.8417, 0.9458, 0.875, 0.8542, 0.875, 0.8125, 0.875, 0.8667, 0.8958, 0.8625, 0.95, 0.9167, 0.8458, 0.8875, 0.9208, 0.9083, 0.8333, 0.9375, 0.8833, 0.8708, 0.9208, 0.9333, 0.925, 0.8083, 0.9625, 0.95, 0.875, 0.9375, 0.8708, 0.9333, 0.9125, 0.9583, 0.8417, 0.8958, 0.9542, 0.9667, 0.9625, 0.8917, 0.8875, 0.9208, 0.875, 0.9458, 0.9042, 0.9458, 0.8875, 0.8833, 0.8542, 0.8167, 0.9042, 0.8625, 0.9542, 0.9292, 0.8792, 0.8792, 0.9042, 0.9, 0.8667, 0.8375, 0.9375, 0.9, 0.8542, 0.9125, 0.9125, 0.9, 0.7917, 0.9083, 0.8917, 0.8667, 0.9042, 0.9167, 0.8875, 0.8917, 0.891

    Avg Losses: [0.1074, 0.06861, 0.12346, 0.16572, 0.19751, 0.15397, 0.19522, 0.09769, 0.07793, 0.10236, 0.11256, 0.08271, 0.13709, 0.10065, 0.08645, 0.10457, 0.20076, 0.06274, 0.15415, 0.13035, 0.14771, 0.20123, 0.11762, 0.35824, 0.13082, 0.04693, 0.15365, 0.11319, 0.09722, 0.20756, 0.11444, 0.16824, 0.11068, 0.25624, 0.19169, 0.25073, 0.27459, 0.12609, 0.13266, 0.16434, 0.0484, 0.07022, 0.2266, 0.06401, 0.17943, 0.11122, 0.10853, 0.21967, 0.17725, 0.08947, 0.07911, 0.09582, 0.07157, 0.20472, 0.20377, 0.08675, 0.17988, 0.27651, 0.25836, 0.23588, 0.14984, 0.1217, 0.18287, 0.23356, 0.09189, 0.12174, 0.25915, 0.10465, 0.09447, 0.23169, 0.27607, 0.10575, 0.24177, 0.17098, 0.19736, 0.14328, 0.09079, 0.30686, 0.05961, 0.06423, 0.12343, 0.05246, 0.1385, 0.14781, 0.16574, 0.04048, 0.24366, 0.27874, 0.07252, 0.08183, 0.06819, 0.07301, 0.16821, 0.20635, 0.23753, 0.04492, 0.08764, 0.07426, 0.07873, 0.15205, 0.15731, 0.24319, 0.12427, 0.26653, 0.04785, 0.11176, 0.17957, 0.19215, 0.11156, 0.17336

    Avg Accuracies: [0.9875, 0.9833, 0.975, 0.9875, 0.9458, 0.975, 0.9708, 0.9583, 0.9958, 0.975, 0.9833, 0.9875, 0.975, 0.9792, 0.9792, 0.975, 0.9375, 0.9958, 0.95, 0.9542, 0.9708, 0.95, 0.9792, 0.8958, 0.9792, 0.9917, 0.9958, 0.9875, 0.9875, 0.9583, 0.9667, 0.9625, 0.9833, 0.9333, 0.9583, 0.9458, 0.9583, 0.975, 0.9875, 0.9667, 1.0, 0.9833, 0.9667, 0.9917, 0.9708, 0.9792, 0.9833, 0.9667, 0.9667, 0.9667, 0.9875, 0.9708, 0.9875, 0.9542, 0.9417, 0.9917, 0.975, 0.9708, 0.9375, 0.9542, 0.9583, 0.9625, 0.9458, 0.95, 0.9958, 0.9708, 0.9583, 1.0, 0.9917, 0.925, 0.9333, 0.975, 0.9625, 0.9792, 0.975, 0.9667, 0.9792, 0.9375, 0.9875, 0.9958, 0.9875, 0.9917, 0.9875, 0.9792, 0.9708, 1.0, 0.9375, 0.9208, 0.9917, 0.9667, 0.9833, 0.9875, 0.9625, 0.9583, 0.9458, 1.0, 0.9792, 1.0, 0.9833, 0.9625, 0.9833, 0.9792, 0.9667, 0.925, 0.9958, 0.9917, 0.9583, 0.9833, 0.9625, 0.9625, 0.95, 0.975, 0.9667, 0.9792, 0.9833, 0.9917, 0.975, 0.9458, 0.9375, 0.9583, 0.9917, 0.9708, 0.9667, 0.9792, 0.9708, 0.9833, 0.9917,

    Avg Losses: [0.031, 0.04517, 0.03066, 0.03853, 0.08106, 0.04407, 0.06208, 0.04161, 0.02878, 0.02693, 0.04314, 0.03276, 0.04891, 0.06261, 0.03742, 0.08305, 0.05574, 0.03059, 0.10272, 0.03446, 0.06582, 0.05777, 0.03435, 0.1887, 0.04476, 0.01933, 0.04384, 0.04694, 0.03277, 0.05583, 0.05274, 0.09152, 0.03383, 0.06364, 0.1063, 0.1065, 0.19398, 0.03162, 0.02038, 0.0596, 0.00964, 0.03247, 0.07761, 0.01417, 0.06868, 0.02896, 0.03351, 0.05543, 0.05676, 0.08458, 0.02304, 0.04318, 0.01237, 0.07209, 0.04412, 0.01501, 0.05284, 0.06108, 0.06457, 0.10537, 0.05127, 0.04609, 0.04659, 0.07577, 0.02377, 0.05176, 0.07767, 0.01572, 0.01738, 0.07799, 0.06597, 0.03406, 0.0525, 0.05407, 0.05392, 0.04268, 0.03771, 0.15664, 0.01386, 0.01508, 0.04156, 0.01018, 0.04226, 0.02449, 0.03742, 0.00893, 0.05948, 0.1312, 0.02544, 0.03908, 0.02674, 0.02006, 0.05536, 0.06879, 0.07012, 0.01602, 0.07654, 0.01529, 0.02497, 0.11268, 0.0302, 0.09495, 0.04479, 0.10638, 0.01686, 0.02393, 0.08488, 0.04661, 0.02749, 0.06641, 0.

## Aggregate Teacher Models

In [7]:
def aggregate_counts(img):
    assert 3 <= img.dim() <= 4
    if img.dim() == 3:
        img = img.unsqueeze(0)
    else:
        assert img.size(0) == 1

    img = img.to(device)

    preds_list = []

    with torch.no_grad():
        for model in teachers:
            model.eval()
            preds_list.append(model(img).argmax(dim=1).view(1).cpu())

    preds_tensor = torch.cat(preds_list, dim=0)
    
    counts = torch.bincount(preds_tensor, minlength=10)
    
    return counts

---

In [8]:
for model in teachers:
    model.eval()

# Teacher Test
if True:
    test_dataloader = data.DataLoader(test_dataset, batch_size=1024, shuffle=False, drop_last=False)
    criterion = nn.CrossEntropyLoss(reduction='sum')

    instance_count = 0
    total_loss     = 0.
    correct_count  = 0.

    preds_lists_list = []
    labels_list      = []

    n_batches = len(test_dataloader)
    _prev_str_len = 0
    for i, (imgs, labels) in enumerate(test_dataloader):

        imgs   = imgs.to(device)
        labels_list.append(labels)
        labels = labels.to(device)

        preds_list = []
        with torch.no_grad():
            for j, model in enumerate(teachers):
                instance_count += imgs.size(0)

                _progress_str = "Batch {:d}/{:d} - Teacher {:d}/{:d}".format(i, n_batches-1, j, n_teachers-1)
                print(_progress_str + ' ' * (_prev_str_len - len(_progress_str)), end='\r')
                _prev_str_len = len(_progress_str)

                outs  = model(imgs)
                preds = outs.argmax(dim=1)
                preds_list.append(preds.cpu())

                total_loss += criterion(outs, labels).item()
                correct_count += (preds == labels).sum().item()
        preds_lists_list.append(preds_list)

    preds_tensor = torch.cat([torch.stack(preds_list, dim=0) for preds_list in preds_lists_list], dim=1)
    preds_counts = torch.from_numpy(np.apply_along_axis(lambda x: np.bincount(x, minlength=10), axis=0, arr=preds_tensor.numpy()))

    aggregate_preds = preds_counts.argmax(dim=0)
    labels          = torch.cat(labels_list, dim=0)

    aggregate_acc = (aggregate_preds == labels).float().mean().item()

    print('\n')
    print("Average Loss:", total_loss / instance_count)
    print("Average Accuracy:", correct_count / instance_count)
    print()
    print("Aggregate Accuracy:", aggregate_acc)

Batch 0/0 - Teacher 249/249

Average Loss: 0.7388171423339843
Average Accuracy: 0.815684

Aggregate Accuracy: 0.9629999995231628


In [None]:
Batch 0/0 - Teacher 249/249

Average Loss: 0.7504612662353516
Average Accuracy: 0.814112

Aggregate Accuracy: 0.9670000076293945

In [9]:
for model in teachers:
    model.eval()

dataloader = data.DataLoader(student_dataset, batch_size=512, shuffle=False, drop_last=False)

batches_of_preds = []

n_batches = len(dataloader)
_prev_str_len = 0
for i, (imgs, _) in enumerate(dataloader):
    imgs = imgs.to(device)

    batch_of_preds = []
    
    with torch.no_grad():
        for j, model in enumerate(teachers):
            _progress_str = "Batch {:d}/{:d} - Teacher {:d}/{:d}".format(i, n_batches-1, j, n_teachers-1)
            print(_progress_str + ' ' * (_prev_str_len - len(_progress_str)), end='\r')
            _prev_str_len = len(_progress_str)

            outs  = model(imgs)
            preds = outs.argmax(dim=1)
            batch_of_preds.append(preds.cpu())
    
    batches_of_preds.append(batch_of_preds)
        
label_preds = torch.cat(
    [torch.stack([preds for preds in batch_of_preds], dim=0)
     for batch_of_preds in batches_of_preds],
    dim=1)

print()
print(label_preds)

Batch 17/17 - Teacher 249/249
tensor([[7, 2, 1,  ..., 6, 9, 0],
        [7, 2, 1,  ..., 6, 9, 0],
        [7, 2, 1,  ..., 6, 9, 0],
        ...,
        [7, 2, 1,  ..., 6, 9, 0],
        [7, 2, 1,  ..., 6, 4, 0],
        [7, 2, 1,  ..., 6, 9, 0]])


In [10]:
label_counts = torch.from_numpy(np.apply_along_axis(lambda x: np.bincount(x, minlength=10), axis=0, arr=label_preds.numpy()))
print(label_counts)

tensor([[  0,   1,   0,  ...,   0,   1, 237],
        [  0,   3, 248,  ...,   0,   0,   0],
        [  0, 225,   0,  ...,   0,   2,   5],
        ...,
        [250,   0,   0,  ...,   0,  19,   0],
        [  0,   6,   1,  ...,   0,   8,   2],
        [  0,   0,   0,  ...,   0, 204,   0]])


In [11]:
epsilon = 0.05

noise_dist = dists.Laplace(loc=torch.zeros([], dtype=torch.float),
                           scale=torch.full([], 1 / epsilon, dtype=torch.float))

noisy_counts = label_counts.float() + noise_dist.sample([10, label_counts.size(1)])

generated_labels = noisy_counts.argmax(dim=0)
print(generated_labels)
print()
print("Noisy Accuracy Against Predictions:", (generated_labels == label_counts.argmax(dim=0)).float().mean().item())

tensor([7, 2, 1,  ..., 6, 9, 0])

Noisy Accuracy Against Predictions: 0.973111093044281


In [12]:
import itertools

In [13]:
def train(max_train_subset_size, n_new_labels_per_epoch, n_updates_per_epoch, lr, weight_decay, epsilon):
    spec_dict = {
        "max_size": max_train_subset_size,
        "n_new_labels": n_new_labels_per_epoch,
        "n_updates": n_updates_per_epoch,
        "lr": lr,
        "wd": weight_decay,
        "eps": epsilon
    }

    student = MNISTClassifier().to(device)
    
    laplace_noise = dists.Laplace(torch.zeros([], dtype=torch.float), torch.tensor(1 / epsilon, dtype=torch.float))

    init_subset_size = n_new_labels_per_epoch + (max_train_subset_size % n_new_labels_per_epoch)
    n_total_epochs   = max_train_subset_size // n_new_labels_per_epoch

    student_unlabeled_dataset    = data.Subset(mnist_testset, list(student_dataset.indices))
    student_unlabeled_dataloader = data.DataLoader(student_unlabeled_dataset, batch_size=1024, shuffle=False, drop_last=False)
    student_labeled_dataset      = data.Subset(mnist_testset, [])
    student_labeled_dataloader   = data.DataLoader(student_labeled_dataset,
                                                   batch_sampler=data.BatchSampler(data.SequentialSampler(student_unlabeled_dataset), init_subset_size, True))

    student_optimizer            = optim.Adam(student.parameters(), lr=lr, weight_decay=weight_decay)
    criterion                    = nn.CrossEntropyLoss()

    student.train()

    for i_epoch in range(n_total_epochs):
        if i_epoch == 0:
            new_label_indices = random.sample(student_unlabeled_dataset.indices, init_subset_size)

        else:
            max_probs_list = []

            with torch.no_grad():
                for imgs, _ in student_unlabeled_dataloader:
                    imgs = imgs.to(device)

                    outs  = student(imgs)
                    probs = outs.softmax(dim=1)

                    max_probs_list.append(probs.max(dim=1)[0].cpu())

            max_probs_tensor = torch.cat(max_probs_list, dim=0)

            new_label_indices = [student_unlabeled_dataset.indices[idx] for idx in
                                 max_probs_tensor.topk(n_new_labels_per_epoch, largest=False, sorted=False)[1]]

        for idx in new_label_indices:
            student_labeled_dataset.indices.append(idx)
            student_unlabeled_dataset.indices.remove(idx)
            label_pred_counts = aggregate_counts(mnist_testset[idx][0].view(1, 1, 28, 28))
            noisy_label = (label_pred_counts.float() + laplace_noise.sample(label_pred_counts.size())).argmax(dim=0)
            mnist_testset.targets[idx] = noisy_label

        student_labeled_dataloader.batch_sampler.batch_size = len(student_labeled_dataset)

        for i_update in range(n_updates_per_epoch):
            imgs, labels = next(iter(student_labeled_dataloader))

            imgs   = imgs.to(device)
            labels = labels.to(device)

            outs  = student(imgs)
            preds = torch.argmax(outs, dim=1)

            student_optimizer.zero_grad()
            loss = criterion(outs, labels)
            loss.backward()
            student_optimizer.step()

    student.eval()

    criterion = nn.CrossEntropyLoss(reduction='sum')

    instance_count = 0
    total_loss     = 0.
    correct_count  = 0.

    test_dataloader = data.DataLoader(test_dataset, batch_size=batch_size, shuffle=False, drop_last=False)

    for imgs, labels in test_dataloader:
        instance_count += imgs.size(0)

        imgs   = imgs.to(device)
        labels = labels.to(device)

        with torch.no_grad():
            outs  = student(imgs)

        total_loss += criterion(outs, labels).item()

        preds = outs.argmax(dim=1)

        correct_count += (preds == labels).sum().item()

    print(spec_dict)
    print("    Average Loss:", total_loss / instance_count)
    print("    Average Accuracy:", correct_count / instance_count)
    print()

In [22]:
grid_spec = {
        "max_train_subset_size":  [50, 100, 150],
        "n_new_labels_per_epoch": [2, 3, 5, 10],
        "n_updates_per_epoch":    [10, 20, 30, 40],
        "lr":                     [5e-3, 1e-2, 2e-2, 3e-2],
        "weight_decay":           [0],
        "epsilon":                [0.05, 0.07, 0.1]
}

spec_iter = itertools.product(grid_spec["max_train_subset_size"],
                              grid_spec["n_new_labels_per_epoch"],
                              grid_spec["n_updates_per_epoch"],
                              grid_spec["lr"],
                              grid_spec["weight_decay"],
                              grid_spec["epsilon"])

n_specs = np.prod([len(options) for options in grid_spec.values()])

for i, spec in enumerate(spec_iter):
    print("{:d}/{:d}".format(i, n_specs-1), end='\r')
    train(*spec)

{'max_size': 50, 'n_new_labels': 2, 'n_updates': 10, 'lr': 0.005, 'wd': 0, 'eps': 0.05}
    Average Loss: 1.7401377468109132
    Average Accuracy: 0.546

{'max_size': 50, 'n_new_labels': 2, 'n_updates': 10, 'lr': 0.005, 'wd': 0, 'eps': 0.07}
    Average Loss: 1.6192414531707764
    Average Accuracy: 0.598

{'max_size': 50, 'n_new_labels': 2, 'n_updates': 10, 'lr': 0.005, 'wd': 0, 'eps': 0.1}
    Average Loss: 2.3291448612213133
    Average Accuracy: 0.42

{'max_size': 50, 'n_new_labels': 2, 'n_updates': 10, 'lr': 0.01, 'wd': 0, 'eps': 0.05}
    Average Loss: 1.6666455268859863
    Average Accuracy: 0.635

{'max_size': 50, 'n_new_labels': 2, 'n_updates': 10, 'lr': 0.01, 'wd': 0, 'eps': 0.07}
    Average Loss: 2.186118989944458
    Average Accuracy: 0.576

{'max_size': 50, 'n_new_labels': 2, 'n_updates': 10, 'lr': 0.01, 'wd': 0, 'eps': 0.1}
    Average Loss: 2.083424431800842
    Average Accuracy: 0.565

{'max_size': 50, 'n_new_labels': 2, 'n_updates': 10, 'lr': 0.02, 'wd': 0, 'eps': 0.0

    Average Loss: 1.9830959701538087
    Average Accuracy: 0.511

{'max_size': 50, 'n_new_labels': 3, 'n_updates': 10, 'lr': 0.02, 'wd': 0, 'eps': 0.05}
    Average Loss: 2.449305709838867
    Average Accuracy: 0.587

{'max_size': 50, 'n_new_labels': 3, 'n_updates': 10, 'lr': 0.02, 'wd': 0, 'eps': 0.07}
    Average Loss: 2.409902581214905
    Average Accuracy: 0.544

{'max_size': 50, 'n_new_labels': 3, 'n_updates': 10, 'lr': 0.02, 'wd': 0, 'eps': 0.1}
    Average Loss: 1.7087656302452088
    Average Accuracy: 0.633

{'max_size': 50, 'n_new_labels': 3, 'n_updates': 10, 'lr': 0.03, 'wd': 0, 'eps': 0.05}
    Average Loss: 2.440125967025757
    Average Accuracy: 0.606

{'max_size': 50, 'n_new_labels': 3, 'n_updates': 10, 'lr': 0.03, 'wd': 0, 'eps': 0.07}
    Average Loss: 3.882114326477051
    Average Accuracy: 0.492

{'max_size': 50, 'n_new_labels': 3, 'n_updates': 10, 'lr': 0.03, 'wd': 0, 'eps': 0.1}
    Average Loss: 2.2087627201080324
    Average Accuracy: 0.582

{'max_size': 50, 'n_ne

{'max_size': 50, 'n_new_labels': 5, 'n_updates': 10, 'lr': 0.03, 'wd': 0, 'eps': 0.1}
    Average Loss: 4.808572399139404
    Average Accuracy: 0.445

{'max_size': 50, 'n_new_labels': 5, 'n_updates': 20, 'lr': 0.005, 'wd': 0, 'eps': 0.05}
    Average Loss: 1.0773349990844727
    Average Accuracy: 0.641

{'max_size': 50, 'n_new_labels': 5, 'n_updates': 20, 'lr': 0.005, 'wd': 0, 'eps': 0.07}
    Average Loss: 1.391078387737274
    Average Accuracy: 0.583

{'max_size': 50, 'n_new_labels': 5, 'n_updates': 20, 'lr': 0.005, 'wd': 0, 'eps': 0.1}
    Average Loss: 1.7164187660217285
    Average Accuracy: 0.547

{'max_size': 50, 'n_new_labels': 5, 'n_updates': 20, 'lr': 0.01, 'wd': 0, 'eps': 0.05}
    Average Loss: 1.0420264544487
    Average Accuracy: 0.701

{'max_size': 50, 'n_new_labels': 5, 'n_updates': 20, 'lr': 0.01, 'wd': 0, 'eps': 0.07}
    Average Loss: 1.8005458850860596
    Average Accuracy: 0.549

{'max_size': 50, 'n_new_labels': 5, 'n_updates': 20, 'lr': 0.01, 'wd': 0, 'eps': 0.1}


    Average Loss: 1.7249375
    Average Accuracy: 0.595

{'max_size': 50, 'n_new_labels': 10, 'n_updates': 20, 'lr': 0.01, 'wd': 0, 'eps': 0.1}
    Average Loss: 3.671211252212524
    Average Accuracy: 0.432

{'max_size': 50, 'n_new_labels': 10, 'n_updates': 20, 'lr': 0.02, 'wd': 0, 'eps': 0.05}
    Average Loss: 1.7682519617080688
    Average Accuracy: 0.616

{'max_size': 50, 'n_new_labels': 10, 'n_updates': 20, 'lr': 0.02, 'wd': 0, 'eps': 0.07}
    Average Loss: 2.35464435005188
    Average Accuracy: 0.585

{'max_size': 50, 'n_new_labels': 10, 'n_updates': 20, 'lr': 0.02, 'wd': 0, 'eps': 0.1}
    Average Loss: 2.2872038650512696
    Average Accuracy: 0.467

{'max_size': 50, 'n_new_labels': 10, 'n_updates': 20, 'lr': 0.03, 'wd': 0, 'eps': 0.05}
    Average Loss: 2.9423513832092287
    Average Accuracy: 0.536

{'max_size': 50, 'n_new_labels': 10, 'n_updates': 20, 'lr': 0.03, 'wd': 0, 'eps': 0.07}
    Average Loss: 2.565483946800232
    Average Accuracy: 0.578

{'max_size': 50, 'n_new_l

{'max_size': 100, 'n_new_labels': 2, 'n_updates': 20, 'lr': 0.03, 'wd': 0, 'eps': 0.07}
    Average Loss: 1.9050000867843628
    Average Accuracy: 0.765

{'max_size': 100, 'n_new_labels': 2, 'n_updates': 20, 'lr': 0.03, 'wd': 0, 'eps': 0.1}
    Average Loss: 1.9419033735990525
    Average Accuracy: 0.743

{'max_size': 100, 'n_new_labels': 2, 'n_updates': 30, 'lr': 0.005, 'wd': 0, 'eps': 0.05}
    Average Loss: 1.0188584852218627
    Average Accuracy: 0.769

{'max_size': 100, 'n_new_labels': 2, 'n_updates': 30, 'lr': 0.005, 'wd': 0, 'eps': 0.07}
    Average Loss: 1.7361435236930847
    Average Accuracy: 0.718

{'max_size': 100, 'n_new_labels': 2, 'n_updates': 30, 'lr': 0.005, 'wd': 0, 'eps': 0.1}
    Average Loss: 1.2476792407035828
    Average Accuracy: 0.718

{'max_size': 100, 'n_new_labels': 2, 'n_updates': 30, 'lr': 0.01, 'wd': 0, 'eps': 0.05}
    Average Loss: 2.008960277557373
    Average Accuracy: 0.765

{'max_size': 100, 'n_new_labels': 2, 'n_updates': 30, 'lr': 0.01, 'wd': 0, '

{'max_size': 100, 'n_new_labels': 3, 'n_updates': 30, 'lr': 0.01, 'wd': 0, 'eps': 0.05}
    Average Loss: 1.4294540231227875
    Average Accuracy: 0.741

{'max_size': 100, 'n_new_labels': 3, 'n_updates': 30, 'lr': 0.01, 'wd': 0, 'eps': 0.07}
    Average Loss: 1.097004559993744
    Average Accuracy: 0.794

{'max_size': 100, 'n_new_labels': 3, 'n_updates': 30, 'lr': 0.01, 'wd': 0, 'eps': 0.1}
    Average Loss: 1.2301035461425782
    Average Accuracy: 0.765

{'max_size': 100, 'n_new_labels': 3, 'n_updates': 30, 'lr': 0.02, 'wd': 0, 'eps': 0.05}
    Average Loss: 1.6357698297500611
    Average Accuracy: 0.744

{'max_size': 100, 'n_new_labels': 3, 'n_updates': 30, 'lr': 0.02, 'wd': 0, 'eps': 0.07}
    Average Loss: 1.7134089258313179
    Average Accuracy: 0.757

{'max_size': 100, 'n_new_labels': 3, 'n_updates': 30, 'lr': 0.02, 'wd': 0, 'eps': 0.1}
    Average Loss: 1.100827702522278
    Average Accuracy: 0.775

{'max_size': 100, 'n_new_labels': 3, 'n_updates': 30, 'lr': 0.03, 'wd': 0, 'eps'

{'max_size': 100, 'n_new_labels': 5, 'n_updates': 30, 'lr': 0.02, 'wd': 0, 'eps': 0.1}
    Average Loss: 1.2363315143585205
    Average Accuracy: 0.779

{'max_size': 100, 'n_new_labels': 5, 'n_updates': 30, 'lr': 0.03, 'wd': 0, 'eps': 0.05}
    Average Loss: 2.9784038939476014
    Average Accuracy: 0.736

{'max_size': 100, 'n_new_labels': 5, 'n_updates': 30, 'lr': 0.03, 'wd': 0, 'eps': 0.07}
    Average Loss: 3.1566335678100588
    Average Accuracy: 0.649

{'max_size': 100, 'n_new_labels': 5, 'n_updates': 30, 'lr': 0.03, 'wd': 0, 'eps': 0.1}
    Average Loss: 1.9763067915439605
    Average Accuracy: 0.748

{'max_size': 100, 'n_new_labels': 5, 'n_updates': 40, 'lr': 0.005, 'wd': 0, 'eps': 0.05}
    Average Loss: 1.086678268313408
    Average Accuracy: 0.742

{'max_size': 100, 'n_new_labels': 5, 'n_updates': 40, 'lr': 0.005, 'wd': 0, 'eps': 0.07}
    Average Loss: 0.9244445152282715
    Average Accuracy: 0.745

{'max_size': 100, 'n_new_labels': 5, 'n_updates': 40, 'lr': 0.005, 'wd': 0, '

{'max_size': 100, 'n_new_labels': 10, 'n_updates': 40, 'lr': 0.005, 'wd': 0, 'eps': 0.07}
    Average Loss: 0.8879588589668274
    Average Accuracy: 0.749

{'max_size': 100, 'n_new_labels': 10, 'n_updates': 40, 'lr': 0.005, 'wd': 0, 'eps': 0.1}
    Average Loss: 0.9231351819038391
    Average Accuracy: 0.726

{'max_size': 100, 'n_new_labels': 10, 'n_updates': 40, 'lr': 0.01, 'wd': 0, 'eps': 0.05}
    Average Loss: 1.0679701032638549
    Average Accuracy: 0.703

{'max_size': 100, 'n_new_labels': 10, 'n_updates': 40, 'lr': 0.01, 'wd': 0, 'eps': 0.07}
    Average Loss: 0.7244054468870162
    Average Accuracy: 0.806

{'max_size': 100, 'n_new_labels': 10, 'n_updates': 40, 'lr': 0.01, 'wd': 0, 'eps': 0.1}
    Average Loss: 0.6762572232484817
    Average Accuracy: 0.838

{'max_size': 100, 'n_new_labels': 10, 'n_updates': 40, 'lr': 0.02, 'wd': 0, 'eps': 0.05}
    Average Loss: 1.5854909162521362
    Average Accuracy: 0.716

{'max_size': 100, 'n_new_labels': 10, 'n_updates': 40, 'lr': 0.02, 'wd

{'max_size': 150, 'n_new_labels': 2, 'n_updates': 40, 'lr': 0.02, 'wd': 0, 'eps': 0.05}
    Average Loss: 1.6990147406458855
    Average Accuracy: 0.788

{'max_size': 150, 'n_new_labels': 2, 'n_updates': 40, 'lr': 0.02, 'wd': 0, 'eps': 0.07}
    Average Loss: 1.409909659266472
    Average Accuracy: 0.857

{'max_size': 150, 'n_new_labels': 2, 'n_updates': 40, 'lr': 0.02, 'wd': 0, 'eps': 0.1}
    Average Loss: 2.0123518241643907
    Average Accuracy: 0.811

{'max_size': 150, 'n_new_labels': 2, 'n_updates': 40, 'lr': 0.03, 'wd': 0, 'eps': 0.05}
    Average Loss: 3.335244556427002
    Average Accuracy: 0.721

{'max_size': 150, 'n_new_labels': 2, 'n_updates': 40, 'lr': 0.03, 'wd': 0, 'eps': 0.07}
    Average Loss: 1.2559509642124176
    Average Accuracy: 0.855

{'max_size': 150, 'n_new_labels': 2, 'n_updates': 40, 'lr': 0.03, 'wd': 0, 'eps': 0.1}
    Average Loss: 4.51027382183075
    Average Accuracy: 0.743

{'max_size': 150, 'n_new_labels': 3, 'n_updates': 10, 'lr': 0.005, 'wd': 0, 'eps':

{'max_size': 150, 'n_new_labels': 3, 'n_updates': 40, 'lr': 0.03, 'wd': 0, 'eps': 0.1}
    Average Loss: 3.0882794016897677
    Average Accuracy: 0.73

{'max_size': 150, 'n_new_labels': 5, 'n_updates': 10, 'lr': 0.005, 'wd': 0, 'eps': 0.05}
    Average Loss: 0.7720988602638245
    Average Accuracy: 0.773

{'max_size': 150, 'n_new_labels': 5, 'n_updates': 10, 'lr': 0.005, 'wd': 0, 'eps': 0.07}
    Average Loss: 1.047406002998352
    Average Accuracy: 0.743

{'max_size': 150, 'n_new_labels': 5, 'n_updates': 10, 'lr': 0.005, 'wd': 0, 'eps': 0.1}
    Average Loss: 0.7022094553709031
    Average Accuracy: 0.783

{'max_size': 150, 'n_new_labels': 5, 'n_updates': 10, 'lr': 0.01, 'wd': 0, 'eps': 0.05}
    Average Loss: 0.6987470054030418
    Average Accuracy: 0.817

{'max_size': 150, 'n_new_labels': 5, 'n_updates': 10, 'lr': 0.01, 'wd': 0, 'eps': 0.07}
    Average Loss: 0.8359587993621826
    Average Accuracy: 0.784

{'max_size': 150, 'n_new_labels': 5, 'n_updates': 10, 'lr': 0.01, 'wd': 0, 'e

{'max_size': 150, 'n_new_labels': 10, 'n_updates': 10, 'lr': 0.01, 'wd': 0, 'eps': 0.07}
    Average Loss: 0.7000357268452644
    Average Accuracy: 0.796

{'max_size': 150, 'n_new_labels': 10, 'n_updates': 10, 'lr': 0.01, 'wd': 0, 'eps': 0.1}
    Average Loss: 0.6823592698574066
    Average Accuracy: 0.791

{'max_size': 150, 'n_new_labels': 10, 'n_updates': 10, 'lr': 0.02, 'wd': 0, 'eps': 0.05}
    Average Loss: 0.980935721874237
    Average Accuracy: 0.764

{'max_size': 150, 'n_new_labels': 10, 'n_updates': 10, 'lr': 0.02, 'wd': 0, 'eps': 0.07}
    Average Loss: 0.6335765900611877
    Average Accuracy: 0.827

{'max_size': 150, 'n_new_labels': 10, 'n_updates': 10, 'lr': 0.02, 'wd': 0, 'eps': 0.1}
    Average Loss: 1.2554666647911072
    Average Accuracy: 0.747

{'max_size': 150, 'n_new_labels': 10, 'n_updates': 10, 'lr': 0.03, 'wd': 0, 'eps': 0.05}
    Average Loss: 0.6527966257333756
    Average Accuracy: 0.843

{'max_size': 150, 'n_new_labels': 10, 'n_updates': 10, 'lr': 0.03, 'wd': 

---

In [12]:
student = MNISTClassifier().to(device)

In [13]:
max_train_subset_size  = 100
n_new_labels_per_epoch = 5
n_updates_per_epoch    = 40
lr                     = 2e-2
weight_decay           = 0
epsilon                = 0.05

laplace_noise = dists.Laplace(torch.zeros([], dtype=torch.float), torch.tensor(1 / epsilon, dtype=torch.float))

init_subset_size = n_new_labels_per_epoch + (max_train_subset_size % n_new_labels_per_epoch)
n_total_epochs   = max_train_subset_size // n_new_labels_per_epoch

student_unlabeled_dataset    = data.Subset(mnist_testset, list(student_dataset.indices))
student_unlabeled_dataloader = data.DataLoader(student_unlabeled_dataset, batch_size=1024, shuffle=False, drop_last=False)
student_labeled_dataset      = data.Subset(mnist_testset, [])
student_labeled_dataloader   = data.DataLoader(student_labeled_dataset,
                                               batch_sampler=data.BatchSampler(data.SequentialSampler(student_unlabeled_dataset), init_subset_size, True))

student_optimizer            = optim.Adam(student.parameters(), lr=lr, weight_decay=weight_decay)
criterion                    = nn.CrossEntropyLoss()

student.train()

teacher_preds = []
noisy_labels  = []
student_train_history = {'n_labels': {}, 'avg_losses':{}, 'avg_accuracies': {}}
for i_epoch in range(n_total_epochs):
    if i_epoch == 0:
        new_label_indices = random.sample(student_unlabeled_dataset.indices, init_subset_size)

    else:
        max_probs_list = []

        with torch.no_grad():
            for imgs, _ in student_unlabeled_dataloader:
                imgs = imgs.to(device)

                outs  = student(imgs)
                probs = outs.softmax(dim=1)

                max_probs_list.append(probs.max(dim=1)[0].cpu())

        max_probs_tensor = torch.cat(max_probs_list, dim=0)
        
        new_label_indices = [student_unlabeled_dataset.indices[idx] for idx in
                             max_probs_tensor.topk(n_new_labels_per_epoch, largest=False, sorted=False)[1]]

    for idx in new_label_indices:
        student_labeled_dataset.indices.append(idx)
        student_unlabeled_dataset.indices.remove(idx)
        label_pred_counts = aggregate_counts(mnist_testset[idx][0].view(1, 1, 28, 28))
        noisy_label = (label_pred_counts.float() + laplace_noise.sample(label_pred_counts.size())).argmax(dim=0)
        mnist_testset.targets[idx] = noisy_label
        teacher_preds.append(label_pred_counts)
        noisy_labels.append(noisy_label)
        
    student_labeled_dataloader.batch_sampler.batch_size = len(student_labeled_dataset)
    
    print("Epoch {:d}/{:d} - Number of Labeled Data: {:d}".format(i_epoch, n_total_epochs-1, len(student_labeled_dataset)))

    student_train_history['n_labels'][i_epoch] = init_subset_size + (i_epoch * n_new_labels_per_epoch)
    losses = []
    accuracies = []
    for i_update in range(n_updates_per_epoch):
        imgs, labels = next(iter(student_labeled_dataloader))

        imgs   = imgs.to(device)
        labels = labels.to(device)

        outs  = student(imgs)
        preds = torch.argmax(outs, dim=1)

        student_optimizer.zero_grad()
        loss = criterion(outs, labels)
        loss.backward()
        student_optimizer.step()

        accuracy = (preds == labels).float().mean().item()

        losses.append(loss.item())
        accuracies.append(accuracy)

        print("    Update {:d} - Loss = {:.6f}, Accuracy = {:.4}".format(i_update, loss.item(), accuracy))

    student_train_history['avg_losses'][i_epoch]     = losses
    student_train_history['avg_accuracies'][i_epoch] = accuracies

teacher_preds = torch.stack(teacher_preds, dim=1)
noisy_labels  = torch.tensor(noisy_labels)

Epoch 0/19 - Number of Labeled Data: 5
    Update 0 - Loss = 2.220031, Accuracy = 0.0
    Update 1 - Loss = 1.248176, Accuracy = 1.0
    Update 2 - Loss = 0.754863, Accuracy = 1.0
    Update 3 - Loss = 0.422651, Accuracy = 1.0
    Update 4 - Loss = 0.251692, Accuracy = 1.0
    Update 5 - Loss = 0.154266, Accuracy = 1.0
    Update 6 - Loss = 0.085805, Accuracy = 1.0
    Update 7 - Loss = 0.044131, Accuracy = 1.0
    Update 8 - Loss = 0.023445, Accuracy = 1.0
    Update 9 - Loss = 0.013550, Accuracy = 1.0
    Update 10 - Loss = 0.008267, Accuracy = 1.0
    Update 11 - Loss = 0.005395, Accuracy = 1.0
    Update 12 - Loss = 0.003596, Accuracy = 1.0
    Update 13 - Loss = 0.002460, Accuracy = 1.0
    Update 14 - Loss = 0.001749, Accuracy = 1.0
    Update 15 - Loss = 0.001303, Accuracy = 1.0
    Update 16 - Loss = 0.000998, Accuracy = 1.0
    Update 17 - Loss = 0.000784, Accuracy = 1.0
    Update 18 - Loss = 0.000633, Accuracy = 1.0
    Update 19 - Loss = 0.000517, Accuracy = 1.0
    Update 

    Update 7 - Loss = 0.028449, Accuracy = 1.0
    Update 8 - Loss = 0.026527, Accuracy = 1.0
    Update 9 - Loss = 0.021562, Accuracy = 1.0
    Update 10 - Loss = 0.016046, Accuracy = 1.0
    Update 11 - Loss = 0.011540, Accuracy = 1.0
    Update 12 - Loss = 0.008572, Accuracy = 1.0
    Update 13 - Loss = 0.006749, Accuracy = 1.0
    Update 14 - Loss = 0.005537, Accuracy = 1.0
    Update 15 - Loss = 0.004710, Accuracy = 1.0
    Update 16 - Loss = 0.004162, Accuracy = 1.0
    Update 17 - Loss = 0.003767, Accuracy = 1.0
    Update 18 - Loss = 0.003516, Accuracy = 1.0
    Update 19 - Loss = 0.003339, Accuracy = 1.0
    Update 20 - Loss = 0.003224, Accuracy = 1.0
    Update 21 - Loss = 0.003144, Accuracy = 1.0
    Update 22 - Loss = 0.003071, Accuracy = 1.0
    Update 23 - Loss = 0.002996, Accuracy = 1.0
    Update 24 - Loss = 0.002908, Accuracy = 1.0
    Update 25 - Loss = 0.002797, Accuracy = 1.0
    Update 26 - Loss = 0.002668, Accuracy = 1.0
    Update 27 - Loss = 0.002523, Accuracy =

    Update 14 - Loss = 0.001915, Accuracy = 1.0
    Update 15 - Loss = 0.001692, Accuracy = 1.0
    Update 16 - Loss = 0.001565, Accuracy = 1.0
    Update 17 - Loss = 0.001499, Accuracy = 1.0
    Update 18 - Loss = 0.001473, Accuracy = 1.0
    Update 19 - Loss = 0.001486, Accuracy = 1.0
    Update 20 - Loss = 0.001520, Accuracy = 1.0
    Update 21 - Loss = 0.001556, Accuracy = 1.0
    Update 22 - Loss = 0.001568, Accuracy = 1.0
    Update 23 - Loss = 0.001548, Accuracy = 1.0
    Update 24 - Loss = 0.001499, Accuracy = 1.0
    Update 25 - Loss = 0.001425, Accuracy = 1.0
    Update 26 - Loss = 0.001337, Accuracy = 1.0
    Update 27 - Loss = 0.001244, Accuracy = 1.0
    Update 28 - Loss = 0.001153, Accuracy = 1.0
    Update 29 - Loss = 0.001069, Accuracy = 1.0
    Update 30 - Loss = 0.000994, Accuracy = 1.0
    Update 31 - Loss = 0.000926, Accuracy = 1.0
    Update 32 - Loss = 0.000867, Accuracy = 1.0
    Update 33 - Loss = 0.000816, Accuracy = 1.0
    Update 34 - Loss = 0.000772, Accurac

    Update 21 - Loss = 0.002630, Accuracy = 1.0
    Update 22 - Loss = 0.002374, Accuracy = 1.0
    Update 23 - Loss = 0.002140, Accuracy = 1.0
    Update 24 - Loss = 0.001901, Accuracy = 1.0
    Update 25 - Loss = 0.001656, Accuracy = 1.0
    Update 26 - Loss = 0.001436, Accuracy = 1.0
    Update 27 - Loss = 0.001251, Accuracy = 1.0
    Update 28 - Loss = 0.001108, Accuracy = 1.0
    Update 29 - Loss = 0.000995, Accuracy = 1.0
    Update 30 - Loss = 0.000906, Accuracy = 1.0
    Update 31 - Loss = 0.000835, Accuracy = 1.0
    Update 32 - Loss = 0.000779, Accuracy = 1.0
    Update 33 - Loss = 0.000734, Accuracy = 1.0
    Update 34 - Loss = 0.000697, Accuracy = 1.0
    Update 35 - Loss = 0.000666, Accuracy = 1.0
    Update 36 - Loss = 0.000638, Accuracy = 1.0
    Update 37 - Loss = 0.000614, Accuracy = 1.0
    Update 38 - Loss = 0.000592, Accuracy = 1.0
    Update 39 - Loss = 0.000573, Accuracy = 1.0
Epoch 13/19 - Number of Labeled Data: 70
    Update 0 - Loss = 0.171946, Accuracy = 0.92

    Update 28 - Loss = 0.000980, Accuracy = 1.0
    Update 29 - Loss = 0.000910, Accuracy = 1.0
    Update 30 - Loss = 0.000859, Accuracy = 1.0
    Update 31 - Loss = 0.000821, Accuracy = 1.0
    Update 32 - Loss = 0.000795, Accuracy = 1.0
    Update 33 - Loss = 0.000775, Accuracy = 1.0
    Update 34 - Loss = 0.000757, Accuracy = 1.0
    Update 35 - Loss = 0.000738, Accuracy = 1.0
    Update 36 - Loss = 0.000718, Accuracy = 1.0
    Update 37 - Loss = 0.000696, Accuracy = 1.0
    Update 38 - Loss = 0.000672, Accuracy = 1.0
    Update 39 - Loss = 0.000647, Accuracy = 1.0
Epoch 17/19 - Number of Labeled Data: 90
    Update 0 - Loss = 0.164338, Accuracy = 0.9444
    Update 1 - Loss = 0.104738, Accuracy = 0.9889
    Update 2 - Loss = 0.054309, Accuracy = 0.9889
    Update 3 - Loss = 0.011782, Accuracy = 1.0
    Update 4 - Loss = 0.011469, Accuracy = 1.0
    Update 5 - Loss = 0.031673, Accuracy = 1.0
    Update 6 - Loss = 0.027237, Accuracy = 1.0
    Update 7 - Loss = 0.011243, Accuracy = 1.

In [14]:
student.eval()

criterion = nn.CrossEntropyLoss(reduction='sum')

instance_count = 0
total_loss     = 0.
correct_count  = 0.

test_dataloader = data.DataLoader(test_dataset, batch_size=batch_size, shuffle=False, drop_last=False)

n_batches = len(test_dataloader)
for i, (imgs, labels) in enumerate(test_dataloader):
    print("Batch {:d}/{:d}".format(i, n_batches-1), end='\r')

    instance_count += imgs.size(0)
    
    imgs   = imgs.to(device)
    labels = labels.to(device)
    
    with torch.no_grad():
        outs  = student(imgs)
    
    total_loss += criterion(outs, labels).item()

    preds = outs.argmax(dim=1)
    
    correct_count += (preds == labels).sum().item()

print()
print("Average Loss:", total_loss / instance_count)
print("Average Accuracy:", correct_count / instance_count)

Batch 33/33
Average Loss: 1.3769125180244446
Average Accuracy: 0.815


In [15]:
data_dep_eps, data_indep_eps = pate.perform_analysis(teacher_preds, noisy_labels, epsilon)
print("Data Independent Epsilon:", data_indep_eps)
print("Data Dependent Epsilon:", data_dep_eps)

Data Independent Epsilon: 5.302585092994046
Data Dependent Epsilon: 5.30258509299405


### Predict Labels Using Teacher Models

In [9]:
for model in teachers:
    model.eval()

dataloader = data.DataLoader(student_dataset, batch_size=512, shuffle=False, drop_last=False)

batches_of_preds = []

n_batches = len(dataloader)
_prev_str_len = 0
for i, (imgs, _) in enumerate(dataloader):
    imgs = imgs.to(device)

    batch_of_preds = []
    
    with torch.no_grad():
        for j, model in enumerate(teachers):
            _progress_str = "Batch {:d}/{:d} - Teacher {:d}/{:d}".format(i, n_batches-1, j, n_teachers-1)
            print(_progress_str + ' ' * (_prev_str_len - len(_progress_str)), end='\r')
            _prev_str_len = len(_progress_str)

            outs  = model(imgs)
            preds = outs.argmax(dim=1)
            batch_of_preds.append(preds.cpu())
    
    batches_of_preds.append(batch_of_preds)
        
label_preds = torch.cat(
    [torch.stack([preds for preds in batch_of_preds], dim=0)
     for batch_of_preds in batches_of_preds],
    dim=1)

print()
print(label_preds)

Batch 17/17 - Teacher 249/249
tensor([[7, 0, 1,  ..., 6, 9, 0],
        [7, 0, 1,  ..., 6, 9, 0],
        [7, 2, 1,  ..., 6, 9, 0],
        ...,
        [7, 2, 1,  ..., 6, 9, 0],
        [7, 2, 1,  ..., 6, 9, 0],
        [7, 0, 1,  ..., 6, 9, 0]])


### Get Label Counts for Each Image

In [10]:
label_counts = torch.from_numpy(np.apply_along_axis(lambda x: np.bincount(x, minlength=10), axis=0, arr=label_preds.numpy()))
print(label_counts)

tensor([[  0,   4,   0,  ...,   0,   1, 232],
        [  0,   5, 248,  ...,   0,   0,   0],
        [  0, 213,   1,  ...,   0,   1,   7],
        ...,
        [248,   0,   0,  ...,   0,  26,   0],
        [  0,   7,   0,  ...,   0,   8,   0],
        [  0,   0,   0,  ...,   0, 192,   0]])


### Obtain Labels from Noisy Counts with a Certain $\epsilon$ Value

In [20]:
epsilon = 0.05

noise_dist = dists.Laplace(loc=torch.zeros([], dtype=torch.float),
                           scale=torch.full([], 1 / epsilon, dtype=torch.float))

noisy_counts = label_counts.float() + noise_dist.sample([10, label_counts.size(1)])

generated_labels = noisy_counts.argmax(dim=0)
print(generated_labels)
print()
print("Noisy Accuracy Against Predictions:", (generated_labels == label_counts.argmax(dim=0)).float().mean().item())

tensor([7, 2, 1,  ..., 6, 9, 0])

Noisy Accuracy Against Predictions: 0.9427777528762817


### Perform PATE Analysis to Check Information Leak

In [25]:
pate.perform_analysis(label_preds[:, :100], generated_labels[:100], epsilon, delta=1e-05)

(5.30258509299405, 5.302585092994046)

### Assign the Generated Labels to the DP Training Set

In [None]:
mnist_testset.targets[dp_dataset.indices] = generated_labels

## Train the DP Model

In [None]:
lr       = 3e-3
n_epochs = 10

dp_optimizer = optim.Adam(dp_model.parameters(), lr=lr)
criterion    = nn.CrossEntropyLoss()

dp_model.train()

dp_train_history = {'avg_losses':{}, 'avg_accuracies': {}}
for i_epoch in range(n_epochs):
    instance_count = 0
    total_loss     = 0.
    correct_count  = 0.

    n_batches = len(dp_dataloader)
    _prev_str_len = 0
    for i, (imgs, labels) in enumerate(dp_dataloader):
        _batch_str = "Epoch {:d}/{:d}: ({:d}/{:d})".format(i_epoch, n_epochs-1, i, n_batches-1)
        print(_batch_str + ' ' * (_prev_str_len - len(_batch_str)), end='\r')
        _prev_str_len = len(_batch_str)

        instance_count += imgs.size(0)

        imgs   = imgs.to(device)
        labels = labels.to(device)

        outs  = dp_model(imgs)
        preds = torch.argmax(outs, dim=1)

        dp_optimizer.zero_grad()
        loss = criterion(outs, labels)
        loss.backward()
        dp_optimizer.step()

        total_loss += loss.item() * imgs.size(0)

        correct_count += (preds == labels).sum().item()

    avg_loss = total_loss / instance_count
    avg_accuracy = correct_count / instance_count

    print()
    print("    Avg Loss: {:.6f}".format(avg_loss))
    print("    Avg Accuracy: {:.4f}".format(avg_accuracy))
    print()

    dp_train_history['avg_losses'][i_epoch]     = avg_loss
    dp_train_history['avg_accuracies'][i_epoch] = avg_accuracy

## Evaluate the Result Model on the Test Set

In [None]:
dp_model.eval()

criterion = nn.CrossEntropyLoss(reduction='sum')

instance_count = 0
total_loss     = 0.
correct_count  = 0.
# test_dataloader     = data.DataLoader(test_dataset, batch_size=batch_size, shuffle=False, drop_last=False)
n_batches = len(dataloader)
for i, (imgs, labels) in enumerate(data.DataLoader(test_dataset, batch_size=batch_size, shuffle=False, drop_last=False)):
    print("Batch {:d}/{:d}".format(i, n_batches-1), end='\r')

    instance_count += imgs.size(0)
    
    imgs = imgs.to(device)
    
    with torch.no_grad():
        outs  = model(imgs)
    
    total_loss += criterion(outs, labels).item()

    preds = outs.argmax(dim=1)
    
    correct_count += (preds == labels).sum().item()

print()
print("Average Loss:", total_loss / instance_count)
print("Average Accuracy:", correct_count / instance_count)

1. Prepare Data
    1. Split the training dataset into `n + 1` smaller datasets where `n` is the number of teacher models
    2. Define a Dataset class and a DataLoader that can give batches of data for all `n` teacher datasets
2. Define Model(s)
    1. A simple ConvNet for both the main model and all teacher models
    2. If too slow: custom `nn.Module` that can process `n` batches at once for all `n` teachers
3. Train Teachers
4. Label Unlabeled Training Dataset in a Differentially Private Manner
    1. Generate raw labels
    2. PATE analysis to find a proper `epsilon` value
    3. Add proper noise to the label counts
    4. Take the labels with most counts
5. Train the Main Model On the Training Dataset with Generated Labels
6. Test on the Test Dataset