#Load data

In [1]:
import sys
sys.path.insert(1, '../Codes/')
import torch
from dataset import OneDDatasetBuilder, OneDDatasetLoader
from plot import *
from preprocessing import dataset_split_to_loader, cal_deriv_F
from networks_v3 import PARC_Graph

import torch_geometric
import os
os.environ["CUDA_VISIBLE_DEVICES"]="0,1,2,3"
# %env CUDA_VISIBLE_DEVICE=2
CUDA_LAUNCH_BLOCKING=1
torch.cuda.empty_cache()

In [2]:
# define argument
class objectview(object):
    def __init__(self, d) -> None:
        self.__dict__ = d
    def setattr(self, attr_name, attr_value):
        self.__dict__[attr_name] = attr_value

args = objectview(d={
    # data params
    'total_time': 4.0,
    'n_time': 201,
    'batch_size': None,
    'batch_n_time': None,
    'batch_step': 2,
    'batch_recursive': True,
    # model params
    'n_field': 2,
    'n_meshfield': (3, 15),
    'n_boundaryfield': 1,
    'hidden_size': 256,
    'unet_depth':5,
    'forward_sequence': False,
    'max_gen': 10,
    # training params
    'device': torch.device("cuda:1" if torch.cuda.is_available() else "cpu"),
    'lr': 2e-7,
    'weight_decay': 5e-3,
    'epoch': 2000,
    'criterion': torch.nn.MSELoss(),
    'n_datas_per_batch': 1,
})

In [3]:
# # Build dataset
# dataset = OneDDatasetBuilder(
#     raw_dir='/data1/tam/datasets_231228',
#     root_dir='/data1/tam/downloaded_datasets_WT_v1',
#     data_names='all',
#     time_names=[str(i).zfill(3) for i in range(201)]
# )

# Load dataset
dataset = OneDDatasetLoader(
    root_dir='M:/Tam/fromHung_data231228_source/download_dataset_WT_v1',
    sub_dir='processed',
    data_names='all',
    time_names=[str(i).zfill(3) for i in range(201)]
)

# dataset = dataset.cut_branch(max_gen=args.max_gen)

# dataset = dataset.normalizing(
#     sub_dir='normalized',
#     scalers = {
#         'node_attr' : ['minmax_scaler', 0],
#         'edge_attr' : ['minmax_scaler', 0],
#         'pressure' : ['minmax_scaler', None],
#         'flowrate' : ['minmax_scaler', None],
#         'pressure_dot' : ['minmax_scaler', None],
#         'flowrate_dot' : ['minmax_scaler', None]
#     }
# )

# dataset = dataset.batching(
#     batch_size = args.batch_size,
#     batch_n_times = args.batch_n_time, 
#     recursive = args.batch_recursive, 
#     sub_dir='/normed_batched', 
#     step=args.batch_step
# )

In [4]:
from loss import OneDAirwayLoss
criterion = OneDAirwayLoss()
data = dataset[1]
loss = criterion(data)

torch.Size([59692, 201])
tensor(0)


In [None]:
loss[:,0].size()

In [None]:
# Load dataset
dataset = OneDDatasetLoader(
    root_dir='/data1/tam/downloaded_datasets_WT_v1',
    sub_dir='normed_batched',
    data_names='all',
    time_names=[str(i).zfill(3) for i in range(201)]
)

train_loader, test_loader = dataset_split_to_loader(
    dataset = dataset,
    subset_ids = {
        'train': list(range(0, 30)),
        'test': list(range(30, 35))
    },
    n_datas_per_batch = args.n_datas_per_batch
)


