Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 3 additions & 8 deletions test/models/test_alexnet.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,12 @@
import numpy as np
import pytest
import tensorflow as tf
from torchvision.models import alexnet

from test.utils import convert_and_test
from torchvision.models import alexnet


@pytest.mark.parametrize('change_ordering', [True, False])
def test_alexnet(change_ordering):
if not tf.test.gpu_device_name() and not change_ordering:
pytest.skip("Skip! Since tensorflow Conv2D op currently only supports the NHWC tensor format on the CPU")
def test_alexnet():
model = alexnet()
model.eval()

input_np = np.random.uniform(0, 1, (1, 3, 224, 224))
error = convert_and_test(model, input_np, verbose=False, change_ordering=change_ordering)
error = convert_and_test(model, input_np, verbose=False, should_transform_inputs=True)
6 changes: 2 additions & 4 deletions test/models/test_deeplab.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,10 @@


@pytest.mark.slow
@pytest.mark.parametrize('change_ordering', [False])
@pytest.mark.parametrize('model_class', [deeplabv3_resnet50, deeplabv3_resnet101, deeplabv3_mobilenet_v3_large])
def test_deeplab(change_ordering, model_class):
def test_deeplab(model_class):
model = model_class()
model.eval()

input_np = np.random.uniform(0, 1, (1, 3, 256, 256))
error = convert_and_test(model, input_np, verbose=False, change_ordering=change_ordering,
should_transform_inputs=True)
error = convert_and_test(model, input_np, verbose=False, should_transform_inputs=True)
10 changes: 3 additions & 7 deletions test/models/test_densenet.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,14 @@
import numpy as np
import pytest
import tensorflow as tf
from torchvision.models.densenet import densenet121

from test.utils import convert_and_test
from torchvision.models.densenet import densenet121


@pytest.mark.slow
@pytest.mark.parametrize('change_ordering', [True, False])
def test_densenet(change_ordering):
if not tf.test.gpu_device_name() and not change_ordering:
pytest.skip("Skip! Since tensorflow Conv2D op currently only supports the NHWC tensor format on the CPU")
def test_densenet():
model = densenet121()
model.eval()

input_np = np.random.uniform(0, 1, (1, 3, 224, 224))
error = convert_and_test(model, input_np, verbose=False, change_ordering=change_ordering)
error = convert_and_test(model, input_np, verbose=False, should_transform_inputs=True)
10 changes: 3 additions & 7 deletions test/models/test_googlenet.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,14 @@
import numpy as np
import pytest
import tensorflow as tf
from torchvision.models import googlenet

from test.utils import convert_and_test
from torchvision.models import googlenet


@pytest.mark.slow
@pytest.mark.parametrize('change_ordering', [True, False])
def test_googlenet(change_ordering):
if not tf.test.gpu_device_name() and not change_ordering:
pytest.skip("Skip! Since tensorflow Conv2D op currently only supports the NHWC tensor format on the CPU")
def test_googlenet():
model = googlenet()
model.eval()

input_np = np.random.uniform(0, 1, (1, 3, 224, 224))
error = convert_and_test(model, input_np, verbose=False, change_ordering=change_ordering)
error = convert_and_test(model, input_np, verbose=False, should_transform_inputs=True)
11 changes: 3 additions & 8 deletions test/models/test_mbnet2.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,12 @@
import numpy as np
import pytest
import tensorflow as tf
from torchvision.models import mobilenet_v2

from test.utils import convert_and_test
from torchvision.models import mobilenet_v2


@pytest.mark.parametrize('change_ordering', [True, False])
def test_mobilenetv2(change_ordering):
if not tf.test.gpu_device_name() and not change_ordering:
pytest.skip("Skip! Since tensorflow Conv2D op currently only supports the NHWC tensor format on the CPU")
def test_mobilenetv2():
model = mobilenet_v2()
model.eval()

input_np = np.random.uniform(0, 1, (1, 3, 224, 224))
error = convert_and_test(model, input_np, verbose=False, change_ordering=change_ordering)
error = convert_and_test(model, input_np, verbose=False, should_transform_inputs=True)
10 changes: 3 additions & 7 deletions test/models/test_mnasnet.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,15 @@
import numpy as np
import pytest
import tensorflow as tf
from torchvision.models import mnasnet0_5, mnasnet1_0, mnasnet0_75, mnasnet1_3

from test.utils import convert_and_test
from torchvision.models import mnasnet0_5, mnasnet1_0, mnasnet0_75, mnasnet1_3


@pytest.mark.slow
@pytest.mark.parametrize('change_ordering', [True, False])
@pytest.mark.parametrize('model_class', [mnasnet0_5, mnasnet0_75, mnasnet1_0, mnasnet1_3])
def test_mnasnet(change_ordering, model_class):
if not tf.test.gpu_device_name() and not change_ordering:
pytest.skip("Skip! Since tensorflow Conv2D op currently only supports the NHWC tensor format on the CPU")
def test_mnasnet(model_class):
model = model_class()
model.eval()

