# Nullspace BERT Demonstration

##### Install Requirements

In [1]:
!pip install -r requirements.txt

[31mERROR: Could not open requirements file: [Errno 2] No such file or directory: 'requirements.txt'[0m


---

In [None]:
import numpy as np
from tqdm import tqdm
import transformers
from transformers import (
    BertTokenizer,
    BertForMaskedLM
)
import pandas as pd
import torch
from torch.nn import functional as F
from transformers import logging

logging.set_verbosity_error()

# BERT Transformer

In [None]:
bert_tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
bert_mlm = BertForMaskedLM.from_pretrained('bert-base-uncased', return_dict=True)

![title](images/bert_architecture.png)

Source: [**BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding**: Jacob Devlin, Ming-Wei Chang, Kenton Lee, Kristina Toutanova](https://arxiv.org/abs/1810.04805)

In [None]:
BERT_BIAS_LAYER = 12

def chunker(input_list, chunk_size):
    """split sequence into chunks"""
    for i in range(0, len(input_list), chunk_size):
        yield input_list[i:i + chunk_size]
        
def get_embeddings(input_sequences, model, tokenizer):
    """extract hidden state from nth attention layer in encoder as specified by BERT_BIAS_LAYER"""
    tokenized_input = bert_tokenizer.batch_encode_plus(input_sequences, return_tensors = "pt", padding=True, truncation=False)
    embeddings = bert_mlm(**tokenized_input, output_hidden_states=True).hidden_states[BERT_BIAS_LAYER]
    return embeddings.detach().numpy(), tokenized_input["input_ids"].detach().numpy()

def extract_token_embeddings(embeddings, input_ids):
    """filter special token embeddings"""
    extracted_embeddings = []
    for idx in range(embeddings.shape[0]):
        if 0 in input_ids[idx]: # if input contains padding
            eos_idx = list(input_ids[idx]).index(0) - 1
        else:
            eos_idx = list(input_ids[idx]).index(102)
        extracted_embeddings.append(embeddings[idx][1:eos_idx].mean(axis=0))
    return np.array(extracted_embeddings)

# BERT Vector Generation

![title](images/bert_layers_gender_bias.png)

Source: [**Investigating Gender Bias in BERT**: Rishabh Bhardwaj, Navonil Majumder, Soujanya Poria](https://arxiv.org/abs/2009.05021)

In [None]:
import gensim
from gensim import downloader
from gensim.models import KeyedVectors
# word2vec = gensim.downloader.load('glove-wiki-gigaword-300')
# word2vec.save('vectors.kv')
word2vec = KeyedVectors.load('vectors.kv')
print(type(word2vec.key_to_index))
print(len(word2vec.key_to_index))
VOCABULARY = list(word2vec.key_to_index.keys())

<class 'dict'>
400000


In [None]:
sub_vocab = VOCABULARY[0:10000]
bert_vocab_embedding_list = np.empty((0, 768))
for chunk in tqdm(chunker(sub_vocab, 1000)):
    embeddings, input_ids = get_embeddings(chunk, bert_mlm, bert_tokenizer)
    embeddings = extract_token_embeddings(embeddings, input_ids)
    bert_vocab_embedding_list = np.concatenate((bert_vocab_embedding_list, embeddings), axis=0)

9it [05:31, 36.57s/it]

In [None]:
sub_vocab = VOCABULARY[10000:20000]
for chunk in tqdm(chunker(sub_vocab, 1000)):
    embeddings, input_ids = get_embeddings(chunk, bert_mlm, bert_tokenizer)
    embeddings = extract_token_embeddings(embeddings, input_ids)
    bert_vocab_embedding_list = np.concatenate((bert_vocab_embedding_list, embeddings), axis=0)

In [None]:
sub_vocab = VOCABULARY[20000:30000]
for chunk in tqdm(chunker(sub_vocab, 1000)):
    embeddings, input_ids = get_embeddings(chunk, bert_mlm, bert_tokenizer)
    embeddings = extract_token_embeddings(embeddings, input_ids)
    bert_vocab_embedding_list = np.concatenate((bert_vocab_embedding_list, embeddings), axis=0)

In [None]:
sub_vocab = VOCABULARY[30000:40000]
for chunk in tqdm(chunker(sub_vocab, 1000)):
    embeddings, input_ids = get_embeddings(chunk, bert_mlm, bert_tokenizer)
    embeddings = extract_token_embeddings(embeddings, input_ids)
    bert_vocab_embedding_list = np.concatenate((bert_vocab_embedding_list, embeddings), axis=0)

In [None]:
embedding_shape = bert_vocab_embedding_list.shape
with open("data/embeddings/BERTLM_ENCODER_LAYER_ONE/bert-base-uncased-embeddings.txt", "w") as bert_file:
    bert_file.write("40000 768")
    for word, embedding in zip(VOCABULARY, bert_vocab_embedding_list):
        bert_file.write(f"{word} {' '.join(map(str, list(embedding)))}\n")

In [None]:
!python get_bias_sensitive_tokens.py

pca explained variance ratio: [0.32702205 0.20407498 0.15900119 0.08174839 0.05758674 0.05069627
 0.04543993 0.03472263 0.02933716 0.01037071] 

TOP 100 MALE SENSITIVE TOKENS 
 ('man', 'john', 'mano', 'he', 'boy', 'guy', 'mancuso', 'son', 'his', 'mancini', 'manhunt', 'him', 'manger', 'boynton', 'manx', 'manmohan', 'hebei', 'heinous', 'hisham', 'boyce', 'manley', 'himself', 'heisman', 'hebron', 'jackman', 'hester', 'manu', 'mandating', 'heim', 'sons', 'kidman', 'mandel', 'mangled', 'heston', 'hebert', '10-man', 'man-made', 'helton', 'brothers', 'heist', 'dutchman', 'father', 'hitman', 'brother', 'hoo', 'mike', 'menlo', 'boyish', 'boyz', 'manhood', 'fatherland', 'bradman', 'beckman', 'handyman', 'himachal', 'dude', 'mendocino', 'hillman', 'jason', 'johndroe', 'paceman', 'mr', 'scotsman', 'boys', 'redman', 'linesman', 'hectic', 'mrt', 'mandelson', 'charles', 'rodman', 'manure', 'hezb', 'edgardo', 'cashman', 'sonoma', 'horan', 'handsome', 'boyhood', 'ackerman', 'pittman', 'siro', 'horrendo

# Null-Space Projection

In [None]:
!python context_nullspace_projection.py

Train size: 147; Dev size: 63; Test size: 90
iteration: 24, accuracy: 0.36507936507936506: 100%|█| 25/25 [00:08<00:00,  3.12i
Figure(600x500)
Figure(600x500)
V-measure-before (TSNE space): 0.778190793392485
V-measure-after (TSNE space): 0.0011550932483761207
V-measure-before (original space): 1.0
V-measure-after (original space): 0.0007205831499929152


![title](images/tsne_projections.png)

# Transformer Encoder / Decoder Generation

In [None]:
NULL_PROJECTION = np.load("data/nullspace_vector.npy")

In [None]:
def guard_vector(layer):
    """Apply nullprojection to inputted vector"""
    return NULL_PROJECTION.dot(layer.T).T


def guard_embedding(hidden_state, tokenized_input):
    """Apply the linear guarding function to hidden state"""
    input_ids_numpy = list(tokenized_input["input_ids"].detach().numpy()[0])
    word_indexes = [input_ids_numpy.index(token_id) for token_id in input_ids_numpy if token_id not in [101, 103, 102, 0]]
    bias_layer_numpy =  hidden_state.detach().numpy()
    for idx in word_indexes:
        bias_layer_numpy[0][idx] = guard_vector(bias_layer_numpy[0][idx])
    return torch.Tensor(bias_layer_numpy)


def run_post_bias_encoder_layers(encoder_layers_list, previous_hidden_state):
    """Manually run embeddings through attention blocks in encoder"""
    for attention_block in encoder_layers_list:
        previous_hidden_state = attention_block.forward(hidden_states=previous_hidden_state)[0]
    return previous_hidden_state


def get_next_word(logits, tokenizer, mask_index):
    """Generate the next highest liklihood word given logits"""
    softmax = F.softmax(logits, dim = -1)
    mask_word = softmax[0, mask_index, :]
    top_word = torch.argmax(mask_word, dim=1)
    return tokenizer.decode(top_word)


def generate_next_word(input_sequence, model, tokenizer, guard_flag=False, biased_layer_index=1):
    # extracting modules from BERT LM
    bert_encoder_modules = list(bert_mlm.modules())[8:-5] # extract list of model components
    encoder_layers_list = [bert_encoder_modules[idx] for idx in range(19, 206, 17)] # extracting each encoder attention block
    bert_mlm_head = bert_encoder_modules[-1] # extracting BERT LM Head

    # tokenize input sequence
    tokenized_input = tokenizer.encode_plus(input_sequence, return_tensors = "pt")
    mask_index = torch.where(tokenized_input["input_ids"][0] == bert_tokenizer.mask_token_id)

    # extracting encoding and feeding back into model
    hidden_state = model(**tokenized_input, output_hidden_states=True).hidden_states[biased_layer_index]
    
    # apply guarding function to hidden state
    hidden_state = guard_embedding(hidden_state, tokenized_input) if guard_flag else hidden_state
    
    # run guarded hidden state through remaining encoder layers
    encoder_output = run_post_bias_encoder_layers(encoder_layers_list, hidden_state)
    
    # pass encoder output into LM Head to generate logits
    output_logits = bert_mlm_head.forward(sequence_output=encoder_output)

    # generate the highest likelihood word
    return get_next_word(output_logits, tokenizer, mask_index)


def complete_the_sentence(female_variant, male_variant):
    guarded_pred_female = generate_next_word(female_variant, bert_mlm, bert_tokenizer, guard_flag=True)
    unguarded_pred_female = generate_next_word(female_variant, bert_mlm, bert_tokenizer, guard_flag=False)
    guarded_pred_male = generate_next_word(male_variant, bert_mlm, bert_tokenizer, guard_flag=True)
    unguarded_pred_male = generate_next_word(male_variant, bert_mlm, bert_tokenizer, guard_flag=False)

    guarded_female_sentence = female_variant.replace("[MASK]", guarded_pred_female.upper())
    unguarded_female_sentence = female_variant.replace("[MASK]", unguarded_pred_female.upper())
    guarded_male_sentence = male_variant.replace("[MASK]", guarded_pred_male.upper())
    unguarded_male_sentence = male_variant.replace("[MASK]", unguarded_pred_male.upper())
    output_string = (f"unguarded outputs: \n \t{unguarded_female_sentence} \n \t{unguarded_male_sentence} \n\n" +
                     f"guarded outputs: \n \t{guarded_female_sentence} \n \t{guarded_male_sentence} \n")
    print(output_string)



In [None]:
print("\n-------------- people thought <she/he> was very [MASK] during the meeting. -----------------\n")
complete_the_sentence("people thought she was very " + bert_tokenizer.mask_token + " during the meeting.",
                      "people thought he was very " + bert_tokenizer.mask_token + " during the meeting.")


-------------- people thought <she/he> was very [MASK] during the meeting. -----------------

unguarded outputs: 
 	people thought she was very BEAUTIFUL during the meeting. 
 	people thought he was very POPULAR during the meeting. 

guarded outputs: 
 	people thought she was very VERY during the meeting. 
 	people thought he was very VERY during the meeting. 

