# GraphRAG on Financial Knowledge Graphs

In [1]:
from langchain.indexes import GraphIndexCreator
from langchain_openai import OpenAI
from langchain.indexes.graph import NetworkxEntityGraph

In [2]:
import getpass
import os

if not os.environ.get("NVIDIA_API_KEY", "").startswith("nvapi-"):
    nvapi_key = getpass.getpass("Enter your NVIDIA API key: ")
    assert nvapi_key.startswith("nvapi-"), f"{nvapi_key[:5]}... is not a valid key"
    os.environ["NVIDIA_API_KEY"] = nvapi_key

Enter your NVIDIA API key:  ······································································


In [3]:
## Core LC Chat Interface
from langchain_nvidia_ai_endpoints import ChatNVIDIA

llm = ChatNVIDIA(model="mixtral_8x7b")

In [4]:
import argparse
import os, ast

def load_mapping(file_path):
    """Load mapping from ID to name."""
    mapping = {}
    with open(file_path, 'r') as file:
        for line in file:
            parts = line.strip().split('\t')
            if len(parts) == 2:
                id_str, name = parts
                mapping[id_str] = name
#                 print(mapping)
    return mapping



def load_triplets_from_files(directory_path):
    """Load triplets from all files in the given directory."""
    triplets = []
    for filename in os.listdir(directory_path):
        if filename.endswith('.txt'):
            file_path = os.path.join(directory_path, filename)
            with open(file_path, 'r') as file:
                data = file.read()
                data_dict = ast.literal_eval(data)  # Safely evaluate the string to a dictionary
                triplets.extend(data_dict.get('output', []))
    return triplets


def write_knowledge_graph(triplets, entity_mapping, relation_mapping, output_file_path):
    """Write the knowledge graph data to an output file."""
    with open(output_file_path, 'w') as file:
        for triplet in triplets:
            if len(triplet) == 5:
                head, head_type, relation, tail, tail_type = triplet
                head_name = entity_mapping.get(head, head)
                relation_name = relation_mapping.get(relation, relation)
                tail_name = entity_mapping.get(tail, tail)
                file.write(f"{head_name}\t{relation_name}\t{tail_name}\n")


def main(entity_file, relation_file, triplets_directory, output_file):
    entity_mapping = load_mapping(entity_file)
    relation_mapping = load_mapping(relation_file)
    triplets = load_triplets_from_files(triplets_directory)
    write_knowledge_graph(triplets, entity_mapping, relation_mapping, output_file)
    print("Knowledge graph data has been successfully saved.")

# if __name__ == "__main__":
#     parser = argparse.ArgumentParser(description="Build a knowledge graph from triplets and mappings.")
#     parser.add_argument("entity_file", help="Path to the entity mapping file.")
#     parser.add_argument("relation_file", help="Path to the relation mapping file.")
#     parser.add_argument("triplets_directory", help="Directory containing triplet files.")
#     parser.add_argument("output_file", help="Path to the output file for the knowledge graph.")
    
#     args = parser.parse_args()
#     main(args.entity_file, args.relation_file, args.triplets_directory, args.output_file)

# if __name__ == "__main__":
#     parser = argparse.ArgumentParser(description="Build knowledge graph from entity, relation, and triplet files.")
#     parser.add_argument("entity_file", type=str, help="Path to the entity file.")
#     parser.add_argument("relation_file", type=str, help="Path to the relation file.")
#     parser.add_argument("triplets_directory", type=str, help="Directory containing triplet files.")
#     parser.add_argument("output_file", type=str, help="Path to the output file where the knowledge graph will be saved.")
    
#     args = parser.parse_args()
main("test_entity2id.txt", "test_relation2id.txt", "../../data/news_articles/processed/us-financial-news-articles/output/test_set/", "test_kg.txt")


Knowledge graph data has been successfully saved.


In [None]:
index_creator = GraphIndexCreator(llm=ChatNVIDIA(model="mixtral_8x7b", temperature=0))

In [None]:
loaded_graph = NetworkxEntityGraph.from_gml("test_graph.gml")

In [5]:
def load_entities(filename):
    lines = open(filename).readlines()
    all_entities = {}
    for line in lines:
        entity, id = line.strip().split("\t")
        all_entities[int(id)] = entity
    
    return all_entities

