In [34]:
from sentence_transformers import SentenceTransformer
import requests
import pandas as pd

from langchain_chroma import Chroma
import chromadb
from tqdm import tqdm

from langchain_ollama import ChatOllama
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.output_parsers import StrOutputParser


In [35]:
CRISP_ENDPOINT = 'http://crisp.ai.wu.ac.at/blazegraph/namespace/crisp/sparql'
OLLAMA_ENDPOINT = '127.0.0.1:11434'
CRISP_NAMESPACE = 'http://crisp.ai.wu.ac.at/crisp/'

In [36]:
def sparql_query(query: str) -> pd.DataFrame:
    """
    Executes a SPARQL query on a pre-loaded RDF graph and returns the results as a DataFrame.
    """
    try:
        # Prepare and execute the query        
        # query = prepareQuery(query)
        # results = rdf_graph.query(query)

        response = requests.get(CRISP_ENDPOINT, params={'query': query, 'format': 'json'})
        results = response.json()

        
        # Extract variable (column) names from the query result
        columns = results['head']['vars']  # Get the variable names from the query results
        
        # Process the results and convert them into a list of dictionaries
        data = []
        for row in results['results']['bindings']:
            row_data = {str(var): row[var]['value'].replace(CRISP_NAMESPACE, "") for var in columns}  # Dynamically build a row dict
            # .replace(CRISP_NS, "")
            data.append(row_data)
        
        # Convert the data into a DataFrame
        df = pd.DataFrame(data, columns=[str(var) for var in columns])
        return df

    except Exception as e:
        print(f"An error occurred while executing the SPARQL query: {e}")
        return pd.DataFrame()

In [37]:

# A pretrained Sentence Transformer model
model = SentenceTransformer("all-MiniLM-L6-v2")

In [38]:
model.encode("Why sky is blue?")

