Copyright 2022 The TensorFlow Authors.

In [None]:
#@title Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


# Use Membership Inference and Secret Sharer to Test Word Embedding Models

This notebook shows how to run privacy tests for word2vec models, trained with gensim. Models are trained using the procedure used in https://arxiv.org/abs/2004.00053, code for which is found here: https://github.com/google/embedding-tests .

We run membership inference as well as secret sharer. Membership inference attempts to identify whether a given document was included in training. Secret sharer adds random "canary" documents into training, and identifies which canary was added.

<table class="tfo-notebook-buttons" align="left">
  <td>
    <a target="_blank" href="https://colab.research.google.com/github/tensorflow/privacy/blob/master/tensorflow_privacy/privacy/privacy_tests/membership_inference_attack/codelabs/word2vec_codelab.ipynb"><img src="https://www.tensorflow.org/images/colab_logo_32px.png" />Run in Google Colab</a>
  </td>
  <td>
    <a target="_blank" href="https://github.com/tensorflow/privacy/blob/master/tensorflow_privacy/privacy/privacy_tests/membership_inference_attack/codelabs/word2vec_codelab.ipynb"><img src="https://www.tensorflow.org/images/GitHub-Mark-32px.png" />View source on GitHub</a>
  </td>
</table>

In [None]:
# install dependencies
!pip install gensim --upgrade
!pip install git+https://github.com/tensorflow/privacy

from IPython.display import clear_output
clear_output()

In [None]:
# imports
import smart_open
import random
import gensim.utils
import os
import bz2
import multiprocessing
import logging
import tqdm
import xml
import numpy as np

from gensim.models import Word2Vec
from six import raise_from
from gensim.corpora.wikicorpus import WikiCorpus, init_to_ignore_interrupt, \
  ARTICLE_MIN_WORDS, _process_article, IGNORED_NAMESPACES, get_namespace
from pickle import PicklingError
from xml.etree.cElementTree import iterparse, ParseError

from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack import membership_inference_attack as mia
from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack import data_structures as mia_data_structures
from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack import plotting as mia_plotting

from tensorflow_privacy.privacy.privacy_tests.secret_sharer.exposures import compute_exposure_interpolation, compute_exposure_extrapolation

In [None]:
# all the functions we need to get data and canary it
# we will use google drive to store data models to be able to reuse them
# you can change this to local directories by changing DATA_DIR and MODEL_DIR
# make sure to copy the data locally, otherwise training will be very slow

# code in this cell originates from https://github.com/google/embedding-tests
# some edits were made to allow saving to google drive, and to add canaries

from google.colab import drive
drive.mount('/content/drive/')

LOCAL_DATA_DIR = 'data_dir'
LOCAL_MODEL_DIR = 'model_dir'
DATA_DIR = '/content/drive/MyDrive/w2v/data_dir/'
MODEL_DIR = '/content/drive/MyDrive/w2v/model_dir/'

# made up words will be used for canaries
MADE_UP_WORDS = []
for i in range(20):
  MADE_UP_WORDS.append("o"*i + "oongaboonga")

# deterministic dataset partitioning
def gen_seed(idx, n=10000):
  random.seed(12345)

  seeds = []
  for i in range(n):
    s = random.random()
    seeds.append(s)

  return seeds[idx]


def make_wiki9_dirs(data_dir):
  # makes all the directories we'll need to store data
  wiki9_path = os.path.join(data_dir, 'wiki9', 'enwik9.bz2')
  wiki9_dir = os.path.join(data_dir, 'wiki9', 'articles')
  wiki9_split_dir = os.path.join(data_dir, 'wiki9', 'split')
  for d in [wiki9_dir, wiki9_split_dir]:
    if not os.path.exists(d):
      os.makedirs(d)
  return wiki9_path, wiki9_dir, wiki9_split_dir


