In [30]:

node_embeddings_list = []
edge_embeddings_list = []


In [31]:
from rdflib import Graph, Literal, RDF, RDFS, URIRef, Namespace
import random
from pathlib import Path
import os

# Define the namespace again
ex = Namespace("http://example.org/")

# Create a new RDF graph
g = Graph()
file_path = Path("data/rdf_graph.nt")

# if file_path.exists():
#     # Read the graph from the existing file if it exists
#     g.parse(file_path, format="nt")
#     print(f"Graph loaded from {file_path}.")
# else:
    # Define classes
g.add((ex.A, RDF.type, RDFS.Class))
g.add((ex.B, RDF.type, RDFS.Class))
g.add((ex.C, RDF.type, RDFS.Class))

# Define possible colors
colors = ["red", "blue", "green", "yellow", "purple", "orange"]

# Helper function to generate nodes with instance number, name, and color
def create_node(g, class_type, node_id, color):
    node_uri = URIRef(f"{class_type}{node_id}")
    g.add((node_uri, RDF.type, class_type))
    g.add((node_uri, ex.instance_number, Literal(node_id)))
    g.add((node_uri, ex.name, Literal(f"{class_type}{node_id}")))
    g.add((node_uri, ex.color, Literal(color)))
    return node_uri

# Create 5 nodes for class A, B, and C with instance number, name, and color
for i in range(1, 6):
    color_A = colors[i % len(colors)]  # Assign color based on modulo to reuse colors
    color_B = colors[(i + 1) % len(colors)]
    color_C = colors[(i + 2) % len(colors)]
    
    node_B = create_node(g, ex.A, i, color_A)  # Now A is being used for node B
    node_A = create_node(g, ex.B, i, color_B)  # Now B is being used for node A
    node_C = create_node(g, ex.C, i, color_C)  # C remains the same

    # First 3 nodes of A are connected to C nodes (positive classification)
    if i <= 3:
        g.add((node_A, ex.connected_to, node_C))
        g.add((node_B, ex.connected_to, node_C))
    else:
        # For nodes A4 and A5, no connections to C nodes (negative classification)
        # Ensure no connection to C nodes (negative classification)
        pass

# Add extra relations to increase complexity
for i in range(1, 6):
    node_A = URIRef(f"http://example.org/A{i}")
    node_B = URIRef(f"http://example.org/B{i}")
    g.add((node_A, ex.related_to, node_B))


def classify_node_A(graph, node_A):
    """
    Determines if a node of type A is connected to any node of type C.
    If A has an outgoing edge to any C node, it's classified as 'positive',
    otherwise, it's 'negative'.
    """
    connected_to_C = False
    
    # Iterate over all triples where node_A is the subject and 'connected_to' is the predicate
    for _, _, obj in graph.triples((node_A, ex.connected_to, None)):  # Ensure A -> C
        if obj.startswith(f"{ex}C"):  # Ensure the target is a C node
            connected_to_C = True
            break  # No need to check further if a connection exists
    
    return "positive" if connected_to_C else "negative"

        
# Add classification triples for node A
for i in range(1, 6):
    node_A = URIRef(f"http://example.org/A{i}")
    classification = classify_node_A(g, node_A)
    g.add((node_A, ex.classification, Literal(classification)))

# Save the graph to a file in NT format (triples)
graph_file_path = "./data/rdf_graph.nt"

# Create the directory if it doesn't exist
os.makedirs(os.path.dirname(graph_file_path), exist_ok=True)

# Print all edges (triples) where predicate is 'connected_to' or 'related_to'
for subj, pred, obj in g:
    if str(pred) in ["http://example.org/connected_to", "http://example.org/related_to"]:
        print(f"Subject: {subj}, Predicate: {pred}, Object: {obj}")


# Serialize and write the graph to the file properly
g.serialize(destination=graph_file_path, format="nt")
print(f"Graph serialized and saved to {graph_file_path}.")



        
    # # Print the graph in NT format (triples)
    # print(g.serialize(format="nt"))


# Print all edges (triples) in the graph




Subject: http://example.org/A2, Predicate: http://example.org/connected_to, Object: http://example.org/C2
Subject: http://example.org/A1, Predicate: http://example.org/connected_to, Object: http://example.org/C1
Subject: http://example.org/A3, Predicate: http://example.org/connected_to, Object: http://example.org/C3
Subject: http://example.org/A1, Predicate: http://example.org/related_to, Object: http://example.org/B1
Subject: http://example.org/A4, Predicate: http://example.org/related_to, Object: http://example.org/B4
Subject: http://example.org/B1, Predicate: http://example.org/connected_to, Object: http://example.org/C1
Subject: http://example.org/B3, Predicate: http://example.org/connected_to, Object: http://example.org/C3
Subject: http://example.org/A5, Predicate: http://example.org/related_to, Object: http://example.org/B5
Subject: http://example.org/A3, Predicate: http://example.org/related_to, Object: http://example.org/B3
Subject: http://example.org/B2, Predicate: http://exam



In [32]:
import torch
from torch_geometric.data import HeteroData
from rdflib import Graph, Namespace, RDF, URIRef, Literal
from rdflib.namespace import RDFS
import random

# Define the RDF namespace
ex = Namespace("http://example.org/")

# Initialize the HeteroData object
hetero_data = HeteroData()

# Create dictionaries to map node URIs to indices
node_mapping = {
    'A': {},
    'B': {},
    'C': {}
}

# Function to safely extract literal values
def get_literal_value(graph, subject, predicate, default_value=None):
    value = graph.value(subject, predicate)
    return value if value is not None else default_value

# Function to extract node features from the RDF graph
def extract_node_features(graph, class_type, node_type):
    node_features = []
    index = 0
    for s, p, o in graph.triples((None, RDF.type, class_type)):
        instance_number = get_literal_value(graph, s, ex.instance_number, 0)
        name = str(get_literal_value(graph, s, ex.name, "Unknown"))
        color = str(get_literal_value(graph, s, ex.color, "black"))

        node_id = s.split("/")[-1]
        node_mapping[node_type][node_id] = index

        instance_tensor = torch.tensor([float(instance_number)], dtype=torch.float32)
        color_tensor = torch.tensor([hash(color) % 1000], dtype=torch.float32)
        name_tensor = torch.tensor([hash(name) % 1000], dtype=torch.float32)

        node_features.append(torch.cat([instance_tensor, color_tensor, name_tensor], dim=0))
        index += 1

    if node_features:
        hetero_data[node_type].x = torch.stack(node_features)

# Extract node features for classes A, B, and C
extract_node_features(g, ex.A, 'A')
extract_node_features(g, ex.B, 'B')
extract_node_features(g, ex.C, 'C')

# Edge extraction function ensuring correct mapping
def extract_edges(graph, source_class, target_class, relation):
    edge_list = []
    for s, p, o in graph.triples((None, relation, None)):
        source_id = s.split("/")[-1]
        target_id = o.split("/")[-1]
        
        if source_id in node_mapping[source_class] and target_id in node_mapping[target_class]:
            source_idx = node_mapping[source_class][source_id]
            target_idx = node_mapping[target_class][target_id]
            edge_list.append((source_idx, target_idx))
    
    return edge_list

# Extract 'connected_to' edges (A -> C, B -> C)
connected_to_edges_a_c = extract_edges(g, "A", "C", ex.connected_to)
connected_to_edges_b_c = extract_edges(g, "B", "C", ex.connected_to)

# Extract 'related_to' edges (A -> B)
related_to_edges_a_b = extract_edges(g, "A", "B", ex.related_to)

