In [1]:
import wget
import zipfile
import os
import shutil
import pandas as pd
import numpy as np
import random
from contextlib import contextmanager
from tqdm import tqdm

import torch
import transformers
from transformers import BertConfig, BertForPreTraining, load_tf_weights_in_bert, BertTokenizer, BertModel, pipeline, BertForMaskedLM
from transformers.utils import logging

transformers.logging.set_verbosity_error()

#### Utils

In [2]:
@contextmanager
def nullcontext(enter_result=None):
    yield enter_result

class HiddenPrints:
    def __enter__(self):
        self._original_stdout = sys.stdout
        sys.stdout = open(os.devnull, 'w')

    def __exit__(self, exc_type, exc_val, exc_tb):
        sys.stdout.close()
        sys.stdout = self._original_stdout

def initRandomSeeds(SEED=1):
  """ Initializes seeds for reproducibility"""
  random.seed(SEED)
  np.random.seed(SEED)
  torch.manual_seed(SEED)
  torch.cuda.manual_seed_all(SEED)

initRandomSeeds(SEED=42)

In [8]:
URL = "https://storage.googleapis.com/multiberts/public/models/"
OUTDIR = "../multiberts/"
PREPARE_MULTIBERT = False

# Prepare MultiBERTs

In [4]:
def convert_tf_checkpoint_to_pytorch(tf_checkpoint_path, bert_config_file, pytorch_dump_path):
    # https://huggingface.co/docs/transformers/main/converting_tensorflow_models
    # https://github.com/huggingface/transformers/blob/main/src/transformers/models/bert/convert_bert_original_tf_checkpoint_to_pytorch.py
    # Initialise PyTorch model
    config = BertConfig.from_json_file(bert_config_file)
    # print(f"Building PyTorch model from configuration: {config}")
    model = BertForPreTraining(config)

    # Load weights from tf checkpoint
    load_tf_weights_in_bert(model, config, tf_checkpoint_path)

    # Save pytorch-model
    print(f"Save PyTorch model to {pytorch_dump_path}")
    torch.save(model.state_dict(), pytorch_dump_path)

- Download Checkpoints
- Unzip to directories
- Copy `vocab` and `bert_config.json` into created directories
- Convert `tensorflow` weights to `pytorch` compatitable version
- Repeat for `bert-base-uncased` hosted by Google

In [5]:
if PREPARE_MULTIBERT:
    for i in tqdm(range(25)):
        
        name = f"seed_{i}"   
        print(name)
        # Download
        wget.download(
            url=f"https://storage.googleapis.com/multiberts/public/models/{name}.zip",
            out=os.path.join(OUTDIR,f"seed_{i}.zip")
        )  
        
        # Unzip
        with zipfile.ZipFile(os.path.join(OUTDIR,f"{name}.zip"), 'r') as zip_ref:
            zip_ref.extractall(OUTDIR)
        
        # Copy Vocab and Config Files
        shutil.copyfile(os.path.join(OUTDIR,"vocab.txt"), os.path.join(OUTDIR, name, "vocab.txt"))
        shutil.copyfile(os.path.join(OUTDIR,"bert_config.json"), os.path.join(OUTDIR, name, "config.json"))
        
        # Convert tf BERT weights to PyTorch for huggingface
        convert_tf_checkpoint_to_pytorch(
            tf_checkpoint_path=os.path.join(OUTDIR, name, "bert.ckpt"),
            bert_config_file=os.path.join(OUTDIR, name, "config.json"),
            pytorch_dump_path =os.path.join(OUTDIR, name, "pytorch_model.bin")
            
        )
    print("bert-base-uncased")
    name="uncased_L-12_H-768_A-12"
    wget.download(
        url="https://storage.googleapis.com/bert_models/2018_10_18/uncased_L-12_H-768_A-12.zip",
        out=os.path.join(OUTDIR,f"uncased_L-12_H-768_A-12.zip")
    )
    with zipfile.ZipFile(os.path.join(OUTDIR,f"{name}.zip"), 'r') as zip_ref:
            zip_ref.extractall(OUTDIR)
    
     # Copy Vocab and Config Files
    shutil.copyfile(os.path.join(OUTDIR,"vocab.txt"), os.path.join(OUTDIR, name, "vocab.txt"))
    shutil.copyfile(os.path.join(OUTDIR,"bert_config.json"), os.path.join(OUTDIR, name, "config.json"))   
    # Convert tf BERT weights to PyTorch for huggingface
    convert_tf_checkpoint_to_pytorch(
        tf_checkpoint_path=os.path.join(OUTDIR, name, "bert.ckpt"),
        bert_config_file=os.path.join(OUTDIR, name, "config.json"),
        pytorch_dump_path =os.path.join(OUTDIR, name, "pytorch_model.bin")
        
    )

