Imports und Hilfsfunktionen

In [4]:
import torch 

import torch_geometric as pyg
import torch_geometric.nn as nn
import torch_geometric.transforms as T
import torch_geometric.utils as utils
import torch_geometric.data as data
import torch_geometric.datasets as datasets

import networkx as nx
import matplotlib.pyplot as plt

def visualize_graph(G, color):
    plt.figure(figsize=(7,7))
    plt.xticks([])
    plt.yticks([])
    nx.draw_networkx(G, pos=nx.spring_layout(G, seed=42), with_labels=False,
                     node_color=color, cmap="Set2")
    plt.show()

def visualize_embedding(h, color, epoch=None, loss=None):
    plt.figure(figsize=(7,7))
    plt.xticks([])
    plt.yticks([])
    h = h.detach().cpu().numpy()
    plt.scatter(h[:, 0], h[:, 1], s=140, c=color, cmap="Set2")
    if epoch is not None and loss is not None:
        plt.xlabel(f'Epoch: {epoch}, Loss: {loss.item():.4f}', fontsize=16)
    plt.show()

In [None]:
#Example Visualization
from torch_geometric.datasets import KarateClub
dataset = KarateClub()
data = dataset[0]
G = utils.to_networkx(data, to_undirected=True)
#visualize_graph(G, color=data.y)

In [None]:
#Example Model
class MPNN(MessagePassing):
    def __init__(self, in_dim, hidden_dim, out_dim):
        super().__init__(aggr="add")  # klassische Sum-Aggregation

        # MLP für Messages (Kanten)
        self.msg_mlp = nn.Sequential(
            nn.Linear(2 * in_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
        )

        # MLP für Update (Knoten)
        self.upd_mlp = nn.Sequential(
            nn.Linear(in_dim + hidden_dim, out_dim),
            nn.ReLU(),
        )

    def forward(self, x, edge_index):
        # x: [num_nodes, in_dim]
        # propagate gibt aggr_out zurück
        return self.propagate(edge_index, x=x)

    def message(self, x_i, x_j):
        # x_i: Zielknoten-Features
        # x_j: Quellknoten-Features
        msg_input = torch.cat([x_i, x_j], dim=-1)
        return self.msg_mlp(msg_input)

    def update(self, aggr_out, x):
        # aggr_out: aggregierte Nachrichten pro Knoten
        upd_input = torch.cat([x, aggr_out], dim=-1)
        return self.upd_mlp(upd_input)


In [None]:
#Example Calling the Model
model = MPNN(in_channels=dataset.num_features, out_channels=16)
_, h = model(data.x, data.edge_index)

#Print Model Structure
#MessagePassing.__call__() (geerbt von nn.Module)
#ruft MessagePassing.forward()
#forward() ruft intern:
#propagate() MessagePassing.propagate(edge_index, size=None, **kwargs)
#→ message() MessagePassing.message(...)
#→ aggregate() MessagePassing.aggregate()
#→ update() MessagePassing.update() If you do not override update(), PyG defaults to: return aggr_out

In [None]:
#Training Loop Example
criterion = torch.nn.CrossEntropyLoss()  # Define loss criterion.
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)  # Define optimizer.

def train(data):
    optimizer.zero_grad()  # Clear gradients.
    out, h = model(data.x, data.edge_index)  # Perform a single forward pass.
    loss = criterion(out, data.y)  # Compute the loss solely based on the training nodes.
    loss.backward()  # Derive gradients.
    optimizer.step()  # Update parameters based on gradients.
    return loss, h

for epoch in range(1000):
    loss, h = train(data)
    if epoch % 100 == 0: print(f'Epoch {epoch}, Loss: {loss.item():.4f}')
