In [1]:
import sys
sys.path.insert(1, '../')
import os
os.environ["CUDA_VISIBLE_DEVICES"]="0" 
CUDA_LAUNCH_BLOCKING=1

In [2]:
import torch
from data.graph_dataset import OneDDatasetBuilder, OneDDatasetLoader
from data.pre_process import *
from data.post_process import print_1D
from networks.graph_parc import GraphPARC
import matplotlib.pyplot as plt
# from networks.losses import LpLoss
from neuralop.losses.data_losses import LpLoss, H1Loss
from torch_geometric.loader import NeighborLoader
import matplotlib.pyplot as plt

In [None]:
class objectview(object):
    def __init__(self, d) -> None:
        self.__dict__ = d
    def setattr(self, attr_name, attr_value):
        self.__dict__[attr_name] = attr_value

args = objectview({
    'n_fields': 1,
    'n_meshfields': (13, 0),
    'hidden_size': 48,
    'n_layers': 7,
    'n_timesteps': 201,
    'n_previous_timesteps': 1,
    'aggr': 'sum',
    'act': 'mish',
    'dropout': 0.1,
    'device': torch.device('cuda:0' if torch.cuda.is_available() else 'cpu'),
    'lr': 5e-6,
    'weight_decay': 5e-4,
    'grad_clip' : 1.,
    'n_epoch': 5000,
    'alpha': 1.0, 
    'batch_size': 10000,
    'timestep': 201,
    'timeslice_hops': 0,
    'timeslice_steps': 1,
    'n_data_per_batch': 1,
    'forward_sequence': False,
    'criterion': torch.nn.MSELoss(),
    'plot': False
})

In [4]:
# # Build dataset
# dataset = OneDDatasetBuilder(
#     raw_dir='/data1/tam/datasets',
#     root_dir='/data1/tam/downloaded_datasets_v3',
#     sub_dir='processed',
#     subjects='all',
#     refined_max_length=4.,
#     time_names=[str(i).zfill(3) for i in range(201)],
#     data_type = torch.float64,
# )

In [5]:
# Load raw dataset
dataset = OneDDatasetLoader(
    root_dir='/data1/tam/downloaded_datasets_node_features',
    sub_dir='processed',
    subjects='all',
    time_names=[str(i).zfill(3) for i in range(201)],
    data_type = torch.float64,
)

In [6]:
# Normalize dataset
dataset = normalize(
    dataset=dataset,
    sub_dir='normalized',
    scaler_dict={
        'node_attr' : ['minmax_scaler']*3+['robust_scaler']*2+['minmax_scaler']*8,
        #'edge_attr' : ['minmax_scaler']*args.n_meshfields[1],
        'pressure' : 'robust_scaler',
    },
    clipping=5e-4
)

In [7]:
# Load normalized datasset
dataset = OneDDatasetLoader(
    root_dir='/data1/tam/downloaded_datasets_node_features',
    sub_dir='normalized',
    subjects='all',
    time_names=[str(i).zfill(3) for i in range(201)],
    data_type = torch.float64,
)

In [8]:
# data = dataset[0]
# x = data.pressure.flatten()
# plt.hist(x, bins=100)
# plt.show()

In [9]:
# Batch dataset
# batched_dataset = batchgraph_generation_wise(
#     sub_dir = 'batched',
#     dataset=dataset,
#     batch_gens=[[0,9], [10, 13], [14, 17], [18, 50]],
#     subset_hops=1,
# )

batched_dataset = batchgraph_timeslice(
    sub_dir = 'batched',
    dataset=dataset,
    timestep = 201,
    timeslice_hops = 0,
    timeslice_steps = 10
)

In [10]:
# Load batched dataset
dataset = OneDDatasetLoader(
    root_dir='/data1/tam/downloaded_datasets_node_features',
    sub_dir='batched',
    subjects='all',
    time_names=[str(i).zfill(3) for i in range(201)],
    data_type = torch.float64,
)

In [11]:
(train_loader, test_loader) = dataset_to_loader(
    dataset=dataset,
    data_subset_dict={
        'train': list(range(0, 20)),
        'test': list(range(30, 35))
    },
    n_data_per_batch=1
)

