In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import json

import llama_cpp
import torch

from curverag import utils
from curverag.curverag import CurveRAG, DEFAULT_ENTITY_TYPES, DEFAULT_GLINER_MODEL
from curverag.graph import KnowledgeGraph
from curverag.atth.kg_dataset import KGDataset
from curverag.atth.models.hyperbolic import AttH

In [3]:
import os
from dotenv import load_dotenv
from openai import OpenAI
load_dotenv() 

True

## Dataset and LLM

In [4]:
max_tokens = 10000
n_ctx=10000

docs = [
    "The patient was diagnosed with type 2 diabetes mellitus and prescribed metformin 500mg twice daily.",
    "MRI scan revealed a small lesion in the left temporal lobe suggestive of low-grade glioma.",
    "Administer 5mg of lorazepam intravenously for acute seizure management.",
    "Blood tests showed elevated ALT and AST levels, indicating possible liver inflammation.",
    "The subject reported chronic lower back pain, managed with physical therapy and NSAIDs.",
    "CT angiography confirmed the presence of a pulmonary embolism in the right lower lobe.",
    "The patient underwent coronary artery bypass graft surgery without complications.",
    "Routine vaccination included MMR, tetanus, and influenza immunizations.",
    "Histopathology indicated ductal carcinoma in situ (DCIS) in the breast biopsy sample.",
    "The child presented with a persistent cough and fever, diagnosed as streptococcal pharyngitis."
]


In [None]:
model = utils.load_model(
    llm_model_path="./models/Meta-Llama-3-8B-Instruct.Q6_K.gguf",
    tokenizer=llama_cpp.llama_tokenizer.LlamaHFTokenizer.from_pretrained("unsloth/Llama-3.2-1B-Instruct"),
    n_ctx=n_ctx,
    max_tokens=max_tokens
)

In [5]:
api_key = os.getenv("OPENAI_API_KEY")
client = OpenAI(api_key=api_key)

## Fit CurveRAG

In [6]:
entity_types = ['people', 'locations', 'entities', 'movies', 'directors']
rag = CurveRAG(
    openai_client=client,
    entity_types=entity_types,
)

Fetching 5 files:   0%|          | 0/5 [00:00<?, ?it/s]



In [7]:
rag = CurveRAG(llm=model)

NameError: name 'model' is not defined

In [7]:
rag.fit(docs=docs, dataset_name='test_run')

creating graph


 10%|████████▎                                                                          | 1/10 [00:09<01:27,  9.75s/it]

upserting


 20%|████████████████▌                                                                  | 2/10 [00:14<00:53,  6.74s/it]

upserting


 30%|████████████████████████▉                                                          | 3/10 [00:19<00:42,  6.02s/it]

upserting


 40%|█████████████████████████████████▏                                                 | 4/10 [00:26<00:37,  6.33s/it]

upserting


 50%|█████████████████████████████████████████▌                                         | 5/10 [00:31<00:28,  5.75s/it]

upserting


 60%|█████████████████████████████████████████████████▊                                 | 6/10 [00:33<00:18,  4.69s/it]

upserting


 70%|██████████████████████████████████████████████████████████                         | 7/10 [00:36<00:12,  4.01s/it]

upserting


 80%|██████████████████████████████████████████████████████████████████▍                | 8/10 [00:41<00:08,  4.42s/it]

upserting


 90%|██████████████████████████████████████████████████████████████████████████▋        | 9/10 [00:46<00:04,  4.48s/it]

upserting


100%|██████████████████████████████████████████████████████████████████████████████████| 10/10 [00:52<00:00,  5.24s/it]
2025-06-10 20:23:32,407 INFO     Saving logs in: ./logs/06_10\test_run\AttH_20_23_32
2025-06-10 20:23:32,407 INFO     	 (30, 28, 30)
2025-06-10 20:23:32,429 INFO     Total number of parameters 170088
2025-06-10 20:23:32,431 INFO     	 Start training


