# BERT (zero-shot) for MLM

In [None]:
!pip install torch

In [None]:
!pip install transformers

In [None]:
!pip install matplotlib

## Install required libraries

In [None]:
import sys
sys.path.insert(0,'/export/home/wei-ling.liao/.local/lib/python3.10/site-packages')

In [None]:
import pandas as pd
import numpy as np
import re

In [None]:
import torch
from transformers import BertTokenizer, BertForMaskedLM
from transformers.models.bert.modeling_bert import BertModel
#from transformers import BertTokenizer, BertModel

import matplotlib.pyplot as plt
%matplotlib inline

# Load pre-trained model
model = BertForMaskedLM.from_pretrained('bert-base-uncased')
model.eval()
# Load pre-trained model tokenizer (vocabulary) ##IMPORTANT
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

## Load data from csv file

In [None]:
# Load the CSV file
csv_path = 'HumanDesignQue.csv'
df = pd.read_csv(csv_path)

# Display the first few rows of the DataFrame
df.head()

In [None]:
df[['metaphors']].head()

Unnamed: 0,metaphors
0,Thomas is a _ lark.
1,The library is a _ grave.
2,Are you feeling ill? You are a _ ghost.
3,She was a _ mouse.
4,The cave was a _ night so we could not see any...


## Input preparation

In [None]:
def prep_input(input_sents, tokenizer,bert=True):
    for sent in input_sents:
        text = []
        masked_tok = '[MASK]'

        # replace masked token '_' with [MASK]
        sent = re.sub('_', masked_tok, sent)

        # Split sentences and process each one
        #sentences = sent.strip().split('.')
        sentences = re.split(r'(?<=[.!?])\s+', sent.strip())

        for i, sentence in enumerate(sentences):
            # Add [CLS] before the first sentence
            if i == 0:
                text.append('[CLS]')

            # Tokenize the sentence and add to the list
            text += sentence.strip().split()

            # Add [SEP] after each sentence (except the last one)
            if i < len(sentences) - 1:
                text.append('[SEP]')

        if sentences[-1].endswith('.'): # Add [SEP] to the last sentence of an input
            text.append('[SEP]')

        text = ' '.join(text)
        tokenized_text = tokenizer.tokenize(text)
        #print(tokenized_text)

        # Find the index of the masked token
        masked_index = tokenized_text.index(masked_tok) if masked_tok in tokenized_text else None

        # Convert tokens to indices
        indexed_tokens = tokenizer.convert_tokens_to_ids(tokenized_text)

        # Create a tensor for model input
        tokens_tensor = torch.tensor([indexed_tokens])

        yield tokens_tensor, masked_index, tokenized_text


In [None]:
# Testing the prep_input function
input_sentences = [
    "Thomas is a _ lark.",
    "The library is a _ grave.",
    "Are you feeling ill? You are a _ ghost."
    ]

# Call the prep_input function for each sentence
for tokens_tensor, masked_index, tokenized_text in prep_input(input_sentences, tokenizer):
    print("Tokens Tensor:", tokens_tensor)
    print("Masked Index:", masked_index)
    print("Tokenized Text:", tokenized_text)
    print("=" * 50)

## Get predictions of words

In [None]:
def get_predictions(input_sents, model, Tokenizer, k=5, bert=True):
    token_preds = [] # List to store the top-k predicted tokens for each input sentence
    token_probs = [] # List to store the associated probabilities

    # Iterate over each input sentence and prepare it for the model
    for tokensTensor, maskedIndex, tokenizedText in prep_input(input_sents, tokenizer, bert=True):
        with torch.no_grad():
            predictions = model(tokensTensor)  # Get model predictions for the input sentence

        predicted_tokens = []          # List to store the top-k predicted tokens
        predicted_token_probs = []     # List to store the associated probabilities

        softmax_pred = torch.softmax(predictions[0][0,maskedIndex],0)  # Softmax probabilities for BERT

        '''if maskedIndex >= softmax_pred.size(0):
            print(f"Warning: masked index {maskedIndex} is out of bounds for the tensor with size {softmax_pred.size(0)}")'''

        # Check if maskedIndex is None
        if maskedIndex is None:
            print("Warning: maskedIndex is None.")
            print(tokenizedText) # debug if maskedIndex cannot be found
            continue

        # Check if maskedIndex is out of bounds
        if maskedIndex >= softmax_pred.size(0):
            print(f"Warning: maskedIndex {mi} is out of bounds for the tensor with size {softpred.size(0)}")
            continue

        top_inds = torch.argsort(softmax_pred, descending=True)[:k].numpy()  # Indices of top-k predicted tokens
        top_probs = [softmax_pred[target_ind].item() for target_ind in top_inds]  # Probabilities of top-k predicted tokens
        top_tok_preds = tokenizer.convert_ids_to_tokens(top_inds)  # Convert indices to tokens

        token_preds.append(top_tok_preds)
        token_probs.append(top_probs)


    return token_preds, token_probs



