In [14]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [27]:
import os

import numpy as np

from curverag.atth.kg_dataset import KGDataset

Load Dataset

In [17]:
dataset_path = './data/medical_docs'
dataset = KGDataset(dataset_path, debug=False)

In [19]:
dataset.get_shape()

(45, 48, 45)

# Load saved KG Model (AttH)model

In [20]:
import torch
from curverag.atth.models.hyperbolic import AttH

In [23]:
class ModelArgs:
    def __init__(self, dataset):
        """Knowledge Graph Embedding Model Configuration"""
        self.rank = 1000                # Embedding dimension
        self.learning_rate = 1e-1       # Initial learning rate
        self.batch_size = 1000          # Training batch size
        self.reg = 0.1                  # Regularization strength
        self.max_epochs = 10            # Maximum training epochs
        self.patience = 20              # Early stopping patience
        self.debug = False              # Debug mode flag
        self.data_type = 'double'           # Data type (float32/double)
        self.neg_sample_size = 50       # Negative samples per positive
        self.double_neg = True          # Use double negative sampling
        self.bias = 'constant'          # Bias type in model
        self.init_size = 1e-3           # Embedding initialization scale
        self.multi_c = True             # Multiple curvatures (hyperbolic)
        self.dropout = 0                # Dropout rate
        self.sizes = dataset.get_shape()# (n_entities, n_relations) from dataset
        self.gamma = 0

    def __repr__(self):
        return f"ModelArgs({vars(self)})"
model_args = ModelArgs(dataset)

In [25]:
model = AttH(model_args)
model_path = '/Users/nathan/Documents/projects/curve_rag/logs/05_01/medical_docs/AttH_23_03_31/model.pt'
model.load_state_dict(torch.load(model_path, weights_only=True))
model.eval()

AttH(
  (entity): Embedding(45, 1000)
  (rel): Embedding(48, 1000)
  (bh): Embedding(45, 1)
  (bt): Embedding(45, 1)
  (rel_diag): Embedding(48, 2000)
  (context_vec): Embedding(48, 1000)
  (act): Softmax(dim=1)
)

# Get Embeddings

In [28]:
# Get all entity embeddings (n_entities x embedding_dim)
entity_embeddings = model.entity.weight.data.cpu().numpy()

# Get relation embeddings (n_relations x embedding_dim)
relation_embeddings = model.rel.weight.data.cpu().numpy()

# Save to file
os.makedirs("./embeddings", exist_ok=True)

np.save("./embeddings/entity_emb.npy", entity_embeddings)
np.save("./embeddings/relation_emb.npy", relation_embeddings)

In [29]:
entity_embeddings

array([[-0.10088129,  0.09966055,  0.10105066, ...,  0.09940649,
        -0.09916997, -0.08717234],
       [-0.0993396 ,  0.10219772, -0.10032152, ..., -0.10193137,
         0.10069192,  0.09881262],
       [ 0.10167376, -0.10083576,  0.09849502, ..., -0.09702491,
         0.10110857, -0.10045891],
       ...,
       [ 0.09495796, -0.09677415,  0.09946039, ..., -0.09818554,
         0.09982164,  0.09723762],
       [ 0.09853571, -0.10173274,  0.09652006, ...,  0.09986235,
        -0.09847258, -0.09740754],
       [ 0.10125503,  0.10050082, -0.09904297, ..., -0.10006015,
         0.0951002 , -0.09867576]], shape=(45, 1000))

# Make prediction using graph

In [30]:
import torch
model.eval()

# Example prediction: (h, r) -> predict t
def predict_tail(h, r, top_k=5):
    with torch.no_grad():
        scores, _ = model.get_queries(torch.tensor([[h, r, 0]]))
        #print(type(scores[0]), scores)
        scores = scores[0]
        values, indices = torch.topk(scores, k=top_k)
        return [(idx, score.item()) 
                for idx, score in zip(indices[0], values[0])]

# Usage
predictions = predict_tail(h=0, r=0)  # (Paris, capital_of) -> top 5 countries

In [31]:
predictions

[(tensor(457), 0.05623408789488405),
 (tensor(465), 0.05482293428543924),
 (tensor(676), 0.05458844254021787),
 (tensor(495), 0.05448341048716934),
 (tensor(881), 0.05444969239179038)]