<a href="https://colab.research.google.com/github/raz0208/ModernBERT/blob/main/ModernBERT_TokenEmbedding.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [17]:
# import required libraries
import os
import numpy as np
import pandas as pd
from transformers import AutoTokenizer, AutoModel
import torch
import spacy
from keybert import KeyBERT

In [26]:
# Load NLP model for noun phrase extraction
nlp = spacy.load("en_core_web_sm")

# Load ModernBERT tokenizer and model from Hugging Face
MODEL_NAME = "answerdotai/ModernBERT-base"
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
model = AutoModel.from_pretrained(MODEL_NAME)

In [27]:
# Function to gest text and return the embedding
def get_text_embedding(text):
    # Tokenize input text
    inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True)

    # Forward pass to get hidden states
    with torch.no_grad():
        outputs = model(**inputs)

    # Get the embeddings (use CLS token for sentence-level embedding)
    cls_embedding = outputs.last_hidden_state[:, 0, :]  # shape: [batch_size, hidden_size]

    return cls_embedding.squeeze().numpy()

In [28]:
# Function to get per-token embeddings
def get_token_embeddings(text):
    # Tokenize with return_tensors and also get token strings
    inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True, return_attention_mask=True)
    tokens = tokenizer.convert_ids_to_tokens(inputs['input_ids'][0])

    # Forward pass
    with torch.no_grad():
        outputs = model(**inputs)

    embeddings = outputs.last_hidden_state.squeeze(0)  # shape: [seq_len, hidden_size]

    # Prepare cleaned token list and their embeddings
    token_embeddings = []
    for token, emb in zip(tokens, embeddings):
        # Skip special tokens and punctuation/whitespace
        if token in tokenizer.all_special_tokens:
            continue
        token = token.lstrip("Ġ▁")  # Remove leading space markers from BPE/SentencePiece
        if not token.isalnum():  # Skip non-alphanumeric tokens
            continue
        token_embeddings.append((token, emb.numpy()))

    return token_embeddings

In [29]:
# # Extract topics (noun phrase) from sentence
# def extract_topics(text):
#     doc = nlp(text)
#     topics = []

#     for chunk in doc.noun_chunks:
#         # Remove articles/determiners (like "the", "an", "a")
#         words = [token.text for token in chunk if token.pos_ != "DET"]
#         if words:
#             cleaned_phrase = " ".join(words).strip()
#             topics.append(cleaned_phrase)

#     # Remove duplicates while preserving order
#     seen = set()
#     unique_topics = []
#     for topic in topics:
#         if topic.lower() not in seen:
#             unique_topics.append(topic)
#             seen.add(topic.lower())

#     return unique_topics

def extract_topics(text):
    doc = nlp(text)
    topic_list = []

    for chunk in doc.noun_chunks:
        # Remove articles (determiners) like "a", "an", "the"
        cleaned = " ".join(token.text for token in chunk if token.pos_ != "DET").strip()
        if cleaned and cleaned.lower() not in topic_list:
            topic_list.append(cleaned)

    return topic_list

In [30]:
# Example usage (Sentence: This is an application about Breast Cancer.)
if __name__ == "__main__":
    user_text = input("Enter your text: ")

    # Get sentence embedding
    embedding = get_text_embedding(user_text)
    print("\nSentence Embedding vector shape:", embedding.shape)
    print("Sentence Embedding (first 10 values):", embedding[:10])

    # Get per-token embeddings
    token_embeddings = get_token_embeddings(user_text)
    print("\nToken-wise Embeddings:")
    for token, emb in token_embeddings:
        print(f"Token: {token:15} | Embedding (first 5 vals): {emb[:5]}")

    # Extract topic and embedding

    topics = extract_topics(user_text)
    if topics:
        print(f"\nIdentified topics: {topics}")
        for topic in topics:
            topic_embedding = get_text_embedding(topic)
            print(f"\nTopic: '{topic}'")
            print("Embedding shape:", topic_embedding.shape)
            print("Embedding (first 5 values):", topic_embedding[:5])
    else:
        print("No topics found.")

Enter your text: Treating high blood pressure, high blood lipids, diabetes.

Sentence Embedding vector shape: (768,)
Sentence Embedding (first 10 values): [ 0.77355295 -0.9813839  -0.6875638  -0.36766133 -0.45456663 -1.140104
 -1.0688002  -0.98096883  0.03701263  0.48558575]

Token-wise Embeddings:
Token: Treat           | Embedding (first 5 vals): [ 0.9773896   1.1849611  -0.41335937 -0.42269483  0.30082476]
Token: ing             | Embedding (first 5 vals): [ 0.5608321  -1.0576426   0.23242453 -0.13428017 -0.4607672 ]
Token: high            | Embedding (first 5 vals): [ 2.578935   -1.1134218   0.05234717 -0.05968764 -0.13070378]
Token: blood           | Embedding (first 5 vals): [ 1.5280033   0.06228757 -0.822701    0.5588437   1.3647466 ]
Token: pressure        | Embedding (first 5 vals): [ 1.1171628 -1.0177413  0.0522917 -0.3610779  1.4506999]
Token: high            | Embedding (first 5 vals): [ 0.9468698  -0.3100607   0.46468502  1.1450183   0.40300363]
Token: blood           | Em