## Overview

This notebook showcase the use of Llama index to combine a chatbot and RDF knowledge graph.

The RDF data is converted into human-readable triples, which are then used to build a vector index. For this, triples are split into documents (by subject), documents embedding are then computed, and used to build a key-value store mapping the embeddings to the document contents.

When the user enters a question, llama index will compute the question embedding, look up the top k similar documents and inject their content (triples) as context in the prompt.

The triples used as context can then be returned along with the question to verify the reliability of the answer.

In [2]:

import os
from importlib import reload
from pathlib import Path
from typing import Optional, List, Mapping, Any

from dotenv import load_dotenv
import torch
from langchain.embeddings import HuggingFaceEmbeddings 
from llama_index import LangchainEmbedding, PromptHelper
from llama_index import GPTVectorStoreIndex
from llama_index import LLMPredictor, ServiceContext
from rdflib import ConjunctiveGraph, Graph

import aikg.utils.llm as akllm
import aikg.utils.rdf as akrdf
load_dotenv()

Loading checkpoint shards: 100%|██████████| 39/39 [00:37<00:00,  1.05it/s]
Loading checkpoint shards: 100%|██████████| 39/39 [00:17<00:00,  2.21it/s]


True

## Test data

For the sake of this example, we use the [pokemon knowledge graph](https://www.pokemonkg.org/) developed by Kevin Haller, which aggregates data from multiple source into a single RDF graph.

In [4]:
ontology = '/tmp/ontology.nt'
instances = '/tmp/instances.nq'

In [5]:
!curl https://www.pokemonkg.org/ontology/ontology.nt -o $ontology
!curl https://www.pokemonkg.org/download/dump/poke-a.nq.gz -o - | gzip -dc  > $instances

  % Total    % Received % Xferd  Average Speed   Time    Time     Time  Current
                                 Dload  Upload   Total   Spent    Left  Speed
100  108k  100  108k    0     0   393k      0 --:--:-- --:--:-- --:--:--  394k
  % Total    % Received % Xferd  Average Speed   Time    Time     Time  Current
                                 Dload  Upload   Total   Spent    Left  Speed
100 1599k  100 1599k    0     0  3846k      0 --:--:-- --:--:-- --:--:-- 3853k


In [6]:
!head $instances

<https://pokemonkg.org/dataset/artwork/sugimori-early-japan> <http://www.w3.org/1999/02/22-rdf-syntax-ns#type> <http://www.w3.org/ns/dcat#Dataset> <https://pokemonkg.org/dataset/artwork/sugimori-early-japan> .
<https://pokemonkg.org/dataset/artwork/sugimori-early-japan> <http://www.w3.org/1999/02/22-rdf-syntax-ns#type> <http://www.w3.org/ns/prov#Entity> <https://pokemonkg.org/dataset/artwork/sugimori-early-japan> .
<https://pokemonkg.org/dataset/artwork/sugimori-early-japan> <http://purl.org/dc/terms/accrualPeriodicity> <http://purl.org/linked-data/sdmx/2009/code#freq-A> <https://pokemonkg.org/dataset/artwork/sugimori-early-japan> .
<https://pokemonkg.org/dataset/artwork/sugimori-early-japan> <http://purl.org/dc/terms/description> "This dataset provides meta information about early version of official Pokémon artwork in Japan of Pokémon in the national Pokédex."@en <https://pokemonkg.org/dataset/artwork/sugimori-early-japan> .
<https://pokemonkg.org/dataset/artwork/sugimori-early-japan

The instance graph is composed of multiple "contexts", or _named graphs_.
Here, we remove the contexts and split graphs by subject to create "instance graphs".
The ontology is stored in a separate graph which will be used later to obtain the labels of properties and classes.

In [7]:
reload(akrdf)
# From quads to triples (dropping named graphs context)
instance_quads = ConjunctiveGraph()
instance_quads.parse(instances, format='nquads')

# One graph per subject (to generate documents)
instance_graphs = [s for g in instance_quads.contexts() for s in akrdf.split_graph_by_subject(g)]
schema_graph = Graph().parse(ontology, format='nt')

Llama index allows us to specify custom models (instead of ChatGPT) for both the text generation (i.e. question answering), and computing the embeddings used for similarity search.

Here, we define our own Llm class for text-generation, which uses the `Alpaca-lora-7b`. For the embedding, we use the default model from HuggingFaceEmbedding, which is `sentence-transformers/all-mpnet-base-v2`

In [9]:
# define prompt helper
# set maximum input size
max_input_size = 2048
# set number of output tokens
num_output = 256
# set maximum chunk overlap
max_chunk_overlap = 20

prompt_helper = PromptHelper(max_input_size, num_output, max_chunk_overlap)
import aikg.config.chat
from transformers import pipeline
from langchain.llms.base import LLM
LLM_CONFIG = aikg.config.chat.ChatConfig()


class CustomLLM(LLM):
    model_id: str = LLM_CONFIG.model_id
    pipeline = pipeline(
        "text-generation",
        model=model_id,
        model_kwargs={"torch_dtype": torch.bfloat16},
    )

    def _call(
        self,
        prompt: str,
        stop: Optional[List[str]] = None,
        max_new_tokens: int = LLM_CONFIG.max_new_tokens,
    ) -> str:
        prompt_length = len(prompt)
        response = self.pipeline(prompt, max_new_tokens=max_new_tokens)[0][
            "generated_text"
        ]

        # only return newly generated tokens
        return response[prompt_length:]

    @property
    def _identifying_params(self) -> Mapping[str, Any]:
        return {"name_of_model": self.model_id}

    @property
    def _llm_type(self) -> str:
        return "custom"

# define our LLM
llm_predictor = LLMPredictor(llm=CustomLLM())
embed_model = LangchainEmbedding(HuggingFaceEmbeddings())

service_context = ServiceContext.from_defaults(
    llm_predictor=llm_predictor,
    prompt_helper=prompt_helper,
    embed_model=embed_model,
)


INFO:sentence_transformers.SentenceTransformer:Load pretrained SentenceTransformer: sentence-transformers/all-mpnet-base-v2
INFO:sentence_transformers.SentenceTransformer:Use pytorch device: cpu


The instance graphs are then converted to human readable triples and loaded into a vector store.

During conversion, we also include the labels of the classes and properties, which are stored in the ontology graph.

The index is then built using the vector store, as a key-value store mapping the embeddings to the human-readable triples. Each document in the dataset is related to one subject.

The index can be persisted to disk for reuse. As we are using the SimpleVectorIndex, it is stored as a plain JSON file. Vector databases are also supported and should provide much better performances for larger datasets. 

In [None]:
import joblib
# Load each instance graph as a document
# The schema is injected in each graph to
# provide human readable context
os.environ["TOKENIZERS_PARALLELISM"] = "false"
loader = akrdf.CustomRDFReader()
doc_graphs = map(lambda g: g | schema_graph, instance_graphs)
runner = joblib.parallel(n_jobs=12)
documents = runner(joblib.delayed(loader.load_data)(g) for g in doc_graphs)
documents = [doc for doc in documents if doc.text]
index = GPTVectorStoreIndex.from_documents(documents, service_context=service_context)
index.storage_context.persist(persist_dir=os.environ["INDEX_PATH"])

# Or load from disk if pre-computed
# from llama_index import StorageContext, load_index_from_storage
# storage_context = StorageContext.from_defaults(persist_dir=os.environ["INDEX_PATH"])
# index = load_index_from_storage(storage_context)
# index = GPTSimpleVectorIndex.load_from_disk('<path>/pokemon_index.json', service_context=service_context)


In [93]:
len(documents)

Document(text='<#039 Central Kalos> <describes Pokémon> <bibarel>\n<#039 Central Kalos> <has entry number> <39>', doc_id='2447e6a4-967f-4ebb-9c76-8d9556f96a30', embedding=None, doc_hash='64f60290aac9456e62f68ef2269cfab135bd48bc14be5fc3d21cdc13b0f34478', extra_info=None)

In [65]:
from llama_index import QuestionAnswerPrompt
QA_PROMPT_TMLP = (
    "We have provided the contextual facts below in the form of"
    "<subject> <predicate> <object> triples. \n"
    "-----------------\n"
    "{context_str}"
    "\n-----------------\n"
    "Please answer the question using only the context and no "
    "prior knowledge. If the context is insufficient to answer "
    "the question, simply answer the words 'Not found'.\n"
    "Question: {query_str}\n"
)

QA_PROMPT = QuestionAnswerPrompt(QA_PROMPT_TMLP)
query_engine = index.as_query_engine(
    text_qa_template=QA_PROMPT,
    similarity_top_k=2
)


When querying the index, the query embedding is computed and used to look up the index. The top K most similary documents are then selected and injected as context in the prompt.

In [68]:
response = query_engine.query(
    "What does Psychic terrain do?",
)

Batches: 100%|██████████| 1/1 [00:00<00:00, 16.14it/s]
INFO:llama_index.token_counter.token_counter:> [retrieve] Total LLM token usage: 0 tokens
INFO:llama_index.token_counter.token_counter:> [retrieve] Total embedding token usage: 6 tokens
INFO:llama_index.token_counter.token_counter:> [get_response] Total LLM token usage: 176 tokens
INFO:llama_index.token_counter.token_counter:> [get_response] Total embedding token usage: 0 tokens
INFO:llama_index.token_counter.token_counter:> [get_response] Total LLM token usage: 176 tokens
INFO:llama_index.token_counter.token_counter:> [get_response] Total embedding token usage: 0 tokens


In [69]:
import textwrap
textwrap.wrap(response.response, width=70)

['Answer: Psychic terrain protects Pokémon on the ground from priority',
 'moves and powers up Psychic-type moves for five turns.']

The triples used to generate the answer can be extracted from the result to verify the accuracy of the answer.

In [64]:
[ print(node.node.text) for node in response.source_nodes]

<Psychic Terrain> <has Type> <Psychic>
<Psychic Terrain> <effect description> <This protects Pokémon on the ground from priority
moves and powers up Psychic-type moves for
five turns.>
<Psychic Surge> <effect description> <Turns the ground into Psychic Terrain when
the Pokémon enters a battle.>


[None, None]

In [73]:
embed_model.get_query_embedding("What color is Pikachu?")

Batches: 100%|██████████| 1/1 [00:00<00:00, 11.45it/s]


[0.013918144628405571,
 0.010103491134941578,
 -0.0029478929936885834,
 0.025255318731069565,
 -0.024396510794758797,
 0.010354544967412949,
 -0.012173263356089592,
 0.0032943696714937687,
 -0.09720435738563538,
 0.012091472744941711,
 -0.008735760115087032,
 0.01527224387973547,
 0.007638393435627222,
 -0.021960949525237083,
 0.03280056267976761,
 -0.08146408200263977,
 0.0067080180160701275,
 -0.026102768257260323,
 0.019261838868260384,
 0.00051346723921597,
 -0.02615729346871376,
 0.02991742081940174,
 0.012400462292134762,
 0.035565562546253204,
 0.006379717495292425,
 -0.04566677659749985,
 -0.07444821298122406,
 -0.009689990431070328,
 -0.04008793458342552,
 0.01361369900405407,
 0.024228952825069427,
 -0.04263141751289368,
 -0.0535460039973259,
 -0.02025439217686653,
 1.2996208624826977e-06,
 -0.03838735818862915,
 0.044721636921167374,
 0.02928580902516842,
 0.06678709387779236,
 0.0011090098414570093,
 -0.014865274541079998,
 -0.02515535056591034,
 0.03177135810256004,
 -0.01