Import dataset

In [None]:
import numpy as np
import torch
from data.dataset import OneDDatasetLoader, DatasetLoader
from typing import List

def train_eval_split(dataset : DatasetLoader, train_id : List, eval_id : List):
    # Get batching id
    if dataset._sub_dir == '/batched/':
        batching_id = dataset.batching_id.numpy()
        train_id = list(np.where(np.isin(batching_id, train_id) == True)[0])
        eval_id = list(np.where(np.isin(batching_id, eval_id) == True)[0])
    # Train dataset
    train_dataset = [dataset[i] for i in train_id]
    # Test dataset
    eval_dataset = [dataset[i] for i in eval_id]
    return train_dataset, eval_dataset

In [None]:
dataset = OneDDatasetLoader(
    root_dir='/data1/tam/downloaded_datasets_nodeattr_in',
    sub_dir='/normalized/'
)
print('Dataset loaded.')

# batched_dataset = dataset.batching(batch_size=None, batch_n_times=10, recursive=True,
#                                     sub_dir='/batched/')
# print('Dataset batching finished.')

batched_dataset = OneDDatasetLoader(
    root_dir='/data1/tam/downloaded_datasets_nodeattr_in',
    sub_dir='/batched/'
)

train_dataset, eval_dataset = train_eval_split(
    dataset=batched_dataset,
    train_id=list(range(0,20)),
    eval_id=list(range(20,40))
)
print('Train/eval spliting finished.')

Train

In [None]:
import os
import torch
from networks.network_parc import PARC
os.environ["CUDA_VISIBLE_DEVICES"]="1"
from networks.network_recurrent import objectview
CUDA_LAUNCH_BLOCKING=1
torch.cuda.empty_cache()

# Model params
args = objectview({
    'n_fields' : 2,
    'n_timesteps' : 1,
    'n_hiddenfields' : 128,
    'n_meshfields' : dataset[0].node_attr.size(1),
    'timesteps' : 0.002,
    'device' : torch.device("cuda:0" if torch.cuda.is_available() else "cpu"),
    'lr' : 1e-7,
    'weight_decay' : 1e-3,
    'epoch' : 100,
    'train_lambda' : 0.5
})

# Model initializing
model = PARC(
    n_fields=args.n_fields,
    n_timesteps=args.n_timesteps,
    n_hiddenfields=args.n_hiddenfields,
    n_meshfields=args.n_meshfields
)
model = model.to(args.device)
optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
criterion = torch.nn.MSELoss()
# criterion = WeightedMSELoss()

In [None]:
def cal_derivative(F : torch.Tensor, dim : int = -1, delta_t : float = 1.) -> torch.Tensor:
    _F = F.transpose(0, dim)
    deriv_F = []
    for i in range(1, _F.size(0)):
        deriv_F_i = (_F[i] - _F[i-1]) / delta_t
        deriv_F.append(deriv_F_i.unsqueeze(dim))
    return torch.cat(deriv_F, dim=dim)

# F = torch.cat([dataset[0].pressure.unsqueeze(-1),dataset[0].pressure.unsqueeze(-1)], dim=-1)
# print(F.size())
# deriv_F = cal_derivative(F, dim=1, delta_t=args.timesteps)
# print(deriv_F.size())

In [None]:
# Train function v2
def train(model, data, args):
    n_time = data.pressure.size(1)
    edge_index = data.edge_index.to(args.device)
    mesh_features = data.node_attr.to(args.device)
    F_initial = torch.cat([data.pressure[:,0].unsqueeze(1), data.velocity[:,0].unsqueeze(1)], dim=-1)\
                .to(args.device)
    model.n_timesteps = n_time - 1

    Fs, F_dots = model(F_initial, mesh_features, edge_index)

    Fs_hat = torch.cat([data.pressure.unsqueeze(-1), data.velocity.unsqueeze(-1)], dim=-1)\
                .to(args.device)
    
    F_dots_hat = cal_derivative(Fs_hat, dim=1, delta_t=args.timesteps)

    Fs_hat = Fs_hat[:,1:,:]

    loss = (1.-args.train_lambda)*criterion(Fs_hat, Fs) + \
            args.train_lambda*criterion(F_dots_hat, F_dots)

    loss.backward()
    optimizer.step()
    return loss.item()

