In [14]:
!pip install --upgrade pip setuptools wheel torch-geometric pykeen neo4j



# **Imports**

In [15]:
from pykeen.datasets import PrimeKG
import pandas as pd
import torch
import torch.nn.functional as F
from torch_geometric.nn.kge import DistMult
from neo4j import GraphDatabase

# **Loading and inspecting the PrimeKG dataset**

In [16]:
dataset = PrimeKG()

print(dataset)

train_factory = dataset.training
valid_factory = dataset.validation
test_factory = dataset.testing

print("Training triples:", train_factory.num_triples)
print("Validation triples:", valid_factory.num_triples)
print("Test triples:", test_factory.num_triples)

print("\nSample training triples:")
for triple in train_factory.mapped_triples[:5]:
    h, r, t = triple.tolist()
    print(
        train_factory.entity_id_to_label[h],
        "--",
        train_factory.relation_id_to_label[r],
        "-->",
        train_factory.entity_id_to_label[t],
    )


INFO:pykeen.datasets.base:reordering columns: ['x_name', 'relation', 'y_name']
INFO:pykeen.triples.splitting:done splitting triples to groups of sizes [6350753, 809999, 810000]


PrimeKG(num_entities=129262, num_relations=30, create_inverse_triples=False)
Training triples: 6479992
Validation triples: 810000
Test triples: 809999

Sample training triples:
'de novo' AMP biosynthetic process -- bioprocess_bioprocess --> AMP biosynthetic process
'de novo' AMP biosynthetic process -- bioprocess_protein --> ADSL
'de novo' CTP biosynthetic process -- bioprocess_bioprocess --> CTP biosynthetic process
'de novo' GDP-L-fucose biosynthetic process -- bioprocess_bioprocess --> GDP-L-fucose biosynthetic process
'de novo' IMP biosynthetic process -- bioprocess_bioprocess --> IMP biosynthetic process


# **Creating a readable DataFrame of Knowledge Graph Triples**

In [17]:
triples = train_factory.mapped_triples.tolist()
df = pd.DataFrame(triples, columns=["head_id", "relation_id", "tail_id"])

df["head"] = df["head_id"].map(train_factory.entity_id_to_label)
df["relation"] = df["relation_id"].map(train_factory.relation_id_to_label)
df["tail"] = df["tail_id"].map(train_factory.entity_id_to_label)

df = df[["head", "relation", "tail"]]
print(df.head(10))

                                                head               relation  \
0                 'de novo' AMP biosynthetic process  bioprocess_bioprocess   
1                 'de novo' AMP biosynthetic process     bioprocess_protein   
2                 'de novo' CTP biosynthetic process  bioprocess_bioprocess   
3        'de novo' GDP-L-fucose biosynthetic process  bioprocess_bioprocess   
4                 'de novo' IMP biosynthetic process  bioprocess_bioprocess   
5        'de novo' L-methionine biosynthetic process  bioprocess_bioprocess   
6                 'de novo' NAD biosynthetic process  bioprocess_bioprocess   
7  'de novo' NAD biosynthetic process from aspartate  bioprocess_bioprocess   
8  'de novo' NAD biosynthetic process from trypto...  bioprocess_bioprocess   
9                 'de novo' UMP biosynthetic process  bioprocess_bioprocess   

                                                tail  
0                           AMP biosynthetic process  
1                   

# **Preparing Triples for Embedding Models**

In [18]:
triples_tensor = torch.tensor(train_factory.mapped_triples, dtype=torch.long)
num_entities = dataset.num_entities
num_relations = dataset.num_relations

print("Entities:", num_entities)
print("Relations:", num_relations)
print("Triples:", triples_tensor.shape)

Entities: 129262
Relations: 30
Triples: torch.Size([6479992, 3])


  triples_tensor = torch.tensor(train_factory.mapped_triples, dtype=torch.long)


# **Defining the DistMult model**

In [19]:
model = DistMult(
    num_nodes = num_entities,
    num_relations = num_relations,
    hidden_channels = 128,
    margin = 1.0,
)

# **Training the DistMult model with Negative Sampling**

In [20]:
def train(model, triples, num_entities, epochs=100, lr=0.01, weight_decay=1e-5):
    optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay)
    model.train()

    for epoch in range(epochs):
        optimizer.zero_grad()

        h = triples[:, 0]
        r = triples[:, 1]
        t = triples[:, 2]

        pos_score = model(h, r, t)

        mask = torch.rand(len(h)) < 0.5
        h_neg = h.clone()
        t_neg = t.clone()
        h_neg[mask] = torch.randint(0, num_entities, (mask.sum().item(),), dtype=torch.long)
        t_neg[~mask] = torch.randint(0, num_entities, ((~mask).sum().item(),), dtype=torch.long)

        neg_score = model(h_neg, r, t_neg)

        loss = F.softplus(neg_score + model.margin - pos_score).mean()

        loss.backward()
        optimizer.step()

        with torch.no_grad():
            model.node_emb.weight.div_(model.node_emb.weight.norm(dim=1, keepdim=True) + 1e-9)
            model.rel_emb.weight.div_(model.rel_emb.weight.norm(dim=1, keepdim=True) + 1e-9)

        if epoch % 5 == 0 or epoch == epochs - 1:
            print(f"Epoch {epoch+1}/{epochs} | Loss: {loss.item():.4f}")

