In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import json

import llama_cpp
import torch

from curverag import utils
from curverag.curverag import CurveRAG, DEFAULT_ENTITY_TYPES, DEFAULT_GLINER_MODEL
from curverag.graph import KnowledgeGraph
from curverag.atth.kg_dataset import KGDataset
from curverag.atth.models.hyperbolic import AttH

In [3]:
import os
from dotenv import load_dotenv
from openai import OpenAI
load_dotenv() 

True

## Dataset and LLM

In [4]:
max_tokens = 10000
n_ctx=10000

docs = [
    "The patient was diagnosed with type 2 diabetes mellitus and prescribed metformin 500mg twice daily.",
    "MRI scan revealed a small lesion in the left temporal lobe suggestive of low-grade glioma.",
    "Administer 5mg of lorazepam intravenously for acute seizure management.",
    "Blood tests showed elevated ALT and AST levels, indicating possible liver inflammation.",
    "The subject reported chronic lower back pain, managed with physical therapy and NSAIDs.",
    "CT angiography confirmed the presence of a pulmonary embolism in the right lower lobe.",
    "The patient underwent coronary artery bypass graft surgery without complications.",
    "Routine vaccination included MMR, tetanus, and influenza immunizations.",
    "Histopathology indicated ductal carcinoma in situ (DCIS) in the breast biopsy sample.",
    "The child presented with a persistent cough and fever, diagnosed as streptococcal pharyngitis."
]


In [None]:
model = utils.load_model(
    llm_model_path="./models/Meta-Llama-3-8B-Instruct.Q6_K.gguf",
    tokenizer=llama_cpp.llama_tokenizer.LlamaHFTokenizer.from_pretrained("unsloth/Llama-3.2-1B-Instruct"),
    n_ctx=n_ctx,
    max_tokens=max_tokens
)

In [6]:
api_key = os.getenv("OPENAI_API_KEY")
client = OpenAI(api_key=api_key)

## Fit CurveRAG

In [7]:
entity_types = ['people', 'locations', 'entities', 'movies', 'directors']
rag = CurveRAG(
    openai_client=client,
    entity_types=entity_types,
)

Fetching 5 files:   0%|          | 0/5 [00:00<?, ?it/s]



In [7]:
rag = CurveRAG(llm=model)

NameError: name 'model' is not defined

In [8]:
rag.fit(docs=docs, dataset_name='test_run')

creating graph


 10%|████████▎                                                                          | 1/10 [00:05<00:48,  5.37s/it]

upserting


 20%|████████████████▌                                                                  | 2/10 [00:10<00:43,  5.39s/it]

upserting


 30%|████████████████████████▉                                                          | 3/10 [00:24<01:05,  9.29s/it]

upserting


 40%|█████████████████████████████████▏                                                 | 4/10 [00:33<00:54,  9.14s/it]

upserting


 50%|█████████████████████████████████████████▌                                         | 5/10 [00:41<00:44,  8.82s/it]

upserting


 60%|█████████████████████████████████████████████████▊                                 | 6/10 [00:47<00:30,  7.59s/it]

upserting


 70%|██████████████████████████████████████████████████████████                         | 7/10 [00:50<00:18,  6.13s/it]

upserting


 80%|██████████████████████████████████████████████████████████████████▍                | 8/10 [00:55<00:11,  5.91s/it]

upserting


 90%|██████████████████████████████████████████████████████████████████████████▋        | 9/10 [00:59<00:05,  5.14s/it]

upserting


100%|██████████████████████████████████████████████████████████████████████████████████| 10/10 [01:03<00:00,  6.31s/it]
2025-06-12 20:14:30,417 INFO     Saving logs in: ./logs/06_12\test_run\AttH_20_14_30
2025-06-12 20:14:30,418 INFO     	 (32, 30, 32)
2025-06-12 20:14:30,433 INFO     Total number of parameters 182094
2025-06-12 20:14:30,435 INFO     	 Start training


upserting
num nodes 32
unique node ids {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32}
creating dataset
train kg embeddings
max_epochs 50


train loss: 100%|█████████████████████████████████████████████████████████| 48/48 [00:00<00:00, 55.49ex/s, loss=0.6883]
2025-06-12 20:14:31,471 INFO     	 Epoch 0 | average train loss: 0.6883
2025-06-12 20:14:31,483 INFO     	 Epoch 0 | average valid loss: 2.0935
train loss: 100%|█████████████████████████████████████████████████████████| 48/48 [00:00<00:00, 67.32ex/s, loss=2.1058]
2025-06-12 20:14:32,201 INFO     	 Epoch 1 | average train loss: 2.1058
2025-06-12 20:14:32,209 INFO     	 Epoch 1 | average valid loss: 1.9162
train loss: 100%|█████████████████████████████████████████████████████████| 48/48 [00:00<00:00, 68.28ex/s, loss=1.9965]
2025-06-12 20:14:32,916 INFO     	 Epoch 2 | average train loss: 1.9965
2025-06-12 20:14:32,926 INFO     	 Epoch 2 | average valid loss: 1.7142
2025-06-12 20:14:32,940 INFO     	 valid MR: 26.50 | MRR: 0.110 | H@1: 0.000 | H@3: 0.167 | H@10: 0.167
2025-06-12 20:14:32,941 INFO     	 Epoch 2 | average valid mmr: 0.1100
2025-06-12 20:14:32,942 INFO     

