## Overview

This notebook showcase the use of Llama index to combine a chatbot and RDF knowledge graph.

The RDF data is converted into human-readable triples, which are then used to build a vector index. For this, triples are split into documents (by subject), documents embedding are then computed, and used to build a key-value store mapping the embeddings to the document contents.

When the user enters a question, llama index will compute the question embedding, look up the top k similar documents and inject their content (triples) as context in the prompt.

The triples used as context can then be returned along with the question to verify the reliability of the answer.

In [7]:

import os
from importlib import reload
from pathlib import Path
from typing import Optional, List, Mapping, Any

from dotenv import load_dotenv
import torch
from langchain.embeddings import HuggingFaceEmbeddings 
from llama_index import LangchainEmbedding, PromptHelper
from llama_index import GPTSimpleVectorIndex
from llama_index import LLMPredictor, ServiceContext
from rdflib import ConjunctiveGraph, Graph

import aikg.utils.llm as akllm
import aikg.utils.rdf as akrdf
reload(akrdf)
reload(akllm)
load_dotenv()

Loading checkpoint shards: 100%|██████████| 39/39 [00:22<00:00,  1.70it/s]


True

In [2]:

question = "What is the tallest Pokemon?"
prompt_template = """Question: {question}

Answer: Let's think step by step."""



## Test data

