In [None]:
import pandas as pd
import torch
import gc
import Stemmer
import re

## Explore nodes and prompts

In [126]:
fbqa = pd.read_json("/datasets/FreebaseQA/FreebaseQA-eval.json")
fbqa.head()

Unnamed: 0,Dataset,Version,Questions
0,FreebaseQA-eval,1,"{'Question-ID': 'FreebaseQA-eval-0', 'RawQuest..."
1,FreebaseQA-eval,1,"{'Question-ID': 'FreebaseQA-eval-1', 'RawQuest..."
2,FreebaseQA-eval,1,"{'Question-ID': 'FreebaseQA-eval-2', 'RawQuest..."
3,FreebaseQA-eval,1,"{'Question-ID': 'FreebaseQA-eval-3', 'RawQuest..."
4,FreebaseQA-eval,1,"{'Question-ID': 'FreebaseQA-eval-4', 'RawQuest..."


In [127]:
def get_fbqa_data(question_row):
    """
    Takes in a dataset row and returns Q and A as strings
    """
    question = question_row.Questions.get("RawQuestion", None)
    parse = question_row.Questions.get("Parses", [None])[0]
    if not parse:
        print(f"error in question: {question}")
        return question, None
    answer = parse.get("Answers")
    return question, answer

In [128]:
get_fbqa_data(fbqa.iloc[0])

("Who is the female presenter of the Channel 4 quiz show '1001 things you should know'?",
 [{'AnswersMid': 'm.0216y_', 'AnswersName': ['sandi toksvig']}])

### Text preprocessing

In [129]:
stemmer = Stemmer.Stemmer("english")
alphanum = r"[a-zA-Z0-9_-]*"
non_alphanum = r"[^a-zA-Z0-9]"

def preprocess_text(text):
    # tokenize
    text = text.lower()
    text = re.sub(non_alphanum, "\n", text)
    tokens = list(filter(lambda token: bool(token), text.split("\n")))
    # normalize
    tokens = list(map(lambda token: stemmer.stemWord(token), tokens))
    return " ".join(tokens)

In [130]:
preprocess_text("Henry_Fonda")

'henri fonda'

## Load KB
combine train, dev and test

In [131]:
mid2name = pd.read_csv("/datasets/FB15k-237/mid2name.txt", sep="\t", header=None)
print(len(mid2name), "records")
mid2name[1] = mid2name[1].apply(preprocess_text)
mid2name.head()

14951 records


Unnamed: 0,0,1
0,/m/06rf7,schleswig holstein
1,/m/0c94fn,gari rydstrom
2,/m/016ywr,jeremi iron
3,/m/01yjl,chicago cub
4,/m/02hrh1q,actor gb


In [132]:
mid2name_dict = dict(zip(mid2name[0], mid2name[1]))

In [133]:
name2mid = dict(zip(mid2name[1], mid2name[0]))
name2mid

