In [1]:
import sys
sys.path.insert(1, '../')
import os
os.environ["CUDA_VISIBLE_DEVICES"]="1" 
CUDA_LAUNCH_BLOCKING=1

In [2]:
import torch
from data.graph_dataset import OneDDatasetBuilder, OneDDatasetLoader
from data.pre_process import *
from data.post_process import print_1D
from networks.gcnv1 import GraphNet
import matplotlib.pyplot as plt
from networks.losses import LpLoss

In [3]:
class objectview(object):
    def __init__(self, d) -> None:
        self.__dict__ = d
    def setattr(self, attr_name, attr_value):
        self.__dict__[attr_name] = attr_value

args = objectview({
    'n_fields': 1,
    'n_meshfields': (9, 15),
    'hidden_size': 512,
    'n_layers': 10,
    'n_timesteps': 201,
    'n_previous_timesteps': 1,
    'aggr': 'sum',
    'act': 'relu',
    'dropout': 0.15,
    'device': torch.device('cuda:0' if torch.cuda.is_available() else 'cpu'),
    'lr': 1e-5,
    'weight_decay': 5e-4,
    'n_epoch': 50000,
    'alpha': 1.0,
    'batch_size': 100,
    'timestep': 201,
    'timeslice_hops': 0,
    'timeslice_steps': 1,
    'n_data_per_batch': 2,
    'forward_sequence': False,
    'criterion': LpLoss(),
    'plot': False
})

In [4]:
# # Build dataset
# dataset = OneDDatasetBuilder(
#     raw_dir='/data1/tam/datasets',
#     root_dir='/data1/tam/downloaded_datasets_v2',
#     sub_dir='processed',
#     subjects='all',
#     refined_max_length=4.,
#     time_names=[str(i).zfill(3) for i in range(201)],
#     data_type = torch.float64,
# )

In [5]:
# Load raw dataset
# dataset = OneDDatasetLoader(
#     root_dir='/data1/tam/downloaded_datasets_v2',
#     sub_dir='processed',
#     subjects='all',
#     time_names=[str(i).zfill(3) for i in range(201)],
#     data_type = torch.float64,
# )

In [6]:
# Normalize dataset
# dataset = normalize(
#     dataset=dataset,
#     sub_dir='normalized',
#     scaler_dict={
#         'node_attr' : ['minmax_scaler']*9,
#         'edge_attr' : ['minmax_scaler']*15,
#         'pressure' : 'minmax_scaler',
#     },
#     clipping=5e-4
# )

In [7]:
# Load normalized datasset
# dataset = OneDDatasetLoader(
#     root_dir='/data1/tam/downloaded_datasets_v2',
#     sub_dir='normalized',
#     subjects='all',
#     time_names=[str(i).zfill(3) for i in range(201)],
#     data_type = torch.float64,
# )

In [8]:
# Batch dataset
# batched_dataset = batchgraph_generation_wise(
#     sub_dir = 'batched',
#     dataset=dataset,
#     batch_gens=[[0,9], [10, 13], [14, 17], [18, 50]],
#     subset_hops=1,
# )

In [9]:
# Load batched dataset
dataset = OneDDatasetLoader(
    root_dir='/data1/tam/downloaded_datasets_v2',
    sub_dir='batched',
    subjects='all',
    time_names=[str(i).zfill(3) for i in range(201)],
    data_type = torch.float64,
)

In [10]:
(train_loader, test_loader) = dataset_to_loader(
    dataset=dataset,
    data_subset_dict={
        'train': list(range(0, 30)),
        'test': list(range(30, 35))
    },
    n_data_per_batch=args.n_data_per_batch
)

In [11]:
model = GraphNet(
    n_fields=args.n_fields,
    n_meshfields=args.n_meshfields,
    n_timesteps=args.n_timesteps,
    hidden_size=args.hidden_size,
    n_layers=args.n_layers,
    n_previous_timesteps=args.n_previous_timesteps,
    act=args.act,
)
setattr(model, 'name', 'model_GraphUNet')
model = model.to(args.device)
# model.load_state_dict(torch.load(f'models/{model.name}_node2_epoch6000.pth', map_location=args.device) )
optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
setattr(args, 'optimizer', optimizer)
model

GraphNet(
  (act): ReLU()
  (net1): GraphConv()
  (net2): GraphUNet(512, 512, 512, depth=5, pool_ratios=[0.5, 0.5, 0.5, 0.5, 0.5])
  (net3): MLP(512, 512, 201)
)

In [None]:
# Train
start_epoch = 0
for epoch in range(args.n_epoch):
    CUDA_LAUCH_BLOCKING = 1
    torch.cuda.empty_cache()

    train_loss = 0
    for i in range(train_loader.__len__()):
        data = next(iter(train_loader))
        data = data.to(args.device)
        
        F_true = torch.cat([data.pressure.unsqueeze(2)], dim=2) \
                    .float().to(args.device)
        model.train()
        args.optimizer.zero_grad()
        F_pred = model.forward(data)

        loss = args.criterion(F_pred, F_true)
        loss.backward()
        args.optimizer.step()
        train_loss += loss.item()
    train_loss /= train_loader.__len__()

    eval_loss = 0
    for i in range(test_loader.__len__()):
        data = next(iter(test_loader))
        data = data.to(args.device)

        F_true = torch.cat([data.pressure.unsqueeze(2)], dim=2) \
                    .float().to(args.device)
        model.eval()
        F_pred = model.forward(data)

        loss = args.criterion(F_pred, F_true)
        eval_loss += loss.item()
    eval_loss /= test_loader.__len__()

    print(f'Epoch {epoch}: train loss = {train_loss}; eval loss = {eval_loss}')

    if (epoch+1) % 20 == 0:
        torch.save(model.state_dict(), f'models/{model.name}_node1_epoch{start_epoch+epoch+1}.pth')


  adj = torch.sparse_csr_tensor(


Epoch 0: train loss = 0.4693300951272249; eval loss = 0.28785794973373413
Epoch 1: train loss = 0.19331734602650005; eval loss = 0.22585470974445343
Epoch 2: train loss = 0.16500102157394092; eval loss = 0.2200707048177719
Epoch 3: train loss = 0.1447904149691264; eval loss = 0.24198313057422638
Epoch 4: train loss = 0.12717481292784213; eval loss = 0.2543887495994568
Epoch 5: train loss = 0.11529046408832073; eval loss = 0.28710174560546875
Epoch 6: train loss = 0.10557033084332942; eval loss = 0.3235306739807129
Epoch 7: train loss = 0.09637945480644702; eval loss = 0.32248154282569885
Epoch 8: train loss = 0.0899086520075798; eval loss = 0.31760627031326294
Epoch 9: train loss = 0.08402706881364187; eval loss = 0.3083067536354065
Epoch 10: train loss = 0.07861129275212685; eval loss = 0.3518325984477997
Epoch 11: train loss = 0.07395327699681123; eval loss = 0.3065919578075409
Epoch 12: train loss = 0.06995657042910655; eval loss = 0.3291291892528534
Epoch 13: train loss = 0.0663830

KeyboardInterrupt: 