<a href="https://colab.research.google.com/github/raz0208/ModernBERT/blob/main/ModernBERT_TokenEmbedding_V1.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## Extract embedding form inpot text using ModernBERT Version 1

In [2]:
# Install Neo4j
#!pip install neo4j

In [3]:
# import required libraries
import os
import numpy as np
import pandas as pd
import torch
from transformers import AutoTokenizer, AutoModel
from neo4j import GraphDatabase

### Load NLP and ModernBert models

In [None]:
# Load ModernBERT tokenizer and model from Hugging Face
MODEL_NAME = "answerdotai/ModernBERT-base"
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
model = AutoModel.from_pretrained(MODEL_NAME)

### Extract emmbedings based on full text

In [5]:
# Function to get inpout text and return full text embedding (Edit code to get embedding sentence by sentence)
def get_text_embedding(text):
    # Tokenize input text
    inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True)

    # Forward pass to get hidden states
    with torch.no_grad():
        outputs = model(**inputs)

    # Get the embeddings (use CLS token for sentence-level embedding)
    cls_embedding = outputs.last_hidden_state[:, 0, :]  # shape: [batch_size, hidden_size]

    return cls_embedding.squeeze().numpy()

## Use Neo4j to connect the graph database

In [7]:
# Define Neo4j connection credentials
NEO4J_URI = "neo4j://143.225.233.156:7687"
NEO4J_USER = "rezaazari"
NEO4J_PASSWORD = "rAzari987"

# Initialize the driver
driver = GraphDatabase.driver(NEO4J_URI, auth=(NEO4J_USER, NEO4J_PASSWORD))

# Function to test connection
def test_connection():
    with driver.session() as session:
        greeting = session.run("RETURN 'Connected to Neo4j' AS message").single()["message"]
        print(greeting)

if __name__ == "__main__":
    test_connection()

ERROR:neo4j.pool:Unable to retrieve routing information


ServiceUnavailable: Unable to retrieve routing information

In [None]:
# Function to run Cypher query
def run_query(cypher_query, parameters=None):
    with driver.session() as session:
        result = session.run(cypher_query, parameters or {})
        return [record.data() for record in result]

# Query of showing example nodes
query = "MATCH (n) RETURN n LIMIT 5"
results = run_query(query)
for r in results:
    print(r)

In [None]:
# Function to find similar nodes using cosine similarity ( Edit the code to use indexes to be faster and used single query to gest similar nodes)
def find_similar_nodes(text_embedding, top_n):
     embedding_list = text_embedding.tolist()
     cypher_query = """
     MATCH (n)-[:HAS_EMBEDDING]->(e:ABSTRACT)
     WHERE e.embedding IS NOT NULL
     WITH n, labels(n) AS labels, gds.similarity.cosine($sent_embedding, e.embedding) AS similarity
     RETURN n, labels, similarity
     ORDER BY similarity DESC
     LIMIT $limit
     """
     parameters = {"sent_embedding": embedding_list, "limit": top_n}
     results = run_query(cypher_query, parameters)
     return results

### Exacute the app and get output

In [None]:
### --- ### Sample text for test ### --- ###

# 1- This is an application about Breast Cancer.
# 2- Treating high blood pressure, high blood lipids, diabetes.
# 3- Heart failure, heart attack, stroke, aneurysm, peripheral artery disease, sudden cardiac arrest. Deaths: 17.9 million / 32% (2015)
# 4- Heart failure and stroke are common causes of death.

In [None]:
# Example usage (Sentence: This is an application about Breast Cancer.)
if __name__ == "__main__":
    user_text = input("Enter your text: ")

    # Get sentence embedding
    full_text_embedding = get_text_embedding(user_text)
    print("\nSentence Embedding vector shape:", full_text_embedding.shape)
    print("Sentence Embedding (first 10 values):", full_text_embedding[:10])

    # Call function to run similarity query
    similar_nodes = find_similar_nodes(full_text_embedding, top_n=5)

    # Show the result
    print(f"\nTop {len(similar_nodes)} similar nodes:")
    for node_data in similar_nodes:
      print(f"Node: {node_data['n']}, Similarity: {node_data['similarity']:.4f}")

### Import extracted data to local dataset (Optional)

In [None]:
# Define Neo4j connection credentials
# --- Connect to Second Database ---
SECOND_NEO4J_URI = "neo4j+s://17c6383a.databases.neo4j.io"
SECOND_NEO4J_USER = "neo4j"
SECOND_NEO4J_PASSWORD = "MMJrt6Cc0cp4VQn6QJJphqXFUUytZOpc0ip1tErl-4U"
second_driver = GraphDatabase.driver(SECOND_NEO4J_URI, auth=(SECOND_NEO4J_USER, SECOND_NEO4J_PASSWORD))

# Function to test connection
def test_connection():
    with second_driver.session() as session:
        greeting = session.run("RETURN 'Connected to Neo4j' AS message").single()["message"]
        print(greeting)

if __name__ == "__main__":
    test_connection()

In [None]:
# Function to import extracted similar nodes into local graph database by implementing cypher query
def import_nodes_to_second_db(nodes_data):
    with second_driver.session() as session:
        for record in nodes_data:
            node_props = record['n']
            labels = record['labels']  # From the modified query

            # Build Cypher query
            label_string = ":".join(labels)
            prop_keys = ", ".join(f"{k}: ${k}" for k in node_props.keys())
            cypher_query = f"CREATE (n:{label_string} {{ {prop_keys} }})"

            session.run(cypher_query, node_props)

    print(f"✅ Imported {len(nodes_data)} nodes to the second Neo4j database.")

In [None]:
# Import to second Neo4j
import_nodes_to_second_db(similar_nodes)

In [None]:
# Function to run Cypher query
def run_query(cypher_query, parameters=None):
    with second_driver.session() as session:
        result = session.run(cypher_query, parameters or {})
        return [record.data() for record in result]

# Query of showing example nodes
query = "MATCH (n) RETURN n LIMIT 5"
results = run_query(query)
for r in results:
    print(r)