{'schleswig holstein': '/m/06rf7',
 'gari rydstrom': '/m/0c94fn',
 'jeremi iron': '/m/016ywr',
 'chicago cub': '/m/01yjl',
 'actor gb': '/m/02hrh1q',
 'anthoni horowitz': '/m/03ftmg',
 'chapman univers': '/m/04p_hy',
 'freddi rodriguez': '/m/06151l',
 'spongebob squarep': '/m/07vqnc',
 'dream theater': '/m/0jn38',
 'focus featur': '/m/024rbz',
 'henri v': '/m/016ywb',
 'himesh reshammiya': '/m/01pr_j6',
 'maggi gyllenha': '/m/023mdt',
 'iron': '/m/025rw19',
 'brompton cemeteri': '/m/034wx3',
 'univers of california irvin': '/m/0c5x_',
 'aacta award for best guest or support actor in a televis drama': '/m/0fqpg6b',
 'hunter s thompson': '/m/03qcq',
 'hang drawn and quarter': '/m/0lf_w',
 'usc trojan footbal': '/m/07k53y',
 'the pirat movi': '/m/06zn1c',
 'fantast four rise of the silver surfer': '/m/09sh8k',
 'port elizabeth': '/m/06f0y3',
 'univers of the pacif': '/m/033x5p',
 'univers of idaho': '/m/0f102',
 'columbia busi school': '/m/095kp',
 'fame': '/m/02zk08',
 'rector': '/m/0hm4

In [134]:
fbkb = pd.DataFrame()
for split in ["train", "valid", "test"]:
    path = f"/datasets/FB15k-237/{split}.txt"
    split_df = pd.read_csv(path, sep="\t", header=None)
    print(split, len(split_df), "records")
    fbkb = pd.concat([fbkb, split_df])
print("total", len(fbkb))
fbkb.rename(columns={0: "subject", 1: "relation", 2: "object"}, inplace=True)
fbkb.head()

train 272115 records
valid 17535 records
test 20466 records
total 310116


Unnamed: 0,subject,relation,object
0,/m/027rn,/location/country/form_of_government,/m/06cx9
1,/m/017dcd,/tv/tv_program/regular_cast./tv/regular_tv_app...,/m/06v8s0
2,/m/07s9rl0,/media_common/netflix_genre/titles,/m/0170z3
3,/m/01sl1q,/award/award_winner/awards_won./award/award_ho...,/m/044mz_
4,/m/0cnk2q,/soccer/football_team/current_roster./sports/s...,/m/02nzb8


In [135]:
from langchain_community.graphs.networkx_graph import NetworkxEntityGraph, KnowledgeTriple
fbkb_graph = NetworkxEntityGraph()

In [136]:
for i, r in fbkb.iterrows():
    # subject = preprocess_text(mid2name_dict[r.subject])
    # object_ = preprocess_text(mid2name_dict[r.object])
    triplet = KnowledgeTriple(
        r.subject,
        r.relation,
        r.object,
    )
    fbkb_graph.add_triple(triplet)


In [137]:
print(len(fbkb_graph.get_triples()))
fbkb_graph.get_triples()[1]

283868


('/m/027rn',
 '/m/0l6ny',
 '/olympics/olympic_participating_country/medals_won./olympics/olympic_medal_honor/olympics')

In [14]:
from typing import Any, Dict, Iterator, List, Mapping, Optional
from langchain_core.callbacks.manager import CallbackManagerForLLMRun
from langchain_core.language_models.llms import LLM
from langchain_core.outputs import GenerationChunk
from langchain.chains import GraphQAChain

In [15]:
from mistral_inference.transformer import Transformer
from mistral_common.tokens.tokenizers.mistral import MistralTokenizer

from mistral_inference.generate import generate
from mistral_common.protocol.instruct.messages import UserMessage
from mistral_common.protocol.instruct.request import ChatCompletionRequest

In [44]:
# device="cuda:1"
class MistralLLM(LLM):
    tokenizer: MistralTokenizer
    model: Transformer

    def __init__(self):
        tokenizer = MistralTokenizer.from_file("/models/M7B/tokenizer.model.v3")  # change to extracted tokenizer file
        model = Transformer.from_folder("/models/M7B")  # change to extracted model dir
        super().__init__(model=model, tokenizer=tokenizer) 

    def get_response(self, prompt: str) -> str:
        tokenizer = self.tokenizer
        model = self.model
        completion_request = ChatCompletionRequest(messages=[UserMessage(content=prompt)])
        tokens = tokenizer.encode_chat_completion(completion_request).tokens
        out_tokens, _ = generate([tokens], model, max_tokens=1024, temperature=0.35, eos_id=tokenizer.instruct_tokenizer.tokenizer.eos_id)
        result = tokenizer.instruct_tokenizer.tokenizer.decode(out_tokens[0])
        return result

    def _call(
        self,
        prompt: str,
        stop: Optional[List[str]] = None,
        run_manager: Optional[CallbackManagerForLLMRun] = None,
        **kwargs: Any,
    ) -> str:
        if stop is not None:
            raise ValueError("stop kwargs are not permitted.")
        clear_cache()
        return self.get_response(prompt)
    
    @property
    def _identifying_params(self) -> Dict[str, Any]:
        """Return a dictionary of identifying parameters."""
        return {
            # The model name allows users to specify custom token counting
            # rules in LLM monitoring applications (e.g., in LangSmith users
            # can provide per token pricing for their model and monitor
            # costs for the given LLM.)
            "model_name": "Mistral-7B-0.3",
        }

    @property
    def _llm_type(self) -> str:
        """Get the type of language model used by this chat model. Used for logging purposes only."""
        return "mistral-7b-0.3"

In [17]:
# torch.set_default_device('cuda:1')
# torch.get_default_device()

In [None]:
mistral = MistralLLM()

In [43]:
# del mistral
def clear_cache():
    gc.collect()
    torch.cuda.empty_cache()

In [20]:
from langchain_core.prompts.prompt import PromptTemplate

EXTRACT_ENTITIES_TEMPLATE = """
Below is a question, extract all entities and return them separeted by a coma.
Only return the entities with no extra text and separators. 
Do not introduce or deduce anything not present in the question.
Question:
{question}"""

entity_prompt = PromptTemplate(
    input_variables=["question"],
    template=EXTRACT_ENTITIES_TEMPLATE
)


In [21]:
LEN_LIMITED_PROMPT = """
You will be given a context containing relevant entity triplets and a question.

Using the context, you should answer the question below in under a sentence with no other text but the answer.
Context:
{context}

Question:
{question}
"""

GRAPH_QA_PROMPT = PromptTemplate(
    template=LEN_LIMITED_PROMPT, input_variables=["context", "question"]
)


In [22]:
def get_closest_entity_mid(entity):
    entity = preprocess_text(entity)
    

In [None]:
!pip install tabulate

: 

In [49]:
from langchain_core.callbacks.manager import CallbackManagerForChainRun
from langchain_community.graphs.networkx_graph import get_entities

class GraphChain(GraphQAChain):
    def _call(
        self,
        inputs: Dict[str, Any],
        run_manager: Optional[CallbackManagerForChainRun] = None,
    ) -> Dict[str, str]:
        """Extract entities, look up info and answer question."""
        _run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
        question = inputs[self.input_key]

        entity_string = self.entity_extraction_chain.run(question)

        _run_manager.on_text("Entities Extracted:", end="\n", verbose=self.verbose)
        _run_manager.on_text(
            entity_string, color="green", end="\n", verbose=self.verbose
        )
        entities = get_entities(entity_string)
        context = ""
        all_triplets = []
        for entity in entities:
            # introduce preprocessing
            processed_entity = preprocess_text(entity)
            entity_mid = name2mid.get(processed_entity, None)
            # continue
            triplets = self.graph.get_entity_knowledge(entity_mid)
            for triplet in triplets:
                t = triplet.split()
                all_triplets.append(" ".join([mid2name_dict.get(t[0], ""), t[1], mid2name_dict.get(t[2], "")]))
            print(f"Extracted {len(all_triplets)} triplets")
        # truncate at 600 triplets
        context = "\n".join(all_triplets[:600])
        _run_manager.on_text("Full Context:", end="\n", verbose=self.verbose)
        _run_manager.on_text(context, color="green", end="\n", verbose=self.verbose)
        result = self.qa_chain(
            {"question": question, "context": context},
            callbacks=_run_manager.get_child(),
        )
        return {self.output_key: result[self.qa_chain.output_key]}

In [50]:
chain = GraphChain.from_llm(
    llm=mistral, 
    graph=fbkb_graph, 
    qa_prompt=GRAPH_QA_PROMPT,
    entity_prompt=entity_prompt,
    verbose=True,
)


In [33]:
N = 84
print(fbqa.iloc[N].Questions["Parses"])
q = get_fbqa_data(fbqa.iloc[N])
print(q[1])
q[0]

[{'Parse-Id': 'FreebaseQA-eval-84.P0', 'PotentialTopicEntityMention': 'africa', 'TopicEntityName': 'africa', 'TopicEntityMid': 'm.0dg3n1', 'InferentialChain': 'base.locations.continents.countries_within', 'Answers': [{'AnswersMid': 'm.0h3y', 'AnswersName': ['algeria']}]}]
[{'AnswersMid': 'm.0h3y', 'AnswersName': ['algeria']}]


'Following the creation of South Sudan, which is now the largest country by area in Africa?'

In [34]:
mistral(q[0])

"The largest country by area in Africa after the creation of South Sudan is Algeria. It is approximately 2.38 million square kilometers, significantly larger than South Sudan's 644,329 square kilometers."

In [54]:
clear_cache()
r = chain.invoke(q[0])
print(r["result"])



[1m> Entering new GraphChain chain...[0m
Entities Extracted:
[32;1m[1;3mAlgeria, Democratic Republic of the Congo, Sudan, Libya, Chad, Mali, Niger, Angola, Republic of the Congo, Central African Republic, South Africa, Zambia, Tanzania, Ethiopia, Nigeria, Egypt, Tunisia, Morocco, Mauritania, Senegal, Gabon, Botswana, Namibia, Zimbabwe, Uganda, Kenya, Somalia, Eritrea, Benin, Burkina Faso, Ghana, Guinea, Cote d'Ivoire, Malawi, Mozambique, Lesotho, Swaziland, Gambia, Liberia, Sierra Leone, Rwanda, Burundi, Cape Verde, Comoros, Djibouti, Equatorial Guinea, Guinea-Bissau, Madagascar, Mauritius, Seychelles, Sao Tome and Principe, Togo, and South Sudan (entities present in the question but not the largest country by area in Africa are not included in the answer)[0m
Extracted 38 triplets
Extracted 60 triplets
Extracted 90 triplets
Extracted 108 triplets
Extracted 126 triplets
Extracted 147 triplets
Extracted 165 triplets
Extracted 191 triplets
Extracted 191 triplets
Extracted 207 tripl

In [29]:
q0 = fbqa.iloc[1]
q = q0.Questions["RawQuestion"]

In [30]:
q0.Questions

{'Question-ID': 'FreebaseQA-eval-1',
 'RawQuestion': 'Who produced the film 12 Angry Men, which was scripted by Reginald Rose, starred Henry Fonda and was directed by Sidney Lumet?',
 'ProcessedQuestion': 'who produced the film 12 angry men, which was scripted by reginald rose, starred henry fonda and was directed by sidney lumet',
 'Parses': [{'Parse-Id': 'FreebaseQA-eval-1.P0',
   'PotentialTopicEntityMention': '12 angry men',
   'TopicEntityName': '12 angry men',
   'TopicEntityMid': 'm.0m_tj',
   'InferentialChain': 'film.film.produced_by',
   'Answers': [{'AnswersMid': 'm.0cj8x', 'AnswersName': ['henry fonda']}]}]}

In [31]:
mid2name[mid2name[1].str.contains("Lumet")]

Unnamed: 0,0,1


In [32]:
r = mistral()
print(r)

TypeError: BaseLLM.__call__() missing 1 required positional argument: 'prompt'

In [None]:
chain.invoke(q)

In [None]:
mid2name[mid2name[1].str.contains("")]

In [None]:
list(fbkb_graph.get_neighbors("Channel_4"))

In [None]:
mid = name2mid["1001 things you should know"]
fbkb_graph.get_neighbors(mid)