array([ 1.56446267e-02, -1.52610093e-02,  4.47327122e-02,  3.29037718e-02,
        3.47187072e-02, -8.04994954e-04,  1.15529321e-01, -2.80523170e-02,
        1.08167142e-01,  2.17172634e-02, -2.75128148e-02, -1.61158666e-02,
       -8.94826744e-03, -6.10241033e-02, -7.60247651e-03,  7.03079477e-02,
       -2.53220275e-02, -1.41977936e-01, -9.10010561e-02, -6.26014546e-02,
       -5.69708869e-02,  4.83268835e-02, -7.19384253e-02,  5.36682829e-02,
       -3.64999361e-02,  5.27441874e-02, -3.15055251e-02,  8.93931277e-03,
        5.95199652e-02,  1.00659700e-02, -3.05421036e-02,  9.08340663e-02,
        4.35057469e-02,  2.59071775e-02, -4.55773063e-02, -5.03962412e-02,
        6.65812343e-02, -4.05957326e-02, -1.82360224e-02, -3.43648195e-02,
       -2.86470428e-02, -3.37902009e-02, -2.25397944e-02,  4.07451205e-03,
       -1.10364705e-02,  2.36096065e-02, -1.19074453e-02,  3.38454060e-02,
        7.42058977e-02, -1.52767990e-02, -3.38980816e-02,  7.99685717e-03,
       -8.87603760e-02,  

In [39]:
# Community information
community_query = """
PREFIX rdfs: <http://www.w3.org/2000/01/rdf-schema#>
prefix crisp: <http://crisp.ai.wu.ac.at/crisp/>

SELECT ?community_id ?community_name ?state_name

WHERE {
  ?community_id a crisp:Community;
    rdfs:label ?community_name ; 
    crisp:locatedIn ?district_id .
  ?district_id crisp:locatedIn ?state_id .
  ?state_id rdfs:label ?state_name .      
} LIMIT 10
"""

# Observations of a community
def observation_query(community_id):
  return f"""
    prefix sosa: <http://www.w3.org/ns/sosa/> 
    prefix crisp: <http://crisp.ai.wu.ac.at/crisp/>

    SELECT ?observation_id ?property ?value

    WHERE {{
      ?observation_id sosa:hasFeatureOfInterest <http://crisp.ai.wu.ac.at/crisp/{community_id}> ;
          sosa:observedProperty ?property; 
          sosa:hasSimpleResult ?value . 
    }}
    """




In [40]:
# Initialize Chroma Persistent Client
persistent_client = chromadb.PersistentClient()
collection = persistent_client.get_or_create_collection(name="graphrag_collection")

# Prepare data for Chroma collection
documents, metadatas, embeddings, ids = [], [], [], []


In [41]:

# Adding communities to Vector DB
community_df = sparql_query(community_query)
print(community_df)


# Using tqdm for progress tracking
print("Processing communities and adding to Chroma collection...")

communities = {}
for _, row in tqdm(community_df.iterrows(), total=len(community_df), desc="Communities Processed"):
    description = f"{row['community_name']}  is a community located in {row['state_name']} state in Austria"
    description_embedding = model.encode(description)
    documents.append(f"Name: {row['community_name']}, Type: 'Community', Description: {description}")
    metadatas.append({
        'subject': row['community_id'],
        'name': str(row['community_name']),
        'type': 'Community',
        'description': description
    })
    # embeddings.append([float(x) for x in row['description_embedding'].split()])
    embeddings.append(description_embedding)
    ids.append(str(row['community_id']))
    communities[row['community_id']] = row['community_name']


         community_id       community_name state_name
0  id/community/20502               Brückl    Kärnten
1  id/community/20503      Deutsch-Griffen    Kärnten
2  id/community/20504            Eberstein    Kärnten
3  id/community/20505             Friesach    Kärnten
4  id/community/20506             Glödnitz    Kärnten
5  id/community/20508                 Gurk    Kärnten
6  id/community/20509            Guttaring    Kärnten
7  id/community/20511           Hüttenberg    Kärnten
8  id/community/20512  Kappel am Krappfeld    Kärnten
9  id/community/20513       Klein St. Paul    Kärnten
Processing communities and adding to Chroma collection...


Communities Processed: 100%|██████████| 10/10 [00:00<00:00, 111.33it/s]


In [42]:
def observation_description(community_name, row):
    if row['property'] == "weeklyHeatdaysOver30":
        year, week = row['observation_id'].split('/')[-2:]
        return f"The {community_name} community experienced {row['value']} hot days during week {week} of {year}"
    else:
        year = row['observation_id'].split('/')[-1]
        return f"The population of the {community_name} community in {year} was {row['value']}"

In [43]:
# Adding community observations to Vector DB
for community_id in communities:
  observation_df = sparql_query(observation_query(community_id))

  # Using tqdm for progress tracking
  print("Processing observations and adding to Chroma collection...")

  for _, row in tqdm(observation_df.iterrows(), total=len(observation_df), desc="Observations Processed"):

      
      description = observation_description(communities[community_id], row)
      description_embedding = model.encode(description)
      documents.append(f"Type: 'Observation', Description: {description}")
      metadatas.append({
          'subject': row['observation_id'],
          'type': 'Observation',
          'description': description
      })
      embeddings.append(description_embedding)
      ids.append(str(row['observation_id']))



Processing observations and adding to Chroma collection...


Observations Processed: 100%|██████████| 111/111 [00:00<00:00, 122.17it/s]


Processing observations and adding to Chroma collection...


Observations Processed: 100%|██████████| 52/52 [00:00<00:00, 109.27it/s]


Processing observations and adding to Chroma collection...


Observations Processed: 100%|██████████| 102/102 [00:00<00:00, 122.41it/s]


Processing observations and adding to Chroma collection...


Observations Processed: 100%|██████████| 71/71 [00:00<00:00, 117.24it/s]


Processing observations and adding to Chroma collection...


Observations Processed: 100%|██████████| 60/60 [00:00<00:00, 122.41it/s]


Processing observations and adding to Chroma collection...


Observations Processed: 100%|██████████| 73/73 [00:00<00:00, 113.62it/s]


Processing observations and adding to Chroma collection...


Observations Processed: 100%|██████████| 84/84 [00:00<00:00, 119.51it/s]


Processing observations and adding to Chroma collection...


Observations Processed: 100%|██████████| 65/65 [00:00<00:00, 120.67it/s]


Processing observations and adding to Chroma collection...


Observations Processed: 100%|██████████| 111/111 [00:00<00:00, 116.72it/s]


Processing observations and adding to Chroma collection...


Observations Processed: 100%|██████████| 87/87 [00:00<00:00, 119.71it/s]


In [44]:

# Add the processed data to the Chroma collection
collection.add(
    documents=documents,
    metadatas=metadatas,
    embeddings=embeddings,
    ids=ids
)
print("Chroma collection populated successfully.")

# Initialize vector store using Chroma
vector_store = Chroma(client=persistent_client, collection_name="graphrag_collection")

# Verify the count of entries in the collection
print(f"Total entries in the collection: {vector_store._collection.count()}")


Add of existing embedding ID: id/community/20502
Add of existing embedding ID: id/community/20503
Add of existing embedding ID: id/community/20504
Add of existing embedding ID: id/community/20505
Add of existing embedding ID: id/community/20506
Add of existing embedding ID: id/community/20508
Add of existing embedding ID: id/community/20509
Add of existing embedding ID: id/community/20511
Add of existing embedding ID: id/community/20512
Add of existing embedding ID: id/community/20513
Add of existing embedding ID: id/community/20502/weeklyHeatdaysOver30/2022/20
Add of existing embedding ID: id/community/20502/weeklyHeatdaysOver30/2022/22
Add of existing embedding ID: id/community/20502/weeklyHeatdaysOver30/2022/24
Add of existing embedding ID: id/community/20502/weeklyHeatdaysOver30/2022/25
Add of existing embedding ID: id/community/20502/weeklyHeatdaysOver30/2022/26
Add of existing embedding ID: id/community/20502/weeklyHeatdaysOver30/2022/27
Add of existing embedding ID: id/community

Chroma collection populated successfully.
Total entries in the collection: 4220


# LLM Query

In [None]:
llm = ChatOllama(
    # model = "llama3.1",
    model = "gemma3:12b",
    temperature = 0.8,
    num_predict = 256,
    base_url = OLLAMA_ENDPOINT
    # other params ...
)

community_id, community_name = list(communities.items())[0]

question = f"What can you tell me about hot days in {community_name}?"

In [46]:
# Query without context
prompt = ChatPromptTemplate.from_messages(
    [
        (
            "system",
            "You are a helpful assistant that answers questions about community observations.",
        ),
        ("human", "{input}"),
    ]
)

chain = prompt | llm | StrOutputParser()
chain.invoke(
    {
        "input": question,
    }
)

ResponseError: model "llama3.1" not found, try pulling it first (status code: 404)

In [None]:
# Calculate embedding of query
query_embedding_vector = model.encode(question)
print(query_embedding_vector.shape)

In [None]:
TOP_ENTITIES = 10
TOP_CHUNKS = 10
TOP_COMMUNITIES = 3
TOP_OUTGOING_RELATIONSHIPS = 10
TOP_INCOMING_RELATIONSHIPS = 10

In [None]:
results = vector_store.similarity_search_by_vector(
    embedding=query_embedding_vector, k=TOP_ENTITIES
)
entity_list = [doc.metadata['subject'] for doc in results]
descriptions = [doc.metadata['description'] for doc in results]

context = ". \n".join(descriptions)
print(context)

In [None]:
# Query with context from embeddings
prompt = ChatPromptTemplate.from_messages(
    [
         (
            "system",
            f"You are a helpful assistant that answers questions about community observations. You have the following context: {context}",
        ),
       ("human", "{input}"),
    ]
)
chain = prompt | llm | StrOutputParser()
chain.invoke(
    {
        "input": question,
    }
)