#Load data

In [1]:
import sys
sys.path.insert(1, '/data1/tam/python_graph_utilities_v2/Codes/')
import torch
from dataset import OneDDatasetBuilder, OneDDatasetLoader
from plot import *
from preprocessing import dataset_split_to_loader
from networks_v4 import RecurrentFormulationNetwork

import torch_geometric
import os
os.environ["CUDA_VISIBLE_DEVICES"]="0,1,2,3"
# %env CUDA_VISIBLE_DEVICE=2
CUDA_LAUNCH_BLOCKING=1
torch.cuda.empty_cache()

In [2]:
# define argument
class objectview(object):
    def __init__(self, d) -> None:
        self.__dict__ = d
    def setattr(self, attr_name, attr_value):
        self.__dict__[attr_name] = attr_value

args = objectview(d={
    # data params
    'total_time': 4.0,
    'n_time': 201,
    'batch_size': 2000,
    'batch_n_time': None,
    'batch_step': 1,
    'batch_recursive': True,
    # model params
    'n_field': 2,
    'n_meshfield': (19, 0),
    'n_boundaryfield': 1,
    'n_globalfield': 0,
    'latent_size': 128,
    'n_latent': 10,
    'hidden_size': 256,
    'n_hidden': 5,
    'forward_sequence': False,
    # training params
    'device': torch.device("cuda:0" if torch.cuda.is_available() else "cpu"),
    'lr': 5e-7,
    'weight_decay': 5e-4,
    'epoch': 100,
    'criterion': torch.nn.MSELoss(),
    'n_datas_per_batch': 1,
})

In [3]:
# # Build dataset
# dataset = OneDDatasetBuilder(
#     raw_dir='/data1/tam/datasets',
#     root_dir='/data1/tam/downloaded_datasets_new1',
#     data_names='all',
#     time_names=[str(i).zfill(3) for i in range(201)]
# )

# Load dataset
dataset = OneDDatasetLoader(
    root_dir='/data1/tam/downloaded_datasets_WT_v2',
    sub_dir='processed',
    data_names='all',
    time_names=[str(i).zfill(3) for i in range(201)]
)

dataset = dataset.normalizing(
    sub_dir='normalized_quantile',
    scalers = {
        'node_attr' : ['quantile_transformer', 0],
        # 'edge_attr' : ['minmax_scaler', 0],
        'pressure' : ['quantile_transformer', None],
        'flowrate' : ['quantile_transformer', None]
    }
)

dataset = dataset.batching(
    batch_size = args.batch_size,
    batch_n_times = args.batch_n_time, 
    recursive = args.batch_recursive, 
    sub_dir='/normalized_quantile_batched', 
    step=args.batch_step
)

In [4]:
# Load dataset
dataset = OneDDatasetLoader(
    root_dir='/data1/tam/downloaded_datasets_WT_v2',
    sub_dir='normed_batched',
    data_names='all',
    time_names=[str(i).zfill(3) for i in range(201)]
)

train_loader, test_loader = dataset_split_to_loader(
    dataset = dataset,
    subset_ids = {
        'train': list(range(0, 20)),
        'test': list(range(30, 35))
    },
    n_datas_per_batch = args.n_datas_per_batch
)

In [5]:
# prepare model
model = RecurrentFormulationNetwork(
    n_field=args.n_field,
    n_meshfield=args.n_meshfield,
    n_boundaryfield=args.n_boundaryfield,
    # n_globalfield=args.n_globalfield,
    n_hidden=args.n_hidden,
    hidden_size=args.hidden_size,
    integration=None
)
# model = torch_geometric.nn.DataParallel(model)
model = model.to(args.device)
optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
# args.setattr(attr_name='optimizer', attr_value=optimizer)

#Training

