# 1. Install Dependencies

In [None]:
# Cài torch 2.0.1 + CUDA 11.8
!pip install torch==2.5.1+cu118 torchvision==0.20.1+cu118 torchaudio==2.5.1 --index-url https://download.pytorch.org/whl/cu118

# Dùng đúng phiên bản phù hợp với torch==2.0.1 và cu118 (CUDA 11.8)
!pip install torch-scatter -f https://data.pyg.org/whl/torch-2.5.0+cu118.html
!pip install torch-sparse -f https://data.pyg.org/whl/torch-2.5.0+cu118.html
!pip install torch-geometric
!pip install numpy==1.25.2 --quiet

Looking in indexes: https://download.pytorch.org/whl/cu118
Collecting torch==2.5.1+cu118
  Downloading https://download.pytorch.org/whl/cu118/torch-2.5.1%2Bcu118-cp312-cp312-linux_x86_64.whl (838.3 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m838.3/838.3 MB[0m [31m910.8 kB/s[0m eta [36m0:00:00[0m
[?25hCollecting torchvision==0.20.1+cu118
  Downloading https://download.pytorch.org/whl/cu118/torchvision-0.20.1%2Bcu118-cp312-cp312-linux_x86_64.whl (6.5 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m6.5/6.5 MB[0m [31m12.2 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting torchaudio==2.5.1
  Downloading https://download.pytorch.org/whl/cu118/torchaudio-2.5.1%2Bcu118-cp312-cp312-linux_x86_64.whl (3.3 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m3.3/3.3 MB[0m [31m29.9 MB/s[0m eta [36m0:00:00[0m
Collecting nvidia-cuda-nvrtc-cu11==11.8.89 (from torch==2.5.1+cu118)
  Downloading https://download.pytorch.org/whl/cu118/nvidi

# 2. Mount Google Drive

In [None]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


# 3. Navigate to Model Folder

In [None]:
import os

# Change directory to the shortcut folder
os.chdir("/content/drive/MyDrive/Colab Notebooks/RED-GNN_backup")
print("Now in:", os.getcwd())
# List files in current folder
print("Contain: ", os.listdir())

Now in: /content/drive/.shortcut-targets-by-id/1L4ON7WudkBp-Zjzj-Fn0FGqA35djsbeR/RED-GNN_backup
Contain:  ['inductive', 'README.md', 'transductive']


In [None]:
# Go into a subfolder (like `cd subfolder`)
os.chdir("/content/drive/MyDrive/Colab Notebooks/RED-GNN_backup")
# os.chdir("transductive")
os.chdir("inductive")
print("Now in:", os.getcwd())
# List files in current folder
print("Contain: ", os.listdir())

Now in: /content/drive/.shortcut-targets-by-id/1L4ON7WudkBp-Zjzj-Fn0FGqA35djsbeR/RED-GNN_backup/inductive
Contain:  ['results', 'weights', 'checkpoints', 'utils.py', 'base_model.py', 'models.py', '__pycache__', 'data', 'train.py', 'load_data.py']


# 4. Import Required Libraries

In [None]:
import argparse
import torch
import torch.nn as nn
import numpy as np
import random
import networkx as nx
import matplotlib.pyplot as plt
import plotly.graph_objects as go
import math

from collections import deque, defaultdict
from load_data import DataLoader
from base_model import BaseModel
from utils import cal_ranks
from graphviz import Source
from IPython.display import display  # For rendering in Jupyter

# 5. Implement KGInference - Visualization class

In [None]:
class KGInference:
    def __init__(self, base_model):
        self.base_model = base_model
        self.model = base_model.model
        self.loader = base_model.loader
        self.n_rel = self.loader.n_rel
        self.n_ent = self.loader.n_ent
        self.n_layer = 3

        self.id2entity = {v: k for k, v in self.loader.entity2id.items()}
        self.id2relation = {v: k for k, v in self.loader.relation2id.items()}

    def _get_rel_name(self, rel_id, is_inverse=False):
        """Return relation name, adjusting for inverse direction."""
        if rel_id == 2 * self.n_rel:
            return "self"
        elif rel_id >= self.n_rel:
            base_rel = self.id2relation[rel_id - self.n_rel]
            return base_rel if is_inverse else base_rel + "_inv"
        else:
            return self.id2relation[rel_id] + "_inv" if is_inverse else self.id2relation[rel_id]

    def _is_reverse_relation(self, r1, r2):
        """Check if r1 and r2 are reverse relations (e.g., son and son_inv)."""
        if r1 >= self.n_rel and r2 < self.n_rel:
            return self.id2relation.get(r1 - self.n_rel) == self.id2relation.get(r2)
        if r2 >= self.n_rel and r1 < self.n_rel:
            return self.id2relation.get(r2 - self.n_rel) == self.id2relation.get(r1)
        return False

    def _prune_paths_to_tail(self, edges, head_id, tail_id):
        """Prune edges to keep all <=n_layer-hop forward paths from head to tail.
          - Forward edges can be normal or inverse
          - No cycles (simple paths only)
          - No duplicate triplets (h,r,t)
          - At most one self-edge per entity
          - Self-edge counts as 1 hop, so effective budget = n_layer - 1 if used
          - Every kept edge must lie on a valid head->tail path
        """

        # Build adjacency list
        graph = defaultdict(list)
        for h, r, t, alpha in edges:
            graph[h].append((t, r, alpha))

        max_hops = self.n_layer
        valid_paths = []  # store all valid paths to tail

        # BFS to enumerate all simple paths from head to tail
        queue = deque([(head_id, [head_id], 0)])  # (node, path_so_far, hop_count)
        min_hops_to_tail = float("inf")

        while queue:
            node, path, hops = queue.popleft()
            if hops >= max_hops:
                continue

            for next_node, rel, alpha in graph[node]:
                # Prevent cycles except self-edge
                if next_node in path and node != next_node:
                    continue

                new_hops = hops + 1
                if new_hops > max_hops:
                    continue

                new_path = path + [next_node]

                if next_node == tail_id:
                    # Store full path including this edge
                    valid_paths.append((new_path, (node, rel, next_node, alpha)))
                    min_hops_to_tail = min(min_hops_to_tail, new_hops)
                else:
                    queue.append((next_node, new_path, new_hops))

        # Collect only edges that are part of some valid head->tail path
        valid_edges = set()
        for path, last_edge in valid_paths:
            # Reconstruct edges along this path
            for i in range(len(path) - 1):
                h = path[i]
                t = path[i + 1]
                for t2, r, a in graph[h]:
                    if t2 == t:
                        valid_edges.add((h, r, t, a))
            valid_edges.add(last_edge)

        # ✅ Add self-loop for tail if exists and budget allows
        if min_hops_to_tail < max_hops:
            for h, r, t, alpha in edges:
                if h == t == tail_id:
                    valid_edges.add((h, r, t, alpha))
                    break

        # Deduplicate edges, keeping best alpha per (h, r, t)
        edge_dict = defaultdict(list)
        for h, r, t, alpha in valid_edges:
            edge_dict[(h, r, t)].append(alpha)

        pruned_edges = []
        used_self_edges = set()

        for (h, r, t), alphas in edge_dict.items():
            best_alpha = max(alphas)
            if h == t:
                # Each entity allows at most 1 self-edge
                if h not in used_self_edges:
                    pruned_edges.append((h, r, t, best_alpha))
                    used_self_edges.add(h)
            else:
                pruned_edges.append((h, r, t, best_alpha))

        return pruned_edges

    def predict_tail(self, head_name, rel_name, alpha=0.0):
        if head_name not in self.loader.entity2id:
            raise ValueError(f"Unknown entity: {head_name}")
        if rel_name not in self.loader.relation2id:
            raise ValueError(f"Unknown relation: {rel_name}")
        if not 0.0 <= alpha <= 1.0:
            raise ValueError("Alpha must be between 0 and 1")

        head_id = self.loader.entity2id[head_name]
        rel_id = self.loader.relation2id[rel_name]

        self.model.eval()

        # Single query forward pass with subgraph collection
        n = 1
        q_sub = torch.LongTensor([head_id]).cuda()
        q_rel = torch.LongTensor([rel_id]).cuda()

        h0 = torch.zeros((1, n, self.model.hidden_dim)).cuda()
        nodes = torch.cat([torch.arange(n).unsqueeze(1).cuda(), q_sub.unsqueeze(1)], 1)
        hidden = torch.zeros(n, self.model.hidden_dim).cuda()

        all_edges_with_alpha = []

        for i in range(self.model.n_layer):
            nodes_cpu = nodes.data.cpu().numpy()
            tail_nodes, sampled_edges, old_nodes_new_idx = self.loader.get_neighbors(nodes_cpu, mode='test')

            # Extract attention weights
            sub = sampled_edges[:, 4]
            rel = sampled_edges[:, 2]
            obj = sampled_edges[:, 5]
            hs = hidden[sub]
            hr = self.model.gnn_layers[i].rela_embed(rel)
            r_idx = sampled_edges[:, 0]
            h_qr = self.model.gnn_layers[i].rela_embed(q_rel)[r_idx]
            alpha_raw = self.model.gnn_layers[i].w_alpha(
                nn.ReLU()(self.model.gnn_layers[i].Ws_attn(hs) +
                         self.model.gnn_layers[i].Wr_attn(hr) +
                         self.model.gnn_layers[i].Wqr_attn(h_qr))
            )
            # Softmax normalization
            alpha_weights = torch.sigmoid(alpha_raw)
            alpha_weights = alpha_weights.detach().cpu().numpy()

            # Debug attention weights
            print(f"Layer {i+1}: Alpha weights - Min: {alpha_weights.min():.4f}, "
                  f"Max: {alpha_weights.max():.4f}, Mean: {alpha_weights.mean():.4f}, "
                  f"Num edges: {len(alpha_weights)}")

            # Collect edges with attention weights
            layer_edges = []
            for j, e in enumerate(sampled_edges):
                h_id, r_id, t_id = e[1].item(), e[2].item(), e[3].item()
                alpha_value = alpha_weights[j].item()
                if alpha_value >= alpha:
                    layer_edges.append((h_id, r_id, t_id, alpha_value))

            all_edges_with_alpha.extend(layer_edges)
            print(f"Layer {i+1}: {len(layer_edges)} edges after alpha={alpha} pruning")

            hidden = self.model.gnn_layers[i](q_sub, q_rel, hidden, sampled_edges, tail_nodes.size(0), old_nodes_new_idx)
            h0 = torch.zeros(1, tail_nodes.size(0), hidden.size(1)).cuda().index_copy_(1, old_nodes_new_idx, h0)
            hidden = self.model.dropout(hidden)
            hidden, h0 = self.model.gate(hidden.unsqueeze(0), h0)
            hidden = hidden.squeeze(0)
            nodes = tail_nodes

        scores = self.model.W_final(hidden).squeeze(-1)
        scores_all = torch.zeros((n, self.n_ent)).cuda()
        scores_all[[nodes[:, 0], nodes[:, 1]]] = scores

        # Find best tail
        scores_cpu = scores_all[0].detach().cpu().numpy()
        best_id = np.argmax(scores_cpu)
        best_name = self.id2entity[best_id]
        best_score = scores_cpu[best_id]
        best_rank = 1

        # Prune dead-end, redundant paths, and duplicates
        all_edges_with_alpha = self._prune_paths_to_tail(all_edges_with_alpha, head_id, best_id)

        # Convert to subgraph — no flipping needed, since BFS handled inverses
        subgraph = []
        for h, r, t, _ in all_edges_with_alpha:
            subgraph.append((self.id2entity[h],
                            self._get_rel_name(r),
                            self.id2entity[t]))

        print(f"Final subgraph size after alpha={alpha} and path pruning: {len(subgraph)} edges")
        return best_name, best_rank, best_score, subgraph

    def get_info(self, head_name, rel_name, tail_name, alpha=0.0):
        if head_name not in self.loader.entity2id:
            raise ValueError(f"Unknown entity: {head_name}")
        if tail_name not in self.loader.entity2id:
            raise ValueError(f"Unknown entity: {tail_name}")
        if rel_name not in self.loader.relation2id:
            raise ValueError(f"Unknown relation: {rel_name}")
        if not 0.0 <= alpha <= 1.0:
            raise ValueError("Alpha must be between 0 and 1")

        head_id = self.loader.entity2id[head_name]
        tail_id = self.loader.entity2id[tail_name]
        rel_id = self.loader.relation2id[rel_name]

        self.model.eval()

        # Single query forward pass with subgraph collection
        n = 1
        q_sub = torch.LongTensor([head_id]).cuda()
        q_rel = torch.LongTensor([rel_id]).cuda()

        h0 = torch.zeros((1, n, self.model.hidden_dim)).cuda()
        nodes = torch.cat([torch.arange(n).unsqueeze(1).cuda(), q_sub.unsqueeze(1)], 1)
        hidden = torch.zeros(n, self.model.hidden_dim).cuda()

        all_edges_with_alpha = []

        for i in range(self.model.n_layer):
            nodes_cpu = nodes.data.cpu().numpy()
            tail_nodes, sampled_edges, old_nodes_new_idx = self.loader.get_neighbors(nodes_cpu, mode='test')

            # Extract attention weights
            sub = sampled_edges[:, 4]
            rel = sampled_edges[:, 2]
            obj = sampled_edges[:, 5]
            hs = hidden[sub]
            hr = self.model.gnn_layers[i].rela_embed(rel)
            r_idx = sampled_edges[:, 0]
            h_qr = self.model.gnn_layers[i].rela_embed(q_rel)[r_idx]
            alpha_raw = self.model.gnn_layers[i].w_alpha(
                nn.ReLU()(self.model.gnn_layers[i].Ws_attn(hs) +
                         self.model.gnn_layers[i].Wr_attn(hr) +
                         self.model.gnn_layers[i].Wqr_attn(h_qr))
            )
            # Softmax normalization
            alpha_weights = torch.sigmoid(alpha_raw)
            alpha_weights = alpha_weights.detach().cpu().numpy()

            # Collect edges with attention weights
            layer_edges = []
            for j, e in enumerate(sampled_edges):
                h_id, r_id, t_id = e[1].item(), e[2].item(), e[3].item()
                alpha_value = alpha_weights[j].item()
                if alpha_value >= alpha:
                    layer_edges.append((h_id, r_id, t_id, alpha_value))

            all_edges_with_alpha.extend(layer_edges)

            hidden = self.model.gnn_layers[i](q_sub, q_rel, hidden, sampled_edges, tail_nodes.size(0), old_nodes_new_idx)
            h0 = torch.zeros(1, tail_nodes.size(0), hidden.size(1)).cuda().index_copy_(1, old_nodes_new_idx, h0)
            hidden = self.model.dropout(hidden)
            hidden, h0 = self.model.gate(hidden.unsqueeze(0), h0)
            hidden = hidden.squeeze(0)
            nodes = tail_nodes

        scores = self.model.W_final(hidden).squeeze(-1)
        scores_all = torch.zeros((n, self.n_ent)).cuda()
        scores_all[[nodes[:, 0], nodes[:, 1]]] = scores

        # Find best tail
        scores_cpu = scores_all[0].detach().cpu().numpy()
        score = scores_cpu[tail_id]
        rank = np.where(np.argsort(-scores_cpu) == tail_id)[0][0] + 1

        # Prune dead-end, redundant paths, and duplicates
        all_edges_with_alpha = self._prune_paths_to_tail(all_edges_with_alpha, head_id, tail_id)

        # Convert to subgraph — no flipping needed, since BFS handled inverses
        subgraph = []
        for h, r, t, _ in all_edges_with_alpha:
            subgraph.append((self.id2entity[h],
                            self._get_rel_name(r),
                            self.id2entity[t]))

        print(f"Final subgraph size after alpha={alpha} and path pruning: {len(subgraph)} edges")
        return tail_name, rank, score, subgraph

    def generate_random_predictions(self, num_times=10, alpha=0.0):
        entities = list(self.loader.entity2id.keys())
        relations = list(self.loader.relation2id.keys())

        results = []
        for _ in range(num_times):
            rand_head = random.choice(entities)
            rand_rel = random.choice(relations)
            try:
                best_tail, rank, score, subgraph = self.predict_tail(rand_head, rand_rel, alpha)
                results.append({
                    'head': rand_head,
                    'relation': rand_rel,
                    'best_tail': best_tail,
                    'rank': rank,
                    'score': score,
                    'subgraph': subgraph
                })
            except ValueError:
                continue

        if not results:
            raise ValueError("No valid random predictions generated")

        best_result = max(results, key=lambda x: x['score'])
        return best_result

    def get_random_head(self):
        entities = list(self.loader.entity2id.keys())
        return random.choice(entities)

    def get_random_relation(self):
        relations = list(self.loader.relation2id.keys())
        return random.choice(relations)

    def visualize_subgraph(self, subgraph, head_name, rel_name, best_tail_name):
        # Adjacency: h -> list of (t, r)
        adj = defaultdict(list)
        for h, r, t in subgraph:
            adj[h].append((t, r))

        # Initialize graph
        G = nx.DiGraph()

        used_self_global = set()        # Each entity may use at most one self-edge
        used_edges_per_layer = set()    # Dedupe per layer: (layer, h, r, t)
        layer_nodes = defaultdict(list) # For rank grouping
        reached_layer = {head_name: 0}  # Where each node ended up (deepest layer index)
        layer_nodes[0].append(head_name)

        # Layered frontier
        current_layer = {head_name}

        # Expand exactly self.n_layer steps; last layer index will be self.n_layer
        for layer in range(self.n_layer):
            next_layer = set()

            # Add nodes for current layer
            for h in current_layer:
                h_layer = f"{h}@{layer}"
                G.add_node(h_layer, label=h, layer=layer)

            for h in current_layer:
                h_layer = f"{h}@{layer}"
                for t, r in adj.get(h, []):
                    # Defer tail until final step so last layer contains only the predicted tail
                    if layer < self.n_layer - 1 and t == best_tail_name:
                        continue
                    # Final step: only allow edges that go to the predicted tail
                    if layer == self.n_layer - 1 and t != best_tail_name:
                        continue

                    # One self-edge per entity globally
                    if h == t and h in used_self_global:
                        continue

                    key = (layer, h, r, t)
                    if key in used_edges_per_layer:
                        continue
                    used_edges_per_layer.add(key)
                    if h == t:
                        used_self_global.add(h)

                    t_layer = f"{t}@{layer+1}"
                    G.add_node(t_layer, label=t, layer=layer + 1)
                    G.add_edge(h_layer, t_layer, relation=r)

                    next_layer.add(t)
                    reached_layer[t] = max(reached_layer.get(t, -1), layer + 1)
                    layer_nodes[layer + 1].append(t)

            current_layer = next_layer

        # Compute layout (position nodes by layer with reduced width for 80% edge length)
        pos = {}
        width = 800
        height = 600
        for node in G.nodes:
            layer = G.nodes[node]['layer']
            node_label = G.nodes[node]['label']
            # Spread nodes horizontally by layer, vertically within layer
            layer_size = len(layer_nodes[layer]) or 1
            y_offset = (layer_nodes[layer].index(node_label) - (layer_size - 1) / 2) \
                      * (height / (layer_size + 1)) * 2
            pos[node] = ((layer / self.n_layer) * width,
                        height / 2 + y_offset)

        # Check for predicted edge collision
        def edges_intersect(p1, p2, q1, q2, tolerance=5):
            def ccw(A, B, C):
                return (C[1] - A[1]) * (B[0] - A[0]) > (B[1] - A[1]) * (C[0] - A[0])
            def intersect(A, B, C, D):
                return ccw(A, C, D) != ccw(B, C, D) and ccw(A, B, C) != ccw(A, B, D)
            return intersect(p1, p2, q1, q2) and abs((p2[0] - p1[0]) * (q2[1] - q1[1]) - (q2[0] - q1[0]) * (p2[1] - p1[1])) > tolerance

        pred_edge_collides = False
        pred_edge_points = None
        if best_tail_name in reached_layer and reached_layer[best_tail_name] == self.n_layer:
            pred_start = pos[f"{head_name}@0"]
            pred_end = pos[f"{best_tail_name}@{self.n_layer}"]
            pred_edge_points = (pred_start, pred_end)
            for edge in G.edges():
                if edge[0] == f"{head_name}@0" and edge[1] == f"{best_tail_name}@{self.n_layer}":
                    continue
                start = pos[edge[0]]
                end = pos[edge[1]]
                if edges_intersect(pred_start, pred_end, start, end):
                    pred_edge_collides = True
                    break

        # Create edge traces
        edge_x = []
        edge_y = []
        edge_text = []
        edge_annotations = []
        for edge in G.edges(data=True):
            x0, y0 = pos[edge[0]]
            x1, y1 = pos[edge[1]]
            edge_x.extend([x0, x1, None])
            edge_y.extend([y0, y1, None])
            edge_text.append(edge[2]['relation'])

            # Place label directly on the edge, no rotation
            mid_x = (x0 + x1) / 2
            mid_y = (y0 + y1) / 2
            edge_annotations.append(
                dict(
                    x=mid_x, y=mid_y,
                    xref="x", yref="y",
                    text=edge[2]['relation'],
                    showarrow=False,
                    font=dict(size=15, color='black'),
                    textangle=0,  # No rotation, always horizontal
                    align='center',
                    xanchor='center',
                    yanchor='middle'
                )
            )

        edge_trace = go.Scatter(
            x=edge_x, y=edge_y,
            line=dict(width=1, color='black'),
            hoverinfo='text',
            text=edge_text,
            mode='lines+markers',
            marker=dict(symbol='arrow-bar-up', size=8, angleref='previous', color='black'),
            line_shape='spline',
            showlegend=False
        )

        # Predicted edge trace
        pred_edge_x = []
        pred_edge_y = []
        pred_edge_text = []
        pred_edge_annotation = []
        if best_tail_name in reached_layer and reached_layer[best_tail_name] == self.n_layer and not pred_edge_collides:
            x0, y0 = pos[f"{head_name}@0"]
            x1, y1 = pos[f"{best_tail_name}@{self.n_layer}"]
            pred_edge_x = [x0, x1, None]
            pred_edge_y = [y0, y1, None]
            pred_edge_text = [rel_name]
            # Place label directly on the edge, no rotation
            mid_x = (x0 + x1) / 2
            mid_y = (y0 + y1) / 2
            pred_edge_annotation.append(
                dict(
                    x=mid_x, y=mid_y,
                    xref="x", yref="y",
                    text=rel_name,
                    showarrow=False,
                    font=dict(size=15, color='red'),
                    textangle=0,  # No rotation, always horizontal
                    align='center',
                    xanchor='center',
                    yanchor='middle'
                )
            )

        pred_edge_trace = go.Scatter(
            x=pred_edge_x, y=pred_edge_y,
            line=dict(width=2, color='red', dash='dash'),
            hoverinfo='text',
            text=pred_edge_text,
            mode='lines+markers',
            marker=dict(symbol='arrow-bar-up', size=8, angleref='previous', color='red'),
            line_shape='spline',
            name=f"{rel_name} (predicted)",
            showlegend=False
        )

        # Create node traces
        node_x = []
        node_y = []
        node_text = []
        node_colors = []
        for node in G.nodes:
            x, y = pos[node]
            node_x.append(x)
            node_y.append(y)
            label = G.nodes[node]['label']
            layer = G.nodes[node]['layer']
            node_text.append(f"{label} (Layer {layer})")
            if node == f"{head_name}@0":
                node_colors.append('blue')
            elif node == f"{best_tail_name}@{self.n_layer}":
                node_colors.append('green')
            else:
                node_colors.append('lightblue')

        node_trace = go.Scatter(
            x=node_x, y=node_y,
            mode='markers+text',
            text=node_text,
            textposition='top center',
            hoverinfo='text',
            marker=dict(
                showscale=False,
                color=node_colors,
                size=10,
                line=dict(width=2, color='white')
            ),
            textfont=dict(size=15),
            showlegend=False
        )

        # Create figure
        annotations = edge_annotations + pred_edge_annotation
        # Add (h, r, t) at bottom if predicted edge collides
        if pred_edge_collides and best_tail_name in reached_layer and reached_layer[best_tail_name] == self.n_layer:
            annotations.append(
                dict(
                    x=0.5, y=-0.1,
                    xref="paper", yref="paper",
                    text=f"Predicted: ({head_name}, {rel_name}, {best_tail_name})",
                    showarrow=False,
                    font=dict(size=14, color='red'),
                    align='center'
                )
            )

        # Find min/max x from positions
        all_x = [coord[0] for coord in pos.values()]
        min_x, max_x = min(all_x), max(all_x)

        # Add horizontal padding (e.g., 10%)
        padding_x = (max_x - min_x) * 0.08

        fig = go.Figure(
            data=[edge_trace, pred_edge_trace, node_trace],
            layout=go.Layout(
                title='Subgraph Visualization',
                titlefont=dict(size=24),
                showlegend=False,
                hovermode='closest',
                margin=dict(b=50, l=5, r=5, t=60),
                xaxis=dict(
                    showgrid=False, zeroline=False, showticklabels=False,
                    range=[min_x - padding_x, max_x + padding_x],  # <-- only horizontal padding
                ),
                yaxis=dict(
                    showgrid=False, zeroline=False, showticklabels=False,
                    scaleanchor="x"  # preserve aspect ratio
                ),
                width=800,
                height=600,
                plot_bgcolor='white',
                annotations=annotations
            )
        )

        # Show the plot
        fig.show()

        # Return HTML for embedding if needed
        return fig.to_html()

# 6. Transductive Scenario

In [None]:
data_path = 'data/family'

class Options(object):
    pass

np.random.seed(1234)
torch.manual_seed(1234)

dataset = data_path.split('/')
if len(dataset[-1]) > 0:
    dataset = dataset[-1]
else:
    dataset = dataset[-2]

save_dir = os.getcwd() + '/'

if not os.path.exists(save_dir):
    os.makedirs(save_dir)

results_dir = save_dir + 'results'
if not os.path.exists(results_dir):
    os.makedirs(results_dir)

weights_dir = save_dir + "weights"
if not os.path.exists(weights_dir):
    os.makedirs(weights_dir)

checkpoint_dir = save_dir + "checkpoints"
if not os.path.exists(checkpoint_dir):
    os.makedirs(checkpoint_dir)

opts = Options
opts.perf_file = os.path.join(results_dir,  dataset + '_perf.txt')
opts.best_weight_file = os.path.join(weights_dir, dataset + '_weight.pt')
opts.checkpoint_weight_file = os.path.join(checkpoint_dir, dataset + '_checkpoint.pt')

loader = DataLoader(data_path)
opts.n_ent = loader.n_ent
opts.n_rel = loader.n_rel

if dataset == 'family':
    opts.lr = 0.0036
    opts.decay_rate = 0.999
    opts.lamb = 0.000017
    opts.hidden_dim = 48
    opts.attn_dim = 5
    opts.n_layer = 3
    opts.dropout = 0.29
    opts.act = 'relu'
    opts.n_batch = 20
    opts.n_tbatch = 50
elif dataset == 'umls':
    opts.lr = 0.0012
    opts.decay_rate = 0.998
    opts.lamb = 0.00014
    opts.hidden_dim = 64
    opts.attn_dim = 5
    opts.n_layer = 5
    opts.dropout = 0.01
    opts.act = 'tanh'
    opts.n_batch = 10
    opts.n_tbatch = 50
elif dataset == 'WN18RR':
    opts.lr = 0.0003
    opts.decay_rate = 0.994
    opts.lamb = 0.00014
    opts.hidden_dim = 64
    opts.attn_dim = 5
    opts.n_layer = 4
    opts.dropout = 0.02
    opts.act = 'idd'
    opts.n_batch = 50
    opts.n_tbatch = 50
elif dataset == 'fb15k-237':
    opts.lr = 0.0009
    opts.decay_rate = 0.9938
    opts.lamb = 0.000080
    opts.hidden_dim = 48
    opts.attn_dim = 5
    opts.n_layer = 3
    opts.dropout = 0.0391
    opts.act = 'relu'
    opts.n_batch = 5
    opts.n_tbatch = 1
elif dataset == 'nell':
    opts.lr = 0.0011
    opts.decay_rate = 0.9938
    opts.lamb = 0.000089
    opts.hidden_dim = 48
    opts.attn_dim = 5
    opts.n_layer = 3
    opts.dropout = 0.2593
    opts.act = 'relu'
    opts.n_batch = 5
    opts.n_tbatch = 1



config_str = '%.4f, %.4f, %.6f,  %d, %d, %d, %d, %.4f,%s\n' % (opts.lr, opts.decay_rate, opts.lamb, opts.hidden_dim, opts.attn_dim, opts.n_layer, opts.n_batch, opts.dropout, opts.act)
print(config_str)
with open(opts.perf_file, 'a+') as f:
    f.write(config_str)

model = BaseModel(opts, loader)

# start_epoch, best_mrr = model.load_checkpoint(opts.best_weight_file)

# best_mrr = 0
# for epoch in range(50):
#     mrr, out_str = model.train_batch()
#     with open(opts.perf_file, 'a+') as f:
#         f.write(out_str)
#     if mrr > best_mrr:
#         best_mrr = mrr
#         best_str = out_str
#         print(str(epoch) + '\t' + best_str)
# print(best_str)

n_train: 11736 n_valid: 3564 n_test: 4751
0.0036, 0.9990, 0.000017,  48, 5, 3, 20, 0.2900,relu

Checkpoint loaded from /content/drive/.shortcut-targets-by-id/1L4ON7WudkBp-Zjzj-Fn0FGqA35djsbeR/RED-GNN_backup/transductive/weights/family_weight.pt


# 7. Inductive Scenario

In [None]:
data_path = "data/nell_v4"
dataset = data_path.split('/')

class Options(object):
    pass

if len(dataset[-1]) > 0:
    dataset = dataset[-1]
else:
    dataset = dataset[-2]

save_dir = '/content/drive/MyDrive/RED-GNN_backup/inductive/'

if not os.path.exists(save_dir):
    os.makedirs(save_dir)

results_dir = save_dir + 'results'
if not os.path.exists(results_dir):
    os.makedirs(results_dir)

weights_dir = save_dir + "weights"
if not os.path.exists(weights_dir):
    os.makedirs(weights_dir)

checkpoint_dir = save_dir + "checkpoints"
if not os.path.exists(checkpoint_dir):
    os.makedirs(checkpoint_dir)

opts = Options
opts.perf_file = os.path.join(results_dir,  dataset + '_perf.txt')
opts.best_weight_file = os.path.join(weights_dir, dataset + '_weight.pt')
opts.checkpoint_file = os.path.join(checkpoint_dir, dataset + '_checkpoint.pt')

loader = DataLoader(data_path)
opts.n_ent = loader.n_ent
opts.n_rel = loader.n_rel

if dataset == 'WN18RR_v1':
    opts.lr = 0.005
    opts.lamb = 0.0002
    opts.decay_rate = 0.991
    opts.hidden_dim = 64
    opts.attn_dim = 5
    opts.dropout = 0.21
    opts.act = 'idd'
    opts.n_layer = 5
    opts.n_batch = 100
elif dataset == 'fb237_v1':
    opts.lr = 0.0092
    opts.lamb = 0.0003
    opts.decay_rate = 0.994
    opts.hidden_dim = 32
    opts.attn_dim = 5
    opts.dropout = 0.23
    opts.act = 'relu'
    opts.n_layer = 3
    opts.n_batch = 20
elif dataset == 'nell_v1':
    opts.lr = 0.0021
    opts.lamb = 0.000189
    opts.decay_rate = 0.9937
    opts.hidden_dim = 48
    opts.attn_dim = 5
    opts.dropout = 0.2460
    opts.act = 'relu'
    opts.n_layer = 5
    opts.n_batch = 10

elif dataset == 'WN18RR_v2':
    opts.lr = 0.0016
    opts.lamb = 0.0004
    opts.decay_rate = 0.994
    opts.hidden_dim = 48
    opts.attn_dim = 3
    opts.dropout = 0.02
    opts.act = 'relu'
    opts.n_layer = 5
    opts.n_batch = 20
elif dataset == 'fb237_v2':
    opts.lr = 0.0077
    opts.lamb = 0.0002
    opts.decay_rate = 0.993
    opts.hidden_dim = 48
    opts.attn_dim = 5
    opts.dropout = 0.3
    opts.act = 'relu'
    opts.n_layer = 3
    opts.n_batch = 10
elif dataset == 'nell_v2':
    opts.lr = 0.0075
    opts.lamb = 0.000066
    opts.decay_rate = 0.9996
    opts.hidden_dim = 48
    opts.attn_dim = 5
    opts.dropout = 0.2881
    opts.act = 'relu'
    opts.n_layer = 3
    opts.n_batch = 100

elif dataset == 'WN18RR_v3':
    opts.lr = 0.0014
    opts.lamb = 0.000034
    opts.decay_rate = 0.991
    opts.hidden_dim = 64
    opts.attn_dim = 5
    opts.dropout = 0.28
    opts.act = 'tanh'
    opts.n_layer = 5
    opts.n_batch = 20
elif dataset == 'fb237_v3':
    opts.lr = 0.0006
    opts.lamb = 0.000023
    opts.decay_rate = 0.994
    opts.hidden_dim = 48
    opts.attn_dim = 3
    opts.dropout = 0.27
    opts.act = 'relu'
    opts.n_layer = 3
    opts.n_batch = 20
elif dataset == 'nell_v3':
    opts.lr = 0.0008
    opts.lamb = 0.0004
    opts.decay_rate = 0.995
    opts.hidden_dim = 16
    opts.attn_dim = 3
    opts.dropout = 0.06
    opts.act = 'relu'
    opts.n_layer = 3
    opts.n_batch = 10

elif dataset == 'WN18RR_v4':
    opts.lr = 0.006
    opts.lamb = 0.000132
    opts.decay_rate = 0.991
    opts.hidden_dim = 32
    opts.attn_dim = 5
    opts.dropout = 0.11
    opts.act = 'relu'
    opts.n_layer = 5
    opts.n_batch = 10
elif dataset == 'fb237_v4':
    opts.lr = 0.0052
    opts.lamb = 0.000018
    opts.decay_rate = 0.999
    opts.hidden_dim = 48
    opts.attn_dim = 5
    opts.dropout = 0.07
    opts.act = 'idd'
    opts.n_layer = 5
    opts.n_batch = 20
elif dataset == 'nell_v4':
    opts.lr = 0.0005
    opts.lamb = 0.000398
    opts.decay_rate = 1
    opts.hidden_dim = 16
    opts.attn_dim = 5
    opts.dropout = 0.1472
    opts.act = 'tanh'
    opts.n_layer = 5
    opts.n_batch = 20

config_str = '%.4f, %.4f, %.6f,  %d, %d, %d, %d, %.4f,%s\n' % (opts.lr, opts.decay_rate, opts.lamb, opts.hidden_dim, opts.attn_dim, opts.n_layer, opts.n_batch, opts.dropout, opts.act)
print(config_str)

model = BaseModel(opts, loader)

# start_epoch, best_mrr = model.load_checkpoint(opts.best_weight_file)

# end_epoch = 50
# best_mrr = 0
# for epoch in range(end_epoch):
#     print(f"Epoch {epoch}")
#     mrr, out_str = model.train_batch()
#     with open(opts.perf_file, 'a+') as f:
#         f.write(f'epoch {epoch}  ' + out_str)
#     if mrr > best_mrr:
#         best_mrr = mrr
#         best_str = out_str
#         print(f"Best Epoch {epoch}:\t{best_str}")
#         model.save_checkpoint(opts.best_weight_file, epoch + 1, best_mrr)
#         print(f"Best model saved at epoch {epoch + 1}")
#     if (epoch + 1) % 4 == 0:
#         model.save_checkpoint(opts.checkpoint_file, epoch + 1, best_mrr)
#         print(f"Checkpoint saved at epoch {epoch + 1}")

Đường dẫn hiện tại: /content/drive/.shortcut-targets-by-id/1L4ON7WudkBp-Zjzj-Fn0FGqA35djsbeR/RED-GNN_backup/inductive
n_train: 1752 n_valid: 1063 n_test: 1882
0.0005, 1.0000, 0.000398,  16, 5, 5, 20, 0.1472,tanh

Epoch 0
Best Epoch 0:	[VALID] MRR:0.0723 H@1:0.0288 H@10:0.1522	 [TEST] MRR:0.0354 H@1:0.0187 H@10:0.0622 	[TIME] train:12.8732 inference:12.0386

Checkpoint saved at /content/drive/MyDrive/RED-GNN_backup/inductive/weights/nell_v4_weight.pt
Best model saved at epoch 1
Epoch 1
Best Epoch 1:	[VALID] MRR:0.1130 H@1:0.0709 H@10:0.1753	 [TEST] MRR:0.0643 H@1:0.0249 H@10:0.1240 	[TIME] train:24.9196 inference:11.7830

Checkpoint saved at /content/drive/MyDrive/RED-GNN_backup/inductive/weights/nell_v4_weight.pt
Best model saved at epoch 2
Epoch 2
Best Epoch 2:	[VALID] MRR:0.1502 H@1:0.0854 H@10:0.2964	 [TEST] MRR:0.1301 H@1:0.0743 H@10:0.2343 	[TIME] train:36.9394 inference:13.5304

Checkpoint saved at /content/drive/MyDrive/RED-GNN_backup/inductive/weights/nell_v4_weight.pt
Best mod

# 8. Visualization

In [None]:
# Initialize
kg_inference = KGInference(model)

In [None]:
# Predict for a specific doublet
head_name = "1482"
rel_name = "son"
# tail_name = "1480"
try:
    best_tail, rank, score, subgraph = kg_inference.predict_tail(head_name, rel_name, alpha=0.5)
    # print(f"Prediction for {head_name}, {rel_name}:")
    # print(f"Best Tail: {best_tail}, Rank: {rank}, Score: {score:.4f}")
    # print("Subgraph:")
    # for h, r, t in subgraph:
    #     print(f"  {h} --[{r}]--> {t}")

    # Visualize
    graph = kg_inference.visualize_subgraph(subgraph, head_name, rel_name, best_tail)
    # display(graph)
except ValueError as e:
    print(f"Error: {e}")

Layer 1: Alpha weights - Min: 0.0000, Max: 0.0604, Mean: 0.0205, Num edges: 3
Layer 1: 0 edges after alpha=0.5 pruning
Layer 2: Alpha weights - Min: 0.0002, Max: 1.0000, Mean: 0.5310, Num edges: 16
Layer 2: 9 edges after alpha=0.5 pruning
Layer 3: Alpha weights - Min: 0.0001, Max: 1.0000, Mean: 0.9261, Num edges: 84
Layer 3: 78 edges after alpha=0.5 pruning
Final subgraph size after alpha=0.5 and path pruning: 19 edges
