<a href="https://colab.research.google.com/github/sujitpal/nlp-deeplearning-ai-examples/blob/master/arxiv_1909_01066_lm_as_kb.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Impl: Language Model as Knowledge Base?

An implementation of the [Language Model as Knowledge Bases?](https://arxiv.org/pdf/1909.01066.pdf) (Petroni, et al, 2019) using pre-trained models in the HuggingFace library.

The method used is to identify subject, predicate, and object in simple (cloze style) sentences, and mask out the predicate, and have the masked language model make a prediction. This gives synonyms of the predicate.

We haven't gone that far, we use one of the inputs provided (referenced from the github repository referenced by the paper) to infer predictions from a masked language model based on `bert-base-uncased`.

In [1]:
!pip install transformers

Collecting transformers
[?25l  Downloading https://files.pythonhosted.org/packages/3a/83/e74092e7f24a08d751aa59b37a9fc572b2e4af3918cb66f7766c3affb1b4/transformers-3.5.1-py3-none-any.whl (1.3MB)
[K     |████████████████████████████████| 1.3MB 8.7MB/s 
[?25hCollecting sacremoses
[?25l  Downloading https://files.pythonhosted.org/packages/7d/34/09d19aff26edcc8eb2a01bed8e98f13a1537005d31e95233fd48216eed10/sacremoses-0.0.43.tar.gz (883kB)
[K     |████████████████████████████████| 890kB 25.8MB/s 
Collecting tokenizers==0.9.3
[?25l  Downloading https://files.pythonhosted.org/packages/4c/34/b39eb9994bc3c999270b69c9eea40ecc6f0e97991dba28282b9fd32d44ee/tokenizers-0.9.3-cp36-cp36m-manylinux1_x86_64.whl (2.9MB)
[K     |████████████████████████████████| 2.9MB 34.8MB/s 
[?25hCollecting sentencepiece==0.1.91
[?25l  Downloading https://files.pythonhosted.org/packages/d4/a4/d0a884c4300004a78cca907a6ff9a5e9fe4f090f5d95ab341c53d28cbc58/sentencepiece-0.1.91-cp36-cp36m-manylinux1_x86_64.whl (1.1MB)

## Model and Tokenizer

Task is to predict words that are masked using BERT, so we will use [BERTMaskedLM](https://huggingface.co/transformers/model_doc/bert.html#bertformaskedlm) model and [BERTTokenizer](https://huggingface.co/transformers/model_doc/bert.html#berttokenizer) and use the pre-trained `bert-base-uncased` model.

In [2]:
import json
import pandas as pd
import torch

from transformers import BertTokenizer, BertForMaskedLM

In [3]:
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
model = BertForMaskedLM.from_pretrained('bert-base-uncased', return_dict=True)

HBox(children=(FloatProgress(value=0.0, description='Downloading', max=231508.0, style=ProgressStyle(descripti…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=433.0, style=ProgressStyle(description_…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=440473133.0, style=ProgressStyle(descri…




Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertForMaskedLM: ['cls.seq_relationship.weight', 'cls.seq_relationship.bias']
- This IS expected if you are initializing BertForMaskedLM from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForMaskedLM from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


## Data

Data comes from a site referenced in the paper repository [facebookresearch/LAMA](https://github.com/facebookresearch/LAMA).

In [4]:
%%bash
(
if [ ! -d "data" ]; then
  echo "downloading and unzipping data..."
  wget -q https://dl.fbaipublicfiles.com/LAMA/data.zip
  unzip -a data.zip
else
  echo "data already available, skipping step..."
fi
)

downloading and unzipping data...
Archive:  data.zip
   creating: data/
   creating: data/Google_RE/
  inflating: data/Google_RE/place_of_birth_test.jsonl  [text]  
  inflating: data/Google_RE/date_of_birth_test.jsonl  [text]  
  inflating: data/Google_RE/place_of_death_test.jsonl  [text]  
   creating: data/Squad/
  inflating: data/Squad/test.jsonl   [text]  
  inflating: data/relations.jsonl    [text]  
   creating: data/ConceptNet/
  inflating: data/ConceptNet/test.jsonl  [text]  
   creating: data/TREx/
  inflating: data/TREx/P740.jsonl    [text]  
  inflating: data/TREx/P108.jsonl    [text]  
  inflating: data/TREx/P190.jsonl    [text]  
  inflating: data/TREx/P27.jsonl     [text]  
  inflating: data/TREx/P1376.jsonl   [text]  
  inflating: data/TREx/P131.jsonl    [text]  
  inflating: data/TREx/P937.jsonl    [text]  
  inflating: data/TREx/P176.jsonl    [text]  
  inflating: data/TREx/P463.jsonl    [text]  
  inflating: data/TREx/P20.jsonl     [text]  
  inflating: data/TREx/P136

In [5]:
num_read = 1
records = []
with open("data/ConceptNet/test.jsonl", mode="r", encoding="utf-8") as fdata:
  for line in fdata:
    if num_read % 10000 == 0:
      print("{:d} records read".format(num_read))
    line_json = json.loads(line.strip())
    masked_sentence = line_json["masked_sentences"][0]
    label = line_json["obj_label"]
    records.append((masked_sentence, label))
    num_read += 1

print("{:d} records read, COMPLETE".format(num_read))

10000 records read
20000 records read
29775 records read, COMPLETE


In [6]:
pd.set_option('display.max_columns', None)  
pd.set_option('display.expand_frame_repr', False)
pd.set_option('max_colwidth', 0)

data_df = pd.DataFrame(records, columns=["masked_sentence", "label"])
data_df.head()

Unnamed: 0,masked_sentence,label
0,One of the things you do when you are alive is [MASK].,think
1,Something that might happen when you analyse something is [MASK].,paralysis
2,Something that might happen while analysing something is [MASK].,analysis
3,Something that might happen while analysing something is [MASK].,education
4,Something that might happen while analysing something is coming up with a new [MASK].,idea


## Functions

We are going to use the pre-trained BERT language model in inference mode only.

The tokenizer tokenizes the input sequence and pads it with the `[CLS]` and `[SEP]` tokens.

The output produced by the model has two components, `loss` and `logits`. The `logits` component has shape (1, `number_of_tokens`, `vocab_size`) where the leading 1 represents the single input sentence.

We will identify the logits corresponding to the position of our masked token, identify the top 5 vocabulary words predicted for that position, and return the softmax probabilities for each of the top 5 predicted words.


In [7]:
model = model.eval()

In [8]:
inputs = tokenizer("The capital of France is [MASK].", return_tensors="pt")
outputs = model(**inputs)

In [9]:
tokenizer.convert_ids_to_tokens(inputs.input_ids[0])

['[CLS]', 'the', 'capital', 'of', 'france', 'is', '[MASK]', '.', '[SEP]']

In [10]:
outputs

MaskedLMOutput([('logits',
                 tensor([[[ -6.4346,  -6.4063,  -6.4097,  ...,  -5.7691,  -5.6326,  -3.7883],
                          [-14.0119, -14.7241, -14.2120,  ..., -11.6976, -10.7304, -12.7618],
                          [ -9.6561, -10.3125,  -9.7459,  ...,  -8.7782,  -6.6036, -12.6596],
                          ...,
                          [ -3.7861,  -3.8572,  -3.5644,  ...,  -2.5593,  -3.1093,  -4.3820],
                          [-11.6598, -11.4274, -11.9267,  ...,  -9.8772, -10.2103,  -4.7594],
                          [-11.7267, -11.7509, -11.8040,  ..., -10.5943, -10.9407,  -7.5151]]],
                        grad_fn=<AddBackward0>))])

In [11]:
def get_mask_index(input_ids, tokenizer):
  x = input_ids[0]
  is_masked = torch.where(x == tokenizer.mask_token_id, x, 0)
  mask_idx = torch.nonzero(is_masked)
  return mask_idx.item()


mask_idx = get_mask_index(inputs.input_ids, tokenizer)
mask_idx

6

In [12]:
def get_top_k_predictions(pred_logits, mask_idx, top_k):
  probs = torch.nn.functional.softmax(pred_logits[0, mask_idx, :], dim=-1)
  top_k_weights, top_k_indices = torch.topk(probs, top_k, sorted=True)
  top_k_pct_weights = [100 * x.item() for x in top_k_weights]
  top_k_tokens = tokenizer.convert_ids_to_tokens(top_k_indices)
  return list(zip(top_k_tokens, top_k_pct_weights))


get_top_k_predictions(outputs.logits, mask_idx, 5)

[('paris', 41.678863763809204),
 ('lille', 7.141648232936859),
 ('lyon', 6.339278072118759),
 ('marseille', 4.444753006100655),
 ('tours', 3.02971713244915)]

### Mask Word Prediction

We sample 100 masked sentences from the dataset and run the top 5 predictions on it. Sample of 20 top-5 predictions and their probabilities are shown.



In [13]:
def get_predictions(row, tokenizer, model):
  inputs = tokenizer(row.masked_sentence, return_tensors="pt")
  outputs = model(**inputs)
  mask_idx = get_mask_index(inputs.input_ids, tokenizer)
  top_preds = get_top_k_predictions(outputs.logits, mask_idx, 5)
  formatted_top_preds = ", ".join(["{:s} ({:.2f}%)".format(t, w) for t, w in top_preds])
  return formatted_top_preds

data_df = data_df.sample(n=100, axis=0)
data_df["top_preds"] = data_df.apply(lambda row: get_predictions(row, tokenizer, model), axis=1)
data_df.head(20)

Unnamed: 0,masked_sentence,label,top_preds
7740,Grass can [MASK].,burn,"grow (57.51%), die (1.43%), kill (1.33%), bloom (1.29%), fall (1.27%)"
21833,Kirstin wants [MASK].,sex,"me (24.29%), her (10.95%), him (10.22%), this (6.22%), more (5.73%)"
13394,Deserts are [MASK].,barren,"common (14.29%), rare (8.39%), dry (8.11%), hot (4.34%), cold (3.92%)"
4560,Partygoers are [MASK].,people,"invited (42.35%), welcome (12.66%), encouraged (6.52%), free (5.70%), excluded (2.47%)"
29206,"Another way to say ""teenagers like to hang out in pizzeria's"" is ""A pizza parlor is [MASK] with teenagers"".",popular,"popular (68.71%), packed (10.83%), filled (4.30%), crowded (3.01%), crawling (2.25%)"
26333,Lying is for [MASK] someone.,decieving,"trusting (9.96%), hurting (9.84%), protecting (7.64%), ##giving (7.17%), killing (6.64%)"
5848,A mug is a type of [MASK].,container,"mug (92.53%), coffee (2.72%), drink (0.60%), beer (0.39%), beverage (0.37%)"
522,Something you might do while shopping is [MASK].,selecting,"important (10.79%), dangerous (3.51%), done (3.11%), forbidden (2.74%), over (2.65%)"
24232,There are people with fillings in their [MASK].,teeth,"mouths (21.84%), hands (15.40%), pockets (7.52%), mouth (5.10%), faces (3.18%)"
22273,A spy wants to [MASK].,snoop,"spy (29.67%), know (13.37%), die (8.41%), kill (4.92%), escape (2.85%)"