In [12]:
model = GraphPARC(
    n_fields=args.n_fields,
    n_meshfields=args.n_meshfields,
    hidden_channels=args.hidden_size,
    num_layers=args.n_layers,
    dropout=args.dropout,
    act=args.act,
)
setattr(model, 'name', 'GraphPARC')
model = model.to(args.device)
# model.load_state_dict(torch.load(f'models/{model.name}_node2_epoch6000.pth', map_location=args.device) )
optimizer = torch.optim.AdamW(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.9)
setattr(args, 'optimizer', optimizer)
model

GraphPARC(
  (act): Mish()
  (differentiator): Sequential(
    (0) - SAGEConv(14, 32, aggr=lstm): x, edge_index -> x
    (1) - Dropout(p=0.2, inplace=False): x -> x
    (2) - Mish(): x -> x
    (3) - SAGEConv(32, 64, aggr=lstm): x, edge_index -> x
    (4) - Mish(): x -> x
    (5) - SAGEConv(64, 128, aggr=lstm): x, edge_index -> x
    (6) - Mish(): x -> x
    (7) - SAGEConv(128, 128, aggr=lstm): x, edge_index -> x
    (8) - Mish(): x -> x
    (9) - SAGEConv(128, 64, aggr=lstm): x, edge_index -> x
    (10) - Mish(): x -> x
    (11) - SAGEConv(64, 32, aggr=lstm): x, edge_index -> x
    (12) - Mish(): x -> x
    (13) - SAGEConv(32, 1, aggr=lstm): x, edge_index -> x
  )
  (integrator): Sequential(
    (0) - SAGEConv(2, 32, aggr=lstm): x, edge_index -> x
    (1) - Mish(): x -> x
    (2) - SAGEConv(32, 64, aggr=lstm): x, edge_index -> x
    (3) - Mish(): x -> x
    (4) - SAGEConv(64, 128, aggr=lstm): x, edge_index -> x
    (5) - Mish(): x -> x
    (6) - SAGEConv(128, 128, aggr=lstm): x, edge_

In [13]:
# Train
if 1:
    n_time = 20
    delta_t = 4.0 / n_time
    start_epoch = 0
    train_record = []
    eval_record = []
    for epoch in range(args.n_epoch):
        CUDA_LAUCH_BLOCKING = 1
        torch.cuda.empty_cache()
        
        train_loss = 0
        model.train()
        for i in range(train_loader.__len__()):
            _data = next(iter(train_loader))
            loader = NeighborLoader(_data, num_neighbors=[1], batch_size=30000)
            for (_, data) in enumerate(loader):
                data = data.to(args.device)
                F_true = torch.cat([data.pressure.unsqueeze(2)], dim=2) \
                            .float().to(args.device)
                F_dot_true = torch.diff(F_true, dim=1) / delta_t
                F_true = F_true[:,1:]
                
                args.optimizer.zero_grad()
                F_pred, F_dot_pred = model.forward(data, device=args.device, n_time=n_time)

                loss = args.criterion(F_pred.unsqueeze(0), F_true.unsqueeze(0)) \
                        + args.criterion(F_dot_pred.unsqueeze(0), F_dot_true.unsqueeze(0))
                loss.backward()
                args.optimizer.step()
                train_loss += loss.item()

        train_loss /= train_loader.__len__()
        train_record.append(train_loss)

        torch.cuda.empty_cache()

        eval_loss = 0
        model.eval()
        for i in range(test_loader.__len__()):
            _data = next(iter(test_loader))
            loader = NeighborLoader(_data, num_neighbors=[1], batch_size=30000)
            for (_, data) in enumerate(loader):
                data = data.to(args.device)
                F_true = torch.cat([data.pressure.unsqueeze(2)], dim=2) \
                            .float().to(args.device)
                F_dot_true = torch.diff(F_true, dim=1) / delta_t
                F_true = F_true[:,1:]
                
                F_pred, F_dot_pred = model.forward(data, device=args.device, n_time=n_time)

                loss = args.criterion(F_pred.unsqueeze(0), F_true.unsqueeze(0)) \
                        + args.criterion(F_dot_pred.unsqueeze(0), F_dot_true.unsqueeze(0))
                eval_loss += loss.item()

        eval_loss /= test_loader.__len__()
        eval_record.append(eval_loss)

        print(f'Epoch {epoch}: train loss = {train_loss}; eval loss = {eval_loss}')

        if (epoch+1) % 20 == 0:
            torch.save(model.state_dict(), f'models/{model.name}_node1_epoch{start_epoch+epoch+1}.pth')
            scheduler.step()

Epoch 0: train loss = 8.975405299663544; eval loss = 15.871912002563477
Epoch 1: train loss = 8.755983924865722; eval loss = 15.688855171203613
Epoch 2: train loss = 8.554837906360627; eval loss = 15.522186279296875
Epoch 3: train loss = 8.369728100299834; eval loss = 15.370030403137207
Epoch 4: train loss = 8.198942625522614; eval loss = 15.231188774108887
Epoch 5: train loss = 8.041369318962097; eval loss = 15.104832649230957
Epoch 6: train loss = 7.896286654472351; eval loss = 14.990443229675293
Epoch 7: train loss = 7.76327508687973; eval loss = 14.887731075286865
Epoch 8: train loss = 7.64211550951004; eval loss = 14.796544551849365
Epoch 9: train loss = 7.53274267911911; eval loss = 14.716784477233887
Epoch 10: train loss = 7.43514232635498; eval loss = 14.648363590240479
Epoch 11: train loss = 7.349305140972137; eval loss = 14.591103553771973
Epoch 12: train loss = 7.27513964176178; eval loss = 14.544672966003418
Epoch 13: train loss = 7.2124155282974245; eval loss = 14.50855541

KeyboardInterrupt: 

In [None]:
# plt.plot(train_record, label='train loss')
# plt.plot(eval_record, label='eval loss')
# plt.legend()
# plt.show()

In [None]:
# Test
if 1:
    model = GraphPARC(
        n_fields=args.n_fields,
        n_meshfields=args.n_meshfields,
        hidden_channels=args.hidden_size,
        num_layers=args.n_layers,
        dropout=args.dropout,
        act=args.act,
    )
    setattr(model, 'name', 'GraphPARC')
    model = model.to(args.device)
    model.load_state_dict(torch.load(f'models/{model.name}_node1_epoch20.pth', map_location=args.device) )
    # optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
    # setattr(args, 'optimizer', optimizer)
    model

In [None]:
if 1:
    device = args.device
    torch.cuda.empty_cache()
    i_data = 38
    n_time = 20
    data = dataset[i_data].to(args.device)


    ##
    F_true = torch.cat([data.pressure.unsqueeze(2)], dim=2) \
                            .float().to(args.device)
    F_true = F_true[:,1:]

    ##
    model.eval()

    with torch.no_grad():
        
        F_pred, F_dot_pred = model.forward(data, device=device, n_time=n_time)

    loss = args.criterion(F_pred, F_true)
    print(loss.item())



    ##
    # gen = dataset[i_data].node_attr[:,5]
    node_list = [1, 10, 50, 100, 1000, 2000, 5000, 20000, 30000, 40000, 50000]
    ## Draw pressure
    import matplotlib.pyplot as plt
    for i_node in node_list:
        i_field = 0
        y_pred = F_pred.cpu().numpy()[i_node,:,i_field]
        y_true = F_true.cpu().numpy()[i_node,:,i_field]

        # print(y_true.shape, y_pred.shape)
        x = [i * 4.0 /(n_time) for i in range(y_pred.shape[0])]
        # print(data.node_attr.numpy()[i_node, 6])
        # plt.ylim(0,1)
        # plt.title(f'gen = {int(gen[i_node])}')
        plt.plot(x, y_pred, c='red', label='GNN model')
        plt.plot(x, y_true, c='blue', linestyle='dashdot', label='1DCFD')
        # plt.ylim([0.46,0.55])
        plt.legend(loc='upper right')
        plt.ylabel('Pressure', fontsize=20)
        plt.xlabel('Time', fontsize=20)
        plt.show()