# Learning private models with multiple teachers

Ressources:

http://www.cleverhans.io/privacy/2018/04/29/privacy-and-machine-learning.html
https://github.com/tensorflow/models/tree/master/research/differential_privacy/multiple_teachers

Protocol:
1. Train teachers:
    - Devide training set into buckets (not overlapping)
    - Train a models (teacher) on each bucket
2. Train student:
    - Extract a share of the test set
    - Ensemble predictions from teachers: queries each teacher for predictions on the test set share
    - Aggregate teacher predictions to get student training labels using noising max: it
  adds Laplacian noise to label counts and returns the most frequent label
    - Train student with the aggregated label
    - Validate the student model on the remaining test data
    


## Train Teachers

In [1]:
import os

path = "/Users/yanndupis/Documents/Datascience/private-ml/models/research/"
os.chdir(path)


from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import tensorflow as tf
import numpy as np

from differential_privacy.multiple_teachers import deep_cnn
from differential_privacy.multiple_teachers import input
from differential_privacy.multiple_teachers import metrics
from differential_privacy.multiple_teachers import aggregation

import torch
import torch.nn as nn
import torch.nn.functional as F

from torch.utils.data import DataLoader,Dataset

  from ._conv import register_converters as _register_converters


In [2]:
tf.app.flags.DEFINE_string('f', '', 'kernel')

tf.flags.DEFINE_string('dataset', 'mnist', 'The name of the dataset to use')
tf.flags.DEFINE_integer('nb_labels', 10, 'Number of output classes')

tf.flags.DEFINE_string('data_dir','/tmp','Temporary storage')
tf.flags.DEFINE_string('train_dir','/tmp/train_dir',
                       'Where model ckpt are saved')

tf.flags.DEFINE_integer('max_steps', 3000, 'Number of training steps to run.')
tf.flags.DEFINE_integer('nb_teachers', 100, 'Teachers in the ensemble.')
tf.flags.DEFINE_integer('teacher_id', 0, 'ID of teacher being trained.')


tf.flags.DEFINE_string('teachers_dir','/tmp/train_dir',
                       'Directory where teachers checkpoints are stored.')

tf.flags.DEFINE_integer('stdnt_share', 1000,
                        'Student share (last index) of the test data')
tf.flags.DEFINE_integer('lap_scale', 10,
                        'Scale of the Laplacian noise added for privacy')
tf.flags.DEFINE_boolean('save_labels', False,
                        'Dump numpy arrays of labels and clean teacher votes')
tf.flags.DEFINE_boolean('deeper', False, 'Activate deeper CNN model')


FLAGS = tf.flags.FLAGS

In [3]:
train_data, train_labels, test_data, test_labels = input.ld_mnist()

# Reshape to have channel first
train_data = train_data.reshape(60000,1,28,28)
test_data = test_data.reshape(10000,1,28,28)

In [4]:
print(train_data.shape)
print(train_labels.shape)
print(test_data.shape)

(60000, 1, 28, 28)
(60000,)
(10000, 1, 28, 28)


In [5]:
dataset, nb_teachers, teacher_id = FLAGS.dataset, FLAGS.nb_teachers,  FLAGS.teacher_id

In [6]:
class PrepareData(Dataset):
    def __init__(self, X, y):
        self.X = X
        self.y = y

    def __len__(self):
        return len(self.X)

    def __getitem__(self, idx):
        return self.X[idx], self.y[idx]

In [7]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [8]:
# class CNN_Model(nn.Module):
#     def __init__(self, num_classes):
#         super(CNN_Model, self).__init__()
#         self.conv1 = nn.Conv2d(1, 64, 5, stride=1)
#         self.maxpool1 = nn.MaxPool2d(kernel_size=3, stride=2)
#         # Add local response normalization
#         self.conv2 = nn.Conv2d(64, 128, 5, stride=1)
#         # Add local response normalization
#         self.maxpool2 = nn.MaxPool2d(kernel_size=3, stride=2)
        
#         self.linear1 = nn.Linear(128*3*3,384)
#         self.relu1 = nn.ReLU()
        
#         self.linear2 = nn.Linear(384,192)
#         self.relu2 = nn.ReLU()
        
