#### 1. Setup and imports


In [2]:
import os
import re
import json
import time
import requests
import wikipedia

from dotenv import load_dotenv
load_dotenv()  # Loads variables from .env into the environment

from transformers import BertTokenizer, BertForMaskedLM
import torch
import torch.nn.functional as F

GOOGLE_API_KEY = os.getenv("API_KEY", "")
GOOGLE_CSE_ID = os.getenv("SEARCH_ENGINE_ID", "")

# For demonstration only; adapt as needed
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

  from .autonotebook import tqdm as notebook_tqdm


#### 2. External context

In [3]:
def fetch_wikipedia_context(query, max_chars=1000, timeout=5):
    """
    Simple Wikipedia snippet fetch with python-wikipedia library.
    'query' can be a title or search term.
    'max_chars' is how many characters to return from the summary.
    'timeout' is a naive approach for demonstration.
    """
    try:
        wikipedia.set_lang("en")
        wikipedia.set_rate_limiting(True)
        # We'll do a simple approach, ignoring advanced concurrency/timeouts
        page_titles = wikipedia.search(query, results=1)
        if not page_titles:
            return ""
        page_title = page_titles[0]
        summary = wikipedia.summary(page_title, sentences=2)
        return summary[:max_chars]
    except Exception as e:
        print(f"[WARN] Wikipedia fetch error: {e}")
        return ""

def fetch_google_context(query, api_key=None, cse_id=None, max_chars=1000, timeout=5):
    """
    Demonstration of using a Google Custom Search Engine (CSE).
    'api_key' and 'cse_id' come from .env => (API_KEY, SEARCH_ENGINE_ID).
    Returns top snippet or empty string if no results found.
    """
    if not api_key or not cse_id:
        return ""

    base_url = "https://www.googleapis.com/customsearch/v1"
    params = {
        "key": api_key,
        "cx": cse_id,
        "q": query
    }

    try:
        r = requests.get(base_url, params=params, timeout=timeout)
        if r.status_code == 200:
            data = r.json()
            items = data.get("items", [])
            if not items:
                return ""
            snippet = items[0].get("snippet", "")
            return snippet[:max_chars]
        else:
            print(f"[WARN] Google search error: status={r.status_code}")
            return ""
    except Exception as e:
        print(f"[WARN] Google search failed: {e}")
        return ""

def gather_external_context_from_title(json_file, api_key=None, cse_id=None):
    """
    1. Load the JSON data, extract a potential title (first sentence).
    2. Use fetch_wikipedia_context or fetch_google_context to get external info.
    3. Return combined text snippet.
    """
    with open(json_file, "r", encoding="utf-8") as f:
        data = json.load(f)

    # Heuristic: let's look at the json and find the doc title to be the 4th sentence
    if data["original_sentences"]:
        potential_title = data["original_sentences"][3]
    else:
        potential_title = "Untitled Document"

    # Wikipedia
    wiki_context = fetch_wikipedia_context(potential_title, max_chars=1000)

    # Google
    google_context = fetch_google_context(potential_title, api_key=api_key, cse_id=cse_id, max_chars=1000)

    combined_context = wiki_context + "\n\n" + google_context
    return combined_context

#### 3. Data prep

In [4]:
def load_training_data(json_file):
    """
    Return lists of original sentences (train) and censored sentences (test).
    """
    with open(json_file, "r", encoding="utf-8") as f:
        data = json.load(f)
    original_sents = data["original_sentences"]
    censored_sents = data["censored_sentences"]
    return original_sents, censored_sents

def create_masked_texts(original_sents):
    """
    We create masked versions of the sentences for a naive MLM approach.
    Example: any word that starts uppercase we replace with [MASK].
    Also store ground truths so we know what we replaced.
    """
    masked_texts = []
    ground_truths = []

    for sent in original_sents:
        words = sent.split()
        new_words = []
        truth_words = []
        for w in words:
            # Simple heuristic: if w starts uppercase & length>3 => mask
            # You can do more advanced checks or use spaCy to detect named entities
            if w[0].isupper() and len(w) > 3:
                new_words.append("[MASK]")
                truth_words.append(w)
            else:
                new_words.append(w)
                truth_words.append(None)
        masked_sent = " ".join(new_words)
        masked_texts.append(masked_sent)
        ground_truths.append(truth_words)

    return masked_texts, ground_truths

#### 4. BERT model setup

In [5]:
from transformers import BertTokenizer, BertForMaskedLM

def prepare_bert_model():
    tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
    model = BertForMaskedLM.from_pretrained("bert-base-uncased")
    model.to(DEVICE)
    return tokenizer, model

#### 5. Fine tuning the model

In [6]:

import torch
import torch.nn.functional as F

def fine_tune_bert_maskedLM(tokenizer, model, masked_texts, epochs=1, batch_size=4):
    """
    Example training loop for masked language modeling.
    This is a minimal approach, for demonstration only.
    """
    model.train()
    optimizer = torch.optim.AdamW(model.parameters(), lr=1e-5)

    for epoch in range(epochs):
        batch_start = 0
        while batch_start < len(masked_texts):
            batch_end = batch_start + batch_size
            batch_sents = masked_texts[batch_start:batch_end]

            inputs = tokenizer(batch_sents, return_tensors="pt", padding=True, truncation=True)
            input_ids = inputs["input_ids"].to(DEVICE)
            attention_mask = inputs["attention_mask"].to(DEVICE)

            # BERTForMaskedLM expects labels=input_ids for teacher-forcing
            outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=input_ids)
            loss = outputs.loss

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            print(f"Epoch {epoch} | Batch {batch_start} Loss={loss.item():.4f}")
            batch_start += batch_size

    model.eval()
    return model

