In [None]:
import sys
sys.path.insert(1, '../')
import os
os.environ["CUDA_VISIBLE_DEVICES"]="2" 
CUDA_LAUNCH_BLOCKING=1

In [2]:
import torch
from data.graph_dataset import OneDDatasetBuilder, OneDDatasetLoader
from data.pre_process import *
from data.post_process import print_1D
from networks.gcnv1 import GraphNet
import matplotlib.pyplot as plt
from networks.losses import LpLoss

In [3]:
test = False

In [4]:
class objectview(object):
    def __init__(self, d) -> None:
        self.__dict__ = d
    def setattr(self, attr_name, attr_value):
        self.__dict__[attr_name] = attr_value

args = objectview({
    'n_fields': 1,
    'n_meshfields': (9, 5),
    'hidden_size': 512,
    'n_layers': 10,
    'n_timesteps': 201,
    'n_previous_timesteps': 1,
    'aggr': 'sum',
    'act': 'relu',
    'dropout': 0.2,
    'device': torch.device('cuda:0' if torch.cuda.is_available() else 'cpu'),
    'lr': 5e-5,
    'weight_decay': 5e-4,
    'n_epoch': 50000,
    'alpha': 1.0, 
    'batch_size': 100,
    'timestep': 201,
    'timeslice_hops': 0,
    'timeslice_steps': 1,
    'n_data_per_batch': 1,
    'forward_sequence': False,
    'criterion': torch.nn.MSELoss(),
    'plot': False
})

In [None]:
# # Build dataset
# dataset = OneDDatasetBuilder(
#     raw_dir='/data1/tam/datasets',
#     root_dir='/data1/tam/downloaded_datasets_v3',
#     sub_dir='processed',
#     subjects='all',
#     refined_max_length=4.,
#     time_names=[str(i).zfill(3) for i in range(201)],
#     data_type = torch.float64,
# )

In [5]:
# Load raw dataset
dataset = OneDDatasetLoader(
    root_dir='/data1/tam/downloaded_datasets_v2',
    sub_dir='normalized',
    subjects='all',
    time_names=[str(i).zfill(3) for i in range(201)],
    data_type = torch.float64,
)

In [6]:
dataset[0]

TorchGraphData(edge_index=[2, 64111], edge_index_raw=[2, 59587], original_flag=[64112], node_attr=[64112, 9], edge_attr=[64111, 15], pressure=[64112, 201], flowrate=[64112, 201])

In [None]:
from torch_geometric.loader import DataLoader

loader = DataLoader(dataset, batch_size=2)

In [None]:
next(iter(loader))

In [None]:
# Normalize dataset
dataset = normalize(
    dataset=dataset,
    sub_dir='normalized',
    scaler_dict={
        'node_attr' : ['minmax_scaler']*args.n_meshfields[0],
        'edge_attr' : ['minmax_scaler']*args.n_meshfields[1],
        'pressure' : 'minmax_scaler',
    },
    clipping=5e-4
)

In [None]:
# Load normalized datasset
dataset = OneDDatasetLoader(
    root_dir='/data1/tam/downloaded_datasets_v3',
    sub_dir='normalized',
    subjects='all',
    time_names=[str(i).zfill(3) for i in range(201)],
    data_type = torch.float64,
)

In [None]:
# Batch dataset
# batched_dataset = batchgraph_generation_wise(
#     sub_dir = 'batched',
#     dataset=dataset,
#     batch_gens=[[0,9], [10, 13], [14, 17], [18, 50]],
#     subset_hops=1,
# )
batched_dataset = batchgraph_edgewise(
    sub_dir = 'batched',
    dataset=dataset,
    subset_hops=2,
)

In [None]:
# Load batched dataset
dataset = OneDDatasetLoader(
    root_dir='/data1/tam/downloaded_datasets_v3',
    sub_dir='batched',
    subjects=['10081'],
    time_names=[str(i).zfill(3) for i in range(201)],
    data_type = torch.float64,
)

In [None]:
dataset

In [None]:
dataset.processed_file_names()

In [None]:

(train_loader, test_loader) = dataset_to_loader(
    dataset=dataset,
    data_subset_dict={
        'train': list(range(0, 30)),
        'test': list(range(30, 35))
    },
    n_data_per_batch=200
)

