In [1]:
import sys
sys.path.insert(1, '../Codes_1Dtree')
import os
os.environ["CUDA_VISIBLE_DEVICES"]="0,1,2,3"

Load libraries

In [2]:
import torch
from torch_geometric.loader import DataLoader
from data.graph_dataset import OneDDatasetBuilder, OneDDatasetLoader, normalize, dataset_to_loader
# from networks.gcn import GraphUNet, RecurrentFormulationNet
from networks.gcnv2 import GraphUNetv2, RecurrentFormulationNet_hidden
import matplotlib.pyplot as plt

from data.graph_dataset import batchgraph_generation_wise, batchgraph

In [3]:
class objectview(object):
    def __init__(self, d) -> None:
        self.__dict__ = d
    def setattr(self, attr_name, attr_value):
        self.__dict__[attr_name] = attr_value

args = objectview({
    'n_field': 2,
    'n_meshfield': 16,
    'hidden_size': 128,
    'latent_size': 128,
    'aggr': 'sum',
    'act': 'mish',
    'dropout': 0.2,
    'device': torch.device('cuda:0' if torch.cuda.is_available() else 'cpu'),
    'lr': 1e-6,
    'weight_decay': 5e-3,
    'n_epoch': 200,
    'alpha': 0.5,
    'timestep': None,
    'timeslice_hops': 1,
    'timeslice_steps': 1,
    'n_data_per_batch': 3,
    'criterion': torch.nn.MSELoss(),
    'plot': False
})

Load dataset

In [4]:
# dataset = OneDDatasetBuilder(
#     raw_dir='/data1/tam/datasets',
#     root_dir='/data1/tam/downloaded_datasets_Static_v1',
#     sub_dir='processed',
#     subjects='all',
#     time_names=[str(i).zfill(3) for i in range(201)],
#     data_type = torch.float32,
#     readme='edge_index(2xn_edge), node_attr(n_nodex10), pressure+flowrate(n_nodex201)'
# )
# dataset = OneDDatasetLoader(
#     root_dir='/data1/tam/downloaded_datasets_Static_v1',
#     sub_dir='processed',
#     subjects='all',
#     time_names=[str(i).zfill(3) for i in range(201)],
#     data_type = torch.float32
# )

In [5]:
# dataset = normalize(
#     dataset=dataset,
#     sub_dir='normalized',
#     scaler_dict={
#         'node_attr': ('minmax_scaler', 0, None),
#         'pressure': ('minmax_scaler', None, 0.99),
#         'flowrate': ('minmax_scaler', None, 0.99),
#         'pressure_dot': ('minmax_scaler', None, 0.99),
#         'flowrate_dot': ('minmax_scaler', None, 0.99),
#         'time': ('minmax_scaler', None, None)
#     }
# )
# dataset = OneDDatasetLoader(
#     root_dir='/data1/tam/downloaded_datasets_Static_v1',
#     sub_dir='normalized',
#     subjects='all',
#     time_names=[str(i).zfill(3) for i in range(201)],
#     data_type = torch.float32
# )

In [6]:
# dataset = batchgraph_generation_wise(
#     dataset,
#     sub_dir='batched',
#     batch_gens=[[14,15]],
#     timestep=args.timestep,
#     timeslice_hops=args.timeslice_hops,
#     timeslice_steps=args.timeslice_steps
# )
# dataset = batchgraph(
#     dataset,
#     sub_dir='batched',
#     batchsize=1000,
#     timestep=args.timestep,
#     timeslice_hops=args.timeslice_hops,
#     timeslice_steps=args.timeslice_steps
# )
dataset = OneDDatasetLoader(
    root_dir='/data1/tam/downloaded_datasets_Static_v1',
    sub_dir='batched',
    subjects='all',
    time_names=[str(i).zfill(3) for i in range(201)],
    data_type = torch.float32
)
if args.plot:
    print(dataset[0])

In [7]:
# [dataset[i] for i in range(dataset.len())]

In [8]:
# fold_size = 5
# fold = [list(range(i*fold_size, (i+1)*fold_size) for i in range(5))]
# (train_loader, test_loader) = dataset_to_loader(
#     dataset=dataset,
#     data_subset_dict={
#         'train': list(range(5, 35)),
#         'test': list(range(0, 5))
#     },
#     n_data_per_batch=args.n_data_per_batch
# )

train_set, test_set = dataset_to_loader(
    dataset=dataset,
    data_subset_dict={
        'train': list(range(6, 36)),
        'test': list(range(0, 5))
    },
    n_data_per_batch=args.n_data_per_batch
)

Model initializing and training

In [9]:
model = RecurrentFormulationNet_hidden(
    n_field=args.n_field,
    n_meshfield=args.n_meshfield,
    hidden_size=args.hidden_size,
    latent_size=args.latent_size,
    act=args.act,
    use_time_feature=True,
    dropout=args.dropout
)
setattr(model, 'name', 'PARC_GCN_UNet_full')
model = model.to(args.device)
optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
setattr(args, 'optimizer', optimizer)