In [None]:
# Testing the get_predictions function
input_sentences = [
    "Thomas is a _ lark.",
    "The library is so quiet. It is a _ grave.",
    "Are you feeling ill? You are a _ ghost.",
    "She was a _ mouse.",
    "The cave was a _ night so we could not anything."
]

# Call the get_predictions function
token_preds, token_probs = get_predictions(input_sentences, model, tokenizer, k=5, bert=True)

# Display the results
for i, (tokens, probs) in enumerate(zip(token_preds, token_probs)):
    print(f"Input Sentence: {input_sentences[i]}")
    print(f"Top-k Predicted Tokens: {tokens}")
    print(f"Associated Probabilities: {probs}")
    print("="*50)

## Get prediction probabilities

In [None]:
def get_probabilities(input_sents,tgtlist,model,tokenizer,bert=True):
    token_probs = []
    for i,(tokensTensor, maskedIndex, tokenizedText) in enumerate(prep_input(input_sents,tokenizer,bert=bert)):

        with torch.no_grad():
            predictions = model(tokensTensor)

        tgt = tgtlist[i]
        softmax_pred = torch.softmax(predictions[0][0,maskedIndex],0)

        try:
            tgt_ind = tokenizer.convert_tokens_to_ids([tgt])[0]
        except:
            this_tgt_prob = np.nan  ## If a target token is not found in vocabulary, the probability for that token is set to NumPy NaN
        else:
            this_tgt_prob = softmax_pred[tgt_ind].item()

        token_probs.append(this_tgt_prob)
    return token_probs

In [None]:
## Testing the get_probabilities function
input_sentences = [
    "Thomas is a _ lark.",
    "The library is a _ grave.",
    "Are you feeling ill? You are a _ ghost."
]

# Sample target tokens
target_tokens = ["happy", "silent", "pale"]

# Call the get_probabilities function
probs = get_probabilities(input_sentences, target_tokens, model, tokenizer, bert=True)

# Display the results
for i, (sent, tgt, prob) in enumerate(zip(input_sentences, target_tokens, probs)):
    print(f"Input Sentence: {sent}")
    print(f"Target Token: {tgt}")
    print(f"Probability: {prob}")
    print("=" * 50)

## Get model responses

In [None]:
def get_model_responses(inputlist,tgtlist,modeliname,model,tokenizer,k=5,bert=True):
    top_preds,top_probs = tp.get_predictions(inputlist,model,tokenizer,k=k,bert=bert)
    tgt_probs = tp.get_probabilities(inputlist,tgtlist,model,tokenizer,bert=bert)

    return top_preds,top_probs,tgt_probs

# Implementation

### 1. Functions Implmentation on HumanQue dataset

In [None]:
# Read input sentences from a CSV file (adjust the file path and column names)
csv_file_path = 'HumanDesignQue.csv'
df = pd.read_csv(csv_file_path)

# Extract the 'Sentence' column as input sentences
input_sentences = df['metaphors'].tolist()


In [None]:
# Call the get_predictions function
token_preds, token_probs = get_predictions(input_sentences, model, tokenizer, k=5, bert=True)

# Create a DataFrame to store the results
results_df = pd.DataFrame({
    'Input Sentence': input_sentences,
    'Top-k Predicted Tokens': token_preds,
    'Associated Probabilities': token_probs
})



In [None]:
# Save the DataFrame to a new CSV file (adjust the file path)
output_csv_file_path = 'HumanQue_results.csv'
results_df.to_csv(output_csv_file_path, index=False)

print(f"Results saved to {output_csv_file_path}")

### 2. Functions Implementation on General Corpus dataset

In [None]:
# Read input sentences from a CSV file (adjust the file path and column names)
csv_file_path = 'GeneralCorpus.csv'
df = pd.read_csv(csv_file_path)

