# Graph Neural Networks

In this notebook we give an example of a Graph Neural Network built on physics equations for a dynamic physical system.

The GNN is built with PyTorch, but we write the layers out explicitly instead of using `torch_geometric`.

## Spring-Mass System

This system has 3 masses connected by springs in a 1-dimensional chain. This means mass 1 is connected to mass 2, and mass 2 is connected to mass 3, but mass 1 is not directly connected to mass 2.

Our goal is to train the GNN to predict the next velocities (actually the change in velocities $dv$, given a spring-mass configuration of positions and current velocities.

## Generate training data

First we will generate some random configurations for training.
Note that these are *not* necessarily connected by time steps; these are simply random configurations and next velocities calculated from those configurations.


In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt

# ============================================
# 1. GENERATE PHYSICS DATA
# ============================================

def generate_spring_system_data(n_samples=1000):
    """
    Create a simple 1D chain: mass -- spring -- mass -- spring -- mass
    We'll have 3 masses connected by 2 springs
    n_samples is the number of samples to generate
    For each sample, we throw random positions and velocities
    and calculate the next velocities. We are not evolving in time!
    """
    data = []

    for _ in range(n_samples):
        # Random initial positions (displacements from equilibrium)
        # There are 3 masses, but each has only an x coordinate
        positions = torch.randn(3) * 0.5

        # Random velocities
        velocities = torch.randn(3) * 0.3

        # Spring constants (all equal for simplicity)
        k = 1.0
        # Mass (all equal)
        m = 1.0
        # Time step
        dt = 0.01

        # Calculate forces: F = -k * (x_i - x_j) for connected masses
        forces = torch.zeros(3)
        forces[0] = -k * (positions[0] - positions[1])  # spring 0-1
        forces[1] = -k * (positions[1] - positions[0]) - k * (positions[1] - positions[2])  # springs from both sides
        forces[2] = -k * (positions[2] - positions[1])  # spring 1-2

        # ADD NOISE later to make it more realistic
        #forces += torch.randn(3) * 0.1  # Measurement noise or external perturbations

        # Acceleration from F = ma
        accelerations = forces / m

        # Next velocities (Euler integration)
        next_velocities = velocities + accelerations * dt

        data.append({
            'positions': positions,
            'velocities': velocities,
            'next_velocities': next_velocities,
            'edges': torch.tensor([[0, 1], [1, 0], [1, 2], [2, 1]], dtype=torch.long).T  # bidirectional edges
        })

    return data

# Test this generator by looking at the first data sample
vis_data = generate_spring_system_data(n_samples=10)
print(vis_data[0])

The output is a series of positions, velocities, and next_velocities.

It is always a good idea to plot the data, so that we can see what it looks like.

In [None]:
import numpy as np
import matplotlib.pyplot as plt

# Extract all positions from vis_data
all_positions = []
for sample in vis_data:
    # Convert each position tensor to numpy array
    all_positions.append(sample['positions'].numpy())

all_positions = np.array(all_positions) # Shape will be (n_samples, 3)

# Plotting the positions of the 3 masses as a function of time
plt.figure(figsize=(12, 6))
for i in range(all_positions.shape[1]): # Iterate through the 3 position dimensions
    plt.plot(all_positions[:, i], label=f'Mass {i+1} Position')

plt.xlabel('Sample Index')
plt.ylabel('Position Value')
plt.title('Random Positions of 3 Masses Across Samples')
plt.legend()
plt.grid(True, alpha=0.3)
plt.tight_layout()
plt.show()

Remember that these are completely random positions, not time evolution of positions. (They really shouldn't even be connected with a line.)

## Set up network

The GNN is very simple: just one hidden layer is enough to learn this linear physics $F=-k (x_1-x_
2)$.

The key is the message passing step.

What is being learned by this network?

In [None]:
# ============================================
# 2. GRAPH NEURAL NETWORK
# ============================================

class SpringGNN(nn.Module):
    """
    A simple GNN that learns spring dynamics
    Key idea: Each node (mass) gets information from its neighbors
    through the edges (springs) to predict its next state.
    The goal is to predict what the change in velocities will be.
    """

    def __init__(self, hidden_dim=32):
        super().__init__()

        # Node features: position + velocity = 2 features per node
        # Edge network: processes information flowing along edges
        self.edge_mlp = nn.Sequential(
            nn.Linear(4, hidden_dim),  # 2 features from source + 2 from target
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim)
        )

        # Node update network: aggregates edge messages and updates node state
        self.node_mlp = nn.Sequential(
            nn.Linear(2 + hidden_dim, hidden_dim),  # original features + aggregated messages
            nn.ReLU(),
            nn.Linear(hidden_dim, 1)  # predict change in velocity
        )

    def forward(self, positions, velocities, edges):
        """
        positions: [n_nodes] - current positions
        velocities: [n_nodes] - current velocities
        edges: [2, n_edges] - edge connectivity
        """
        n_nodes = positions.shape[0]

        # Combine position and velocity as node features
        node_features = torch.stack([positions, velocities], dim=1)  # [n_nodes, 2]

        # MESSAGE PASSING STEP
        # For each edge, create a message based on source and target features
        source_nodes = edges[0]
        target_nodes = edges[1]

        # Get features of connected nodes
        source_features = node_features[source_nodes]  # [n_edges, 2]
        target_features = node_features[target_nodes]  # [n_edges, 2]

        # Create edge messages by processing both node features
        edge_features = torch.cat([source_features, target_features], dim=1)  # [n_edges, 4]
        edge_messages = self.edge_mlp(edge_features)  # [n_edges, hidden_dim]

        # AGGREGATION STEP
        # Sum messages arriving at each node
        aggregated = torch.zeros(n_nodes, edge_messages.shape[1])
        for i in range(edges.shape[1]):
            target = target_nodes[i]
            aggregated[target] += edge_messages[i]

        # UPDATE STEP
        # Combine original features with aggregated messages to predict update
        combined = torch.cat([node_features, aggregated], dim=1)  # [n_nodes, 2 + hidden_dim]
        velocity_updates = self.node_mlp(combined).squeeze()  # [n_nodes]

        return velocity_updates


## Training

Now we present the random training data to the GNN for training.
The loss function is actually the change in velocity $dv$.

*Warning*: this is pretty slow to train, even with GPU.

In [None]:
# ============================================
# 3. TRAINING
# ============================================

# Generate data
print("Generating training data...")
train_data = generate_spring_system_data(1000)
test_data = generate_spring_system_data(200)

# Initialize model
model = SpringGNN(hidden_dim=32)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

# Training loop
print("\nTraining GNN...")
n_epochs = 100

for epoch in range(n_epochs):
    total_loss = 0

    for sample in train_data:
        optimizer.zero_grad()

        # Forward pass
        predicted_dv = model(sample['positions'], sample['velocities'], sample['edges'])

        # True change in velocity
        true_dv = sample['next_velocities'] - sample['velocities']

        # Loss
        loss = F.mse_loss(predicted_dv, true_dv)

        # Backward pass
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

    if (epoch + 1) % 10 == 0:
        print(f"Epoch {epoch+1}/{n_epochs}, Loss: {total_loss/len(train_data):.6f}")


The evaluation on the test data compares the expected $dv$ with the predicted $dv$. The result is almost perfect because it is just a linear function that is easily learned.

In [None]:
# ============================================
# 4. EVALUATION
# ============================================

print("\nEvaluating on test set...")
model.eval()
test_loss = 0

with torch.no_grad():
    for sample in test_data:
        predicted_dv = model(sample['positions'], sample['velocities'], sample['edges'])
        true_dv = sample['next_velocities'] - sample['velocities']
        test_loss += F.mse_loss(predicted_dv, true_dv).item()

print(f"Test Loss: {test_loss/len(test_data):.6f}")

Now we visualize the predictions by evolving a random sample system forward in time to see if the GNN learned its kinematic behavior.

In [None]:
# ============================================
# 5. VISUALIZE PREDICTIONS
# ============================================

# Pick a random test sample and simulate forward in time
# by using the predicted change in velocities.
sample = test_data[0]
positions = sample['positions'].clone()
velocities = sample['velocities'].clone()

predicted_positions = [positions.numpy().copy()]
true_positions = [positions.numpy().copy()]

# Simulation parameters
k, m, dt = 1.0, 1.0, 0.01
n_steps = 500

with torch.no_grad():
    for step in range(n_steps):
        # GNN prediction
        dv_pred = model(positions, velocities, sample['edges'])
        velocities_pred = velocities + dv_pred
        positions = positions + velocities_pred * dt
        predicted_positions.append(positions.numpy().copy())

        # True physics (for comparison)
        # These are the same equations we used to generate the data
        if step == 0:
            true_pos = sample['positions'].clone()
            true_vel = sample['velocities'].clone()

        forces = torch.zeros(3)
        forces[0] = -k * (true_pos[0] - true_pos[1])
        forces[1] = -k * (true_pos[1] - true_pos[0]) - k * (true_pos[1] - true_pos[2])
        forces[2] = -k * (true_pos[2] - true_pos[1])

        true_vel = true_vel + (forces / m) * dt
        true_pos = true_pos + true_vel * dt
        true_positions.append(true_pos.numpy().copy())

predicted_positions = np.array(predicted_positions)
true_positions = np.array(true_positions)

# Plot
plt.figure(figsize=(12, 4))
for i in range(3):
    plt.plot(predicted_positions[:, i], label=f'GNN Mass {i}', linestyle='--')
    plt.plot(true_positions[:, i], label=f'True Mass {i}', alpha=0.7)

plt.xlabel('Time Step')
plt.ylabel('Position')
plt.title('GNN vs True Physics: Spring-Mass System')
plt.legend()
plt.grid(True, alpha=0.3)
plt.tight_layout()
plt.show()

## Analysis: what went wrong in the GNN?

We trained the GNN to learn the `dv`, the change in velocities. This is a linear function when the time steps `dt` are very small. But after many steps, the true motion curves away from this simple linear prediction. (Think, for example of $dx = v\ dt$ but with $v=at$ changing with time.)

You can see that the GNN performs very well, until we run it for many steps and the error from the linear assumption accumulates.



# Alternative GNN to learn acceleration

In fact, the *acceleration* might be a more useful thing to learn. Let's write a GNN that will learn the acceleration of the system instead of the velocity change `dv`.

Note that this will give a true linear system, unlike the velocity changes which accumulate. Therefore we don't need a very deep GNN to learn the physics of this system.

In [None]:
# ============================================
# 1. GENERATE PHYSICS DATA WITH ACCELERATIONS
# ============================================

def generate_spring_system_acc_data(n_samples=1000):
    """
    Create a simple 1D chain: mass -- spring -- mass -- spring -- mass
    We'll have 3 masses connected by 2 springs
    n_samples is the number of time steps to generate
    """
    data = []

    for _ in range(n_samples):
        # Random initial positions (displacements from equilibrium)
        # There are 3 masses, but each has only an x coordinate
        positions = torch.randn(3) * 0.5

        # Random velocities
        velocities = torch.randn(3) * 0.3

        # Spring constants (all equal for simplicity)
        k = 1.0
        # Mass (all equal)
        m = 1.0
        # Time step
        dt = 0.01

        # Calculate forces: F = -k * (x_i - x_j) for connected masses
        forces = torch.zeros(3)
        forces[0] = -k * (positions[0] - positions[1])  # spring 0-1
        forces[1] = -k * (positions[1] - positions[0]) - k * (positions[1] - positions[2])  # springs from both sides
        forces[2] = -k * (positions[2] - positions[1])  # spring 1-2

        # ADD NOISE later to make it more realistic
        #forces += torch.randn(3) * 0.1  # Measurement noise or external perturbations

        # Acceleration from F = ma
        accelerations = forces / m

        data.append({
            'positions': positions,
            'velocities': velocities,
            'accelerations': accelerations,  # Target is now acceleration
            'edges': torch.tensor([[0, 1], [1, 0], [1, 2], [2, 1]], dtype=torch.long).T  # bidirectional edges
        })

    return data

The main difference in the GNN is that we will now update and learn accelerations instead of velocity changes.

In [None]:
class SpringGNN_Acc(nn.Module):
    """
    A GNN that learns to predict ACCELERATION from current state
    This is more physics-inspired: a = F/m, and F depends on relative positions
    """

    def __init__(self, hidden_dim=32):  # Even 32 might be overkill!
        super().__init__()

        # Edge network processes information about relative positions/velocities
        # force (and acceleration) depends on RELATIVE position
        self.edge_mlp = nn.Sequential(
            nn.Linear(4, hidden_dim),  # 2 features from source + 2 from target
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim)
        )

        # Node update network: predicts acceleration from aggregated edge info
        self.node_mlp = nn.Sequential(
            nn.Linear(2 + hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, 1)  # predict acceleration
        )

    def forward(self, positions, velocities, edges):
        """Returns: predicted accelerations for each node"""
        n_nodes = positions.shape[0]

        # Node features
        node_features = torch.stack([positions, velocities], dim=1)

        # MESSAGE PASSING
        source_nodes = edges[0]
        target_nodes = edges[1]

        source_features = node_features[source_nodes]
        target_features = node_features[target_nodes]

        # Edge messages encode relative information
        edge_features = torch.cat([source_features, target_features], dim=1)
        edge_messages = self.edge_mlp(edge_features)

        # AGGREGATION
        aggregated = torch.zeros(n_nodes, edge_messages.shape[1])
        for i in range(edges.shape[1]):
            target = target_nodes[i]
            aggregated[target] += edge_messages[i]

        # UPDATE
        combined = torch.cat([node_features, aggregated], dim=1)
        accelerations = self.node_mlp(combined).squeeze()

        return accelerations

