# fm_ops_keras.py

In [1]:
import tensorflow as tf

In [2]:
import keras
import numpy as np

Using TensorFlow backend.


In [4]:
from tensorflow import nn as nn

In [6]:
import scipy.linalg

In [10]:
# Matrix Square Root
# sqrtm = scipy.linalg.sqrtm(m).real
# Keep type consistent
# np.matrix.astype(np.float/int/...)

In [11]:
def weightedIterativeStatistic(point_list, weights, \
                               iterative_stat_function):
    mean = iterative_stat_function(point_list[0], \
                        point_list[1], weights[0])
    for point, weight in zip(point_list[2:], weights[2:]):
        mean = iterative_stat_function(mean, point, weight)
    return mean

In [18]:
# Geodesic approximation
def stiefelGeodesicApprox(X, Y, t):
    lift = Y - 0.5*np.matmul(X, (np.matmul(Y.transpose(),X) + \
                     np.matmul(X.transpose(), Y)))
    scale = t * lift
    a = np.identity(scale.shape[-1]) \
        + np.matmul(scale.transpose(), scale)
    a = np.linalg.inv(scipy.linalg.sqrt(a).real)
    retract = np.matmul(X+scale, a)
    return retract

In [24]:
def grassmanGeodesic(X, Y, t):
    svd_term = np.matmul(Y, np.linalg.inv(\
                        np.matmul(X.transpose, Y)) ) - X
    U, s, V = np.linalg.svd(svd_term)
    theta = np.arctan2(s, dtype=float())
    qr_term = \
        np.matmul( X, np.matmul(V, np.diag(np.cos(theta*t)))) \
            + np.matmul(U, np.diag(np.sin(theta*t)))
    return qr_term

In [25]:
def weightedFrechetMeanUpdate(previous_mean, new_point, weight,\
                geodesic_generator=grassmanGeodesic):
    return geodesic_generator(previous_mean, new_point, np.float(weight))

In [33]:
class fullConv1d(keras.layers.Layer):
    # Init with: 
    # iterative_mean_function = weightedFrechetMeanUpdate
    # num_frames = number of blocks
    def __init__(self, num_frames, iterative_mean_function=\
                weightedFrechetMeanUpdate):
        super(fullConv1d, self).__init__()
        self.iterative_mean_function = iterative_mean_function
        # Weights: Default - init weights to compute unweighted FM
        # Frechet Mean:
        self.weight = tf.data.Dataset.from_tensor_slices(\
                [1/n for n in range(2, num_frames+2) ] )
        self.weight_reference = \
            np.sum(tf.data.Dataset.from_tensor_slices(\
                [1/n for n in range(2, num_frames+2) ] ) )

    def forward(self, block_list):
        # Computes weighted FM.
        out = weightedIterativeStatistic(block_list, self.weight,\
                        weightedFrechetMeanUpdate)
        weight_penalty = (self.weight_reference - \
                          np.sum(self.weight))**2
        
        return out, weight_penalty

In [35]:
class GrassmannAverageProjection(keras.layers.Layer):
    def __init__(self, in_frames, out_frames):
        super(GrassmannAverageProjection, self).__init__()
        self.temporal_mean = GrassmannAverage(in_frames, \
                                              out_frames)
    def forward(self, x):
        y, weight_penalty = self.temporal_mean(x)
        x = temporalProjection(x, y)
        return x, weight_penalty
        

In [36]:
class GrassmannAverageBottleneck(keras.layers.Layer):
    def __init__(self, in_frames, out_frames):
        super(GrassmannAverageBottleneck, self).__init__()
        self.temporal_mean = GrassmannAverage(in_frames, \
                                              out_frames)
    def forward(self, x):
        y, weight_penalty = self.temporal_mean(x)
        x = temporalReconstruction(x,y)
        return x, weight_penalty

In [None]:
# Given an array with dimensions [num_frames, 1, 2, 128],
# reduce the first dimension by taking a weighted FM.
class GrassmannAverage(keras.layers.Layer):
    def __init__(self, in_frames, out_frames):
        super(GrassmannAverage, self).__init__()
        self.out_frames = out_frames
        self.num_blocks = int(in_frames/out_frames)
        self.weights = tf.data.Dataset.from_tensor_slices(\
                    [1/n for n in range(2,self.num_blocks+2)] )
        self.weight_reference = keras.layers.