In [6]:
def train(model, data, args, forward_sequence=False):
    ## Field tensor: Tensor(n_node, n_time, n_field)
    F_true = torch.cat([
            data.pressure.unsqueeze(2),
            data.flowrate.unsqueeze(2)
        ], dim=2).float().to(args.device)
    if not forward_sequence:
        # forward only initial state
        F = F_true[:,0,:]
    else:
        # forward all timestep
        F = F_true
    F_true = F_true[:,1:,:]
    
    ## Connectivity/edge_index: Tensor(2, n_edge)
    # edge_index = data.edge_index.to(args.device)
    edge_index = torch.cat([
        data.edge_index, torch.flip(data.edge_index, dims=[0]
    )], dim=1).to(args.device)

    ## Mesh features: Tuple(Tensor(n_node, n_node_attr), Tensor(n_edge, n_edge_attr))
    node_attr = data.node_attr.float().to(args.device)
    # edge_attr = torch.cat([data.edge_attr, data.edge_attr], dim=0).float().to(args.device)
    edge_attr = None
    meshfield = (node_attr, edge_attr)

    ## Boundary field
    boundaryfield = torch.cat([data.flowrate[0,:].unsqueeze(0)]*data.flowrate.size(0), dim=0)
    boundaryfield = boundaryfield.float().to(args.device)

    # ## Global field
    # globalfield = torch.cat([data.global_attr.unsqueeze(0)]*node_attr.size(0), dim=0)
    # globalfield = globalfield.float().to(args.device)

    ## Predict output sequence
    F_pred = model(
        F = F,
        edge_index=edge_index,
        meshfield=meshfield,
        boundaryfield=boundaryfield,
        # globalfield=globalfield,
        forward_sequence=forward_sequence,
        n_time=data.flowrate.size(1)
    )
    
    ## loss calculation
    # print(F_true.size(), F_pred.size())
    loss = args.criterion(F_pred, F_true)
    loss.backward()
    optimizer.step()

    return loss.item()
# train(model, dataset[0], args, forward_sequence=True)

In [7]:
def eval(model, data, args, forward_sequence=False):
    ## Field tensor: Tensor(n_node, n_time, n_field)
    F_true = torch.cat([
            data.pressure.unsqueeze(2),
            data.flowrate.unsqueeze(2)
        ], dim=2).float().to(args.device)
    if not forward_sequence:
        # forward only initial state
        F = F_true[:,0,:]
    else:
        # forward all timestep
        F = F_true
    F_true = F_true[:,1:,:]
    
    ## Connectivity/edge_index: Tensor(2, n_edge)
    # edge_index = data.edge_index.to(args.device)
    edge_index = torch.cat([
        data.edge_index, torch.flip(data.edge_index, dims=[0]
    )], dim=1).to(args.device)

    ## Mesh features: Tuple(Tensor(n_node, n_node_attr), Tensor(n_edge, n_edge_attr))
    node_attr = data.node_attr.float().to(args.device)
    # edge_attr = data.edge_attr.float().to(args.device)
    # edge_attr = torch.cat([data.edge_attr, data.edge_attr], dim=0).float().to(args.device)
    edge_attr = None
    meshfield = (node_attr, edge_attr)

    ## Boundary field
    boundaryfield = torch.cat([data.flowrate[0,:].unsqueeze(0)]*data.flowrate.size(0), dim=0)
    boundaryfield = boundaryfield.float().to(args.device)

    # ## Global field
    # globalfield = torch.cat([data.global_attr.unsqueeze(0)]*node_attr.size(0), dim=0)
    # globalfield = globalfield.float().to(args.device)

    ## Predict output sequence
    with torch.no_grad():
        F_pred = model(
            F = F,
            edge_index=edge_index,
            meshfield=meshfield,
            boundaryfield=boundaryfield,
            # globalfield=globalfield,
            forward_sequence=forward_sequence,
            n_time=data.flowrate.size(1)
        )
        
        ## loss calculation
        loss = args.criterion(F_pred, F_true)

    return loss.item()

In [8]:
# Training
total_train_loss = []
total_eval_loss = []
# batch = enumerate(list(range(0,10)))
for epoch in range(args.epoch):
    if epoch == 10:
        args.forward_sequence = False
    torch.cuda.empty_cache()
    train_loss = 0
    for i in range(train_loader.__len__()):
        data = next(iter(train_loader))
        # print(data)
        train_loss += train(model=model, data=data, args=args, forward_sequence=True)

    # train_loss /= train_loader.__len__() # len(train_dataset)
    total_train_loss.append(train_loss)

    eval_loss = 0
    for i in range(test_loader.__len__()):
        data = next(iter(test_loader))
        eval_loss += eval(model=model, data=data, args=args)
    # eval_loss /= test_loader.__len__() #len(eval_dataset)
    total_eval_loss.append(eval_loss)
    
    print(f'Epoch {epoch}: train loss = {train_loss}; eval loss = {eval_loss}')
    if (epoch+1) % 20 == 0:
        torch.save(model.state_dict(), f'models/parc_test4_epoch{epoch+1}.pth')

