#### 1. Setup and imports


In [10]:
import os
import re
import json
import time
import requests
import wikipedia

from dotenv import load_dotenv
load_dotenv()  # Loads variables from .env into the environment

import nltk
from nltk.tokenize import word_tokenize

from transformers import BertTokenizer, BertForMaskedLM, pipeline
import torch
import torch.nn.functional as F

GOOGLE_API_KEY = os.getenv("API_KEY", "")
GOOGLE_CSE_ID = os.getenv("SEARCH_ENGINE_ID", "")

# For demonstration only; adapt as needed
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

#### 2. External context

In [11]:
# Initialize a global BERT tokenizer
bert_tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")

def tokenize_redacted_context(redacted_context_list):
    """
    Takes a list of redacted sentences (strings).
    Returns a list of dicts with:
      - 'original_text'
      - 'simple_tokens'  (for optional debugging)
      - 'input_ids'      (BERT-compatible IDs)
      - 'attention_mask' (BERT mask)
    """
    tokenized_result = []
    for sentence in redacted_context_list:
        # Simple NLTK-based tokenization for debugging/log
        simple_tokens = word_tokenize(sentence)
        
        # BERT-based tokenization
        bert_encoded = bert_tokenizer.encode_plus(
            sentence,
            add_special_tokens=True,
            truncation=True,
            max_length=512,
            return_tensors='pt'
        )
        
        tokenized_result.append({
            "original_text": sentence,
            "simple_tokens": simple_tokens,
            "input_ids": bert_encoded["input_ids"],
            "attention_mask": bert_encoded["attention_mask"]
        })
    return tokenized_result

# def fetch_wikipedia_context(query, max_chars=1000, timeout=5):
#     """
#     Simple Wikipedia snippet fetch with python-wikipedia library.
#     'query' can be a title or search term.
#     'max_chars' is how many characters to return from the summary.
#     'timeout' is a naive approach for demonstration.
#     """
#     try:
#         wikipedia.set_lang("en")
#         wikipedia.set_rate_limiting(True)
#         # We'll do a simple approach, ignoring advanced concurrency/timeouts
#         page_titles = wikipedia.search(query, results=1)
#         if not page_titles:
#             return ""
#         page_title = page_titles[0]
#         summary = wikipedia.summary(page_title, sentences=2)
#         return summary[:max_chars]
#     except Exception as e:
#         print(f"[WARN] Wikipedia fetch error: {e}")
#         return ""

# def fetch_google_context(query, api_key=None, cse_id=None, max_chars=1000, timeout=5):
#     """
#     Demonstration of using a Google Custom Search Engine (CSE).
#     'api_key' and 'cse_id' come from .env => (API_KEY, SEARCH_ENGINE_ID).
#     Returns top snippet or empty string if no results found.
#     """
#     if not api_key or not cse_id:
#         return ""

#     base_url = "https://www.googleapis.com/customsearch/v1"
#     params = {
#         "key": api_key,
#         "cx": cse_id,
#         "q": query
#     }

#     try:
#         r = requests.get(base_url, params=params, timeout=timeout)
#         if r.status_code == 200:
#             data = r.json()
#             items = data.get("items", [])
#             if not items:
#                 return ""
#             snippet = items[0].get("snippet", "")
#             return snippet[:max_chars]
#         else:
#             print(f"[WARN] Google search error: status={r.status_code}")
#             return ""
#     except Exception as e:
#         print(f"[WARN] Google search failed: {e}")
#         return ""

# def gather_external_context_from_title(json_file, api_key=None, cse_id=None):
#     """
#     1. Load the JSON data, extract a potential title (first sentence).
#     2. Use fetch_wikipedia_context or fetch_google_context to get external info.
#     3. Return combined text snippet.
#     """
#     with open(json_file, "r", encoding="utf-8") as f:
#         data = json.load(f)

#     # Heuristic: let's look at the json and find the doc title to be the 4th sentence
#     if data["original_sentences"]:
#         potential_title = data["original_sentences"][3]
#     else:
#         potential_title = "Untitled Document"

#     # Wikipedia
#     wiki_context = fetch_wikipedia_context(potential_title, max_chars=1000)

