<a href="https://colab.research.google.com/github/tomasonjo/blogs/blob/master/llm/diffbot_rag_ingest.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
!pip --quiet install langchain-community langchain-core neo4j langchain-openai langchain-text-splitters tiktoken langchain-experimental

In [2]:
import os
from langchain_community.graphs import Neo4jGraph
from langchain_openai import OpenAIEmbeddings
from langchain_text_splitters import TokenTextSplitter

os.environ["OPENAI_API_KEY"] = "sk-"

url = "bolt://44.212.20.20:7687"
username ="neo4j"
password = "misfits-"

graph = Neo4jGraph(
    url=url,
    username=username,
    password=password,
    refresh_schema=False
)

text_splitter = TokenTextSplitter(
    chunk_size=500,
    chunk_overlap=50,
)

embeddings = OpenAIEmbeddings(model="text-embedding-3-small")

In [3]:
import requests
import pandas as pd

DIFF_TOKEN = ""

def get_articles(query: str, size: int = 5 , offset: int = 0):
    """
    Fetch relevant articles from Diffbot KG endpoint
    """
    search_host = "https://kg.diffbot.com/kg/v3/dql?"
    search_query = f'query=type%3AArticle+text%3A"{query}"+strict%3Alanguage%3A"en"+sortBy%3Adate'
    url = f"{search_host}{search_query}&token={DIFF_TOKEN}&from={offset}&size={size}"
    return requests.get(url).json()

In [4]:
data = get_articles("Nvidia", size=50)

In [5]:
from typing import List
CATEGORY_THRESHOLD = 0.50
params = []

def get_tag_type(types: List[str]) -> str:
    try:
        return types[0].split("/")[-1]
    except:
        return 'Node'

for row in data['data']:
    article = row['entity']
    split_chunks = text_splitter.split_text(article['text'])
    params.append({
        'sentiment': article['sentiment'],
        'date': int(article['date']['timestamp'] / 1000),
        'publisher_region': article.get('publisherRegion'),
        'site_name': article['siteName'],
        'language': article['language'],
        'title': article['title'],
        'text': article['text'],
        'categories': [el['name'] for el in article.get('categories', []) if el['score'] > CATEGORY_THRESHOLD],
        'author': article.get('author'),
        'tags': [{'sentiment': el['sentiment'], 'name': el['label'], 'type':get_tag_type(el.get('types'))} for el in article['tags']],
        'page_url': article['pageUrl'],
        'id': article['id'],
        'chunks': [{'text': el, 'embedding': embeddings.embed_query(el), 'index': i} for i, el in enumerate(split_chunks)]
    })

In [6]:
import_cypher_query = """
UNWIND $data AS row
MERGE (a:Article {id:row.id})
SET a.sentiment = toFloat(row.sentiment),
    a.title = row.title,
    a.text = row.text,
    a.language = row.language,
    a.pageUrl = row.page_url,
    a.date = datetime({epochSeconds: row.date})
MERGE (s:Site {name: row.site_name})
ON CREATE SET s.publisherRegion = row.publisher_region
MERGE (a)-[:ON_SITE]->(s)
FOREACH (category in row.category |
  MERGE (c:Category {name: category}) MERGE (a)-[:IN_CATEGORY]->(c)
)
FOREACH (tag in row.tags |
  MERGE (t:Tag {name: tag.name})
  MERGE (a)-[:HAS_TAG {sentiment: tag.sentiment}]->(t)
)
FOREACH (i in CASE WHEN row.author IS NOT NULL THEN [1] ELSE [] END |
  MERGE (au:Author {name: row.author})
  MERGE (a)-[:HAS_AUTHOR]->(au)
  MERGE (au)-[:WRITES_FOR]->(s)
)
WITH a, row
UNWIND row.chunks AS chunk
  MERGE (c:Chunk {id: row.id + '-' + chunk.index})
  SET c.text = chunk.text,
      c.index = chunk.index
  MERGE (a)-[:HAS_CHUNK]->(c)
  WITH c, chunk
  CALL db.create.setNodeVectorProperty(c, 'embedding', chunk.embedding)
"""

In [7]:
graph.query(import_cypher_query, params={'data': params})

[]

In [8]:
from langchain_experimental.graph_transformers import DiffbotGraphTransformer

diffbot_nlp = DiffbotGraphTransformer(diffbot_api_key=DIFF_TOKEN)

In [9]:
import os

cpu_count_os = os.cpu_count()
print("Number of CPUs available according to os module:", cpu_count_os)
texts = graph.query("MATCH (a:Article) WHERE a.processed IS NULL RETURN a.id AS id, a.text AS text")

Number of CPUs available according to os module: 2


In [10]:
from concurrent.futures import ThreadPoolExecutor, as_completed
from langchain_core.documents import Document
from tqdm import tqdm  # Import tqdm for progress tracking

def process_document(text):
    try:
        # Assume diffbot_nlp.convert_to_graph_documents is a method that converts text to graph documents
        return diffbot_nlp.convert_to_graph_documents([Document(page_content=text['text'], metadata={'id': text['id']})])
    except Exception as e:
        print(f"Error processing document with ID {text['id']}: {e}")
        return []

graph_documents = []
# Since we are mostly waiting, we have increase the num of workers
max_workers = cpu_count_os * 5

with ThreadPoolExecutor(max_workers=max_workers) as executor:
    # Submitting all tasks and creating a list of future objects
    futures = [executor.submit(process_document, text) for text in texts]

    # Using tqdm to track progress as each future completes
    for future in tqdm(as_completed(futures), total=len(futures), desc="Processing documents"):
        graph_document = future.result()
        graph_documents.extend(graph_document)

Processing documents: 100%|██████████| 50/50 [00:55<00:00,  1.10s/it]


In [11]:
node_import_query = """
MATCH (a:Article {id: $document.metadata.id})
UNWIND $data AS row
MERGE (source:`_Entity_` {id: row.id})
SET source += row.properties
MERGE (a)-[:MENTIONS]->(source)
WITH source, row
CALL apoc.create.addLabels( source, [row.type] ) YIELD node
RETURN count(*)
"""

rel_import_query = """
UNWIND $data AS row
MERGE (source:`_Entity_` {id: row.source})
MERGE (target:`_Entity_` {id: row.target})
WITH source, target, row
CALL apoc.merge.relationship(source, row.type,
{}, row.properties, target) YIELD rel
RETURN count(*)
"""

for document in graph_documents:
    # Import nodes
    graph.query(
        node_import_query,
        {
            "data": [el.__dict__ for el in document.nodes],
            "document": document.source.__dict__,
        },
    )
    # Import relationships
    graph.query(
        rel_import_query,
        {
            "data": [
                {
                    "source": el.source.id,
                    "source_label": el.source.type,
                    "target": el.target.id,
                    "target_label": el.target.type,
                    "type": el.type.replace(" ", "_").upper(),
                    "properties": el.properties,
                }
                for el in document.relationships
            ]
        },
    )