In [27]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [1]:
import llama_cpp
from curverag import utils
from curverag.curverag import CurveRAG

## Dataset and LLM

In [2]:
max_tokens = 10000
n_ctx=10000

docs = [
    "The patient was diagnosed with type 2 diabetes mellitus and prescribed metformin 500mg twice daily.",
    "MRI scan revealed a small lesion in the left temporal lobe suggestive of low-grade glioma.",
    "Administer 5mg of lorazepam intravenously for acute seizure management.",
    "Blood tests showed elevated ALT and AST levels, indicating possible liver inflammation.",
    "The subject reported chronic lower back pain, managed with physical therapy and NSAIDs.",
    "CT angiography confirmed the presence of a pulmonary embolism in the right lower lobe.",
    "The patient underwent coronary artery bypass graft surgery without complications.",
    "Routine vaccination included MMR, tetanus, and influenza immunizations.",
    "Histopathology indicated ductal carcinoma in situ (DCIS) in the breast biopsy sample.",
    "The child presented with a persistent cough and fever, diagnosed as streptococcal pharyngitis."
]


In [None]:
model = utils.load_model(
    llm_model_path="./models/Meta-Llama-3-8B-Instruct.Q6_K.gguf",
    tokenizer=llama_cpp.llama_tokenizer.LlamaHFTokenizer.from_pretrained("unsloth/Llama-3.2-1B-Instruct"),
    n_ctx=n_ctx,
    max_tokens=max_tokens
)

## Fit CurveRAG

In [4]:
rag = CurveRAG(llm=model)

In [None]:
rag.fit(docs, dataset_name='test_run')

In [9]:
len(rag.graph.nodes)

45

In [None]:
# Get all entity embeddings (n_entities x embedding_dim)
entity_embeddings = model.entity.weight.data.cpu().numpy()

# Get relation embeddings (n_relations x embedding_dim)
relation_embeddings = model.rel.weight.data.cpu().numpy()

In [11]:
nodes_id_idx = rag.dataset.nodes_id_idx

In [17]:
node_embs = rag.model.entity.weight.data.cpu().numpy()
print(len(node_embs))
node_embs[nodes_id_idx[1]]

45


array([ 5.81928278e-03, -2.81952184e-02,  1.90265952e-02, -3.19721146e-03,
       -8.93134130e-03,  1.43143845e-02, -2.16812938e-02, -1.23798875e-02,
       -7.85401404e-03, -3.40788820e-02,  1.50577328e-02,  1.58859481e-02,
       -9.34734813e-02, -3.82454758e-02,  1.76710727e-02,  2.20435778e-03,
        6.64168193e-04, -3.41603290e-02, -3.09017472e-02, -1.60718260e-02,
        5.08408444e-03, -2.70790972e-02, -3.00340165e-03, -3.25994164e-02,
        9.25533052e-04, -9.53712346e-03,  1.08343951e-02, -3.17294051e-03,
        1.82067312e-02,  2.57764442e-02, -8.24327401e-03, -4.88753555e-04,
       -4.17824340e-04,  3.68967587e-02, -1.62313360e-02, -5.70871814e-03,
       -2.35310674e-02, -7.44719694e-03,  1.47713502e-02, -1.01373686e-02,
       -1.05396421e-02, -2.19445735e-05,  8.50790638e-03,  1.04352489e-02,
       -1.67183165e-03, -1.19690442e-02, -1.29760279e-02,  2.76977812e-02,
       -4.83249188e-03, -4.59397969e-03,  9.55424974e-03, -9.39276756e-03,
       -8.52880972e-03, -