In [13]:
import numpy as np
import torch
import torch.nn as nn
from sklearn.preprocessing import MinMaxScaler

# -------------------------
# 1. Load LiH dataset
# -------------------------
positions = np.load("eqnn_force_field_data/Positions.npy")  # shape (N, 2, 3)
forces = np.load("eqnn_force_field_data/Forces.npy")        # shape (N, 2, 3)
energy = np.load("eqnn_force_field_data/Energy.npy")        # shape (N,)

# -------------------------
# 2. Preprocessing
# -------------------------
energy = energy.reshape(-1, 1)
scaler = MinMaxScaler((-1, 1))
energy_scaled = scaler.fit_transform(energy).flatten()
forces_scaled = forces * scaler.scale_[0]
positions_centered = positions - positions[:, 0:1, :]  # place first atom at origin

# Convert to torch tensors
positions_tensor = torch.tensor(positions_centered, dtype=torch.float32)
forces_tensor = torch.tensor(forces_scaled, dtype=torch.float32)
energy_tensor = torch.tensor(energy_scaled, dtype=torch.float32).unsqueeze(-1)

# -------------------------
# 3. EQNN Model
# -------------------------
class EQNN(nn.Module):
    def __init__(self, num_atoms=2, hidden_dim=64):
        super(EQNN, self).__init__()
        self.num_atoms = num_atoms
        self.mlp = nn.Sequential(
            nn.Linear(num_atoms*(num_atoms-1)//2, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, 1)
        )

    def forward(self, positions):
        """
        positions: (batch, num_atoms, 3) with requires_grad=True
        Returns: energy (batch,1), forces (batch,num_atoms,3)
        """
        batch_size = positions.shape[0]
        # Differentiable pairwise distances
        dists = []
        for i in range(self.num_atoms):
            for j in range(i+1, self.num_atoms):
                diff = positions[:, i:i+1, :] - positions[:, j:j+1, :]  # (batch,1,3)
                dist = torch.sqrt(torch.sum(diff**2, dim=-1, keepdim=True))  # (batch,1,1)
                dists.append(dist)
        dists = torch.cat(dists, dim=1)  # (batch, num_pairs, 1)
        dists = dists.squeeze(-1)        # (batch, num_pairs)

        # Predict energy
        energy = self.mlp(dists)

        # Compute forces: -dE/dR
        forces = -torch.autograd.grad(
            energy.sum(), positions, create_graph=True
        )[0]

        return energy, forces

# -------------------------
# 4. Training setup
# -------------------------
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
eqnn_model = EQNN(num_atoms=positions_tensor.shape[1], hidden_dim=64).to(device)
optimizer = torch.optim.Adam(eqnn_model.parameters(), lr=1e-3)
loss_fn = nn.MSELoss()
epochs = 10000
batch_size = 8
N = positions_tensor.shape[0]

# -------------------------
# 5. Training loop
# -------------------------
eqnn_model.train()
for epoch in range(epochs):
    perm = torch.randperm(N)
    total_loss = 0.0
    for i in range(0, N, batch_size):
        idx = perm[i:i+batch_size]
        pos_batch = positions_tensor[idx].to(device)
        e_batch = energy_tensor[idx].to(device)
        f_batch = forces_tensor[idx].to(device)

        # Crucial: positions must require grad for forces
        pos_batch.requires_grad_(True)

        optimizer.zero_grad()
        e_pred, f_pred = eqnn_model(pos_batch)
        loss = loss_fn(e_pred, e_batch) + loss_fn(f_pred, f_batch)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()

    if (epoch+1) % 200 == 0:
        print(f"Epoch {epoch+1}/{epochs} | Loss: {total_loss/(N//batch_size):.6f}")

# -------------------------
# 6. Prediction / Testing
# -------------------------
eqnn_model.eval()
pos_batch = positions_tensor.to(device)
pos_batch.requires_grad_(True)  # needed for autograd forces
energy_preds, force_preds = eqnn_model(pos_batch)
energy_preds = energy_preds.detach().cpu().numpy()
force_preds = force_preds.detach().cpu().numpy()
energy_true = energy_tensor.numpy()
force_true = forces_tensor.numpy()

energy_mae = np.mean(np.abs(energy_preds - energy_true))
force_mae = np.mean(np.abs(force_preds - force_true))

print("\n✅ Prediction complete!")
print(f"Energy MAE: {energy_mae:.6f}")
print(f"Force MAE:  {force_mae:.6f}")

# Inspect first sample
print("\nFirst sample:")
print("Predicted energy:", energy_preds[0])
print("True energy:     ", energy_true[0])
print("Predicted forces:\n", force_preds[0])
print("True forces:\n", force_true[0])


Epoch 200/10000 | Loss: 323.620860
Epoch 400/10000 | Loss: 368.923889
Epoch 600/10000 | Loss: 405.270325
Epoch 800/10000 | Loss: 326.677053
Epoch 1000/10000 | Loss: 384.755849
Epoch 1200/10000 | Loss: 400.796010
Epoch 1400/10000 | Loss: 423.364889
Epoch 1600/10000 | Loss: 420.909078
Epoch 1800/10000 | Loss: 420.817973
Epoch 2000/10000 | Loss: 420.815216
Epoch 2200/10000 | Loss: 420.816488
Epoch 2400/10000 | Loss: 420.818064
Epoch 2600/10000 | Loss: 420.815165
Epoch 2800/10000 | Loss: 420.813334
Epoch 3000/10000 | Loss: 420.814860
Epoch 3200/10000 | Loss: 420.813273
Epoch 3400/10000 | Loss: 420.813304
Epoch 3600/10000 | Loss: 420.814596
Epoch 3800/10000 | Loss: 420.813273
Epoch 4000/10000 | Loss: 420.813507
Epoch 4200/10000 | Loss: 420.816050
Epoch 4400/10000 | Loss: 420.813232
Epoch 4600/10000 | Loss: 420.814290
Epoch 4800/10000 | Loss: 420.813243
Epoch 5000/10000 | Loss: 420.816294
Epoch 5200/10000 | Loss: 420.818960
Epoch 5400/10000 | Loss: 420.813375
Epoch 5600/10000 | Loss: 420.815

In [15]:
import numpy as np
import torch
import torch.nn as nn
from sklearn.preprocessing import MinMaxScaler

# -------------------------
# 1. Load LiH dataset
# -------------------------
positions = np.load("eqnn_force_field_data/Positions.npy")  # shape (N, num_atoms, 3)
forces = np.load("eqnn_force_field_data/Forces.npy")        # shape (N, num_atoms, 3)
energy = np.load("eqnn_force_field_data/Energy.npy")        # shape (N,)

# -------------------------
# 2. Preprocessing
# -------------------------
energy = energy.reshape(-1, 1)
scaler = MinMaxScaler((-1, 1))
energy_scaled = scaler.fit_transform(energy).flatten()
forces_scaled = forces * scaler.scale_[0]
positions_centered = positions - positions[:, 0:1, :]  # place first atom at origin

# Convert to torch tensors
positions_tensor = torch.tensor(positions_centered, dtype=torch.float32)
forces_tensor = torch.tensor(forces_scaled, dtype=torch.float32)
energy_tensor = torch.tensor(energy_scaled, dtype=torch.float32).unsqueeze(-1)

# -------------------------
# 3. Vectorized EQNN Model
# -------------------------
class EQNN(nn.Module):
    def __init__(self, num_atoms=2, hidden_dim=64):
        super(EQNN, self).__init__()
        self.num_atoms = num_atoms
        self.num_pairs = num_atoms * (num_atoms - 1) // 2
        self.mlp = nn.Sequential(
            nn.Linear(self.num_pairs, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, 1)
        )

    def forward(self, positions):
        """
        positions: (batch, num_atoms, 3)
        returns: energy (batch,1), forces (batch,num_atoms,3)
        """
        batch_size = positions.shape[0]

        # Vectorized pairwise distances
        diff = positions.unsqueeze(2) - positions.unsqueeze(1)  # (batch, N, N, 3)
        dists = torch.sqrt(torch.sum(diff**2, dim=-1) + 1e-12)  # (batch, N, N)
        
        # Take upper triangular without diagonal (i<j)
        triu_idx = torch.triu_indices(self.num_atoms, self.num_atoms, offset=1)
        dists_pairs = dists[:, triu_idx[0], triu_idx[1]]  # (batch, num_pairs)

        # Predict energy
        energy = self.mlp(dists_pairs)

        # Compute forces
        forces = -torch.autograd.grad(
            energy.sum(), positions, create_graph=True
        )[0]

        return energy, forces

# -------------------------
# 4. Training setup
# -------------------------
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
eqnn_model = EQNN(num_atoms=positions_tensor.shape[1], hidden_dim=64).to(device)
optimizer = torch.optim.Adam(eqnn_model.parameters(), lr=1e-3)
loss_fn = nn.MSELoss()
epochs = 10000
batch_size = 8
N = positions_tensor.shape[0]

# -------------------------
# 5. Training loop
# -------------------------
eqnn_model.train()
for epoch in range(epochs):
    perm = torch.randperm(N)
    total_loss = 0.0
    for i in range(0, N, batch_size):
        idx = perm[i:i+batch_size]
        pos_batch = positions_tensor[idx].to(device)
        e_batch = energy_tensor[idx].to(device)
        f_batch = forces_tensor[idx].to(device)

        pos_batch.requires_grad_(True)  # must require grad

        optimizer.zero_grad()
        e_pred, f_pred = eqnn_model(pos_batch)
        loss = loss_fn(e_pred, e_batch) + loss_fn(f_pred, f_batch)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()

    if (epoch+1) % 200 == 0:
        print(f"Epoch {epoch+1}/{epochs} | Loss: {total_loss/(N//batch_size):.6f}")

# -------------------------
# 6. Prediction / Testing
# -------------------------
eqnn_model.eval()
pos_batch = positions_tensor.to(device)
pos_batch.requires_grad_(True)
energy_preds, force_preds = eqnn_model(pos_batch)
energy_preds = energy_preds.detach().cpu().numpy()
force_preds = force_preds.detach().cpu().numpy()
energy_true = energy_tensor.numpy()
force_true = forces_tensor.numpy()

energy_mae = np.mean(np.abs(energy_preds - energy_true))
force_mae = np.mean(np.abs(force_preds - force_true))

print("\n✅ Prediction complete!")
print(f"Energy MAE: {energy_mae:.6f}")
print(f"Force MAE:  {force_mae:.6f}")

# Inspect first sample
print("\nFirst sample:")
print("Predicted energy:", energy_preds[0])
print("True energy:     ", energy_true[0])
print("Predicted forces:\n", force_preds[0])
print("True forces:\n", force_true[0])


Epoch 200/10000 | Loss: 329.008860
Epoch 400/10000 | Loss: 371.334086
Epoch 600/10000 | Loss: 405.563853
Epoch 800/10000 | Loss: 403.149999
Epoch 1000/10000 | Loss: 424.659047
Epoch 1200/10000 | Loss: 422.150533
Epoch 1400/10000 | Loss: 422.983256
Epoch 1600/10000 | Loss: 421.623027
Epoch 1800/10000 | Loss: 421.076742
Epoch 2000/10000 | Loss: 420.889079
Epoch 2200/10000 | Loss: 420.833344
Epoch 2400/10000 | Loss: 420.823558
Epoch 2600/10000 | Loss: 420.814006
Epoch 2800/10000 | Loss: 420.831919
Epoch 3000/10000 | Loss: 420.817322
Epoch 3200/10000 | Loss: 420.815725
Epoch 3400/10000 | Loss: 420.814494
Epoch 3600/10000 | Loss: 420.827596
Epoch 3800/10000 | Loss: 420.819346
Epoch 4000/10000 | Loss: 420.815603
Epoch 4200/10000 | Loss: 420.813756
Epoch 4400/10000 | Loss: 420.821788
Epoch 4600/10000 | Loss: 420.816223
Epoch 4800/10000 | Loss: 420.834554
Epoch 5000/10000 | Loss: 420.819631
Epoch 5200/10000 | Loss: 420.815104
Epoch 5400/10000 | Loss: 420.814107
Epoch 5600/10000 | Loss: 420.833