triples_tensor = torch.tensor(train_factory.mapped_triples[:100000], dtype=torch.long)
train(model, triples_tensor, num_entities=num_entities, epochs=300, lr=0.01)


  triples_tensor = torch.tensor(train_factory.mapped_triples[:100000], dtype=torch.long)


Epoch 1/300 | Loss: 1.3133
Epoch 6/300 | Loss: 1.3195
Epoch 11/300 | Loss: 1.3144
Epoch 16/300 | Loss: 1.3072
Epoch 21/300 | Loss: 1.3031
Epoch 26/300 | Loss: 1.3009
Epoch 31/300 | Loss: 1.2989
Epoch 36/300 | Loss: 1.2979
Epoch 41/300 | Loss: 1.2977
Epoch 46/300 | Loss: 1.2980
Epoch 51/300 | Loss: 1.2983
Epoch 56/300 | Loss: 1.2984
Epoch 61/300 | Loss: 1.2980
Epoch 66/300 | Loss: 1.2972
Epoch 71/300 | Loss: 1.2960
Epoch 76/300 | Loss: 1.2949
Epoch 81/300 | Loss: 1.2938
Epoch 86/300 | Loss: 1.2930
Epoch 91/300 | Loss: 1.2924
Epoch 96/300 | Loss: 1.2918
Epoch 101/300 | Loss: 1.2915
Epoch 106/300 | Loss: 1.2913
Epoch 111/300 | Loss: 1.2912
Epoch 116/300 | Loss: 1.2911
Epoch 121/300 | Loss: 1.2909
Epoch 126/300 | Loss: 1.2908
Epoch 131/300 | Loss: 1.2908
Epoch 136/300 | Loss: 1.2907
Epoch 141/300 | Loss: 1.2905
Epoch 146/300 | Loss: 1.2904
Epoch 151/300 | Loss: 1.2903
Epoch 156/300 | Loss: 1.2902
Epoch 161/300 | Loss: 1.2901
Epoch 166/300 | Loss: 1.2899
Epoch 171/300 | Loss: 1.2896
Epoch 1

# **Computing Probabilistic Priors from the trained model**

In [21]:
def compute_priors(model, triples_tensor, factory, sample_size=200):
    model.eval()
    with torch.no_grad():
        n = min(sample_size, triples_tensor.size(0))
        h = triples_tensor[:n, 0]
        r = triples_tensor[:n, 1]
        t = triples_tensor[:n, 2]

        scores = model(h, r, t)
        scores = torch.clamp(scores, min=-10.0, max=10.0)
        probs = torch.sigmoid(scores)

    result = []
    for i in range(n):
        result.append({
            "head":    factory.entity_id_to_label[h[i].item()],
            "relation":factory.relation_id_to_label[r[i].item()],
            "tail":    factory.entity_id_to_label[t[i].item()],
            "prior":   float(probs[i].item())
        })
    return result

subset_priors = compute_priors(model, triples_tensor, train_factory, sample_size=200)
for row in subset_priors[:10]:
    print(f"{row['head']} --{row['relation']}--> {row['tail']}  [prior={row['prior']:.3f}]")


'de novo' AMP biosynthetic process --bioprocess_bioprocess--> AMP biosynthetic process  [prior=0.522]
'de novo' AMP biosynthetic process --bioprocess_protein--> ADSL  [prior=0.501]
'de novo' CTP biosynthetic process --bioprocess_bioprocess--> CTP biosynthetic process  [prior=0.514]
'de novo' GDP-L-fucose biosynthetic process --bioprocess_bioprocess--> GDP-L-fucose biosynthetic process  [prior=0.508]
'de novo' IMP biosynthetic process --bioprocess_bioprocess--> IMP biosynthetic process  [prior=0.518]
'de novo' L-methionine biosynthetic process --bioprocess_bioprocess--> L-methionine biosynthetic process  [prior=0.517]
'de novo' NAD biosynthetic process --bioprocess_bioprocess--> 'de novo' NAD biosynthetic process from aspartate  [prior=0.554]
'de novo' NAD biosynthetic process from aspartate --bioprocess_bioprocess--> 'de novo' NAD biosynthetic process  [prior=0.554]
'de novo' NAD biosynthetic process from tryptophan --bioprocess_bioprocess--> 'de novo' NAD biosynthetic process  [prior=

# **Connect to Neo4j**

In [22]:
NEO4J_URI="neo4j+s://36039b8d.databases.neo4j.io"
NEO4J_USERNAME="neo4j"
NEO4J_PASSWORD="2GSEZotjB6q0ZB8CBDP36BPfhOo4Ip60hPLTURPIqrU"

driver = GraphDatabase.driver(NEO4J_URI, auth=(NEO4J_USERNAME, NEO4J_PASSWORD))


# **Upload the triplets**

In [23]:

def upload_priors(driver, records, batch_size=1000):
    with driver.session() as session:
        for i in range(0, len(records), batch_size):
            batch = records[i : i + batch_size]
            session.run(
                """
                UNWIND $batch AS rec
                MERGE (h:Entity {name: rec.head})
                MERGE (t:Entity {name: rec.tail})
                MERGE (h)-[r:RELATION {type: rec.relation}]->(t)
                SET r.prior = rec.prior
                """,
                batch=batch
            )
records = subset_priors
upload_priors(driver, records, batch_size=500)
driver.close()