In [1]:
import sys
sys.path.insert(1, '../Codes_1Dtree')
import os
os.environ["CUDA_VISIBLE_DEVICES"]="0"

Load libraries

In [2]:
import torch
from torch_geometric.loader import DataLoader
from data.graph_dataset import OneDDatasetBuilder, OneDDatasetLoader, normalize
# from networks.gcn import GraphUNet, RecurrentFormulationNet
from networks.gcnv4 import RecurrentFormulationNet
import matplotlib.pyplot as plt
from torch_geometric.loader import NeighborLoader

from data.graph_dataset import batchgraph

In [3]:
class objectview(object):
    def __init__(self, d) -> None:
        self.__dict__ = d
    def setattr(self, attr_name, attr_value):
        self.__dict__[attr_name] = attr_value

args = objectview({
    'n_field': 1,
    'n_meshfield': 16,
    'hidden_size': 256,
    'latent_size': 256,
    'aggr': 'sum',
    'act': 'mish',
    'dropout': 0.2,
    'device': torch.device('cuda:0' if torch.cuda.is_available() else 'cpu'),
    'lr': 1e-6,
    'weight_decay': 1e-3,
    'n_epoch': 500,
    'alpha': 0.5,
    'batchsize': 20000,
    'timestep': 201,
    'timeslice_hops': 0,
    'timeslice_steps': 5,
    'n_data_per_batch': 1,
    'criterion': torch.nn.MSELoss(),
    'plot': False
})

Load dataset

In [4]:
# dataset = OneDDatasetBuilder(
#     raw_dir='/data1/tam/datasets',
#     root_dir='/data1/tam/downloaded_datasets_Static_v1',
#     sub_dir='processed',
#     subjects='all',
#     time_names=[str(i).zfill(3) for i in range(201)],
#     data_type = torch.float32,
#     readme='edge_index(2xn_edge), node_attr(n_nodex10), pressure+flowrate(n_nodex201)'
# )
dataset = OneDDatasetLoader(
    root_dir='/data1/tam/downloaded_datasets_Static_v1',
    sub_dir='processed',
    subjects='all',
    time_names=[str(i).zfill(3) for i in range(201)],
    data_type = torch.float32
)

In [5]:
dataset = normalize(
    dataset=dataset,
    sub_dir='normalized',
    scaler_dict={
        'node_attr': ('minmax_scaler', 0, None),
        'pressure': ('robust_scaler', None, None),
        # 'flowrate': ('robust_scaler', None, None),
        'pressure_dot': ('robust_scaler', None, None),
        # 'flowrate_dot': ('robust_scaler', None, None),
        # 'time': ('minmax_scaler', None, None)
    }
)
dataset = OneDDatasetLoader(
    root_dir='/data1/tam/downloaded_datasets_Static_v1',
    sub_dir='normalized',
    subjects='all',
    time_names=[str(i).zfill(3) for i in range(201)],
    data_type = torch.float32
)

In [6]:
# dataset = batchgraph_generation_wise(
#     dataset,
#     sub_dir='batched',
#     batch_gens=[[14,15]],
#     timestep=args.timestep,
#     timeslice_hops=args.timeslice_hops,
#     timeslice_steps=args.timeslice_steps
# )
dataset = batchgraph(
    dataset,
    sub_dir='batched_1',
    batchsize=None,
    timestep=args.timestep,
    timeslice_hops=args.timeslice_hops,
    timeslice_steps=args.timeslice_steps
)
dataset = OneDDatasetLoader(
    root_dir='/data1/tam/downloaded_datasets_Static_v1',
    sub_dir='batched_1',
    subjects='all',
    time_names=[str(i).zfill(3) for i in range(201)],
    data_type = torch.float32
)

In [7]:
fold_size = 2
fold = [list(range(i*fold_size, (i+1)*fold_size)) for i in range(8)]
n_train_fold = 5

# for i in range(n_train_fold):
#     test_subset = fold[i]
#     train_subset = []
#     for j in range(n_train_fold):
#         if j != i:
#             train_subset += fold[j]

# (train_loader, test_loader) = dataset_to_loader(
#     dataset=dataset,
#     data_subset_dict={
#         'train': list(range(5, 6)),
#         'test': list(range(0, 1))
#     },
#     n_data_per_batch=args.n_data_per_batch
# )

# train_set, test_set = dataset_to_loader(
#     dataset=dataset,
#     data_subset_dict={
#         'train': list(range(6, 36)),
#         'test': list(range(0, 5))
#     },
#     n_data_per_batch=args.n_data_per_batch
# )

Model initializing and training

In [8]:
model = RecurrentFormulationNet(
    n_field=args.n_field,
    n_meshfield=args.n_meshfield,
    hidden_size=args.hidden_size,
    latent_size=args.latent_size,
    act=args.act,
    use_time_feature=True,
    dropout=args.dropout,
    use_hidden=True
)
setattr(model, 'name', 'PARC_GCN_UNet_full')
model = model.to(args.device)
model.load_state_dict(torch.load(f'models/{model.name}_node2_epoch500.pth', map_location=args.device) )
optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
setattr(args, 'optimizer', optimizer)