Now the training loss function has to be the acceleration instead of the change in velocities.

In [None]:
# ============================================
# 3. TRAINING
# ============================================

print("Generating training data...")
train_data = generate_spring_system_acc_data(2000)  # More data for better learning
test_data = generate_spring_system_acc_data(200)

model = SpringGNN_Acc(hidden_dim=32)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

print("\nTraining GNN to predict acceleration...")
n_epochs = 100

for epoch in range(n_epochs):
    total_loss = 0

    for sample in train_data:
        optimizer.zero_grad()

        # Predict acceleration
        predicted_accel = model(sample['positions'], sample['velocities'], sample['edges'])
        true_accel = sample['accelerations']

        # Loss
        loss = F.mse_loss(predicted_accel, true_accel)

        loss.backward()
        optimizer.step()

        total_loss += loss.item()

    if (epoch + 1) % 10 == 0:
        print(f"Epoch {epoch+1}/{n_epochs}, Loss: {total_loss/len(train_data):.6f}")

# ============================================
# 4. EVALUATION
# ============================================

print("\nEvaluating on test set...")
model.eval()
test_loss = 0

with torch.no_grad():
    for sample in test_data:
        predicted_accel = model(sample['positions'], sample['velocities'], sample['edges'])
        true_accel = sample['accelerations']
        test_loss += F.mse_loss(predicted_accel, true_accel).item()

