In [1]:
import numpy as np
import os
import torch
import random
from networks.network import EmbeddedNet
from torch_geometric.data import InMemoryDataset, Data
from torch_geometric.nn import global_mean_pool
from preprocessing.dataset import OneDGraphDataset
from preprocessing.batching import metis_batching_torch
import matplotlib.pyplot as plt
os.environ["CUDA_VISIBLE_DEVICES"]="1"
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

Class to load 1D dataset

In [2]:
time_id = [str(i).zfill(3) for i in range(201)]
dataset = OneDGraphDataset(
    raw_file_dir='/data1/tam/datasets_cement',
    root_dir='/data1/tam/downloaded_datasets',
    data_names='all',
    time_id=time_id,
    transform=None,
    pre_transform=None,
    pre_filter=None,
    is_loader=True
)
print(dataset[10])
print(dataset.len)

TorchGraphData(x=[59815, 3], edge_index=[2, 59814], edge_attr=[59814, 5], pressure=[59815, 201], flowrate=[59814, 201])
41


Train/Test split

In [3]:
# train_percent = 0.8
case_name = dataset.data_names
# case_name = sorted(np.unique([name[4:7] for name in case_name]))
# print(case_name)
train_id = range(0, 31)
eval_id = range(31, 35)
val_id = range(35, 41)
print(f"Load dataset, split with {len(train_id)} training cases and {len(eval_id)} evaluating cases.")

Load dataset, split with 31 training cases and 4 evaluating cases.


In [4]:
learning_rate = 1e-7
decay = 5e-4

model = EmbeddedNet(
    input_dim_node=1, 
    input_dim_edge=5, 
    hidden_dim=128, 
    output_dim=1,  
    num_processor_layers = 10, 
    emb=False, 
    add_self_loops=True
)
model = model.to(device)
# model = torch.nn.DataParallel(model)
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate, weight_decay=decay)
# optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate, weight_decay=decay)
print(f'Loaded model: {model}')
criterion = torch.nn.MSELoss()
# criterion = torch.nn.HuberLoss(delta=2.0)