In [None]:
# prepare model
model = PARC_Graph(
    n_field=args.n_field,
    n_meshfield=args.n_meshfield,
    n_boundaryfield=args.n_boundaryfield,
    hidden_size=args.hidden_size,
    unet_depth=args.unet_depth,
    activation=torch.nn.functional.relu
)
# model = torch_geometric.nn.DataParallel(model)
model = model.to(args.device)
optimizer = torch.optim.SGD(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
# args.setattr(attr_name='optimizer', attr_value=optimizer)

#Training

In [None]:
def train(model, data, args, forward_sequence=False):
    ## Field tensor: Tensor(n_node, n_time, n_field)
    F_true = torch.cat([
        data.pressure.unsqueeze(2),
        data.flowrate.unsqueeze(2)
    ], dim=2).float().to(args.device)
    if not forward_sequence:
        # forward only initial state
        F = F_true[:,0,:]
    else:
        # forward all timestep
        F = F_true
    F_true = F_true[:,1:,:] # timestep 1 to N

    ## Field gradient tensor
    F_dot_true = torch.cat([
        data.pressure_dot.unsqueeze(2),
        data.flowrate_dot.unsqueeze(2)
    ], dim=2).float().to(args.device)
    F_dot_true = F_dot_true[:,:-1,:] # timestep 0 to N-1
    
    ## Connectivity/edge_index: Tensor(2, n_edge)
    edge_index = data.edge_index.to(args.device)
    # edge_index = torch.cat([
    #     data.edge_index, torch.flip(data.edge_index, dims=[0]
    # )], dim=1).to(args.device)

    ## Mesh features: Tuple(Tensor(n_node, n_node_attr), Tensor(n_edge, n_edge_attr))
    node_attr = data.node_attr.float().to(args.device)
    edge_attr = data.edge_attr.float().to(args.device)
    # edge_attr = torch.cat([data.edge_attr, data.edge_attr], dim=0).float().to(args.device)
    meshfield = (node_attr, edge_attr)

    # Boundary field
    boundaryfield = torch.cat([data.flowrate[0,:].unsqueeze(0)]*data.flowrate.size(0), dim=0)
    boundaryfield = boundaryfield.float().to(args.device)
    # boundaryfield = None


    ## Predict output sequence
    F_pred, F_dot_pred = model(
        F_input = F,
        edge_index=edge_index,
        meshfield=meshfield,
        boundaryfield=boundaryfield,
        forward_sequence=forward_sequence,
        n_time=data.flowrate.size(1)
    )
    
    ## loss calculation
    loss = args.criterion(F_pred, F_true) + args.criterion(F_dot_pred, F_dot_true)
    loss.backward()
    optimizer.step()

    return loss.item()

In [None]:
def eval(model, data, args, forward_sequence=False):
    ## Field tensor: Tensor(n_node, n_time, n_field)
    F_true = torch.cat([
            data.pressure.unsqueeze(2),
            data.flowrate.unsqueeze(2)
        ], dim=2).float().to(args.device)
    if not forward_sequence:
        # forward only initial state
        F = F_true[:,0,:]
    else:
        # forward all timestep
        F = F_true
    F_true = F_true[:,1:,:]

    ## Field gradient tensor
    F_dot_true = torch.cat([
        data.pressure_dot.unsqueeze(2),
        data.flowrate_dot.unsqueeze(2)
    ], dim=2).float().to(args.device)
    F_dot_true = F_dot_true[:,:-1,:] # timestep 0 to N-1
    
    ## Connectivity/edge_index: Tensor(2, n_edge)
    edge_index = data.edge_index.to(args.device)
    # edge_index = torch.cat([
    #     data.edge_index, torch.flip(data.edge_index, dims=[0]
    # )], dim=1).to(args.device)

    ## Mesh features: Tuple(Tensor(n_node, n_node_attr), Tensor(n_edge, n_edge_attr))
    node_attr = data.node_attr.float().to(args.device)
    edge_attr = data.edge_attr.float().to(args.device)
    # edge_attr = torch.cat([data.edge_attr, data.edge_attr], dim=0).float().to(args.device)
    meshfield = (node_attr, edge_attr)

    ## Boundary field
    boundaryfield = torch.cat([data.flowrate[0,:].unsqueeze(0)]*data.flowrate.size(0), dim=0)
    boundaryfield = boundaryfield.float().to(args.device)
    # boundaryfield = None

    ## Predict output sequence
    with torch.no_grad():
        F_pred, F_dot_pred = model(
            F_input = F,
            edge_index=edge_index,
            meshfield=meshfield,
            boundaryfield=boundaryfield,
            forward_sequence=forward_sequence,
            n_time=data.flowrate.size(1)
        )
        
        ## loss calculation
        loss = args.criterion(F_pred, F_true) + args.criterion(F_dot_pred, F_dot_true)

    return loss.item()

In [None]:
# Training
total_train_loss = []
total_eval_loss = []
# batch = enumerate(list(range(0,10)))
for epoch in range(args.epoch):
    torch.cuda.empty_cache()
    train_loss = 0
    for i in range(train_loader.__len__()):
        data = next(iter(train_loader))
        # print(data)
        train_loss += train(model=model, data=data, args=args, forward_sequence=True)

    train_loss /= train_loader.__len__() # len(train_dataset)
    total_train_loss.append(train_loss)

    eval_loss = 0
    for i in range(test_loader.__len__()):
        data = next(iter(test_loader))
        eval_loss += eval(model=model, data=data, args=args)
    eval_loss /= test_loader.__len__() #len(eval_dataset)
    total_eval_loss.append(eval_loss)
    
    print(f'Epoch {epoch}: train loss = {train_loss}; eval loss = {eval_loss}')
    if (epoch+1) % 20 == 0:
        torch.save(model.state_dict(), f'models/parc_3_epoch{epoch+1}.pth')

In [None]:
forward_sequence_validate = False
# Load to evaluate
dataset = OneDDatasetLoader(
    root_dir='/data1/tam/downloaded_datasets_WT_v1',
    sub_dir='normed_batched',
    # sub_dir='normed_and_batched',
    data_names='all',
    time_names=[str(i).zfill(3) for i in range(201)]
)

data = dataset[38]
# args.device = torch.device('cpu')
# prepare model
model = PARC_Graph(
    n_field=args.n_field,
    n_meshfield=args.n_meshfield,
    n_boundaryfield=args.n_boundaryfield,
    hidden_size=args.hidden_size,
    unet_depth=args.unet_depth,
    activation=torch.nn.functional.relu
)
model.load_state_dict(torch.load(
    'models/parc_3_epoch640.pth',
    map_location='cuda:1'
))
model=model.to(args.device)




F_true = torch.cat([
        data.pressure.unsqueeze(2),
        data.flowrate.unsqueeze(2)
    ], dim=2).float().to(args.device)
if not args.forward_sequence:
    # forward only initial state
    F = F_true[:,0,:]
else:
    # forward all timestep
    F = F_true
F_true = F_true[:,1:,:]

## Field gradient tensor
F_dot_true = torch.cat([
    data.pressure_dot.unsqueeze(2),
    data.flowrate_dot.unsqueeze(2)
], dim=2).float().to(args.device)
F_dot_true = F_dot_true[:,:-1,:] # timestep 0 to N-1

## Connectivity/edge_index: Tensor(2, n_edge)
edge_index = data.edge_index.to(args.device)
# edge_index = torch.cat([
#     data.edge_index, torch.flip(data.edge_index, dims=[0]
# )], dim=1).to(args.device)

## Mesh features: Tuple(Tensor(n_node, n_node_attr), Tensor(n_edge, n_edge_attr))
node_attr = data.node_attr.float().to(args.device)
edge_attr = data.edge_attr.float().to(args.device)
# edge_attr = torch.cat([data.edge_attr, data.edge_attr], dim=0).float().to(args.device)
meshfield = (node_attr, edge_attr)

## Boundary field
boundaryfield = torch.cat([data.flowrate[0,:].unsqueeze(0)]*data.flowrate.size(0), dim=0)
boundaryfield = boundaryfield.float().to(args.device)
# boundaryfield = None

## Predict output sequence
with torch.no_grad():
    F_pred, F_dot_pred = model(
        F_input = F,
        edge_index=edge_index,
        meshfield=meshfield,
        boundaryfield=boundaryfield,
        forward_sequence=args.forward_sequence,
        n_time=data.flowrate.size(1)
    )

node_list = [5, 10, 50, 100, 150, 200]
## Draw pressure
import matplotlib.pyplot as plt
for i_node in node_list:
    i_field = 0
    y_pred = F_pred.cpu().numpy()[i_node,:,i_field]
    y_true = F_true.cpu().numpy()[i_node,:,i_field]
    # print(y_true, y_pred)
    x = [i * 4.0 /200 for i in range(y_pred.shape[0])]
    # print(data.node_attr.numpy()[i_node, 6])
    # plt.ylim(-1,1)
    plt.plot(x, y_pred, c='red', label='GNN Euler')
    plt.plot(x, y_true, c='blue', linestyle='dashdot', label='ground_truth')
    # plt.ylim([-1,1])
    plt.legend(loc='upper right')
    plt.ylabel('Pressure', fontsize=20)
    plt.xlabel('Time', fontsize=20)
    plt.show()

## Draw flowrate
for i_node in node_list:
    i_field = 1
    y_pred = F_pred.cpu().numpy()[i_node,:,i_field]
    y_true = F_true.cpu().numpy()[i_node,:,i_field]
    x = [i * 4.0 /200 for i in range(y_pred.shape[0])]
    # print(data.node_attr.numpy()[i_node, 6])
    # plt.ylim(-1,1)
    plt.plot(x, y_pred, c='red', label='GNN Euler')
    plt.plot(x, y_true, c='blue', linestyle='dashdot', label='ground_truth')
    # plt.ylim([-1,1])
    plt.legend(loc='upper right')
    plt.ylabel('Flowrate', fontsize=20)
    plt.xlabel('Time', fontsize=20)
    plt.show()