#         self.logit = nn.Linear(192, num_classes)
        
#     def forward(self, x):
#         x = self.conv1(x)
#         x = self.maxpool1(x)
#         x = self.conv2(x)
#         x = self.maxpool2(x)
#         x = x.view(-1, 128*3*3)
#         x = self.linear1(x)
#         x = self.relu1(x)
#         x = self.linear2(x)
#         x = self.relu2(x)
#         out = self.logit(x)
#         return x

In [9]:
class CNN_Model(nn.Module):
    def __init__(self, num_classes):
        super(CNN_Model, self).__init__()
        self.conv1 = nn.Conv2d(1, 16, 5, stride = 1)
        self.batchnorm1 = nn.BatchNorm2d(16)
        self.relu1 = nn.ReLU()
        self.avgpool1 = nn.AvgPool2d(2)
        self.conv2 = nn.Conv2d(16, 16, 5, stride = 1)
        self.batchnorm2 = nn.BatchNorm2d(16)
        self.relu2 = nn.ReLU()
        self.avgpool2 = nn.AvgPool2d(2)
        self.linear1 = nn.Linear(256, 100)
        self.batchnorm3 = nn.BatchNorm1d(100)
        self.relu3 = nn.ReLU()
        self.linear2 = nn.Linear(100, num_classes)

    def forward(self, x):
        x = self.conv1(x)
        x = self.batchnorm1(x)
        x = self.relu1(x)
        x = self.avgpool1(x)
        x = self.conv2(x)
        x = self.batchnorm2(x)
        x = self.relu2(x)
        x = self.avgpool2(x)
        x = x.view(-1, 256)
        x = self.linear1(x)
        x = self.batchnorm3(x)
        x = self.relu3(x)
        out = self.linear2(x)
        return out

In [10]:
model = CNN_Model(10)

In [11]:
def train(model, train_loader, test_loader, ckpt_path, filename):
    
    optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
    
    for epoch in range(10):
            model.train() # set model to training mode

            # set up training metrics we want to track
            correct = 0
            train_num = len(train_loader.sampler)

            for ix, (img, label) in enumerate(train_loader): # iterate over training batches
                #img, label = img.to(device), label.to(device) # get data, send to gpu if needed
                img = img.type(torch.float32)
                #label = label.type(torch.float32)
                label = label.type(torch.LongTensor)
                optimizer.zero_grad() # clear parameter gradients from previous training update
                output = model(img) # forward pass
                #output = output.type(torch.float32)
                loss = F.cross_entropy(output, label,size_average=False) # calculate network loss
                loss.backward() # backward pass
                optimizer.step() # take an optimization step to update model's parameters

                pred = output.max(1, keepdim=True)[1] # get the index of the max logit
                correct += pred.eq(label.view_as(pred)).sum().item() # add to running total of hits

            # print whole epoch's training accuracy; useful for monitoring overfitting
            print('Train Accuracy: {}/{} ({:.0f}%)'.format(
                correct, train_num, 100. * correct / train_num))

# set up training metrics we want to track
    test_correct = 0
    test_num = len(test_loader.sampler)
    with torch.no_grad():
        for ix, (img, label) in enumerate(test_loader): # iterate over training batches
            #img, label = img.to(device), label.to(device) # get data, send to gpu if needed
            img = img.type(torch.float32)
            #label = label.type(torch.float32)
            label = label.type(torch.LongTensor)
            optimizer.zero_grad() # clear parameter gradients from previous training update
            output = model(img) # forward pass
            #output = output.type(torch.float32)
            loss = F.cross_entropy(output, label,size_average=False) # calculate network loss
            #loss.backward() # backward pass
            #optimizer.step() # take an optimization step to update model's parameters

            pred = output.max(1, keepdim=True)[1] # get the index of the max logit
            test_correct += pred.eq(label.view_as(pred)).sum().item() # add to running total of hits

            # print whole epoch's training accuracy; useful for monitoring overfitting
        print('Test Accuracy: {}/{} ({:.0f}%)'.format(
        correct, test_num, 100. * test_correct / test_num))


    if not os.path.isdir(ckpt_path):
        os.makedirs(ckpt_path)

    torch.save(model.state_dict(), ckpt_path + filename)