Epoch 0: train loss = 0.00040585276997262554; eval loss = 0.9666224056854844
Epoch 1: train loss = 0.00012802546424950378; eval loss = 0.1745601030997932
Epoch 2: train loss = 3.555949609790332e-05; eval loss = 0.03760470013367012
Epoch 3: train loss = 0.00011399848389714862; eval loss = 0.16122134460601956
Epoch 4: train loss = 0.00011673552931767972; eval loss = 0.08398367225890979
Epoch 5: train loss = 5.510589656765319e-05; eval loss = 0.01531207085645292
Epoch 6: train loss = 2.3532425334593654e-05; eval loss = 0.024208789531257935
Epoch 7: train loss = 2.267596809346628e-05; eval loss = 0.05833310403977521
Epoch 8: train loss = 2.8815811919713497e-05; eval loss = 0.06409591823467053
Epoch 9: train loss = 2.732691957874067e-05; eval loss = 0.035488810201059096
Epoch 10: train loss = 2.110053697279568e-05; eval loss = 0.008698428922798485
Epoch 11: train loss = 1.848051065422851e-05; eval loss = 0.00359968802149524
Epoch 12: train loss = 2.032678924024367e-05; eval loss = 0.0112505

In [9]:
forward_sequence_validate = False
# Load to evaluate
dataset = OneDDatasetLoader(
    root_dir='/data1/tam/downloaded_datasets_WT_v2',
    sub_dir='normalized',
    # sub_dir='normed_and_batched',
    data_names='all',
    time_names=[str(i).zfill(3) for i in range(201)]
)

data = dataset[38]
# args.device = torch.device('cpu')
# prepare model
model = RecurrentFormulationNetwork(
    n_field=args.n_field,
    n_meshfield=args.n_meshfield,
    n_boundaryfield=args.n_boundaryfield,
    # n_globalfield=args.n_globalfield,
    n_hidden=args.n_hidden,
    hidden_size=args.hidden_size,
    integration=None
)
model.load_state_dict(torch.load(
    'models/parc_test3_epoch80.pth',
    map_location='cuda:0'
))
model=model.to(args.device)




F_true = torch.cat([
        data.pressure.unsqueeze(2),
        data.flowrate.unsqueeze(2)
    ], dim=2).float().to(args.device)
if not forward_sequence_validate:
    # forward only initial state
    F = F_true[:,0,:]
else:
    # forward all timestep
    F = F_true
F_true = F_true[:,1:,:]

## Connectivity/edge_index: Tensor(2, n_edge)
# edge_index = data.edge_index.to(args.device)
edge_index = torch.cat([
    data.edge_index, torch.flip(data.edge_index, dims=[0]
)], dim=1).to(args.device)

## Mesh features: Tuple(Tensor(n_node, n_node_attr), Tensor(n_edge, n_edge_attr))
node_attr = data.node_attr.float().to(args.device)
# edge_attr = data.edge_attr.float().to(args.device)
# edge_attr = torch.cat([data.edge_attr, data.edge_attr], dim=0).float().to(args.device)
edge_attr = None
meshfield = (node_attr, edge_attr)

## Boundary field
# boundaryfield = None
boundaryfield = torch.cat([data.flowrate[0,:].unsqueeze(0)]*data.flowrate.size(0), dim=0)
boundaryfield = boundaryfield.float().to(args.device)

## Time tensor
# time = data.time.float().to(args.device)
time = None

## Predict output sequence
with torch.no_grad():
    F_pred = model(
        F = F,
        edge_index=edge_index,
        meshfield=meshfield,
        boundaryfield=boundaryfield,
        forward_sequence=forward_sequence_validate,
        n_time=data.flowrate.size(1)
    )

node_list = [1, 100, 30000]
## Draw pressure
import matplotlib.pyplot as plt
for i_node in node_list:
    i_field = 0
    y_pred = F_pred.cpu().numpy()[i_node,:,i_field]
    y_true = F_true.cpu().numpy()[i_node,:,i_field]
    # print(y_true, y_pred)
    x = [i * 4.0 /200 for i in range(y_pred.shape[0])]
    # print(data.node_attr.numpy()[i_node, 6])
    # plt.ylim(-1,1)
    plt.plot(x, y_pred, c='red', label='GNN Euler')
    plt.plot(x, y_true, c='blue', linestyle='dashdot', label='ground_truth')
    # plt.ylim([-1,1])
    plt.legend(loc='upper right')
    plt.ylabel('Pressure', fontsize=20)
    plt.xlabel('Time', fontsize=20)
    plt.show()

## Draw flowrate
for i_node in node_list:
    i_field = 1
    y_pred = F_pred.cpu().numpy()[i_node,:,i_field]
    y_true = F_true.cpu().numpy()[i_node,:,i_field]
    x = [i * 4.0 /200 for i in range(y_pred.shape[0])]
    # print(data.node_attr.numpy()[i_node, 6])
    # plt.ylim(-1,1)
    plt.plot(x, y_pred, c='red', label='GNN Euler')
    plt.plot(x, y_true, c='blue', linestyle='dashdot', label='ground_truth')
    # plt.ylim([-1,1])
    plt.legend(loc='upper right')
    plt.ylabel('Flowrate', fontsize=20)
    plt.xlabel('Time', fontsize=20)
    plt.show()