upserting
num nodes 30
unique node ids {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30}
creating dataset
train kg embeddings
max_epochs 50


train loss: 100%|█████████████████████████████████████████████████████████| 42/42 [00:00<00:00, 80.31ex/s, loss=0.6884]
2025-06-10 20:23:32,956 INFO     	 Epoch 0 | average train loss: 0.6884
2025-06-10 20:23:32,962 INFO     	 Epoch 0 | average valid loss: 2.1058
train loss: 100%|█████████████████████████████████████████████████████████| 42/42 [00:00<00:00, 81.40ex/s, loss=2.1058]
2025-06-10 20:23:33,481 INFO     	 Epoch 1 | average train loss: 2.1058
2025-06-10 20:23:33,489 INFO     	 Epoch 1 | average valid loss: 1.9965
train loss: 100%|█████████████████████████████████████████████████████████| 42/42 [00:00<00:00, 81.55ex/s, loss=1.9965]
2025-06-10 20:23:34,007 INFO     	 Epoch 2 | average train loss: 1.9965
2025-06-10 20:23:34,012 INFO     	 Epoch 2 | average valid loss: 1.9274
2025-06-10 20:23:34,017 INFO     	 valid MR: 30.50 | MRR: 0.033 | H@1: 0.000 | H@3: 0.000 | H@10: 0.000
2025-06-10 20:23:34,018 INFO     	 Epoch 2 | average valid mmr: 0.0328
2025-06-10 20:23:34,018 INFO     

returning model


Batches:   0%|          | 0/1 [00:00<?, ?it/s]

In [8]:
len(rag.graph.nodes)

30

In [9]:
import json
with open('rag.graph.json', 'w') as f:
    json.dump(rag.graph.dict(), f, indent=4)

C:\Users\natha\AppData\Local\Temp\ipykernel_33204\946862156.py:3: PydanticDeprecatedSince20: The `dict` method is deprecated; use `model_dump` instead. Deprecated in Pydantic V2.0 to be removed in V3.0. See Pydantic V2 Migration Guide at https://errors.pydantic.dev/2.11/migration/
  json.dump(rag.graph.dict(), f, indent=4)


In [12]:
rag.graph.nodes

