In [None]:

import tensorflow as tf
#import tf_util
import numpy as np


def attn_feature(input_feature, output_dim, neighbors_idx, activation, in_dropout=0.0, coef_dropout=0.0, is_training=None, bn_decay=None, layer='', k=20, i=0, is_dist=False):
    batch_size = input_feature.get_shape()[0]
    num_dim = input_feature.get_shape()[-1]

    input_feature = tf.squeeze(input_feature)
    if batch_size == 1:
        input_feature = tf.expand_dims(input_feature, 0)

    input_feature = tf.expand_dims(input_feature, axis=-2)


    # if in_dropout != 0.0:
    #     input = tf.nn.dropout(input, 1.0 - in_dropout)

    new_feature = conv2d_nobias(input_feature, output_dim, [1, 1], padding='VALID', stride=[1, 1], bn=True,
                                        is_training=is_training, scope=layer + '_newfea_conv_head_' + str(i),
                                        bn_decay=bn_decay, is_dist=is_dist)
    #Encode the high level features in a 1 layer CNN, the weights are learnable parameters of this filter

    neighbors = get_neighbors(input_feature, nn_idx=neighbors_idx, k=k) #Group up the neighbors using the index passed on the arguments
    input_feature_tiled = tf.tile(input_feature, [1, 1, k, 1])
    edge_feature = input_feature_tiled - neighbors #Make the edge features yij
    #edge_feature = tf.concat([input_feature_tiled, input_feature_tiled-neighbors], axis=-1)
    edge_feature = conv2d(edge_feature, output_dim, [1, 1], padding='VALID', stride=[1, 1],
                               bn=True, is_training=is_training, scope=layer + '_edgefea_' + str(i), bn_decay=bn_decay, is_dist=is_dist)
    #Enconde that as well

    self_attention = conv2d(new_feature, 1, [1, 1], padding='VALID', stride=[1, 1], bn=True,
                                  is_training=is_training, scope=layer+'_self_att_conv_head_'+str(i), bn_decay=bn_decay, is_dist=is_dist)
    
    neibor_attention = conv2d(edge_feature, 1, [1, 1], padding='VALID', stride=[1, 1], bn=True,
                                    is_training=is_training, scope=layer+'_neib_att_conv_head_'+str(i), bn_decay=bn_decay, is_dist=is_dist)
    #To merge both contributions, pass them to a 1 layer, size 1 output

    logits = self_attention + neibor_attention
    logits = tf.transpose(a=logits, perm=[0, 1, 3, 2])

    coefs = tf.nn.softmax(tf.nn.leaky_relu(logits))
    # zero_tf =  tf.fill(tf.shape(coefs), tf.constant(0.0, dtype=coefs.dtype))
    # coefs = tf.where(tf.less_equal(coefs,0.1),zero_tf,coefs ) #Keep only att. > 0.1
    #coefs = tf.nn.softmax(tf.nn.relu(logits))
    # coefs = tf.ones_like(coefs)
    #
    # if coef_dropout != 0.0:
    #     coefs = tf.nn.dropout(coefs, 1.0 - coef_dropout)


    vals = tf.matmul(coefs, edge_feature)

    if is_dist:
        ret = activation(vals)
    else:
        # ret = tf.contrib.layers.bias_add(vals)
        ret = activation(vals)


    return ret, self_attention, edge_feature
    #return ret, coefs, edge_feature

In [None]:
#tf_util.py
  
""" Wrapper functions for TensorFlow layers.
Author: Charles R. Qi
Date: November 2016
Upadted by Yue Wang and Yongbin Sun
"""

import numpy as np
import tensorflow as tf
# import lorentz
from math import *
from itertools import combinations


# from tensorflow.contrib.seq2seq.python.ops import attention_wrapper
# import tensorflow.contrib.seq2seq as seq2seq

def _variable_on_cpu(name, shape, initializer, use_fp16=False, trainable=True):
    """Helper to create a Variable stored on CPU memory.
  Args:
    name: name of the variable
    shape: list of ints
    initializer: initializer for Variable
  Returns:
    Variable Tensor
  """
    with tf.device('/cpu:0'):
        dtype = tf.float16 if use_fp16 else tf.float32
        var = tf.compat.v1.get_variable(name, shape, initializer=initializer, dtype=dtype, trainable=trainable,
                                        use_resource=False)
    return var


def _variable_with_weight_decay(name, shape, stddev, wd, use_xavier=True):
    """Helper to create an initialized Variable with weight decay.
  Note that the Variable is initialized with a truncated normal distribution.
  A weight decay is added only if one is specified.
  Args:
    name: name of the variable
    shape: list of ints
    stddev: standard deviation of a truncated Gaussian
    wd: add L2Loss weight decay multiplied by this float. If None, weight
        decay is not added for this Variable.
    use_xavier: bool, whether to use xavier initializer
  Returns:
    Variable Tensor
  """
    if use_xavier:
        initializer = tf.compat.v1.keras.initializers.VarianceScaling(scale=1.0, mode="fan_avg", distribution="uniform")
    else:
        initializer = tf.compat.v1.truncated_normal_initializer(stddev=stddev)
    var = _variable_on_cpu(name, shape, initializer)
    if wd is not None:
        weight_decay = tf.multiply(tf.nn.l2_loss(var), wd, name='weight_loss')
        tf.compat.v1.add_to_collection('losses', weight_decay)
    return var


def conv1d(inputs,
           num_output_channels,
           kernel_size,
           scope,
           stride=1,
           padding='SAME',
           use_xavier=True,
           stddev=1e-3,
           weight_decay=0.0,
           activation_fn=tf.nn.relu,
           bn=False,
           bn_decay=None,
           is_training=None,
           is_dist=False):
    """ 1D convolution with non-linear operation.
  Args:
    inputs: 3-D tensor variable BxLxC
    num_output_channels: int
    kernel_size: int
    scope: string
    stride: int
    padding: 'SAME' or 'VALID'
    use_xavier: bool, use xavier_initializer if true
    stddev: float, stddev for truncated_normal init
    weight_decay: float
    activation_fn: function
    bn: bool, whether to use batch norm
    bn_decay: float or float tensor variable in [0,1]
    is_training: bool Tensor variable
  Returns:
    Variable tensor
  """
    with tf.compat.v1.variable_scope(scope) as sc:
        num_in_channels = inputs.get_shape()[-1].value
        kernel_shape = [kernel_size,
                        num_in_channels, num_output_channels]
        kernel = _variable_with_weight_decay('weights',
                                             shape=kernel_shape,
                                             use_xavier=use_xavier,
                                             stddev=stddev,
                                             wd=weight_decay)
        outputs = tf.nn.conv1d(input=inputs, filters=kernel,
                               stride=stride,
                               padding=padding)
        biases = _variable_on_cpu('biases', [num_output_channels],
                                  tf.compat.v1.constant_initializer(0.0))
        outputs = tf.nn.bias_add(outputs, biases)

        if bn:
            outputs = batch_norm_for_conv1d(outputs, is_training,
                                            bn_decay=bn_decay, scope='bn', is_dist=is_dist)

        if activation_fn is not None:
            outputs = activation_fn(outputs)
        return outputs


def conv2d_nobias(inputs,
                  num_output_channels,
                  kernel_size,
                  scope,
                  stride=[1, 1],
                  padding='SAME',
                  use_xavier=True,
                  stddev=1e-3,
                  weight_decay=0.0,
                  activation_fn=tf.nn.relu,
                  bn=False,
                  bn_decay=None,
                  is_training=None,
                  is_dist=False):
    """ 2D convolution with non-linear operation.
      Args:
        inputs: 4-D tensor variable BxHxWxC
        num_output_channels: int
        kernel_size: a list of 2 ints
        scope: string
        stride: a list of 2 ints
        padding: 'SAME' or 'VALID'
        use_xavier: bool, use xavier_initializer if true
        stddev: float, stddev for truncated_normal init
        weight_decay: float
        activation_fn: function
        bn: bool, whether to use batch norm
        bn_decay: float or float tensor variable in [0,1]
        is_training: bool Tensor variable
      Returns:
        Variable tensor
      """

    with tf.compat.v1.variable_scope(scope) as sc:
        kernel_h, kernel_w = kernel_size
        num_in_channels = inputs.get_shape()[-1]
        kernel_shape = [kernel_h, kernel_w,
                        num_in_channels, num_output_channels]
        kernel = _variable_with_weight_decay('weights',
                                             shape=kernel_shape,
                                             use_xavier=use_xavier,
                                             stddev=stddev,
                                             wd=weight_decay)
        stride_h, stride_w = stride
        outputs = tf.nn.conv2d(input=inputs, filters=kernel,
                               strides=[1, stride_h, stride_w, 1],
                               padding=padding)

        if bn:
            outputs = batch_norm_for_conv2d(outputs, is_training,
                                            bn_decay=bn_decay, scope='bn', is_dist=is_dist)

        if activation_fn is not None:
            outputs = activation_fn(outputs)
        return outputs


def get_neighbors(point_cloud, nn_idx, k=20):
    """Construct neighbors feature for each point
      Args:
        point_cloud: (batch_size, num_points, 1, num_dims)
        nn_idx: (batch_size, num_points, k)
        k: int
      Returns:
        neighbors features: (batch_size, num_points, k, num_dims)
      """
    og_batch_size = point_cloud.get_shape().as_list()[0]
    og_num_dims = point_cloud.get_shape().as_list()[-1]
    point_cloud = tf.squeeze(point_cloud)
    if og_batch_size == 1:
        point_cloud = tf.expand_dims(point_cloud, 0)
    if og_num_dims == 1:
        point_cloud = tf.expand_dims(point_cloud, -1)

    point_cloud_shape = point_cloud.get_shape()
    batch_size = point_cloud_shape[0]
    num_points = point_cloud_shape[1]
    num_dims = point_cloud_shape[2]

    idx_ = tf.range(batch_size) * num_points
    idx_ = tf.reshape(idx_, [batch_size, 1, 1])

    point_cloud_flat = tf.reshape(point_cloud, [-1, num_dims])
    point_cloud_neighbors = tf.gather(point_cloud_flat, nn_idx + idx_)

    return point_cloud_neighbors


def conv2d(inputs,
           num_output_channels,
           kernel_size,
           scope,
           stride=[1, 1],
           padding='SAME',
           use_xavier=True,
           stddev=1e-3,
           weight_decay=0.0,
           activation_fn=tf.nn.relu,
           bn=False,
           bn_decay=None,
           is_training=None,
           is_dist=False):
    """ 2D convolution with non-linear operation.
  Args:
    inputs: 4-D tensor variable BxHxWxC
    num_output_channels: int
    kernel_size: a list of 2 ints
    scope: string
    stride: a list of 2 ints
    padding: 'SAME' or 'VALID'
    use_xavier: bool, use xavier_initializer if true
    stddev: float, stddev for truncated_normal init
    weight_decay: float
    activation_fn: function
    bn: bool, whether to use batch norm
    bn_decay: float or float tensor variable in [0,1]
    is_training: bool Tensor variable
  Returns:
    Variable tensor
  """
    with tf.compat.v1.variable_scope(scope) as sc:
        kernel_h, kernel_w = kernel_size
        num_in_channels = inputs.get_shape()[-1]
        kernel_shape = [kernel_h, kernel_w,
                        num_in_channels, num_output_channels]
        kernel = _variable_with_weight_decay('weights',
                                             shape=kernel_shape,
                                             use_xavier=use_xavier,
                                             stddev=stddev,
                                             wd=weight_decay)
        stride_h, stride_w = stride
        outputs = tf.nn.conv2d(input=inputs, filters=kernel,
                               strides=[1, stride_h, stride_w, 1],
                               padding=padding)
        biases = _variable_on_cpu('biases', [num_output_channels],
                                  tf.compat.v1.constant_initializer(0.0))
        outputs = tf.nn.bias_add(outputs, biases)

        if bn:
            outputs = batch_norm_for_conv2d(outputs, is_training,
                                            bn_decay=bn_decay, scope='bn', is_dist=is_dist)

        if activation_fn is not None:
            outputs = activation_fn(outputs)
        return outputs


def conv2d_transpose(inputs,
                     num_output_channels,
                     kernel_size,
                     scope,
                     stride=[1, 1],
                     padding='SAME',
                     use_xavier=True,
                     stddev=1e-3,
                     weight_decay=0.0,
                     activation_fn=tf.nn.relu,
                     bn=False,
                     bn_decay=None,
                     is_training=None,
                     is_dist=False):
    """ 2D convolution transpose with non-linear operation.
  Args:
    inputs: 4-D tensor variable BxHxWxC
    num_output_channels: int
    kernel_size: a list of 2 ints
    scope: string
    stride: a list of 2 ints
    padding: 'SAME' or 'VALID'
    use_xavier: bool, use xavier_initializer if true
    stddev: float, stddev for truncated_normal init
    weight_decay: float
    activation_fn: function
    bn: bool, whether to use batch norm
    bn_decay: float or float tensor variable in [0,1]
    is_training: bool Tensor variable
  Returns:
    Variable tensor
  Note: conv2d(conv2d_transpose(a, num_out, ksize, stride), a.shape[-1], ksize, stride) == a
  """
    with tf.compat.v1.variable_scope(scope) as sc:
        kernel_h, kernel_w = kernel_size
        num_in_channels = inputs.get_shape()[-1]
        kernel_shape = [kernel_h, kernel_w,
                        num_output_channels, num_in_channels]  # reversed to conv2d
        kernel = _variable_with_weight_decay('weights',
                                             shape=kernel_shape,
                                             use_xavier=use_xavier,
                                             stddev=stddev,
                                             wd=weight_decay)
        stride_h, stride_w = stride

        # from slim.convolution2d_transpose
        def get_deconv_dim(dim_size, stride_size, kernel_size, padding):
            dim_size *= stride_size

            if padding == 'VALID' and dim_size is not None:
                dim_size += max(kernel_size - stride_size, 0)
            return dim_size

        # caculate output shape
        batch_size = inputs.get_shape()[0]
        height = inputs.get_shape()[1]
        width = inputs.get_shape()[2]
        out_height = get_deconv_dim(height, stride_h, kernel_h, padding)
        out_width = get_deconv_dim(width, stride_w, kernel_w, padding)
        output_shape = [batch_size, out_height, out_width, num_output_channels]

        outputs = tf.nn.conv2d_transpose(inputs, kernel, output_shape,
                                         [1, stride_h, stride_w, 1],
                                         padding=padding)
        biases = _variable_on_cpu('biases', [num_output_channels],
                                  tf.compat.v1.constant_initializer(0.0))
        outputs = tf.nn.bias_add(outputs, biases)

        if bn:
            outputs = batch_norm_for_conv2d(outputs, is_training,
                                            bn_decay=bn_decay, scope='bn', is_dist=is_dist)

        if activation_fn is not None:
            outputs = activation_fn(outputs)
        return outputs


def conv3d(inputs,
           num_output_channels,
           kernel_size,
           scope,
           stride=[1, 1, 1],
           padding='SAME',
           use_xavier=True,
           stddev=1e-3,
           weight_decay=0.0,
           activation_fn=tf.nn.relu,
           bn=False,
           bn_decay=None,
           is_training=None,
           is_dist=False):
    """ 3D convolution with non-linear operation.
  Args:
    inputs: 5-D tensor variable BxDxHxWxC
    num_output_channels: int
    kernel_size: a list of 3 ints
    scope: string
    stride: a list of 3 ints
    padding: 'SAME' or 'VALID'
    use_xavier: bool, use xavier_initializer if true
    stddev: float, stddev for truncated_normal init
    weight_decay: float
    activation_fn: function
    bn: bool, whether to use batch norm
    bn_decay: float or float tensor variable in [0,1]
    is_training: bool Tensor variable
  Returns:
    Variable tensor
  """
    with tf.compat.v1.variable_scope(scope) as sc:
        kernel_d, kernel_h, kernel_w = kernel_size
        num_in_channels = inputs.get_shape()[-1]
        kernel_shape = [kernel_d, kernel_h, kernel_w,
                        num_in_channels, num_output_channels]
        kernel = _variable_with_weight_decay('weights',
                                             shape=kernel_shape,
                                             use_xavier=use_xavier,
                                             stddev=stddev,
                                             wd=weight_decay)
        stride_d, stride_h, stride_w = stride
        outputs = tf.nn.conv3d(inputs, kernel,
                               [1, stride_d, stride_h, stride_w, 1],
                               padding=padding)
        biases = _variable_on_cpu('biases', [num_output_channels],
                                  tf.compat.v1.constant_initializer(0.0))
        outputs = tf.nn.bias_add(outputs, biases)

        if bn:
            outputs = batch_norm_for_conv3d(outputs, is_training,
                                            bn_decay=bn_decay, scope='bn', is_dist=is_dist)

        if activation_fn is not None:
            outputs = activation_fn(outputs)
        return outputs


def fully_connected(inputs,
                    num_outputs,
                    scope,
                    use_xavier=True,
                    stddev=1e-3,
                    weight_decay=0.0,
                    activation_fn=tf.nn.relu,
                    bn=False,
                    bn_decay=None,
                    is_training=None,
                    is_dist=False):
    """ Fully connected layer with non-linear operation.
  
  Args:
    inputs: 2-D tensor BxN
    num_outputs: int
  
  Returns:
    Variable tensor of size B x num_outputs.
  """
    with tf.compat.v1.variable_scope(scope) as sc:
        num_input_units = inputs.get_shape()[-1]
        weights = _variable_with_weight_decay('weights',
                                              shape=[num_input_units, num_outputs],
                                              use_xavier=use_xavier,
                                              stddev=stddev,
                                              wd=weight_decay)
        outputs = tf.matmul(inputs, weights)
        biases = _variable_on_cpu('biases', [num_outputs],
                                  tf.compat.v1.constant_initializer(0.0))
        outputs = tf.nn.bias_add(outputs, biases)

        if bn:
            outputs = batch_norm_for_fc(outputs, is_training, bn_decay, 'bn', is_dist=is_dist)

        if activation_fn is not None:
            if activation_fn == tf.nn.softmax:
                outputs = activation_fn(outputs - tf.reduce_max(input_tensor=outputs, axis=1, keepdims=True))
            else:
                outputs = activation_fn(outputs)
        return outputs


def max_pool2d(inputs,
               kernel_size,
               scope,
               stride=[2, 2],
               padding='VALID'):
    """ 2D max pooling.
  Args:
    inputs: 4-D tensor BxHxWxC
    kernel_size: a list of 2 ints
    stride: a list of 2 ints
  
  Returns:
    Variable tensor
  """
    with tf.compat.v1.variable_scope(scope) as sc:
        kernel_h, kernel_w = kernel_size
        stride_h, stride_w = stride
        outputs = tf.nn.max_pool2d(input=inputs,
                                 ksize=[1, kernel_h, kernel_w, 1],
                                 strides=[1, stride_h, stride_w, 1],
                                 padding=padding,
                                 name=sc.name)
        return outputs


def avg_pool2d(inputs,
               kernel_size,
               scope,
               stride=[2, 2],
               padding='VALID'):
    """ 2D avg pooling.
  Args:
    inputs: 4-D tensor BxHxWxC
    kernel_size: a list of 2 ints
    stride: a list of 2 ints
  
  Returns:
    Variable tensor
  """
    with tf.compat.v1.variable_scope(scope) as sc:
        kernel_h, kernel_w = kernel_size
        stride_h, stride_w = stride
        outputs = tf.nn.avg_pool2d(input=inputs,
                                 ksize=[1, kernel_h, kernel_w, 1],
                                 strides=[1, stride_h, stride_w, 1],
                                 padding=padding,
                                 name=sc.name)
        return outputs


def max_pool3d(inputs,
               kernel_size,
               scope,
               stride=[2, 2, 2],
               padding='VALID'):
    """ 3D max pooling.
  Args:
    inputs: 5-D tensor BxDxHxWxC
    kernel_size: a list of 3 ints
    stride: a list of 3 ints
  
  Returns:
    Variable tensor
  """
    with tf.compat.v1.variable_scope(scope) as sc:
        kernel_d, kernel_h, kernel_w = kernel_size
        stride_d, stride_h, stride_w = stride
        outputs = tf.nn.max_pool3d(inputs,
                                   ksize=[1, kernel_d, kernel_h, kernel_w, 1],
                                   strides=[1, stride_d, stride_h, stride_w, 1],
                                   padding=padding,
                                   name=sc.name)
        return outputs


def avg_pool3d(inputs,
               kernel_size,
               scope,
               stride=[2, 2, 2],
               padding='VALID'):
    """ 3D avg pooling.
  Args:
    inputs: 5-D tensor BxDxHxWxC
    kernel_size: a list of 3 ints
    stride: a list of 3 ints
  
  Returns:
    Variable tensor
  """
    with tf.compat.v1.variable_scope(scope) as sc:
        kernel_d, kernel_h, kernel_w = kernel_size
        stride_d, stride_h, stride_w = stride
        outputs = tf.nn.avg_pool3d(inputs,
                                   ksize=[1, kernel_d, kernel_h, kernel_w, 1],
                                   strides=[1, stride_d, stride_h, stride_w, 1],
                                   padding=padding,
                                   name=sc.name)
        return outputs


def batch_norm_template(inputs, is_training, scope, moments_dims, bn_decay):
    """ Batch normalization on convolutional maps and beyond...
  Ref.: http://stackoverflow.com/questions/33949786/how-could-i-use-batch-normalization-in-tensorflow
  
  Args:
      inputs:        Tensor, k-D input ... x C could be BC or BHWC or BDHWC
      is_training:   boolean tf.Varialbe, true indicates training phase
      scope:         string, variable scope
      moments_dims:  a list of ints, indicating dimensions for moments calculation
      bn_decay:      float or float tensor variable, controling moving average weight
  Return:
      normed:        batch-normalized maps
  """
    with tf.compat.v1.variable_scope(scope) as sc:
        num_channels = inputs.get_shape()[-1]
        beta = tf.Variable(tf.constant(0.0, shape=[num_channels]),
                           name='beta', trainable=True)
        gamma = tf.Variable(tf.constant(1.0, shape=[num_channels]),
                            name='gamma', trainable=True)
        batch_mean, batch_var = tf.nn.moments(x=inputs, axes=moments_dims, name='moments')
        decay = bn_decay if bn_decay is not None else 0.9
        ema = tf.train.ExponentialMovingAverage(decay=decay)
        # Operator that maintains moving averages of variables.
        ema_apply_op = tf.cond(pred=is_training,
                               true_fn=lambda: ema.apply([batch_mean, batch_var]),
                               false_fn=lambda: tf.no_op())

        # Update moving average and return current batch's avg and var.
        def mean_var_with_update():
            with tf.control_dependencies([ema_apply_op]):
                return tf.identity(batch_mean), tf.identity(batch_var)

        # ema.average returns the Variable holding the average of var.
        mean, var = tf.cond(pred=is_training,
                            true_fn=mean_var_with_update,
                            false_fn=lambda: (ema.average(batch_mean), ema.average(batch_var)))
        # print('types again', inputs.dtype, mean.dtype, var.dtype, beta.dtype, gamma.dtype)
        # normed = tf.nn.batch_normalization(inputs, mean, var, tf.cast(beta,tf.float64), tf.cast(gamma,tf.float64), 1e-3)
        normed = tf.nn.batch_normalization(inputs, mean, var, beta, gamma, 1e-3)
    return normed


def batch_norm_dist_template(inputs, is_training, scope, moments_dims, bn_decay):
    """ The batch normalization for distributed training.
  Args:
      inputs:        Tensor, k-D input ... x C could be BC or BHWC or BDHWC
      is_training:   boolean tf.Varialbe, true indicates training phase
      scope:         string, variable scope
      moments_dims:  a list of ints, indicating dimensions for moments calculation
      bn_decay:      float or float tensor variable, controling moving average weight
  Return:
      normed:        batch-normalized maps
  """
    with tf.compat.v1.variable_scope(scope) as sc:
        num_channels = inputs.get_shape()[-1]
        beta = _variable_on_cpu('beta', [num_channels], initializer=tf.compat.v1.zeros_initializer())
        gamma = _variable_on_cpu('gamma', [num_channels], initializer=tf.compat.v1.ones_initializer())

        pop_mean = _variable_on_cpu('pop_mean', [num_channels], initializer=tf.compat.v1.zeros_initializer(), trainable=False)
        pop_var = _variable_on_cpu('pop_var', [num_channels], initializer=tf.compat.v1.ones_initializer(), trainable=False)

        def train_bn_op():
            batch_mean, batch_var = tf.nn.moments(x=inputs, axes=moments_dims, name='moments')
            decay = bn_decay if bn_decay is not None else 0.9
            # decay = tf.cast(decay,tf.float64)
            # print('types:', pop_mean.dtype,  decay.dtype,batch_mean.dtype)
            train_mean = tf.compat.v1.assign(pop_mean, pop_mean * decay + batch_mean * (1 - decay))
            train_var = tf.compat.v1.assign(pop_var, pop_var * decay + batch_var * (1 - decay))
            with tf.control_dependencies([train_mean, train_var]):
                return tf.nn.batch_normalization(inputs, batch_mean, batch_var, beta, gamma, 1e-3)

        def test_bn_op():
            return tf.nn.batch_normalization(inputs, pop_mean, pop_var, beta, gamma, 1e-3)

        normed = tf.cond(pred=is_training,
                         true_fn=train_bn_op,
                         false_fn=test_bn_op)
        return normed


def batch_norm_for_fc(inputs, is_training, bn_decay, scope, is_dist=False):
    """ Batch normalization on FC data.
  
  Args:
      inputs:      Tensor, 2D BxC input
      is_training: boolean tf.Varialbe, true indicates training phase
      bn_decay:    float or float tensor variable, controling moving average weight
      scope:       string, variable scope
      is_dist:     true indicating distributed training scheme
  Return:
      normed:      batch-normalized maps
  """
    if is_dist:
        return batch_norm_dist_template(inputs, is_training, scope, [0, ], bn_decay)
    else:
        return batch_norm_template(inputs, is_training, scope, [0, ], bn_decay)


def batch_norm_for_conv1d(inputs, is_training, bn_decay, scope, is_dist=False):
    """ Batch normalization on 1D convolutional maps.
  
  Args:
      inputs:      Tensor, 3D BLC input maps
      is_training: boolean tf.Varialbe, true indicates training phase
      bn_decay:    float or float tensor variable, controling moving average weight
      scope:       string, variable scope
      is_dist:     true indicating distributed training scheme
  Return:
      normed:      batch-normalized maps
  """
    if is_dist:
        return batch_norm_dist_template(inputs, is_training, scope, [0, 1], bn_decay)
    else:
        return batch_norm_template(inputs, is_training, scope, [0, 1], bn_decay)


def batch_norm_for_conv2d(inputs, is_training, bn_decay, scope, is_dist=False):
    """ Batch normalization on 2D convolutional maps.
  
  Args:
      inputs:      Tensor, 4D BHWC input maps
      is_training: boolean tf.Varialbe, true indicates training phase
      bn_decay:    float or float tensor variable, controling moving average weight
      scope:       string, variable scope
      is_dist:     true indicating distributed training scheme
  Return:
      normed:      batch-normalized maps
  """
    if is_dist:
        return batch_norm_dist_template(inputs, is_training, scope, [0, 1, 2], bn_decay)
    else:
        return batch_norm_template(inputs, is_training, scope, [0, 1, 2], bn_decay)


def batch_norm_for_conv3d(inputs, is_training, bn_decay, scope, is_dist=False):
    """ Batch normalization on 3D convolutional maps.
  
  Args:
      inputs:      Tensor, 5D BDHWC input maps
      is_training: boolean tf.Varialbe, true indicates training phase
      bn_decay:    float or float tensor variable, controling moving average weight
      scope:       string, variable scope
      is_dist:     true indicating distributed training scheme
  Return:
      normed:      batch-normalized maps
  """
    if is_dist:
        return batch_norm_dist_template(inputs, is_training, scope, [0, 1, 2, 3], bn_decay)
    else:
        return batch_norm_template(inputs, is_training, scope, [0, 1, 2, 3], bn_decay)


def dropout(inputs,
            is_training,
            scope,
            keep_prob=0.5,
            noise_shape=None):
    """ Dropout layer.
  Args:
    inputs: tensor
    is_training: boolean tf.Variable
    scope: string
    keep_prob: float in [0,1]
    noise_shape: list of ints
  Returns:
    tensor variable
  """
    with tf.compat.v1.variable_scope(scope) as sc:
        outputs = tf.cond(pred=is_training,
                          true_fn=lambda: tf.nn.dropout(inputs, 1 - (keep_prob), noise_shape),
                          false_fn=lambda: inputs)
        return outputs


def pairwise_distance(point_cloud):
    """Compute pairwise distance of a point cloud.
  Args:
    point_cloud: tensor (batch_size, num_points, num_dims)
  Returns:
    pairwise distance: (batch_size, num_points, num_points)
  """
    og_batch_size = point_cloud.get_shape().as_list()[0]
    point_cloud = tf.squeeze(point_cloud)
    if og_batch_size == 1:
        point_cloud = tf.expand_dims(point_cloud, 0)  # first dim is batch size

    point_cloud_transpose = tf.transpose(a=point_cloud, perm=[0, 2, 1])
    point_cloud_inner = tf.matmul(point_cloud, point_cloud_transpose)  # x.x + y.y + z.z shape: NxN
    point_cloud_inner = -2 * point_cloud_inner
    point_cloud_square = tf.reduce_sum(input_tensor=tf.square(point_cloud), axis=-1,
                                       keepdims=True)  # from x.x, y.y, z.z to x.x + y.y + z.z
    point_cloud_square_tranpose = tf.transpose(a=point_cloud_square, perm=[0, 2, 1])
    return point_cloud_square + point_cloud_inner + point_cloud_square_tranpose


def pairwise_distanceR(point_cloud):
    """Compute pairwise distance in the eta-phi plane for the point cloud.
  Uses the third dimension to find the zero-padded terms
  Args:
    point_cloud: tensor (batch_size, num_points, 2)
    IMPORTANT: The order should be (eta, phi) 
  Returns:
    pairwise distance: (batch_size, num_points, num_points)
  """
    og_batch_size = point_cloud.get_shape().as_list()[0]
    point_cloud = tf.squeeze(point_cloud)
    if og_batch_size == 1:
        point_cloud = tf.expand_dims(point_cloud, 0)  # first dim is batch size

    pt = point_cloud[:, :, 2]
    pt = tf.expand_dims(pt, -1)
    is_zero = point_cloud[:, :, ]
    point_shift = 1000 * tf.compat.v1.where(tf.equal(pt, 0), tf.ones_like(pt),
                                  tf.fill(tf.shape(input=pt), tf.constant(0.0, dtype=pt.dtype)))
    point_shift_transpose = tf.transpose(a=point_shift, perm=[0, 2, 1])
    # pt = tf.exp(pt)
    point_cloud = point_cloud[:, :, :2]

    point_cloud_transpose = tf.transpose(a=point_cloud, perm=[0, 2, 1])
    point_cloud_phi = point_cloud_transpose[:, 1:, :]
    point_cloud_phi = tf.tile(point_cloud_phi, [1, point_cloud_phi.get_shape()[2], 1])
    point_cloud_phi_transpose = tf.transpose(a=point_cloud_phi, perm=[0, 2, 1])
    point_cloud_phi = tf.abs(point_cloud_phi - point_cloud_phi_transpose)
    is_bigger2pi = tf.greater_equal(tf.abs(point_cloud_phi), 2 * np.pi)
    point_cloud_phi_corr = tf.compat.v1.where(is_bigger2pi, 4 * np.pi ** 2 - 4 * np.pi * point_cloud_phi,
                                    point_cloud_phi - point_cloud_phi)
    point_cloud_inner = tf.matmul(point_cloud, point_cloud_transpose)  # x.x + y.y + z.z shape: NxN
    point_cloud_inner = -2 * point_cloud_inner
    point_cloud_square = tf.reduce_sum(input_tensor=tf.square(point_cloud), axis=-1,
                                       keepdims=True)  # from x.x, y.y, z.z to x.x + y.y + z.z
    point_cloud_square_tranpose = tf.transpose(a=point_cloud_square, perm=[0, 2, 1])

    # print("shape",point_cloud_square.shape)
    return point_cloud_square + point_cloud_inner + point_cloud_square_tranpose + point_cloud_phi_corr + point_shift + point_shift_transpose
    # return point_shift


def knn(adj_matrix, k=20):
    """Get KNN based on the pairwise distance.
  Args:
    pairwise distance: (batch_size, num_points, num_points)
    k: int
  Returns:
    nearest neighbors: (batch_size, num_points, k)
  """
    neg_adj = -adj_matrix
    _, nn_idx = tf.nn.top_k(neg_adj, k=k)  # values, indices
    return nn_idx


def get_edge_feature(point_cloud, nn_idx, k=20, edge_type='dgcnn'):
    """Construct edge feature for each point
  Args:
    point_cloud: (batch_size, num_points, 1, num_dims) 
    nn_idx: (batch_size, num_points, k)
    k: int
  Returns:
    edge features: (batch_size, num_points, k, 2*num_dims)
  """
    og_batch_size = point_cloud.get_shape().as_list()[0]
    point_cloud = tf.squeeze(point_cloud)
    if og_batch_size == 1:
        point_cloud = tf.expand_dims(point_cloud, 0)

    point_cloud_central = point_cloud

    point_cloud_shape = point_cloud.get_shape()
    batch_size = point_cloud_shape[0]
    num_points = point_cloud_shape[1]
    num_dims = point_cloud_shape[2]

    idx_ = tf.range(batch_size) * num_points
    idx_ = tf.reshape(idx_, [batch_size, 1, 1])
    point_cloud_flat = tf.reshape(point_cloud, [-1, num_dims])
    point_cloud_neighbors = tf.gather(point_cloud_flat, nn_idx + idx_)

    point_cloud_central = tf.expand_dims(point_cloud_central, axis=-2)

    if edge_type == "dgcnn":
        point_cloud_central = tf.tile(point_cloud_central, [1, 1, k, 1])
        edge_feature = tf.concat([point_cloud_central, point_cloud_neighbors - point_cloud_central], axis=-1)
    elif edge_type == "sub":
        point_cloud_central = tf.tile(point_cloud_central, [1, 1, k, 1])
        # 4 vector difference with invariant mass sum
        edge_feature = Sub_Cloud(point_cloud_central, point_cloud_neighbors)
        edge_feature = tf.concat([point_cloud_central, edge_feature], axis=-1)
    elif edge_type == "add":
        # 4 vector difference with invariant mass sum
        edge_feature = Add_Cloud(point_cloud_central, point_cloud_neighbors)
        edge_feature = tf.concat([point_cloud_central, edge_feature], axis=-1)
        # edge_feature= tf.reduce_max(edge_feature, axis=-2, keep_dims=True)
    return edge_feature


def Sub_Cloud(central, neighbors):
    """ Input: BxPxKxF for central and K-neighbors
    Returns: BxPxKx(F+7), 7 = eta, phi, pt, px, py, pz differences + invariant mass sum """
    num_batch = central.get_shape()[0]
    num_point = central.get_shape()[1]
    num_k = central.get_shape()[2]
    num_dims = central.get_shape()[3]
    point_diff = central - neighbors

    identity = -np.identity(4, dtype=np.float32)  # 4vector
    identity[0][0] = 1
    identity = tf.expand_dims(identity, 0)
    identity = tf.expand_dims(identity, 0)
    identity = tf.expand_dims(identity, 0)
    identity = tf.tile(identity, [num_batch, num_point, num_k, 1, 1])

    sum_vec = neighbors + central
    sum_vec = tf.concat([point_diff[:, :, 0:1, :], sum_vec[:, :, 1:, :]], -2)  # first neighbor is the point itself
    sum_vec = sum_vec[:, :, :, 4:8]
    sum_vec = tf.expand_dims(sum_vec, -2)
    sum_vec_T = tf.transpose(a=sum_vec, perm=[0, 1, 2, 4, 3])
    mult = tf.matmul(sum_vec, identity)
    mult = tf.matmul(mult, sum_vec_T)
    mult = tf.sqrt(tf.abs(mult))
    # return tf.squeeze(mult,axis=-2)
    phi = point_diff[:, :, :, 1:2]  # Correct phi for 2pi bound
    is_bigger2pi = tf.greater_equal(tf.abs(phi), 2 * np.pi)
    phi_corr = tf.compat.v1.where(is_bigger2pi, phi - 2 * np.pi, phi)

    diff_update = point_diff[:, :, :, 0:1]
    diff_update = tf.concat([diff_update, phi_corr], axis=-1)
    diff_update = tf.concat([diff_update, point_diff[:, :, :, 2:3]], axis=-1)
    diff_update = tf.concat([diff_update, tf.squeeze(mult, axis=-2)], axis=-1)
    diff_update = tf.concat([diff_update, point_diff[:, :, :, 4:]], axis=-1)
    return diff_update