#     # Google
#     google_context = fetch_google_context(potential_title, api_key=api_key, cse_id=cse_id, max_chars=1000)

#     combined_context = wiki_context + "\n\n" + google_context
#     return combined_context

#### 3. Data prep

In [12]:
# def load_training_data(json_file):
#     """
#     Return lists of original sentences (train) and censored sentences (test).
#     """
#     with open(json_file, "r", encoding="utf-8") as f:
#         data = json.load(f)
#     original_sents = data["original_sentences"]
#     censored_sents = data["censored_sentences"]
#     return original_sents, censored_sents

# def create_masked_texts(original_sents):
#     """
#     We create masked versions of the sentences for a naive MLM approach.
#     Example: any word that starts uppercase we replace with [MASK].
#     Also store ground truths so we know what we replaced.
#     """
#     masked_texts = []
#     ground_truths = []

#     for sent in original_sents:
#         words = sent.split()
#         new_words = []
#         truth_words = []
#         for w in words:
#             # Simple heuristic: if w starts uppercase & length>3 => mask
#             # You can do more advanced checks or use spaCy to detect named entities
#             if w[0].isupper() and len(w) > 3:
#                 new_words.append("[MASK]")
#                 truth_words.append(w)
#             else:
#                 new_words.append(w)
#                 truth_words.append(None)
#         masked_sent = " ".join(new_words)
#         masked_texts.append(masked_sent)
#         ground_truths.append(truth_words)

#     return masked_texts, ground_truths

# Summarization pipeline
summarizer = pipeline("summarization", model="facebook/bart-large-cnn")

def summarize_if_needed(text_in, max_bert_length=512):
    """
    Summarizes text_in if it produces more than 'max_bert_length'
    tokens for BERT. Otherwise returns it unchanged.
    """
    tokens = bert_tokenizer.encode(text_in, add_special_tokens=False)
    if len(tokens) <= max_bert_length:
        return text_in  # No summarization required
    
    # Summarize
    summary_out = summarizer(
        text_in,
        max_length=150,  # tune as needed
        min_length=40,
        do_sample=False
    )
    return summary_out[0]["summary_text"]


Device set to use cuda:0


#### 4. BERT model setup

In [13]:
# from transformers import BertTokenizer, BertForMaskedLM

# def prepare_bert_model():
#     tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
#     model = BertForMaskedLM.from_pretrained("bert-base-uncased")
#     model.to(DEVICE)
#     return tokenizer, model

# Load the BERT masked LM
bert_masked_model = BertForMaskedLM.from_pretrained("bert-base-uncased")

##############################################
# External context gathering: Wikipedia example
##############################################
# Make sure you install the 'wikipedia' package:
#   pip install wikipedia


Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertForMaskedLM: ['bert.pooler.dense.bias', 'bert.pooler.dense.weight', 'cls.seq_relationship.bias', 'cls.seq_relationship.weight']
- This IS expected if you are initializing BertForMaskedLM from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForMaskedLM from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


#### 5. Fine tuning the model

In [14]:

# import torch
# import torch.nn.functional as F

# def fine_tune_bert_maskedLM(tokenizer, model, masked_texts, epochs=1, batch_size=4):
#     """
#     Example training loop for masked language modeling.
#     This is a minimal approach, for demonstration only.
#     """
#     model.train()
#     optimizer = torch.optim.AdamW(model.parameters(), lr=1e-5)

#     for epoch in range(epochs):
#         batch_start = 0
#         while batch_start < len(masked_texts):
#             batch_end = batch_start + batch_size
#             batch_sents = masked_texts[batch_start:batch_end]

#             inputs = tokenizer(batch_sents, return_tensors="pt", padding=True, truncation=True)
#             input_ids = inputs["input_ids"].to(DEVICE)
#             attention_mask = inputs["attention_mask"].to(DEVICE)

#             # BERTForMaskedLM expects labels=input_ids for teacher-forcing
#             outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=input_ids)
#             loss = outputs.loss

#             optimizer.zero_grad()
#             loss.backward()
#             optimizer.step()

#             print(f"Epoch {epoch} | Batch {batch_start} Loss={loss.item():.4f}")
#             batch_start += batch_size

