In [3]:
!pip install --upgrade pip setuptools wheel torch-geometric pykeen neo4j pyro-ppl



# **Imports**

In [4]:
from pykeen.datasets import PrimeKG
import pandas as pd
import torch
import torch.nn.functional as F
from torch_geometric.nn.kge import DistMult
from neo4j import GraphDatabase
import pyro
import pyro.distributions as dist
from pyro.infer import MCMC, NUTS
import numpy as np
import matplotlib.pyplot as plt
from collections import deque

INFO:pykeen.utils:Using opt_einsum


# **Loading and inspecting the PrimeKG dataset**

In [5]:
dataset = PrimeKG()

print(dataset)

train_factory = dataset.training
valid_factory = dataset.validation
test_factory = dataset.testing

print("Training triples:", train_factory.num_triples)
print("Validation triples:", valid_factory.num_triples)
print("Test triples:", test_factory.num_triples)

print("\nSample training triples:")
for triple in train_factory.mapped_triples[:5]:
    h, r, t = triple.tolist()
    print(
        train_factory.entity_id_to_label[h],
        "--",
        train_factory.relation_id_to_label[r],
        "-->",
        train_factory.entity_id_to_label[t],
    )


INFO:pykeen.datasets.base:reordering columns: ['x_name', 'relation', 'y_name']
INFO:pykeen.triples.splitting:done splitting triples to groups of sizes [6350753, 809999, 810000]


PrimeKG(num_entities=129262, num_relations=30, create_inverse_triples=False)
Training triples: 6479992
Validation triples: 810000
Test triples: 809999

Sample training triples:
'de novo' AMP biosynthetic process -- bioprocess_bioprocess --> AMP biosynthetic process
'de novo' AMP biosynthetic process -- bioprocess_protein --> ADSL
'de novo' CTP biosynthetic process -- bioprocess_bioprocess --> CTP biosynthetic process
'de novo' GDP-L-fucose biosynthetic process -- bioprocess_bioprocess --> GDP-L-fucose biosynthetic process
'de novo' IMP biosynthetic process -- bioprocess_bioprocess --> IMP biosynthetic process


# **Creating a readable DataFrame of Knowledge Graph Triples**

In [6]:
triples = train_factory.mapped_triples.tolist()
df = pd.DataFrame(triples, columns=["head_id", "relation_id", "tail_id"])

df["head"] = df["head_id"].map(train_factory.entity_id_to_label)
df["relation"] = df["relation_id"].map(train_factory.relation_id_to_label)
df["tail"] = df["tail_id"].map(train_factory.entity_id_to_label)

df = df[["head", "relation", "tail"]]
print(df.head(10))

                                                head               relation  \
0                 'de novo' AMP biosynthetic process  bioprocess_bioprocess   
1                 'de novo' AMP biosynthetic process     bioprocess_protein   
2                 'de novo' CTP biosynthetic process  bioprocess_bioprocess   
3        'de novo' GDP-L-fucose biosynthetic process  bioprocess_bioprocess   
4                 'de novo' IMP biosynthetic process  bioprocess_bioprocess   
5        'de novo' L-methionine biosynthetic process  bioprocess_bioprocess   
6                 'de novo' NAD biosynthetic process  bioprocess_bioprocess   
7  'de novo' NAD biosynthetic process from aspartate  bioprocess_bioprocess   
8  'de novo' NAD biosynthetic process from trypto...  bioprocess_bioprocess   
9                 'de novo' UMP biosynthetic process  bioprocess_bioprocess   

                                                tail  
0                           AMP biosynthetic process  
1                   

# **Preparing Triples for Embedding Models**

In [7]:
triples_tensor = torch.tensor(train_factory.mapped_triples, dtype=torch.long)
num_entities = dataset.num_entities
num_relations = dataset.num_relations

print("Entities:", num_entities)
print("Relations:", num_relations)
print("Triples:", triples_tensor.shape)

Entities: 129262
Relations: 30
Triples: torch.Size([6479992, 3])


  triples_tensor = torch.tensor(train_factory.mapped_triples, dtype=torch.long)


# **Defining the DistMult model**

In [8]:
model = DistMult(
    num_nodes = num_entities,
    num_relations = num_relations,
    hidden_channels = 128,
    margin = 1.0,
)

# **Training the DistMult model with Negative Sampling**