def extract_pages(f, filter_namespaces=False, filter_articles=None):
  try:
    elems = (elem for _, elem in iterparse(f, events=("end",)))
  except ParseError:
    yield None, "", None

  elem = next(elems)
  namespace = get_namespace(elem.tag)
  ns_mapping = {"ns": namespace}
  page_tag = "{%(ns)s}page" % ns_mapping
  text_path = "./{%(ns)s}revision/{%(ns)s}text" % ns_mapping
  title_path = "./{%(ns)s}title" % ns_mapping
  ns_path = "./{%(ns)s}ns" % ns_mapping
  pageid_path = "./{%(ns)s}id" % ns_mapping

  try:

    for elem in elems:
      if elem.tag == page_tag:
        title = elem.find(title_path).text
        text = elem.find(text_path).text

        if filter_namespaces:
          ns = elem.find(ns_path).text
          if ns not in filter_namespaces:
            text = None

        if filter_articles is not None:
          if not filter_articles(
            elem, namespace=namespace, title=title,
            text=text, page_tag=page_tag,
            text_path=text_path, title_path=title_path,
            ns_path=ns_path, pageid_path=pageid_path):
            text = None

        pageid = elem.find(pageid_path).text
        yield title, text or "", pageid  # empty page will yield None

        elem.clear()
  except ParseError:
    yield None, "", None
  return

class MyWikiCorpus(WikiCorpus):

  def get_texts(self):
    logger = logging.getLogger(__name__)

    articles, articles_all = 0, 0
    positions, positions_all = 0, 0

    tokenization_params = (
      self.tokenizer_func, self.token_min_len, self.token_max_len, self.lower)
    texts = ((text, title, pageid, tokenization_params)
             for title, text, pageid in extract_pages(bz2.BZ2File(self.fname),
                                                      self.filter_namespaces,
                                                      self.filter_articles))
    print("got texts")
    pool = multiprocessing.Pool(self.processes, init_to_ignore_interrupt)

    try:
      # process the corpus in smaller chunks of docs,
      # because multiprocessing.Pool
      # is dumb and would load the entire input into RAM at once...
      for group in gensim.utils.chunkize(texts, chunksize=10 * self.processes,
                                         maxsize=1):
        for tokens, title, pageid in pool.imap(_process_article, group):
          articles_all += 1
          positions_all += len(tokens)
          # article redirects and short stubs are pruned here
          if len(tokens) < self.article_min_tokens or \
              any(title.startswith(ignore + ':') for ignore in
                  IGNORED_NAMESPACES):
            continue
          articles += 1
          positions += len(tokens)
          yield (tokens, (pageid, title))

    except KeyboardInterrupt:
      logger.warn(
        "user terminated iteration over Wikipedia corpus after %i"
        " documents with %i positions "
        "(total %i articles, %i positions before pruning articles"
        " shorter than %i words)",
        articles, positions, articles_all, positions_all, ARTICLE_MIN_WORDS
      )
    except PicklingError as exc:
      raise_from(
        PicklingError('Can not send filtering function {} to multiprocessing, '
                      'make sure the function can be pickled.'.format(
                        self.filter_articles)), exc)
    else:
      logger.info(
        "finished iterating over Wikipedia corpus of %i "
        "documents with %i positions "
        "(total %i articles, %i positions before pruning articles"
        " shorter than %i words)",
        articles, positions, articles_all, positions_all, ARTICLE_MIN_WORDS
      )
      self.length = articles  # cache corpus length
    finally:
      pool.terminate()


def write_wiki9_articles(data_dir):
  wiki9_path, wiki9_dir, wiki9_split_dir = make_wiki9_dirs(data_dir)
  wiki = MyWikiCorpus(wiki9_path, dictionary={},
                      filter_namespaces=False)
  i = 0
  for text, (p_id, title) in tqdm.tqdm(wiki.get_texts()):
    i += 1
    if title is None:
      continue

    article_path = os.path.join(wiki9_dir, p_id)
    if os.path.exists(article_path):
      continue

    with open(article_path, 'wb') as f:
      f.write(' '.join(text).encode("utf-8"))
  print("done", i)

def split_wiki9_articles(data_dir, exp_id=0):
  wiki9_path, wiki9_dir, wiki9_split_dir = make_wiki9_dirs(data_dir)
  all_docs = list(os.listdir(wiki9_dir))
  print("wiki9 len", len(all_docs))
  print(wiki9_dir)
  s = gen_seed(exp_id)
  random.seed(s)
  random.shuffle(all_docs)
  random.seed()

  n = len(all_docs) // 2
  return all_docs[:n], all_docs[n:]


def read_wiki9_train_split(data_dir, exp_id=0):
  wiki9_path, wiki9_dir, wiki9_split_dir = make_wiki9_dirs(data_dir)

  split_path = os.path.join(wiki9_split_dir, 'split{}.train'.format(exp_id))
  if not os.path.exists(split_path):
    train_docs, _ = split_wiki9_articles(exp_id=exp_id)
    with open(split_path, 'w') as f:
      for doc in tqdm.tqdm(train_docs):
        with open(os.path.join(wiki9_dir, doc), 'r') as fd:
          f.write(fd.read())
        f.write(' ')

  return split_path