#     model.eval()
#     return model

def gather_wikipedia_context(search_terms, max_chars=1000):
    """
    Searches Wikipedia for each term in 'search_terms' (a list of strings).
    Concatenates summaries up to 'max_chars' to avoid going overly long.
    """
    combined = ""
    for term in search_terms:
        try:
            # Get summary
            summary_txt = wikipedia.summary(term, sentences=3)
            # Accumulate
            if len(combined) + len(summary_txt) <= max_chars:
                combined += " " + summary_txt
            else:
                break
        except Exception as e:
            # If there's an error (page not found, disambiguation, etc.), skip
            print(f"[WARN] Could not retrieve Wikipedia for term '{term}' -> {e}")
            continue
    return combined.strip()

def iterative_decensoring(redacted_text, wikipedia_search_terms=None, epochs=3):
    """
    Iteratively replace [REDACTED] placeholders by:
      1) Optionally gather external context from Wikipedia (if search terms given).
      2) Combine external context + redacted text -> single string
      3) Summarize if needed
      4) Replace 1st occurrence of [REDACTED] with [MASK]
      5) BERT Masked LM to predict the masked token
      6) Insert predicted token
      7) Repeat for 'epochs' or until no more placeholders
    """
    current_text = redacted_text
    
    # 1) Optionally gather external context from Wikipedia
    if wikipedia_search_terms:
        ext_context = gather_wikipedia_context(wikipedia_search_terms, max_chars=1200)
        # 2) Combine them
        combined_input = f"{current_text}\n\nAdditionalContext:\n{ext_context}"
    else:
        combined_input = current_text
    
    for epoch in range(epochs):
        # 3) Summarize if needed
        short_text = summarize_if_needed(combined_input, max_bert_length=512)
        
        # If there's no [REDACTED] left, break
        if "[REDACTED]" not in short_text:
            print(f"No more [REDACTED] placeholders at epoch {epoch}. Done.")
            break
        
        # 4) Replace [REDACTED] with [MASK] (1 occurrence)
        masked_text = short_text.replace("[REDACTED]", "[MASK]", 1)
        
        # 5) BERT tokenization
        inputs = bert_tokenizer.encode_plus(masked_text, return_tensors='pt')
        input_ids = inputs["input_ids"]
        
        # Identify [MASK] location
        mask_indices = (input_ids == bert_tokenizer.mask_token_id).nonzero(as_tuple=True)[1]
        if mask_indices.size(0) == 0:
            print(f"No [MASK] found at epoch {epoch}.")
            break
        
        # Forward pass
        with torch.no_grad():
            outputs = bert_masked_model(**inputs)
        logits = outputs.logits
        
        # 6) Get top predicted token
        mask_logits = logits[0, mask_indices, :]
        top_id = torch.argmax(mask_logits, dim=-1)
        predicted_token = bert_tokenizer.decode(top_id).strip()
        
        # Insert predicted token
        updated_text = masked_text.replace("[MASK]", predicted_token, 1)
        
        combined_input = updated_text  # update our text for the next epoch
        
        print(f"Epoch {epoch+1} => predicted: '{predicted_token}'\nResult:\n{combined_input}\n{'-'*40}")
    
    return combined_input

#### 6. Inference and top-K predictions

In [15]:
# def get_top_predictions_for_masked_sentence(tokenizer, model, masked_sentence, top_k=5):
#     model.eval()
#     inputs = tokenizer(masked_sentence, return_tensors="pt")
#     input_ids = inputs["input_ids"].to(DEVICE)
#     attention_mask = inputs["attention_mask"].to(DEVICE)

#     with torch.no_grad():
#         outputs = model(input_ids=input_ids, attention_mask=attention_mask)
    
#     logits = outputs.logits  # [batch_size, seq_len, vocab_size]
#     mask_token_index = (input_ids == tokenizer.mask_token_id)[0].nonzero(as_tuple=True)[0]

#     predictions = {}
#     for idx in mask_token_index:
#         idx = idx.item()
#         logits_for_mask = logits[0, idx]
#         probs = F.softmax(logits_for_mask, dim=0)
#         top_probs, top_ids = probs.topk(top_k)

