-
Notifications
You must be signed in to change notification settings - Fork 1
/
models.py
101 lines (75 loc) · 2.6 KB
/
models.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
import os
import tensorflow as tf
from tensorflow.keras import Model
from tensorflow.keras.layers import (
Input,
Concatenate,
MaxPool1D,
Dense
)
from layers import GConv, ConcatAdj
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
def PointSegGCN(cfg):
'''
Builds PointSegGCN model from skip connections and GCN layers
:param cfg: Model parameters retrieved from cfg file
:return: Built TF model, ready for forward pass
'''
F = cfg['n_node_features']
num_classes = cfg['num_classes']
X_in = Input(shape=(F,), name='X_in')
A_in = Input(shape=(None,), sparse=True)
levels = 4
skips = []
x = GConv(32)([X_in, A_in])
X_1 = x
for i in range(levels):
x = GConv(32, dropout=True)([x, A_in])
skips.append(x)
skips = reversed(skips)
for skip in skips:
x = GConv(32)([x, A_in])
x = Concatenate()([x, skip])
x = Concatenate()([x, X_1])
output = GConv(num_classes, activation='softmax', kernel_init='glorot_uniform')([x, A_in])
model = Model(inputs=[X_in, A_in], outputs=output, name='PointSegGCN_v1')
return model
def Dense_GCN(cfg, levels=3):
'''
Builds a Dense GCN model with vertex-wise skip connections and MLP layers
:param cfg: Model parameters retrieved from cfg file
:param levels: No. of hierarchichal feature extraction levels
:return: Built TF model, ready for forward pass
'''
F = cfg['n_node_features']
num_classes = cfg['num_classes']
X_in = Input(shape=(F,), name='X_in', batch_size=1)
A_in = Input(shape=(None,), name='A_in', sparse=True, batch_size=1)
x_skips = []
a_skips = []
x, a = X_in, A_in
x = GConv(32)([x, a])
x_skips.append(x)
a_skips.append(a)
for i in range(levels):
x = GConv(32, True)([x, a])
x = Concatenate(axis=0)([x, x_skips[i]])
x_skips.append(x)
# Block-diagonal concatenation for adjacency matrix
a = ConcatAdj()(a, a_skips[i])
a_skips.append(a)
x = GConv(32)([x, a])
x = Concatenate(axis=0)([x, *x_skips])
for j in range(len(a_skips)):
a = ConcatAdj()(a, a_skips[j])
x = GConv(32)([x, a])
# Max pooling kernel size computation
mp_size = int(3 * 2 ** levels - 1)
# MLP block
x = MaxPool1D(pool_size=mp_size, data_format='channels_last')(tf.expand_dims(x, 0))
x = Dense(256, activation='softmax')(x)
x = Dense(128, activation='softmax')(x)
x = Dense(64, activation='softmax')(x)
output = Dense(num_classes, activation='softmax')(x)
model = Model(inputs=[X_in, A_in], outputs=output, name='Dense_GCN')
return model