In [9]:
def train(model, triples, num_entities, epochs=100, lr=0.01, weight_decay=1e-5):
    optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay)
    model.train()

    for epoch in range(epochs):
        optimizer.zero_grad()

        h = triples[:, 0]
        r = triples[:, 1]
        t = triples[:, 2]

        pos_score = model(h, r, t)

        mask = torch.rand(len(h)) < 0.5
        h_neg = h.clone()
        t_neg = t.clone()
        h_neg[mask] = torch.randint(0, num_entities, (mask.sum().item(),), dtype=torch.long)
        t_neg[~mask] = torch.randint(0, num_entities, ((~mask).sum().item(),), dtype=torch.long)

        neg_score = model(h_neg, r, t_neg)

        loss = F.softplus(neg_score + model.margin - pos_score).mean()

        loss.backward()
        optimizer.step()

        with torch.no_grad():
            model.node_emb.weight.div_(model.node_emb.weight.norm(dim=1, keepdim=True) + 1e-9)
            model.rel_emb.weight.div_(model.rel_emb.weight.norm(dim=1, keepdim=True) + 1e-9)

        if epoch % 5 == 0 or epoch == epochs - 1:
            print(f"Epoch {epoch+1}/{epochs} | Loss: {loss.item():.4f}")

triples_tensor = torch.tensor(train_factory.mapped_triples[:100000], dtype=torch.long)
train(model, triples_tensor, num_entities=num_entities, epochs=100, lr=0.01)


  triples_tensor = torch.tensor(train_factory.mapped_triples[:100000], dtype=torch.long)


Epoch 1/100 | Loss: 1.3133
Epoch 6/100 | Loss: 1.3195
Epoch 11/100 | Loss: 1.3144
Epoch 16/100 | Loss: 1.3072
Epoch 21/100 | Loss: 1.3031
Epoch 26/100 | Loss: 1.3009
Epoch 31/100 | Loss: 1.2989
Epoch 36/100 | Loss: 1.2979
Epoch 41/100 | Loss: 1.2977
Epoch 46/100 | Loss: 1.2980
Epoch 51/100 | Loss: 1.2983
Epoch 56/100 | Loss: 1.2984
Epoch 61/100 | Loss: 1.2980
Epoch 66/100 | Loss: 1.2972
Epoch 71/100 | Loss: 1.2960
Epoch 76/100 | Loss: 1.2949
Epoch 81/100 | Loss: 1.2938
Epoch 86/100 | Loss: 1.2930
Epoch 91/100 | Loss: 1.2924
Epoch 96/100 | Loss: 1.2918
Epoch 100/100 | Loss: 1.2916


# **Computing Probabilistic Priors from the trained model**

In [10]:
def compute_priors(model, triples_tensor, factory, sample_size=200):
    model.eval()
    with torch.no_grad():
        n = min(sample_size, triples_tensor.size(0))
        h = triples_tensor[:n, 0]
        r = triples_tensor[:n, 1]
        t = triples_tensor[:n, 2]

        scores = model(h, r, t)
        scores = torch.clamp(scores, min=-10.0, max=10.0)
        probs = torch.sigmoid(scores)

    result = []
    for i in range(n):
        result.append({
            "head":    factory.entity_id_to_label[h[i].item()],
            "relation":factory.relation_id_to_label[r[i].item()],
            "tail":    factory.entity_id_to_label[t[i].item()],
            "prior":   float(probs[i].item())
        })
    return result

subset_priors = compute_priors(model, triples_tensor, train_factory, sample_size=200000)
for row in subset_priors[:10]:
    print(f"{row['head']} --{row['relation']}--> {row['tail']}  [prior={row['prior']:.3f}]")