#         predicted_tokens = [tokenizer.convert_ids_to_tokens(int(i)) for i in top_ids]
#         predicted_scores = [float(tp) for tp in top_probs]

#         predictions[idx] = list(zip(predicted_tokens, predicted_scores))

#     return predictions

#### 7. Putting everything together

In [16]:
# def main(json_file="./data/processed/document_1_processed.json"):
#     # Step A: gather external context
#     external_context = gather_external_context_from_title(
#         json_file,
#         api_key=GOOGLE_API_KEY,
#         cse_id=GOOGLE_CSE_ID
#     )

#     # Step B: load data
#     original_sents, censored_sents = load_training_data(json_file)

#     # Merge external context
#     # We treat the external context as an additional "sentence" for training
#     original_plus_context = original_sents + [external_context]

#     # Step C: create masked data
#     masked_texts, ground_truths = create_masked_texts(original_plus_context)

#     # Step D: prepare & fine-tune BERT
#     tokenizer, model = prepare_bert_model()
#     model = fine_tune_bert_maskedLM(tokenizer, model, masked_texts, epochs=1, batch_size=2)

#     # Step E: test with a random masked sentence
#     test_masked = "areas where the [MASK] and the [MASK] attacked."
#     top_preds = get_top_predictions_for_masked_sentence(tokenizer, model, test_masked, top_k=5)
#     print("\n=== TOP PREDICTIONS FOR TEST MASK ===")
#     print(top_preds)

#     # Also see how it tries to reconstruct a censored sentence
#     # We do a quick hack: replace [REDACTED] with [MASK]
#     for i, cens in enumerate(censored_sents[:3]):
#         test_sent_masked = cens.replace("[REDACTED]", tokenizer.mask_token)
#         results = get_top_predictions_for_masked_sentence(tokenizer, model, test_sent_masked, top_k=5)
#         print(f"\nCensored Sentence {i}: {cens}")
#         print(f"Mask Predictions: {results}")
    
# if __name__ == "__main__":
#     main()

###############################################
# EXAMPLE USAGE (Just for demonstration)
###############################################
if __name__ == "__main__":
    # Some example redacted lines
    redacted_context_list_example = [
        "The small number of [REDACTED] troops initially deployed on the [REDACTED] held long enough for the mobilised force to get into position."
    ]
    
    # Section 2: tokenization
    tokenized_data_example = tokenize_redacted_context(redacted_context_list_example)
    print("Tokenized Data Example:")
    for entry in tokenized_data_example:
        print(entry)
    
    # Combine into single text if that's what your workflow does
    combined_redacted_text = " ".join(d["original_text"] for d in tokenized_data_example)
    
    # Potential Wikipedia terms
    wiki_terms = ["Israeli Army", "air force", "Golan Heights"]
    
    # Section 4: iterative decensoring with external context
    final_decensored = iterative_decensoring(
        redacted_text=combined_redacted_text,
        wikipedia_search_terms=wiki_terms,  # pass None if you don't want external
        epochs=3
    )
    print("\nFinal Decensored Text:\n", final_decensored)


Tokenized Data Example:
{'original_text': 'The small number of [REDACTED] troops initially deployed on the [REDACTED] held long enough for the mobilised force to get into position.', 'simple_tokens': ['The', 'small', 'number', 'of', '[', 'REDACTED', ']', 'troops', 'initially', 'deployed', 'on', 'the', '[', 'REDACTED', ']', 'held', 'long', 'enough', 'for', 'the', 'mobilised', 'force', 'to', 'get', 'into', 'position', '.'], 'input_ids': tensor([[  101,  1996,  2235,  2193,  1997,  1031,  2417, 18908,  2098,  1033,
          3629,  3322,  7333,  2006,  1996,  1031,  2417, 18908,  2098,  1033,
          2218,  2146,  2438,  2005,  1996, 11240, 21758,  2486,  2000,  2131,
          2046,  2597,  1012,   102]]), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1]])}
[WARN] Could not retrieve Wikipedia for term 'air force' -> Page id "air for e" does not match any pages. Try another id!
Epoch 1 => predicted: