In [None]:
from datasets import load_dataset

dataset = load_dataset("conll2003")

In [None]:
import torch
import torch.nn.functional as F
from transformers import AutoTokenizer, AutoModel
from datasets import load_dataset
import string
import random

# Load model and tokenizer
model_name = "bert-base-cased"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModel.from_pretrained(model_name)
model.eval()

# Load a sample sentence from the CoNLL-2003 dataset
dataset = load_dataset("conll2003", split="train[:1]")  # Just one example
sentence = dataset[0]["tokens"]
labels = dataset[0]["ner_tags"]
label_names = dataset.features["ner_tags"].feature.names

# Tokenize with word alignment
encoding = tokenizer(
    sentence,
    return_tensors="pt",
    is_split_into_words=True,
    return_attention_mask=True,
    return_offsets_mapping=True
)

word_ids = encoding.word_ids()
input_ids = encoding["input_ids"]
attention_mask = encoding["attention_mask"]

encoding_clean = {k: v for k, v in encoding.items() if k != "offset_mapping"}

# Get full original output
with torch.no_grad():
    original_output = model(**encoding_clean)   
original_final_layer = original_output.last_hidden_state.squeeze(0)

# Prepare mapping from word index to token indices
word_to_tokens = {}
for idx, word_id in enumerate(word_ids):
    if word_id is not None:
        word_to_tokens.setdefault(word_id, []).append(idx)

# Prepare results
results = []

# Loop over each word
for word_id, token_indices in word_to_tokens.items():
    word = sentence[word_id]
    label = label_names[labels[word_id]]

    # Skip punctuation
    if (
        word in string.punctuation or
        tokenizer.convert_tokens_to_string(tokenizer.tokenize(word)).strip() in tokenizer.all_special_tokens
    ):
        continue

    # Clone the original input embedding
    with torch.no_grad():
        # Get the input embeddings
        input_embeds = model.embeddings(input_ids).clone()
        
        # Replace all embeddings *except* current word's tokens with random vectors
        for i in range(input_embeds.shape[1]):
            if i not in token_indices:
                input_embeds[0, i] = torch.randn_like(input_embeds[0, i])

        # Pass through model (via embeddings input)
        altered_output = model(inputs_embeds=input_embeds, attention_mask=attention_mask)
        altered_final_layer = altered_output.last_hidden_state.squeeze(0)

    # Mean pool over token indices for current word
    original_vec = original_final_layer[token_indices].mean(dim=0)
    altered_vec = altered_final_layer[token_indices].mean(dim=0)

    # Compute Euclidean distance
    distance = F.pairwise_distance(original_vec.unsqueeze(0), altered_vec.unsqueeze(0), p=2).item()
    results.append((word, label, distance))

# Sort by distance
results.sort(key=lambda x: x[2], reverse=True)

# Print results
print(f"{'Word':<15} {'Entity':<10} {'Contextual Shift (Euclidean)':>30}")
print("-" * 60)
for word, label, dist in results:
    print(f"{word:<15} {label:<10} {dist:>30.6f}")


TypeError: BertModel.forward() got an unexpected keyword argument 'offset_mapping'