For the sake of this example, we use the [pokemon knowledge graph](https://www.pokemonkg.org/) developed by Kevin Haller, which aggregates data from multiple source into a single RDF graph.

In [64]:
ontology = '/tmp/ontology.nt'
instances = '/tmp/instances.nq'

In [65]:
!curl https://www.pokemonkg.org/ontology/ontology.nt -o $ontology
!curl https://www.pokemonkg.org/download/dump/poke-a.nq.gz -o - | gzip -dc  > $instances

  % Total    % Received % Xferd  Average Speed   Time    Time     Time  Current
                                 Dload  Upload   Total   Spent    Left  Speed
100  108k  100  108k    0     0   385k      0 --:--:-- --:--:-- --:--:--  386k
  % Total    % Received % Xferd  Average Speed   Time    Time     Time  Current
                                 Dload  Upload   Total   Spent    Left  Speed
100 1599k  100 1599k    0     0  3869k      0 --:--:-- --:--:-- --:--:-- 3871k


In [66]:
!head $instances

<https://pokemonkg.org/dataset/artwork/sugimori-early-japan> <http://www.w3.org/1999/02/22-rdf-syntax-ns#type> <http://www.w3.org/ns/dcat#Dataset> <https://pokemonkg.org/dataset/artwork/sugimori-early-japan> .
<https://pokemonkg.org/dataset/artwork/sugimori-early-japan> <http://www.w3.org/1999/02/22-rdf-syntax-ns#type> <http://www.w3.org/ns/prov#Entity> <https://pokemonkg.org/dataset/artwork/sugimori-early-japan> .
<https://pokemonkg.org/dataset/artwork/sugimori-early-japan> <http://purl.org/dc/terms/accrualPeriodicity> <http://purl.org/linked-data/sdmx/2009/code#freq-A> <https://pokemonkg.org/dataset/artwork/sugimori-early-japan> .
<https://pokemonkg.org/dataset/artwork/sugimori-early-japan> <http://purl.org/dc/terms/description> "This dataset provides meta information about early version of official Pokémon artwork in Japan of Pokémon in the national Pokédex."@en <https://pokemonkg.org/dataset/artwork/sugimori-early-japan> .
<https://pokemonkg.org/dataset/artwork/sugimori-early-japan

The instance graph is composed of multiple "contexts", or _named graphs_.
Here, we remove the contexts and split graphs by subject to create "instance graphs".
The ontology is stored in a separate graph which will be used later to obtain the labels of properties and classes.

In [67]:
reload(akrdf)
# From quads to triples (dropping named graphs context)
instance_quads = ConjunctiveGraph()
instance_quads.parse(instances, format='nquads')

# One graph per subject (to generate documents)
instance_graphs = [s for g in instance_quads.contexts() for s in akrdf.split_by_subject(g)]
schema_graph = Graph().parse(ontology, format='nt')

Llama index allows us to specify custom models (instead of ChatGPT) for both the text generation (i.e. question answering), and computing the embeddings used for similarity search.

Here, we define our own Llm class for text-generation, which uses the `Alpaca-lora-7b`. For the embedding, we use the default model from HuggingFaceEmbedding, which is `sentence-transformers/all-mpnet-base-v2`

In [9]:
reload(akllm)
# define prompt helper
# set maximum input size
max_input_size = 2048
# set number of output tokens
num_output = 256
# set maximum chunk overlap
max_chunk_overlap = 20

prompt_helper = PromptHelper(max_input_size, num_output, max_chunk_overlap)


# define our LLM
llm_predictor = LLMPredictor(llm=akllm.CustomLLM())
embed_model = LangchainEmbedding(HuggingFaceEmbeddings())

service_context = ServiceContext.from_defaults(
    llm_predictor=llm_predictor,
    prompt_helper=prompt_helper,
    embed_model=embed_model,
)


Loading checkpoint shards: 100%|██████████| 39/39 [00:37<00:00,  1.03it/s]
INFO:sentence_transformers.SentenceTransformer:Load pretrained SentenceTransformer: sentence-transformers/all-mpnet-base-v2
INFO:sentence_transformers.SentenceTransformer:Use pytorch device: cpu


The instance graphs are then converted to human readable triples and loaded into a vector store.

During conversion, we also include the labels of the classes and properties, which are stored in the ontology graph.

In [74]:
reload(akrdf)
import joblib
# Load each instance graph as a document
# The schema is injected in each graph to
# provide human readable context
os.environ["TOKENIZERS_PARALLELISM"] = "false"
loader = akrdf.CustomRDFReader()
doc_graphs = map(lambda g: g | schema_graph, instance_graphs)
runner = joblib.Parallel(n_jobs=12)
documents = runner(joblib.delayed(loader.load_data)(g) for g in doc_graphs)
documents = [doc for doc in documents if doc.text]
index = GPTSimpleVectorIndex.from_documents(documents, service_context=service_context)


Batches: 100%|██████████| 1/1 [00:00<00:00,  1.62it/s]
Batches: 100%|██████████| 1/1 [00:00<00:00,  4.28it/s]
Batches: 100%|██████████| 1/1 [00:00<00:00,  1.98it/s]
Batches: 100%|██████████| 1/1 [00:00<00:00,  2.29it/s]
Batches: 100%|██████████| 1/1 [00:01<00:00,  1.57s/it]
Batches: 100%|██████████| 1/1 [00:00<00:00,  4.96it/s]
Batches: 100%|██████████| 1/1 [00:00<00:00,  3.91it/s]
Batches: 100%|██████████| 1/1 [00:00<00:00,  5.44it/s]
Batches: 100%|██████████| 1/1 [00:00<00:00,  1.79it/s]
Batches: 100%|██████████| 1/1 [00:01<00:00,  1.59s/it]
Batches: 100%|██████████| 1/1 [00:00<00:00,  1.45it/s]
Batches: 100%|██████████| 1/1 [00:00<00:00,  2.08it/s]
Batches: 100%|██████████| 1/1 [00:00<00:00,  2.51it/s]
Batches: 100%|██████████| 1/1 [00:00<00:00,  3.80it/s]
Batches: 100%|██████████| 1/1 [00:00<00:00,  2.86it/s]
Batches: 100%|██████████| 1/1 [00:00<00:00,  2.06it/s]
Batches: 100%|██████████| 1/1 [00:00<00:00,  4.90it/s]
Batches: 100%|██████████| 1/1 [00:00<00:00,  4.54it/s]
Batches: 1

In [76]:
len(documents)

7045

When querying the index, the query embedding is computed and used to look up the index. The top K most similary documents are then selected and injected as context in the prompt.

In [118]:
response = index.query(
    "What abilities does Pikachu have?",
    similarity_top_k=2
)

Batches: 100%|██████████| 1/1 [00:00<00:00,  8.21it/s]
INFO:llama_index.token_counter.token_counter:> [query] Total LLM token usage: 265 tokens
INFO:llama_index.token_counter.token_counter:> [query] Total embedding token usage: 6 tokens


In [119]:
import textwrap
textwrap.wrap(response.response, width=70)

[' Answer: Pikachu has the ability to use electricity, and may have the',
 'hidden ability of lightning-rod.']

The triples used to generate the answer can be extracted from the result to verify the accuracy of the answer.

In [121]:
[ print(node.node.text) for node in response.source_nodes]

<Pikachu> <has Type> <Electric>
<Pikachu> <may have ability> <static>
<Pikachu> <in egg group> <Field>
<Pikachu> <in egg group> <Fairy>
<Pikachu> <has Type> <Electric>
<Pikachu> <has shape> <Quadruped>
<Pikachu> <has colour> <Yellow>
<Pikachu> <may have hidden ability> <lightning-rod>


[None, None]

The index can be persisted to disk for reuse. As we are using the SimpleVectorIndex, it is stored as a plain JSON file. Vector databases are also supported and should provide much better performances for larger datasets. 

In [79]:
index.save_to_disk('/data/cyril/pokemon_index.json')

In [81]:
embed_model.get_query_embedding(question)

Batches: 100%|██████████| 1/1 [00:00<00:00, 11.39it/s]


[0.06434361636638641,
 -0.023223863914608955,
 -0.020594755187630653,
 0.06523242592811584,
 0.02275378443300724,
 -0.021171800792217255,
 0.00831054337322712,
 0.017303984612226486,
 -0.05561314523220062,
 0.03522009402513504,
 0.024702714756131172,
 -0.019364766776561737,
 0.010412730276584625,
 0.0382516011595726,
 -0.006263944320380688,
 -0.11581674218177795,
 -0.0048931557685136795,
 -0.05765917897224426,
 0.008177483454346657,
 -0.013494142331182957,
 -0.059755418449640274,
 -0.03324386477470398,
 0.0007755070691928267,
 0.028279809281229973,
 0.008532238192856312,
 -0.006932844873517752,
 -0.048014722764492035,
 0.0023540949914604425,
 0.038182951509952545,
 -0.015616008080542088,
 0.02159934863448143,
 -0.10398890823125839,
 -0.014894835650920868,
 0.0439520999789238,
 1.1566714874788886e-06,
 -0.03699086233973503,
 0.04262537136673927,
 0.03867626562714577,
 0.005172022618353367,
 -0.02297866903245449,
 -0.08988142758607864,
 -0.029855752363801003,
 -0.006618781015276909,
 -0.