def build_vocab(word2vec_model):
  vocab = word2vec_model.wv.index_to_key
  counts = [word2vec_model.wv.get_vecattr(word, "count") for word in vocab]
  sorted_inds = np.argsort(counts)
  sorted_vocab = [vocab[ind] for ind in sorted_inds]
  return sorted_vocab

def sample_words(vocab, count, rng):
  inds = rng.choice(len(vocab), count, replace=False)
  return [vocab[ind] for ind in inds], rng


def gen_canaries(num_canaries, canary_repeat, vocab_model_path, seed=0):
  # create canaries, injecting made up words into the corpus
  existing_w2v = Word2Vec.load(vocab_model_path)
  existing_vocab = build_vocab(existing_w2v)
  rng = np.random.RandomState(seed)

  all_canaries = []
  for i in range(num_canaries):
    new_word = MADE_UP_WORDS[i%len(MADE_UP_WORDS)]
    assert new_word not in existing_vocab
    canary_words, rng = sample_words(existing_vocab, 4, rng)
    canary = canary_words[:2] + [new_word] + canary_words[2:]
    all_canaries.append(canary)
  all_canaries = all_canaries * canary_repeat
  return all_canaries

# iterator for training documents, with an option to canary
class WIKI9Articles:
  def __init__(self, docs, data_dir, verbose=0, ssharer=False, num_canaries=0,
               canary_repeat=0, canary_seed=0, vocab_model_path=None):
    self.docs = [(0, doc) for doc in docs]
    if ssharer:
      all_canaries = gen_canaries(
          num_canaries, canary_repeat, vocab_model_path, canary_seed)
      self.docs.extend([(1, canary) for canary in all_canaries])
      np.random.RandomState(0).shuffle(self.docs)

    wiki9_path, wiki9_dir, wiki9_split_dir = make_wiki9_dirs(data_dir)
    self.dirname = wiki9_dir
    self.verbose = verbose

  def __iter__(self):
    for is_canary, fname in tqdm.tqdm(self.docs) if self.verbose else self.docs:
      if not is_canary:
        for line in smart_open.open(os.path.join(self.dirname, fname),
                                    'r', encoding='utf-8'):
          yield line.split()
      else:
          yield fname


def train_word_embedding(data_dir, model_dir, exp_id=0, use_secret_sharer=False,
                         num_canaries=0, canary_repeat=1, canary_seed=0,
                         vocab_model_path=None):
  # this function trains the word2vec model, after setting up the training set
  logging.basicConfig(format='%(asctime)s : %(levelname)s : %(message)s',
                      level=logging.INFO)

  params = {
    'sg': 1,
    'negative': 25,
    'alpha': 0.05,
    'sample': 1e-4,
    'workers': 48,
    'epochs': 5,
    'window': 5,
  }

  train_docs, test_docs = split_wiki9_articles(data_dir, exp_id)
  print(len(train_docs), len(test_docs))
  wiki9_articles = WIKI9Articles(
      train_docs, data_dir, ssharer=use_secret_sharer, num_canaries=num_canaries,
      canary_repeat=canary_repeat, canary_seed=canary_seed, vocab_model_path=vocab_model_path)

  if not os.path.exists(model_dir):
    os.makedirs(model_dir)

  model = Word2Vec(wiki9_articles, **params)

  if not use_secret_sharer:
    model_path = os.path.join(model_dir, 'wiki9_w2v_{}.model'.format(exp_id))
  else:
    model_path = os.path.join(model_dir, 'wiki9_w2v_{}_{}_{}_{}.model'.format(
        exp_id, num_canaries, canary_repeat, canary_seed
        ))
  model.save(model_path)
  return model_path, train_docs, test_docs

In [None]:
# setup directories
wiki9_path, wiki9_dir, wiki9_split_dir = make_wiki9_dirs(DATA_DIR)
local_wiki9_path, local_wiki9_dir, local_wiki9_splitdir = make_wiki9_dirs(LOCAL_DATA_DIR)

In [None]:
# download and format documents
!wget http://mattmahoney.net/dc/enwik9.zip
!unzip enwik9.zip
!bzip2 enwik9
!cp enwik9.bz2 $wiki9_path
!cp $wiki9_path $local_wiki9_path
write_wiki9_articles(LOCAL_DATA_DIR)  # need local data for fast training

