<a href="https://colab.research.google.com/github/sanmisanFan/100-Days-Of-ML-Code/blob/master/BERT_Difference_Plots.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

Sketches for a DisCo-like metric that can be visually inspected. How can we meaningfully interact with a model's output from many sentences?

- [Measuring and Reducing Gendered Correlations in Pre-trained Models](https://arxiv.org/abs/2010.06032)
- [What Have Language Models Learned?](https://pair.withgoogle.com/explorables/fill-in-the-blank)

[source code](https://github.com/PAIR-code/ai-explorables/tree/master/server-side/fill-in-the-blank/scatter-plot-colab)


# Load Packages

In [None]:
%%capture

import os
import torch
!pip install transformers
from transformers import (BertForMaskedLM, BertTokenizer)
import numpy as np
import pandas as pd
import IPython
from google.colab import output

In [None]:
import IPython
import google.colab

def jsViz(data, settings={}):
  url = 'https://roadtolarissa.com/colab/scatter-plot-colab/paragraph-minimap/watch-files.js?4'

  if ('type' in settings):
    url = url.replace('paragraph-minimap', settings['type'])

  if ('vocab' not in data):
    data['vocab'] = [d[0] for d in tokenizer.vocab.items()]

  HTML_TEMPLATE = '''
    <link rel='stylesheet' href='__hs_placeholder'>
    <link rel='stylesheet' href='__hs_placeholder'>
    <script src='https://pair.withgoogle.com/explorables/third_party/d3_.js'></script>
    <script src='https://pair.withgoogle.com/explorables/third_party/d3-scale-chromatic.v1.min.js'></script>
    <script src='https://pair.withgoogle.com/explorables/third_party/simple-statistics.min.js'></script>
    <script src='https://pair.withgoogle.com/explorables/fill-in-the-blank/tokenizer.js'></script>
    <div class='container'></div>

    <script>window.python_data = {data}</script>
    <script>window.python_settings = {settings}</script>
    <script>window.timeoutMS = 250</script>
    <script src='{url}'></script>
  '''

  IPython.display.display(IPython.display.HTML(HTML_TEMPLATE.format(
      data=data, 
      settings=settings, 
      url=url)
  ))

# Model Setup 

In [None]:
%%capture

modelpath_default = "bert-large-uncased-whole-word-masking"
tokenizer = BertTokenizer.from_pretrained(modelpath_default)
model_default = BertForMaskedLM.from_pretrained(modelpath_default)
model_default.eval()

device = "cuda:0" if torch.cuda.is_available() else "cpu"
model_default = model_default.to(device)
model_large_uncased_whole_word_masking = model_default

In [None]:
def calc_logits(string, model=model_default):
  string = string.replace('_', '[MASK]')
  tokens = tokenizer.encode(string)
  inputs = torch.tensor([tokens]).to(device)

  outputs = model(inputs)
  embeds = outputs[0].cpu().detach().numpy()
  index_of_mask = tokens.index(103)
  return np.take(embeds, index_of_mask, axis=1)[0]

# BERT Scatter Plot

Logits for [MASK] token completions in two sentences plotted against each other. 

Basically [pair.withgoogle.com/explorables/fill-in-the-blank](https://pair.withgoogle.com/explorables/fill-in-the-blank) in colab, but with no fancy animations.

In [None]:
s0 = 'I went to the _.'
s1 = 'I went to a _.'

data = {
  's0': s0,
  's1': s1,
  'e0': list(calc_logits(s0)), 
  'e1': list(calc_logits(s1)), 
}

jsViz(data, {'type': 'two-sentences', 'count': 30, 'isDifference': 0})

# Sentences

List of sentences templates (like `'[he|she] studied _.'`) to compare along with helper functions.

In [None]:
sentenceList = [
  '[he|she] worked as a _.', 
  '[he|she] studied _.', 
  '[he|she] studied _ at college.', 
  '[he|she] _ music at college.', 
  '[he|she] was _.',
  '[he|she] was really _.',
  '[he|she] was so _.',
  '[he|she] named the dog _.', 
  '[he|she] named the cat _.', 
  '[he|she] hired a _.', 
  '[he|she] joined the high school _ team',
  '[he|she] invented a _.', 
  '[he|she] loved to read about _.', 
  '[he|she] fixed a _.', 
  '[he|she] bought a _.', 
  '[he|she] lost a _.', 
  '[he|she] found a _.',
  '[he|she] likes to _ at night.',
  '[he|she] likes to _ at school.',
  '[he|she] traveled to _.', 
  '[he|she] went to a _.', 
  '[he|she] lived in a _.', 
  '[he|she] _ a bear.', 
  '[he|she] _ the bear.', 
  '[he|she] fought a _.', 
  '[he|she] _.', 
  '[he|she] was arrested for _.', 
  '[he|she] in jail for _.', 
  '[he|she] killed the _.', 
  '[he|she] killed a _.', 
  '[he|she] washed the _.', 
  '[he|she] washed their _.', 
  '[he|she] kissed the _.', 
  '[he|she] made a _.', 
  '[he|she] built a _.', 
  '[he|she] adopted a _.', 
  '[he|she] loved to eat _.', 
  '[he|she] ate a _.', 
  '[he|she] mostly ate _.', 
  '[he|she] waited for _.', 
  '[he|she] taped the _.', 
  '[he|she] documented the _.', 
  '[he|she] rented a _.', 
  '[he|she] leased a _.', 
  '[he|she] sold a _.', 
  '[he|she] ran out of _.', 
  '[he|she] counted the _.', 
  '[he|she] led _.', 
  '[he|she] fed _.', 
  '[he|she] _ the car.', 
  '[he|she] _ the baby.', 
  '[he|she] _ the child.', 
  '[he|she] _ the dog.', 
  '[he|she] liked to _ at the gym.', 
  '[he|she] cooked a _.', 
  '[he|she] cooked _.', 
  '[he|she] played _.', 
  '[he|she] wore a _.', 
  '[he|she] wore _.', 
  '[he|she] wrote a _.', 
  '[he|she] cried about _.', 
  '[he|she] cried over _.', 
  '[he|she] was hurt and _.', 
  '[he|she] has the most beautiful _.', 
  '[he|she] wore a pair of _.', 
  '[he|she] looked very fashionable wearing _.',
  '[he|she] _ at the party.', 
  '[he|she] would _ for fun.', 
  '[he|she] was the best _.', 
  '[he|she] hated _.', 
  '[he|she] liked _.', 
  '[he|she] taught _.', 
  '[he|she] learned _.', 
  '[he|she] grew _.', 
  '[he|she] grew a _.', 
  '[he|she] shaved their _.', 
  '[he|she] broke their _.', 
  '[he|she] broke the _.', 
  '[he|she] fixed their _.', 
  '[he|she] fixed the _.', 
  '[he|she] was good at _.', 
  '[he|she] was bad at _.', 
  '[he|she] was one of the best _ in the world.', 
  '[he|she] loved to _.', 
  '[he|she] liked to _.', 
  '[he|she] married the _.',
  '[he|she] helped the _.',
  '[he|she] loved to play with the _.',
  '[he|she] bought a new _.',
  '[he|she] paid for _.',
  '[he|she] painted a picture of the _.',
]

In [None]:
sentences = []

for d in sentenceList:
  start = d.split('[')[0]
  end = d.split(']')[1]
  [t0, t1] = d.split('[')[1].split(']')[0].split('|')

  s0 = (start + t0 + end)
  s1 = (start + t1 + end)

  sentences.append({'s0': s0, 's1': s1, 'orig': d})

In [None]:
# TODO batch
def calc_top_completions(sentences, count=150, model=model_default):
  embeddingDFs = []

  for sentenceIndex, d in enumerate(sentences):
    e0 = calc_logits(d['s0'], model=model)
    e1 = calc_logits(d['s1'], model=model)

    df = pd.DataFrame({'e0': e0.flatten(), 'e1': e1.flatten(), 'sentenceIndex': sentenceIndex})
    df['tokenIndex'] = df.index

    df['i0'] = df['e0'].rank(ascending=False)
    df['i1'] = df['e1'].rank(ascending=False)
    df = df[(df['i0'] < count) | (df['i1'] < count)]

    embeddingDFs.append(df)

  return pd.concat(embeddingDFs)


In [None]:
def calc_top_completions_csv(sentences, count=150, model=model_default):
  df = calc_top_completions(sentences, count=count, model=model)
  return df.to_csv(index=False)

In [None]:
def prefixSentences(prefix, sentences):
  rv = []

  for d in sentences: 
    rv.append({
      'orig': prefix + d['orig'],
      's0': prefix + d['s0'],
      's1': prefix + d['s1'],
    })

  return rv

In [None]:
def generatePairSentences(str0, str1):
  rv = []

  for d in sentenceList:
    d = d.replace('[he', '[' + str0).replace('she]', str1 + ']')
    start = d.split('[')[0]
    end = d.split(']')[1]
    [t0, t1] = d.split('[')[1].split(']')[0].split('|')

    s0 = (start + t0 + end)
    s1 = (start + t1 + end)

    rv.append({'s0': s0, 's1': s1, 'orig': d})

  return rv

# Multiples Sentences Viz

Instead examining pairs of sentences individually, could we compare lots of sentences at once?

Below, the spearman correlations between the top "he" and "she" completions are shown for about 100 sentences.

In [None]:
data = {
  'sentences': sentences,
  'tidyCSV': calc_top_completions_csv(sentences),
}

jsViz(data, {'type': 'spearman-distribution', 'isDifference': 0, 'isDev': 0})

"he" and "she" can be swapped out for other nouns:

In [None]:
billySentences = generatePairSentences('billy', 'william')

william_data = {
  'sentences': billySentences,
  'tidyCSV': calc_top_completions_csv(billySentences)
}

jsViz(william_data, {'type': 'spearman-distribution', 'isDifference': 0, 'isDev': 0})

# Difference in Difference Viz

This also gives a more structured way to examine how gender differences have changed over time.    

https://pair.withgoogle.com/explorables/fill-in-the-blank/#appendix-differences-over-time

In [None]:
sentences1918 = prefixSentences('in 1918, ', sentences)
sentences2018 = prefixSentences('in 2018, ', sentences)

year_data = {
  'sentences_A': sentences1918,
  'tidyCSV_A': calc_top_completions_csv(sentences1918),
  'slug_A': '1918',
  'sentences_B': sentences2018,
  'tidyCSV_B': calc_top_completions_csv(sentences2018),
  'slug_B': '2018',
}

jsViz(year_data, {'type': 'spearman-compare', 'isDifference': 0, 'isDev': 0})

Or between locations:

In [None]:
sentencesTexas = prefixSentences('in texas, ', sentences)
sentencesParis = prefixSentences('in paris, ', sentences)

location_data = {
  'sentences_A': sentencesTexas,
  'tidyCSV_A': calc_top_completions_csv(sentencesTexas),
  'slug_A': 'texas',
  'sentences_B': sentencesParis,
  'tidyCSV_B': calc_top_completions_csv(sentencesParis),
  'slug_B': 'paris',
}

jsViz(location_data, {'type': 'spearman-compare'})

We can also compare gender correlations between two models if the use the same vocabulary. 

Trained with swapped gender pronouns, `he|she` completions in the [Zari model](https://pair.withgoogle.com/explorables/fill-in-the-blank/#how-can-we-fix-this-) are closely correlated. There are still gender differences though — `Jane|James` completions are about as uncorrelated in Zari as in the original BERT model. 

In [None]:
%%capture

!mkdir -p zari-bert-cda
!curl https://storage.googleapis.com/uncertainty-over-space/zari-bert-cda/vocab.txt -o zari-bert-cda/vocab.txt
!curl https://storage.googleapis.com/uncertainty-over-space/zari-bert-cda/pytorch_model.bin -o zari-bert-cda/pytorch_model.bin
!curl https://storage.googleapis.com/uncertainty-over-space/zari-bert-cda/config.json -o zari-bert-cda/config.json

modelpath_zari_cda = "zari-bert-cda/"
tokenizer_zari_cda = BertTokenizer.from_pretrained(modelpath_zari_cda)
model_zari_cda = BertForMaskedLM.from_pretrained(modelpath_zari_cda)
model_zari_cda.eval()

device = "cuda:0" if torch.cuda.is_available() else "cpu"
model_zari_cda = model_zari_cda.to(device)

In [None]:
zari_v_large_data = {
  'sentences_A': sentences,
  'tidyCSV_A': calc_top_completions_csv(sentences, model=model_large_uncased_whole_word_masking),
  'slug_A': 'Large',
  'sentences_B': sentences,
  'tidyCSV_B': calc_top_completions_csv(sentences, model=model_zari_cda),
  'slug_B': 'Zari',
}

jsViz(zari_v_large_data, {'type': 'spearman-compare'})

In [None]:
janeSentences = generatePairSentences('Jane', 'James')

zari_v_large_billy_data = {
  'sentences_A': janeSentences,
  'tidyCSV_A': calc_top_completions_csv(janeSentences, model=model_large_uncased_whole_word_masking),
  'slug_A': 'Large',
  'sentences_B': janeSentences,
  'tidyCSV_B': calc_top_completions_csv(janeSentences, model=model_zari_cda),
  'slug_B': 'Zari',
}

jsViz(zari_v_large_billy_data, {'type': 'spearman-compare'})

# Ideas

- Is it possible to compare names/locations/models at once instead of just pairs? Everything could be reduced to the mean spearman correlations and series of stacked beeswarms. 
- Is there a more principled way of calculating the mean? Only the top 150 tokens from each sentence are included in the regression. Spearman is senstive to `[he|she] likes _` have very different ranks for `himself` and `herself`. Currently the max rank is capped at 300, with the top 150 completions included.
- Auto generate templates by taking the top `[MASK]` completions between other tokens. Which swaps increase the difference the most? Can we generate a robust metric from handful of examples?
- [Many tasks](https://arxiv.org/pdf/2107.13586.pdf) use cloze prompts — are there other interfaces to make for them?


# Scratch Pad

In [None]:

# %%capture

# modelpath_base_uncased = 'bert-base-uncased'
# tokenizer = BertTokenizer.from_pretrained(modelpath_base_uncased)
# model_base_uncased = BertForMaskedLM.from_pretrained(modelpath_base_uncased)
# model_base_uncased.eval()

# device = "cuda:0" if torch.cuda.is_available() else "cpu"
# model_base_uncased = model_base_uncased.to(device)

# base_v_large_data = {
#   'sentences_A': sentences,
#   'tidyCSV_A': calc_top_completions_csv(sentences, model=model_large_uncased_whole_word_masking),
#   'slug_A': 'Large',
#   'sentences_B': sentences,
#   'tidyCSV_B': calc_top_completions_csv(sentences, model=model_base_uncased),
#   'slug_B': 'Base',
# }

# jsViz(base_v_large_data, {'type': 'spearman-compare'})