# Add edge indices and edge types to HeteroData
if connected_to_edges_a_c:
    hetero_data['A', 'connected_to', 'C'].edge_index = torch.tensor(connected_to_edges_a_c, dtype=torch.long).t().contiguous()
    hetero_data['A', 'connected_to', 'C'].edge_type = torch.zeros(len(connected_to_edges_a_c), dtype=torch.long)
    print(f"Added {len(connected_to_edges_a_c)} edge indices for 'A -> C'.")

if connected_to_edges_b_c:
    hetero_data['B', 'connected_to', 'C'].edge_index = torch.tensor(connected_to_edges_b_c, dtype=torch.long).t().contiguous()
    hetero_data['B', 'connected_to', 'C'].edge_type = torch.ones(len(connected_to_edges_b_c), dtype=torch.long)
    print(f"Added {len(connected_to_edges_b_c)} edge indices for 'B -> C'.")

if related_to_edges_a_b:
    hetero_data['A', 'related_to', 'B'].edge_index = torch.tensor(related_to_edges_a_b, dtype=torch.long).t().contiguous()
    hetero_data['A', 'related_to', 'B'].edge_type = torch.full((len(related_to_edges_a_b),), 2, dtype=torch.long)
    print(f"Added {len(related_to_edges_a_b)} edge indices for 'A -> B'.")

# Create labels for node type A based on RDF graph classification
labels_A = []
for s, p, o in g.triples((None, ex.classification, None)):
    if str(s).startswith("http://example.org/A"):
        label = 1 if str(o) == "positive" else 0
        labels_A.append(label)

if labels_A:
    hetero_data['A'].y = torch.tensor(labels_A, dtype=torch.long)

# Create train and test masks for node type 'A'
num_nodes_A = len(hetero_data['A'].y)
indices_A = list(range(num_nodes_A))
random.shuffle(indices_A)

train_size_A = int(0.8 * num_nodes_A)
train_indices_A = indices_A[:train_size_A]
test_indices_A = indices_A[train_size_A:]

train_mask_A = torch.zeros(num_nodes_A, dtype=torch.bool)
train_mask_A[train_indices_A] = True

test_mask_A = torch.zeros(num_nodes_A, dtype=torch.bool)
test_mask_A[test_indices_A] = True

hetero_data['A'].train_mask = train_mask_A
hetero_data['A'].test_mask = test_mask_A

# Ensure edges exist before extracting them
if ('A', 'connected_to', 'C') in hetero_data and ('B', 'connected_to', 'C') in hetero_data:
    edge_index_A_to_C = hetero_data[('A', 'connected_to', 'C')].edge_index
    edge_index_B_to_C = hetero_data[('B', 'connected_to', 'C')].edge_index

    edge_index = torch.cat([edge_index_A_to_C, edge_index_B_to_C], dim=1)

    edge_type_A_to_C = hetero_data[('A', 'connected_to', 'C')].edge_type
    edge_type_B_to_C = hetero_data[('B', 'connected_to', 'C')].edge_type

    edge_type = torch.cat([edge_type_A_to_C, edge_type_B_to_C], dim=0)
    print(f"Final combined edge index shape: {edge_index.shape}")

print(hetero_data)


Added 3 edge indices for 'A -> C'.
Added 3 edge indices for 'B -> C'.
Added 5 edge indices for 'A -> B'.
HeteroData(
  A={
    x=[5, 3],
    y=[5],
    train_mask=[5],
    test_mask=[5],
  },
  B={ x=[5, 3] },
  C={ x=[5, 3] },
  (A, connected_to, C)={
    edge_index=[2, 3],
    edge_type=[3],
  },
  (B, connected_to, C)={
    edge_index=[2, 3],
    edge_type=[3],
  },
  (A, related_to, B)={
    edge_index=[2, 5],
    edge_type=[5],
  }
)


In [33]:
# Assuming 'data' is your HeteroData object
# Accessing node features
x_A = hetero_data['A'].x  # Node features for type A
x_B = hetero_data['B'].x  # Node features for type B
x_C = hetero_data['C'].x  # Node features for type C
y_A = hetero_data['A'].y  # Labels for type C nodes

# Accessing edge indices for different relationships
edge_index_A_to_C = hetero_data[('A', 'connected_to', 'C')].edge_index  # Edge index from C to A
edge_index_B_to_C = hetero_data[('B', 'connected_to', 'C')].edge_index  # Edge index from C to B
edge_index_A_to_B = hetero_data[('A', 'related_to', 'B')].edge_index  # Edge index from A to B

# Example: Printing the extracted data
print("Node features for A:", x_A.shape)
print("Node features for B:", x_B.shape)
print("Node features for C:", x_C.shape)
print("Labels for A:", y_A.shape)

print("Edge index from A to C:", edge_index_A_to_C)
print("Edge index from B to C:", edge_index_B_to_C)
print("Edge index from A to B:", edge_index_A_to_B)

Node features for A: torch.Size([5, 3])
Node features for B: torch.Size([5, 3])
Node features for C: torch.Size([5, 3])
Labels for A: torch.Size([5])
Edge index from A to C: tensor([[0, 1, 2],
        [0, 1, 2]])
Edge index from B to C: tensor([[0, 1, 2],
        [0, 1, 2]])
Edge index from A to B: tensor([[0, 1, 2, 3, 4],
        [0, 1, 2, 3, 4]])


In [62]:

# import torch
# import torch.nn.functional as F
# from torch_geometric.nn import FastRGCNConv

# # Define the FastRGCN-based model
# class FastRGCNGNN(torch.nn.Module):
#     def __init__(self, num_relations):
#         super().__init__()
#         self.conv1 = FastRGCNConv(in_channels=3, out_channels=32, num_relations=num_relations)
#         self.conv2 = FastRGCNConv(in_channels=32, out_channels=64, num_relations=num_relations)
#         self.lin = torch.nn.Linear(64, 2)  # Binary classification (positive or negative)

#     def forward(self, x, edge_index, edge_type=None):
#         x = F.relu(self.conv1(x, edge_index, edge_type))  # Pass node features x to the first layer
#         x = self.conv2(x, edge_index, edge_type)
#         return self.lin(x)  # Apply the final linear layer

# # Initialize the model
# num_relations = 2  # Adjust according to your relations
# model = FastRGCNGNN(num_relations)

# # Define the optimizer and loss function
# optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
# criterion = torch.nn.CrossEntropyLoss()

# def train(hetero_data):
#     model.train()  # Set the model to training mode
#     optimizer.zero_grad()  # Clear gradients

#     # Extracting node features and edge indices for node type A
#     x_A = hetero_data['A'].x  # Node features for A

#     # Ensure correct shape for x_A
#     if len(x_A.shape) != 2 or x_A.shape[1] != 3:  # Should be [num_nodes, num_features]
#         print(f"Unexpected shape for x_A: {x_A.shape}")

#     # Forward pass for node type A
#     out = model(x_A, edge_index, edge_type)  # Pass edge_type to the model

#     # Compute loss using masks for node type A
#     loss = criterion(out[hetero_data['A'].train_mask], hetero_data['A'].y[hetero_data['A'].train_mask])

#     # Backward pass: Compute gradients and update weights
#     loss.backward()
#     optimizer.step()

#     # Calculate accuracy for node type A
#     preds = out.argmax(dim=1)  # Get predicted class labels
#     correct = (preds[hetero_data['A'].train_mask] == hetero_data['A'].y[hetero_data['A'].train_mask]).sum().item()  # Count correct predictions
#     total_samples = hetero_data['A'].train_mask.sum().item()  # Count total samples in training

