Skip to content

Commit

Permalink
Add normalization, checks to Conv2D attributes (#488)
Browse files Browse the repository at this point in the history
* add additional normalization, checks on attributes

* move normalize_data_format to seperate file

* fix lint

* add more tests
  • Loading branch information
yanndupis committed May 21, 2019
1 parent 34897f4 commit 183c632
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 16 deletions.
30 changes: 20 additions & 10 deletions tf_encrypted/keras/layers/convolutional.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
# pylint: disable=arguments-differ
"""Convolutional Layer implementation."""
import logging

import numpy as np
from tensorflow.python.keras import initializers
from tensorflow.python.keras.utils import conv_utils

from tf_encrypted.keras.engine import Layer
from tf_encrypted.keras import activations
Expand Down Expand Up @@ -99,11 +99,17 @@ def __init__(self,

super(Conv2D, self).__init__(**kwargs)

self.rank = 2
self.filters = filters
self.kernel_size = kernel_size
self.strides = strides
self.padding = padding.upper()
self.data_format = data_format
self.kernel_size = conv_utils.normalize_tuple(
kernel_size, self.rank, 'kernel_size')
if self.kernel_size[0] != self.kernel_size[1]:
raise NotImplementedError("TF Encrypted currently only supports same "
"stride along the height and the width."
"You gave: {}".format(self.kernel_size))
self.strides = conv_utils.normalize_tuple(strides, self.rank, 'strides')
self.padding = conv_utils.normalize_padding(padding).upper()
self.data_format = conv_utils.normalize_data_format(data_format)
if activation is not None:
logger.info("Performing an activation before a pooling layer can result "
"in unnecesary performance loss. Check model definition in "
Expand Down Expand Up @@ -167,7 +173,10 @@ def call(self, inputs):
if self.data_format != 'channels_first':
inputs = self.prot.transpose(inputs, perm=[0, 3, 1, 2])

outputs = self.prot.conv2d(inputs, self.kernel, self.strides, self.padding)
outputs = self.prot.conv2d(inputs,
self.kernel,
self.strides[0],
self.padding)

if self.use_bias:
outputs = outputs + self.bias
Expand All @@ -189,10 +198,11 @@ def compute_output_shape(self, input_shape):
n_x, h_x, w_x, _ = input_shape.as_list()

if self.padding == "SAME":
h_out = int(np.ceil(float(h_x) / float(self.strides)))
w_out = int(np.ceil(float(w_x) / float(self.strides)))
h_out = int(np.ceil(float(h_x) / float(self.strides[0])))
w_out = int(np.ceil(float(w_x) / float(self.strides[0])))
if self.padding == "VALID":
h_out = int(np.ceil(float(h_x - h_filter + 1) / float(self.strides)))
w_out = int(np.ceil(float(w_x - w_filter + 1) / float(self.strides)))
h_out = int(np.ceil(float(h_x - h_filter + 1) / float(self.strides[0])))
w_out = int(np.ceil(float(w_x - w_filter + 1) / float(self.strides[0])))

return [n_x, n_filters, h_out, w_out]

23 changes: 17 additions & 6 deletions tf_encrypted/keras/layers/convolutional_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,22 +16,33 @@ def setUp(self):
tf.reset_default_graph()

def test_conv2d_bias(self):
self._core_conv2d(use_bias=True)
self._core_conv2d(kernel_size=2, use_bias=True)

def test_conv2d_nobias(self):
self._core_conv2d(use_bias=False)
self._core_conv2d(kernel_size=2, use_bias=False)

def test_conv2d_same_padding(self):
self._core_conv2d(kernel_size=2, padding='same')

def test_conv2d_kernelsize_tuple(self):
self._core_conv2d(kernel_size=(2, 2))

def _core_conv2d(self, **layer_kwargs):
filters_in = 3
input_shape = [2, filters_in, 6, 6] # channels first
input_shape = [2, 6, 6, filters_in] # channels last
filters = 5
kernel_size = (2, 2)
kernel = np.random.normal(kernel_size + (filters_in, filters))

if isinstance(layer_kwargs['kernel_size'], int):
kernel_size_in = (layer_kwargs['kernel_size'],) * 2
else:
kernel_size_in = layer_kwargs['kernel_size']

kernel = np.random.normal(kernel_size_in +
(filters_in, filters))
initializer = tf.keras.initializers.Constant(kernel)

base_kwargs = {
"filters": filters,
"kernel_size": kernel_size,
"strides": 2,
"kernel_initializer": initializer,
}
Expand Down

0 comments on commit 183c632

Please sign in to comment.