In [10]:
def train(model, data, args):
    ##
    F_true = torch.cat([data.pressure.unsqueeze(2), data.flowrate.unsqueeze(2)], dim=2) \
                .float().to(args.device)
    F_dot_true = torch.cat([data.pressure_dot.unsqueeze(2), data.flowrate_dot.unsqueeze(2)], dim=2) \
                .float().to(args.device)
    ##
    F_0 = F_true[:,0,:]
    # edge_index = data.edge_index.to(args.device)
    edge_index = torch.cat([data.edge_index, torch.flip(data.edge_index, dims=[0])], dim=1).to(args.device)
    node_attr = data.node_attr.float().to(args.device)
    F_true = F_true[:,1:,:]
    F_dot_true = F_dot_true[:,1:,:]
    time = data.time.float().to(args.device)
    ##
    F_pred, F_dot_pred = model.forward(
        F_0=F_0,
        edge_index=edge_index,
        meshfield=node_attr,
        time=time,
        n_time=data.number_of_timesteps - 1,
        device=args.device
    )
    ##
    loss = args.criterion(F_pred, F_true)*args.alpha + (1.-args.alpha)*args.criterion(F_dot_pred, F_dot_true)
    loss.backward()
    args.optimizer.step()
    
    return loss.item()

def eval(model, data, args):
    ##
    F_true = torch.cat([data.pressure.unsqueeze(2), data.flowrate.unsqueeze(2)], dim=2) \
                .float().to(args.device)
    F_dot_true = torch.cat([data.pressure_dot.unsqueeze(2), data.flowrate_dot.unsqueeze(2)], dim=2) \
                .float().to(args.device)
    ##
    F_0 = F_true[:,0,:]
    # edge_index = data.edge_index.to(args.device)
    edge_index = torch.cat([data.edge_index, torch.flip(data.edge_index, dims=[0])], dim=1).to(args.device)
    node_attr = data.node_attr.float().to(args.device)
    F_true = F_true[:,1:,:]
    F_dot_true = F_dot_true[:,1:,:]
    time = data.time.float().to(args.device)
    ##
    with torch.no_grad():
        F_pred, F_dot_pred = model.forward(
            F_0=F_0,
            edge_index=edge_index,
            meshfield=node_attr,
            time=time,
            n_time=data.number_of_timesteps - 1,
            device=args.device
        )
        loss = args.criterion(F_pred, F_true)*args.alpha + (1.-args.alpha)*args.criterion(F_dot_pred, F_dot_true)
        
    return loss.item()

In [11]:
# Training
total_train_loss = []
total_eval_loss = []
for epoch in range(args.n_epoch):
    torch.cuda.empty_cache()
    train_loss = 0
    for i in range(train_loader.__len__()):
        data = next(iter(train_loader))
        train_loss += train(model=model, data=data, args=args)
    train_loss /= train_loader.__len__()
    total_train_loss.append(train_loss)

    eval_loss = 0
    for i in range(test_loader.__len__()):
        data = next(iter(test_loader))
        eval_loss += eval(model=model, data=data, args=args)
    eval_loss /= test_loader.__len__()
    total_eval_loss.append(eval_loss)
    
    print(f'Epoch {epoch}: train loss = {train_loss}; eval loss = {eval_loss}')
    if (epoch+1) % 20 == 0:
        torch.save(model.state_dict(), f'models/{model.name}_node0_epoch{epoch+1}.pth')

  return adj.to_sparse_csr()


Epoch 0: train loss = 0.07876851679457827; eval loss = 0.23630165060361227
Epoch 1: train loss = 0.06922524499308236; eval loss = 0.22864992991842406
Epoch 2: train loss = 0.06325348011531604; eval loss = 0.2242613567246331
Epoch 3: train loss = 0.05898486746572803; eval loss = 0.22050368364411171
Epoch 4: train loss = 0.05545914284065492; eval loss = 0.21554821789866746
Epoch 5: train loss = 0.05085510195763986; eval loss = 0.21069835623105368
Epoch 6: train loss = 0.04728097640066179; eval loss = 0.2077354072320341
Epoch 7: train loss = 0.044964717666408936; eval loss = 0.20587033906368293
Epoch 8: train loss = 0.04332936475094402; eval loss = 0.20446842141223676
Epoch 9: train loss = 0.042107915540435394; eval loss = 0.20329479797921998
Epoch 10: train loss = 0.041008505547914006; eval loss = 0.20239336156483853
Epoch 11: train loss = 0.0400857713432784; eval loss = 0.20141118508998793
Epoch 12: train loss = 0.039432009868008634; eval loss = 0.19992632351138376
Epoch 13: train loss 