In [1]:
import torch
from torch_geometric.loader import DataLoader
from data.graph_dataset import OneDDatasetBuilder, OneDDatasetLoader, normalize, dataset_to_loader
import matplotlib.pyplot as plt
import os
from model.Geo_DeepOnet import Net1D
from data.utils import LpLoss
from time import time
os.environ["CUDA_VISIBLE_DEVICES"]="0"

In [2]:
# dataset = OneDDatasetLoader(
#     root_dir='./pre_data',
#     sub_dir='normalized',
#     subjects='all',
#     time_names=[str(i).zfill(3) for i in range(201)],
#     data_type = torch.float32
# )

dataset = OneDDatasetLoader(
    root_dir='./pre_data',
    sub_dir='batched',
    subjects='all',
    time_names=[str(i).zfill(3) for i in range(201)],
    data_type = torch.float32
)

train_set, test_set = dataset_to_loader(
    dataset=dataset,
    data_subset_dict={
        'train': list(range(0, 30)),
        'test': list(range(31, 35))
    },
    n_data_per_batch=100
)

Processing...
Done!


In [None]:
Epoch = 200000
learning_rate = 1e-3
learning_rate_decay = [20,Epoch]

model = Net1D(n_branch=13, width=1000, depth=3, p=512, act=torch.relu).cuda()

myloss = LpLoss(size_average=True)
# myloss = torch.nn.MSELoss()
opt = torch.optim.Adam(model.parameters(), lr=learning_rate)
lr_decay = torch.optim.lr_scheduler.MultiStepLR(opt, learning_rate_decay)

x_trunk = torch.linspace(-1, 1, 201).cuda()
x_trunk = x_trunk[:, None]

start = time()

for epoch in range(Epoch):

    mean_loss = []
    
    for step, x in enumerate(train_set):
        x = x.cuda()

        output= model(x.node_attr, x.edge_index, x_trunk)

        loss = myloss(output.to(torch.float32), x.pressure.to(torch.float32))

        mean_loss.append(loss.item())

        opt.zero_grad()
        loss.backward()
        opt.step()
    lr_decay.step()
    mean_loss = sum(mean_loss) / len(mean_loss)

    print('epoch:{}, loss:{:.3f}, time:{:.3f} min'.format(epoch, mean_loss, ((time() -start)/60)))

    if epoch!=0 and epoch%500==0:
        torch.save(model.state_dict(), './module/Geo_DeepOnet_batch/model{}-{:.3f}.pth'.format(epoch, mean_loss))

epoch:0, loss:0.206, time:0.155 min
epoch:1, loss:0.068, time:0.285 min
epoch:2, loss:0.054, time:0.414 min
epoch:3, loss:0.063, time:0.542 min
epoch:4, loss:0.052, time:0.670 min
epoch:5, loss:0.058, time:0.799 min
epoch:6, loss:0.053, time:0.930 min
epoch:7, loss:0.052, time:1.061 min
epoch:8, loss:0.044, time:1.189 min
epoch:9, loss:0.052, time:1.318 min
epoch:10, loss:0.044, time:1.446 min
epoch:11, loss:0.050, time:1.576 min
epoch:12, loss:0.044, time:1.707 min
epoch:13, loss:0.042, time:1.838 min
epoch:14, loss:0.041, time:1.973 min
epoch:15, loss:0.040, time:2.108 min
epoch:16, loss:0.054, time:2.239 min
epoch:17, loss:0.046, time:2.370 min
epoch:18, loss:0.039, time:2.501 min
epoch:19, loss:0.039, time:2.633 min
epoch:20, loss:0.034, time:2.765 min
epoch:21, loss:0.031, time:2.896 min
epoch:22, loss:0.030, time:3.030 min
epoch:23, loss:0.030, time:3.165 min
epoch:24, loss:0.029, time:3.296 min
epoch:25, loss:0.029, time:3.427 min
epoch:26, loss:0.029, time:3.558 min
epoch:27, l