def Add_Cloud(central, neighbors):
    """ Input: BxPxKxF for central and K-neighbors
    Returns: BxPxKx(F+7), 5 = E, px, py, pz sum + invariant mass sum """
    num_batch = central.get_shape()[0]
    num_point = central.get_shape()[1]
    num_k = central.get_shape()[2]
    num_dims = central.get_shape()[3]

    identity = -np.identity(4, dtype=np.float32)  # 4vector
    identity[0][0] = 1
    identity = tf.expand_dims(identity, 0)
    identity = tf.expand_dims(identity, 0)
    identity = tf.expand_dims(identity, 0)
    identity = tf.tile(identity, [num_batch, num_point, num_k, 1, 1])

    sum_vec = neighbors + central
    sum_vec = tf.concat([point_diff[:, :, 0:1, :], sum_vec[:, :, 1:, :]], -2)  # first neighbor is the point itself
    sum_vec = sum_vec[:, :, :, 4:8]
    point_sum = sum_vec
    sum_vec = tf.expand_dims(sum_vec, -2)
    sum_vec_T = tf.transpose(a=sum_vec, perm=[0, 1, 2, 4, 3])
    mult = tf.matmul(sum_vec, identity)
    mult = tf.matmul(mult, sum_vec_T)
    mult = tf.sqrt(tf.abs(mult))

    diff_update = tf.squeeze(mult, axis=-2)
    diff_update = tf.concat([diff_update, point_sum], axis=-1)
    return diff_update


def Add_3VecCloud(central, neighbors):
    """ Will add the 4vectors for the sum of 3 4-vectors with the invariant mass sum """
    num_batch = central.get_shape()[0]
    num_point = central.get_shape()[1]
    num_k = central.get_shape()[2]
    num_dims = central.get_shape()[3]
    point_diff = central - neighbors

    identity = -np.identity(4, dtype=np.float32)  # 4vector
    identity[0][0] = 1
    identity = tf.expand_dims(identity, 0)
    identity = tf.expand_dims(identity, 0)
    identity = tf.expand_dims(identity, 0)
    identity = tf.tile(identity, [num_batch, num_point, num_k, 1, 1])

    sum_vec = neighbors + central
    sum_vec = tf.concat([point_diff[:, :, 0:1, :], sum_vec[:, :, 1:, :]], -2)
    sum_vec = sum_vec[:, :, :, 4:8]
    point_sum = sum_vec
    sum_vec = tf.expand_dims(sum_vec, -2)
    sum_vec_T = tf.transpose(a=sum_vec, perm=[0, 1, 2, 4, 3])
    mult = tf.matmul(sum_vec, identity)
    mult = tf.matmul(mult, sum_vec_T)
    mult = tf.sqrt(tf.abs(mult))

    diff_update = tf.squeeze(mult, axis=-2)
    diff_update = tf.concat([diff_update, point_sum], axis=-1)
    return diff_update


# def seq2seq_with_attention(inputs,
#         hidden_size,
#         scope,
#         activation_fn=tf.nn.relu,
#         bn=False,
#         bn_decay=None,
#         is_training=None):
#     """ sequence model with attention.
#        Args:
#          inputs: 4-D tensor variable BxNxTxD
#          hidden_size: int
#          scope: encoder
#          activation_fn: function
#          bn: bool, whether to use batch norm
#          bn_decay: float or float tensor variable in [0,1]
#          is_training: bool Tensor variable
#        Return:
#          Variable Tensor BxNxD
#        """
#     with tf.variable_scope(scope) as sc:
#         batch_size = inputs.get_shape()[0]
#         npoint = inputs.get_shape()[1]
#         nstep = inputs.get_shape()[2]
#         in_size = inputs.get_shape()[3]
#         reshaped_inputs = tf.reshape(inputs, (-1, nstep, in_size))

#         with tf.variable_scope('encoder'):
#             #build encoder
#             encoder_cell = tf.nn.rnn_cell.LSTMCell(hidden_size)
#             encoder_outputs, encoder_state = tf.nn.dynamic_rnn(encoder_cell, reshaped_inputs,
#                                                                sequence_length=tf.fill([batch_size*npoint], 4),
#                                                                dtype=tf.float64, time_major=False)
#         with tf.variable_scope('decoder'):
#             #build decoder
#             decoder_cell = tf.nn.rnn_cell.LSTMCell(hidden_size)
#             decoder_inputs = tf.reshape(encoder_state.h, [batch_size*npoint, 1, hidden_size])

#             # building attention mechanism: default Bahdanau
#             # 'Bahdanau' style attention: https://arxiv.org/abs/1409.0473
#             attention_mechanism = seq2seq.BahdanauAttention(num_units=hidden_size, memory=encoder_outputs)
#             # 'Luong' style attention: https://arxiv.org/abs/1508.04025
#             # attention_mechanism = seq2seq.LuongAttention(num_units=hidden_size, memory=encoder_outputs)

#             # AttentionWrapper wraps RNNCell with the attention_mechanism
#             decoder_cell = seq2seq.AttentionWrapper(cell=decoder_cell, attention_mechanism=attention_mechanism,
#                                                               attention_layer_size=hidden_size)

#             # Helper to feed inputs for training: read inputs from dense ground truth vectors
#             train_helper = seq2seq.TrainingHelper(inputs=decoder_inputs, sequence_length=tf.fill([batch_size*npoint], 1),
#                                                   time_major=False)
#             decoder_initial_state = decoder_cell.zero_state(batch_size=batch_size*npoint, dtype=tf.float64)
#             train_decoder = seq2seq.BasicDecoder(cell=decoder_cell, helper=train_helper, initial_state=decoder_initial_state, output_layer=None)
#             decoder_outputs_train, decoder_last_state_train, decoder_outputs_length_train = seq2seq.dynamic_decode(
#                 decoder=train_decoder, output_time_major=False, impute_finished=True)

#         outputs = tf.reshape(decoder_last_state_train[0].h, (-1, npoint, hidden_size))
#         if bn:
#           outputs = batch_norm_for_fc(outputs, is_training, bn_decay, 'bn')

#         if activation_fn is not None:
#           outputs = activation_fn(outputs)
#         return outputs


if __name__ == "__main__":
    #import provider
    import numpy as np

    batch_size = 1
    num_pt = 3
    pos_dim = 5
    k = 2
    # wmass = ROOT.TH1D("wmass","wmass",100,0,6)
    # test_feed = np.random.rand(batch_size, num_pt, pos_dim)
    # pairwise_distanceR(pointclouds_pl)
    a = np.array(
        [
            [
                [-5.0, 5, 1, 2, 3],
                [0, 0, 0, 0, 0],
                [3, -3, 10, 11, 12],

            ]
        ]
    )

    # a, b, c = provider.load_h5_data_label_seg("../data/ttbb/h5/test_files_ttbar.h5")

    batch_size = a.shape[0]
    with tf.Graph().as_default():
        pointclouds_pl = tf.compat.v1.placeholder(tf.float32, shape=(batch_size, num_pt, pos_dim))
        # nn_idx = tf.placeholder(tf.int32, shape=(batch_size, num_pt, k))
        pair = pairwise_distanceR(pointclouds_pl[:, :, :3])
        nn_idx = knn(pair, k=k)

        # edge = get_edge_feature(pointclouds_pl, nn_idx, k=k,edge_type='sub')
        with tf.compat.v1.Session() as sess:
            sess.run(tf.compat.v1.global_variables_initializer())
            feed_dict = {
                pointclouds_pl: a,
                # nn_idx: idx
            }
            # edges = sess.run([pair], feed_dict=feed_dict)
            idxs, pairs = sess.run([nn_idx, pair], feed_dict=feed_dict)
            print(idxs, pairs)
            # for batch in edges:
            #  for point in batch:
            #    for mass in point:
            #       wmass.Fill(mass[1][0])

    # wmass.Draw()
    # raw_input()
    # pair = pairwise_distanceR(a)
    # r = tf.greater_equal(tf.abs(a),2*np.pi)
    # re = tf.where(r,4*np.pi**2-4*np.pi*tf.abs(a),a-a)
    # print(a.get_shape())
    # print(a.get_shape(),'a')
    # at = tf.transpose(a,[0,2,1])
    # print(at.get_shape(),'at')
    # phi = at[:,1:,:]
    # print(phi.get_shape()[2],'phi')
    # phi5 = tf.tile(phi,[1,5,1])
    # b =  tf.constant([[1],[2]])
    # c = a + b
    # print(pair.eval())


  

>     [[[0 2]
>       [2 0]
>       [2 0]]] [[[   0.      1050.        66.94745]
>       [1050.      2000.      1018.     ]
>       [  66.94745 1018.         0.     ]]]

In [None]:
#gapnet_classify.py
import tensorflow as tf
import numpy as np
import math
import os
import sys

