# Create and query knowledge graph with LLM. (CosmosDB version)

We created a pickle file in notebook ´[Create knowledge graph from PDF](./knowledgegraph.ipynb)´ that we will now use with CosmosDB

In [None]:
%pip install -r requirements.txt

# Load env variables and connect to CosmosDB database

You need to have CosmosDB Gremlin graph database in Azure.
The database name is expected to be 'rag' and the graph-name 'kg'.

In [None]:
import os
from dotenv import load_dotenv
from gremlin_python.driver import client, serializer
import nest_asyncio


load_dotenv()

nest_asyncio.apply()
client = client.Client(
    url=os.environ["GREMLIN_URI"],
    traversal_source="g",
    username="/dbs/rag/colls/kg", # If you want to use different database/graph names, edit this
    password=os.environ["GREMLIN_PASSWORD"],
    message_serializer=serializer.GraphSONSerializersV2d0(),
)

In [None]:
import pickle
with open('./data/graph_docs.pkl','rb') as f:
    graph_docs = pickle.load(f)

In [None]:
import os
from gremlin_python.driver import client, serializer
import nest_asyncio
nest_asyncio.apply()
client = client.Client(
    url=os.environ["GREMLIN_URI"],
    traversal_source="g",
    username="/dbs/rag/colls/kg",
    password=os.environ["GREMLIN_PASSWORD"],
    message_serializer=serializer.GraphSONSerializersV2d0(),
)

def build_node_update_query(label_value, id_value, properties):
    base_query = f"g.V().hasLabel('{label_value}').has('id','{id_value}').fold().coalesce(unfold(),addV('{label_value}').property('id','{id_value}').property('type', '{label_value}')"
    for key, value in properties.items():
        base_query += f".property('{key}', '{value}')"
    
    return base_query + ")"

def build_source_to_target_query(type, source, target, properties):
    source_query = f".hasLabel('{source['label']}').has('id','{source['id']}')"
    target_query = f".hasLabel('{target['label']}').has('id','{target['id']}')"
    source = f"g.V()" + source_query	
    target = f"g.V()" + target_query
    
    base_query = f"""
    {source} 
    .as('source')  
    .outE('{type}').inV(){target_query}
    .fold()  
    .coalesce(  
        __.unfold(),  
        __.addE('{type}').from('source').to(  
            {target}
        )  
    )
    """
    for key, value in properties.items():
        base_query += f".property('{key}', '{value}')"
    

    return base_query

for document in graph_docs:
            # Import nodes
            for el in document.nodes:
                #print (el)
                query = build_node_update_query(el.type, el.id, el.properties)
                #print(query)
                #client.submit(query)
                
            for el in document.relationships:
                print (el)
                # Find or create the source vertex
                source = client.submit(build_node_update_query(el.source.type, el.source.id, el.source.properties)).all().result()[0]                                
                # Find or create the target vertex
                target = client.submit(build_node_update_query(el.target.type, el.target.id, el.target.properties)).all().result()[0] 
                #target = self.g.V().hasLabel(el.target.type).has('id', el.target.id).fold().coalesce(self.g.unfold(), self.g.addV(el.target.type).property('id', el.target.id).property('pk', 'partitionKeyValue')).next()
                # Find or create the edge
                
                edge = client.submit(build_source_to_target_query(el.type, source, target, el.properties))
                print(edge)
                # Update the properties of the edge
                #for key, value in el.properties.items():

                    #edge.property(key, value)'id', el.target.id).next())
                #edge = source.addEdge(el.type.replace(" ", "_").upper(), target)
                #for key, value in el.properties.items():
                    #edge.property(key, value)

#rs = client.submit(
#    message=(
#   ),
#    bindings={
##        "label_value": "Breed",
 #       "id_value": "Breed"}
 #   )
#print(rs.all().result())


In [None]:
import ssl
from httpx import HTTPTransport
from GremlinGraph import GremlinGraph
from gremlin_python.process.anonymous_traversal import traversal
from gremlin_python.driver.driver_remote_connection import DriverRemoteConnection

g = traversal().with_remote(
  DriverRemoteConnection(url=os.environ["GREMLIN_URI"], username="/dbs/rag/colls/kg", password=os.environ["GREMLIN_PASSWORD"],
                         transport_factory=lambda: HTTPTransport(read_timeout=60,
                                                                    write_timeout=20,
                                                                    heartbeat=10,
                                                                    call_from_event_loop=True,
                                                                    max_content_length=100*1024*1024,
                                                                    ssl_options=ssl.create_default_context(ssl.Purpose.CLIENT_AUTH))))
g = traversal().with_remote(DriverRemoteConnection(url=os.environ["GREMLIN_URI"],username="/dbs/rag/colls/kg", password=os.environ["GREMLIN_PASSWORD"]))
print(g.V().has_label('breed').next())
for document in graph_docs:
            # Import nodes
            for el in document.nodes:
                node =  g.V().has_label(el.type).has('id', el.id).tryNext().orElseGet(lambda: self.g.addV(el.type).property('id', el.id).next())
                print(node)
                    
                    
