forked from lu-group/mionet
-
Notifications
You must be signed in to change notification settings - Fork 0
/
main.py
84 lines (77 loc) · 2.23 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
"""
@author: jpzxshi
"""
import learner as ln
from data import ADVDData
from mionet_periodic import MIONet_periodic
from postprocessing import L2_relative_error
# advection-diffusion system
def main():
device = 'gpu' # 'cpu', 'gpu'
#### data
sensors1 = 100
sensors2 = 100
mesh = [100, 100]
p = 100
length_scale = 0.5
train_num = 1000
test_num = 1000
##### net
net_type = 'MIONet_periodic' # 'MIONet', 'MIONet_periodic', 'DeepONet'
if net_type == 'MIONet':
sizes = [
[sensors1, 300, 300, 300],
[sensors2, -300],
[2, 300, 300, 300]
]
activation = 'relu'
initializer = 'Glorot normal'
elif net_type == 'MIONet_periodic':
sizes = [
[sensors1, 248, 248, 248],
[sensors2, -248],
['p', 248, 248, 248],
[1, 248, 248, 248]
]
activation = 'relu'
initializer = 'Glorot normal'
elif net_type == 'DeepONet':
sizes = [
[sensors1 + sensors2, 300, 300, 300],
[2, 300, 300, 0]
]
activation = 'relu'
initializer = 'Glorot normal'
##### training
lr = 0.0002
iterations = 100000
batch_size = None
print_every = 1000
data = ADVDData(sensors1, sensors2, mesh, p, length_scale, train_num, test_num)
if net_type == 'MIONet_periodic':
data.trans_to_P()
elif net_type == 'DeepONet':
data.trans_to_D()
Net_class = MIONet_periodic if net_type == 'MIONet_periodic' else ln.nn.MIONet
net = Net_class(sizes, activation, initializer)
args = {
'data': data,
'net': net,
'criterion': 'MSE',
'optimizer': 'adam',
'lr': lr,
'iterations': iterations,
'batch_size': batch_size,
'print_every': print_every,
'save': True,
'callback': None,
'dtype': 'float',
'device': device
}
ln.Brain.Init(**args)
ln.Brain.Run()
ln.Brain.Restore()
ln.Brain.Output()
print('L2 relative error:', L2_relative_error(data, ln.Brain.Best_model()))
if __name__ == '__main__':
main()