returning model


Batches:   0%|          | 0/1 [00:00<?, ?it/s]

In [9]:
len(rag.graph.nodes)

32

In [10]:
import json
with open('rag.graph.json', 'w') as f:
    json.dump(rag.graph.dict(), f, indent=4)

C:\Users\natha\AppData\Local\Temp\ipykernel_41816\946862156.py:3: PydanticDeprecatedSince20: The `dict` method is deprecated; use `model_dump` instead. Deprecated in Pydantic V2.0 to be removed in V3.0. See Pydantic V2 Migration Guide at https://errors.pydantic.dev/2.11/migration/
  json.dump(rag.graph.dict(), f, indent=4)


In [11]:
rag.graph.nodes

[Node(id=1, name='Patient', description='A person receiving medical care. A person receiving medical care. A person receiving medical care. A person receiving medical care. A person receiving medical care. A person receiving medical care. A person receiving medical care. A person receiving medical care. A person receiving medical care.', alias=['Individual', 'Subject', 'Patient'], additional_information=['Diagnosed with type 2 diabetes mellitus']),
 Node(id=2, name='Type 2 Diabetes Mellitus', description='A chronic condition affecting the way the body processes blood sugar. A chronic condition affecting the way the body processes blood sugar.', alias=['Diabetes Type 2', 'Type 2 Diabetes Mellitus'], additional_information=[]),
 Node(id=3, name='Metformin', description='A medication used to improve blood sugar control in people with type 2 diabetes. A medication used to improve blood sugar control in people with type 2 diabetes.', alias=['Metformin'], additional_information=['Prescribed 

In [11]:
rag.graph.dict()

/var/folders/m6/12prs8yx2r90psgztq05hy_c0000gn/T/ipykernel_69670/774845048.py:1: PydanticDeprecatedSince20: The `dict` method is deprecated; use `model_dump` instead. Deprecated in Pydantic V2.0 to be removed in V3.0. See Pydantic V2 Migration Guide at https://errors.pydantic.dev/2.11/migration/
  rag.graph.dict()


{'nodes': [{'id': 1,
   'name': 'Patient',
   'description': 'A person who is receiving medical treatment',
   'alias': ['individual', 'person']},
  {'id': 2,
   'name': 'Diabetes Mellitus',
   'description': 'A group of metabolic disorders characterized by high blood sugar levels',
   'alias': ['diabetes', 'DM']},
  {'id': 3,
   'name': 'Metformin',
   'description': 'A medication used to treat type 2 diabetes',
   'alias': ['glucophage', 'metformin hydrochloride']},
  {'id': 4,
   'name': 'Type 2 Diabetes',
   'description': 'A type of diabetes that is caused by insulin resistance and impaired insulin secretion',
   'alias': ['T2D', 'non-insulin-dependent diabetes']},
  {'id': 5,
   'name': 'Daily',
   'description': 'A unit of time',
   'alias': ['day', '24 hours']},
  {'id': 6,
   'name': '500mg',
   'description': 'A dosage of medication',
   'alias': ['half a gram', '0.5 grams']},
  {'id': 7,
   'name': 'Twice',
   'description': 'A frequency of medication administration',
   'al

In [11]:
nodes_id_idx = rag.dataset.nodes_id_idx

In [19]:
node_embs = rag.model.entity.weight.data.cpu().numpy()
print(len(node_embs))
node_embs[nodes_id_idx[1]][:5]

45


array([ 0.00581928, -0.02819522,  0.0190266 , -0.00319721, -0.00893134])

# Create Network X Graph

In [12]:
import networkx as nx
G = nx.Graph()

In [13]:
G.add_nodes_from([n.id for n in rag.graph.nodes])

In [14]:
edges = [(e.source, e.target) for e in rag.graph.edges]
G.add_edges_from(edges)

In [15]:
G.number_of_nodes()

32

In [16]:
G.number_of_edges()

30

In [17]:
personalization = {21: 1, 22: 1}
pr = nx.pagerank(G, alpha=0.85, personalization=personalization)

pr_sorted = sorted(pr.items(), key=lambda x: x[1], reverse=True)


In [18]:
top_k = 5

In [21]:
[i[0] for i in pr_sorted[:top_k]]

[21, 22, 19, 18, 20]

In [25]:
pr

{1: 0.10465670434595327,
 2: 0.019827111440711198,
 3: 0.03622576628851659,
 4: 0.03858331731416875,
 5: 0.021398654847805402,
 6: 0.0486480043320244,
 7: 0.025675997833987798,
 8: 0.025675997833987798,
 9: 0.0486480043320244,
 10: 0.025675997833987798,
 11: 0.025675997833987798,
 12: 0.03333333333333333,
 13: 0.03333333333333333,
 14: 0.03333333333333333,
 15: 0.03333333333333333,
 16: 0.0486480043320244,
 17: 0.025675997833987798,
 18: 0.025675997833987798,
 19: 0.03333333333333333,
 20: 0.03333333333333333,
 21: 0.019827111440711198,
 22: 0.019827111440711198,
 23: 0.019827111440711198,
 24: 0.019827111440711198,
 25: 0.03333333333333333,
 26: 0.03333333333333333,
 27: 0.06396267533071547,
 28: 0.02312355266753929,
 29: 0.02312355266753929,
 30: 0.02312355266753929}

rag.graph.nodes #2 diabites, 3 metamorfin 25 tenaus cavvine 26 influenza vaccine 27 ductal in situ

In [28]:
rag.graph.edges

[Edge(source=1, target=2, name='Diagnosed With', is_directed=True, description='Patient diagnosed with condition', notes=[]),
 Edge(source=1, target=3, name='Prescribed', is_directed=True, description='Patient prescribed medication', notes=['500mg twice daily']),
 Edge(source=4, target=6, name='Located In', is_directed=True, description='Lesion found in a specific brain region', notes=[]),
 Edge(source=4, target=5, name='Suggestive Of', is_directed=True, description='Lesion indicates the presence of a tumor', notes=['Indicates low-grade glioma']),
 Edge(source=7, target=4, name='Revealed', is_directed=True, description='MRI scan found lesion', notes=['Small lesion noted']),
 Edge(source=1, target=8, name='Prescribed', is_directed=True, description='Patient administered medication', notes=['5mg intravenously']),
 Edge(source=9, target=1, name='Suffer From', is_directed=True, description='Patient experiences a medical condition', notes=['Acute seizure management']),
 Edge(source=10, targ

In [27]:
query = "What is used to treat Diabetes Mellitus and how does can a patient use the treatment?"
rag.query(query, traversal = 'pp')

entities []


Batches:   0%|          | 0/1 [00:00<?, ?it/s]

graph nodes retrieved ['Type 2 Diabetes Mellitus']
similar_node_indexes [1, 2, 24, 25, 26]
similar_node_ids [2, 3, 25, 26, 27]
similar_node_ids graph nodes retrieved ['Type 2 Diabetes Mellitus', 'Metformin', 'Tetanus Vaccine', 'Influenza Vaccine', 'Ductal Carcinoma In Situ']
***************************************!!******************************
***************************************!!******************************
***************************************!!******************************
***************************************!!******************************
***************************************!!******************************

You are a helpful assistant analyzing the given input data to provide an helpful response to the user query.

# USER QUERY
What is used to treat Diabetes Mellitus and how does can a patient use the treatment?

# Context:
KnowledgeGraph Overview
  There are 5 entities and 1 relationships in this graph.

Entities in the graph:
  • 'Type 2 Diabetes Mellitus':
    

2025-06-12 20:46:22,857 INFO     HTTP Request: POST https://api.openai.com/v1/responses "HTTP/1.1 200 OK"


"To treat Type 2 Diabetes Mellitus, the primary medication used is Metformin. It is prescribed to improve blood sugar control in people with this condition, typically at a dosage of 500mg taken twice daily. Patients should follow their healthcare provider's instructions regarding the timing and dosage of Metformin to effectively manage their blood sugar levels and monitor their health closely. Adjustments may be necessary based on individual response and side effects[1]. \n\nNo additional information was found regarding other treatment options or methods of use beyond Metformin as it relates to the provided data."

# Load CurveRag modules

In [24]:
# load KG
with open('rag.graph.json', 'r') as f:
    data = json.load(f)

#kg = KnowledgeGraph.construct(**data)
kg = KnowledgeGraph.parse_file("rag.graph.json")

/var/folders/m6/12prs8yx2r90psgztq05hy_c0000gn/T/ipykernel_82949/2232709649.py:6: PydanticDeprecatedSince20: The `parse_file` method is deprecated; load the data from file, then if your data is JSON use `model_validate_json`, otherwise `model_validate` instead. Deprecated in Pydantic V2.0 to be removed in V3.0. See Pydantic V2 Migration Guide at https://errors.pydantic.dev/2.11/migration/
  kg = KnowledgeGraph.parse_file("rag.graph.json")


In [23]:
kg.nodes[0]

Node(id=1, name='Patient', description='A person who is receiving medical treatment', alias=['individual', 'person'])

In [6]:
# load dataset
dataset_path = './data/medical_docs'
dataset = KGDataset(dataset_path, debug=False, name='medical_docs')

In [7]:
# load embeding model
sizes = dataset.get_shape()
embedding_model = AttH(sizes=sizes)
model_path = '/Users/nathan/Documents/projects/curve_rag/logs/05_05/test_run/AttH_00_38_35/model.pt'
embedding_model.load_state_dict(torch.load(model_path, weights_only=True))
embedding_model.eval()

AttH(
  (entity): Embedding(45, 1000)
  (rel): Embedding(48, 1000)
  (bh): Embedding(45, 1)
  (bt): Embedding(45, 1)
  (rel_diag): Embedding(48, 2000)
  (context_vec): Embedding(48, 1000)
  (act): Softmax(dim=1)
)

In [None]:
# load llm
max_tokens = 10000
n_ctx=10000
llm, outlines_model = utils.load_model(
    llm_model_path="./models/Meta-Llama-3-8B-Instruct.Q6_K.gguf",
    tokenizer=llama_cpp.llama_tokenizer.LlamaHFTokenizer.from_pretrained("unsloth/Llama-3.2-1B-Instruct"),
    n_ctx=n_ctx,
    max_tokens=max_tokens
)

In [67]:
resp = llm('Hi how are you?')

Llama.generate: 6 prefix-match hit, remaining 1 prompt tokens to eval
llama_perf_context_print:        load time =    2329.38 ms
llama_perf_context_print: prompt eval time =       0.00 ms /     1 tokens (    0.00 ms per token,      inf tokens per second)
llama_perf_context_print:        eval time =    5169.89 ms /    16 runs   (  323.12 ms per token,     3.09 tokens per second)
llama_perf_context_print:       total time =    5176.08 ms /    17 tokens


In [59]:
resp['choices'][0]['text']

" I hope you're doing well. I am from Mexico and I am a fan"

In [68]:
cg = CurveRAG.load_class(
    llm=llm,
    outlines_llm=outlines_model,
    entity_types=DEFAULT_ENTITY_TYPES,
    gliner_model_name=DEFAULT_GLINER_MODEL,
    graph=kg,
    graph_embedding_model=embedding_model,
    dataset=dataset,
)

Fetching 5 files:   0%|          | 0/5 [00:00<?, ?it/s]

In [69]:
entities = ['Patient', 'Diabetes Mellitus', 'Metformin', 'Type 2 Diabetes', 'Daily']

In [22]:
query = "What is used to treat Diabetes Mellitus and how does can a patient use the treatment?"
rag.query(query)

NameError: name 'cg' is not defined

In [86]:
llm2("What is used to treat Diabetes Mellitus and how does can a patient use the treatment?",  max_tokens=1000)['choices'][0]['text']

Llama.generate: 10 prefix-match hit, remaining 10 prompt tokens to eval
llama_perf_context_print:        load time =    3865.92 ms
llama_perf_context_print: prompt eval time =    1877.04 ms /    10 tokens (  187.70 ms per token,     5.33 tokens per second)
llama_perf_context_print:        eval time =   51397.43 ms /   999 runs   (   51.45 ms per token,    19.44 tokens per second)
llama_perf_context_print:       total time =   54751.37 ms /  1009 tokens


" (2023)\nWhat is Diabetes Mellitus?\nDiabetes Mellitus, commonly referred to as diabetes, is a metabolic disorder that affects the way the body regulates blood sugar levels. There are several types of diabetes, but the most common ones are:\nType 1 diabetes: An autoimmune disease where the body's immune system attacks and destroys the cells in the pancreas that produce insulin, a hormone that regulates blood sugar levels.\nType 2 diabetes: A condition where the body becomes resistant to insulin, making it difficult for the body to use insulin effectively.\nGestational diabetes: A type of diabetes that develops during pregnancy due to hormonal changes.\nWhat are the symptoms of Diabetes Mellitus?\nThe symptoms of diabetes can vary from person to person, but common signs and symptoms include:\nIncreased thirst and urination\nFatigue and weakness\nBlurred vision\nSlow healing of cuts and wounds\n Tingling or numbness in the hands and feet\nHow is Diabetes Mellitus treated?\nThe treatment

In [None]:
from llama_cpp import Llama
llm2 = Llama(
        model_path="./models/Meta-Llama-3-8B-Instruct.Q6_K.gguf",
        tokenizer=llama_cpp.llama_tokenizer.LlamaHFTokenizer.from_pretrained("unsloth/Llama-3.2-1B-Instruct"),
        n_ctx=n_ctx,
        max_tokens=max_tokens
    )