diff --git a/example/tutorial_vgg19.py b/example/tutorial_vgg19.py index 4a40fc48b..498b3eddd 100755 --- a/example/tutorial_vgg19.py +++ b/example/tutorial_vgg19.py @@ -36,7 +36,8 @@ def load_image(path): # load image img = skimage.io.imread(path) img = img / 255.0 - assert (0 <= img).all() and (img <= 1.0).all() + if ((0 <= img).all() and (img <= 1.0).all()) is False: + raise Exception("image value should be [0, 1]") # print "Original Image Shape: ", img.shape # we crop image from center short_edge = min(img.shape[:2]) @@ -78,9 +79,12 @@ def Vgg19(rgb): else: # TF 1.0 print(rgb_scaled) red, green, blue = tf.split(rgb_scaled, 3, 3) - assert red.get_shape().as_list()[1:] == [224, 224, 1] - assert green.get_shape().as_list()[1:] == [224, 224, 1] - assert blue.get_shape().as_list()[1:] == [224, 224, 1] + if red.get_shape().as_list()[1:] != [224, 224, 1]: + raise Exception("image size unmatch") + if green.get_shape().as_list()[1:] != [224, 224, 1]: + raise Exception("image size unmatch") + if blue.get_shape().as_list()[1:] != [224, 224, 1]: + raise Exception("image size unmatch") if tf.__version__ <= '0.11': bgr = tf.concat(3, [ blue - VGG_MEAN[0], @@ -94,7 +98,8 @@ def Vgg19(rgb): green - VGG_MEAN[1], red - VGG_MEAN[2], ], axis=3) - assert bgr.get_shape().as_list()[1:] == [224, 224, 3] + if bgr.get_shape().as_list()[1:] != [224, 224, 3]: + raise Exception("image size unmatch") # input layer net_in = InputLayer(bgr, name='input') # conv1 @@ -149,9 +154,12 @@ def Vgg19_simple_api(rgb): else: # TF 1.0 print(rgb_scaled) red, green, blue = tf.split(rgb_scaled, 3, 3) - assert red.get_shape().as_list()[1:] == [224, 224, 1] - assert green.get_shape().as_list()[1:] == [224, 224, 1] - assert blue.get_shape().as_list()[1:] == [224, 224, 1] + if red.get_shape().as_list()[1:] != [224, 224, 1]: + raise Exception("image size unmatch") + if green.get_shape().as_list()[1:] != [224, 224, 1]: + raise Exception("image size unmatch") + if blue.get_shape().as_list()[1:] != [224, 224, 1]: + raise Exception("image size unmatch") if tf.__version__ <= '0.11': bgr = tf.concat(3, [ blue - VGG_MEAN[0], @@ -165,7 +173,8 @@ def Vgg19_simple_api(rgb): green - VGG_MEAN[1], red - VGG_MEAN[2], ], axis=3) - assert bgr.get_shape().as_list()[1:] == [224, 224, 3] + if bgr.get_shape().as_list()[1:] != [224, 224, 3]: + raise Exception("image size unmatch") # input layer net_in = InputLayer(bgr, name='input') # conv1