In [1]:
%load_ext autoreload
%autoreload 2

In [4]:
import json

import llama_cpp
import torch

from curverag import utils
from curverag.curverag import CurveRAG, DEFAULT_ENTITY_TYPES, DEFAULT_GLINER_MODEL
from curverag.graph import KnowledgeGraph
from curverag.atth.kg_dataset import KGDataset
from curverag.atth.models.hyperbolic import AttH

## Dataset and LLM

In [6]:
max_tokens = 10000
n_ctx=10000

docs = [
    "The patient was diagnosed with type 2 diabetes mellitus and prescribed metformin 500mg twice daily.",
    "MRI scan revealed a small lesion in the left temporal lobe suggestive of low-grade glioma.",
    "Administer 5mg of lorazepam intravenously for acute seizure management.",
    "Blood tests showed elevated ALT and AST levels, indicating possible liver inflammation.",
    "The subject reported chronic lower back pain, managed with physical therapy and NSAIDs.",
    "CT angiography confirmed the presence of a pulmonary embolism in the right lower lobe.",
    "The patient underwent coronary artery bypass graft surgery without complications.",
    "Routine vaccination included MMR, tetanus, and influenza immunizations.",
    "Histopathology indicated ductal carcinoma in situ (DCIS) in the breast biopsy sample.",
    "The child presented with a persistent cough and fever, diagnosed as streptococcal pharyngitis."
]


In [None]:
model = utils.load_model(
    llm_model_path="./models/Meta-Llama-3-8B-Instruct.Q6_K.gguf",
    tokenizer=llama_cpp.llama_tokenizer.LlamaHFTokenizer.from_pretrained("unsloth/Llama-3.2-1B-Instruct"),
    n_ctx=n_ctx,
    max_tokens=max_tokens
)

## Fit CurveRAG

In [8]:
rag = CurveRAG(llm=model)

Fetching 5 files:   0%|          | 0/5 [00:00<?, ?it/s]

gliner_config.json:   0%|          | 0.00/476 [00:00<?, ?B/s]

.gitattributes:   0%|          | 0.00/1.52k [00:00<?, ?B/s]