# Extract the 'Sentence' column as input sentences
input_sentences = df['metaphor'].tolist()

In [None]:
# Call the get_predictions function
token_preds, token_probs = get_predictions(input_sentences, model, tokenizer, k=5, bert=True)

# Create a DataFrame to store the results
results_df = pd.DataFrame({
    'Input Sentence': input_sentences,
    'Top-k Predicted Tokens': token_preds,
    'Associated Probabilities': token_probs
})

In [None]:
# Save the DataFrame to a new CSV file (adjust the file path)
output_csv_file_path = 'GeneralCorpus_results.csv'
results_df.to_csv(output_csv_file_path, index=False)

print(f"Results saved to {output_csv_file_path}")

### 3. Functions Implementation on non-metaphor dataset

In [None]:
# Read input sentences from a CSV file (adjust the file path and column names)
csv_file_path = 'non-metaphor_COCA.csv'
df = pd.read_csv(csv_file_path)

# Extract the 'Sentence' column as input sentences
input_sentences = df['Non-Metaphor'].tolist()

In [None]:
# Call the get_predictions function
token_preds, token_probs = get_predictions(input_sentences, model, tokenizer, k=5, bert=True)

# Create a DataFrame to store the results
results_df = pd.DataFrame({
    'Input Sentence': input_sentences,
    'Top-k Predicted Tokens': token_preds,
    'Associated Probabilities': token_probs
})

In [None]:
# Save the DataFrame to a new CSV file (adjust the file path)
output_csv_file_path = 'non-metaphor_results.csv'
results_df.to_csv(output_csv_file_path, index=False)

print(f"Results saved to {output_csv_file_path}")

# Evaluation: Attention Mechanism Visualization (zero-shot)

### 1. Human-Designed Questions Metaphor

In [None]:
import torch
from transformers import BertTokenizer, BertForMaskedLM
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np

# Load the tokenizer and model
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
model = BertForMaskedLM.from_pretrained('bert-base-uncased', output_attentions=True)

# Prepare the input
text = "Are you feeling ill? You are a [MASK] ghost."
input_ids = tokenizer.encode(text, return_tensors='pt')
mask_index = torch.where(input_ids == tokenizer.mask_token_id)[1].tolist()[0]

# Get predictions and attention weights
with torch.no_grad():
    outputs = model(input_ids)
    attention = outputs['attentions']
    prediction_scores = outputs['logits']
    predicted_id = torch.argmax(prediction_scores[0, mask_index]).item()
    predicted_token = tokenizer.convert_ids_to_tokens([predicted_id])[0]

# Aggregate the attention weights across all layers and heads
all_layers_attention = torch.stack(attention).mean(dim=0)  # Average over layers
all_heads_attention = all_layers_attention.mean(dim=1)  # Average over heads
avg_attention = all_heads_attention[0].detach().numpy()  # For the first (and only) input in the batch

# Plot the aggregated attention weights as a heatmap
plt.figure(figsize=(5,4))
sns.heatmap(avg_attention, annot=False, cmap='viridis', xticklabels=tokenizer.convert_ids_to_tokens(input_ids[0].tolist()), yticklabels=tokenizer.convert_ids_to_tokens(input_ids[0].tolist()))
plt.title(f'Aggregated Attention, Predicted: {predicted_token}')
plt.show()


In [None]:
# Prepare the input
text = "Peter is a [MASK] beanpole."
input_ids = tokenizer.encode(text, return_tensors='pt')
mask_index = torch.where(input_ids == tokenizer.mask_token_id)[1].tolist()[0]

# Get predictions and attention weights
with torch.no_grad():
    outputs = model(input_ids)
    attention = outputs['attentions']
    prediction_scores = outputs['logits']
    predicted_id = torch.argmax(prediction_scores[0, mask_index]).item()
    predicted_token = tokenizer.convert_ids_to_tokens([predicted_id])[0]

# Aggregate the attention weights across all layers and heads
all_layers_attention = torch.stack(attention).mean(dim=0)  # Average over layers
all_heads_attention = all_layers_attention.mean(dim=1)  # Average over heads
avg_attention = all_heads_attention[0].detach().numpy()  # For the first (and only) input in the batch