In [12]:
def train_teachers(train_data, train_labels, test_data, test_labels, nb_teachers, teacher_id):
    
    data, labels = input.partition_dataset(train_data,
                                         train_labels,
                                         nb_teachers,
                                         teacher_id)
    
    train_prep = PrepareData(data, labels)
    train_loader = DataLoader(train_prep, batch_size=64, shuffle=True)
    
    test_prep = PrepareData(test_data, test_labels)
    test_loader = DataLoader(test_prep, batch_size=64, shuffle=False)
    
    print("\nTrain teacher ID: " + str(teacher_id))
    train(model, train_loader, test_loader, ckpt_path, filename)

In [13]:
ckpt_path = 'differential_privacy/multiple_teachers/' + 'checkpoint/'

for teacher_id in range(nb_teachers):
    
    filename = str(dataset) + '_' + str(nb_teachers) + '_teachers_' + str(teacher_id) + '.pth'

    train_teachers(train_data, train_labels, test_data, test_labels, nb_teachers, teacher_id)


Train teacher ID: 0




Train Accuracy: 412/600 (69%)
Train Accuracy: 548/600 (91%)
Train Accuracy: 576/600 (96%)
Train Accuracy: 590/600 (98%)
Train Accuracy: 598/600 (100%)
Train Accuracy: 599/600 (100%)
Train Accuracy: 599/600 (100%)
Train Accuracy: 600/600 (100%)
Train Accuracy: 599/600 (100%)
Train Accuracy: 600/600 (100%)
Test Accuracy: 600/10000 (94%)

Train teacher ID: 1
Train Accuracy: 542/600 (90%)
Train Accuracy: 578/600 (96%)
Train Accuracy: 590/600 (98%)
Train Accuracy: 590/600 (98%)
Train Accuracy: 600/600 (100%)
Train Accuracy: 599/600 (100%)
Train Accuracy: 600/600 (100%)
Train Accuracy: 600/600 (100%)
Train Accuracy: 600/600 (100%)
Train Accuracy: 600/600 (100%)
Test Accuracy: 600/10000 (95%)

Train teacher ID: 2
Train Accuracy: 572/600 (95%)
Train Accuracy: 583/600 (97%)
Train Accuracy: 595/600 (99%)
Train Accuracy: 598/600 (100%)
Train Accuracy: 600/600 (100%)
Train Accuracy: 600/600 (100%)
Train Accuracy: 600/600 (100%)
Train Accuracy: 600/600 (100%)
Train Accuracy: 600/600 (100%)
Train Ac

Train Accuracy: 600/600 (100%)
Test Accuracy: 600/10000 (98%)

Train teacher ID: 23
Train Accuracy: 583/600 (97%)
Train Accuracy: 598/600 (100%)
Train Accuracy: 597/600 (100%)
Train Accuracy: 600/600 (100%)
Train Accuracy: 600/600 (100%)
Train Accuracy: 599/600 (100%)
Train Accuracy: 600/600 (100%)
Train Accuracy: 600/600 (100%)
Train Accuracy: 600/600 (100%)
Train Accuracy: 600/600 (100%)
Test Accuracy: 600/10000 (98%)

Train teacher ID: 24
Train Accuracy: 582/600 (97%)
Train Accuracy: 592/600 (99%)
Train Accuracy: 599/600 (100%)
Train Accuracy: 599/600 (100%)
Train Accuracy: 600/600 (100%)
Train Accuracy: 600/600 (100%)
Train Accuracy: 599/600 (100%)
Train Accuracy: 600/600 (100%)
Train Accuracy: 600/600 (100%)
Train Accuracy: 599/600 (100%)
Test Accuracy: 599/10000 (98%)

Train teacher ID: 25
Train Accuracy: 586/600 (98%)
Train Accuracy: 597/600 (100%)
Train Accuracy: 599/600 (100%)
Train Accuracy: 600/600 (100%)
Train Accuracy: 600/600 (100%)
Train Accuracy: 600/600 (100%)
Train Ac

