Skip to content

Commit

Permalink
feat: Move updated GravNet model into the repository
Browse files Browse the repository at this point in the history
  • Loading branch information
shahrukhqasim committed Mar 11, 2019
1 parent ab86f0a commit 1cf5988
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 6 deletions.
13 changes: 8 additions & 5 deletions python/models/gravnet_segment.py
@@ -1,13 +1,15 @@
from models.network_segment_interface import NetworkSegmentInterface
import tensorflow as tf
from caloGraphNN import *
from ops.ties import *
from caloGraphNN import layer_global_exchange, high_dim_dense
from ops.ties import layer_GravNet2

class GravnetSegment(NetworkSegmentInterface):
def build_network_segment(self, feat):
x = feat
feat_in = feat
feat_in= high_dim_dense(feat_in, 32, activation=tf.nn.leaky_relu)
# feat_in= high_dim_dense(feat_in, 32, activation=tf.nn.leaky_relu)

all_feats = []
for i in range(4):
x = tf.layers.batch_normalization(x, momentum=0.8, training=self.training)
x = layer_global_exchange(x)
Expand All @@ -19,9 +21,10 @@ def build_network_segment(self, feat):
n_neighbours=30,
n_dimensions=4,
n_filters=64,
n_propagate=64)
n_propagate=32)
all_feats.append(x)

x = tf.concat([x, feat_in], axis=-1)
x = tf.concat(all_feats+[feat_in], axis=-1)
x = high_dim_dense(x, 128, activation=tf.nn.relu)
x = high_dim_dense(x, 128, activation=tf.nn.relu)
return x
35 changes: 34 additions & 1 deletion python/ops/ties.py
Expand Up @@ -57,4 +57,37 @@ def edge_conv_layer(vertices_in, num_neighbors=30,

vertex_out = aggregation_function(edge, axis=2)

return vertex_out
return vertex_out


def layer_GravNet2(vertices_in,
n_neighbours,
n_dimensions,
n_filters,
n_propagate):
vertices_prop = high_dim_dense(vertices_in, n_propagate, activation=None)
neighb_dimensions = high_dim_dense(vertices_in, n_dimensions, activation=None) # BxVxND,

indexing, distance = indexing_tensor(neighb_dimensions, n_neighbours)

net = tf.gather_nd(vertices_prop, indexing) # BxVxNxF

distance_scale = 1 + tf.nn.softmax(-distance)[..., tf.newaxis]

net = distance_scale * net
batch, max_vertices, _, _ = net.shape

net = tf.reduce_mean(net, axis=-1)
# net = tf.reshape(net, shape=(batch, max_vertices, -1))
net = high_dim_dense(net, n_filters, activation=None)
print(net.shape)
return net
0/0
return net
print(net.shape, distance.shape)
0/0

collapsed = collapse_to_vertex(indexing, distance, vertices_prop)
updated_vertices = tf.concat([vertices_in, collapsed], axis=-1)

return high_dim_dense(updated_vertices, n_filters, activation=None)

0 comments on commit 1cf5988

Please sign in to comment.