In [30]:
import torch
import torch.nn.functional as F
from torch.nn import Linear
from torch_geometric.nn import RGCNConv
from torch_geometric.data import Data, DataLoader
from torch_geometric.utils import add_self_loops
import random

class RGCN(torch.nn.Module):
    def __init__(self, M, N, hidden_channels, out_channels, num_relations, num_classes):
        super().__init__()
        # This MLP will be used to upscale M-dim features to N-dim
        self.M = M
        self.N = N
        self.upscale = Linear(M, N)

        # After upscaling, all features are Nx
        self.conv1 = RGCNConv(N, hidden_channels, num_relations, num_bases=2)
        self.conv2 = RGCNConv(hidden_channels, out_channels, num_relations, num_bases=2)
        self.fc = Linear(out_channels, num_classes)

    def forward(self, x, edge_index, edge_type, is_m_dim):
        # x is [num_nodes, N]
        # is_m_dim is a boolean mask: True where node features were originally M-dim

        # Upscale only those nodes that are M-dim
        # Extract M-dim node features (only first M entries are relevant)
        M_nodes = x[is_m_dim, :self.M]   # shape: [num_M_nodes, M]
        M_upscaled = self.upscale(M_nodes)  # shape: [num_M_nodes, N]

        # Replace the M-dim node rows in x with the upscaled features
        x_new = x.clone()
        x_new[is_m_dim] = M_upscaled
        x = x_new

        # Now all nodes are effectively N-dim
        x = self.conv1(x, edge_index, edge_type)
        x = F.relu(x)
        x = self.conv2(x, edge_index, edge_type)
        return self.fc(x)

def create_graph(
    min_nodes,
    max_nodes,
    min_edges,
    max_edges,
    num_relations,
    M,
    N,
    num_classes,
    relation_mode,
    p_m=0.5
):
    """
    Create a random graph.
    Each node feature is either M-dim or N-dim. We store them in an N-dim tensor.
    For M-dim nodes, the last (N-M) entries are zero.
    We'll also store a boolean mask indicating which nodes had M-dim features originally.
    """
    num_nodes = random.randint(min_nodes, max_nodes)
    num_edges = random.randint(min_edges, max_edges)

    # Generate random edges and types
    edge_index = torch.randint(0, num_nodes, (2, num_edges))
    edge_type = torch.randint(0, num_relations, (num_edges,))

    # Remove self-loops introduced by random sampling
    mask = edge_index[0] != edge_index[1]
    edge_index = edge_index[:, mask]
    edge_type = edge_type[mask]

    if relation_mode == 1:
        # Add self-loops with a new relation (num_relations-th type)
        edge_index, edge_type = add_self_loops(edge_index, edge_type, fill_value=num_relations)
    elif relation_mode == 2:
        # Add inverse relations
        src, dst = edge_index
        inv_edge_index = torch.stack([dst, src], dim=0)
        inv_edge_type = edge_type + num_relations
        edge_index = torch.cat([edge_index, inv_edge_index], dim=1)
        edge_type = torch.cat([edge_type, inv_edge_type], dim=0)

        # Add self-loops with new relation type = 2N
        edge_index, edge_type = add_self_loops(edge_index, edge_type, fill_value=(2 * num_relations))

    # Create features
    # With probability p_m, node has M-dim features, else N-dim
    is_m_dim = torch.rand(num_nodes) < p_m
    x = torch.zeros(num_nodes, N)
    for i in range(num_nodes):
        if is_m_dim[i]:
            # M-dim
            feats = torch.rand(M)
            x[i, :M] = feats
        else:
            # N-dim
            feats = torch.rand(N)
            x[i] = feats

    y = torch.randint(0, num_classes, (num_nodes,))

    data = Data(x=x, edge_index=edge_index, edge_type=edge_type, y=y)
    # Store the mask in the data object so we know which nodes were M-dim originally
    data.is_m_dim = is_m_dim
    return data

def create_dataset(
    num_graphs,
    min_nodes,
    max_nodes,
    min_edges,
    max_edges,
    num_relations,
    M,
    N,
    num_classes,
    relation_mode,
):
    dataset = [
        create_graph(
            min_nodes,
            max_nodes,
            min_edges,
            max_edges,
            num_relations,
            M,
            N,
            num_classes,
            relation_mode,
        )
        for _ in range(num_graphs)
    ]
    return dataset

def train_rgcn(
    dataset,
    num_relations,
    batch_size,
    num_epochs,
    M,
    N,
    device,
    num_classes,
    relation_mode,
):
    loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

    if relation_mode == 0:
        total_relations = num_relations
    elif relation_mode == 1:
        total_relations = num_relations + 1
    else:  # relation_mode == 2
        total_relations = 2 * num_relations + 1

    hidden_dim = 16
    out_dim = 16

    model = RGCN(
        M=M,
        N=N,
        hidden_channels=hidden_dim,
        out_channels=out_dim,
        num_relations=total_relations,
        num_classes=num_classes,
    ).to(device)

    optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
    loss_fn = torch.nn.CrossEntropyLoss()

    for epoch in range(num_epochs):
        model.train()
        total_loss = 0
        correct_predictions = 0
        total_samples = 0

        for batch in loader:
            batch = batch.to(device)
            optimizer.zero_grad()
            out = model(batch.x, batch.edge_index, batch.edge_type, batch.is_m_dim)

            # Compute loss
            loss = loss_fn(out, batch.y)
            loss.backward()
            optimizer.step()

            total_loss += loss.item()

            # Compute accuracy
            _, predicted = torch.max(out, dim=1)  # Get predicted class indices
            correct_predictions += (predicted == batch.y).sum().item()
            total_samples += batch.y.size(0)

        # Calculate average loss and accuracy for the epoch
        avg_loss = total_loss / len(loader)
        accuracy = correct_predictions / total_samples

        if (epoch + 1) % 100 == 0:
            print(f"Mode {relation_mode}, Epoch {epoch+1}, Loss: {avg_loss:.4f}, Accuracy: {accuracy:.4f}")

