Import dataset

In [1]:
import numpy as np
import torch
from data.dataset import OneDDatasetLoader, DatasetLoader
from data.data import TorchGraphData
from preprocessing.batching import merge_graphs
import math
from typing import List
from networks.loss import WeightedMSELoss

In [3]:
def train_eval_split(dataset : DatasetLoader, train_id : List, eval_id : List):
    # Get batching id
    if dataset._sub_dir == '/batched/':
        batching_id = dataset.batching_id.numpy()
        train_id = list(np.where(np.isin(batching_id, train_id) == True)[0])
        eval_id = list(np.where(np.isin(batching_id, eval_id) == True)[0])
    # Train dataset
    train_dataset = [dataset[i] for i in train_id]
    # train_dataset = []
    # for i in train_id:
    #     train_dataset.append(dataset[i])
    # Test dataset
    eval_dataset = [dataset[i] for i in eval_id]
    # eval_dataset = []
    # for i in eval_id:
    #     eval_dataset.append(dataset[i])
    return train_dataset, eval_dataset

In [4]:
dataset = OneDDatasetLoader(
    root_dir='/data1/tam/downloaded_datasets_transformed'
)
print('Dataset loaded.')

# batched_dataset = dataset.batching(batch_size=2000, batch_n_times=10, recursive=True,
#                                     sub_dir='/batched/')
# print('Dataset batching finished.')
batched_dataset = OneDDatasetLoader(
    root_dir='/data1/tam/downloaded_datasets_transformed',
    sub_dir='/batched/'
)

train_dataset, eval_dataset = train_eval_split(
    dataset=batched_dataset,
    train_id=list(range(0,20)),
    eval_id=list(range(20,40))
)
print('Train/eval spliting finished.')

Dataset loaded.
Train/eval spliting finished.


Train

In [5]:
import os
import torch
from networks.network_recurrent import RecurrentMeshGraphNet
os.environ["CUDA_VISIBLE_DEVICES"]="3"
from networks.network_recurrent import objectview
CUDA_LAUNCH_BLOCKING=1
torch.cuda.empty_cache()

# Model params
args = objectview({
    'input_dim_node' : dataset[0].x.size(1)+1,
    'input_dim_edge' : dataset[0].edge_attr.size(1)+1,
    'output_dim_node' : 1,
    'output_dim_edge' : 1,
    'hidden_dim' : 128,
    'n_processors' : 10,
    'n_time' : dataset[0].pressure.size(1),
    'device' : torch.device("cuda:0" if torch.cuda.is_available() else "cpu"),
    'lr' : 1e-7,
    'weight_decay' : 1e-3,
    'epoch' : 100
})

# Model initializing
model = RecurrentMeshGraphNet(
    input_dim_node = args.input_dim_node,
    input_dim_edge = args.input_dim_edge,
    output_dim_node = args.output_dim_node,
    output_dim_edge = args.output_dim_edge,
    hidden_dim = args.hidden_dim,
    n_processors = args.n_processors
)
model = model.to(args.device)
optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
# criterion = torch.nn.MSELoss()
criterion = WeightedMSELoss()

In [None]:
# Train function v2
def train(model, data, args):
    n_edge = data.edge_attr.size(0)
    hidden = torch.zeros(n_edge, args.hidden_dim).to(args.device)
    _x = data.x #.to(args.device)
    edge_index = data.edge_index.to(args.device)
    _edge_attr = data.edge_attr #.to(args.device)

    _pressure = data.pressure[:,0].unsqueeze(1)
    _velocity = data.velocity[:,0].unsqueeze(1)

    node_weight = data.node_weight.to(args.device)
    edge_weight = data.edge_weight.to(args.device)
    loss = 0
    for i in range(1, data.number_of_timesteps):
        # print(i)
        model.train()
        optimizer.zero_grad()

        x = torch.cat([_x, _pressure], dim=1).to(args.device)
        edge_attr = torch.cat([_edge_attr, _velocity], dim=1).to(args.device)
        node_out, edge_out, hidden = model(x, edge_index, edge_attr, hidden)
        hidden = hidden.detach()
        _pressure = node_out.detach().cpu()
        _velocity = edge_out.detach().cpu()

        loss += criterion(node_out, data.pressure[:,i].unsqueeze(1).float().to(args.device), node_weight)
        loss += criterion(edge_out, data.velocity[:,i].unsqueeze(1).float().to(args.device), edge_weight)

        # del x
        # del edge_attr
        # del node_out
        # del edge_out

    loss.backward()
    optimizer.step()
    return loss.item()

# Eval function
def eval(model, data, args):
    n_edge = data.edge_attr.size(0)
    hidden = torch.zeros(n_edge, args.hidden_dim).to(args.device)
    _x = data.x #.to(args.device)
    edge_index = data.edge_index.to(args.device)
    _edge_attr = data.edge_attr #.to(args.device)

    _pressure = data.pressure[:,0].unsqueeze(1)
    _velocity = data.velocity[:,0].unsqueeze(1)

    node_weight = data.node_weight.to(args.device)
    edge_weight = data.edge_weight.to(args.device)

    loss = 0
    for i in range(1, data.number_of_timesteps):
        model.eval()
        with torch.no_grad():

            x = torch.cat([_x, _pressure], dim=1).to(args.device)
            edge_attr = torch.cat([_edge_attr, _velocity], dim=1).to(args.device)
            node_out, edge_out, hidden = model(x, edge_index, edge_attr, hidden)
            hidden = hidden.detach()
            _pressure = node_out.detach().cpu()
            _velocity = edge_out.detach().cpu()

            loss += criterion(node_out, data.pressure[:,i].unsqueeze(1).float().to(args.device), node_weight)
            loss += criterion(edge_out, data.velocity[:,i].unsqueeze(1).float().to(args.device), edge_weight)

            # del x
            # del edge_attr
            # del node_out
            # del edge_out

    return loss.item()
