Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 23 additions & 0 deletions onnx2kerastl/activation_layers.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import torch.nn
import tensorflow as tf
from tensorflow import keras
import logging
from .utils import ensure_tf_type, ensure_numpy_type
Expand Down Expand Up @@ -85,6 +87,27 @@ def convert_sigmoid(node, params, layers, lambda_func, node_name, keras_name):
layers[node_name] = sigmoid(input_0)


def convert_hard_sigmoid(node, params, layers, lambda_func, node_name, keras_name, alpha=0.167, beta=0.5):
"""
Convert Hard Sigmoid activation layer
:param node: current operation node
:param params: operation attributes
:param layers: available keras layers
:param lambda_func: function for keras Lambda layer
:param node_name: internal converter name
:param keras_name: resulting layer name
:return: None
"""
if len(node.input) != 1:
assert AttributeError('More than 1 input for an activation layer.')

input_0 = ensure_tf_type(layers[node.input[0]], name="%s_const" % keras_name)

input_0 = tf.multiply(input_0, 5/6) # TODO: conversion from torch HardSigmoid to tf version
hard_sigmoid = keras.layers.Activation('hard_sigmoid', name=keras_name)
layers[node_name] = hard_sigmoid(input_0)


def convert_tanh(node, params, layers, lambda_func, node_name, keras_name):
"""
Convert Tanh activation layer
Expand Down
3 changes: 2 additions & 1 deletion onnx2kerastl/layers.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from .convolution_layers import convert_conv, convert_convtranspose
from .activation_layers import convert_relu, convert_elu, convert_lrelu, convert_selu, \
convert_sigmoid, convert_tanh, convert_softmax, convert_prelu
convert_sigmoid, convert_hard_sigmoid, convert_tanh, convert_softmax, convert_prelu
from .operation_layers import convert_clip, convert_exp, convert_reduce_sum, convert_reduce_mean, \
convert_log, convert_pow, convert_sqrt, convert_split, convert_cast, convert_floor, convert_identity, \
convert_argmax, convert_reduce_l2, convert_reduce_max
Expand All @@ -22,6 +22,7 @@
'Elu': convert_elu,
'LeakyRelu': convert_lrelu,
'Sigmoid': convert_sigmoid,
'HardSigmoid': convert_hard_sigmoid,
'Tanh': convert_tanh,
'Selu': convert_selu,
'Clip': convert_clip,
Expand Down
3 changes: 3 additions & 0 deletions onnx2kerastl/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,3 +117,6 @@ def check_torch_keras_error(model, k_model, input_np, epsilon=1e-5, change_order
max_error = error

return max_error



46 changes: 46 additions & 0 deletions test/layers/activations/test_hard_sigmoid.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
import torch.nn as nn
import numpy as np
import pytest

from test.utils import convert_and_test


class LayerHardSigmoid(nn.Module):
"""
Test for nn.layers based types
"""
def __init__(self):
super(LayerHardSigmoid, self).__init__()
self.hard_sig = nn.Hardsigmoid()

def forward(self, x):
x = self.hard_sig(x)
return x


class FHardSigmoid(nn.Module):
"""
Test for nn.functional types
"""
def __init__(self):
super(FHardSigmoid, self).__init__()

def forward(self, x):
from torch.nn import functional as F
return F.hardsigmoid(x)


@pytest.mark.parametrize('change_ordering', [True, False])
def test_layer_sigmoid(change_ordering):
model = LayerHardSigmoid()
model.eval()
input_np = np.random.uniform(0, 1, (1, 3, 224, 224))
error = convert_and_test(model, input_np, verbose=False, change_ordering=change_ordering)


@pytest.mark.parametrize('change_ordering', [True, False])
def test_f_hard_sigmoid(change_ordering):
model = FHardSigmoid()
model.eval()
input_np = np.random.uniform(0, 1, (1, 3, 224, 224))
error = convert_and_test(model, input_np, verbose=False, change_ordering=change_ordering)