#     accuracy = correct / total_samples if total_samples > 0 else 0  # Prevent division by zero

#     return loss.item(), accuracy  # Return loss and accuracy


# # Train the model for a number of epochs
# for epoch in range(1, 20):  # Train for 20 epochs
#     loss, accuracy = train(hetero_data)
#     print(f'Epoch {epoch}, Loss: {loss:.4f}, Accuracy: {accuracy:.4f}')
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch_geometric.nn import FastRGCNConv

# Define the FastRGCN-based model
class FastRGCNGNN(torch.nn.Module):
    def __init__(self, num_relations, in_channels=3, hidden_dim=32, out_dim=64):
        super().__init__()
        self.conv1 = FastRGCNConv(in_channels, hidden_dim, num_relations=num_relations)
        self.conv2 = FastRGCNConv(hidden_dim, out_dim, num_relations=num_relations)
        self.lin = torch.nn.Linear(out_dim, 2)  # Binary classification (positive or negative)

    def forward(self, x, edge_index, edge_type=None):
        x = F.relu(self.conv1(x, edge_index, edge_type))
        x = self.conv2(x, edge_index, edge_type)
        return self.lin(x)

import os
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch_geometric.nn import FastRGCNConv

# Define the FastRGCN-based model
class FastRGCNGNN(torch.nn.Module):
    def __init__(self, num_relations, in_channels=3, hidden_dim=32, out_dim=64):
        super().__init__()
        self.conv1 = FastRGCNConv(in_channels, hidden_dim, num_relations=num_relations)
        self.conv2 = FastRGCNConv(hidden_dim, out_dim, num_relations=num_relations)
        self.lin = torch.nn.Linear(out_dim, 2)  # Binary classification (positive or negative)

    def forward(self, x, edge_index, edge_type=None):
        x = F.relu(self.conv1(x, edge_index, edge_type))
        x = self.conv2(x, edge_index, edge_type)
        return self.lin(x)

# Define the GNN Trainer Class
class GNNTrainer:
    def __init__(self, hetero_data, node_type, in_channels=3, hidden_dim=32, out_dim=64):
        self.hetero_data = hetero_data
        self.node_type = node_type

        # ✅ Define model path
        self.model_dir = "./saved_models"
        self.model_path = f"{self.model_dir}/{node_type}_gnn.pth"

        # ✅ Ensure directory exists
        os.makedirs(self.model_dir, exist_ok=True)

        self.model = FastRGCNGNN(num_relations=len(hetero_data.edge_types), in_channels=in_channels, hidden_dim=hidden_dim, out_dim=out_dim)
        self.optimizer = torch.optim.Adam(self.model.parameters(), lr=0.01)
        self.criterion = torch.nn.CrossEntropyLoss()

    def save_model(self):
        """Saves the trained model to a file."""
        torch.save(self.model.state_dict(), self.model_path)
        print(f"GNN model saved to {self.model_path}")

    def load_model(self):
        """Loads the trained model if it exists."""
        if os.path.exists(self.model_path):
            self.model.load_state_dict(torch.load(self.model_path, weights_only=True))
            print(f" Loaded GNN model from {self.model_path}")

    def train(self, epochs=20):
        """Train the GNN model and save it."""
        if os.path.exists(self.model_path):
            self.load_model()
            return self.model

        for epoch in range(epochs):
            self.model.train()
            self.optimizer.zero_grad()
            x = self.hetero_data[self.node_type].x
            edge_index = self.hetero_data[('A', 'connected_to', 'C')].edge_index
            edge_type = self.hetero_data[('A', 'connected_to', 'C')].edge_type
            out = self.model(x, edge_index, edge_type)
            loss = self.criterion(out[self.hetero_data[self.node_type].train_mask], self.hetero_data[self.node_type].y[self.hetero_data[self.node_type].train_mask])
            loss.backward()
            self.optimizer.step()
            preds = out.argmax(dim=1)
            accuracy = (preds[self.hetero_data[self.node_type].train_mask] == self.hetero_data[self.node_type].y[self.hetero_data[self.node_type].train_mask]).float().mean().item()
            print(f'Epoch {epoch+1}/{epochs}, Loss: {loss:.4f}, Accuracy: {accuracy:.4f}')

        return self.model

    
    def get_positive_nodes(self):
        """Returns the indices of nodes classified as positive by the trained GNN model."""
        self.model.eval()  # Ensure the model is in evaluation mode

        with torch.no_grad():
            # Get features & edges
            x = self.hetero_data[self.node_type].x
            edge_index = self.hetero_data[('A', 'connected_to', 'C')].edge_index
            edge_type = self.hetero_data[('A', 'connected_to', 'C')].edge_type
            
            # Forward pass through the model
            out = self.model(x, edge_index, edge_type)

            # Apply softmax to get probabilities
            probs = torch.softmax(out, dim=1)

            # Classify nodes: 1 = Positive, 0 = Negative
            preds = torch.argmax(probs, dim=1)

            # Filter only positively classified nodes
            pos_nodes = torch.where(preds == 1)[0].tolist()

            # print(f"Predicted Class Labels: {preds.tolist()}")  # Debugging print
            # print(f"Positive Nodes: {pos_nodes}")  # Debugging print

            return pos_nodes

    def get_negative_nodes(self):
        """Returns the indices of nodes classified as negative by the trained GNN model."""
        self.model.eval()  # Ensure the model is in evaluation mode

        with torch.no_grad():
            # Get features & edges
            x = self.hetero_data[self.node_type].x
            edge_index = self.hetero_data[('A', 'connected_to', 'C')].edge_index
            edge_type = self.hetero_data[('A', 'connected_to', 'C')].edge_type
            
            # Forward pass through the model
            out = self.model(x, edge_index, edge_type)

            # Apply softmax to get probabilities
            probs = torch.softmax(out, dim=1)

            # Classify nodes: 1 = Positive, 0 = Negative
            preds = torch.argmax(probs, dim=1)

            # Filter only negatively classified nodes
            neg_nodes = torch.where(preds == 0)[0].tolist()

            return neg_nodes
    
gnn = GNNTrainer(hetero_data,'A')
gnn.train()
gnn.save_model()
print(gnn.get_positive_nodes())
print(gnn.get_negative_nodes())

 Loaded GNN model from ./saved_models/A_gnn.pth
GNN model saved to ./saved_models/A_gnn.pth
[0, 1, 2]
[3, 4]


In [35]:
# import os
# import pandas as pd
# import torch
# import torch.nn as nn
# import torch.optim as optim
# import numpy as np

# # Define data directory
# data_dir = "./data"
# os.makedirs(data_dir, exist_ok=True)  # Ensure the directory exists

# unique_nodes = set()
# node_mapping = {}  # Dictionary to map (node type, local index) -> global index
# global_index = 0   # Start from global index 0

# for node_type in hetero_data.node_types:
#     num_nodes = hetero_data[node_type].x.shape[0] if 'x' in hetero_data[node_type] else 0
#     for local_idx in range(num_nodes):
#         node_mapping[(node_type, local_idx)] = global_index  # Assign unique index
#         unique_nodes.add(global_index)
#         global_index += 1  # Increment global index

# unique_nodes = sorted(list(unique_nodes))  # Ensure sorted list for indexing

# # Create a node index mapping
# node_to_index = {node: i for i, node in enumerate(unique_nodes)}

