# Learning private models with multiple teachers

Protocol:
1. Train teachers:
    - Devide training set into buckets (not overlapping)
    - Train a models (teacher) on each bucket
2. Train student:
    - Extract a share of the test set
    - Ensemble predictions from teachers: queries each teacher for predictions on the test set share
    - Aggregate teacher predictions to get student training labels using noising max: it
  adds Laplacian noise to label counts and returns the most frequent label
    - Train student with the aggregated label
    - Validate the student model on the remaining test data

http://www.cleverhans.io/privacy/2018/04/29/privacy-and-machine-learning.html
https://github.com/tensorflow/models/tree/master/research/differential_privacy/multiple_teachers

In [1]:
from syft.dp.pate import train_teachers, train_student
from torchvision import datasets, transforms
from torch.utils.data import TensorDataset, DataLoader

import torch
import torch.nn as nn
import torch.nn.functional as F

def prepare_mnist():
    kwargs = {"num_workers": 1}

    train_loader = torch.utils.data.DataLoader(
        datasets.MNIST(
            "./data",
            train=True,
            download=True,
            transform=transforms.Compose(
                [transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]
            ),
        ),
        batch_size=60000,
        shuffle=True,
        **kwargs,
    )

    test_loader = torch.utils.data.DataLoader(
        datasets.MNIST(
            "./data",
            train=False,
            transform=transforms.Compose(
                [transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]
            ),
        ),
        batch_size=10000,
        shuffle=False,
        **kwargs,
    )

    train_data, train_labels = next(iter(train_loader))
    test_data, test_labels = next(iter(test_loader))

    return train_data, train_labels, test_data, test_labels

In [2]:
# For this demo, we use MNIST dataset
train_data, train_labels, test_data, test_labels = prepare_mnist()

In [3]:
import argparse
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.autograd import Variable
from torch.utils.data import TensorDataset, DataLoader

print(torch.__version__)
# Training settings
parser = argparse.ArgumentParser(description='PyTorch Example')
parser.add_argument('--batch-size', type=int, default=32, metavar='N',
                    help='input batch size for training (default: 8)')
parser.add_argument('--test-batch-size', type=int, default=8, metavar='N',
                    help='input batch size for testing (default: 8)')
parser.add_argument('--epochs', type=int, default=10, metavar='N',
                    help='number of epochs to train (default: 10)')
parser.add_argument('--lr', type=float, default=0.001, metavar='LR',
                    help='learning rate (default: 0.001)')
parser.add_argument('--momentum', type=float, default=0.0, metavar='M',
                    help='SGD momentum (default: 0.0)')
parser.add_argument('--seed', type=int, default=1, metavar='S',
                    help='random seed (default: 1)')
parser.add_argument('--log-interval', type=int, default=10, metavar='N',
                    help='how many batches to wait before logging training status')
args = parser.parse_args([])

torch.manual_seed(args.seed)
kwargs = {}

0.3.1


In [4]:
dataset = "mnist" 
nb_labels = 10
nb_teachers = 100 
stdnt_share = 1000
lap_scale = 10

In [5]:
import syft as sy
from syft import Variable as Var
from syft import nn
from syft import optim

In [6]:
# hook = sy.TorchHook(verbose=False)
# me = hook.local_worker
# bob = sy.VirtualWorker(id="bob",hook=hook, is_client_worker=False)
# alice = sy.VirtualWorker(id="alice",hook=hook, is_client_worker=False)
# me.is_client_worker = False

# compute_nodes = [bob, alice]

# bob.add_workers([alice])
# alice.add_workers([bob])

In [7]:
hook = sy.TorchHook(verbose=False)
me = hook.local_worker

compute_nodes = []
for i in range(nb_teachers):
    compute_nodes.append(sy.VirtualWorker(id=str(i), hook=hook))
    
for i in range(len(compute_nodes)):
#    compute_nodes[i].add_workers([compute_nodes[i+1]])
    me.add_worker(compute_nodes[i])





In [8]:
#compute_nodes[0].add_workers(compute_nodes[1:])

In [9]:
# from syft.dp.pate import partition_dataset

# train_distributed_dataset = []

# for i in range(len(compute_nodes)):
#     worker_id = int(compute_nodes[i].id)
#     data, labels = partition_dataset(train_data, train_labels, nb_teachers, worker_id)
#     data = Variable(data)
#     labels = Variable(labels.type(torch.LongTensor))
#     #print(len(labels))
#     data.send(compute_nodes[worker_id])
#     labels.send(compute_nodes[worker_id])
#     train_distributed_dataset.append((data, labels))

In [10]:
test = TensorDataset(test_data, test_labels)
test_loader = DataLoader(test, batch_size=args.batch_size, shuffle=True, **kwargs)

In [11]:
from syft.dp.pate import partition_dataset

train_distributed_dataset = []
train_distributed_dataset_all_teachers = []

for i in range(len(compute_nodes[:2])):
#for i in range(1):
    teacher_id = int(compute_nodes[i].id)
    data, labels = partition_dataset(train_data, train_labels, nb_teachers, teacher_id)
    train = TensorDataset(data, labels)
    train_loader = DataLoader(train, batch_size=args.batch_size, shuffle=True, **kwargs)
    for batch_idx, (data,labels) in enumerate(train_loader):
        print(len(labels))
        data = Variable(data)
        labels = Variable(labels.type(torch.LongTensor))
        data.send(compute_nodes[teacher_id])
        labels.send(compute_nodes[teacher_id])
        train_distributed_dataset.append((data, labels))
        
    train_distributed_dataset_all_teachers.append(train_distributed_dataset)
    