[Node(id=1, name='Patient', description='The person receiving medical care The person receiving medical care', alias=['Individual', 'Client', 'Patient'], additional_information=['The patient has a diagnosis of type 2 diabetes.']),
 Node(id=2, name='Type 2 Diabetes Mellitus', description='A chronic condition that affects the way the body processes blood sugar (glucose)', alias=['Diabetes Type 2', 'T2DM'], additional_information=['Type 2 diabetes is characterized by insulin resistance.']),
 Node(id=3, name='Metformin', description='A medication used to treat type 2 diabetes', alias=[], additional_information=['Metformin is commonly prescribed for blood sugar control.']),
 Node(id=4, name='500mg', description='Dosage of metformin prescribed', alias=[], additional_information=['Dosage frequency is twice daily.']),
 Node(id=5, name='Twice Daily', description='Indicates the frequency of medication administration', alias=['BID'], additional_information=['Common dosage instruction for medicati

In [11]:
rag.graph.dict()

/var/folders/m6/12prs8yx2r90psgztq05hy_c0000gn/T/ipykernel_69670/774845048.py:1: PydanticDeprecatedSince20: The `dict` method is deprecated; use `model_dump` instead. Deprecated in Pydantic V2.0 to be removed in V3.0. See Pydantic V2 Migration Guide at https://errors.pydantic.dev/2.11/migration/
  rag.graph.dict()


{'nodes': [{'id': 1,
   'name': 'Patient',
   'description': 'A person who is receiving medical treatment',
   'alias': ['individual', 'person']},
  {'id': 2,
   'name': 'Diabetes Mellitus',
   'description': 'A group of metabolic disorders characterized by high blood sugar levels',
   'alias': ['diabetes', 'DM']},
  {'id': 3,
   'name': 'Metformin',
   'description': 'A medication used to treat type 2 diabetes',
   'alias': ['glucophage', 'metformin hydrochloride']},
  {'id': 4,
   'name': 'Type 2 Diabetes',
   'description': 'A type of diabetes that is caused by insulin resistance and impaired insulin secretion',
   'alias': ['T2D', 'non-insulin-dependent diabetes']},
  {'id': 5,
   'name': 'Daily',
   'description': 'A unit of time',
   'alias': ['day', '24 hours']},
  {'id': 6,
   'name': '500mg',
   'description': 'A dosage of medication',
   'alias': ['half a gram', '0.5 grams']},
  {'id': 7,
   'name': 'Twice',
   'description': 'A frequency of medication administration',
   'al

In [11]:
nodes_id_idx = rag.dataset.nodes_id_idx

In [19]:
node_embs = rag.model.entity.weight.data.cpu().numpy()
print(len(node_embs))
node_embs[nodes_id_idx[1]][:5]

45


array([ 0.00581928, -0.02819522,  0.0190266 , -0.00319721, -0.00893134])

# Create Network X Graph

In [13]:
import networkx as nx
G = nx.Graph()

In [16]:
G.add_nodes_from([n.id for n in rag.graph.nodes])

In [20]:
edges = [(e.source, e.target) for e in rag.graph.edges]
G.add_edges_from(edges)

In [21]:
G.number_of_nodes()

30

In [22]:
G.number_of_edges()

23

In [28]:
personalization = {21: 1, 22: 1}
pr = nx.pagerank(G, alpha=0.85, personalization=personalization)

In [34]:
pr_sorted = sorted(pr.items(), key=lambda x: x[1], reverse=True)


[(1, 0.40607836769153777),
 (21, 0.13252695416293836),
 (22, 0.13252695416293836),
 (3, 0.08020664145829767),
 (2, 0.057526954162938335),
 (23, 0.057526954162938335),
 (24, 0.057526954162938335),
 (4, 0.053365886787377234),
 (5, 0.02267968729535934),
 (27, 4.94942181944509e-06),
 (6, 3.2996145462967315e-06),
 (9, 3.2996145462967315e-06),
 (16, 3.2996145462967315e-06),
 (12, 1.6498072731483657e-06),
 (13, 1.6498072731483657e-06),
 (14, 1.6498072731483657e-06),
 (15, 1.6498072731483657e-06),
 (19, 1.6498072731483657e-06),
 (20, 1.6498072731483657e-06),
 (25, 1.6498072731483657e-06),
 (26, 1.6498072731483657e-06),
 (7, 8.249036365741829e-07),
 (8, 8.249036365741829e-07),
 (10, 8.249036365741829e-07),
 (11, 8.249036365741829e-07),
 (17, 8.249036365741829e-07),
 (18, 8.249036365741829e-07),
 (28, 5.49935757716121e-07),
 (29, 5.49935757716121e-07),
 (30, 5.49935757716121e-07)]

In [29]:
pr

{1: 0.40607836769153777,
 2: 0.057526954162938335,
 3: 0.08020664145829767,
 4: 0.053365886787377234,
 5: 0.02267968729535934,
 6: 3.2996145462967315e-06,
 7: 8.249036365741829e-07,
 8: 8.249036365741829e-07,
 9: 3.2996145462967315e-06,
 10: 8.249036365741829e-07,
 11: 8.249036365741829e-07,
 12: 1.6498072731483657e-06,
 13: 1.6498072731483657e-06,
 14: 1.6498072731483657e-06,
 15: 1.6498072731483657e-06,
 16: 3.2996145462967315e-06,
 17: 8.249036365741829e-07,
 18: 8.249036365741829e-07,
 19: 1.6498072731483657e-06,
 20: 1.6498072731483657e-06,
 21: 0.13252695416293836,
 22: 0.13252695416293836,
 23: 0.057526954162938335,
 24: 0.057526954162938335,
 25: 1.6498072731483657e-06,
 26: 1.6498072731483657e-06,
 27: 4.94942181944509e-06,
 28: 5.49935757716121e-07,
 29: 5.49935757716121e-07,
 30: 5.49935757716121e-07}

In [25]:
pr

{1: 0.10465670434595327,
 2: 0.019827111440711198,
 3: 0.03622576628851659,
 4: 0.03858331731416875,
 5: 0.021398654847805402,
 6: 0.0486480043320244,
 7: 0.025675997833987798,
 8: 0.025675997833987798,
 9: 0.0486480043320244,
 10: 0.025675997833987798,
 11: 0.025675997833987798,
 12: 0.03333333333333333,
 13: 0.03333333333333333,
 14: 0.03333333333333333,
 15: 0.03333333333333333,
 16: 0.0486480043320244,
 17: 0.025675997833987798,
 18: 0.025675997833987798,
 19: 0.03333333333333333,
 20: 0.03333333333333333,
 21: 0.019827111440711198,
 22: 0.019827111440711198,
 23: 0.019827111440711198,
 24: 0.019827111440711198,
 25: 0.03333333333333333,
 26: 0.03333333333333333,
 27: 0.06396267533071547,
 28: 0.02312355266753929,
 29: 0.02312355266753929,
 30: 0.02312355266753929}

# Load CurveRag modules

In [24]:
# load KG
with open('rag.graph.json', 'r') as f:
    data = json.load(f)

#kg = KnowledgeGraph.construct(**data)
kg = KnowledgeGraph.parse_file("rag.graph.json")

/var/folders/m6/12prs8yx2r90psgztq05hy_c0000gn/T/ipykernel_82949/2232709649.py:6: PydanticDeprecatedSince20: The `parse_file` method is deprecated; load the data from file, then if your data is JSON use `model_validate_json`, otherwise `model_validate` instead. Deprecated in Pydantic V2.0 to be removed in V3.0. See Pydantic V2 Migration Guide at https://errors.pydantic.dev/2.11/migration/
  kg = KnowledgeGraph.parse_file("rag.graph.json")


In [23]:
kg.nodes[0]

Node(id=1, name='Patient', description='A person who is receiving medical treatment', alias=['individual', 'person'])

In [6]:
# load dataset
dataset_path = './data/medical_docs'
dataset = KGDataset(dataset_path, debug=False, name='medical_docs')

In [7]:
# load embeding model
sizes = dataset.get_shape()
embedding_model = AttH(sizes=sizes)
model_path = '/Users/nathan/Documents/projects/curve_rag/logs/05_05/test_run/AttH_00_38_35/model.pt'
embedding_model.load_state_dict(torch.load(model_path, weights_only=True))
embedding_model.eval()

AttH(
  (entity): Embedding(45, 1000)
  (rel): Embedding(48, 1000)
  (bh): Embedding(45, 1)
  (bt): Embedding(45, 1)
  (rel_diag): Embedding(48, 2000)
  (context_vec): Embedding(48, 1000)
  (act): Softmax(dim=1)
)

In [None]:
# load llm
max_tokens = 10000
n_ctx=10000
llm, outlines_model = utils.load_model(
    llm_model_path="./models/Meta-Llama-3-8B-Instruct.Q6_K.gguf",
    tokenizer=llama_cpp.llama_tokenizer.LlamaHFTokenizer.from_pretrained("unsloth/Llama-3.2-1B-Instruct"),
    n_ctx=n_ctx,
    max_tokens=max_tokens
)

In [67]:
resp = llm('Hi how are you?')

Llama.generate: 6 prefix-match hit, remaining 1 prompt tokens to eval
llama_perf_context_print:        load time =    2329.38 ms
llama_perf_context_print: prompt eval time =       0.00 ms /     1 tokens (    0.00 ms per token,      inf tokens per second)
llama_perf_context_print:        eval time =    5169.89 ms /    16 runs   (  323.12 ms per token,     3.09 tokens per second)
llama_perf_context_print:       total time =    5176.08 ms /    17 tokens


In [59]:
resp['choices'][0]['text']

" I hope you're doing well. I am from Mexico and I am a fan"

In [68]:
cg = CurveRAG.load_class(
    llm=llm,
    outlines_llm=outlines_model,
    entity_types=DEFAULT_ENTITY_TYPES,
    gliner_model_name=DEFAULT_GLINER_MODEL,
    graph=kg,
    graph_embedding_model=embedding_model,
    dataset=dataset,
)

Fetching 5 files:   0%|          | 0/5 [00:00<?, ?it/s]

In [69]:
entities = ['Patient', 'Diabetes Mellitus', 'Metformin', 'Type 2 Diabetes', 'Daily']

In [None]:
query = "What is used to treat Diabetes Mellitus and how does can a patient use the treatment?"
cg.query(query, max_tokens=1000)

entities ['patient']
node_ids []
node_ids []

You are a helpful assistant analyzing the given input data to provide an helpful response to the user query.

# USER QUERY
What is used to treat Diabetes Mellitus and how does can a patient use the treatment?

# Context:
KnowledgeGraph Overview
  There are 0 entities and 0 relationships in this graph.

Entities in the graph:

Relationships between nodes:



# INSTRUCTIONS
Your goal is to provide a response to the user query using the relevant information in the input data:
- the "Entities" and "Relationships" tables contain high-level information. Use these tables to identify the most important entities and relationships to respond to the query.
- the "Sources" list contains raw text sources to help answer the query. It may contain noisy data, so pay attention when analyzing it.

Follow these steps:
1. Read and understand the user query.
2. Look at the "Entities" and "Relationships" tables to get a general sense of the data and understand w

Llama.generate: 370 prefix-match hit, remaining 1 prompt tokens to eval


In [86]:
llm2("What is used to treat Diabetes Mellitus and how does can a patient use the treatment?",  max_tokens=1000)['choices'][0]['text']

Llama.generate: 10 prefix-match hit, remaining 10 prompt tokens to eval
llama_perf_context_print:        load time =    3865.92 ms
llama_perf_context_print: prompt eval time =    1877.04 ms /    10 tokens (  187.70 ms per token,     5.33 tokens per second)
llama_perf_context_print:        eval time =   51397.43 ms /   999 runs   (   51.45 ms per token,    19.44 tokens per second)
llama_perf_context_print:       total time =   54751.37 ms /  1009 tokens


" (2023)\nWhat is Diabetes Mellitus?\nDiabetes Mellitus, commonly referred to as diabetes, is a metabolic disorder that affects the way the body regulates blood sugar levels. There are several types of diabetes, but the most common ones are:\nType 1 diabetes: An autoimmune disease where the body's immune system attacks and destroys the cells in the pancreas that produce insulin, a hormone that regulates blood sugar levels.\nType 2 diabetes: A condition where the body becomes resistant to insulin, making it difficult for the body to use insulin effectively.\nGestational diabetes: A type of diabetes that develops during pregnancy due to hormonal changes.\nWhat are the symptoms of Diabetes Mellitus?\nThe symptoms of diabetes can vary from person to person, but common signs and symptoms include:\nIncreased thirst and urination\nFatigue and weakness\nBlurred vision\nSlow healing of cuts and wounds\n Tingling or numbness in the hands and feet\nHow is Diabetes Mellitus treated?\nThe treatment

In [None]:
from llama_cpp import Llama
llm2 = Llama(
        model_path="./models/Meta-Llama-3-8B-Instruct.Q6_K.gguf",
        tokenizer=llama_cpp.llama_tokenizer.LlamaHFTokenizer.from_pretrained("unsloth/Llama-3.2-1B-Instruct"),
        n_ctx=n_ctx,
        max_tokens=max_tokens
    )