# # Extract unique edges from HeteroData before using them
# unique_edges = []
# for edge_type in hetero_data.edge_types:
#     edge_index = hetero_data[edge_type].edge_index
#     for i in range(edge_index.shape[1]):  # Iterate over edges
#         src = edge_index[0, i].item()  # Source node
#         tgt = edge_index[1, i].item()  # Target node
#         unique_edges.append((src, edge_type, tgt))  # Store edge as tuple (src, relation, tgt)


# # Create mappings for edges and relations
# relation_to_index = {rel: i for i, rel in enumerate(hetero_data.edge_types)}

# # Convert edges to tensors
# train_edges = torch.tensor([(node_to_index[h], relation_to_index[r], node_to_index[t]) for (h, r, t) in unique_edges])

# # Define the TransE model again
# class TransE(nn.Module):
#     def __init__(self, num_entities, num_relations, embedding_dim):
#         super(TransE, self).__init__()
#         self.entity_embeddings = nn.Embedding(num_entities, embedding_dim)
#         self.relation_embeddings = nn.Embedding(num_relations, embedding_dim)
#         nn.init.xavier_uniform_(self.entity_embeddings.weight)
#         nn.init.xavier_uniform_(self.relation_embeddings.weight)

#     def forward(self, head, relation):
#         head_emb = self.entity_embeddings(head)
#         relation_emb = self.relation_embeddings(relation)
#         return head_emb + relation_emb

#     def score(self, head, relation, tail):
#         """ Compute TransE score (L2 norm of (h + r - t)) """
#         h_r = self.forward(head, relation)
#         t_emb = self.entity_embeddings(tail)
#         return -torch.norm(h_r - t_emb, p=2, dim=1)  # Negative distance (higher = better)

# # Initialize model
# num_entities = len(unique_nodes)
# num_relations = len(hetero_data.edge_types)
# embedding_dim = 128

# transe_model = TransE(num_entities, num_relations, embedding_dim)
# optimizer = optim.Adam(transe_model.parameters(), lr=0.01)

# # Define training loop
# num_epochs = 100
# batch_size = 32
# loss_fn = nn.MarginRankingLoss(margin=1.0)

# # Convert edges to tensors
# train_edges = torch.tensor([(node_to_index[h], relation_to_index[r], node_to_index[t]) for (h, r, t) in unique_edges])

# # Training Loop
# for epoch in range(num_epochs):
#     transe_model.train()
#     optimizer.zero_grad()

#     # Sample positive triplets
#     idx = torch.randint(0, train_edges.shape[0], (batch_size,))
#     pos_triplets = train_edges[idx]

#     # Generate negative samples (corrupt the tail entity)
#     neg_triplets = pos_triplets.clone()
#     neg_triplets[:, 2] = torch.randint(0, num_entities, (batch_size,))

#     # Compute scores
#     pos_scores = transe_model.score(pos_triplets[:, 0], pos_triplets[:, 1], pos_triplets[:, 2])
#     neg_scores = transe_model.score(neg_triplets[:, 0], neg_triplets[:, 1], neg_triplets[:, 2])

#     # Compute loss (Ranking loss)
#     loss = loss_fn(pos_scores, neg_scores, torch.ones_like(pos_scores))
#     loss.backward()
#     optimizer.step()

#     # Print loss every 10 epochs
#     if epoch % 10 == 0:
#         print(f"Epoch {epoch}: Loss = {loss.item():.4f}")
    

# # Evaluation: Compute MRR and Hits@k
# transe_model.eval()

# ranks = []
# for (head, relation, tail) in unique_edges:
#     head_idx = torch.tensor([node_to_index[head]])
#     relation_idx = torch.tensor([relation_to_index[relation]])
#     tail_idx = torch.tensor([node_to_index[tail]])

#     # Compute the score for the correct entity
#     correct_score = transe_model.score(head_idx, relation_idx, tail_idx).item()

#     # Compute scores for all possible tail entities
#     all_tail_indices = torch.arange(num_entities)
#     all_scores = transe_model.score(head_idx.repeat(num_entities), relation_idx.repeat(num_entities), all_tail_indices)

#     # Rank the correct tail
#     sorted_scores, sorted_indices = torch.sort(all_scores, descending=True)
#     rank = (sorted_indices == tail_idx).nonzero(as_tuple=True)[0].item() + 1  # Convert 0-based to 1-based rank

#     ranks.append(rank)

# # Compute evaluation metrics
# MR = np.mean(ranks)
# MRR = np.mean([1.0 / r for r in ranks])
# Hits_1 = np.mean([1 if r <= 1 else 0 for r in ranks])
# Hits_3 = np.mean([1 if r <= 3 else 0 for r in ranks])
# Hits_10 = np.mean([1 if r <= 10 else 0 for r in ranks])

# # Save evaluation metrics
# metrics_path = os.path.join(data_dir, "transe_metrics.txt")
# with open(metrics_path, "w") as f:
#     f.write(f"Mean Rank (MR): {MR:.2f}\n")
#     f.write(f"Mean Reciprocal Rank (MRR): {MRR:.4f}\n")
#     f.write(f"Hits@1: {Hits_1:.4f}\n")
#     f.write(f"Hits@3: {Hits_3:.4f}\n")
#     f.write(f"Hits@10: {Hits_10:.4f}\n")

# # Compute embeddings for all nodes
# node_indices = torch.tensor([node_to_index[n] for n in unique_nodes])
# node_embeddings = transe_model.entity_embeddings(node_indices).detach().numpy()

# # Compute embeddings for all relations
# relation_indices = torch.tensor([relation_to_index[r] for r in hetero_data.edge_types])
# relation_embeddings = transe_model.relation_embeddings(relation_indices).detach().numpy()

# # Compute embeddings for all edges
# edge_embeddings = []
# for (head, relation, tail) in unique_edges:
#     head_idx = torch.tensor(node_to_index[head])
#     relation_idx = torch.tensor(relation_to_index[relation])
#     tail_idx = torch.tensor(node_to_index[tail])

#     head_emb = transe_model.entity_embeddings(head_idx)
#     relation_emb = transe_model.relation_embeddings(relation_idx)
#     tail_emb = transe_model.entity_embeddings(tail_idx)

#     # Compute edge embedding using TransE formulation
#     edge_embedding = (head_emb + relation_emb - tail_emb).detach().numpy()
#     edge_embeddings.append(edge_embedding)

# # Convert to DataFrames
# node_df = pd.DataFrame(node_embeddings, index=unique_nodes)
# relation_df = pd.DataFrame(relation_embeddings, index=hetero_data.edge_types)
# edge_df = pd.DataFrame(edge_embeddings, index=[f"{h}-{r}-{t}" for (h, r, t) in unique_edges])

# # Save embeddings to files
# node_embeddings_path = os.path.join(data_dir, "node_embeddings.csv")
# relation_embeddings_path = os.path.join(data_dir, "relation_embeddings.csv")
# edge_embeddings_path = os.path.join(data_dir, "edge_embeddings.csv")

# node_df.to_csv(node_embeddings_path)
# relation_df.to_csv(relation_embeddings_path)
# edge_df.to_csv(edge_embeddings_path)

# # Return saved file paths
# node_embeddings_path, relation_embeddings_path, edge_embeddings_path, metrics_path



import os
import pandas as pd
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np