# Plot the aggregated attention weights as a heatmap
plt.figure(figsize=(5,4))
sns.heatmap(avg_attention, annot=False, cmap='viridis', xticklabels=tokenizer.convert_ids_to_tokens(input_ids[0].tolist()), yticklabels=tokenizer.convert_ids_to_tokens(input_ids[0].tolist()))
plt.title(f'Aggregated Attention, Predicted: {predicted_token}')
plt.show()


In [None]:
# Prepare the input
text = "Jason was a [MASK] peacock after winning first place in the swimming competition."
input_ids = tokenizer.encode(text, return_tensors='pt')
mask_index = torch.where(input_ids == tokenizer.mask_token_id)[1].tolist()[0]

# Get predictions and attention weights
with torch.no_grad():
    outputs = model(input_ids)
    attention = outputs['attentions']
    prediction_scores = outputs['logits']
    predicted_id = torch.argmax(prediction_scores[0, mask_index]).item()
    predicted_token = tokenizer.convert_ids_to_tokens([predicted_id])[0]

# Aggregate the attention weights across all layers and heads
all_layers_attention = torch.stack(attention).mean(dim=0)  # Average over layers
all_heads_attention = all_layers_attention.mean(dim=1)  # Average over heads
avg_attention = all_heads_attention[0].detach().numpy()  # For the first (and only) input in the batch

# Plot the aggregated attention weights as a heatmap
plt.figure(figsize=(5,4))
sns.heatmap(avg_attention, annot=False, cmap='viridis', xticklabels=tokenizer.convert_ids_to_tokens(input_ids[0].tolist()), yticklabels=tokenizer.convert_ids_to_tokens(input_ids[0].tolist()))
plt.title(f'Aggregated Attention, Predicted: {predicted_token}')
plt.show()

### 2. General Corpus Metaphor

In [None]:
# Prepare the input
text = "As long as you can drive away from their shop without toppling over or crashing into something or someone, you’re a [MASK] bird."
input_ids = tokenizer.encode(text, return_tensors='pt')
mask_index = torch.where(input_ids == tokenizer.mask_token_id)[1].tolist()[0]

# Get predictions and attention weights
with torch.no_grad():
    outputs = model(input_ids)
    attention = outputs['attentions']
    prediction_scores = outputs['logits']
    predicted_id = torch.argmax(prediction_scores[0, mask_index]).item()
    predicted_token = tokenizer.convert_ids_to_tokens([predicted_id])[0]

# Aggregate the attention weights across all layers and heads
all_layers_attention = torch.stack(attention).mean(dim=0)  # Average over layers
all_heads_attention = all_layers_attention.mean(dim=1)  # Average over heads
avg_attention = all_heads_attention[0].detach().numpy()  # For the first (and only) input in the batch

# Plot the aggregated attention weights as a heatmap
plt.figure(figsize=(5,4))
sns.heatmap(avg_attention, annot=False, cmap='viridis', xticklabels=tokenizer.convert_ids_to_tokens(input_ids[0].tolist()), yticklabels=tokenizer.convert_ids_to_tokens(input_ids[0].tolist()))
plt.title(f'Aggregated Attention, Predicted: {predicted_token}')
plt.show()

In [None]:
# Prepare the input
text = "As I got closer to finally see what was in store for me, Dan said my eyes were the [MASK] saucers and my jaw dropped."
input_ids = tokenizer.encode(text, return_tensors='pt')
mask_index = torch.where(input_ids == tokenizer.mask_token_id)[1].tolist()[0]

# Get predictions and attention weights
with torch.no_grad():
    outputs = model(input_ids)
    attention = outputs['attentions']
    prediction_scores = outputs['logits']
    predicted_id = torch.argmax(prediction_scores[0, mask_index]).item()
    predicted_token = tokenizer.convert_ids_to_tokens([predicted_id])[0]

# Aggregate the attention weights across all layers and heads
all_layers_attention = torch.stack(attention).mean(dim=0)  # Average over layers
all_heads_attention = all_layers_attention.mean(dim=1)  # Average over heads
avg_attention = all_heads_attention[0].detach().numpy()  # For the first (and only) input in the batch

# Plot the aggregated attention weights as a heatmap
plt.figure(figsize=(5,4))
sns.heatmap(avg_attention, annot=False, cmap='viridis', xticklabels=tokenizer.convert_ids_to_tokens(input_ids[0].tolist()), yticklabels=tokenizer.convert_ids_to_tokens(input_ids[0].tolist()))
plt.title(f'Aggregated Attention, Predicted: {predicted_token}')
plt.show()

