<a href="https://colab.research.google.com/github/raz0208/ModernBERT/blob/main/ModernBERT_TokenEmbedding.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## Extract embedding form inpot text using ModernBERT

In [90]:
# import required libraries
import os
import numpy as np
import pandas as pd
from transformers import AutoTokenizer, AutoModel
import torch
import spacy
from keybert import KeyBERT
import re

### Load NLP and ModernBert models

In [81]:
# Load NLP model for noun phrase extraction
nlp = spacy.load("en_core_web_sm")

# Load ModernBERT tokenizer and model from Hugging Face
MODEL_NAME = "answerdotai/ModernBERT-base"
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
model = AutoModel.from_pretrained(MODEL_NAME)

### Corpus (text) cleaning

In [110]:
# Function to clean text
def clean_text_for_topics(text):
    # Remove content in parentheses (e.g., (2015), (note))
    text = re.sub(r'\([^)]*\)', '', text)

    # Remove percentages, currency, units like "mg", "billion", "years", etc.
    text = re.sub(r'\b\d+(\.\d+)?\s*(%|percent|million|billion|thousand|mg|kg|g|km|cm|years?|months?|days?)\b', '', text, flags=re.IGNORECASE)

    # Remove standalone numbers (integers, decimals, years)
    text = re.sub(r'\b\d+(\.\d+)?\b', '', text)

    # Remove numeric ranges (e.g., 2020–2023 or 10-15)
    text = re.sub(r'\b\d{2,4}\s*[-–—]\s*\d{2,4}\b', '', text)

    # Remove metadata indicators like "Deaths:", "Rate:", "Total:"
    text = re.sub(r'\b(Deaths?|Rate|Cases?|Total|Incidence|Prevalence|Statistics):?.*', '', text, flags=re.IGNORECASE)

    # Remove slashes and colons separating stats or metadata
    text = re.sub(r'[:/]', ' ', text)

    # Remove all punctuation except commas
    text = re.sub(r'[^\w\s,]', '', text)

    # Collapse extra spaces
    text = re.sub(r'\s{2,}', ' ', text)

    return text.strip()

### Extract emmbedings based on full text and each token separatly.

In [82]:
# Function to gest text and return the embedding
def get_text_embedding(text):
    # Tokenize input text
    inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True)

    # Forward pass to get hidden states
    with torch.no_grad():
        outputs = model(**inputs)

    # Get the embeddings (use CLS token for sentence-level embedding)
    cls_embedding = outputs.last_hidden_state[:, 0, :]  # shape: [batch_size, hidden_size]

    return cls_embedding.squeeze().numpy()

In [102]:
# Function to get per-token embeddings
def get_token_embeddings(text):

    # Call text cleaning function
    text = clean_text_for_topics(text)

    # Tokenize with return_tensors and also get token strings
    inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True, return_attention_mask=True)
    tokens = tokenizer.convert_ids_to_tokens(inputs['input_ids'][0])

    # Forward pass
    with torch.no_grad():
        outputs = model(**inputs)

    embeddings = outputs.last_hidden_state.squeeze(0)  # shape: [seq_len, hidden_size]

    # Prepare cleaned token list and their embeddings
    token_embeddings = []
    for token, emb in zip(tokens, embeddings):
        # Skip special tokens and punctuation/whitespace
        if token in tokenizer.all_special_tokens:
            continue
        token = token.lstrip("Ġ▁")  # Remove leading space markers from BPE/SentencePiece
        if not token.isalnum():  # Skip non-alphanumeric tokens
            continue
        token_embeddings.append((token, emb.numpy()))

    return token_embeddings

### Topic extraction: First approch

In [76]:
# ### --- ### Basic method for topic sxtraction ### --- ###

# # Extract topics (noun phrase) from sentence
# def extract_topics(text):
#     doc = nlp(text)
#     topics = []

#     for chunk in doc.noun_chunks:
#         # Remove articles/determiners (like "the", "an", "a")
#         words = [token.text for token in chunk if token.pos_ != "DET"]
#         if words:
#             cleaned_phrase = " ".join(words).strip()
#             topics.append(cleaned_phrase)

#     # Remove duplicates while preserving order
#     seen = set()
#     unique_topics = []
#     for topic in topics:
#         if topic.lower() not in seen:
#             unique_topics.append(topic)
#             seen.add(topic.lower())

#     return unique_topics

In [49]:
# ### --- ### Second method for topic sxtraction ### --- ###

# # Extract topics (noun phrase) from sentence
# def extract_topics(text):
#     doc = nlp(text)
#     topics = []

#     # 1. Add noun chunks (excluding determiners)
#     for chunk in doc.noun_chunks:
#         words = [token.text for token in chunk if token.pos_ != "DET"]
#         if words:
#             cleaned = " ".join(words).strip()
#             topics.append(cleaned)

#     # 2. Add standalone noun tokens (important for medical terms like "diabetes")
#     for token in doc:
#         if token.pos_ == "NOUN" and not token.is_stop:
#             topics.append(token.text.strip())

#     # 3. Add named entities related to medical types
#     for ent in doc.ents:
#         if ent.label_ in {"DISEASE", "CONDITION", "MEDICAL_CONDITION", "HEALTH_CONDITION", "SYMPTOM"}:
#             topics.append(ent.text.strip())