# Membership Inference Attacks

Let's start by running membership inference on a word2vec model.

We'll start by training a bunch of word2vec models with different train/test splits. This can take a long time, so be patient!

In [None]:
for i in range(10):
  if os.path.exists(os.path.join(MODEL_DIR, f"wiki9_w2v_{i}.model")):
    print("done", i)
    continue
  model_path, train_docs, test_docs = train_word_embedding(LOCAL_DATA_DIR, MODEL_DIR, exp_id=i)
  print(model_path)

We now define our loss function. We follow https://arxiv.org/abs/2004.00053, computing the loss of a document as the average loss over all 5 token "windows" in the document.

In [None]:
from re import split

def loss(model, window):
  # compute loss for a single window of 5 tokens
  try:
    sum_embedding = np.array([model.wv[word] for word in window]).sum(axis=0)
  except:
    return np.nan
  middle_embedding = model.wv[window[2]]
  context_embedding = 0.25*(sum_embedding - middle_embedding)
  return np.linalg.norm(middle_embedding - context_embedding)

def loss_per_article(model, article):
  # compute loss for a full document
  losses = []
  article = article.split(' ')
  embs = [model.wv[word] if word in model.wv else np.nan for word in article]

  for i in range(len(article) - 4):
    middle_embedding = embs[i+2]
    context_embedding = 0.25*(np.mean(embs[i:i+2] + embs[i+3:i+5]))
    losses.append(np.linalg.norm(middle_embedding - context_embedding))
  return np.nanmean(losses)

Let's now get the losses of all models on all documents. This also takes a while, so we'll only get a subset.

In [None]:
all_models = []
for i in range(1000, 1020):
  model_path = os.path.join(MODEL_DIR, f"wiki9_w2v_{i}.model")
  if not os.path.exists(model_path):
    continue
  all_models.append(Word2Vec.load(model_path))

train_docs, test_docs = split_wiki9_articles(LOCAL_DATA_DIR, 0)
all_docs = sorted(train_docs + test_docs)
all_losses = np.zeros((len(all_docs), len(all_models)))

for i, doc in tqdm.tqdm(enumerate(all_docs)):
  if i > 1000:
    continue
  with open(os.path.join(local_wiki9_dir, doc), 'r') as fd:
    doc_text = fd.read()
  for j, model in enumerate(all_models):
    all_losses[i,j] = loss_per_article(model, doc_text)

We're going to be running the LiRA attack, so, for each document, we get the document's losses when it is in the model, and the losses when it is not in the model.

In [None]:
all_losses = all_losses[:500, :]
doc_lookup = {doc: i for i, doc in enumerate(all_docs)}

def compute_scores_in_out(losses, seeds):
  in_scores = [[] for _ in range(losses.shape[0])]
  out_scores = [[] for _ in range(losses.shape[0])]
  for seed in seeds:
    train_docs, test_docs = split_wiki9_articles(LOCAL_DATA_DIR, seed)
    for train_doc in train_docs:
      ind = doc_lookup[train_doc]
      if ind >= all_losses.shape[0]:
        continue
      in_scores[ind].append([all_losses[ind, seed-1000]])
    for test_doc in test_docs:
      ind = doc_lookup[test_doc]
      if ind >= all_losses.shape[0]:
        continue
      out_scores[ind].append([all_losses[ind, seed-1000]])
  in_scores = [np.array(s) for s in in_scores]
  out_scores = [np.array(s) for s in out_scores]
  print(in_scores[0].shape)
  return in_scores, out_scores
# we will do MI on model 0
in_scores, out_scores = compute_scores_in_out(all_losses, list(range(1001, 1020)))

Now let's run the global threshold membership inference attack. It gets an advantage of around 0.07.

In [None]:
# global threshold MIA attack
train_docs, test_docs = split_wiki9_articles(LOCAL_DATA_DIR, 1000)
train_losses, test_losses = [], []
for train_doc in train_docs:
  ind = doc_lookup[train_doc]
  if ind >= all_losses.shape[0]:
    continue
  train_losses.append(all_losses[ind, 0])
for test_doc in test_docs:
  ind = doc_lookup[test_doc]
  if ind >= all_losses.shape[0]:
    continue
  test_losses.append(all_losses[ind, 0])