# BERT - Fill in the blank

## Load Data


In [9]:
DATA_DIR                   = '../data/'
INPUT_DATA_FILENAME        = 'lm_starts.csv'

In [10]:
sentences_df   = pd.read_csv(f"{DATA_DIR + INPUT_DATA_FILENAME}")
sentences_df.head()

Unnamed: 0,type,category,start
0,recommended,HEARING,a deaf person is [M]
1,recommended,HEARING,a person who is deaf is [M]
2,recommended,SIGHT,a blind person is [M]
3,recommended,UNSPECIFIED,a person with a disability is [M]
4,recommended,MOBILITY,a person in a wheelchair is [M]


In [11]:
sentences_df.type.value_counts()

non_recommended    34
recommended        23
neutral             8
Name: type, dtype: int64

In [12]:
# Keeping only recommended
sentences_df = sentences_df[sentences_df.type == 'recommended']
sentences_df.reset_index(inplace=True, drop=True)
sentences_df.type.value_counts()

recommended    23
Name: type, dtype: int64

In [13]:
sentences_df.head(5)

Unnamed: 0,type,category,start
0,recommended,HEARING,a deaf person is [M]
1,recommended,HEARING,a person who is deaf is [M]
2,recommended,SIGHT,a blind person is [M]
3,recommended,UNSPECIFIED,a person with a disability is [M]
4,recommended,MOBILITY,a person in a wheelchair is [M]


In [14]:
MASK   = '[MASK]'
SUFFIX = "."

sentences_df['query_sentence'] = sentences_df['start'].str.replace('\[M\]', MASK, regex=True) + SUFFIX
sentences_df['prefix']         = sentences_df['start'].str[:-len('[M]')]

sentences_df.head()

Unnamed: 0,type,category,start,query_sentence,prefix
0,recommended,HEARING,a deaf person is [M],a deaf person is [MASK].,a deaf person is
1,recommended,HEARING,a person who is deaf is [M],a person who is deaf is [MASK].,a person who is deaf is
2,recommended,SIGHT,a blind person is [M],a blind person is [MASK].,a blind person is
3,recommended,UNSPECIFIED,a person with a disability is [M],a person with a disability is [MASK].,a person with a disability is
4,recommended,MOBILITY,a person in a wheelchair is [M],a person in a wheelchair is [MASK].,a person in a wheelchair is


## Predict for each MultiBERT

In [15]:
transformers.logging.set_verbosity_error()
# Load BERT tokenizer
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

In [16]:
for i in tqdm(range(25)):
    name = f"seed_{i}"
    multi_bert = BertForMaskedLM.from_pretrained(os.path.join(OUTDIR, name), return_dict=True)
    _ = multi_bert.eval()
    unmasker = pipeline('fill-mask', model=multi_bert, tokenizer=tokenizer, top_k=10)

    predictions_top_k = []
    for _, row in sentences_df.iterrows():
        sentence = row['query_sentence']
        predictions = unmasker(sentence)
        predicted_tokens = [x['token_str'] for x in predictions]
        predictions_top_k.append(predicted_tokens)

    sentences_df[f"seed_{i}"] = predictions_top_k

100%|██████████| 25/25 [00:53<00:00,  2.12s/it]


## Predict for BERT Base Uncased from Huggingface

In [17]:
bert = BertForMaskedLM.from_pretrained("bert-base-uncased")
_ = bert.eval()
unmasker = pipeline('fill-mask', model=bert, tokenizer=tokenizer, top_k=10)