In [None]:
if not test:
    model = GraphNet(
        n_fields=args.n_fields,
        n_meshfields=args.n_meshfields,
        n_timesteps=args.n_timesteps,
        hidden_size=args.hidden_size,
        n_layers=args.n_layers,
        n_previous_timesteps=args.n_previous_timesteps,
        act=args.act,
    )
    setattr(model, 'name', 'model_GraphUNet')
    model = model.to(args.device)
    # model.load_state_dict(torch.load(f'models/{model.name}_node2_epoch6000.pth', map_location=args.device) )
    optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
    setattr(args, 'optimizer', optimizer)
    model

In [None]:
# Train
if not test:
    start_epoch = 0
    for epoch in range(args.n_epoch):
        CUDA_LAUCH_BLOCKING = 1
        torch.cuda.empty_cache()

        train_loss = 0
        for i in range(train_loader.__len__()):
            data = next(iter(train_loader))
            data = data.to(args.device)
            
            F_true = torch.cat([data.pressure.unsqueeze(2)], dim=2) \
                        .float().to(args.device)
            model.train()
            args.optimizer.zero_grad()
            F_pred = model.forward(data)

            loss = args.criterion(F_pred, F_true)
            loss.backward()
            args.optimizer.step()
            train_loss += loss.item()
        train_loss /= train_loader.__len__()

        eval_loss = 0
        for i in range(test_loader.__len__()):
            data = next(iter(test_loader))
            data = data.to(args.device)

            F_true = torch.cat([data.pressure.unsqueeze(2)], dim=2) \
                        .float().to(args.device)
            model.eval()
            F_pred = model.forward(data)

            loss = args.criterion(F_pred, F_true)
            eval_loss += loss.item()
        eval_loss /= test_loader.__len__()

        print(f'Epoch {epoch}: train loss = {train_loss}; eval loss = {eval_loss}')

        if (epoch+1) % 50 == 0:
            torch.save(model.state_dict(), f'models/{model.name}_node1_epoch{start_epoch+epoch+1}.pth')


In [None]:
# Test
if test:
    model = GraphNet(
        n_fields=args.n_fields,
        n_meshfields=args.n_meshfields,
        n_timesteps=args.n_timesteps,
        hidden_size=args.hidden_size,
        n_layers=args.n_layers,
        n_previous_timesteps=args.n_previous_timesteps,
        act=args.act,
    )
    setattr(model, 'name', 'model_GraphUNet')
    model = model.to(args.device)
    model.load_state_dict(torch.load(f'models/{model.name}_node1_epoch50.pth', map_location=args.device) )
    optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
    setattr(args, 'optimizer', optimizer)
    model

In [None]:
if test:
    i_data = 38
    data = dataset[i_data].to(args.device)


    ##
    F_true = torch.cat([data.pressure.unsqueeze(2)], dim=2) \
                .float().to(args.device)

    ##
    model.eval()

    with torch.no_grad():
        
        F_pred = model.forward(
            data=data
        )
    # F_pred = torch.transpose(F_pred, 1, 2)
    print(F_true.size(), F_pred.size())

    loss = args.criterion(F_pred, F_true)
    print(loss.item())



    ##
    # gen = dataset[i_data].node_attr[:,5]
    node_list = [1, 10, 50, 100, 1000, 2000, 5000, 20000, 30000, 40000, 50000]
    ## Draw pressure
    import matplotlib.pyplot as plt
    for i_node in node_list:
        i_field = 0
        y_pred = F_pred.cpu().numpy()[i_node,:,i_field]
        y_true = F_true.cpu().numpy()[i_node,:,i_field]

        # print(y_true.shape, y_pred.shape)
        x = [i * 4.0 /200 for i in range(y_pred.shape[0])]
        # print(data.node_attr.numpy()[i_node, 6])
        plt.ylim(0,1)
        # plt.title(f'gen = {int(gen[i_node])}')
        plt.plot(x, y_pred, c='red', label='GNN model')
        plt.plot(x, y_true, c='blue', linestyle='dashdot', label='1DCFD')
        # plt.ylim([0.46,0.55])
        plt.legend(loc='upper right')
        plt.ylabel('Pressure', fontsize=20)
        plt.xlabel('Time', fontsize=20)
        plt.show()