#### 6. Inference and top-K predictions

In [7]:
def get_top_predictions_for_masked_sentence(tokenizer, model, masked_sentence, top_k=5):
    model.eval()
    inputs = tokenizer(masked_sentence, return_tensors="pt")
    input_ids = inputs["input_ids"].to(DEVICE)
    attention_mask = inputs["attention_mask"].to(DEVICE)

    with torch.no_grad():
        outputs = model(input_ids=input_ids, attention_mask=attention_mask)
    
    logits = outputs.logits  # [batch_size, seq_len, vocab_size]
    mask_token_index = (input_ids == tokenizer.mask_token_id)[0].nonzero(as_tuple=True)[0]

    predictions = {}
    for idx in mask_token_index:
        idx = idx.item()
        logits_for_mask = logits[0, idx]
        probs = F.softmax(logits_for_mask, dim=0)
        top_probs, top_ids = probs.topk(top_k)

        predicted_tokens = [tokenizer.convert_ids_to_tokens(int(i)) for i in top_ids]
        predicted_scores = [float(tp) for tp in top_probs]

        predictions[idx] = list(zip(predicted_tokens, predicted_scores))

    return predictions

#### 7. Putting everything together

In [8]:
def main(json_file="./data/processed/document_1_processed.json"):
    # Step A: gather external context
    external_context = gather_external_context_from_title(
        json_file,
        api_key=GOOGLE_API_KEY,
        cse_id=GOOGLE_CSE_ID
    )

    # Step B: load data
    original_sents, censored_sents = load_training_data(json_file)

    # Merge external context
    # We treat the external context as an additional "sentence" for training
    original_plus_context = original_sents + [external_context]

    # Step C: create masked data
    masked_texts, ground_truths = create_masked_texts(original_plus_context)

    # Step D: prepare & fine-tune BERT
    tokenizer, model = prepare_bert_model()
    model = fine_tune_bert_maskedLM(tokenizer, model, masked_texts, epochs=1, batch_size=2)

    # Step E: test with a random masked sentence
    test_masked = "areas where the [MASK] and the [MASK] attacked."
    top_preds = get_top_predictions_for_masked_sentence(tokenizer, model, test_masked, top_k=5)
    print("\n=== TOP PREDICTIONS FOR TEST MASK ===")
    print(top_preds)

    # Also see how it tries to reconstruct a censored sentence
    # We do a quick hack: replace [REDACTED] with [MASK]
    for i, cens in enumerate(censored_sents[:3]):
        test_sent_masked = cens.replace("[REDACTED]", tokenizer.mask_token)
        results = get_top_predictions_for_masked_sentence(tokenizer, model, test_sent_masked, top_k=5)
        print(f"\nCensored Sentence {i}: {cens}")
        print(f"Mask Predictions: {results}")
    
if __name__ == "__main__":
    main()

[WARN] Wikipedia fetch error: Page id "structured equation modeling" does not match any pages. Try another id!
[WARN] Google search error: status=403


BertForMaskedLM has generative capabilities, as `prepare_inputs_for_generation` is explicitly overwritten. However, it doesn't directly inherit from `GenerationMixin`. From 👉v4.50👈 onwards, `PreTrainedModel` will NOT inherit from `GenerationMixin`, and this model will lose the ability to call `generate` and other related functions.
  - If you are the owner of the model architecture code, please modify your model class such that it inherits from `GenerationMixin` (after `PreTrainedModel`, otherwise you'll get an exception).
  - If you are not the owner of the model architecture class, please contact the model code owner to update it.
Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertForMaskedLM: ['bert.pooler.dense.bias', 'bert.pooler.dense.weight', 'cls.seq_relationship.bias', 'cls.seq_relationship.weight']
- This IS expected if you are initializing BertForMaskedLM from the checkpoint of a model trained on another task or with another archite

Epoch 0 | Batch 0 Loss=9.8682
Epoch 0 | Batch 2 Loss=5.9569
Epoch 0 | Batch 4 Loss=6.5115
Epoch 0 | Batch 6 Loss=6.5036
Epoch 0 | Batch 8 Loss=8.4266
Epoch 0 | Batch 10 Loss=7.5029
Epoch 0 | Batch 12 Loss=6.4291
Epoch 0 | Batch 14 Loss=3.2505
Epoch 0 | Batch 16 Loss=4.9005
Epoch 0 | Batch 18 Loss=7.0371
Epoch 0 | Batch 20 Loss=8.4356
Epoch 0 | Batch 22 Loss=5.3258
Epoch 0 | Batch 24 Loss=4.0817
Epoch 0 | Batch 26 Loss=3.6666
Epoch 0 | Batch 28 Loss=2.6724
Epoch 0 | Batch 30 Loss=3.9941
Epoch 0 | Batch 32 Loss=7.8576
Epoch 0 | Batch 34 Loss=7.1682
Epoch 0 | Batch 36 Loss=6.0278
Epoch 0 | Batch 38 Loss=3.9803
Epoch 0 | Batch 40 Loss=5.5448
Epoch 0 | Batch 42 Loss=5.5054
Epoch 0 | Batch 44 Loss=4.0677
Epoch 0 | Batch 46 Loss=3.8805
Epoch 0 | Batch 48 Loss=3.4331
Epoch 0 | Batch 50 Loss=4.9956
Epoch 0 | Batch 52 Loss=4.5139
Epoch 0 | Batch 54 Loss=6.8744
Epoch 0 | Batch 56 Loss=8.7369
Epoch 0 | Batch 58 Loss=5.0363
Epoch 0 | Batch 60 Loss=4.6140
Epoch 0 | Batch 62 Loss=5.5696
Epoch 0 | Bat