# Eval function
def eval(model, data, args):
    n_time = data.pressure.size(1)
    edge_index = data.edge_index.to(args.device)
    mesh_features = data.node_attr.to(args.device)
    F_initial = torch.cat([data.pressure[:,0].unsqueeze(1), data.velocity[:,0].unsqueeze(1)], dim=-1)\
                .to(args.device)
    model.n_timesteps = n_time - 1

    with torch.no_grad():
        Fs, F_dots = model(F_initial, mesh_features, edge_index)

        Fs_hat = torch.cat([data.pressure.unsqueeze(-1), data.velocity.unsqueeze(-1)], dim=-1)\
                    .to(args.device)
        
        F_dots_hat = cal_derivative(Fs_hat, dim=1, delta_t=args.timesteps)

        Fs_hat = Fs_hat[:,1:,:]

        loss = (1.-args.train_lambda)*criterion(Fs_hat, Fs) + \
            args.train_lambda*criterion(F_dots_hat, F_dots)

    return loss.item()
# Training
total_train_loss = []
total_eval_loss = []
for epoch in range(args.epoch):
# for epoch in range(1):
    torch.cuda.empty_cache()
    train_loss = 0
    for data in train_dataset:
        train_loss += train(model=model, data=data, args=args)
    train_loss /= len(train_dataset)
    total_train_loss.append(train_loss)

    eval_loss = 0
    for data in eval_dataset:
        eval_loss += eval(model=model, data=data, args=args)
    eval_loss /= len(eval_dataset)
    total_eval_loss.append(eval_loss)
    
    print(f'Epoch {epoch}: train loss = {train_loss}; eval loss = {eval_loss}')
    if (epoch+1) % 25 == 0:
        torch.save(model.state_dict(), f'models/parc_v2_epoch{epoch+1}.pth')
    
    

In [None]:
# Save model
torch.save(model.state_dict(), 'models/parc_v2_final.pth')

Reconstruct CFD output

In [None]:
# reconstruct CFD
def print_prediction(model, data):
    pass

In [None]:
# Model initializing
model = RecurrentMeshGraphNet(
    input_dim_node = args.input_dim_node,
    input_dim_edge = args.input_dim_edge,
    output_dim_node = args.output_dim_node,
    output_dim_edge = args.output_dim_edge,
    hidden_dim = args.hidden_dim,
    n_processors = args.n_processors
)
model = model.to(args.device)
model.load_state_dict(torch.load('models/rmgn_v1_final.pth'))

In [None]:
# Plot prediction/ground truth
import matplotlib.pyplot as plt

def plot_comparison(model, data):
    n_edge = data.edge_attr.size(0)
    hidden = torch.zeros(n_edge, args.hidden_dim).to(args.device)
    _x = data.x
    edge_index = data.edge_index.to(args.device)
    _edge_attr = data.edge_attr

    total_loss = 0
    total_node_out = [data.pressure[:,0].unsqueeze(1)]
    total_edge_out = [data.velocity[:,0].unsqueeze(1)]
    for i in range(1, args.n_time):
        model.eval()
        with torch.no_grad():
            _pressure = data.pressure[:,i-1].unsqueeze(1)
            _velocity = data.velocity[:,i-1].unsqueeze(1)
            _flowrate_bc = data.flowrate_bc[:,i].unsqueeze(1)
            x = torch.cat([_x, _pressure], dim=1).to(args.device)
            edge_attr = torch.cat([_edge_attr, _velocity, _flowrate_bc], dim=1).to(args.device)
            node_out, edge_out, hidden = model(x, edge_index, edge_attr, hidden)
            hidden = hidden.detach()


            loss = criterion(node_out, data.pressure[:,i].unsqueeze(1).float().to(args.device))
            loss += criterion(edge_out, data.velocity[:,i].unsqueeze(1).float().to(args.device))

            total_loss += loss.item()
            total_node_out.append(node_out.detach().cpu())
            total_edge_out.append(edge_out.detach().cpu())
    total_node_out = torch.cat(total_node_out, dim=1)
    total_edge_out = torch.cat(total_edge_out, dim=1)
    
    # plot
    node = 5
    y_pred = total_edge_out[node].numpy()
    y_true = data.velocity[node].numpy()
    x = [i * 4.8 /200 for i in range(201)]
    plt.plot(x, y_pred, c='red', label='RMGN')
    plt.plot(x, y_true, c='blue', linestyle='dashdot', label='ground_truth')
    # plt.ylim([-50,50])
    plt.legend(loc='upper right')
    plt.ylabel('Pressure', fontsize=20)
    plt.xlabel('Time', fontsize=20)
    plt.show()
    
    return total_loss
    
mean, std = mean_std_dataset(_dataset, set_id=list(range(_dataset.len())))
plot_comparison(model, normalize(_dataset[40], mean, std))