# Training
total_train_loss = []
total_eval_loss = []
for epoch in range(args.epoch):
# for epoch in range(1):
    torch.cuda.empty_cache()
    train_loss = 0
    for data in train_dataset:
        train_loss += train(model=model, data=data, args=args)
    train_loss /= len(train_dataset)
    total_train_loss.append(train_loss)

    eval_loss = 0
    for data in eval_dataset:
        eval_loss += eval(model=model, data=data, args=args)
    eval_loss /= len(eval_dataset)
    total_eval_loss.append(eval_loss)
    
    print(f'Epoch {epoch}: train loss = {train_loss}; eval loss = {eval_loss}')
    if (epoch+1) % 25 == 0:
        torch.save(model.state_dict(), f'models/rmgn_v2_epoch{epoch+1}.pth')
    
    

Epoch 0: train loss = 0.009624206320393234; eval loss = 0.003637926196808171
Epoch 1: train loss = 0.003433878563462657; eval loss = 0.003213463997502215
Epoch 2: train loss = 0.0031897206352853214; eval loss = 0.003072503704462737
Epoch 3: train loss = 0.0030830692345816726; eval loss = 0.0029965330179060746
Epoch 4: train loss = 0.003019670055787654; eval loss = 0.002949016358062797
Epoch 5: train loss = 0.002978012608626122; eval loss = 0.002915856439307858
Epoch 6: train loss = 0.0029481432022587886; eval loss = 0.002890296689925361
Epoch 7: train loss = 0.0029255742775144192; eval loss = 0.0028706905003134026
Epoch 8: train loss = 0.002907925525987666; eval loss = 0.0028551543669160582
Epoch 9: train loss = 0.0028936764326507806; eval loss = 0.002842287858166382
Epoch 10: train loss = 0.0028818009155006452; eval loss = 0.002831555619461957
Epoch 11: train loss = 0.0028718253791567183; eval loss = 0.00282247361052392
Epoch 12: train loss = 0.002863535242476524; eval loss = 0.002814

In [None]:
# Save model
torch.save(model.state_dict(), 'models/rmgn_v2_final.pth')

Reconstruct CFD output

In [None]:
# reconstruct CFD
def print_prediction(model, data):
    pass

In [None]:
# Model initializing
model = RecurrentMeshGraphNet(
    input_dim_node = args.input_dim_node,
    input_dim_edge = args.input_dim_edge,
    output_dim_node = args.output_dim_node,
    output_dim_edge = args.output_dim_edge,
    hidden_dim = args.hidden_dim,
    n_processors = args.n_processors
)
model = model.to(args.device)
model.load_state_dict(torch.load('models/rmgn_v1_final.pth'))

In [None]:
# Plot prediction/ground truth
import matplotlib.pyplot as plt

def plot_comparison(model, data):
    n_edge = data.edge_attr.size(0)
    hidden = torch.zeros(n_edge, args.hidden_dim).to(args.device)
    _x = data.x
    edge_index = data.edge_index.to(args.device)
    _edge_attr = data.edge_attr

    total_loss = 0
    total_node_out = [data.pressure[:,0].unsqueeze(1)]
    total_edge_out = [data.velocity[:,0].unsqueeze(1)]
    for i in range(1, args.n_time):
        model.eval()
        with torch.no_grad():
            _pressure = data.pressure[:,i-1].unsqueeze(1)
            _velocity = data.velocity[:,i-1].unsqueeze(1)
            _flowrate_bc = data.flowrate_bc[:,i].unsqueeze(1)
            x = torch.cat([_x, _pressure], dim=1).to(args.device)
            edge_attr = torch.cat([_edge_attr, _velocity, _flowrate_bc], dim=1).to(args.device)
            node_out, edge_out, hidden = model(x, edge_index, edge_attr, hidden)
            hidden = hidden.detach()


            loss = criterion(node_out, data.pressure[:,i].unsqueeze(1).float().to(args.device))
            loss += criterion(edge_out, data.velocity[:,i].unsqueeze(1).float().to(args.device))

            total_loss += loss.item()
            total_node_out.append(node_out.detach().cpu())
            total_edge_out.append(edge_out.detach().cpu())
    total_node_out = torch.cat(total_node_out, dim=1)
    total_edge_out = torch.cat(total_edge_out, dim=1)
    
    # plot
    node = 5
    y_pred = total_edge_out[node].numpy()
    y_true = data.velocity[node].numpy()
    x = [i * 4.8 /200 for i in range(201)]
    plt.plot(x, y_pred, c='red', label='RMGN')
    plt.plot(x, y_true, c='blue', linestyle='dashdot', label='ground_truth')
    # plt.ylim([-50,50])
    plt.legend(loc='upper right')
    plt.ylabel('Pressure', fontsize=20)
    plt.xlabel('Time', fontsize=20)
    plt.show()
    
    return total_loss
    
mean, std = mean_std_dataset(_dataset, set_id=list(range(_dataset.len())))
plot_comparison(model, normalize(_dataset[40], mean, std))