# Example usage
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

num_relations = 10
num_classes = 2
num_graphs = 2
min_nodes, max_nodes = 8, 8
min_edges, max_edges = 8, 8

M = 8   # smaller dimension
N = 16  # larger dimension
batch_size = 128
num_epochs = 512*12

dataset_mode0 = create_dataset(num_graphs, min_nodes, max_nodes, min_edges, max_edges, num_relations, M, N, num_classes, 0)
dataset_mode1 = create_dataset(num_graphs, min_nodes, max_nodes, min_edges, max_edges, num_relations, M, N, num_classes, 1)
dataset_mode2 = create_dataset(num_graphs, min_nodes, max_nodes, min_edges, max_edges, num_relations, M, N, num_classes, 2)

# print("Training with only N relations:")
# train_rgcn(
#     dataset=dataset_mode0,
#     num_relations=num_relations,
#     batch_size=batch_size,
#     num_epochs=num_epochs,
#     M=M,
#     N=N,
#     device=device,
#     num_classes=num_classes,
#     relation_mode=0,
# )

# print("\nTraining with N+1 relations (self-loop added):")
# train_rgcn(
#     dataset=dataset_mode1,
#     num_relations=num_relations,
#     batch_size=batch_size,
#     num_epochs=num_epochs,
#     M=M,
#     N=N,
#     device=device,
#     num_classes=num_classes,
#     relation_mode=1,
# )

# print("\nTraining with 2N+1 relations (self-loop & inverse):")
# train_rgcn(
#     dataset=dataset_mode2,
#     num_relations=num_relations,
#     batch_size=batch_size,
#     num_epochs=num_epochs,
#     M=M,
#     N=N,
#     device=device,
#     num_classes=num_classes,
#     relation_mode=2,
# )


In [31]:
dataset_mode0[0]

Data(x=[8, 16], edge_index=[2, 7], y=[8], edge_type=[7], is_m_dim=[8])

In [32]:
dataset_mode0[0].edge_index

tensor([[4, 3, 4, 0, 5, 3, 7],
        [2, 7, 2, 2, 1, 5, 2]])

In [33]:
dataset_mode0[0].edge_type

tensor([2, 8, 6, 3, 3, 2, 1])

In [13]:
import torch
from torch_geometric.utils import add_self_loops

def add_inverse_edges(edge_index, edge_type, inv_offset):
    """
    Adds inverse edges (reversed edges) to the graph.
    
    Args:
        edge_index (torch.Tensor): Original edge index of shape [2, num_edges].
        edge_type (torch.Tensor): Edge type tensor of shape [num_edges].
        inv_offset (int): Offset to assign new relation types for inverse edges.
    
    Returns:
        torch.Tensor: Updated edge index with inverse edges.
        torch.Tensor: Updated edge type with inverse edge types.
    """
    # Reverse edges
    inv_edge_index = edge_index.flip(0)  # Swap source and target nodes
    
    # Assign new relation types for inverse edges
    inv_edge_type = edge_type + inv_offset
    
    # Concatenate original and inverse edges
    edge_index = torch.cat([edge_index, inv_edge_index], dim=1)
    edge_type = torch.cat([edge_type, inv_edge_type], dim=0)
    
    return edge_index, edge_type

# Example: Add inverse edges first
edge_index = torch.tensor([[0, 1, 2, 2],  # Source nodes
                           [1, 2, 0, 0]]) # Target nodes
edge_type = torch.tensor([0, 1, 2, 2])  # Relation types
print("Original Edge Index:")
print(edge_index)
print("Original Edge Types:")
print(edge_type)
print()

# Add inverse edges with an offset for inverse relation types
inv_offset = edge_type.max().item() + 1
edge_index, edge_type = add_inverse_edges(edge_index, edge_type, inv_offset)
print("Edge Index after Adding Inverse Edges:")
print(edge_index)
print("Edge Types after Adding Inverse Edges:")
print(edge_type)
print()

# Add self-loops as the final step
edge_index, edge_type = add_self_loops(edge_index, edge_type, fill_value=edge_type.max().item() + 1)
print("Final Edge Index with Self-Loops:")
print(edge_index)
print("Final Edge Types with Self-Loops:")
print(edge_type)


Original Edge Index:
tensor([[0, 1, 2, 2],
        [1, 2, 0, 0]])
Original Edge Types:
tensor([0, 1, 2, 2])

Edge Index after Adding Inverse Edges:
tensor([[0, 1, 2, 2, 1, 2, 0, 0],
        [1, 2, 0, 0, 0, 1, 2, 2]])
Edge Types after Adding Inverse Edges:
tensor([0, 1, 2, 2, 3, 4, 5, 5])

Final Edge Index with Self-Loops:
tensor([[0, 1, 2, 2, 1, 2, 0, 0, 0, 1, 2],
        [1, 2, 0, 0, 0, 1, 2, 2, 0, 1, 2]])
Final Edge Types with Self-Loops:
tensor([0, 1, 2, 2, 3, 4, 5, 5, 6, 6, 6])


In [4]:
from datetime import datetime

datetime.now().isoformat(timespec="seconds")

'2024-12-17T17:09:19'