# Learning private models with multiple teachers

Protocol:
1. Train teachers:
    - Devide training set into buckets (not overlapping)
    - Train a models (teacher) on each bucket
2. Train student:
    - Extract a share of the test set
    - Ensemble predictions from teachers: queries each teacher for predictions on the test set share
    - Aggregate teacher predictions to get student training labels using noising max: it
  adds Laplacian noise to label counts and returns the most frequent label
    - Train student with the aggregated label
    - Validate the student model on the remaining test data

http://www.cleverhans.io/privacy/2018/04/29/privacy-and-machine-learning.html
https://github.com/tensorflow/models/tree/master/research/differential_privacy/multiple_teachers

In [1]:
import os
import argparse

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.autograd import Variable

from torchvision import datasets, transforms
from torch.utils.data import TensorDataset, DataLoader



def prepare_mnist():
    kwargs = {"num_workers": 1}

    train_loader = torch.utils.data.DataLoader(
        datasets.MNIST(
            "./data",
            train=True,
            download=True,
            transform=transforms.Compose(
                [transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]
            ),
        ),
        batch_size=60000,
        shuffle=True,
        **kwargs,
    )

    test_loader = torch.utils.data.DataLoader(
        datasets.MNIST(
            "./data",
            train=False,
            transform=transforms.Compose(
                [transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]
            ),
        ),
        batch_size=10000,
        shuffle=False,
        **kwargs,
    )

    train_data, train_labels = next(iter(train_loader))
    test_data, test_labels = next(iter(test_loader))

    return train_data, train_labels, test_data, test_labels

In [2]:
# For this demo, we use MNIST dataset
train_data, train_labels, test_data, test_labels = prepare_mnist()

In [3]:
import argparse
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.autograd import Variable
from torch.utils.data import TensorDataset, DataLoader

print(torch.__version__)
# Training settings
parser = argparse.ArgumentParser(description='PyTorch Example')
parser.add_argument('--batch-size', type=int, default=32, metavar='N',
                    help='input batch size for training (default: 8)')
parser.add_argument('--test-batch-size', type=int, default=8, metavar='N',
                    help='input batch size for testing (default: 8)')
parser.add_argument('--epochs', type=int, default=10, metavar='N',
                    help='number of epochs to train (default: 10)')
parser.add_argument('--lr', type=float, default=0.001, metavar='LR',
                    help='learning rate (default: 0.001)')
parser.add_argument('--momentum', type=float, default=0.0, metavar='M',
                    help='SGD momentum (default: 0.0)')
parser.add_argument('--seed', type=int, default=1, metavar='S',
                    help='random seed (default: 1)')
parser.add_argument('--log-interval', type=int, default=10, metavar='N',
                    help='how many batches to wait before logging training status')
args = parser.parse_args([])

torch.manual_seed(args.seed)
kwargs = {}

0.3.1


In [4]:
dataset = "mnist" 
nb_labels = 10
nb_teachers = 100 
stdnt_share = 1000
lap_scale = 10

In [5]:
import syft as sy
from syft import Variable as Var
from syft import nn
from syft import optim

from syft.dp.pate import partition_dataset

In [6]:
hook = sy.TorchHook(verbose=False)
me = hook.local_worker

teacher_nodes = []
for i in range(nb_teachers):
    teacher_nodes.append(sy.VirtualWorker(id=str(i), hook=hook))

In [7]:
def distribute_training_data_across_teachers(train_data, train_labels, nb_teachers, teacher_nodes):
    train_distributed_dataset_all_teachers = {}

    for i in range(len(teacher_nodes[:11])):
        train_distributed_dataset = []
        teacher_id = int(teacher_nodes[i].id)
        data, labels = partition_dataset(train_data, train_labels, nb_teachers, teacher_id)
        train = TensorDataset(data, labels)
        train_loader = DataLoader(train, batch_size=args.batch_size, shuffle=True, **kwargs)
        for batch_idx, (data,labels) in enumerate(train_loader):
            data = Variable(data)
            labels = Variable(labels.type(torch.LongTensor))
            data.send(teacher_nodes[teacher_id])
            labels.send(teacher_nodes[teacher_id])
            train_distributed_dataset.append((data, labels))

        train_distributed_dataset_all_teachers[teacher_id]= train_distributed_dataset
        
    return train_distributed_dataset_all_teachers
    

In [8]:
train_distributed_dataset_all_teachers = distribute_training_data_across_teachers(train_data, train_labels, nb_teachers, teacher_nodes)

test = TensorDataset(test_data, test_labels)
test_loader = DataLoader(test, batch_size=args.batch_size, shuffle=True, **kwargs)

In [9]:
class CNN_Model(nn.Module):
    def __init__(self, num_classes):
        super(CNN_Model, self).__init__()
        self.conv1 = nn.Conv2d(1,16,5,stride=1)
        self.relu1 = nn.ReLU()
        self.avgpool1 = nn.AvgPool2d(2)
        self.linear1 = nn.Linear(2304, 100)
        self.linear2 = nn.Linear(100, num_classes)

    def forward(self, x):
        x = self.conv1(x)
        x = self.relu1(x)
        x = self.avgpool1(x)
        x = x.view(-1, 2304)
        x = self.linear1(x)
        out = self.linear2(x)
        return out