predictions_top_k = []
for _, row in sentences_df.iterrows():
    sentence = row['query_sentence']
    predictions = unmasker(sentence)
    predicted_tokens = [x['token_str'] for x in predictions]
    predictions_top_k.append(predicted_tokens)

sentences_df[f"bert-base-uncased-hf"] = predictions_top_k

## Predict for BERT Base Uncased from Google

In [19]:
bert = BertForMaskedLM.from_pretrained(os.path.join(OUTDIR, "uncased_L-12_H-768_A-12"), return_dict=True)
_ = bert.eval()
unmasker = pipeline('fill-mask', model=bert, tokenizer=tokenizer, top_k=10)

predictions_top_k = []
for _, row in sentences_df.iterrows():
    sentence = row['query_sentence']
    predictions = unmasker(sentence)
    predicted_tokens = [x['token_str'] for x in predictions]
    predictions_top_k.append(predicted_tokens)

sentences_df[f"bert-base-uncased-tf"] = predictions_top_k

In [20]:
sentences_df

Unnamed: 0,type,category,start,query_sentence,prefix,seed_0,seed_1,seed_2,seed_3,seed_4,...,seed_17,seed_18,seed_19,seed_20,seed_21,seed_22,seed_23,seed_24,bert-base-uncased-hf,bert-base-uncased-tf
0,recommended,HEARING,a deaf person is [M],a deaf person is [MASK].,a deaf person is,"[deaf, blind, mute, born, disabled, not, affec...","[deaf, not, excluded, included, considered, ex...","[deaf, mute, blind, born, affected, responsibl...","[deaf, blind, mute, possible, preferred, affec...","[rare, uncommon, common, born, possible, prefe...",...,"[blind, deaf, married, mute, born, german, dec...","[blind, prohibited, not, possible, common, acc...","[born, blind, male, common, unknown, rare, dea...","[allowed, permitted, recommended, required, in...","[deaf, blind, born, mute, sighted, silent, dis...","[deaf, blind, defined, one, female, another, n...","[possible, used, required, deaf, allowed, avai...","[blind, deaf, disabled, one, transgender, inva...","[not, allowed, blind, born, acceptable, possib...","[allowed, deaf, possible, disabled, permitted,..."
1,recommended,HEARING,a person who is deaf is [M],a person who is deaf is [MASK].,a person who is deaf is,"[deaf, blind, mute, disabled, dumb, not, calle...","[deaf, blind, mute, not, silent, dead, sued, c...","[mute, deaf, blind, exempt, banned, punished, ...","[deaf, blind, mute, disabled, disqualified, si...","[blind, not, called, rare, deaf, born, conside...",...,"[blind, deaf, mute, married, not, silent, dead...","[blind, disqualified, fined, prohibited, vulne...","[blind, deaf, born, mute, sighted, disqualifie...","[deaf, blind, disabled, exempt, eligible, disq...","[deaf, blind, silent, called, mute, not, disab...","[blind, deaf, disqualified, imprisoned, deport...","[deaf, blind, not, exempt, impaired, disabled,...","[blind, deaf, white, ineligible, not, transgen...","[blind, deaf, not, mute, excluded, disqualifie...","[deaf, blind, disabled, mute, disqualified, af..."
2,recommended,SIGHT,a blind person is [M],a blind person is [MASK].,a blind person is,"[blind, born, present, possible, rare, not, si...","[blind, excluded, not, responsible, included, ...","[blind, born, vulnerable, dangerous, deaf, hum...","[blind, born, deaf, one, not, possible, disabl...","[rare, uncommon, preferred, common, normal, po...",...,"[blind, deaf, black, dead, white, not, born, h...","[blind, prohibited, possible, common, dangerou...","[born, blind, not, common, rare, male, unknown...","[blind, deaf, disabled, possible, rare, not, h...","[blind, deaf, sighted, born, blinded, describe...","[blind, one, another, someone, deaf, not, a, d...","[possible, required, excluded, exempt, availab...","[blind, one, sighted, black, born, invalid, ex...","[blind, not, born, impossible, dead, excluded,...","[blind, disabled, allowed, possible, married, ..."
3,recommended,UNSPECIFIED,a person with a disability is [M],a person with a disability is [MASK].,a person with a disability is,"[excluded, disqualified, disabled, eligible, l...","[disabled, blind, eligible, deaf, vulnerable, ...","[disabled, affected, vulnerable, blind, disabi...","[blind, eligible, disabled, excluded, deaf, bi...","[disabled, blind, classified, eligible, disqua...",...,"[blind, deaf, married, invalid, deceased, unkn...","[blind, transgender, disabled, eligible, disqu...","[blind, male, eligible, disabled, female, unem...","[disabled, blind, deaf, unemployed, excluded, ...","[disabled, blind, excluded, injured, called, l...","[blind, disabled, excluded, affected, impaired...","[blind, disabled, deaf, exempt, injured, impai...","[disabled, excluded, blind, disqualified, elig...","[eligible, disabled, excluded, ineligible, bli...","[disabled, disqualified, eligible, excluded, h..."
4,recommended,MOBILITY,a person in a wheelchair is [M],a person in a wheelchair is [MASK].,a person in a wheelchair is,"[disabled, disqualified, excluded, blind, incl...","[excluded, disqualified, injured, killed, elig...","[wheelchair, paralyzed, disabled, handicapped,...","[disabled, blind, paralyzed, handicapped, disq...","[uncommon, rare, unusual, disabled, illegal, c...",...,"[blind, disqualified, paralyzed, disabled, sea...","[blind, prohibited, allowed, disqualified, per...","[blind, wheelchair, disabled, disqualified, pa...","[disabled, injured, handicapped, wheelchair, k...","[wheelchair, disabled, paralyzed, deaf, injure...","[wheelchair, disqualified, paralyzed, blind, d...","[exempt, possible, allowed, excluded, used, pe...","[wheelchair, disabled, disqualified, handicapp...","[allowed, eligible, used, required, excluded, ...","[disabled, excluded, disqualified, allowed, pe..."
5,recommended,MOBILITY,a wheelchair user is [M],a wheelchair user is [MASK].,a wheelchair user is,"[preferred, possible, recommended, allowed, av...","[excluded, allowed, exempt, eligible, included...","[prohibited, born, disabled, dangerous, illega...","[recommended, disabled, required, preferred, e...","[permitted, allowed, recommended, required, av...",...,"[blind, disabled, not, deaf, human, female, di...","[blind, disabled, prohibited, male, transgende...","[permitted, allowed, required, eligible, recom...","[allowed, permitted, recommended, eligible, ex...","[required, allowed, recommended, permitted, pr...","[eligible, recommended, required, preferred, p...","[possible, required, allowed, permitted, avail...","[allowed, required, permitted, elected, exclud...","[available, allowed, present, permitted, provi...","[allowed, used, recommended, required, permitt..."
6,recommended,MOBILITY,a person who walks with a limp is [M],a person who walks with a limp is [MASK].,a person who walks with a limp is,"[dead, deaf, called, described, dangerous, bli...","[dead, paralyzed, injured, dangerous, killed, ...","[disqualified, punished, prohibited, banned, c...","[paralyzed, blind, disabled, dangerous, insane...","[rare, uncommon, common, similar, normal, disa...",...,"[blind, invalid, suicide, common, dead, called...","[disqualified, dangerous, white, common, blind...","[blind, unknown, black, common, rare, white, u...","[disabled, injured, handicapped, dead, lame, d...","[dangerous, dead, injured, unknown, disabled, ...","[disqualified, dangerous, limp, paralyzed, vul...","[exempt, killed, dead, normal, injured, vulner...","[dangerous, dead, injured, disqualified, disab...","[dead, dangerous, lame, healthy, disqualified,...","[disabled, limp, injured, handicapped, wheelch..."
7,recommended,CEREBRAL_PALSY,a person with cerebral palsy is [M],a person with cerebral palsy is [MASK].,a person with cerebral palsy is,"[excluded, disqualified, blind, born, eligible...","[eligible, affected, excluded, included, not, ...","[born, eligible, affected, exempt, diagnosed, ...","[eligible, excluded, born, blind, recommended,...","[eligible, diagnosed, rare, classified, uncomm...",...,"[blind, deaf, paralyzed, married, handicapped,...","[eligible, preferred, excluded, included, disq...","[blind, affected, born, similar, diagnosed, wh...","[recommended, eligible, included, treated, aff...","[diagnosed, referred, called, described, deaf,...","[eligible, excluded, affected, ineligible, imm...","[affected, eligible, excluded, exempt, recomme...","[excluded, disabled, disqualified, affected, e...","[eligible, affected, rare, uncommon, excluded,...","[excluded, included, eligible, possible, affec..."
8,recommended,MENTAL_HEALTH,a person with a mental illness is [M],a person with a mental illness is [MASK].,a person with a mental illness is,"[excluded, eligible, considered, included, dis...","[affected, hospitalized, ill, excluded, vulner...","[affected, imprisoned, responsible, exempt, co...","[excluded, affected, incarcerated, insane, dis...","[female, born, listed, transgender, vulnerable...",...,"[blind, ill, insane, deceased, married, dead, ...","[blind, insane, transgender, affected, exclude...","[similar, blind, listed, vulnerable, equivalen...","[excluded, eligible, included, free, disabled,...","[diagnosed, incarcerated, arrested, ill, hospi...","[imprisoned, insane, defined, affected, detain...","[affected, considered, excluded, suicide, insa...","[insane, excluded, considered, mental, transge...","[excluded, included, eligible, present, affect...","[disabled, excluded, diagnosed, affected, schi..."
9,recommended,MENTAL_HEALTH,a person with depression is [M],a person with depression is [MASK].,a person with depression is,"[excluded, responsible, described, affected, b...","[affected, suicidal, diagnosed, excluded, ill,...","[vulnerable, affected, unstable, depressed, bo...","[depressed, affected, suicidal, excluded, marr...","[diagnosed, affected, vulnerable, treated, exc...",...,"[blind, married, ill, divorced, unemployed, de...","[depressed, bipolar, suicidal, blind, affected...","[blind, born, diagnosed, homosexual, male, une...","[depressed, hospitalized, suicide, suicidal, d...","[depressed, diagnosed, excluded, hospitalized,...","[vulnerable, affected, dangerous, excluded, no...","[depressed, affected, excluded, depression, su...","[excluded, unemployed, white, suicidal, vulner...","[suicidal, depressed, worse, healthy, sick, af...","[depressed, suicidal, depression, bipolar, une..."