Loaded model: EmbeddedNet(
  (mesh_graph_net): MeshGraphNet(
    (node_encoder): Sequential(
      (0): Linear(in_features=1, out_features=128, bias=True)
      (1): ReLU()
      (2): Linear(in_features=128, out_features=128, bias=True)
      (3): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
    )
    (edge_encoder): Sequential(
      (0): Linear(in_features=6, out_features=128, bias=True)
      (1): ReLU()
      (2): Linear(in_features=128, out_features=128, bias=True)
      (3): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
    )
    (processor): ModuleList(
      (0): ProcessorLayer()
      (1): ProcessorLayer()
      (2): ProcessorLayer()
      (3): ProcessorLayer()
      (4): ProcessorLayer()
      (5): ProcessorLayer()
      (6): ProcessorLayer()
      (7): ProcessorLayer()
      (8): ProcessorLayer()
      (9): ProcessorLayer()
    )
    (node_decoder): Sequential(
      (0): Linear(in_features=128, out_features=128, bias=True)
      (1): ReLU()
      (2): Linea

Train/Test function

In [5]:
def train(x, f_prev, f, p_prev, p, edge_index, edge_attr):
    model.train()
    optimizer.zero_grad()
    # print(x.size(), f_prev.size(), p_prev.size(), edge_attr.size(), edge_index.size())
    print(x, f_prev, p_prev, edge_attr, edge_index)
    p_pred, f_pred = model(x.float(), f_prev.float(), p_prev.float(), edge_index, edge_attr.float())
    loss = criterion(f_pred, f.float())  + criterion(p_pred, p.float())
    loss.backward()
    optimizer.step()
    return loss.item()

def eval(x, f_prev, f, p_prev, p, edge_index, edge_attr):
    model.eval()
    with torch.no_grad():
        p_pred, f_pred = model(x.float(), f_prev.float(), p_prev.float(), edge_index, edge_attr.float())
        loss = criterion(f_pred, f.float()) + criterion(p_pred, p.float())
    return loss.item()

Normalize dataset

In [8]:
# data = []
# for i in range(dataset.len):
#     data += metis_batching_torch(dataset[i], relative_batch_size=100, recursive=True)
#     print(f'Finish data number {i}.')
# train_id = range(0, 15000)
# eval_id = range(15000, 24000)
# for i in range(len(data)):
#     data[i] = data[i].to(device)
data

[TorchGraphData(x=[106, 3], edge_index=[2, 103], edge_attr=[103, 5], pressure=[106, 201], flowrate=[103, 201]),
 TorchGraphData(x=[100, 3], edge_index=[2, 99], edge_attr=[99, 5], pressure=[100, 201], flowrate=[99, 201]),
 TorchGraphData(x=[104, 3], edge_index=[2, 103], edge_attr=[103, 5], pressure=[104, 201], flowrate=[103, 201]),
 TorchGraphData(x=[102, 3], edge_index=[2, 99], edge_attr=[99, 5], pressure=[102, 201], flowrate=[99, 201]),
 TorchGraphData(x=[108, 3], edge_index=[2, 106], edge_attr=[106, 5], pressure=[108, 201], flowrate=[106, 201]),
 TorchGraphData(x=[104, 3], edge_index=[2, 102], edge_attr=[102, 5], pressure=[104, 201], flowrate=[102, 201]),
 TorchGraphData(x=[100, 3], edge_index=[2, 99], edge_attr=[99, 5], pressure=[100, 201], flowrate=[99, 201]),
 TorchGraphData(x=[114, 3], edge_index=[2, 111], edge_attr=[111, 5], pressure=[114, 201], flowrate=[111, 201]),
 TorchGraphData(x=[100, 3], edge_index=[2, 98], edge_attr=[98, 5], pressure=[100, 201], flowrate=[98, 201]),
 Tor

Training

In [7]:
CUDA_LAUNCH_BLOCKING=1
torch.cuda.empty_cache()
print('Start training.')
train_loss_all = []
eval_loss_all = []
for epoch in range(100):
    train_loss = 0.0
    for i in train_id:
        x = data[i].x #.to(device)
        edge_index = data[i].edge_index #.to(device)
        edge_attr = data[i].edge_attr #.to(device)
        for time in range(1, len(time_id)):
            p_prev = data[i].pressure[:,time - 1].unsqueeze(1) #.to(device)
            p = data[i].pressure[:,time].unsqueeze(1) #.to(device)
            f_prev = data[i].flowrate[:,time - 1].unsqueeze(1) #.to(device)
            f = data[i].flowrate[:,time].unsqueeze(1) #.to(device)
            # print(f_prev.size(), edge_attr.size())
            train_loss += train(x, f_prev, f, p_prev, p, edge_index, edge_attr)
    train_loss /= len(train_id)
    train_loss_all.append(train_loss)
    
    eval_loss = 0.0
    for i in eval_id:
        x_eval = data[i].x
        edge_index_eval = data[i].edge_index
        edge_attr_eval = data[i].edge_attr
        for time in range(1, len(time_id)):
            p_prev_eval = data[i].pressure[:,time - 1].unsqueeze(1)
            p_eval = data[i].pressure[:,time].unsqueeze(1)
            f_prev_eval = data[i].flowrate[:,time - 1].unsqueeze(1)
            f_eval = data[i].flowrate[:,time].unsqueeze(1)
            eval_loss += eval(x_eval, f_prev_eval, f_eval, p_prev_eval, p_eval,\
                        edge_index_eval, edge_attr_eval)
    eval_loss /= len(eval_id)
    eval_loss_all.append(eval_loss)
    
    if epoch % 10 == 9:
        print(f'Epoch {epoch}: train loss={train_loss}; eval loss={eval_loss}')

Start training.


KeyboardInterrupt: 

In [None]:
torch.save(model.state_dict(), 'models/mgn_v3.pth')
plt.plot(train_loss_all, label='train_loss')
plt.plot(eval_loss_all, label='eval_loss')
plt.xlabel('Epoch')
plt.legend()
plt.show()

f = open('./loss_mgn_v3','w+')
for i in range(len(train_loss_all)):
    f.write(str(train_loss_all[i]))
    f.write('    ')
    f.write(str(eval_loss_all[i]))
    f.write('\n')
f.close()

In [None]:
torch.cuda.empty_cache()
out_dir = './predict_v3/'
model = EmbeddedNet(
    input_dim_node=4, 
    input_dim_edge=1, 
    hidden_dim=128, 
    output_dim=1,  
    num_processor_layers = 10, 
    emb=False, 
    add_self_loops=True
)
device = torch.device("cpu")
model.load_state_dict(torch.load('models/mgn_v3.pth'))
model = model.to(device)
model.eval()


def print_1D(x, edge_index, aoa, left_elem, right_elem, num_elem, case, path):
    fo = open(path+'/output_'+case+'.dat','w+')
    fo.write(f'VARIABLES="CoordinateX"\n"CoordinateY"\n"CoordinateZ"\n"flag"\n"uds-0-scalar"\n')
    fo.write(f'ZONE T= "output_iso_{case}.dat"\n')
    fo.write('STRANDID=17, SOLUTIONTIME=0\n')
    fo.write(f'Nodes={np.shape(x)[0]}, Faces={np.shape(edge_index)[1]}, Elements={num_elem}, ZONETYPE=FEPolygon\n')
    fo.write(f'DATAPACKING=BLOCK\nNumConnectedBoundaryFaces=0, TotalNumBoundaryConnections=0\nAUXDATA Time="0.000000e+00"\nDT=(SINGLE SINGLE SINGLE SINGLE SINGLE )\n')
    line_count = 0
    for j in range(np.shape(x)[1]):
        for i in range(np.shape(x)[0]):
            fo.write(f' {x[i][j]}')
            line_count += 1
            if line_count == 5:
                fo.write('\n')
                line_count = 0
        if line_count != 0:
            fo.write('\n')
        line_count = 0
    for i in range(np.shape(aoa)[0]):
        fo.write(f' {aoa[i]}')
        line_count += 1
        if line_count == 5:
            fo.write('\n')
            line_count = 0
    if line_count != 0:
        fo.write('\n')
    line_count = 0
    
    fo.write('# face nodes\n')
    for i in range(edge_index.shape[1]):
        fo.write(f' {edge_index[0][i]+1}\t{edge_index[1][i]+1}')
        line_count += 1
        if line_count == 5:
            fo.write('\n')
            line_count = 0
    if line_count != 0:
        fo.write('\n')
    line_count = 0

    fo.write('# left elements\n')
    for i in range(left_elem.shape[0]):
        fo.write(f' {left_elem[i]}')
        line_count += 1
        if line_count == 10:
            fo.write('\n')
            line_count = 0
    if line_count != 0:
        fo.write('\n')
    line_count = 0

    fo.write('# right elements\n')
    for i in range(right_elem.shape[0]):
        fo.write(f' {right_elem[i]}')
        line_count += 1
        if line_count == 10:
            fo.write('\n')
            line_count = 0
    if line_count != 0:
        fo.write('\n')
    line_count = 0
    fo.close()
import os
os.system('rm -rf {out_dir}*')
os.system('mkdir '+out_dir)
f = open('mean_aoa.dat','w+')
mean_aoa_true = []
mean_aoa_pred = []
for i in eval_id:
    x = data[i].x.to(device)
    edge_index = data[i].edge_index.to(device)
    edge_attr = data[i].edge_attr.to(device)
    aoa = model(x.float(), edge_index, edge_attr.float())
    left_elem = data[i].left_element.cpu().numpy().astype(np.int32)
    right_elem = data[i].right_element.cpu().numpy().astype(np.int32)
    aoa = reverse_normalize(aoa.squeeze(1).detach().numpy(), mean_aoa, std_aoa)
    print_1D(
        x = dataset[i].x.cpu().numpy(), 
        edge_index=dataset[i].edge_index.cpu().numpy().astype(np.int32),
        aoa= aoa,
        left_elem=left_elem,
        right_elem=right_elem,
        num_elem=dataset[i].num_elem,
        case=dataset.data_names[i], 
        path=out_dir
    )
    batch = torch.zeros(size=(x.size()[0],)).type(torch.LongTensor).to(device)
    mean_aoa_true.append(global_mean_pool(dataset[i].age_of_air.squeeze(1), batch)[0])
    mean_aoa_pred.append(global_mean_pool(torch.tensor(aoa), batch)[0])
    f.write(f'{global_mean_pool(dataset[i].age_of_air.squeeze(1), batch)[0]}\t{global_mean_pool(torch.tensor(aoa), batch)[0]}\n')
f.close()

In [None]:

for sur_id in range(0, 2):
    y_true = []
    y_pred = []
    for i in range(len(eval_id)):
        if i % 2 == sur_id:
            y_true.append(mean_aoa_true[i])
            y_pred.append(mean_aoa_pred[i])

    # print(floor_mean_aoa_pred)
    _x = range(len(y_true))
    plt.plot(_x, y_true, label='true'+str(sur_id))
    plt.plot(_x, y_pred, label='pred'+str(sur_id))
plt.legend()
plt.show()

In [None]:
# time his
i = 35
print(dataset.subject_list[i])
path = out_dir+dataset.subject_list[i]+'/'
# path_ref = './predict_2/'+dataset.subject_list[i]+'/'
node_id = 50000
time_list = [i * 4.8 / 200 for i in range(201)]
def read_value(path, node_id):
    f = open(path, 'r')
    f.readline()
    f.readline()
    f.readline()
    for i in range(node_id + 1):
        line = f.readline()
    line = line[:-1].split('\t')
    f.close()
    return float(line[-2])
value = []
value_ref = []
for time in time_id:
    value.append(read_value(path+'plt_nd_000'+time+'.dat',node_id))
    # value_ref.append(read_value(path_ref+'plt_nd_000'+time+'.dat',node_id))
y_pred = np.array(value)
# y_ref = np.array(value_ref)
y_true = dataset[i].p[node_id].numpy()
# plt.plot(time_list, y_ref, c='black', label='GCN')
plt.plot(time_list, y_pred, c='red', label='ResGCN')
plt.plot(time_list, y_true, c='blue', linestyle='dashdot', label='ground_truth')
# plt.ylim([-50,50])
plt.legend(loc='upper left')
plt.ylabel('Pressure', fontsize=20)
plt.xlabel('Time', fontsize=20)
plt.show()