In [6]:
import os
import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt

from utils.data_loader import FEMDataset
from model.GNN import MeshGraphNet


In [7]:
root = os.getcwd()
data_path = os.path.join(root, "data", "ball_plate_gnn_data.npz")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

dataset = FEMDataset(data_path)
loader = DataLoader(dataset, batch_size=5, shuffle=True, collate_fn=lambda batch: batch)

node_dim = dataset.X_list.shape[2]
edge_dim = dataset.edge_attr.shape[1] if dataset.edge_attr is not None else 1
model = MeshGraphNet(node_dim=node_dim, edge_dim=edge_dim).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)


In [8]:
epochs = 5
loss_history = []
model.train()
for epoch in range(epochs):
    total_loss = 0.0
    num_graphs = 0
    for graphs in loader:
        optimizer.zero_grad()
        batch_loss = 0.0
        for graph in graphs:
            graph = graph.to(device)
            pred = model(graph)
            loss = F.mse_loss(pred, graph.y)
            batch_loss = batch_loss + loss
            num_graphs += 1

        batch_loss.backward()
        optimizer.step()
        total_loss += batch_loss.item()

    avg_loss = total_loss / max(num_graphs, 1)
    loss_history.append(avg_loss)
    print(f"Epoch {epoch+1}/{epochs} - loss: {avg_loss:.6f}")


TypeError: expected Tensor as element 1 in argument 0, but got tuple

In [None]:
plt.plot(loss_history, marker="-")
plt.xlabel("Epoch")
plt.ylabel("Avg Loss")
plt.title("Training Loss")
plt.grid(True)
plt.show()


In [None]:
save_path = os.path.join(root, "model", "meshgraphnet.pt")
torch.save(model.state_dict(), save_path)
print(f"Saved model to {save_path}")