pytorch_model.bin:   0%|          | 0.00/781M [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/781M [00:00<?, ?B/s]

README.md:   0%|          | 0.00/4.76k [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/52.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/579 [00:00<?, ?B/s]

spm.model:   0%|          | 0.00/2.46M [00:00<?, ?B/s]



In [None]:
rag.fit(docs, dataset_name='test_run')

In [10]:
len(rag.graph.nodes)

45

In [13]:
import json
with open('rag.graph.json', 'w') as f:
    json.dump(rag.graph.dict(), f, indent=4)

/var/folders/m6/12prs8yx2r90psgztq05hy_c0000gn/T/ipykernel_69670/946862156.py:3: PydanticDeprecatedSince20: The `dict` method is deprecated; use `model_dump` instead. Deprecated in Pydantic V2.0 to be removed in V3.0. See Pydantic V2 Migration Guide at https://errors.pydantic.dev/2.11/migration/
  json.dump(rag.graph.dict(), f, indent=4)


/var/folders/m6/12prs8yx2r90psgztq05hy_c0000gn/T/ipykernel_69670/3214747748.py:5: PydanticDeprecatedSince20: The `construct` method is deprecated; use `model_construct` instead. Deprecated in Pydantic V2.0 to be removed in V3.0. See Pydantic V2 Migration Guide at https://errors.pydantic.dev/2.11/migration/
  kg = KnowledgeGraph.construct(**data)


In [18]:
kg.nodes

[{'id': 1,
  'name': 'Patient',
  'description': 'A person who is receiving medical treatment',
  'alias': ['individual', 'person']},
 {'id': 2,
  'name': 'Diabetes Mellitus',
  'description': 'A group of metabolic disorders characterized by high blood sugar levels',
  'alias': ['diabetes', 'DM']},
 {'id': 3,
  'name': 'Metformin',
  'description': 'A medication used to treat type 2 diabetes',
  'alias': ['glucophage', 'metformin hydrochloride']},
 {'id': 4,
  'name': 'Type 2 Diabetes',
  'description': 'A type of diabetes that is caused by insulin resistance and impaired insulin secretion',
  'alias': ['T2D', 'non-insulin-dependent diabetes']},
 {'id': 5,
  'name': 'Daily',
  'description': 'A unit of time',
  'alias': ['day', '24 hours']},
 {'id': 6,
  'name': '500mg',
  'description': 'A dosage of medication',
  'alias': ['half a gram', '0.5 grams']},
 {'id': 7,
  'name': 'Twice',
  'description': 'A frequency of medication administration',
  'alias': ['every 12 hours', 'b.i.d.']},


In [11]:
rag.graph.dict()

/var/folders/m6/12prs8yx2r90psgztq05hy_c0000gn/T/ipykernel_69670/774845048.py:1: PydanticDeprecatedSince20: The `dict` method is deprecated; use `model_dump` instead. Deprecated in Pydantic V2.0 to be removed in V3.0. See Pydantic V2 Migration Guide at https://errors.pydantic.dev/2.11/migration/
  rag.graph.dict()


{'nodes': [{'id': 1,
   'name': 'Patient',
   'description': 'A person who is receiving medical treatment',
   'alias': ['individual', 'person']},
  {'id': 2,
   'name': 'Diabetes Mellitus',
   'description': 'A group of metabolic disorders characterized by high blood sugar levels',
   'alias': ['diabetes', 'DM']},
  {'id': 3,
   'name': 'Metformin',
   'description': 'A medication used to treat type 2 diabetes',
   'alias': ['glucophage', 'metformin hydrochloride']},
  {'id': 4,
   'name': 'Type 2 Diabetes',
   'description': 'A type of diabetes that is caused by insulin resistance and impaired insulin secretion',
   'alias': ['T2D', 'non-insulin-dependent diabetes']},
  {'id': 5,
   'name': 'Daily',
   'description': 'A unit of time',
   'alias': ['day', '24 hours']},
  {'id': 6,
   'name': '500mg',
   'description': 'A dosage of medication',
   'alias': ['half a gram', '0.5 grams']},
  {'id': 7,
   'name': 'Twice',
   'description': 'A frequency of medication administration',
   'al

In [11]:
nodes_id_idx = rag.dataset.nodes_id_idx

In [19]:
node_embs = rag.model.entity.weight.data.cpu().numpy()
print(len(node_embs))
node_embs[nodes_id_idx[1]][:5]

45


array([ 0.00581928, -0.02819522,  0.0190266 , -0.00319721, -0.00893134])

# Load CurveRag modules

In [24]:
# load KG
with open('rag.graph.json', 'r') as f:
    data = json.load(f)

#kg = KnowledgeGraph.construct(**data)
kg = KnowledgeGraph.parse_file("rag.graph.json")

/var/folders/m6/12prs8yx2r90psgztq05hy_c0000gn/T/ipykernel_82949/2232709649.py:6: PydanticDeprecatedSince20: The `parse_file` method is deprecated; load the data from file, then if your data is JSON use `model_validate_json`, otherwise `model_validate` instead. Deprecated in Pydantic V2.0 to be removed in V3.0. See Pydantic V2 Migration Guide at https://errors.pydantic.dev/2.11/migration/
  kg = KnowledgeGraph.parse_file("rag.graph.json")


In [23]:
kg.nodes[0]

Node(id=1, name='Patient', description='A person who is receiving medical treatment', alias=['individual', 'person'])

In [6]:
# load dataset
dataset_path = './data/medical_docs'
dataset = KGDataset(dataset_path, debug=False, name='medical_docs')

In [7]:
# load embeding model
sizes = dataset.get_shape()
embedding_model = AttH(sizes=sizes)
model_path = '/Users/nathan/Documents/projects/curve_rag/logs/05_05/test_run/AttH_00_38_35/model.pt'
embedding_model.load_state_dict(torch.load(model_path, weights_only=True))
embedding_model.eval()

AttH(
  (entity): Embedding(45, 1000)
  (rel): Embedding(48, 1000)
  (bh): Embedding(45, 1)
  (bt): Embedding(45, 1)
  (rel_diag): Embedding(48, 2000)
  (context_vec): Embedding(48, 1000)
  (act): Softmax(dim=1)
)

In [None]:
# load llm
max_tokens = 10000
n_ctx=10000
llm, outlines_model = utils.load_model(
    llm_model_path="./models/Meta-Llama-3-8B-Instruct.Q6_K.gguf",
    tokenizer=llama_cpp.llama_tokenizer.LlamaHFTokenizer.from_pretrained("unsloth/Llama-3.2-1B-Instruct"),
    n_ctx=n_ctx,
    max_tokens=max_tokens
)

In [67]:
resp = llm('Hi how are you?')

Llama.generate: 6 prefix-match hit, remaining 1 prompt tokens to eval
llama_perf_context_print:        load time =    2329.38 ms
llama_perf_context_print: prompt eval time =       0.00 ms /     1 tokens (    0.00 ms per token,      inf tokens per second)
llama_perf_context_print:        eval time =    5169.89 ms /    16 runs   (  323.12 ms per token,     3.09 tokens per second)
llama_perf_context_print:       total time =    5176.08 ms /    17 tokens


In [59]:
resp['choices'][0]['text']

" I hope you're doing well. I am from Mexico and I am a fan"

In [68]:
cg = CurveRAG.load_class(
    llm=llm,
    outlines_llm=outlines_model,
    entity_types=DEFAULT_ENTITY_TYPES,
    gliner_model_name=DEFAULT_GLINER_MODEL,
    graph=kg,
    graph_embedding_model=embedding_model,
    dataset=dataset,
)

Fetching 5 files:   0%|          | 0/5 [00:00<?, ?it/s]

In [69]:
entities = ['Patient', 'Diabetes Mellitus', 'Metformin', 'Type 2 Diabetes', 'Daily']

In [None]:
query = "What is used to treat Diabetes Mellitus and how does can a patient use the treatment?"
cg.query(query, max_tokens=1000)

entities ['patient']
node_ids []
node_ids []

You are a helpful assistant analyzing the given input data to provide an helpful response to the user query.

# USER QUERY
What is used to treat Diabetes Mellitus and how does can a patient use the treatment?

# Context:
KnowledgeGraph Overview
  There are 0 entities and 0 relationships in this graph.

Entities in the graph:

Relationships between nodes:



# INSTRUCTIONS
Your goal is to provide a response to the user query using the relevant information in the input data:
- the "Entities" and "Relationships" tables contain high-level information. Use these tables to identify the most important entities and relationships to respond to the query.
- the "Sources" list contains raw text sources to help answer the query. It may contain noisy data, so pay attention when analyzing it.

Follow these steps:
1. Read and understand the user query.
2. Look at the "Entities" and "Relationships" tables to get a general sense of the data and understand w

Llama.generate: 370 prefix-match hit, remaining 1 prompt tokens to eval


In [86]:
llm2("What is used to treat Diabetes Mellitus and how does can a patient use the treatment?",  max_tokens=1000)['choices'][0]['text']

Llama.generate: 10 prefix-match hit, remaining 10 prompt tokens to eval
llama_perf_context_print:        load time =    3865.92 ms
llama_perf_context_print: prompt eval time =    1877.04 ms /    10 tokens (  187.70 ms per token,     5.33 tokens per second)
llama_perf_context_print:        eval time =   51397.43 ms /   999 runs   (   51.45 ms per token,    19.44 tokens per second)
llama_perf_context_print:       total time =   54751.37 ms /  1009 tokens


" (2023)\nWhat is Diabetes Mellitus?\nDiabetes Mellitus, commonly referred to as diabetes, is a metabolic disorder that affects the way the body regulates blood sugar levels. There are several types of diabetes, but the most common ones are:\nType 1 diabetes: An autoimmune disease where the body's immune system attacks and destroys the cells in the pancreas that produce insulin, a hormone that regulates blood sugar levels.\nType 2 diabetes: A condition where the body becomes resistant to insulin, making it difficult for the body to use insulin effectively.\nGestational diabetes: A type of diabetes that develops during pregnancy due to hormonal changes.\nWhat are the symptoms of Diabetes Mellitus?\nThe symptoms of diabetes can vary from person to person, but common signs and symptoms include:\nIncreased thirst and urination\nFatigue and weakness\nBlurred vision\nSlow healing of cuts and wounds\n Tingling or numbness in the hands and feet\nHow is Diabetes Mellitus treated?\nThe treatment

In [None]:
from llama_cpp import Llama
llm2 = Llama(
        model_path="./models/Meta-Llama-3-8B-Instruct.Q6_K.gguf",
        tokenizer=llama_cpp.llama_tokenizer.LlamaHFTokenizer.from_pretrained("unsloth/Llama-3.2-1B-Instruct"),
        n_ctx=n_ctx,
        max_tokens=max_tokens
    )