<a href="https://colab.research.google.com/github/raz0208/ModernBERT/blob/main/ModernBERT_TokenEmbedding_V1.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## Extract embedding form inpot text using ModernBERT Version 1

In [1]:
# import required libraries
import os
import numpy as np
import pandas as pd
import torch
from transformers import AutoTokenizer, AutoModel

### Load NLP and ModernBert models

In [None]:
# Load ModernBERT tokenizer and model from Hugging Face
MODEL_NAME = "answerdotai/ModernBERT-base"
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
model = AutoModel.from_pretrained(MODEL_NAME)

### Extract emmbedings based on full text

In [3]:
# Function to get inpout text and return full text embedding
def get_text_embedding(text):
    # Tokenize input text
    inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True)

    # Forward pass to get hidden states
    with torch.no_grad():
        outputs = model(**inputs)

    # Get the embeddings (use CLS token for sentence-level embedding)
    cls_embedding = outputs.last_hidden_state[:, 0, :]  # shape: [batch_size, hidden_size]

    return cls_embedding.squeeze().numpy()

In [4]:
### --- ### Sample text for test ### --- ###

# 1- This is an application about Breast Cancer.
# 2- Treating high blood pressure, high blood lipids, diabetes.
# 3- Heart failure, heart attack, stroke, aneurysm, peripheral artery disease, sudden cardiac arrest. Deaths: 17.9 million / 32% (2015)
# 4- Heart failure and stroke are common causes of death.

### Exacute the app and get output

In [5]:
# Example usage (Sentence: This is an application about Breast Cancer.)
if __name__ == "__main__":
    user_text = input("Enter your text: ")

    # Get sentence embedding
    full_text_embedding = get_text_embedding(user_text)
    print("\nSentence Embedding vector shape:", full_text_embedding.shape)
    print("Sentence Embedding (first 10 values):", full_text_embedding[:10])

Enter your text: Heart failure and stroke are common causes of death.

Sentence Embedding vector shape: (768,)
Sentence Embedding (first 10 values): [ 0.2938816  -1.0113076  -0.8573238  -0.06944127 -0.7596021  -0.7222282
 -1.1270422  -1.2861091   0.33987728 -0.522541  ]


## Use Neo4j to connect the graph database

In [6]:
!pip install neo4j

Collecting neo4j
  Downloading neo4j-5.28.1-py3-none-any.whl.metadata (5.9 kB)
Downloading neo4j-5.28.1-py3-none-any.whl (312 kB)
[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/312.3 kB[0m [31m?[0m eta [36m-:--:--[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m312.3/312.3 kB[0m [31m11.0 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: neo4j
Successfully installed neo4j-5.28.1


In [None]:
from neo4j import GraphDatabase

# Define Neo4j connection credentials
NEO4J_URI = "bolt+s://<HOST>:<PORT>"  # or bolt://, neo4j+s:// depending on your setup
NEO4J_USER = "<USERNAME>"
NEO4J_PASSWORD = "<PASSWORD>"

# Initialize the driver
driver = GraphDatabase.driver(NEO4J_URI, auth=(NEO4J_USER, NEO4J_PASSWORD))

# Function to test connection
def test_connection():
    with driver.session() as session:
        greeting = session.run("RETURN 'Connected to Neo4j' AS message").single()["message"]
        print(greeting)

if __name__ == "__main__":
    test_connection()