def load_relations(filename):
    lines = open(filename).readlines()
    all_relations = {}
    for line in lines:
        relation, id = line.strip().split("\t")
        all_relations[int(id)] = relation
    
    return all_relations

def get_relation_tuples(all_entities, all_relations, dataset):
    # load the data
    lines = open(dataset).readlines()
    all_tuples = []
    for line in lines:
        subject, relation, obj= line.strip().split("\t")
        all_tuples.append((all_entities[int(subject)], all_relations[int(relation)], all_entities[int(obj)]))
    return all_tuples

In [6]:
ENTITY_ID_MAP = "test_entity2id.txt"
RELATION_ID_MAP = "test_relation2id.txt"
# test/train data: The first four columns correspond to subject (entity), relation, object (entity), and time.
DATASET = "test_kg.txt"

all_entities = load_entities(ENTITY_ID_MAP)
all_relations = load_relations(RELATION_ID_MAP)
knowledge_graph = get_relation_tuples(all_entities, all_relations, DATASET)

In [7]:
from langchain.indexes.graph import NetworkxEntityGraph
from langchain.graphs.networkx_graph import KnowledgeTriple

graph = NetworkxEntityGraph()
for item in knowledge_graph:
    kt = KnowledgeTriple(item[0], item[1], item[2])
    graph.add_triple(kt)

In [8]:
graph.get_triples()

[('California', 'Minimum Wage', 'Raise'),
 ('California', 'More than 2 million workers', 'Impact'),
 ('Maine', 'Minimum Wage', 'Raise'),
 ('Maine', 'An estimated 59,000 workers', 'Impact'),
 ('National Employment Law Project',
  '18 states and 19 cities will boost minimum wage',
  'Announce'),
 ('Economic Policy Institute', 'Figures', 'Compile'),
 ('Federal Minimum Wage', '$7.25 an hour', 'Currently'),
 ('Federal Minimum Wage', 'Inflation', 'Not Pegged'),
 ('1968', '$2 an hour', 'Statutory Minimum Wage'),
 ('1968', 'About $10.90 an hour in 2017 dollars', 'Worth'),
 ('Hedge Funds', 'Oil Prices', 'Most_Bullish'),
 ('Hedge Funds', 'Further_Gains', 'Expect'),
 ('Hedge Funds', 'Risk', 'Ignore'),
 ('Hedge Funds', 'Record_Net_Long_Position', 'Hold'),
 ('Hedge Funds', '1183 Million Barrels', 'Amount'),
 ('Hedge Funds', 'Record_Net_Long_Positions', 'Have'),
 ('Hedge Funds', 'Net_Long_Positions', 'Have'),
 ('Hedge Funds', 'Large_Net_Long_Positions', 'Have'),
 ('Hedge Funds', 'Stretched_Position'

In [9]:
from pyvis.network import Network
import networkx as nx
graph.write_to_gml("test_graph.gml")
G = nx.read_gml("test_graph.gml")
nt = Network(notebook=True, cdn_resources='in_line')
nt.from_nx(G)
nt.show("network.html")

network.html


In [10]:
from langchain.chains import GraphQAChain
chain = GraphQAChain.from_llm(ChatNVIDIA(model="mixtral_8x7b", temperature=0), graph=graph, verbose=True)

In [15]:
chain.run("Explain what is going on with minimum wage in California.")



[1m> Entering new GraphQAChain chain...[0m
Entities Extracted:
[32;1m[1;3mCalifornia, minimum wage

---

Explanation:

The two entities present in the text are 'California' and 'minimum wage'. 'California' is a proper noun and a place, while 'minimum wage' is a common noun phrase that refers to the lowest hourly wage that an employer may legally pay to an employee.[0m
Full Context:
[32;1m[1;3mCalifornia Raise Minimum Wage
California Impact More than 2 million workers[0m

[1m> Finished chain.[0m


'In California, there has been a move to increase the minimum wage, which will impact more than 2 million workers. The exact details and timeline of the increase are not specified, but it is clear that a significant number of workers in the state will be affected by this change. Employers will need to adjust their payrolls accordingly, and workers can expect to see an increase in their earnings.'

In [None]:
chain.run("tell me what you know about Goldman Sachs.")

In [None]:
# chain.run("How do hedge fund managers mitigate liquidataion risk?")

In [None]:
chain.run("Which securities had the highest growth rate?")