Train Accuracy: 600/600 (100%)
Train Accuracy: 600/600 (100%)
Train Accuracy: 600/600 (100%)
Train Accuracy: 600/600 (100%)
Test Accuracy: 600/10000 (98%)

Train teacher ID: 46
Train Accuracy: 588/600 (98%)
Train Accuracy: 596/600 (99%)
Train Accuracy: 597/600 (100%)
Train Accuracy: 599/600 (100%)
Train Accuracy: 600/600 (100%)
Train Accuracy: 600/600 (100%)
Train Accuracy: 600/600 (100%)
Train Accuracy: 600/600 (100%)
Train Accuracy: 599/600 (100%)
Train Accuracy: 598/600 (100%)
Test Accuracy: 598/10000 (98%)

Train teacher ID: 47
Train Accuracy: 586/600 (98%)
Train Accuracy: 591/600 (98%)
Train Accuracy: 597/600 (100%)
Train Accuracy: 600/600 (100%)
Train Accuracy: 597/600 (100%)
Train Accuracy: 599/600 (100%)
Train Accuracy: 600/600 (100%)
Train Accuracy: 600/600 (100%)
Train Accuracy: 600/600 (100%)
Train Accuracy: 600/600 (100%)
Test Accuracy: 600/10000 (98%)

Train teacher ID: 48
Train Accuracy: 588/600 (98%)
Train Accuracy: 598/600 (100%)
Train Accuracy: 599/600 (100%)
Train Acc

Train Accuracy: 596/600 (99%)
Train Accuracy: 599/600 (100%)
Train Accuracy: 600/600 (100%)
Train Accuracy: 600/600 (100%)
Train Accuracy: 600/600 (100%)
Train Accuracy: 600/600 (100%)
Train Accuracy: 600/600 (100%)
Test Accuracy: 600/10000 (98%)

Train teacher ID: 69
Train Accuracy: 590/600 (98%)
Train Accuracy: 597/600 (100%)
Train Accuracy: 597/600 (100%)
Train Accuracy: 600/600 (100%)
Train Accuracy: 599/600 (100%)
Train Accuracy: 600/600 (100%)
Train Accuracy: 599/600 (100%)
Train Accuracy: 600/600 (100%)
Train Accuracy: 600/600 (100%)
Train Accuracy: 600/600 (100%)
Test Accuracy: 600/10000 (98%)

Train teacher ID: 70
Train Accuracy: 591/600 (98%)
Train Accuracy: 597/600 (100%)
Train Accuracy: 600/600 (100%)
Train Accuracy: 600/600 (100%)
Train Accuracy: 600/600 (100%)
Train Accuracy: 600/600 (100%)
Train Accuracy: 600/600 (100%)
Train Accuracy: 600/600 (100%)
Train Accuracy: 600/600 (100%)
Train Accuracy: 600/600 (100%)
Test Accuracy: 600/10000 (99%)

Train teacher ID: 71
Train A

Train Accuracy: 581/600 (97%)
Train Accuracy: 598/600 (100%)
Train Accuracy: 598/600 (100%)
Train Accuracy: 600/600 (100%)
Train Accuracy: 600/600 (100%)
Train Accuracy: 600/600 (100%)
Train Accuracy: 600/600 (100%)
Train Accuracy: 600/600 (100%)
Train Accuracy: 600/600 (100%)
Train Accuracy: 600/600 (100%)
Test Accuracy: 600/10000 (98%)

Train teacher ID: 92
Train Accuracy: 587/600 (98%)
Train Accuracy: 597/600 (100%)
Train Accuracy: 600/600 (100%)
Train Accuracy: 599/600 (100%)
Train Accuracy: 600/600 (100%)
Train Accuracy: 599/600 (100%)
Train Accuracy: 600/600 (100%)
Train Accuracy: 599/600 (100%)
Train Accuracy: 600/600 (100%)
Train Accuracy: 600/600 (100%)
Test Accuracy: 600/10000 (98%)

Train teacher ID: 93
Train Accuracy: 588/600 (98%)
Train Accuracy: 596/600 (99%)
Train Accuracy: 597/600 (100%)
Train Accuracy: 599/600 (100%)
Train Accuracy: 600/600 (100%)
Train Accuracy: 598/600 (100%)
Train Accuracy: 600/600 (100%)
Train Accuracy: 600/600 (100%)
Train Accuracy: 600/600 (100%)