# Define TransE Model
class TransE(nn.Module):
    def __init__(self, num_entities, num_relations, embedding_dim):
        super(TransE, self).__init__()
        self.entity_embeddings = nn.Embedding(num_entities, embedding_dim)
        self.relation_embeddings = nn.Embedding(num_relations, embedding_dim)
        nn.init.xavier_uniform_(self.entity_embeddings.weight)
        nn.init.xavier_uniform_(self.relation_embeddings.weight)

    def forward(self, head, relation):
        head_emb = self.entity_embeddings(head)
        relation_emb = self.relation_embeddings(relation)
        return head_emb + relation_emb

    def score(self, head, relation, tail):
        h_r = self.forward(head, relation)
        t_emb = self.entity_embeddings(tail)
        return -torch.norm(h_r - t_emb, p=2, dim=1)

class TransETrainer:
    def __init__(self, hetero_data, embedding_dim=128, num_epochs=100, batch_size=32, lr=0.01):
        self.data_dir = "./data"
        os.makedirs(self.data_dir, exist_ok=True)
        self.node_embeddings_path = os.path.join(self.data_dir, "node_embeddings.csv")
        self.relation_embeddings_path = os.path.join(self.data_dir, "relation_embeddings.csv")
        self.edge_embeddings_path = os.path.join(self.data_dir, "edge_embeddings.csv")
        self.metrics_path = os.path.join(self.data_dir, "transe_metrics.txt")
        self.hetero_data = hetero_data
        self.embedding_dim = embedding_dim
        self.num_epochs = num_epochs
        self.batch_size = batch_size
        self.lr = lr
        self.unique_nodes, self.unique_edges, self.node_to_index, self.relation_to_index = self._extract_graph_info()
        self.num_entities = len(self.unique_nodes)
        self.num_relations = len(self.hetero_data.edge_types)
        self.transe_model = TransE(self.num_entities, self.num_relations, self.embedding_dim)
        self.optimizer = optim.Adam(self.transe_model.parameters(), lr=self.lr)
        self.loss_fn = nn.MarginRankingLoss(margin=1.0)
    
    def _extract_graph_info(self):
        """Extracts unique nodes, edges, and their mappings."""
        unique_nodes = set()
        self.node_mapping = {}  # (node_type, local_index) -> global_index
        self.global_to_node = {}  # global_index -> (node_type, local_index)
        global_index = 0  

        for node_type in self.hetero_data.node_types:
            num_nodes = self.hetero_data[node_type].x.shape[0] if hasattr(self.hetero_data[node_type], "x") else 0
            for local_idx in range(num_nodes):
                self.node_mapping[(node_type, local_idx)] = global_index  # Forward mapping
                self.global_to_node[global_index] = (node_type, local_idx)  # Reverse mapping
                unique_nodes.add(global_index)
                global_index += 1  # Increment global index

        unique_nodes = sorted(list(unique_nodes))  # Ensure sorted list
        node_to_index = {node: i for i, node in enumerate(unique_nodes)}

        unique_edges = []
        for edge_type in self.hetero_data.edge_types:
            edge_index = self.hetero_data[edge_type].edge_index

            for i in range(edge_index.shape[1]):
                src_local = edge_index[0, i].item()
                tgt_local = edge_index[1, i].item()

                if (edge_type[0], src_local) in self.node_mapping and (edge_type[2], tgt_local) in self.node_mapping:
                    src_global = self.node_mapping[(edge_type[0], src_local)]
                    tgt_global = self.node_mapping[(edge_type[2], tgt_local)]
                    unique_edges.append((src_global, edge_type, tgt_global))  # Store full tuple

        relation_to_index = {rel: i for i, rel in enumerate(self.hetero_data.edge_types)}  # Store full tuple

        return unique_nodes, unique_edges, node_to_index, relation_to_index

    def get_embedding(self, global_index):
        """Retrieve TransE embedding and original node name for a given global index."""
        if global_index not in self.global_to_node:
            raise ValueError(f"Global index {global_index} not found in mapping.")

        node_type, local_index = self.global_to_node[global_index]
        embedding = self.transe_model.entity_embeddings(torch.tensor([global_index])).detach().numpy()
        
        return embedding, f"{node_type}{local_index}"

    def _save_embeddings(self):
        """Saves node, relation, and edge embeddings to CSV files."""
        print("Saving embeddings to disk...")

        node_indices = torch.arange(self.num_entities)
        node_embeddings = self.transe_model.entity_embeddings(node_indices).detach().numpy()
        node_df = pd.DataFrame(node_embeddings, index=self.unique_nodes)
        node_df.to_csv(self.node_embeddings_path)

        relation_indices = torch.arange(self.num_relations)
        relation_embeddings = self.transe_model.relation_embeddings(relation_indices).detach().numpy()
        relation_df = pd.DataFrame(relation_embeddings, index=self.hetero_data.edge_types)
        relation_df.to_csv(self.relation_embeddings_path)

        edge_embeddings = []
        edge_index_list = []
        for (head, relation, tail) in self.unique_edges:
            head_idx = torch.tensor(self.node_to_index[head])
            relation_idx = torch.tensor(self.relation_to_index[relation])
            tail_idx = torch.tensor(self.node_to_index[tail])

            head_emb = self.transe_model.entity_embeddings(head_idx)
            relation_emb = self.transe_model.relation_embeddings(relation_idx)
            tail_emb = self.transe_model.entity_embeddings(tail_idx)

            edge_embedding = (head_emb + relation_emb - tail_emb).detach().numpy()
            edge_embeddings.append(edge_embedding)
            edge_index_list.append(f"{head}-{relation}-{tail}")

        edge_df = pd.DataFrame(edge_embeddings, index=edge_index_list)
        edge_df.to_csv(self.edge_embeddings_path)

        print(f"Embeddings saved:\n - Nodes: {self.node_embeddings_path}\n - Relations: {self.relation_embeddings_path}\n - Edges: {self.edge_embeddings_path}")

    def _evaluate(self):
        """Computes evaluation metrics (MR, MRR, Hits@K)."""
        self.transe_model.eval()
        ranks = []

        for (head, relation, tail) in self.unique_edges:
            head_idx = torch.tensor([self.node_to_index[head]])
            relation_idx = torch.tensor([self.relation_to_index[relation]])
            tail_idx = torch.tensor([self.node_to_index[tail]])

            correct_score = self.transe_model.score(head_idx, relation_idx, tail_idx).item()
            all_tail_indices = torch.arange(self.num_entities)
            all_scores = self.transe_model.score(head_idx.repeat(self.num_entities), relation_idx.repeat(self.num_entities), all_tail_indices)

            sorted_scores, sorted_indices = torch.sort(all_scores, descending=True)
            rank = (sorted_indices == tail_idx).nonzero(as_tuple=True)[0].item() + 1

            ranks.append(rank)

        MR = np.mean(ranks)
        MRR = np.mean([1.0 / r for r in ranks])
        Hits_1 = np.mean([1 if r <= 1 else 0 for r in ranks])
        Hits_3 = np.mean([1 if r <= 3 else 0 for r in ranks])
        Hits_10 = np.mean([1 if r <= 10 else 0 for r in ranks])

        with open(self.metrics_path, "w") as f:
            f.write(f"Mean Rank (MR): {MR:.2f}\n")
            f.write(f"Mean Reciprocal Rank (MRR): {MRR:.4f}\n")
            f.write(f"Hits@1: {Hits_1:.4f}\n")
            f.write(f"Hits@3: {Hits_3:.4f}\n")
            f.write(f"Hits@10: {Hits_10:.4f}\n")

        print("Evaluation complete.")

    def train(self):
        if os.path.exists(self.node_embeddings_path):
            print("Loaded TransE embeddings from file.")
            return
        
        print("TransE embeddings not found. Computing embeddings...")
        train_edges = torch.tensor([(self.node_to_index[h], self.relation_to_index[r], self.node_to_index[t]) for (h, r, t) in self.unique_edges])
        
        for epoch in range(self.num_epochs):
            self.transe_model.train()
            self.optimizer.zero_grad()
            idx = torch.randint(0, train_edges.shape[0], (self.batch_size,))
            pos_triplets = train_edges[idx]
            neg_triplets = pos_triplets.clone()
            neg_triplets[:, 2] = torch.randint(0, self.num_entities, (self.batch_size,))
            pos_scores = self.transe_model.score(pos_triplets[:, 0], pos_triplets[:, 1], pos_triplets[:, 2])
            neg_scores = self.transe_model.score(neg_triplets[:, 0], neg_triplets[:, 1], neg_triplets[:, 2])
            loss = self.loss_fn(pos_scores, neg_scores, torch.ones_like(pos_scores))
            loss.backward()
            self.optimizer.step()

        self._save_embeddings()
        self._evaluate()

