In [None]:
# --- Installation (run these first if needed) ---
# !pip install torch transformers
# !pip install bertviz

# --- Imports ---
import torch
from transformers import BertTokenizer, BertForMaskedLM, BertModel
from bertviz import head_view

# --- Step 2: Load Model and Helper Function ---

# Load the pre-trained model and tokenizer for Masked LM
print("Loading models for Masked Language Modeling...")
tokenizer_mlm = BertTokenizer.from_pretrained('bert-base-uncased')
model_mlm = BertForMaskedLM.from_pretrained('bert-base-uncased')
model_mlm.eval() # Set the model to evaluation mode

def predict_masked_word(text, top_k=5):
    """
    Given a text with a [MASK] token, predicts the top_k most likely words 
    to fill the mask.
    """
    # Tokenize the input
    inputs = tokenizer_mlm(text, return_tensors="pt")
    mask_token_index = torch.where(inputs["input_ids"] == tokenizer_mlm.mask_token_id)[1]

    # Get model outputs (logits)
    with torch.no_grad():
        outputs = model_mlm(**inputs)
        predictions = outputs.logits

    # Get the top_k predictions for the masked token
    masked_token_logits = predictions[0, mask_token_index.item()]
    top_k_indices = torch.topk(masked_token_logits, top_k, dim=0).indices
    top_k_tokens = tokenizer_mlm.convert_ids_to_tokens(top_k_indices)
    
    return top_k_tokens

print("MLM Model and helper function ready.\n")

# --- Step 3: Exploration and Analysis ---

print("--- Running Exploration and Analysis ---")

# Test Case 1: Basic Factual Knowledge
text_1 = "The capital of France is [MASK]."
predictions_1 = predict_masked_word(text_1)
print(f"Text: {text_1}")
print(f"Predictions: {predictions_1}\n")

# Test Case 2a: Semantic Context (Medical)
text_2a = "The doctor prescribed the [MASK] to the patient."
predictions_2a = predict_masked_word(text_2a)
print(f"Text: {text_2a}")
print(f"Predictions: {predictions_2a}\n")

# Test Case 2b: Semantic Context (Automotive)
text_2b = "The mechanic checked the [MASK] of the car."
predictions_2b = predict_masked_word(text_2b)
print(f"Text: {text_2b}")
print(f"Predictions: {predictions_2b}\n")

# Test Case 3: Syntactic and Long-Range Context
text_3 = "All the players on the team celebrated after [MASK] won the championship."
predictions_3 = predict_masked_word(text_3)
print(f"Text: {text_3}")
print(f"Predictions: {predictions_3}\n")

print("--- Analysis Complete ---\n")

# --- Step 5: Visualization of Results ---

print("--- Setting up Visualization ---")
# We need BertModel (not ForMaskedLM) to get attention outputs
model_viz = BertModel.from_pretrained('bert-base-uncased', output_attentions=True)
tokenizer_viz = BertTokenizer.from_pretrained('bert-base-uncased')

# Process input
text_viz = "He sat on the river bank."
inputs_viz = tokenizer_viz(text_viz, return_tensors='pt')
outputs_viz = model_viz(**inputs_viz)
attention = outputs_viz.attentions  # This is a tuple of 12 (layers) tensors

print(f"Visualization ready for text: '{text_viz}'")
print("Run the 'head_view' function below in a Jupyter Notebook to see the visualization.")

# Display visualization
# This command must be run in a Jupyter Notebook cell to render the interactive UI.
head_view(attention, tokenizer_viz.convert_ids_to_tokens(inputs_viz['input_ids'][0]))