'de novo' AMP biosynthetic process --bioprocess_bioprocess--> AMP biosynthetic process  [prior=0.518]
'de novo' AMP biosynthetic process --bioprocess_protein--> ADSL  [prior=0.504]
'de novo' CTP biosynthetic process --bioprocess_bioprocess--> CTP biosynthetic process  [prior=0.514]
'de novo' GDP-L-fucose biosynthetic process --bioprocess_bioprocess--> GDP-L-fucose biosynthetic process  [prior=0.514]
'de novo' IMP biosynthetic process --bioprocess_bioprocess--> IMP biosynthetic process  [prior=0.515]
'de novo' L-methionine biosynthetic process --bioprocess_bioprocess--> L-methionine biosynthetic process  [prior=0.513]
'de novo' NAD biosynthetic process --bioprocess_bioprocess--> 'de novo' NAD biosynthetic process from aspartate  [prior=0.508]
'de novo' NAD biosynthetic process from aspartate --bioprocess_bioprocess--> 'de novo' NAD biosynthetic process  [prior=0.508]
'de novo' NAD biosynthetic process from tryptophan --bioprocess_bioprocess--> 'de novo' NAD biosynthetic process  [prior=

# **Convert priors into a Tensor for inference**

In [11]:
priors = torch.tensor([row["prior"] for row in subset_priors], dtype=torch.float32)
n_triples = len(priors)
print(f"Using {n_triples} triples for probabilistic inference.")

Using 100000 triples for probabilistic inference.


# **Bayesian Posterior Inference using Pyro (Per Triple)**

In [12]:
def beta_bernoulli_model(prior_probs, concentration=10.0):
    n = len(prior_probs)
    alpha = prior_probs * concentration + 1.0
    beta = (1.0 - prior_probs) * concentration + 1.0
    with pyro.plate("triples", n):
        p = pyro.sample("p", dist.Beta(alpha, beta))
        pyro.sample("obs", dist.Bernoulli(probs=p), obs=torch.ones(n))

# **Run MCMC with NUTS**

In [13]:
nuts_kernel = NUTS(beta_bernoulli_model)
mcmc = MCMC(nuts_kernel, num_samples=500, warmup_steps=100)
mcmc.run(priors)

posterior_samples = mcmc.get_samples()["p"]
posterior_means = posterior_samples.mean(dim=0).detach().numpy()


Sample: 100%|██████████| 600/600 [10:33,  1.06s/it, step size=9.96e-02, acc. prob=0.798]


# **Save Posterior estimates**

In [14]:
import pandas as pd

df_posteriors = pd.DataFrame({
    'head': [row['head'] for row in subset_priors],
    'relation': [row['relation'] for row in subset_priors],
    'tail': [row['tail'] for row in subset_priors],
    'prior': [row['prior'] for row in subset_priors],
    'posterior_mean': posterior_means
})

print(df_posteriors.head())


                                          head               relation  \
0           'de novo' AMP biosynthetic process  bioprocess_bioprocess   
1           'de novo' AMP biosynthetic process     bioprocess_protein   
2           'de novo' CTP biosynthetic process  bioprocess_bioprocess   
3  'de novo' GDP-L-fucose biosynthetic process  bioprocess_bioprocess   
4           'de novo' IMP biosynthetic process  bioprocess_bioprocess   

                                tail     prior  posterior_mean  
0           AMP biosynthetic process  0.518474        0.551360  
1                               ADSL  0.503805        0.535268  
2           CTP biosynthetic process  0.514375        0.547528  
3  GDP-L-fucose biosynthetic process  0.514290        0.551204  
4           IMP biosynthetic process  0.514777        0.550747  


# **Connect to Neo4j**

In [15]:
NEO4J_URI="neo4j+s://36039b8d.databases.neo4j.io"
NEO4J_USERNAME="neo4j"
NEO4J_PASSWORD="2GSEZotjB6q0ZB8CBDP36BPfhOo4Ip60hPLTURPIqrU"

driver = GraphDatabase.driver(NEO4J_URI, auth=(NEO4J_USERNAME, NEO4J_PASSWORD))


# **Upload the triplets**

In [16]:
def upload_triples_with_confidence(driver, records, batch_size=1000, mode="prior"):
    assert mode in {"prior", "posterior", "both"}, "mode must be 'prior', 'posterior', or 'both'"

    with driver.session() as session:
        for i in range(0, len(records), batch_size):
            batch = records[i: i + batch_size]
            cypher = """
                UNWIND $batch AS rec
                MERGE (h:Entity {name: rec.head})
                MERGE (t:Entity {name: rec.tail})
                MERGE (h)-[r:RELATION {type: rec.relation}]->(t)
            """
            if mode == "prior":
                cypher += " SET r.prior = rec.prior"
            elif mode == "posterior":
                cypher += " SET r.posterior = rec.posterior"
            elif mode == "both":
                cypher += " SET r.prior = rec.prior, r.posterior = rec.posterior"

            session.run(cypher, batch=batch)


# **Build an in-memory graph**

In [17]:
from collections import defaultdict

def build_graph_with_posteriors(df):
    graph = defaultdict(list)
    for _, row in df.iterrows():
        h = row["head"]
        t = row["tail"]
        r = row["relation"]
        p = row["posterior_mean"]
        graph[h].append((t, r, p))
        graph[t].append((h, r, p))
    return graph

graph = build_graph_with_posteriors(df_posteriors)


# **Define path-finding with Probabilities**

In [18]:
def find_paths(graph, start, end, max_hops=4):
    queue = deque([(start, [start], [], 1.0)])
    paths = []

    while queue:
        node, path, rels, prob = queue.popleft()

        if len(path) - 1 > max_hops:
            continue

        for neighbor, relation, p in graph.get(node, []):
            if neighbor in path:
                continue
            new_path = path + [neighbor]
            new_rels = rels + [relation]
            new_prob = prob * p
            if neighbor == end:
                paths.append((new_path, new_rels, new_prob))
            else:
                queue.append((neighbor, new_path, new_rels, new_prob))
    return paths

# **Aggregate Path Probabilities**

In [19]:
def aggregate_path_probs(paths):
    if not paths:
        return 0.0
    path_probs = [p for (_, _, p) in paths]
    return 1.0 - np.prod([1.0 - p for p in path_probs])


# **Multihop Inference Query function**

In [28]:
def query_multihop(graph, start, end, max_hops=5, verbose=True):
    paths = find_paths(graph, start, end, max_hops)
    p_final = aggregate_path_probs(paths)

    if verbose:
        print(f"P({start} → {end}) ≈ {p_final:.3f} via {len(paths)} path(s)")
        for path, rels, prob in paths[:5]:
            steps = " → ".join(f"{path[i]} -[{rels[i]}]" for i in range(len(rels))) + f" → {path[-1]}"
            print(f"  Path ({len(rels)} hop{'s' if len(rels) > 1 else ''}): {steps}  (p ≈ {prob:.3f})")

    return p_final


# **Run multihop inference**

In [35]:
query_multihop(graph, "'de novo' AMP biosynthetic process", "ADSL", max_hops=10)
print("\n")
query_multihop(graph, "'de novo' AMP biosynthetic process", "ADSS2", max_hops=10)
print("\n")
query_multihop(graph, "'de novo' AMP biosynthetic process", "AMP biosynthetic process", max_hops=10)
print("\n")
query_multihop(graph, "'de novo' AMP biosynthetic process", "AMP metabolic process", max_hops=10)
print("\n")
query_multihop(graph, "'de novo' AMP biosynthetic process", "AMP catabolic process", max_hops=10)
print("\n")
query_multihop(graph, "'de novo' AMP biosynthetic process", "ATP biosynthetic process", max_hops=10)
print("\n")
query_multihop(graph, "'de novo' AMP biosynthetic process", "ATP metabolic process", max_hops=10)


P('de novo' AMP biosynthetic process → ADSL) ≈ 0.535 via 1 path(s)
  Path (1 hop): 'de novo' AMP biosynthetic process -[bioprocess_protein] → ADSL  (p ≈ 0.535)


P('de novo' AMP biosynthetic process → ADSS2) ≈ 0.542 via 1 path(s)
  Path (1 hop): 'de novo' AMP biosynthetic process -[bioprocess_protein] → ADSS2  (p ≈ 0.542)


P('de novo' AMP biosynthetic process → AMP biosynthetic process) ≈ 0.551 via 1 path(s)
  Path (1 hop): 'de novo' AMP biosynthetic process -[bioprocess_bioprocess] → AMP biosynthetic process  (p ≈ 0.551)


P('de novo' AMP biosynthetic process → AMP metabolic process) ≈ 0.295 via 1 path(s)
  Path (2 hops): 'de novo' AMP biosynthetic process -[bioprocess_bioprocess] → AMP biosynthetic process -[bioprocess_bioprocess] → AMP metabolic process  (p ≈ 0.295)


P('de novo' AMP biosynthetic process → AMP catabolic process) ≈ 0.160 via 1 path(s)
  Path (3 hops): 'de novo' AMP biosynthetic process -[bioprocess_bioprocess] → AMP biosynthetic process -[bioprocess_bioprocess] → AM

np.float64(0.047802575437012496)