# Initialize and train TransE model
transe_trainer = TransETrainer(hetero_data)
transe_trainer.train()

Loaded TransE embeddings from file.


In [65]:
import functools

import numpy as np
from owlapy.class_expression import OWLClassExpression, OWLObjectComplementOf, OWLObjectUnionOf, \
    OWLObjectIntersectionOf, OWLObjectSomeValuesFrom, OWLObjectAllValuesFrom, OWLObjectMaxCardinality, \
    OWLObjectMinCardinality, OWLClass, OWLDataSomeValuesFrom, OWLObjectOneOf
from owlapy.owl_property import OWLObjectProperty
from torch_geometric.data import HeteroData


class Evaluator:
    """ An evaluator which is able to evaluate the accuracy of a given logical formula based on a given dataset."""

    def __init__(self, data: HeteroData, labeled_nodeset: set[tuple[int, str]] = set()):
        """"
        Initializes the evaluator based on the given dataset. After the initialization the object should be able to
        evaluate logical formulas based on the dataset.

        Args:
            data: The dataset which should be used for evaluation.
        """""
        self._data = data
        self._nodeset = self._get_nodeset()
        self._labeled_nodeset = labeled_nodeset

        self.owl_mapping = {
            OWLObjectComplementOf: self._eval_complement,
            OWLObjectUnionOf: self._eval_union,
            OWLObjectIntersectionOf: self._eval_intersection,
            OWLObjectSomeValuesFrom: self._eval_existential,
            OWLObjectAllValuesFrom: self._eval_universal,
            OWLObjectMaxCardinality: self._eval_max_cardinality,
            OWLObjectMinCardinality: self._eval_min_cardinality,
            OWLClass: self._eval_class,
            OWLDataSomeValuesFrom: self._eval_property_value,
            OWLObjectOneOf: self._eval_object_one_of
        }

    @property
    def data(self) -> HeteroData:
        """
        The dataset which should be used for evaluation.

        Returns:
            The dataset which should be used for evaluation.
        """
        return self._data

    @data.setter
    def data(self, val: HeteroData) -> None:
        """
        Sets the dataset which should be used for evaluation to the given value.

        Args:
            val: The dataset which should be used for evaluation.
        """
        self._data = val

    def explanation_accuracy(self, ground_truth: set[tuple[int, str]],
                             logical_formula: OWLClassExpression) -> tuple[float, float, float]:
        """
        Calculates the explanation accuracy of the given logical formula based on the given ground truth.

        Args:
            ground_truth: The ground truth which should be used for evaluation.
            logical_formula: The logical formula which should be evaluated.

        Returns:
            A triple containing the precision, the recall and the accuracy of the given logical formula based on the given
            ground truth.
        """
        tp, fp, tn, fn = self._get_positive_negatives(ground_truth, logical_formula)
        precision = tp / (tp + fp) if tp + fp > 0 else 0
        recall = tp / (tp + fn) if tp + fn > 0 else 0
        accuracy = (tp + tn) / (tp + fp + tn + fn)
        return precision, recall, accuracy

    def f1_score(self, ground_truth: set[tuple[int, str]], logical_formula: OWLClassExpression) -> float:
        """
        Calculates the F1 score of the given logical formula based on the given ground truth.
        Args:
            ground_truth: The ground truth which should be used for evaluation.
            logical_formula: The logical formula which should be evaluated.

        Returns:
            The F1 score of the given logical formula based on the given ground truth.
        """
        tp, fp, _, fn = self._get_positive_negatives(ground_truth, logical_formula)

        return (2 * tp) / (2 * tp + fp + fn)

    def _get_positive_negatives(self, ground_truth: set[tuple[int, str]], logical_formula: OWLClassExpression) \
            -> tuple[float, float, float, float]:
        """
        Calculates the sizes of the true positives, false positives, false negatives and true negatives of the given
        logical formula.
        Args:
            ground_truth: The ground truth which should be used for evaluation.
            logical_formula: The logical formula which should be evaluated.

        Returns:
            A tuple containing the sizes of the true positives, false positives, true negatives and false negatives of
            the given logical formula.
        """
        explanation_set = self._eval_formula(logical_formula) & self._labeled_nodeset # we need to filter out every node that is not in the test set, or we overestimate the false positives
        true_positives = len(explanation_set & ground_truth)
        false_positives = len(explanation_set - ground_truth)
        false_negatives = len(ground_truth - explanation_set)
        true_negatives = len(self._labeled_nodeset) - true_positives - false_positives - false_negatives # replace self.data.num_nodes with the size of the test set

        return true_positives, false_positives, true_negatives, false_negatives

    #@functools.lru_cache(maxsize=100)
    def _eval_formula(self, logical_formula: OWLClassExpression) -> set[tuple[int, str]]:
        """
        Evaluates the given logical formula based on the given dataset and returns the set of matching nodes.

        Args:
            logical_formula: The logical formula which should be evaluated.

        Returns:
            A set of nodes which are the result of the evaluation.
        """
        return self.owl_mapping[type(logical_formula)](logical_formula)

    def _eval_complement(self, logical_formula: OWLObjectComplementOf) -> set[tuple[int, str]]:
        """
        Evaluates the given complement based on the given dataset and returns the set of matching nodes.
        which are the complement of the inner set.

        Args:
            logical_formula: The complement which should be evaluated.

        Returns:
            A set of nodes which are the result of the evaluation.
        """
        inner_set = self._eval_formula(logical_formula.get_operand())
        return self._nodeset - inner_set

    def _eval_union(self, logical_formula: OWLObjectUnionOf) -> set[tuple[int, str]]:
        """
        Evaluates the given union based on the given dataset and returns the set of matching nodes.
        Args:
            logical_formula: The union which should be evaluated.

        Returns:
            A set of nodes which are the result of the evaluation.
        """
        operands = list(logical_formula.operands())
        result = set()
        for i in operands:
            result = result | self._eval_formula(i)
        return result

    def _eval_intersection(self, logical_formula: OWLObjectIntersectionOf) -> set[tuple[int, str]]:
        """
        Evaluates the given intersection based on the given dataset and returns the set of matching nodes.
        Args:
            logical_formula: The intersection which should be evaluated.

        Returns:
            A set of nodes which are the result of the evaluation.
        """
        operands = list(logical_formula.operands())
        result = self._eval_formula(operands[0])
        for i in operands[1:]:
            result = result & self._eval_formula(i)
        return result

    def _eval_existential(self, logical_formula: OWLObjectSomeValuesFrom) -> set[tuple[int, str]]:
        """
        Evaluates the given existential based on the given dataset and returns the set of matching nodes.
        Args:
            logical_formula: The existential restriction which should be evaluated.

        Returns:
            A set of nodes which are the result of the evaluation.
        """
        dest = self._eval_formula(logical_formula.get_filler())
        edge_type = self._eval_property(logical_formula.get_property())
        dest_first_elements = np.array([b[0] for b in dest])
        selection = np.isin(self.data[edge_type]['edge_index'][1].cpu(), dest_first_elements)
        origin = self.data[edge_type]['edge_index'][0][selection].cpu().numpy()
        return set(zip(origin, [edge_type[0], ] * len(origin)))

    def _eval_object_one_of(self, logical_formula: OWLObjectOneOf) -> set[tuple[int, str]]:
        """
        Evaluate an OWL ObjectOneOf logical formula and return a set of tuples
        representing nodes that match the condition.

        Args:
            logical_formula: The OWL ObjectOneOf logical formula to evaluate.

        Returns:
            A set of tuples where each tuple represents a node that matches the condition.
            Each tuple contains two elements: an integer representing the index and a string representing the node type.
        """
        nodes = set()
        individuals = list(logical_formula.individuals())
        for individual in individuals:
            node_type, index = individual.get_iri().get_remainder().split('#')
            nodes.add((int(index), node_type))
        return nodes

    def _eval_universal(self, logical_formula: OWLObjectAllValuesFrom) -> set[tuple[int, str]]:
        """
        Evaluates the given universal based on the given dataset and returns the set of matching nodes.
        Args:
            logical_formula: The universal restriction which should be evaluated.

        Returns:
            A set of nodes which are the result of the evaluation.
        """
        dest = set(self._eval_formula(logical_formula.get_filler()))
        edge_type = self._eval_property(logical_formula.get_property())
        result = set()

        mapping = dict()

        # Convert edge_index arrays to NumPy arrays for better performance
        edge_index_0 = self.data[edge_type]["edge_index"][0].cpu().numpy()
        edge_index_1 = self.data[edge_type]["edge_index"][1].cpu().numpy()

        for i in range(len(edge_index_0)):
            idx_0 = edge_index_0[i].item()
            idx_1 = edge_index_1[i].item()

            if idx_0 not in mapping:
                mapping[idx_0] = [idx_1]
            else:
                mapping[idx_0].append(idx_1)

        for i, indices in mapping.items():
            check_set = {(idx, edge_type[2]) for idx in indices}
            if check_set.issubset(dest):
                result.add((i, edge_type[0]))

        return result

    def _eval_max_cardinality(self, logical_formula: OWLObjectMaxCardinality) -> set[tuple[int, str]]:
        """
        Evaluates the given max cardinality restriction based on the given dataset and returns the set of matching
        nodes.

        Args:
            logical_formula: The max cardinality restriction which should be evaluated.

        Returns:
            A set of nodes which are the result of the evaluation.
        """
        dest = set(self._eval_formula(logical_formula.get_filler()))
        edge_type = self._eval_property(logical_formula.get_property())
        cardinality = logical_formula.get_cardinality()
        result = set()

        mapping = dict()

        # Convert edge_index arrays to NumPy arrays for better performance
        edge_index_0 = self.data[edge_type]["edge_index"][0].cpu().numpy()
        edge_index_1 = self.data[edge_type]["edge_index"][1].cpu().numpy()

        for i in range(len(edge_index_0)):
            idx_0 = edge_index_0[i].item()
            idx_1 = edge_index_1[i].item()

            if idx_0 not in mapping:
                mapping[idx_0] = [idx_1]
            else:
                mapping[idx_0].append(idx_1)

        for i, indices in mapping.items():
            check_set = {(idx, edge_type[2]) for idx in indices}
            if len(check_set) <= cardinality and check_set.issubset(dest):
                result.add((i, edge_type[0]))

        return result

    def _eval_min_cardinality(self, logical_formula: OWLObjectMinCardinality) -> set[tuple[int, str]]:
        """
        Evaluates the given min cardinality restriction based on the given dataset and returns the
        set of matching nodes.

        Args:
            logical_formula: The min cardinality restriction which should be evaluated.

        Returns:
            A set of nodes which are the result of the evaluation.
        """
        dest = set(self._eval_formula(logical_formula.get_filler()))
        edge_type = self._eval_property(logical_formula.get_property())
        cardinality = logical_formula.get_cardinality()
        result = set()

        mapping = dict()

        # Convert edge_index arrays to NumPy arrays for better performance
        edge_index_0 = self.data[edge_type]["edge_index"][0].cpu().numpy()
        edge_index_1 = self.data[edge_type]["edge_index"][1].cpu().numpy()

        for i in range(len(edge_index_0)):
            idx_0 = edge_index_0[i].item()
            idx_1 = edge_index_1[i].item()

            if idx_0 not in mapping:
                mapping[idx_0] = [idx_1]
            else:
                mapping[idx_0].append(idx_1)

        for i, indices in mapping.items():
            check_set = {(idx, edge_type[2]) for idx in indices}
            if len(check_set) >= cardinality and check_set.issubset(dest):
                result.add((i, edge_type[0]))

        return result

    def _eval_class(self, logical_formula: OWLClass) -> set[tuple[int, str]]:
        """
        Evaluates the given class based on the given dataset and returns the set of matching nodes.
        Args:
            logical_formula: The class which should be evaluated.

        Returns:
            A set of nodes which are the result of the evaluation.
        """
        return self._get_nodeset([logical_formula.get_iri().get_remainder(), ]) # mask with train/test mask

    def _eval_property_value(self, logical_formula: OWLDataSomeValuesFrom) -> set[tuple[int, str]]:
        """
        Evaluates the given OWLDataSomeValuesFrom logical formula based on the dataset and returns the set of nodes
        that satisfy the specified property value condition.

        Args:
            logical_formula: The OWLDataSomeValuesFrom expression representing a property value condition.

        Returns:
            A set of nodes that satisfy the specified property value condition.
                                 Each tuple contains the node index and node type.
        """
        nodes_matching_condition = set()

        # Extract information from the logical formula
        property_iri = logical_formula.get_property().get_iri().get_remainder()
        facet_restriction = logical_formula.get_filler().get_facet_restrictions()[0]

        # Parse property information
        property_split = property_iri.split('_')
        node_type = property_split[0]
        feature_index = int(property_split[-1]) - 1

        # Extract operator and comparison value from facet restriction
        operator = facet_restriction.get_facet().operator
        comparison_value = facet_restriction.get_facet_value()._v

        # Retrieve nodes and evaluate the condition
        nodes = self.data[node_type]['x'].cpu().numpy()
        for index, node in enumerate(nodes):
            if operator(node[feature_index], comparison_value):
                nodes_matching_condition.add((index, node_type))

        return nodes_matching_condition

    def _eval_property(self, property: OWLObjectProperty) -> tuple[str, str, str]:
        """
        Evaluates the given property based on the given dataset and returns the edge type.
        Args:
            property: The property which should be evaluated.

        Returns:
            The edge type which is the result of the evaluation.
        """
        for i in self.data.edge_types:
            if i[1] == property.get_iri().get_remainder():
                return i

    def _get_nodeset(self, node_types: list[str] = None) -> set[tuple[int, str]]:
        """
        Returns the set of nodes of the given node types.
        Args:
            node_types: The node types for which the nodes should be returned.

        Returns:
            The set of nodes of the given node types.
        """
        if node_types is None or node_types == ['Thing', ]:
            node_types = self.data.node_types
        if node_types == ['Nothing', ]:
            return set()
        result = set()
        for i in node_types:
            result = result | set(enumerate([i] * self.data[i]["x"].shape[0]))
        return result