#BASE_DIR = os.path.dirname(os.path.abspath(__file__))
BASE_DIR = './'
sys.path.append(os.path.dirname(BASE_DIR))
sys.path.append(os.path.join(BASE_DIR, '../utils'))
sys.path.append(os.path.join(BASE_DIR, '../models'))
#import tf_util

#from gat_layers import attn_feature


def placeholder_inputs(batch_size, num_point, num_features):
    pointclouds_pl = tf.compat.v1.placeholder(tf.float32, shape=(batch_size, num_point, num_features))
    labels_pl = tf.compat.v1.placeholder(tf.int32, shape=(batch_size))
    return pointclouds_pl, labels_pl


def gap_block(k, n_heads, nn_idx, net, point_cloud, edge_size, bn_decay, weight_decay, is_training, scname):
    attns = []
    local_features = []
    for i in range(n_heads):
        edge_feature, coefs, locals = attn_feature(net, edge_size[1], nn_idx, activation=tf.nn.relu,
                                                   in_dropout=0.6,
                                                   coef_dropout=0.6, is_training=is_training, bn_decay=bn_decay,
                                                   layer='layer{0}'.format(edge_size[0]) + scname, k=k, i=i)
        attns.append(edge_feature)  # This is the edge feature * att. coeff. activated by RELU, one per particle
        local_features.append(locals)  # Those are the yij

    neighbors_features = tf.concat(attns, axis=-1)
    net = tf.squeeze(net)
    neighbors_features = tf.concat([tf.expand_dims(point_cloud, -2), neighbors_features], axis=-1)

    locals_transform = tf.reduce_max(input_tensor=tf.concat(local_features, axis=-1), axis=-2, keepdims=True)

    return neighbors_features, locals_transform, coefs


def get_model(point_cloud, is_training, num_class,
              weight_decay=None, bn_decay=None, scname=''):
    ''' input: BxNxF
    output:BxNx(cats*segms)  '''
    batch_size = point_cloud.get_shape()[0]
    num_point = point_cloud.get_shape()[1]
    num_feat = point_cloud.get_shape()[2]

    k = 10
    adj = pairwise_distanceR(point_cloud[:, :, :3])
    n_heads = 1
    nn_idx = knn(adj, k=k)

    net, locals_transform, coefs = gap_block(k, n_heads, nn_idx, point_cloud, point_cloud, ('filter0', 16), bn_decay,
                                             weight_decay, is_training, scname)

    net = conv2d(net, 64, [1, 1], padding='VALID', stride=[1, 1], activation_fn=tf.nn.relu,
                         bn=True, is_training=is_training, scope='gapnet01' + scname, bn_decay=bn_decay)
    net01 = net

    net = conv2d(net, 128, [1, 1], padding='VALID', stride=[1, 1], activation_fn=tf.nn.relu,
                         bn=True, is_training=is_training, scope='gapnet02' + scname, bn_decay=bn_decay)

    net02 = net
    adj_matrix = pairwise_distance(net)
    nn_idx = knn(adj_matrix, k=k)
    adj_conv = nn_idx
    n_heads = 1

    net, locals_transform1, coefs2 = gap_block(k, n_heads, nn_idx, net, point_cloud, ('filter1', 128), bn_decay,
                                               weight_decay, is_training, scname)

    net = conv2d(net, 256, [1, 1], padding='VALID', stride=[1, 1], activation_fn=tf.nn.relu,
                         bn=True, is_training=is_training, scope='gapnet11' + scname, bn_decay=bn_decay)
    net11 = net

    net = conv2d(net, 256, [1, 1], padding='VALID', stride=[1, 1], activation_fn=tf.nn.relu,
                         bn=True, is_training=is_training, scope='gapnet12' + scname, bn_decay=bn_decay)

    net12 = net

    net = tf.concat([
        net01,
        net02,
        net11,
        net12,
        locals_transform,
        locals_transform1
    ], axis=-1)

    net = conv2d(net, 3, [1, 1], padding='VALID', stride=[1, 1],
                         activation_fn=tf.nn.relu,
                         bn=True, is_training=is_training, scope='agg' + scname, bn_decay=bn_decay)

    net = avg_pool2d(net, [num_point, 1], padding='VALID', scope='avgpool' + scname)
    max_pool = net

    net = tf.reshape(net, [batch_size, -1])
    net = fully_connected(net, 256, bn=True, is_training=is_training, activation_fn=tf.nn.relu,
                                  scope='fc1' + scname, bn_decay=bn_decay)
    net = fully_connected(net, 128, bn=True, is_training=is_training, activation_fn=tf.nn.relu,
                                  scope='fc2' + scname, bn_decay=bn_decay)
    net = fully_connected(net, num_class, activation_fn=None, scope='fc3' + scname)

    net = tf.squeeze(net)

    return net, max_pool


def get_focal_loss(y_pred, label, num_class, gamma=4., alpha=10):
    gamma = float(gamma)
    alpha = float(alpha)
    epsilon = 1.e-9

    labels = tf.one_hot(indices=label, depth=num_class)
    y_true = tf.convert_to_tensor(value=labels, dtype=tf.float32)
    y_pred = tf.convert_to_tensor(value=y_pred, dtype=tf.float32)

    y_pred = tf.nn.softmax(y_pred - tf.reduce_max(input_tensor=y_pred, axis=1, keepdims=True))

    model_out = tf.add(y_pred, epsilon)
    ce = tf.multiply(y_true, -tf.math.log(model_out))
    weight = tf.multiply(y_true, tf.pow(tf.subtract(1., model_out), gamma))
    fl = tf.multiply(alpha, tf.multiply(weight, ce))
    reduced_fl = tf.reduce_max(input_tensor=fl, axis=1)
    return tf.reduce_mean(input_tensor=reduced_fl)


def get_loss_kmeans(max_pool, mu, max_dim, n_clusters, alpha=100):
    list_dist = []
    for i in range(0, n_clusters):
        dist = f_func(tf.squeeze(max_pool), tf.reshape(mu[i, :], (1, max_dim)))
        list_dist.append(dist)
    stack_dist = tf.stack(list_dist)
    min_dist = tf.reduce_min(input_tensor=list_dist, axis=0)

    list_exp = []
    for i in range(n_clusters):
        exp = tf.exp(-alpha * (stack_dist[i] - min_dist))
        list_exp.append(exp)

    stack_exp = tf.stack(list_exp)
    sum_exponentials = tf.reduce_sum(input_tensor=stack_exp, axis=0)

    list_weighted_dist = []
    for j in range(n_clusters):
        softmax = stack_exp[j] / sum_exponentials
        weighted_dist = stack_dist[j] * softmax

        list_weighted_dist.append(weighted_dist)

    stack_weighted_dist = tf.stack(list_weighted_dist)
    kmeans_loss = tf.reduce_mean(input_tensor=tf.reduce_sum(input_tensor=stack_weighted_dist, axis=0))

    return kmeans_loss, stack_dist


def f_func(x, y):
    dists = tf.square(x - y)
    return tf.reduce_sum(input_tensor=dists, axis=1)


In [None]:
#provider.py

import os
import sys
import numpy as np
import h5py
from sklearn.preprocessing import StandardScaler, MinMaxScaler

#BASE_DIR = os.path.dirname(os.path.abspath(__file__))
BASE_DIR = './'
sys.path.append(BASE_DIR)


# Download dataset for point cloud classification


def shuffle_data(data, labels, global_pl=[], weights=[]):
    """ Shuffle data and labels.
    Input:
      data: B,N,... numpy array
      label: B,N, numpy array
    Return:
      shuffled data, label and shuffle indices
  """
    idx = np.arange(len(labels))
    np.random.shuffle(idx)
    # return data[idx,:], labels[idx,:], idx
    if global_pl != []:
        return data[idx, :], labels[idx], global_pl[idx, :], idx
    elif weights == []:
        return data[idx, :], labels[idx], idx
    else:
        return data[idx, :], labels[idx], weights[idx], idx


def rotate_point_cloud(batch_data):
    """ Randomly rotate the point clouds to augument the dataset
    rotation is per shape based along up direction
    Input:
      BxNx3 array, original batch of point clouds
    Return:
      BxNx3 array, rotated batch of point clouds
  """
    rotated_data = np.zeros(batch_data.shape, dtype=np.float32)
    for k in xrange(batch_data.shape[0]):
        rotation_angle = np.random.uniform() * 2 * np.pi
        cosval = np.cos(rotation_angle)
        sinval = np.sin(rotation_angle)
        rotation_matrix = np.array([[cosval, 0, sinval],
                                    [0, 1, 0],
                                    [-sinval, 0, cosval]])
        shape_pc = batch_data[k, ...]
        rotated_data[k, ...] = np.dot(shape_pc.reshape((-1, 3)), rotation_matrix)
    return rotated_data


def rotate_point_cloud_by_angle(batch_data, rotation_angle):
    """ Rotate the point cloud along up direction with certain angle.
    Input:
      BxNx3 array, original batch of point clouds
    Return:
      BxNx3 array, rotated batch of point clouds
  """
    rotated_data = np.zeros(batch_data.shape, dtype=np.float32)
    for k in xrange(batch_data.shape[0]):
        # rotation_angle = np.random.uniform() * 2 * np.pi
        cosval = np.cos(rotation_angle)
        sinval = np.sin(rotation_angle)
        rotation_matrix = np.array([[cosval, 0, sinval],
                                    [0, 1, 0],
                                    [-sinval, 0, cosval]])
        shape_pc = batch_data[k, ...]
        rotated_data[k, ...] = np.dot(shape_pc.reshape((-1, 3)), rotation_matrix)
    return rotated_data