print(f"Test Loss: {test_loss/len(test_data):.6f}")


We repeat the simulation, this time with acceleration updates at each step.

In [None]:
# ============================================
# 5. LONG-TERM SIMULATION (500 steps)
# ============================================

def simulate_with_gnn(model, initial_pos, initial_vel, edges, n_steps, dt):
    """
    Simulate using GNN-predicted accelerations
    """
    positions = initial_pos.clone()
    velocities = initial_vel.clone()
    trajectory = [positions.numpy().copy()]

    with torch.no_grad():
        for _ in range(n_steps):
            # GNN predicts acceleration
            accel = model(positions, velocities, edges)

            # Physics integration
            velocities = velocities + accel * dt
            positions = positions + velocities * dt

            trajectory.append(positions.numpy().copy())

    return np.array(trajectory)

def simulate_true_physics(initial_pos, initial_vel, n_steps, dt, k=1.0, m=1.0):
    """
    Simulate using true physics equations instead of random data snapshots
    like the training data
    """
    positions = initial_pos.clone()
    velocities = initial_vel.clone()
    trajectory = [positions.numpy().copy()]

    for _ in range(n_steps):
        # True forces
        forces = torch.zeros(3)
        forces[0] = -k * (positions[0] - positions[1])
        forces[1] = -k * (positions[1] - positions[0]) - k * (positions[1] - positions[2])
        forces[2] = -k * (positions[2] - positions[1])

        # Integration
        accel = forces / m
        velocities = velocities + accel * dt
        positions = positions + velocities * dt

        trajectory.append(positions.numpy().copy())

    return np.array(trajectory)

# Run simulation
sample = test_data[0]
# We need dt here to calculate positions from the accelerations.
dt = 0.01
n_steps = 500

print("\nRunning 500-step simulation...")
gnn_trajectory = simulate_with_gnn(model, sample['positions'], sample['velocities'],
                                    sample['edges'], n_steps, dt)
true_trajectory = simulate_true_physics(sample['positions'], sample['velocities'],
                                        n_steps, dt)
print("done!")



In [None]:
# ============================================
# 6. VISUALIZATION
# ============================================

plt.figure(figsize=(10, 6)) # Create a single figure

# Plot all trajectories on the single figure
for i in range(3):
    plt.plot(gnn_trajectory[:, i], label=f'GNN Mass {i}', linestyle='--', linewidth=2)
    plt.plot(true_trajectory[:, i], label=f'True Mass {i}', alpha=0.7, linewidth=2)

plt.xlabel('Time Step', fontsize=11)
plt.ylabel('Position', fontsize=11)
plt.title('GNN vs True Physics: 500 Time Steps', fontsize=12, fontweight='bold')
plt.legend(fontsize=10, ncol=3)
plt.grid(True, alpha=0.3)
plt.tight_layout()
plt.show()