In [9]:
def train(model, data, args):
    ##
    F_true = torch.cat([data.pressure.unsqueeze(2)], dim=2) \
                .float().to(args.device)
    F_dot_true = torch.cat([data.pressure_dot.unsqueeze(2)], dim=2) \
                .float().to(args.device)
    ##
    F_0 = F_true[:,0,:]
    # edge_index = data.edge_index.to(args.device)
    edge_index = torch.cat([data.edge_index, torch.flip(data.edge_index, dims=[0])], dim=1).to(args.device)
    node_attr = data.node_attr.float().to(args.device)
    F_true = F_true[:,1:,:]
    F_dot_true = F_dot_true[:,1:,:]
    time = data.time.float().to(args.device)
    ##
    F_pred, F_dot_pred = model.forward(
        F_0=F_0,
        edge_index=edge_index,
        meshfield=node_attr,
        time=time,
        n_time=data.number_of_timesteps - 1,
        device=args.device
    )
    ##
    loss = args.criterion(F_pred, F_true)*args.alpha + (1.-args.alpha)*args.criterion(F_dot_pred, F_dot_true)
    loss.backward()
    args.optimizer.step()
    
    return loss.item()

def eval(model, data, args):
    ##
    F_true = torch.cat([data.pressure.unsqueeze(2)], dim=2) \
                .float().to(args.device)
    F_dot_true = torch.cat([data.pressure_dot.unsqueeze(2)], dim=2) \
                .float().to(args.device)
    ##
    F_0 = F_true[:,0,:]
    # edge_index = data.edge_index.to(args.device)
    edge_index = torch.cat([data.edge_index, torch.flip(data.edge_index, dims=[0])], dim=1).to(args.device)
    node_attr = data.node_attr.float().to(args.device)
    F_true = F_true[:,1:,:]
    F_dot_true = F_dot_true[:,1:,:]
    time = data.time.float().to(args.device)
    ##
    with torch.no_grad():
        F_pred, F_dot_pred = model.forward(
            F_0=F_0,
            edge_index=edge_index,
            meshfield=node_attr,
            time=time,
            n_time=data.number_of_timesteps - 1,
            device=args.device
        )
        loss = args.criterion(F_pred, F_true)*args.alpha + (1.-args.alpha)*args.criterion(F_dot_pred, F_dot_true)
        
    return loss.item()

In [10]:
# Training
total_train_loss = []
total_eval_loss = []
train_loss = 0
eval_loss = 0
for epoch in range(args.n_epoch):
    CUDA_LAUNCH_BLOCKING=1
    torch.cuda.empty_cache()
    ##
    test_subset = fold[epoch % n_train_fold]
    train_subset = []
    for j in range(n_train_fold):
        if j != (epoch % n_train_fold):
            train_subset += fold[j]
    ##
    torch.cuda.empty_cache()
    # train_loss = 0
    for i_data in train_subset:
        data = dataset[i_data]
        train_loader = NeighborLoader(data, num_neighbors=[1], batch_size=args.batchsize)
        for i in range(train_loader.__len__()):
            data = next(iter(train_loader))
            train_loss += train(model=model, data=data, args=args) / train_loader.__len__()

    # eval_loss = 0
    for i_data in test_subset:
        data = dataset[i_data]
        test_loader = NeighborLoader(data, num_neighbors=[1], batch_size=args.batchsize)
        for i in range(test_loader.__len__()):
            data = next(iter(test_loader))
            eval_loss += eval(model=model, data=data, args=args) / test_loader.__len__()
    
    if (epoch+1) % n_train_fold == 0:
        train_loss /= n_train_fold*fold_size
        eval_loss /= fold_size
        print(f'Epoch {epoch}: train loss = {train_loss}; eval loss = {eval_loss}')
        train_loss = 0
        eval_loss = 0
    # else:
        # print(f'Epoch {epoch}.')
        
    if (epoch+1) % 20 == 0:
        torch.save(model.state_dict(), f'models/{model.name}_node2_epoch{epoch+1}.pth')

Epoch 4: train loss = 3.4960363090038293; eval loss = 4.376752078533173
Epoch 9: train loss = 3.456063965459664; eval loss = 4.333602366348107
Epoch 14: train loss = 3.491986308991909; eval loss = 4.329758994281292
Epoch 19: train loss = 3.4658332392573357; eval loss = 4.333074827988942
Epoch 24: train loss = 3.469724143544833; eval loss = 4.352372484902542
Epoch 29: train loss = 3.490648339192073; eval loss = 4.356131213406723
Epoch 34: train loss = 3.5297192454338067; eval loss = 4.380157118042311
Epoch 39: train loss = 3.504964096844199; eval loss = 4.388063937425613
Epoch 44: train loss = 3.510457652310528; eval loss = 4.383433627585571
Epoch 49: train loss = 3.5420380170146606; eval loss = 4.440773792564869
Epoch 54: train loss = 3.5342303360501943; eval loss = 4.412063437203566
Epoch 59: train loss = 3.647164194782576; eval loss = 4.568718279401461
Epoch 64: train loss = 3.921804159382976; eval loss = 4.953121195236843
Epoch 69: train loss = 4.521487582226594; eval loss = 5.69555