# 3. Generating data files (.pt)
This file contains code that generates data files in the .pt format for training Graph Neural Networks (GNNs). The process involves setting up the dataset, converting a network to PyTorch Geometric data format, splitting the data into training and test sets, and saving the dataset.
### 3.1 Setup instructions
Make sure that the FeatureData_FakeMatl exists in order to generate the graph.
Import the necessary modules and packages:

In [13]:
import import_ipynb
from _2_Visualize_Microstructures import gen_graph, DIR_LOC

import os

import numpy as np
from torch_geometric.utils import from_networkx
import torch
from torch_geometric.loader import DataLoader

GEN_STRUCTURES_FILE_BASE = os.path.join(DIR_LOC, "generated_microstructures", "FeatureData_FakeMatl_")

# Train and Test Dataset
NUM_MICROSTRUCTURES_START = 0 # Including start
NUM_MICROSTRUCTURES_END = 200 # Not including end
TRAIN_RATIO = 0.8 # Ratio of training data to test data (80% training, 20% test)

# Validate Dataset
NUM_MICROSTRUCTURES_VALIDATE_START = 200 # Including start
NUM_MICROSTRUCTURES_VALIDATE_END = 220 # Not including end

### 3.2 Convert networkX to PyTorch Geometric
The function `network_to_pyg_data` performs the following steps:
1. Generate the network graph using the gen_graph function from the Visualize_Microstructures module.
2. Convert the networkx graph to PyTorch Geometric format using the from_networkx function. It groups the node attributes into a single tensor: `data.x`
3. Set the target labels `data.y' for the graph.
4. Split the data into training and test sets based on the specified TRAIN_RATIO.
5. Return the processed data.

When converting, r () and the edge weight into a single tensor (`data.edge_attr`). It then sets the target labels ().

In [14]:
def network_to_pyg_data(file):
    G = gen_graph(file)
    pyg_graph = from_networkx(G, group_node_attrs=["pos"], group_edge_attrs=["weight"])
    pyg_graph.y = pyg_graph["surfaceFeature"]
    del pyg_graph["surfaceFeature"]
    pyg_graph.y = pyg_graph.y.type(torch.LongTensor)

    # Split the data
    train_ratio = TRAIN_RATIO
    num_nodes = pyg_graph.x.shape[0]
    num_train = int(num_nodes * train_ratio)
    idx = [i for i in range(num_nodes)]

    np.random.shuffle(idx)
    train_mask = torch.full_like(pyg_graph.y, False, dtype=bool)
    train_mask[idx[:num_train]] = True
    test_mask = torch.full_like(pyg_graph.y, False, dtype=bool)
    test_mask[idx[num_train:]] = True

    data = pyg_graph
    data.train_mask = train_mask
    data.test_mask = test_mask

    return data

### 3.3 Combine data
We need to load the graphs together and then combine them into a single dataset. The function `combine_data` loads the graphs into a list and then uses the DataLoader to combine them into a single dataset.

In [15]:
def combine_data(start, end):
    data_batch = []
    for i in range(start, end):
        file = GEN_STRUCTURES_FILE_BASE + str(i) + ".csv"
        print("Loading graph " + str(i) + "...")
        data_batch.append(network_to_pyg_data(file))

    # loader to combine data
    print("Combining data...")
    
    loader = DataLoader(data_batch, batch_size=16)
    data = next(iter(loader))
    print(data)
    return data

### 3.4 Save data
After combining the data, we can save it to a .pt file using the torch.save function. We do this for both the test/train data, as well as the validate dataset, which is used to validate the model after training.

In [16]:
data_train_test = combine_data(NUM_MICROSTRUCTURES_START, NUM_MICROSTRUCTURES_END)
torch.save(data_train_test, "datasets/data.pt")

Loading graph 0...
Loading graph 1...
Loading graph 2...
Loading graph 3...
Loading graph 4...
Loading graph 5...
Loading graph 6...
Loading graph 7...
Loading graph 8...
Loading graph 9...
Loading graph 10...
Loading graph 11...
Loading graph 12...
Loading graph 13...
Loading graph 14...
Loading graph 15...
Loading graph 16...
Loading graph 17...
Loading graph 18...
Loading graph 19...
Loading graph 20...
Loading graph 21...
Loading graph 22...
Loading graph 23...
Loading graph 24...
Loading graph 25...
Loading graph 26...
Loading graph 27...
Loading graph 28...
Loading graph 29...
Loading graph 30...
Loading graph 31...
Loading graph 32...
Loading graph 33...
Loading graph 34...
Loading graph 35...
Loading graph 36...
Loading graph 37...
Loading graph 38...
Loading graph 39...
Loading graph 40...
Loading graph 41...
Loading graph 42...
Loading graph 43...
Loading graph 44...
Loading graph 45...
Loading graph 46...
Loading graph 47...
Loading graph 48...
Loading graph 49...
Loading gr

In [17]:
data_validate = combine_data(NUM_MICROSTRUCTURES_VALIDATE_START, NUM_MICROSTRUCTURES_VALIDATE_END)
torch.save(data_validate, "datasets/data_validate.pt")

Loading graph 200...
Loading graph 201...
Loading graph 202...
Loading graph 203...
Loading graph 204...
Loading graph 205...
Loading graph 206...
Loading graph 207...
Loading graph 208...
Loading graph 209...
Loading graph 210...
Loading graph 211...
Loading graph 212...
Loading graph 213...
Loading graph 214...
Loading graph 215...
Loading graph 216...
Loading graph 217...
Loading graph 218...
Loading graph 219...
Combining data...
DataBatch(edge_index=[2, 54550], rot=[3945, 3], size=[3945], x=[3945, 3], edge_attr=[54550, 1], y=[3945], train_mask=[3945], test_mask=[3945], batch=[3945], ptr=[17])