RuntimeError: Error(s) in loading state_dict for RecurrentFormulationNetwork:
	Missing key(s) in state_dict: "pre_net.module_0.bias", "pre_net.module_0.lin.weight", "pre_net.module_1.weight", "pre_net.module_1.bias", "pre_net.module_3.bias", "pre_net.module_3.lin.weight", "net.down_convs.5.bias", "net.down_convs.5.lin.weight", "net.pools.4.weight", "net.up_convs.4.bias", "net.up_convs.4.lin.weight", "post_net.module_0.bias", "post_net.module_0.lin.weight", "post_net.module_2.bias", "post_net.module_2.lin.weight", "post_net.module_4.weight", "post_net.module_4.bias", "post_net.module_6.weight", "post_net.module_6.bias", "post_net.module_8.weight", "post_net.module_8.bias". 
	size mismatch for net.down_convs.0.bias: copying a param with shape torch.Size([128]) from checkpoint, the shape in current model is torch.Size([256]).
	size mismatch for net.down_convs.0.lin.weight: copying a param with shape torch.Size([128, 22]) from checkpoint, the shape in current model is torch.Size([256, 256]).
	size mismatch for net.down_convs.1.bias: copying a param with shape torch.Size([128]) from checkpoint, the shape in current model is torch.Size([256]).
	size mismatch for net.down_convs.1.lin.weight: copying a param with shape torch.Size([128, 128]) from checkpoint, the shape in current model is torch.Size([256, 256]).
	size mismatch for net.down_convs.2.bias: copying a param with shape torch.Size([128]) from checkpoint, the shape in current model is torch.Size([256]).
	size mismatch for net.down_convs.2.lin.weight: copying a param with shape torch.Size([128, 128]) from checkpoint, the shape in current model is torch.Size([256, 256]).
	size mismatch for net.down_convs.3.bias: copying a param with shape torch.Size([128]) from checkpoint, the shape in current model is torch.Size([256]).
	size mismatch for net.down_convs.3.lin.weight: copying a param with shape torch.Size([128, 128]) from checkpoint, the shape in current model is torch.Size([256, 256]).
	size mismatch for net.down_convs.4.bias: copying a param with shape torch.Size([128]) from checkpoint, the shape in current model is torch.Size([256]).
	size mismatch for net.down_convs.4.lin.weight: copying a param with shape torch.Size([128, 128]) from checkpoint, the shape in current model is torch.Size([256, 256]).
	size mismatch for net.pools.0.weight: copying a param with shape torch.Size([1, 128]) from checkpoint, the shape in current model is torch.Size([1, 256]).
	size mismatch for net.pools.1.weight: copying a param with shape torch.Size([1, 128]) from checkpoint, the shape in current model is torch.Size([1, 256]).
	size mismatch for net.pools.2.weight: copying a param with shape torch.Size([1, 128]) from checkpoint, the shape in current model is torch.Size([1, 256]).
	size mismatch for net.pools.3.weight: copying a param with shape torch.Size([1, 128]) from checkpoint, the shape in current model is torch.Size([1, 256]).
	size mismatch for net.up_convs.0.bias: copying a param with shape torch.Size([128]) from checkpoint, the shape in current model is torch.Size([256]).
	size mismatch for net.up_convs.0.lin.weight: copying a param with shape torch.Size([128, 256]) from checkpoint, the shape in current model is torch.Size([256, 256]).
	size mismatch for net.up_convs.1.bias: copying a param with shape torch.Size([128]) from checkpoint, the shape in current model is torch.Size([256]).
	size mismatch for net.up_convs.1.lin.weight: copying a param with shape torch.Size([128, 256]) from checkpoint, the shape in current model is torch.Size([256, 256]).
	size mismatch for net.up_convs.2.bias: copying a param with shape torch.Size([128]) from checkpoint, the shape in current model is torch.Size([256]).
	size mismatch for net.up_convs.2.lin.weight: copying a param with shape torch.Size([128, 256]) from checkpoint, the shape in current model is torch.Size([256, 256]).
	size mismatch for net.up_convs.3.bias: copying a param with shape torch.Size([2]) from checkpoint, the shape in current model is torch.Size([256]).
	size mismatch for net.up_convs.3.lin.weight: copying a param with shape torch.Size([2, 256]) from checkpoint, the shape in current model is torch.Size([256, 256]).