In [13]:
def softmax_preds(images_loader, ckpt_path, return_logits=False):
    """
    Compute softmax activations (probabilities) with the model saved in the path
    specified as an argument
    :param images: a np array of images
    :param ckpt_path: a TF model checkpoint
    :param logits: if set to True, return logits instead of probabilities
    :return: probabilities (or logits if logits is set to True)
    """
    # Compute nb samples and deduce nb of batches
    data_length = len(images_loader.dataset)
    preds = np.zeros((data_length, FLAGS.nb_labels), dtype=np.float32)
    start = 0
    
    check = torch.load(ckpt_path)
    model.load_state_dict(check)
    model.eval() # set model to evaluate mode
    
    with torch.no_grad():
        for img, label in images_loader:
            output = model(img)
            output_softmax = F.softmax(output).data.numpy()
            
            end = start + len(img)
            
            preds[start:end,:] = output_softmax
            
            start += len(img)

    return preds

## Train Student

In [15]:
def ensemble_preds(dataset, nb_teachers, stdnt_data_loader):
  """
  Given a dataset, a number of teachers, and some input data, this helper
  function queries each teacher for predictions on the data and returns
  all predictions in a single array. (That can then be aggregated into
  one single prediction per input using aggregation.py (cf. function
  prepare_student_data() below)
  :param dataset: string corresponding to mnist, cifar10, or svhn
  :param nb_teachers: number of teachers (in the ensemble) to learn from
  :param stdnt_data: unlabeled student training data
  :return: 3d array (teacher id, sample id, probability per class)
  """

  # Compute shape of array that will hold probabilities produced by each
  # teacher, for each training point, and each output class
  result_shape = (nb_teachers, len(stdnt_data_loader.dataset), FLAGS.nb_labels)

  # Create array that will hold result
  result = np.zeros(result_shape, dtype=np.float32)

  # Get predictions from each teacher
  for teacher_id in xrange(nb_teachers):
    # Compute path of checkpoint file for teacher model with ID teacher_id
    filename = str(dataset) + '_' + str(nb_teachers) + '_teachers_' + str(teacher_id) + '.pth'

    ckpt_path = 'differential_privacy/multiple_teachers/' + 'checkpoint/'

    # Get predictions on our training data and store in result array
    result[teacher_id] = softmax_preds(stdnt_data_loader, ckpt_path + filename)

    # This can take a while when there are a lot of teachers so output status
    print("Computed Teacher " + str(teacher_id) + " softmax predictions")

  return result

In [25]:
def prepare_student_data(dataset, nb_teachers, save=False):
  """
  Takes a dataset name and the size of the teacher ensemble and prepares
  training data for the student model, according to parameters indicated
  in flags above.
  :param dataset: string corresponding to mnist, cifar10, or svhn
  :param nb_teachers: number of teachers (in the ensemble) to learn from
  :param save: if set to True, will dump student training labels predicted by
               the ensemble of teachers (with Laplacian noise) as npy files.
               It also dumps the clean votes for each class (without noise) and
               the labels assigned by teachers
  :return: pairs of (data, labels) to be used for student training and testing
  """
  assert input.create_dir_if_needed(FLAGS.train_dir)

  # Load the dataset
  if dataset == 'svhn':
    test_data, test_labels = input.ld_svhn(test_only=True)
  elif dataset == 'cifar10':
    test_data, test_labels = input.ld_cifar10(test_only=True)
  elif dataset == 'mnist':
    test_data, test_labels = input.ld_mnist(test_only=True)
  else:
    print("Check value of dataset flag")
    return False

  test_data = test_data.reshape(10000,1,28,28)
  
  # Make sure there is data leftover to be used as a test set
  assert FLAGS.stdnt_share < len(test_data)

  # Prepare [unlabeled] student training data (subset of test set)
  stdnt_data = test_data[:FLAGS.stdnt_share]
  stdnt_label = test_labels[:FLAGS.stdnt_share]

  stdnt_prep = PrepareData(stdnt_data, stdnt_label)
    
  stdnt_loader = DataLoader(stdnt_prep, batch_size=64, shuffle=False)

  # Compute teacher predictions for student training data
  teachers_preds = ensemble_preds(dataset, nb_teachers, stdnt_loader)

  # Aggregate teacher predictions to get student training labels
  stdnt_labels = aggregation.noisy_max(teachers_preds, FLAGS.lap_scale)
  

  # Print accuracy of aggregated labels
  ac_ag_labels = metrics.accuracy(stdnt_labels, test_labels[:FLAGS.stdnt_share])
  print("\nAccuracy of the aggregated labels: " + str(ac_ag_labels) + "\n")

  # Store unused part of test set for use as a test set after student training
  stdnt_test_data = test_data[FLAGS.stdnt_share:]
  stdnt_test_labels = test_labels[FLAGS.stdnt_share:]

  return stdnt_data, stdnt_labels, stdnt_test_data, stdnt_test_labels

