In [1]:
import numpy as np
import matplotlib.pyplot as plt
import pickle
import torch
from torch_geometric.data import Dataset, Data

  from .autonotebook import tqdm as notebook_tqdm


In [None]:
class MetamatDataset(Dataset):
    def __init__(self, data_list, transform=None, pre_transform=None):
        """
        Custom dataset for 3D graph data with 'Young' as the label.

        Args:
            data_list (list): List of dictionaries containing 'Nodal positions', 'Edge index', and 'Young'.
            transform (callable, optional): A function/transform that takes in a Data object and returns a transformed version.
            pre_transform (callable, optional): A function/transform that is applied before saving the data.
        """
        self.data_list = data_list
        super(MetamatDataset, self).__init__(None, transform, pre_transform)

    def len(self):
        return len(self.data_list)

    def get(self, idx):
        """
        Get a single graph data object.

        Args:
            idx (int): Index of the data to retrieve.

        Returns:
            Data: A PyTorch Geometric Data object containing the graph and label.
        """
        entry = self.data_list[idx]
        nodal_positions = torch.tensor(entry['Nodal positions'], dtype=torch.float)
        edge_index = torch.tensor(entry['Edge index'], dtype=torch.long).t().contiguous()
        young = torch.tensor(entry['Young'], dtype=torch.float).mean()  # Average Young values

        return Data(x=nodal_positions, edge_index=edge_index, y=young)

In [9]:
import torch.nn as nn
from torch_geometric.nn import GCNConv

class PredictorGNN(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim):
        """
        GNN model for predicting 'Young' based on 3D graph data.

        Args:
            input_dim (int): Dimension of input node features.
            hidden_dim (int): Dimension of hidden layers.
            output_dim (int): Dimension of output (1 for regression).
        """
        super(PredictorGNN, self).__init__()
        self.conv1 = GCNConv(input_dim, hidden_dim)
        self.conv2 = GCNConv(hidden_dim, hidden_dim)
        self.fc = nn.Linear(hidden_dim, output_dim)

    def forward(self, x, edge_index):
        """
        Forward pass of the GNN model.

        Args:
            x (Tensor): Node features.
            edge_index (Tensor): Edge indices.

        Returns:
            Tensor: Predicted 'Young' value.
        """
        x = self.conv1(x, edge_index)
        x = torch.relu(x)
        x = self.conv2(x, edge_index)
        x = torch.relu(x)
        x = torch.mean(x, dim=0)  # Global pooling
        x = self.fc(x)
        return x

In [10]:
def read_pkl_file(file_path):
    with open(file_path, 'rb') as file:
        data = pickle.load(file)
    return data

data = read_pkl_file("/home/wzhan24/MetaMatDiff/datacreate/data.pkl")

In [11]:
# Build dataset from 'data'
dataset = MetamatDataset(data)

# Split dataset into train, validation, and test sets
train_size = int(0.6 * len(dataset))
val_size = int(0.2 * len(dataset))
test_size = len(dataset) - train_size - val_size

train_dataset, val_dataset, test_dataset = torch.utils.data.random_split(dataset, [train_size, val_size, test_size])

print(f"Train dataset size: {len(train_dataset)}")
print(f"Validation dataset size: {len(val_dataset)}")
print(f"Test dataset size: {len(test_dataset)}")

Train dataset size: 10357
Validation dataset size: 3452
Test dataset size: 3453


In [None]:
# Initialize the model, optimizer, and loss function
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = PredictorGNN(input_dim=3, hidden_dim=64, output_dim=1).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
criterion = nn.MSELoss()

# Training loop with validation
def train_model(model, train_dataset, val_dataset, optimizer, criterion, device, epochs=50):
    model.train()
    for epoch in range(epochs):
        total_loss = 0
        for data in train_dataset:
            data = data.to(device)
            optimizer.zero_grad()
            output = model(data.x, data.edge_index)
            loss = criterion(output, data.y)
            loss.backward()
            optimizer.step()
            total_loss += loss.item()
        print(f"Epoch {epoch+1}/{epochs}, Loss: {total_loss/len(train_dataset):.4f}")

        # Validate every 5 epochs
        if (epoch + 1) % 5 == 0:
            model.eval()
            val_loss = 0
            with torch.no_grad():
                for data in val_dataset:
                    data = data.to(device)
                    output = model(data.x, data.edge_index)
                    loss = criterion(output, data.y)
                    val_loss += loss.item()
            print(f"Validation Loss after Epoch {epoch+1}: {val_loss/len(val_dataset):.4f}")
            model.train()

# Train the model
train_model(model, train_dataset, val_dataset, optimizer, criterion, device)