### 3. Non-metaphor

In [None]:
# Prepare the input
text = "Now it was Cara’s turn to give back. She drew a [MASK] breath and opened her hands, which had been clenched into fists."
input_ids = tokenizer.encode(text, return_tensors='pt')
mask_index = torch.where(input_ids == tokenizer.mask_token_id)[1].tolist()[0]

# Get predictions and attention weights
with torch.no_grad():
    outputs = model(input_ids)
    attention = outputs['attentions']
    prediction_scores = outputs['logits']
    predicted_id = torch.argmax(prediction_scores[0, mask_index]).item()
    predicted_token = tokenizer.convert_ids_to_tokens([predicted_id])[0]

# Aggregate the attention weights across all layers and heads
all_layers_attention = torch.stack(attention).mean(dim=0)  # Average over layers
all_heads_attention = all_layers_attention.mean(dim=1)  # Average over heads
avg_attention = all_heads_attention[0].detach().numpy()  # For the first (and only) input in the batch

# Plot the aggregated attention weights as a heatmap
plt.figure(figsize=(5,4))
sns.heatmap(avg_attention, annot=False, cmap='viridis', xticklabels=tokenizer.convert_ids_to_tokens(input_ids[0].tolist()), yticklabels=tokenizer.convert_ids_to_tokens(input_ids[0].tolist()))
plt.title(f'Aggregated Attention, Predicted: {predicted_token}')
plt.show()

In [None]:
# Prepare the input
text = "I was unbearably hot. I flung the blanket off and sat up. My [MASK] feet found relief on the cold hardwood floor, and I rubbed my eyes."
input_ids = tokenizer.encode(text, return_tensors='pt')
mask_index = torch.where(input_ids == tokenizer.mask_token_id)[1].tolist()[0]

# Get predictions and attention weights
with torch.no_grad():
    outputs = model(input_ids)
    attention = outputs['attentions']
    prediction_scores = outputs['logits']
    predicted_id = torch.argmax(prediction_scores[0, mask_index]).item()
    predicted_token = tokenizer.convert_ids_to_tokens([predicted_id])[0]

# Aggregate the attention weights across all layers and heads
all_layers_attention = torch.stack(attention).mean(dim=0)  # Average over layers
all_heads_attention = all_layers_attention.mean(dim=1)  # Average over heads
avg_attention = all_heads_attention[0].detach().numpy()  # For the first (and only) input in the batch

# Plot the aggregated attention weights as a heatmap
plt.figure(figsize=(5,4))
sns.heatmap(avg_attention, annot=False, cmap='viridis', xticklabels=tokenizer.convert_ids_to_tokens(input_ids[0].tolist()), yticklabels=tokenizer.convert_ids_to_tokens(input_ids[0].tolist()))
plt.title(f'Aggregated Attention, Predicted: {predicted_token}')
plt.show()

In [None]:
# Prepare the input
text = "Brad felt like he was watching a train wreck develop in [MASK] motion and was powerless to stop it."
input_ids = tokenizer.encode(text, return_tensors='pt')
mask_index = torch.where(input_ids == tokenizer.mask_token_id)[1].tolist()[0]

# Get predictions and attention weights
with torch.no_grad():
    outputs = model(input_ids)
    attention = outputs['attentions']
    prediction_scores = outputs['logits']
    predicted_id = torch.argmax(prediction_scores[0, mask_index]).item()
    predicted_token = tokenizer.convert_ids_to_tokens([predicted_id])[0]

# Aggregate the attention weights across all layers and heads
all_layers_attention = torch.stack(attention).mean(dim=0)  # Average over layers
all_heads_attention = all_layers_attention.mean(dim=1)  # Average over heads
avg_attention = all_heads_attention[0].detach().numpy()  # For the first (and only) input in the batch

# Plot the aggregated attention weights as a heatmap
plt.figure(figsize=(5,4))
sns.heatmap(avg_attention, annot=False, cmap='viridis', xticklabels=tokenizer.convert_ids_to_tokens(input_ids[0].tolist()), yticklabels=tokenizer.convert_ids_to_tokens(input_ids[0].tolist()))
plt.title(f'Aggregated Attention, Predicted: {predicted_token}')
plt.show()