In [26]:
def train_student(dataset, nb_teachers):
  """
  This function trains a student using predictions made by an ensemble of
  teachers. The student and teacher models are trained using the same
  neural network architecture.
  :param dataset: string corresponding to mnist, cifar10, or svhn
  :param nb_teachers: number of teachers (in the ensemble) to learn from
  :return: True if student training went well
  """
  assert input.create_dir_if_needed(FLAGS.train_dir)

  # Call helper function to prepare student data using teacher predictions
  stdnt_dataset = prepare_student_data(dataset, nb_teachers, save=False)

  # Unpack the student dataset
  stdnt_data, stdnt_labels, stdnt_test_data, stdnt_test_labels = stdnt_dataset
  stdnt_data = stdnt_data.reshape(1000,1,28,28)
  stdnt_test_data = stdnt_test_data.reshape(9000,1,28,28)

  # Prepare checkpoint filename and path
  ckpt_path = 'differential_privacy/multiple_teachers/' + 'checkpoint/'

  filename = str(dataset) + '_' + str(nb_teachers) + '_student.ckpt'
  
  stdnt_prep = PrepareData(stdnt_data, stdnt_labels) 
  stdnt_loader = DataLoader(stdnt_prep, batch_size=64, shuffle=False)

  stdnt_test_prep = PrepareData(stdnt_test_data, stdnt_test_labels) 
  stdnt_test_loader = DataLoader(stdnt_test_prep, batch_size=64, shuffle=False)
  

  # Start student training
  train(model, stdnt_loader, stdnt_test_loader, ckpt_path, filename)

  # Compute final checkpoint name for student (with max number of steps)
  #ckpt_path_final = ckpt_path + '-' + str(FLAGS.max_steps - 1)

  # Compute student label predictions on remaining chunk of test set
  student_preds = softmax_preds(stdnt_test_loader, ckpt_path + filename)

  # Compute teacher accuracy
  precision = metrics.accuracy(student_preds, stdnt_test_labels)
  print('\nPrecision of student after training: ' + str(precision))

  return True

In [27]:
train_student(dataset, nb_teachers)



Computed Teacher 0 softmax predictions
Computed Teacher 1 softmax predictions
Computed Teacher 2 softmax predictions
Computed Teacher 3 softmax predictions
Computed Teacher 4 softmax predictions
Computed Teacher 5 softmax predictions
Computed Teacher 6 softmax predictions
Computed Teacher 7 softmax predictions
Computed Teacher 8 softmax predictions
Computed Teacher 9 softmax predictions
Computed Teacher 10 softmax predictions
Computed Teacher 11 softmax predictions
Computed Teacher 12 softmax predictions
Computed Teacher 13 softmax predictions
Computed Teacher 14 softmax predictions
Computed Teacher 15 softmax predictions
Computed Teacher 16 softmax predictions
Computed Teacher 17 softmax predictions
Computed Teacher 18 softmax predictions
Computed Teacher 19 softmax predictions
Computed Teacher 20 softmax predictions
Computed Teacher 21 softmax predictions
Computed Teacher 22 softmax predictions
Computed Teacher 23 softmax predictions
Computed Teacher 24 softmax predictions
Computed T

True