In [10]:
def train_teachers(
    model,
    train_data,
    train_labels,
    test_data,
    test_labels,
    nb_teachers,
    teacher_id,
    filename,
):
    data, labels = partition_dataset(train_data, train_labels, nb_teachers, teacher_id)

    train_prep = PrepareData(data, labels)
    train_loader = DataLoader(train_prep, batch_size=64, shuffle=True)

    test_prep = PrepareData(test_data, test_labels)
    test_loader = DataLoader(test_prep, batch_size=64, shuffle=False)

    print("\nTrain teacher ID: " + str(teacher_id))

    train(model, train_loader, test_loader, ckpt_path, filename)

In [11]:
def train_teacher(model, train_distributed_dataset, test_loader, ckpt_path, filename):
    
    optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
    
    for epoch in range(args.epochs):
        model.train()

        train_num = 0
        correct = 0
        for batch_idx, (data,target) in enumerate(train_distributed_dataset):

            worker = data.location
            model.send(worker)

            optimizer.zero_grad()
            # update the model
            output = model(data)
            loss = F.cross_entropy(output, target, size_average=False)
            loss.backward()
            model.get()
            optimizer.step()

            pred_label = output.max(1, keepdim=True)[1]  # get the index of the max logit
            pred_label.get()
            target.get()

            train_num += len(target)
            correct += int(
                    pred_label.eq(target.view_as(pred_label)).sum()
                 )  # add to running total of hits
            target.send(worker)


        print("Train Accuracy: {}/{} ({:.0f}%)".format(
                correct, int(train_num), 100.0 * float(correct / train_num)))

    
    # set up training metrics we want to track
    test_correct = 0
    test_num = len(test_loader.sampler)

    for ix, (img, label) in enumerate(test_loader):  # iterate over training batches
        img = Var(img.float())
        label = Var(label.type(torch.LongTensor))
        optimizer.zero_grad()  # clear parameter gradients from previous training update
        output = model(img)  # forward pass
        # output = output.type(torch.float32)
        loss = F.cross_entropy(
            output, label, size_average=False
        )  # calculate network loss

        pred = output.max(1, keepdim=True)[1]  # get the index of the max logit
        test_correct += int(
            pred.eq(label.view_as(pred)).sum()
        )  # add to running total of hits

        # print whole epoch's training accuracy; useful for monitoring overfitting
    print(
        "Test Accuracy: {}/{} ({:.0f}%)".format(
            test_correct, test_num, 100.0 * test_correct / test_num
        )
    )

    if not os.path.isdir(ckpt_path):
        os.makedirs(ckpt_path)
        
    #torch.save(model.state_dict(), ckpt_path + filename)

In [12]:
args.epochs = 2
nb_teachers = 10
#for i in range(nb_teachers):
    
for i in range(len(teacher_nodes[:nb_teachers])):
    teacher_id = int(teacher_nodes[i].id)
    
    ckpt_path = 'checkpoint/'
    
    filename = str(dataset) + '_' + str(nb_teachers) + '_teachers_' + str(teacher_id) + '.pth'

    train_distributed_dataset = train_distributed_dataset_all_teachers[teacher_id]
    
    model = CNN_Model(10)
    
    print("\nTrain teacher ID: " + str(teacher_id))
    
    train_teacher(model, train_distributed_dataset, test_loader, ckpt_path, filename)
    


Train teacher ID: 0
Train Accuracy: 333/600 (56%)
Train Accuracy: 517/600 (86%)
Test Accuracy: 8544/10000 (85%)

Train teacher ID: 1
Train Accuracy: 283/600 (47%)
Train Accuracy: 519/600 (86%)
Test Accuracy: 8686/10000 (87%)

Train teacher ID: 2
Train Accuracy: 316/600 (53%)
Train Accuracy: 503/600 (84%)
Test Accuracy: 8688/10000 (87%)

Train teacher ID: 3
Train Accuracy: 287/600 (48%)
Train Accuracy: 505/600 (84%)
Test Accuracy: 8181/10000 (82%)

Train teacher ID: 4
Train Accuracy: 338/600 (56%)
Train Accuracy: 532/600 (89%)
Test Accuracy: 8694/10000 (87%)

Train teacher ID: 5
Train Accuracy: 339/600 (56%)
Train Accuracy: 538/600 (90%)
Test Accuracy: 8833/10000 (88%)

Train teacher ID: 6
Train Accuracy: 324/600 (54%)
Train Accuracy: 510/600 (85%)
Test Accuracy: 8672/10000 (87%)

Train teacher ID: 7
Train Accuracy: 290/600 (48%)
Train Accuracy: 514/600 (86%)
Test Accuracy: 8450/10000 (84%)

Train teacher ID: 8
Train Accuracy: 333/600 (56%)
Train Accuracy: 531/600 (88%)
Test Accuracy: 

In [13]:
# ckpt_path = 'checkpoint/'
    
# filename = str("mnist") + '_' + str(10) + '_teachers_' + str(1) + '.pth'

In [14]:
# model = CNN_Model(10)
# torch.save(model.state_dict(), ckpt_path + filename)

In [15]:
len(teacher_nodes[:nb_teachers])

10