#     # 4. Clean and deduplicate
#     seen = set()
#     unique_topics = []
#     for topic in topics:
#         topic_clean = topic.lower()
#         if topic_clean not in seen and topic_clean.isalpha():
#             unique_topics.append(topic)
#             seen.add(topic_clean)

#     return unique_topics

In [107]:
### --- ### Third method for topic sxtraction ### --- ###

# Extract topics (noun phrase) from sentence
def extract_topics(text):
    doc = nlp(text)
    topics = []

    for chunk in doc.noun_chunks:
        # Keep only noun-based chunks
        if all(token.pos_ in ["NOUN", "PROPN", "ADJ"] for token in chunk if not token.is_stop):
            topic = " ".join(token.text for token in chunk if token.pos_ != "DET").strip()
            topics.append(topic)

    # Fallback: add standalone nouns that weren't in noun_chunks
    for token in doc:
        if token.pos_ in ["NOUN", "PROPN"] and not token.is_stop:
            if not any(token.text in topic for topic in topics):
                topics.append(token.text)

    # Remove duplicates
    seen = set()
    unique_topics = []
    for topic in topics:
        if topic.lower() not in seen:
            unique_topics.append(topic)
            seen.add(topic.lower())

    return unique_topics

### Run the code

## TopicSecond approch
- Extract the representative topic based on the input text.

In [116]:
from sklearn.metrics.pairwise import cosine_similarity

In [117]:
# Function to extract representative topic
def extract_main_topic(text):
    """
    Extract the most representative topic from the input text using ModernBERT embeddings.
    """
    # Step 1: Extract all topics from text
    candidate_topics = extract_topics(text)

    if not candidate_topics:
        return None, None

    # Step 2: Get embedding for full input text
    text_embedding = get_text_embedding(text).reshape(1, -1)

    # Step 3: Compute embeddings and similarities for each candidate topic
    best_topic = None
    highest_similarity = -1

    for topic in candidate_topics:
        topic_embedding = get_text_embedding(topic).reshape(1, -1)
        similarity = cosine_similarity(text_embedding, topic_embedding)[0][0]

        if similarity > highest_similarity:
            highest_similarity = similarity
            best_topic = topic

    return best_topic, highest_similarity

In [88]:
### --- ### Sample text for test ### --- ###

# 1- This is an application about Breast Cancer.
# 2- Treating high blood pressure, high blood lipids, diabetes
# 3- Heart failure, heart attack, stroke, aneurysm, peripheral artery disease, sudden cardiac arrest. Deaths: 17.9 million / 32% (2015)

In [127]:
# Example usage (Sentence: This is an application about Breast Cancer.)
if __name__ == "__main__":
    user_text = input("Enter your text: ")

    # Get sentence embedding
    embedding = get_text_embedding(user_text)
    print("\nSentence Embedding vector shape:", embedding.shape)
    print("Sentence Embedding (first 10 values):", embedding[:10])

    print("\n", "####"*10)

    # Get per-token embeddings
    token_embeddings = get_token_embeddings(user_text)
    print("\nToken-wise Embeddings:")
    for token, emb in token_embeddings:
        print(f"Token: {token:15} | Embedding (first 5 vals): {emb[:5]}")

    print("\n", "####"*10)

    # Extract topic and embedding
    topics = extract_topics(user_text)
    if topics:
        print(f"\nIdentified topics: {topics}")
        for topic in topics:
            topic_embedding = get_text_embedding(topic)
            print(f"\nTopic: '{topic}'")
            print("Embedding shape:", topic_embedding.shape)
            print("Embedding (first 5 values):", topic_embedding[:5])
    else:
        print("No topics found.")

    print("\n", "####"*10)

    # Extract main topic
    Representative_topic, similarity = extract_main_topic(user_text)
    if Representative_topic:
        print(f"\nRepresentative Topic: '{Representative_topic}' (similarity: {similarity:.4f})")

        # Get sentence embedding
        RepTopic_embedding = get_text_embedding(Representative_topic)
        print("\nSentence Embedding vector shape:", RepTopic_embedding.shape)
        print("Sentence Embedding (first 10 values):", RepTopic_embedding[:10])
    else:
        print("No main topic found.")

Enter your text: Heart failure, heart attack, stroke, aneurysm, peripheral artery disease, sudden cardiac arrest. Deaths: 17.9 million / 32% (2015)

Sentence Embedding vector shape: (768,)
Sentence Embedding (first 10 values): [ 0.43869126 -0.26175603 -0.7007977   0.22565924 -0.38932806 -0.41834965
 -1.1746391  -0.8032964   0.19345134 -0.10638831]

 ########################################

Token-wise Embeddings:
Token: Heart           | Embedding (first 5 vals): [ 2.1719382  -0.5325865  -0.69776136  0.8575297   0.06485313]
Token: failure         | Embedding (first 5 vals): [ 2.1571448  -2.0265424  -0.13466616 -0.95061946  0.09105098]
Token: heart           | Embedding (first 5 vals): [ 1.0728792  -0.02802396 -0.36255845  1.1986513   1.0817125 ]
Token: attack          | Embedding (first 5 vals): [ 2.1625645  -0.5557532  -0.5867817   0.46016052  1.0440698 ]
Token: stroke          | Embedding (first 5 vals): [ 1.8419516  -0.65602314 -0.7234627   0.06608873  1.2474613 ]
Token: aneurysm   