# Learning private models with multiple teachers

Ressources:

http://www.cleverhans.io/privacy/2018/04/29/privacy-and-machine-learning.html
https://github.com/tensorflow/models/tree/master/research/differential_privacy/multiple_teachers

Protocol:
1. Train teachers:
    - Devide training set into buckets (not overlapping)
    - Train a models (teacher) on each bucket
2. Train student:
    - Extract a share of the test set
    - Ensemble predictions from teachers: queries each teacher for predictions on the test set share
    - Aggregate teacher predictions to get student training labels using noising max: it
  adds Laplacian noise to label counts and returns the most frequent label
    - Train student with the aggregated label
    - Validate the student model on the remaining test data
    


## Train Teachers

In [1]:
import os

path = "/Users/yanndupis/Documents/Datascience/private-ml/models/research/"
os.chdir(path)


from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import tensorflow as tf
import numpy as np

from differential_privacy.multiple_teachers import deep_cnn
from differential_privacy.multiple_teachers import input
from differential_privacy.multiple_teachers import metrics

from torch.utils.data import DataLoader,Dataset

  from ._conv import register_converters as _register_converters


In [2]:
tf.app.flags.DEFINE_string('f', '', 'kernel')

tf.flags.DEFINE_string('dataset', 'mnist', 'The name of the dataset to use')
tf.flags.DEFINE_integer('nb_labels', 10, 'Number of output classes')

tf.flags.DEFINE_string('data_dir','/tmp','Temporary storage')
tf.flags.DEFINE_string('train_dir','/tmp/train_dir',
                       'Where model ckpt are saved')

tf.flags.DEFINE_integer('max_steps', 3000, 'Number of training steps to run.')
tf.flags.DEFINE_integer('nb_teachers', 10, 'Teachers in the ensemble.')
tf.flags.DEFINE_integer('teacher_id', 0, 'ID of teacher being trained.')


FLAGS = tf.flags.FLAGS

In [3]:
train_data, train_labels, test_data, test_labels = input.ld_mnist()

# Reshape to have channel first
train_data = train_data.reshape(60000,1,28,28)
test_data = test_data.reshape(10000,1,28,28)

In [4]:
print(train_data.shape)
print(train_labels.shape)
print(test_data.shape)

(60000, 1, 28, 28)
(60000,)
(10000, 1, 28, 28)


In [5]:
dataset, nb_teachers, teacher_id = FLAGS.dataset, FLAGS.nb_teachers,  FLAGS.teacher_id

In [6]:
class PrepareData(Dataset):
    def __init__(self, X, y):
        self.X = X
        self.y = y

    def __len__(self):
        return len(self.X)

    def __getitem__(self, idx):
        return self.X[idx], self.y[idx]

In [7]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [8]:
# class CNN_Model(nn.Module):
#     def __init__(self, num_classes):
#         super(CNN_Model, self).__init__()
#         self.conv1 = nn.Conv2d(1, 64, 5, stride=1)
#         self.maxpool1 = nn.MaxPool2d(kernel_size=3, stride=2)
#         # Add local response normalization
#         self.conv2 = nn.Conv2d(64, 128, 5, stride=1)
#         # Add local response normalization
#         self.maxpool2 = nn.MaxPool2d(kernel_size=3, stride=2)
        
#         self.linear1 = nn.Linear(128*3*3,384)
#         self.relu1 = nn.ReLU()
        
#         self.linear2 = nn.Linear(384,192)
#         self.relu2 = nn.ReLU()
        
#         self.logit = nn.Linear(192, num_classes)
        
#     def forward(self, x):
#         x = self.conv1(x)
#         x = self.maxpool1(x)
#         x = self.conv2(x)
#         x = self.maxpool2(x)
#         x = x.view(-1, 128*3*3)
#         x = self.linear1(x)
#         x = self.relu1(x)
#         x = self.linear2(x)
#         x = self.relu2(x)
#         out = self.logit(x)
#         return x

