In [2]:
import gc
import re
import csv
import os
import torch
import Stemmer
import pandas as pd
from tqdm import tqdm
from utils.graph import KGraphPreproc
from utils.llm.mistral import MistralLLM
from utils.prompt import GRAPH_QA_PROMPT, ENTITY_PROMPT, NO_CONTEXT_PROMPT, EVALUATE_CONTEXT_PROMPT

# ToG Algorithm pseudocode
given a list of entities $E^{D-1}$:
1. Select all relations $R^D$
2. Select top-N relevant $E^{D-1}$-$R^D$-? | ?-$R^D$-$E^{D-1}$
3. Extend P by top-N paths
4. Decide if enough to answer the question (YES/NO)
5. Repeat (NO), return (YES)


In [1]:
a = set([1,2,3])
b = set([3,4,5])
b.difference(a)

{4, 5}

In [None]:
class ToGChain(GraphQAChain):
    max_exploration_depth: int = 7
    beam_width: int = 5
    max_triplets: int = 250
    plain_qa_chain: LLMChain
    evaluate_context_chain: LLMChain

    def tog_answer(self, question):
        """
        Think-on-Graph QA implementation
        """
        explored_entities = set()
        P = [
            # step: [triplet1, ...], 
        ] # max_depth X beam_width
        max_depth = self.max_exploration_depth
        beam_width = self.beam_width

        # init entities from question
        extracted_entities = set(
            self.extract_entities(question)
        )
        unexplored_entities = extracted_entities.difference(explored_entities)
        explored_entities.update(unexplored_entities)
        # extract all triplets and rank them
        triplets = graph.extract_triplets(unexplored_entities)
        # ranked_triplets: list[triplet]
        ranked_triplets = ranker.rank(question, triplets, top_k=beam_width)
        # extend the paths by next step
        P.append(ranked_triplets)

        # check at least one path to explore is initialised
        if len(P):
            for depth in range(max_depth):
                _run_manager.on_text(f"Exploring depth: {depth}", end="\n", verbose=self.verbose)
                ######## Extraction
                tail_entities = set(
                    map(
                        # extract tails
                        lambda triplet: triplet[-1],
                        P[depth-1]
                    )
                )
                # select unexplored tails and update
                unexplored_entities = tail_entities.difference(explored_entities)
                explored_entities.update(unexplored_entities)
                # extract and rank triplets
                triplets = extract_triplets(unexplored_entities)
                ranked_triplets = ranker.rank(question, triplets, top_k=beam_width)
                # extend the paths by next step
                P.append(ranked_triplets)
                ######## Reasoning
                is_enough = self.check_enough_context_to_answer(question, reasoning_chain=P)
                if is_enough:
                    return self.generate_answer_with_context(question, reasoning_chain=P)
                else:
                    continue
        # otherwise, use internal knowledge
        return self.generate_answer_no_context(question)
    
    #### GRAPH OPEATIONS
    def extract_triplets(self, entities):
        subgraph_entities = set()
        for entity in entities:
            processed_entity = preprocess_text(entity)
            node = self.graph.preprocessed_nodes.get(processed_entity)
            if node is None:
                continue
            bfs_nodes = nx.single_source_shortest_path_length(
                self.graph._graph,
                node,
                cutoff=self.exploration_depth
            ).keys()
            subgraph_entities.update(bfs_nodes)
        subgraph = self.graph._graph.subgraph(subgraph_entities)
        # 
        triplets = []
        mid2name_dict = self.graph.mid2name
        for head, tail, attrs in subgraph.edges(data=True):
            rel = attrs.get("relation", None)
            triplets.append(
                f"{mid2name_dict.get(head, '')} {rel} {mid2name_dict.get(tail, '')}"
            )

    
    #### PROMPTING AND CHAINS
    def generate_answer_with_context(question, reasoning_chain):
        context = "\n".join(reasoning_chain)
        _run_manager.on_text("Full Context:", end="\n", verbose=self.verbose)
        _run_manager.on_text(context, color="green", end="\n", verbose=self.verbose)
        return self.qa_chain(
            {"question": question, "context": context},
            callbacks=_run_manager.get_child(),
        )
    
    def generate_answer_no_context(question):
        return self.plain_qa_chain(
            {"question": question},
            callbacks=_run_manager.get_child(),
        )
    
    def check_enough_context_to_answer(question, reasoning_chain):
        context = "\n".join(reasoning_chain)
        _run_manager.on_text("Checking if the context is sufficient:", end="\n", verbose=self.verbose)
        _run_manager.on_text("Full Context:", end="\n", verbose=self.verbose)
        _run_manager.on_text(context, color="green", end="\n", verbose=self.verbose)
        answer = self.evaluate_context_chain(
            {"question": question, "context": context},
            callbacks=_run_manager.get_child(),
        )

        if "yes" in answer.strip().lower():
            return True
        return False

    def extract_entities(self, string):
        entity_string = self.entity_extraction_chain.run(question)
        _run_manager.on_text("Entities Extracted:", end="\n", verbose=self.verbose)
        _run_manager.on_text(
            entity_string, color="green", end="\n", verbose=self.verbose
        )
        entities = get_entities(entity_string)
        return entities

    #### INITIALIZATION AND INVOKATION

    def _call(
        self,
        inputs: Dict[str, Any],
        run_manager: Optional[CallbackManagerForChainRun] = None,
    ) -> Dict[str, str]:
        """Extract entities, look up info and answer question."""
        _run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
        question = inputs[self.input_key]        
        # 
        result = self.tog_answer(question)
        return {self.output_key: result[self.qa_chain.output_key]}
    
    @classmethod
    def from_llm(
        cls,
        llm: BaseLanguageModel,
        qa_prompt: BasePromptTemplate = GRAPH_QA_PROMPT,
        qa_no_context_prompt: BasePromptTemplate = NO_CONTEXT_PROMPT,
        evaluate_context_prompt: BasePromptTemplate = EVALUATE_CONTEXT_PROMPT,
        entity_prompt: BasePromptTemplate = ENTITY_PROMPT,
        **kwargs: Any,
    ) -> GraphQAChain:
        """Initialize from LLM."""
        qa_chain = LLMChain(llm=llm, prompt=qa_prompt)
        plain_qa_chain = LLMChain(llm=llm, prompt=qa_no_context_prompt)
        evaluate_context_chain = LLMChain(llm=llm, prompt=evaluate_context_prompt)
        entity_chain = LLMChain(llm=llm, prompt=entity_prompt)

        return cls(
            qa_chain=qa_chain,
            plain_qa_chain=plain_qa_chain,
            evaluate_context_chain=evaluate_context_chain,
            entity_extraction_chain=entity_chain,
            **kwargs,
        )

In [None]:
fbqa = pd.read_json("/datasets/FreebaseQA/FreebaseQA-eval.json")
fbkb_graph = KGraphPreproc.get_fbkb_graph()

In [None]:
mistral = MistralLLM()
chain = ToGChain.from_llm(
    llm=mistral,
    graph=fbkb_graph,
    verbose=False,
)