32
32
32
32
32
32
32
32
32
32
32
32
32
32
32
32
32
32
24
32
32
32
32
32
32
32
32
32
32
32
32
32
32
32
32
32
32
24


In [12]:
train_distributed_dataset[0]

(Variable containing:FloatTensor[_PointerTensor - id:21006087818 owner:me loc:0 id@loc:55719567081],
 Variable containing:LongTensor[_PointerTensor - id:15444007465 owner:me loc:0 id@loc:41056433571])

In [13]:
class CNN_Model(nn.Module):
    def __init__(self, num_classes):
        super(CNN_Model, self).__init__()
        self.conv1 = nn.Conv2d(1,16,5,stride=1)
        self.relu1 = nn.ReLU()
        self.avgpool1 = nn.AvgPool2d(2)
        self.linear1 = nn.Linear(2304, 100)
        self.linear2 = nn.Linear(100, num_classes)

    def forward(self, x):
        x = self.conv1(x)
        x = self.relu1(x)
        x = self.avgpool1(x)
        x = x.view(-1, 2304)
        x = self.linear1(x)
        out = self.linear2(x)
        return out


In [14]:
#yo = model(Variable(train_data))

In [15]:
#yo.size()

In [16]:
def train(epoch, train_distributed_dataset):
    model.train()
    
    train_num = 0
    correct = 0
    for batch_idx, (data,target) in enumerate(train_distributed_dataset):
            
        worker = data.location
        model.send(worker)

        optimizer.zero_grad()
        # update the model
        output = model(data)
        loss = F.cross_entropy(output, target, size_average=False)
        loss.backward()
        model.get()
        optimizer.step()
        
        pred_label = output.max(1, keepdim=True)[1]  # get the index of the max logit
        pred_label.get()
        target.get()
        
        train_num += len(target)
        correct += int(
                pred_label.eq(target.view_as(pred_label)).sum()
             )  # add to running total of hits
        target.send(worker)
        

    print("Train Accuracy: {}/{} ({:.0f}%)".format(
            correct, int(train_num), 100.0 * float(correct / train_num)))


In [17]:
args.epochs = 10
nb_teachers = 2
for i in range(nb_teachers):

    train_distributed_dataset = train_distributed_dataset_all_teachers[i]
    
    model = CNN_Model(10)
    optimizer = optim.SGD(model.parameters(), lr=args.lr)

    for epoch in range(1, args.epochs + 1):
        train(epoch, train_distributed_dataset)
        
    
    test_correct = 0
    test_num = len(test_loader.sampler)

    for ix, (img, label) in enumerate(test_loader):  # iterate over training batches
            # img, label = img.to(device), label.to(device) # get data, send to gpu if needed
            img = Var(img.float())
            # label = label.type(torch.float32)
            label = Var(label.type(torch.LongTensor))
            optimizer.zero_grad()  # clear parameter gradients from previous training update
            output = model(img)  # forward pass
            # output = output.type(torch.float32)
            loss = F.cross_entropy(
                output, label, size_average=False
            )  # calculate network loss

            pred = output.max(1, keepdim=True)[1]  # get the index of the max logit
            test_correct += int(
                pred.eq(label.view_as(pred)).sum()
            )  # add to running total of hits

            # print whole epoch's training accuracy; useful for monitoring overfitting
    print("Test Accuracy: {}/{} ({:.0f}%)".format(
                test_correct, test_num, 100.0 * test_correct / test_num
            )
        )

Train Accuracy: 757/1200 (63%)
Train Accuracy: 996/1200 (83%)
Train Accuracy: 1051/1200 (88%)
Train Accuracy: 1081/1200 (90%)
Train Accuracy: 1093/1200 (91%)
Train Accuracy: 1104/1200 (92%)
Train Accuracy: 1115/1200 (93%)
Train Accuracy: 1121/1200 (93%)
Train Accuracy: 1131/1200 (94%)
Train Accuracy: 1144/1200 (95%)
Test Accuracy: 8950/10000 (90%)
Train Accuracy: 725/1200 (60%)
Train Accuracy: 993/1200 (83%)
Train Accuracy: 1044/1200 (87%)
Train Accuracy: 1074/1200 (90%)
Train Accuracy: 1088/1200 (91%)
Train Accuracy: 1102/1200 (92%)
Train Accuracy: 1111/1200 (93%)
Train Accuracy: 1120/1200 (93%)
Train Accuracy: 1130/1200 (94%)
Train Accuracy: 1138/1200 (95%)
Test Accuracy: 8941/10000 (89%)


In [18]:
test_correct = 0
test_num = len(test_loader.sampler)

for ix, (img, label) in enumerate(test_loader):  # iterate over training batches
        # img, label = img.to(device), label.to(device) # get data, send to gpu if needed
        img = Var(img.float())
        # label = label.type(torch.float32)
        label = Var(label.type(torch.LongTensor))
        optimizer.zero_grad()  # clear parameter gradients from previous training update
        output = model(img)  # forward pass
        # output = output.type(torch.float32)
        loss = F.cross_entropy(
            output, label, size_average=False
        )  # calculate network loss

        pred = output.max(1, keepdim=True)[1]  # get the index of the max logit
        test_correct += int(
            pred.eq(label.view_as(pred)).sum()
        )  # add to running total of hits

        # print whole epoch's training accuracy; useful for monitoring overfitting
print("Test Accuracy: {}/{} ({:.0f}%)".format(
            test_correct, test_num, 100.0 * test_correct / test_num
        )
    )

Test Accuracy: 8941/10000 (89%)


In [19]:
len(train_distributed_dataset_all_teachers)

2