In [9]:
class CNN_Model(nn.Module):
    def __init__(self, num_classes):
        super(CNN_Model, self).__init__()
        self.conv1 = nn.Conv2d(1, 16, 5, stride = 1)
        self.batchnorm1 = nn.BatchNorm2d(16)
        self.relu1 = nn.ReLU()
        self.avgpool1 = nn.AvgPool2d(2)
        self.conv2 = nn.Conv2d(16, 16, 5, stride = 1)
        self.batchnorm2 = nn.BatchNorm2d(16)
        self.relu2 = nn.ReLU()
        self.avgpool2 = nn.AvgPool2d(2)
        self.linear1 = nn.Linear(256, 100)
        self.batchnorm3 = nn.BatchNorm1d(100)
        self.relu3 = nn.ReLU()
        self.linear2 = nn.Linear(100, num_classes)

    def forward(self, x):
        x = self.conv1(x)
        x = self.batchnorm1(x)
        x = self.relu1(x)
        x = self.avgpool1(x)
        x = self.conv2(x)
        x = self.batchnorm2(x)
        x = self.relu2(x)
        x = self.avgpool2(x)
        x = x.view(-1, 256)
        x = self.linear1(x)
        x = self.batchnorm3(x)
        x = self.relu3(x)
        out = self.linear2(x)
        return out

In [10]:
model = CNN_Model(10)

In [11]:
def train(model, train_loader, ckpt_path, filename):
    
    optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
    
    for epoch in range(10):
            model.train() # set model to training mode

            # set up training metrics we want to track
            correct = 0
            train_num = len(train_loader.sampler)

            for ix, (img, label) in enumerate(train_loader): # iterate over training batches
                #img, label = img.to(device), label.to(device) # get data, send to gpu if needed
                img = img.type(torch.float32)
                #label = label.type(torch.float32)
                label = label.type(torch.LongTensor)
                optimizer.zero_grad() # clear parameter gradients from previous training update
                output = model(img) # forward pass
                #output = output.type(torch.float32)
                loss = F.cross_entropy(output, label,size_average=False) # calculate network loss
                loss.backward() # backward pass
                optimizer.step() # take an optimization step to update model's parameters

                pred = output.max(1, keepdim=True)[1] # get the index of the max logit
                correct += pred.eq(label.view_as(pred)).sum().item() # add to running total of hits

            # print whole epoch's training accuracy; useful for monitoring overfitting
            print('Train Accuracy: {}/{} ({:.0f}%)'.format(
                correct, train_num, 100. * correct / train_num))

    if not os.path.isdir(ckpt_path):
        os.makedirs(ckpt_path)

    torch.save(model.state_dict(), ckpt_path + filename)

In [12]:
def train_teachers(train_data, train_labels, nb_teachers, teacher_id):
    
    data, labels = input.partition_dataset(train_data,
                                         train_labels,
                                         nb_teachers,
                                         teacher_id)
    
    train_prep = PrepareData(data, labels)
    
    train_loader = DataLoader(train_prep, batch_size=64, shuffle=True)
    
    print("\nTrain teacher ID: " + str(teacher_id))
    train(model, train_loader, ckpt_path, filename)

In [13]:
ckpt_path = 'differential_privacy/multiple_teachers/' + 'checkpoint/'

for teacher_id in range(nb_teachers):
    
    filename = str(dataset) + '_' + str(nb_teachers) + '_teachers_' + str(teacher_id) + '.pth'

    train_teachers(train_data, train_labels, 100, teacher_id)


Train teacher ID: 0




Train Accuracy: 400/600 (67%)
Train Accuracy: 551/600 (92%)
Train Accuracy: 575/600 (96%)
Train Accuracy: 588/600 (98%)
Train Accuracy: 595/600 (99%)
Train Accuracy: 595/600 (99%)
Train Accuracy: 599/600 (100%)
Train Accuracy: 597/600 (100%)
Train Accuracy: 598/600 (100%)
Train Accuracy: 599/600 (100%)

Train teacher ID: 1
Train Accuracy: 546/600 (91%)
Train Accuracy: 575/600 (96%)
Train Accuracy: 585/600 (98%)
Train Accuracy: 598/600 (100%)
Train Accuracy: 599/600 (100%)
Train Accuracy: 599/600 (100%)
Train Accuracy: 600/600 (100%)
Train Accuracy: 600/600 (100%)
Train Accuracy: 600/600 (100%)
Train Accuracy: 599/600 (100%)