attacks_result_baseline = mia.run_attacks(
    mia_data_structures.AttackInputData(
          loss_train = -np.nan_to_num(train_losses),
          loss_test = -np.nan_to_num(test_losses))).single_attack_results[0]
print('Global Threshold MIA attack:',
          f'auc = {attacks_result_baseline.get_auc():.4f}',
          f'adv = {attacks_result_baseline.get_attacker_advantage():.4f}')

And now we run LiRA. First we need to compute LiRA scores.

In [None]:
# run LiRA
from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack import advanced_mia as amia
good_inds = []
for i, (in_s, out_s) in enumerate(zip(in_scores, out_scores)):
  if len(in_s) > 0 and len(out_s) > 0:
    good_inds.append(i)

for i in good_inds:
  assert len(in_scores[i]) > 0
  assert len(in_scores[i]) > 0

scores = amia.compute_score_lira(all_losses[good_inds, 0],
                                 [in_scores[i] for i in good_inds],
                                 [out_scores[i] for i in good_inds],
                                 fix_variance=True)

train_docs, test_docs = split_wiki9_articles(LOCAL_DATA_DIR, 1000)
in_mask = np.zeros(len(good_inds), dtype=bool)
for doc in train_docs:
  ind = doc_lookup[doc]
  if ind >= all_losses.shape[0]:
    continue
  if ind in good_inds:
    in_mask[good_inds.index(ind)] = True


And now we threshold on LiRA scores, as before. Advantage goes from .07 to .13, it almost doubled!

In [None]:
attacks_result_baseline = mia.run_attacks(
    mia_data_structures.AttackInputData(
          loss_train = scores[in_mask],
          loss_test = scores[~in_mask])).single_attack_results[0]
print('Advanced MIA attack with Gaussian:',
          f'auc = {attacks_result_baseline.get_auc():.4f}',
          f'adv = {attacks_result_baseline.get_attacker_advantage():.4f}')

# Secret Sharer

Here, we're going to run a secret sharer attack on a word2vec model. Our canaries (generated above in gen_canaries) look like the following:

"word1 word2 made_up_word word3 word4",

where all the words except for the made up word are real words from the vocabulary. The model's decision on where to put the made up word in embedding space will depend solely on the canary, which will make this an effective attack. We insert canaries with various repetition counts, and train some models:

In [None]:
vocab_model_path = os.path.join(MODEL_DIR, 'wiki9_w2v_1.model')
interp_exposures = {}
extrap_exposures = {}
all_canaries = gen_canaries(10000, 1, vocab_model_path, 0)

for repeat_count in [5, 10, 20]:
  model_path = os.path.join(MODEL_DIR, 'wiki9_w2v_0_20_{}_0.model'.format(repeat_count))
  print(os.path.exists(model_path))
  model_path, _, _ = train_word_embedding(
      LOCAL_DATA_DIR, MODEL_DIR, exp_id=0, use_secret_sharer=True, num_canaries=20,
      canary_repeat=repeat_count, canary_seed=0, vocab_model_path=vocab_model_path)
  canaried_model = Word2Vec.load(model_path)
  canary_losses = [loss(canaried_model, canary) for canary in all_canaries]
  loss_secrets = np.array(canary_losses[:20])
  loss_ref = np.array(canary_losses[20:])
  loss_secrets = {1: loss_secrets[~np.isnan(loss_secrets)]}
  loss_ref = loss_ref[~np.isnan(loss_ref)]
  exposure_interpolation = compute_exposure_interpolation(loss_secrets, loss_ref)
  exposure_extrapolation = compute_exposure_extrapolation(loss_secrets, loss_ref)
  interp_exposures[repeat_count] = exposure_interpolation[1]
  extrap_exposures[repeat_count] = exposure_extrapolation[1]


And now let's run secret sharer! Exposure is quite high!

In [None]:
for key in interp_exposures:
  print(f"Repeats: {key}, Interpolation Exposure: {np.median(interp_exposures[key])}, Extrapolation Exposure: {np.median(extrap_exposures[key])}")

Repeats: 5, Interpolation Exposure: 12.307770031890703, Extrapolation Exposure: 54.51861034822009
Repeats: 10, Interpolation Exposure: 12.290018846932618, Extrapolation Exposure: 56.91255812786129
Repeats: 20, Interpolation Exposure: 12.290018846932618, Extrapolation Exposure: 64.00837536957133 