def rotate_perturbation_point_cloud(batch_data, angle_sigma=0.06, angle_clip=0.18):
    """ Randomly perturb the point clouds by small rotations
    Input:
      BxNx3 array, original batch of point clouds
    Return:
      BxNx3 array, rotated batch of point clouds
  """
    rotated_data = np.zeros(batch_data.shape, dtype=np.float32)
    for k in xrange(batch_data.shape[0]):
        angles = np.clip(angle_sigma * np.random.randn(3), -angle_clip, angle_clip)
        Rx = np.array([[1, 0, 0],
                       [0, np.cos(angles[0]), -np.sin(angles[0])],
                       [0, np.sin(angles[0]), np.cos(angles[0])]])
        Ry = np.array([[np.cos(angles[1]), 0, np.sin(angles[1])],
                       [0, 1, 0],
                       [-np.sin(angles[1]), 0, np.cos(angles[1])]])
        Rz = np.array([[np.cos(angles[2]), -np.sin(angles[2]), 0],
                       [np.sin(angles[2]), np.cos(angles[2]), 0],
                       [0, 0, 1]])
        R = np.dot(Rz, np.dot(Ry, Rx))
        shape_pc = batch_data[k, ...]
        rotated_data[k, ...] = np.dot(shape_pc.reshape((-1, 3)), R)
    return rotated_data


def jitter_point_cloud(batch_data, sigma=0.01, clip=0.05):
    """ Randomly jitter points. jittering is per point.
    Input:
      BxNx3 array, original batch of point clouds
    Return:
      BxNx3 array, jittered batch of point clouds
  """
    B, N, C = batch_data.shape
    assert (clip > 0)
    jittered_data = np.clip(sigma * np.random.randn(B, N, C), -1 * clip, clip)
    jittered_data += batch_data
    return jittered_data


def shift_point_cloud(batch_data, shift_range=0.1):
    """ Randomly shift point cloud. Shift is per point cloud.
    Input:
      BxNx3 array, original batch of point clouds
    Return:
      BxNx3 array, shifted batch of point clouds
  """
    B, N, C = batch_data.shape
    shifts = np.random.uniform(-shift_range, shift_range, (B, 3))
    for batch_index in range(B):
        batch_data[batch_index, :, :] += shifts[batch_index, :]
    return batch_data


def random_scale_point_cloud(batch_data, scale_low=0.8, scale_high=1.25):
    """ Randomly scale the point cloud. Scale is per point cloud.
    Input:
      BxNx3 array, original batch of point clouds
    Return:
      BxNx3 array, scaled batch of point clouds
  """
    B, N, C = batch_data.shape
    scales = np.random.uniform(scale_low, scale_high, B)
    for batch_index in range(B):
        batch_data[batch_index, :, :] *= scales[batch_index]
    return batch_data


def norm_inputs_point_cloud(data, cloud=True):
    """ Normalize the input data by the mean of the distribution"""
    if cloud:
        NPOINTS = data.shape[1]
        NFEATURES = data.shape[2]
    else:
        NPOINTS = data.shape[0]
        NFEATURES = data.shape[1]
    reshape = np.reshape(data, (-1, NFEATURES))
    # scaler = StandardScaler()
    scaler = MinMaxScaler()
    scaler.fit(reshape[0::NPOINTS])
    # print(scaler.mean_)
    zero_arr = [0] * NFEATURES
    zero_arr = scaler.transform([zero_arr])
    reshape = scaler.transform(reshape)
    for i in range(NFEATURES):
        reshape[reshape == zero_arr[0][i]] = 0
    if cloud:
        reshape = np.reshape(reshape, (-1, NPOINTS, NFEATURES))
        # print(reshape)
    else:
        reshape = np.reshape(reshape, (-1, NFEATURES))

    print("Normalized the data")
    return reshape


def getDataFiles(list_filename):
    return [line.rstrip() for line in open(list_filename)]


def load_add(h5_filename, names=[]):
    f = h5py.File(h5_filename, 'r')
    if len(names) == 0:
        names = list(f.keys())
        print(names)
        names.remove('data')
        names.remove('pid')
        names.remove('label')

    datasets = {}
    for data in names:
        datasets[data] = f[data][:]

    return datasets


def load_h5(h5_filename, mode='seg', unsup=False, glob=False):
    f = h5py.File(h5_filename, 'r')
    data = f['data'][:]
    # data = norm_inputs_point_cloud(data)
    if mode == 'class':
        label = f['pid'][:].astype(int)
    elif mode == 'seg':
        label = f['label'][:].astype(int)
    else:
        print('No mode found')
    print("loaded {0} events".format(len(data)))
    if glob:
        global_pl = f['global'][:]
        return (data, label, global_pl)
    else:
        return (data, label)
        # global_pl = norm_inputs_point_cloud(global_pl,cloud=False)


def load_h5_weights(h5_filename):
    f = h5py.File(h5_filename, 'r')
    data = f['data'][:]
    label = f['pid'][:]
    weights = {}
    for var in f.keys():
        if 'w' in var:
            weights[var] = f[var][:]

    return (data, label, weights)


def load_h5_eval(h5_filename):
    f = h5py.File(h5_filename, 'r')
    data = f['data'][:]
    label = f['pid'][:]
    weight_nom = np.abs(f['weight_nom'][:])
    weight_up = np.abs(f['weight_up'][:])
    weight_down = np.abs(f['weight_down'][:])
    return (data, label, weight_nom, weight_up, weight_down)


def loadDataFile(filename):
    return load_h5(filename)


def load_h5_data_label_seg(h5_filename):
    f = h5py.File(h5_filename, 'r')
    data = f['data'][:]  # (2048, 2048, 3)
    # data = norm_inputs_point_cloud(data)
    label = f['pid'][:]  # (2048, 1)
    seg = f['label'][:]  # (2048, 2048)
    print("loaded {0} events".format(len(data)))

    return (data, label, seg)

In [None]:
ls 06_LHC/models
ls 06_LHC/utils
ls 06_LHC/h5

  

>     __pycache__
>     gapnet_classify.py
>     gapnet_seg.py
>     gat_layers.py
>     __pycache__
>     provider.py
>     tf_util.py
>     evaluate_files_RD.txt
>     evaluate_files_b1.txt
>     evaluate_files_b2.txt
>     evaluate_files_b3.txt
>     evaluate_files_gwztop.txt
>     evaluate_files_wztop.txt
>     test_files_RD.txt
>     test_files_b1.txt
>     test_files_b2.txt
>     test_files_b3.txt
>     test_files_gwztop.txt
>     test_files_wztop.txt
>     test_multi_20v_100P.h5
>     train_files_RD.txt
>     train_files_b1.txt
>     train_files_b2.txt
>     train_files_b3.txt
>     train_files_gwztop.txt
>     train_files_wztop.txt
>     train_multi_20v_100P.h5

In [None]:
  

  
  
import argparse
import h5py
from math import *
import tensorflow as tf
import numpy as np
from datetime import datetime
import json
import os, ast
import sys



In [None]:
# DEFAULT SETTINGS
#parser = argparse.ArgumentParser()
#parser.add_argument('--gpu', type=int, default=0, help='GPUs to use [default: 0]')
#parser.add_argument('--n_clusters', type=int, default=3, help='Number of clusters [Default: 3]')
#parser.add_argument('--max_dim', type=int, default=3, help='Dimension of the encoding layer [Default: 3]')
#parser.add_argument('--log_dir', default='log', help='Log dir [default: log]')
#parser.add_argument('--batch', type=int, default=512, help='Batch Size  during training [default: 512]')
#parser.add_argument('--num_point', type=int, default=100, help='Point Number [default: 100]')
#parser.add_argument('--data_dir', default='../h5', help='directory with data [default: ../h5]')
#parser.add_argument('--nfeat', type=int, default=8, help='Number of features [default: 8]')
#parser.add_argument('--ncat', type=int, default=20, help='Number of categories [default: 20]')
#parser.add_argument('--name', default="", help='name of the output file')
#parser.add_argument('--h5_folder', default="../h5/", help='folder to store output files')
#parser.add_argument('--full_train', default=False, action='store_true',
#                    help='load full training results [default: False]')

In [None]:
def eval():
    with tf.Graph().as_default():
        with tf.device('/gpu:' + str(FLAGSGPU)):
            pointclouds_pl, labels_pl = placeholder_inputs(BATCH_SIZE, NUM_POINT, NFEATURES)
            batch = tf.Variable(0, trainable=False)
            alpha = tf.compat.v1.placeholder(tf.float32, shape=())
            is_training_pl = tf.compat.v1.placeholder(tf.bool, shape=())
            pred, max_pool = get_model(pointclouds_pl, is_training=is_training_pl, num_class=NUM_CATEGORIES)
            mu = tf.Variable(tf.zeros(shape=(N_CLUSTERS, MAX_DIM)), name="mu",
                             trainable=False)  # k centroids

            classify_loss = get_focal_loss(pred, labels_pl, NUM_CATEGORIES)
            kmeans_loss, stack_dist = get_loss_kmeans(max_pool, mu, MAX_DIM,
                                                            N_CLUSTERS, alpha)

            saver = tf.compat.v1.train.Saver()

        config = tf.compat.v1.ConfigProto()
        config.gpu_options.allow_growth = True
        config.allow_soft_placement = True
        sess = tf.compat.v1.Session(config=config)

        if FULL_TRAINING:
            saver.restore(sess, os.path.join(LOG_DIR, 'cluster.ckpt-27000'))
        else:
            saver.restore(sess, os.path.join(LOG_DIR, 'model.ckpt'))
        print('model restored')

        ops = {'pointclouds_pl': pointclouds_pl,
               'labels_pl': labels_pl,
               'stack_dist': stack_dist,
               'kmeans_loss': kmeans_loss,
               'pred': pred,
               'alpha': alpha,
               'max_pool': max_pool,
               'is_training_pl': is_training_pl,
               'classify_loss': classify_loss, }

        eval_one_epoch(sess, ops)