Train teacher ID: 2
Train Accuracy: 568/600 (95%)
Train Accuracy: 585/600 (98%)
Train Accuracy: 593/600 (99%)
Train Accuracy: 596/600 (99%)
Train Accuracy: 598/600 (100%)
Train Accuracy: 600/600 (100%)
Train Accuracy: 600/600 (100%)
Train Accuracy: 599/600 (100%)
Train Accuracy: 598/600 (100%)
Train Accuracy: 599/600 (100%)

Train teacher ID: 3
Train Accuracy: 572/

In [39]:
def softmax_preds(images_loader, ckpt_path, return_logits=False):
    """
    Compute softmax activations (probabilities) with the model saved in the path
    specified as an argument
    :param images: a np array of images
    :param ckpt_path: a TF model checkpoint
    :param logits: if set to True, return logits instead of probabilities
    :return: probabilities (or logits if logits is set to True)
    """
    # Compute nb samples and deduce nb of batches
    data_length = len(images_loader.dataset)
    preds = np.zeros((data_length, FLAGS.nb_labels), dtype=np.float32)
    start = 0
    
    check = torch.load(ckpt_path)
    model.load_state_dict(check)
    model.eval() # set model to evaluate mode
    
    with torch.no_grad():
        for img, label in images_loader:
            output = model(img)
            output_softmax = F.softmax(output).data.numpy()
            
            end = start + len(img)
            
            preds[start:end,:] = output_softmax
            
            start += len(img)

    return preds

In [44]:
test_prep = PrepareData(test_data, test_labels)
    
test_loader = DataLoader(test_prep, batch_size=64, shuffle=True)

softmax_preds(test_loader, './differential_privacy/multiple_teachers/checkpoint/mnist_10_teachers_0.pth')



array([[9.96129870e-01, 1.64601408e-06, 2.19379726e-04, ...,
        3.71756760e-05, 3.35881341e-04, 2.21464157e-04],
       [9.99902964e-01, 2.03707623e-06, 8.10717320e-05, ...,
        3.48550316e-06, 1.49741891e-06, 4.96857062e-08],
       [2.34841827e-05, 4.04260441e-04, 2.81177017e-06, ...,
        1.02101685e-05, 8.10734855e-05, 1.25520935e-04],
       ...,
       [2.08887202e-03, 1.78590519e-04, 6.77612145e-03, ...,
        1.73116918e-03, 1.62030421e-02, 9.26935375e-02],
       [6.17830853e-09, 9.99998450e-01, 3.21992445e-07, ...,
        2.24134070e-07, 3.18366517e-07, 5.19417975e-09],
       [1.53743476e-03, 7.17969425e-03, 2.01376592e-04, ...,
        1.10462606e-05, 1.37406483e-03, 1.55625865e-04]], dtype=float32)

## Train Student

In [None]:
def ensemble_preds(dataset, nb_teachers, stdnt_data_loader):
  """
  Given a dataset, a number of teachers, and some input data, this helper
  function queries each teacher for predictions on the data and returns
  all predictions in a single array. (That can then be aggregated into
  one single prediction per input using aggregation.py (cf. function
  prepare_student_data() below)
  :param dataset: string corresponding to mnist, cifar10, or svhn
  :param nb_teachers: number of teachers (in the ensemble) to learn from
  :param stdnt_data: unlabeled student training data
  :return: 3d array (teacher id, sample id, probability per class)
  """

  # Compute shape of array that will hold probabilities produced by each
  # teacher, for each training point, and each output class
  result_shape = (nb_teachers, len(stdnt_data_loader), FLAGS.nb_labels)

  # Create array that will hold result
  result = np.zeros(result_shape, dtype=np.float32)

  # Get predictions from each teacher
  for teacher_id in xrange(nb_teachers):
    # Compute path of checkpoint file for teacher model with ID teacher_id
    filename = str(dataset) + '_' + str(nb_teachers) + '_teachers_' + str(teacher_id) + '.pth'

    # Get predictions on our training data and store in result array
    result[teacher_id] = softmax_preds(stdnt_data_loader, ckpt_path)

    # This can take a while when there are a lot of teachers so output status
    print("Computed Teacher " + str(teacher_id) + " softmax predictions")

  return result