You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Hi, thanks for sharing the code.
I have a question about the computation of wFM, which is stated to use the recursive estimator in the original paper, but it seems to become a direct weighted average of points (on the manifold) in the following code, is there anything I miss here or is this some kind of simplified way to compute the wFM?
def DCNN(x,d,W_root,mode):
'''
x is input, with shape batch * sequence_length * n_para * in_channel
d is the number of skipped, a number
w is the weights, with shape k * in_channel * out_channel
mode is "SPD" or "ODF"
'''
W = tf.pow(W_root,2)
batch_size = tf.shape(x)[0]#x.shape[0]
sequence_length = x.shape[1]
n_para = x.shape[2]
k = W.shape[0]
in_channel = W.shape[1]
out_channel = W.shape[2]
padding = (k - 1) * d
x_pad = tf.pad(x,tf.constant([(0,0),(1,0),(0,0),(0,0)]) * padding , "REFLECT") # for the first element, we need padding
W = tf.reshape(W,[k*in_channel,out_channel])
W_sum = tf.reduce_sum(W,0)
W = tf.div(W,W_sum) # constrain sum(w_k_inchannel) = 1
if mode =="SPD":
x_reorder = tf.transpose(x_pad,[0,2,1,3])
x_reshape = tf.reshape(x_reorder,[batch_size*n_para,1,sequence_length+padding,in_channel])
W = tf.reshape(W,[1,k,in_channel,out_channel])
conv1 = tf.nn.atrous_conv2d(x_reshape,W,d,"VALID",name=None)
conv1 = tf.reshape(conv1,[batch_size,n_para,sequence_length,out_channel])
conv1 = tf.transpose(conv1,[0,2,1,3])
return conv1
The text was updated successfully, but these errors were encountered:
Hi,
This is the simplified way to compute the wFM on SPD only.
We tried the recursive version (see https://github.com/zhenxingjian/SPD-SRU/blob/master/matrixcell.py to have an idea about the recursive version.) of wFM on this SPD manifold and the experiment shows those two share the similar results. Thus, we only include this simplified closed-version to compute wFM.
One important notice is that for the wFM, the weights should be:
Hi, thanks for sharing the code.
I have a question about the computation of wFM, which is stated to use the recursive estimator in the original paper, but it seems to become a direct weighted average of points (on the manifold) in the following code, is there anything I miss here or is this some kind of simplified way to compute the wFM?
The text was updated successfully, but these errors were encountered: