In [None]:
!pip install llama-index transformers neo4j

In [None]:
from transformers import pipeline

triplet_extractor = pipeline(
    'text2text-generation',
    model='Babelscape/rebel-large',
    tokenizer='Babelscape/rebel-large',
    device='cpu')

  from .autonotebook import tqdm as notebook_tqdm


In [None]:
import re
# Function to parse the generated text and extract the triplets
# Rebel outputs a specific format. This code is mostly copied from the model card!

def clean_triplets(input_text, triplets):
    """Sometimes the model hallucinates, so we filter out entities
       not present in the text"""
    text = input_text.lower()
    clean_triplets = []
    for triplet in triplets:

        if (triplet["head"] == triplet["tail"]):
            continue

        head_match = re.search(
            r'\b' + re.escape(triplet["head"].lower()) + r'\b', text)
        if head_match:
            head_index = head_match.start()
        else:
            head_index = text.find(triplet["head"].lower())

        tail_match = re.search(
            r'\b' + re.escape(triplet["tail"].lower()) + r'\b', text)
        if tail_match:
            tail_index = tail_match.start()
        else:
            tail_index = text.find(triplet["tail"].lower())

        if ((head_index == -1) or (tail_index == -1)):
            continue

        clean_triplets.append((triplet["head"], triplet["type"], triplet["tail"]))

    return clean_triplets

def extract_triplets(input_text):
    text = triplet_extractor.tokenizer.batch_decode([triplet_extractor(input_text, return_tensors=True, return_text=False)[0]["generated_token_ids"]])[0]

    triplets = []
    relation, subject, relation, object_ = '', '', '', ''
    text = text.strip()
    current = 'x'
    for token in text.replace("<s>", "").replace("<pad>", "").replace("</s>", "").split():
        if token == "<triplet>":
            current = 't'
            if relation != '':
                triplets.append({'head': subject.strip(), 'type': relation.strip(),'tail': object_.strip()})
                relation = ''
            subject = ''
        elif token == "<subj>":
            current = 's'
            if relation != '':
                triplets.append({'head': subject.strip(), 'type': relation.strip(),'tail': object_.strip()})
            object_ = ''
        elif token == "<obj>":
            current = 'o'
            relation = ''
        else:
            if current == 't':
                subject += ' ' + token
            elif current == 's':
                object_ += ' ' + token
            elif current == 'o':
                relation += ' ' + token

    if subject != '' and relation != '' and object_ != '':
        triplets.append({'head': subject.strip(), 'type': relation.strip(), 'tail':object_.strip()})
    clean = clean_triplets(input_text, triplets)
    return clean

In [None]:
import os

# Uses openai to generate natural language responses and Cypher statements
os.environ["OPENAI_API_KEY"] = "API-KEY"

In [None]:
from llama_index.graph_stores import Neo4jGraphStore
from llama_index.storage.storage_context import StorageContext

username = "neo4j"
password = "retractor-knot-thermocouples"
url = "bolt://44.211.44.239:7687"
database = "neo4j"

graph_store = Neo4jGraphStore(
    username=username,
    password=password,
    url=url,
    database=database,
)

storage_context = StorageContext.from_defaults(graph_store=graph_store)

In [None]:
from llama_index import KnowledgeGraphIndex, ServiceContext
from llama_index import download_loader

WikipediaReader = download_loader("WikipediaReader")

loader = WikipediaReader()

documents = loader.load_data(
    pages=["The Elder Scrolls V: Skyrim"], auto_suggest=False
)

In [None]:
from llama_index.llms import OpenAI

# rebel supports up to 512 input tokens, but shorter sequences also work well
llm=OpenAI(model_name="gpt-3.5-turbo")
service_context = ServiceContext.from_defaults(llm=llm, chunk_size=256)

# This can take some times
index = KnowledgeGraphIndex.from_documents(
    documents,
    storage_context=storage_context,
    kg_triplet_extract_fn=extract_triplets,
    service_context=service_context)

In [None]:
from llama_index.query_engine import KnowledgeGraphQueryEngine

query_engine = KnowledgeGraphQueryEngine(
    storage_context=storage_context,
    service_context=service_context,
    llm=llm,
    verbose=True,
    refresh_schema=True
)


response = query_engine.query(
    "What is the highest point in Skyrim?",
)

[33;1m[1;3mGraph Store Query: 
MATCH (n:Entity)-[:HIGHEST_POINT]->(m:Entity) WHERE n.id = 'Skyrim' RETURN m.id
[0m[33;1m[1;3mGraph Store Response: [{'m.id': 'Throat of the World'}]
[0m[32;1m[1;3mFinal Response: 
The highest point in Skyrim is the Throat of the World.
[0m