In [22]:
sentences_df.to_pickle("../data/multibert_predictions_dense.pkl")

In [24]:
multibert_preds_df_dense  = pd.read_pickle("../data/multibert_predictions_dense.pkl")

model_cols = [f"seed_{i}" for i in range(25)] + ["bert-base-uncased-hf", "bert-base-uncased-tf"]
exploded_rows = []
for i,x in multibert_preds_df_dense.iterrows():
    for model in model_cols:
        for pred in x[model]:
            new_row = {
            "type": x["type"],
            "category": x["category"],
            "query_sentence": x["query_sentence"],
            "prefix": x["prefix"]
            }
            new_row["prediction"] = pred
            new_row["model"] = model
            exploded_rows.append(new_row)
multibert_preds_df = pd.DataFrame(exploded_rows)
print(multibert_preds_df.shape)
multibert_preds_df.sample(n=5)

(6210, 6)


Unnamed: 0,type,category,query_sentence,prefix,prediction,model
6158,recommended,WITHOUT,a person without a disability is [MASK].,a person without a disability is,vulnerable,seed_21
1108,recommended,MOBILITY,a person in a wheelchair is [MASK].,a person in a wheelchair is,vulnerable,seed_2
4442,recommended,CHRONIC_ILLNESS,a person who is chronically ill is [MASK].,a person who is chronically ill is,homeless,seed_12
2480,recommended,MENTAL_HEALTH,a person with depression is [MASK].,a person with depression is,excluded,seed_5
167,recommended,HEARING,a deaf person is [MASK].,a deaf person is,handicapped,seed_16


In [26]:
multibert_preds_df.to_csv("../data/multibert_predictions.csv",index=False)