In [1]:
import pandas as pd
import numpy as np

In [3]:
# Import data and important features for graph construction
data = pd.read_pickle('../data/connectivity_compliance_matrices.pkl')
data = data.rename(columns={'œÅ': 'rho'})
display(data.head())
print(f"Data shape: {data.shape}")

# Extract relevant features
X = data[['rho', 'connectivity_matrix', 'compliance_matrix']]

Unnamed: 0,rho,connectivity_matrix,compliance_matrix,E1,E2,E3,G23,G13,G12,nu12,nu13,nu23,nu21,nu31,nu32,BV,GV,BR,GR,AU
0,0.3,"[[0.0, 1.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0,...","[[16.2442166588418, -4.390589663678543, -4.392...",0.06156,0.061524,0.0615,0.025668,0.025591,0.025632,0.270286,0.270381,0.270225,0.270126,0.270115,0.27012,0.044626,0.025066,0.044626,0.025047,0.003852
1,0.3,"[[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,...","[[22.194443522542244, -7.184492662134119, -7.1...",0.045056,0.045102,0.045165,0.026453,0.026486,0.026497,0.323707,0.322935,0.322292,0.324034,0.323716,0.322746,0.042532,0.022705,0.042532,0.021679,0.236657
2,0.3,"[[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,...","[[21.44865815819801, -6.467827160497426, -6.44...",0.046623,0.046639,0.046645,0.024257,0.024269,0.024219,0.301549,0.300551,0.30127,0.301656,0.300691,0.301303,0.039092,0.021717,0.039092,0.021247,0.110572
3,0.3,"[[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,...","[[15.247634909853652, -3.931170890685161, -3.9...",0.065584,0.065573,0.065544,0.027258,0.027189,0.027245,0.257822,0.25777,0.256933,0.257779,0.257614,0.25682,0.045055,0.026767,0.045055,0.026755,0.002277
4,0.3,"[[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,...","[[23.16273271145176, -5.511611078765939, -5.66...",0.043173,0.043181,0.043323,0.01994,0.019963,0.019987,0.237952,0.244505,0.239779,0.237998,0.245358,0.240568,0.027819,0.018944,0.027818,0.01886,0.022473


Data shape: (2624, 20)


In [17]:
from torch_geometric.data import Data
import torch

In [28]:
# Construct sample graph data consisting of edge index and node features (rho)
sample = X.iloc[0]
rho = sample['rho']
connectivity_matrix = sample['connectivity_matrix']
compliance_matrix = sample['compliance_matrix']
edge_index = np.array(np.nonzero(connectivity_matrix)).T
node_features = np.full(connectivity_matrix.shape[0], rho, dtype=np.float32)

sample_data = Data(x=torch.tensor(node_features, dtype=torch.float32), edge_index=torch.tensor(edge_index, dtype=torch.long), y=torch.tensor(compliance_matrix, dtype=torch.float32))
print(f"Node features:\n{sample_data.x}")
print(f"Edge index:\n{sample_data.edge_index}")
print(f"Compliance matrix:\n{sample_data.y}")

Node features:
tensor([0.3000, 0.3000, 0.3000, 0.3000, 0.3000, 0.3000, 0.3000, 0.3000, 0.3000,
        0.3000, 0.3000])
Edge index:
tensor([[ 0,  1],
        [ 0,  4],
        [ 1,  0],
        [ 1,  7],
        [ 2,  3],
        [ 2, 10],
        [ 3,  2],
        [ 3,  4],
        [ 4,  0],
        [ 4,  3],
        [ 4,  6],
        [ 4, 10],
        [ 6,  4],
        [ 6,  6],
        [ 7,  1],
        [ 7,  7],
        [10,  2],
        [10,  4]])
Compliance matrix:
tensor([[16.2442, -4.3906, -4.3921,  0.0000,  0.0000,  0.0000],
        [-4.3906, 16.2539, -4.3922,  0.0000,  0.0000,  0.0000],
        [-4.3921, -4.3922, 16.2602,  0.0000,  0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000, 38.9590,  0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  0.0000, 39.0762,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000, 39.0137]])


In [31]:
def construct_graph_data(row):
    rho = row['rho']
    connectivity_matrix = row['connectivity_matrix']
    compliance_matrix = row['compliance_matrix']
    
    edge_index = np.array(np.nonzero(connectivity_matrix)).T
    node_features = np.full(connectivity_matrix.shape[0], rho, dtype=np.float32)
    
    graph_data = Data(x=torch.tensor(node_features, dtype=torch.float32), 
                      edge_index=torch.tensor(edge_index, dtype=torch.long), 
                      y=torch.tensor(compliance_matrix, dtype=torch.float32))
    return graph_data

# Apply the function to the entire DataFrame to create a list of graph data objects
graph_data_list = X.apply(construct_graph_data, axis=1).tolist()
print(f"Constructed {len(graph_data_list)} graph data objects.")
print(f"Sample graph data object:\n{graph_data_list[0]}")

Constructed 2624 graph data objects.
Sample graph data object:
Data(x=[11], edge_index=[18, 2], y=[6, 6])