def get_batch(data, label, start_idx, end_idx):
    batch_label = label[start_idx:end_idx]
    batch_data = data[start_idx:end_idx, :, :]
    return batch_data, batch_label

In [None]:
def eval_one_epoch(sess, ops):
    is_training = False

    eval_idxs = np.arange(0, len(EVALUATE_FILES))
    y_val = []
    for fn in range(len(EVALUATE_FILES)):
        current_file = os.path.join(H5_DIR, EVALUATE_FILES[eval_idxs[fn]])
        current_data, current_label, current_cluster = load_h5_data_label_seg(current_file)
        adds = load_add(current_file, ['masses'])

        current_label = np.squeeze(current_label)

        file_size = current_data.shape[0]
        num_batches = file_size // BATCH_SIZE
        num_batches = 5

        for batch_idx in range(num_batches):
            start_idx = batch_idx * BATCH_SIZE
            end_idx = (batch_idx + 1) * BATCH_SIZE

            batch_data, batch_label = get_batch(current_data, current_label, start_idx, end_idx)
            batch_cluster = current_cluster[start_idx:end_idx]
            cur_batch_size = end_idx - start_idx

            feed_dict = {ops['pointclouds_pl']: batch_data,
                         ops['labels_pl']: batch_label,
                         ops['alpha']: 1,  # No impact on evaluation,
                         ops['is_training_pl']: is_training,
                         }

            loss, dist, max_pool = sess.run([ops['kmeans_loss'], ops['stack_dist'],
                                             ops['max_pool']], feed_dict=feed_dict)
            cluster_assign = np.zeros((cur_batch_size), dtype=int)
            for i in range(cur_batch_size):
                index_closest_cluster = np.argmin(dist[:, i])
                cluster_assign[i] = index_closest_cluster

            batch_cluster = np.array([np.where(r == 1)[0][0] for r in current_cluster[start_idx:end_idx]])

            if len(y_val) == 0:
                y_val = batch_cluster
                y_assign = cluster_assign
                y_pool = np.squeeze(max_pool)
                y_mass = adds['masses'][start_idx:end_idx]
            else:
                y_val = np.concatenate((y_val, batch_cluster), axis=0)
                y_assign = np.concatenate((y_assign, cluster_assign), axis=0)
                y_pool = np.concatenate((y_pool, np.squeeze(max_pool)), axis=0)
                y_mass = np.concatenate((y_mass, adds['masses'][start_idx:end_idx]), axis=0)

    with h5py.File(os.path.join(H5_OUT, '{0}.h5'.format(NAME)), "w") as fh5:
        dset = fh5.create_dataset("pid", data=y_val)  # Real jet categories
        dset = fh5.create_dataset("label", data=y_assign)  # Cluster labeling
        dset = fh5.create_dataset("max_pool", data=y_pool)
        dset = fh5.create_dataset("masses", data=y_mass)

In [None]:
np.set_printoptions(threshold=sys.maxsize)
#BASE_DIR = os.path.dirname(os.path.abspath(__file__)) #FILE SHOULD BE SPECIFIED - H5 MODEL FILE?
BASE_DIR = './'
sys.path.append(BASE_DIR)
sys.path.append(os.path.dirname(BASE_DIR))
sys.path.append(os.path.join(BASE_DIR, '..', 'models'))
sys.path.append(os.path.join(BASE_DIR, '..', 'utils'))
# from MVA_cfg import *
#import provider
#import gapnet_classify as MODEL

In [None]:
#FLAGS = parser.parse_args()
#LOG_DIR = os.path.join('..', 'logs', FLAGS.log_dir)
#DATA_DIR = FLAGS.data_dir
#H5_DIR = os.path.join(BASE_DIR, DATA_DIR)
#H5_OUT = FLAGS.h5_folder
#if not os.path.exists(H5_OUT): os.mkdir(H5_OUT)

LOG_DIR = 'log'
DATA_DIR = '../h5'
H5_DIR = DATA_DIR
H5_OUT = '../h5/'

In [None]:

# MAIN SCRIPT
#NUM_POINT = FLAGS.num_point
#BATCH_SIZE = FLAGS.batch
#NFEATURES = FLAGS.nfeat
#FULL_TRAINING = FLAGS.full_train

NUM_POINT = 100
BATCH_SIZE = 512
NFEATURES = 8
FULL_TRAINING = 'FALSE'
N_CLUSTERS = 3
MAX_DIM = 3
NAME = 'OUTPUT.TXT'

In [None]:
print(checkpoint_dir)

In [None]:
#BASE_DIR = os.path.join(os.getcwd(), '06_LHC','scripts')  
#os.path.dirname(os.path.abspath(__file__))
#sys.path.append(BASE_DIR)
#sys.path.append(os.path.join(BASE_DIR, '..', 'models'))

#RUN THE ACTUAL EVALUATION HERE!
H5_MODEL_PATH = '06_LHC/h5'

#NUM_CATEGORIES = FLAGS.ncat
NUM_CATEGORIES = 20
FLAGSGPU = 0
# Only used to get how many parts per category

print('#### Batch Size : {0}'.format(BATCH_SIZE))
print('#### Point Number: {0}'.format(NUM_POINT))
print('#### Using GPUs: {0}'.format(FLAGSGPU))
#print('#### Using GPUs: {0}'.format(FLAGS.gpu))

print('### Starting evaluation')

EVALUATE_FILES = getDataFiles(os.path.join(H5_MODEL_PATH, 'test_files_wztop.txt'))
#EVALUATE_FILES = getDataFiles(os.path.join(H5_DIR, 'evaluate_files_wztop.txt'))

  

>     #### Batch Size : 512
>     #### Point Number: 100
>     #### Using GPUs: 0
>     ### Starting evaluation

In [None]:
LOG_DIR='/dbfs/databricks/driver/06_LHC/logs/train/1608545228.287947/'
#os.path.join(os.getcwd(), LOG_DIR)


  

>     Out[97]: '/dbfs/databricks/driver/06_LHC/logs/train/1608545228.287947/'

In [None]:
#%sh 
#ls

  

>     06_LHC
>     conf
>     derby.log
>     eventlogs
>     ganglia
>     log
>     logs

In [None]:
################################################


if __name__ == '__main__':
    if not os.path.exists('log'):
      os.makedirs('log')
    eval()

  

>     INFO:tensorflow:Restoring parameters from /dbfs/databricks/driver/06_LHC/logs/train/1608545228.287947/cluster.ckpt-27000
>     model restored

In [None]:
ls /dbfs/databricks/driver/06_LHC/logs/train/1608545228.287947/



  

>     checkpoint
>     cluster.ckpt-27000.data-00000-of-00002
>     cluster.ckpt-27000.data-00001-of-00002
>     cluster.ckpt-27000.index
>     cluster.ckpt-27000.meta
>     cluster.ckpt-28000.data-00000-of-00002
>     cluster.ckpt-28000.data-00001-of-00002
>     cluster.ckpt-28000.index
>     cluster.ckpt-28000.meta
>     cluster.ckpt-29000.data-00000-of-00002
>     cluster.ckpt-29000.data-00001-of-00002
>     cluster.ckpt-29000.index
>     cluster.ckpt-29000.meta
>     events.out.tfevents.1608545270.1120-144117-apses921-10-149-235-171
>     events.out.tfevents.1608545270.1120-144117-apses921-10-149-242-155
>     graph.pbtxt
>     model.ckpt-27339.data-00000-of-00002
>     model.ckpt-27339.data-00001-of-00002
>     model.ckpt-27339.index
>     model.ckpt-27339.meta
>     model.ckpt-28778.data-00000-of-00002
>     model.ckpt-28778.data-00001-of-00002
>     model.ckpt-28778.index
>     model.ckpt-28778.meta

In [None]:
ls 06_LHC/h5

  

>     evaluate_files_RD.txt
>     evaluate_files_b1.txt
>     evaluate_files_b2.txt
>     evaluate_files_b3.txt
>     evaluate_files_gwztop.txt
>     evaluate_files_wztop.txt
>     test_files_RD.txt
>     test_files_b1.txt
>     test_files_b2.txt
>     test_files_b3.txt
>     test_files_gwztop.txt
>     test_files_wztop.txt
>     test_multi_20v_100P.h5
>     train_files_RD.txt
>     train_files_b1.txt
>     train_files_b2.txt
>     train_files_b3.txt
>     train_files_gwztop.txt
>     train_files_wztop.txt
>     train_multi_20v_100P.h5