In [68]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
from torch_geometric.data import HeteroData

# ==================== RL Graph Environment ====================
class RLGraphEnv:
    def __init__(self, hetero_data, trainer):
        self.hetero_data = hetero_data
        self.trainer = trainer
        self.current_node = None
        self.path = []
        self.tracked_paths = []

    def reset(self, start_node):
        self.current_node = start_node
        self.path = [start_node]
        return self.current_node

    def step(self, action):
        next_node = action
        self.path.append(next_node)
        self.current_node = next_node
        done = len(self.path) >= 2
        if done:
            self.tracked_paths.append(self.path.copy())
        return next_node, done

    def get_neighbors(self, node):
        neighbors = set()
        for (src, _, tgt) in self.trainer.unique_edges:
            if src == node:
                neighbors.add(tgt)
        return list(neighbors) if neighbors else list(self.trainer.global_to_node.keys())

    def get_node_embedding(self, node):
        return self.trainer.get_embedding(node)

# ==================== Policy Network (REINFORCE) ====================
class PolicyNetwork(nn.Module):
    def __init__(self, state_dim, action_space, learning_rate=0.001):
        super(PolicyNetwork, self).__init__()
        self.fc1 = nn.Linear(state_dim, 128)
        self.fc2 = nn.Linear(128, action_space)
        self.optimizer = optim.Adam(self.parameters(), lr=learning_rate)

    def forward(self, state):
        x = torch.relu(self.fc1(state))
        return torch.softmax(self.fc2(x), dim=-1).squeeze(0)  # Remove batch dim


    def update(self, state_batch, action_batch, rewards):
        action_probs = self.forward(state_batch)

        # Ensure action_probs is at least 2D
        if action_probs.dim() == 1:
            action_probs = action_probs.unsqueeze(0)

        # Ensure action_batch is (batch_size, 1)
        action_batch = action_batch.view(-1, 1)

        # # Debugging Prints
        # print(f"action_probs shape: {action_probs.shape}")  # (batch_size, action_space)
        # print(f"action_batch shape: {action_batch.shape}")  # (batch_size, 1)
        # print(f"action_probs: {action_probs}")  
        # print(f"action_batch: {action_batch}")  

        # Validate indices before gathering
        max_index = action_probs.shape[1] - 1
        action_batch = torch.clamp(action_batch, 0, max_index)

        # Apply gather and compute log probability
        action_log_probs = torch.log(action_probs.gather(1, action_batch).squeeze())

        loss = -torch.sum(action_log_probs * rewards)

        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()