input_np = np.random.uniform(0, 1, (1, 3, 224, 224))
error = convert_and_test(model, input_np, verbose=False, change_ordering=change_ordering)
error = convert_and_test(model, input_np, verbose=False, should_transform_inputs=True)
8 changes: 2 additions & 6 deletions test/models/test_resnet18.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,14 @@
import numpy as np
import pytest
import tensorflow as tf

from torchvision.models import resnet18

from test.utils import convert_and_test


@pytest.mark.parametrize('change_ordering', [True, False])
@pytest.mark.parametrize('change_ordering', [False])
def test_resnet18(change_ordering):
if not tf.test.gpu_device_name() and not change_ordering:
pytest.skip("Skip! Since tensorflow Conv2D op currently only supports the NHWC tensor format on the CPU")
model = resnet18()
model.eval()

input_np = np.random.uniform(0, 1, (1, 3, 224, 224))
error = convert_and_test(model, input_np, verbose=False, change_ordering=change_ordering)
error = convert_and_test(model, input_np, verbose=False, should_transform_inputs=True)
10 changes: 3 additions & 7 deletions test/models/test_resnext.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,15 @@
import numpy as np
import pytest
import tensorflow as tf
from torchvision.models import resnext50_32x4d, resnext101_32x8d

from test.utils import convert_and_test
from torchvision.models import resnext50_32x4d, resnext101_32x8d


@pytest.mark.slow
@pytest.mark.parametrize('change_ordering', [True, False])
@pytest.mark.parametrize('model_class', [resnext50_32x4d, resnext101_32x8d])
def test_resnext(change_ordering, model_class):
if not tf.test.gpu_device_name() and not change_ordering:
pytest.skip("Skip! Since tensorflow Conv2D op currently only supports the NHWC tensor format on the CPU")
def test_resnext(model_class):
model = model_class()
model.eval()

input_np = np.random.uniform(0, 1, (1, 3, 224, 224))
error = convert_and_test(model, input_np, verbose=False, change_ordering=change_ordering)
error = convert_and_test(model, input_np, verbose=False, should_transform_inputs=True)
10 changes: 3 additions & 7 deletions test/models/test_squeezenet.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,15 @@
import numpy as np
import pytest
import tensorflow as tf
from torchvision.models import squeezenet1_0, squeezenet1_1

from test.utils import convert_and_test
from torchvision.models import squeezenet1_0, squeezenet1_1


@pytest.mark.slow
@pytest.mark.parametrize('change_ordering', [True, False])
@pytest.mark.parametrize('model_class', [squeezenet1_1, squeezenet1_0])
def test_squeezenet(change_ordering, model_class):
if not tf.test.gpu_device_name() and not change_ordering:
pytest.skip("Skip! Since tensorflow Conv2D op currently only supports the NHWC tensor format on the CPU")
def test_squeezenet(model_class):
model = model_class()
model.eval()

input_np = np.random.uniform(0, 1, (1, 3, 224, 224))
error = convert_and_test(model, input_np, verbose=False, change_ordering=change_ordering)
error = convert_and_test(model, input_np, verbose=False, should_transform_inputs=True)
10 changes: 3 additions & 7 deletions test/models/test_vgg.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,15 @@
import numpy as np
import pytest
import tensorflow as tf
from torchvision.models import vgg11, vgg11_bn

from test.utils import convert_and_test
from torchvision.models import vgg11, vgg11_bn


@pytest.mark.slow
@pytest.mark.parametrize('change_ordering', [True, False])
@pytest.mark.parametrize('model_class', [vgg11, vgg11_bn])
def test_vgg(change_ordering, model_class):
if not tf.test.gpu_device_name() and not change_ordering:
pytest.skip("Skip! Since tensorflow Conv2D op currently only supports the NHWC tensor format on the CPU")
def test_vgg(model_class):
model = model_class()
model.eval()

input_np = np.random.uniform(0, 1, (1, 3, 224, 224))
error = convert_and_test(model, input_np, verbose=False, change_ordering=change_ordering)
error = convert_and_test(model, input_np, verbose=False, should_transform_inputs=True)
7 changes: 5 additions & 2 deletions test/utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import io
import warnings

import onnx
import torch
Expand All @@ -9,6 +8,10 @@
from onnx2kerastl import onnx_to_keras, check_torch_keras_error


class LambdaLayerException(Exception):
pass


def torch2keras(model: torch.nn.Module, input_variable, verbose=True, change_ordering=False):
if isinstance(input_variable, (tuple, list)):
input_variable = tuple(torch.FloatTensor(var) for var in input_variable)
Expand Down Expand Up @@ -36,7 +39,7 @@ def convert_and_test(model: torch.nn.Module,
error = check_torch_keras_error(model, k_model, input_variable, change_ordering=change_ordering, epsilon=epsilon,
should_transform_inputs=should_transform_inputs)
if is_lambda_layers_exist(k_model):
warnings.warn("Found Lambda layers")
raise LambdaLayerException("Found Lambda layers")
return error


Expand Down