In [1]:
import torch
from torch_geometric.loader import DataLoader
from data.graph_dataset import OneDDatasetBuilder, OneDDatasetLoader, normalize, dataset_to_loader
import matplotlib.pyplot as plt
import os
from model.Geo_DeepOnet_v2 import Net1D_v2
from data.utils import LpLoss
import torch.nn as nn
from time import time
os.environ["CUDA_VISIBLE_DEVICES"]="0"

In [None]:
dataset = OneDDatasetLoader(
    root_dir='./pre_data',
    sub_dir='normalized',
    subjects='all',
    time_names=[str(i).zfill(3) for i in range(201)],
    data_type = torch.float32
)


train_set, test_set = dataset_to_loader(
    dataset=dataset,
    data_subset_dict={
        'train': list(range(0, 30)),
        'test': list(range(31, 35))
    },
    n_data_per_batch=1
)

Epoch = 200000
learning_rate = 1e-3
learning_rate_decay = [20,250,Epoch]

model = Net1D_v2(n_branch=13, width=1000, depth=3, p=512, act=torch.relu).cuda()

myloss = LpLoss(size_average=True)
opt = torch.optim.Adam(model.parameters(), lr=learning_rate)
lr_decay = torch.optim.lr_scheduler.MultiStepLR(opt, learning_rate_decay)

x_trunk = torch.linspace(-1, 1, 201).cuda()
x_trunk = x_trunk[:, None]
tanh = nn.Tanh()

start = time()

for epoch in range(Epoch):

    mean_loss = []
    
    for step, x in enumerate(train_set):
        x = x.cuda()

        output= model(x.node_attr, x.edge_index, x_trunk)

        loss = myloss(output, x.pressure)

        mean_loss.append(loss.item())

        opt.zero_grad()
        loss.backward()
        opt.step()
    lr_decay.step()
    mean_loss = sum(mean_loss) / len(mean_loss)

    print('epoch:{}, loss:{:.3f}, time:{:.3f} min'.format(epoch, mean_loss, ((time() -start)/60)))

    if epoch!=0 and epoch%500==0:
        torch.save(model.state_dict(), './module/Geo_DeepOnet_v2/model{}-{:.3f}.pth'.format(epoch, mean_loss))

Processing...
Done!


epoch:0, loss:1.847, time:0.116 min
epoch:1, loss:0.855, time:0.215 min
epoch:2, loss:0.968, time:0.322 min
epoch:3, loss:0.938, time:0.424 min
epoch:4, loss:0.905, time:0.526 min
epoch:5, loss:0.867, time:0.626 min
epoch:6, loss:0.771, time:0.730 min
epoch:7, loss:0.682, time:0.831 min
epoch:8, loss:0.733, time:0.940 min
epoch:9, loss:0.669, time:1.037 min
epoch:10, loss:0.521, time:1.143 min
epoch:11, loss:0.544, time:1.233 min
epoch:12, loss:0.474, time:1.334 min
epoch:13, loss:0.465, time:1.437 min
epoch:14, loss:0.519, time:1.534 min
epoch:15, loss:0.530, time:1.630 min
epoch:16, loss:0.542, time:1.736 min
epoch:17, loss:0.537, time:1.835 min
epoch:18, loss:0.569, time:1.932 min
epoch:19, loss:0.474, time:2.036 min
epoch:20, loss:0.462, time:2.140 min
epoch:21, loss:0.438, time:2.241 min
epoch:22, loss:0.436, time:2.344 min
epoch:23, loss:0.434, time:2.444 min
epoch:24, loss:0.433, time:2.546 min
epoch:25, loss:0.431, time:2.650 min
epoch:26, loss:0.430, time:2.751 min
epoch:27, l