# ==================== Logical Expression Generator ====================
def paths_to_logical_expressions(paths, node_mapping):
    logical_expressions = []
    for path in paths:
        if len(path) < 2:
            continue
        owl_classes = [f"∃ connectedTo.{node_mapping[n][0]}{node_mapping[n][1]}" for n in path if n in node_mapping]
        logical_expressions.append(" AND ".join(owl_classes))
    return logical_expressions

# ==================== Train RL Agent ====================
def train_rl_agent_with_owl(hetero_data, num_episodes=100):
    trainer = TransETrainer(hetero_data)
    trainer.train()

    env = RLGraphEnv(hetero_data, trainer)

    state_dim = trainer.transe_model.entity_embeddings.weight.shape[1]
    policy_net = PolicyNetwork(state_dim, len(trainer.node_mapping))

    optimizer = optim.Adam(policy_net.parameters(), lr=0.01)

    gnn_trainer = GNNTrainer(hetero_data, node_type='A')
    gnn_trainer.train()
    positive_nodes = gnn_trainer.get_positive_nodes()

    if not positive_nodes:
        positive_nodes = list(trainer.global_to_node.keys())

    print("Starting RL with positive nodes:", positive_nodes)

    for episode in range(num_episodes):
        start_node = np.random.choice(positive_nodes)
        state = env.reset(start_node)

        episode_states, episode_actions, rewards = [], [], []

        for _ in range(2):
            neighbors = env.get_neighbors(state)

            if not neighbors:
                action = np.random.choice(list(trainer.global_to_node.keys()))
            else:
                state_emb, _ = env.get_node_embedding(state)
                state_emb = torch.tensor(state_emb, dtype=torch.float32)

                action_probs = policy_net.forward(state_emb).detach().numpy().flatten()

                # Debugging Prints
                # print(f"State {state}: Action probabilities shape: {action_probs.shape}")

                neighbor_indices = [n for n in neighbors if 0 <= n < len(action_probs)]
                # print(f"State {state}: Filtered neighbor indices: {neighbor_indices}")

                if neighbor_indices:
                    valid_probs = action_probs[neighbor_indices]

                    if valid_probs.sum() > 0:
                        valid_probs /= valid_probs.sum()
                        action = np.random.choice(neighbor_indices, p=valid_probs)
                    else:
                        action = np.random.choice(neighbor_indices)
                else:
                    action = np.random.choice(neighbors)

            episode_states.append(state_emb)
            episode_actions.append(action)

            state, done = env.step(action)
            if done:
                break

        if episode_states:
            action_batch = torch.tensor(episode_actions, dtype=torch.long)
            state_batch = torch.stack(episode_states)

            retrieved_nodes = set(env.path)
            reward = (len(set(trainer.node_mapping.keys()) & retrieved_nodes) - 
                      len(retrieved_nodes - set(trainer.node_mapping.keys()))) / len(trainer.node_mapping)
            rewards.append(reward)

            returns = torch.tensor([sum(rewards[i:]) for i in range(len(rewards))], dtype=torch.float32)
            policy_net.update(state_batch, action_batch, returns)

        print(f"Episode {episode + 1}/{num_episodes}: Path Taken = {env.path}, Reward = {reward}")

    logical_expressions = paths_to_logical_expressions(env.tracked_paths, trainer.node_mapping)
    return logical_expressions

# ==================== Run RL Training ====================
logical_expressions = train_rl_agent_with_owl(hetero_data)

for expr in logical_expressions:
    print(expr)


Loaded TransE embeddings from file.
 Loaded GNN model from ./saved_models/A_gnn.pth
Starting RL with positive nodes: [0, 1, 2]
Episode 1/100: Path Taken = [np.int64(1), np.int64(6)], Reward = -0.13333333333333333
Episode 2/100: Path Taken = [np.int64(2), np.int64(7)], Reward = -0.13333333333333333
Episode 3/100: Path Taken = [np.int64(1), np.int64(11)], Reward = -0.13333333333333333
Episode 4/100: Path Taken = [np.int64(1), np.int64(11)], Reward = -0.13333333333333333
Episode 5/100: Path Taken = [np.int64(1), np.int64(11)], Reward = -0.13333333333333333
Episode 6/100: Path Taken = [np.int64(2), np.int64(7)], Reward = -0.13333333333333333
Episode 7/100: Path Taken = [np.int64(1), np.int64(11)], Reward = -0.13333333333333333
Episode 8/100: Path Taken = [np.int64(2), np.int64(7)], Reward = -0.13333333333333333
Episode 9/100: Path Taken = [np.int64(1), np.int64(11)], Reward = -0.13333333333333333
Episode 10/100: Path Taken = [np.int64(2), np.int64(12)], Reward = -0.13333333333333333
Episod