From 1a09c8ab85d86445e1c30680af276c7eff1af5b1 Mon Sep 17 00:00:00 2001 From: Jindrich Libovicky Date: Fri, 9 Mar 2018 22:18:38 +0100 Subject: [PATCH 1/4] script for generating imagenet feature maps --- neuralmonkey/readers/image_reader.py | 62 ++++++++------ scripts/imagenet_features.py | 117 +++++++++++++++++++++++++++ 2 files changed, 153 insertions(+), 26 deletions(-) create mode 100755 scripts/imagenet_features.py diff --git a/neuralmonkey/readers/image_reader.py b/neuralmonkey/readers/image_reader.py index ecf2754d9..4c133a168 100644 --- a/neuralmonkey/readers/image_reader.py +++ b/neuralmonkey/readers/image_reader.py @@ -122,37 +122,47 @@ def load(list_files: List[str]) -> Iterable[np.ndarray]: "Image file '{}' no. {} does not exist." .format(path, i + 1)) - image = Image.open(path).convert("RGB") - - width, height = image.size - if width == height: - _rescale_or_crop(image, target_width, target_height, - True, True, False) - elif height < width: - _rescale_or_crop( - image, - int(width * float(target_height) / height), - target_height, True, True, False) - else: - _rescale_or_crop( - image, target_width, - int(height * float(target_width) / width), - True, True, False) - cropped_image = _crop(image, target_width, target_height) - - res = _pad(np.array(cropped_image), - target_width, target_height, 3) - assert res.shape == (target_width, target_height, 3) - - if vgg_normalization: - res -= VGG_RGB_MEANS - if zero_one_normalization: - res /= 255. + res = single_image_for_imagenet( + path, target_height, target_width, + vgg_normalization, zero_one_normalization) yield res return load +def single_image_for_imagenet( + path: str, target_height: int, target_width: int, + vgg_normalization: bool, zero_one_normalization: bool) -> np.ndarray: + image = Image.open(path).convert("RGB") + + width, height = image.size + if width == height: + _rescale_or_crop(image, target_width, target_height, + True, True, False) + elif height < width: + _rescale_or_crop( + image, + int(width * float(target_height) / height), + target_height, True, True, False) + else: + _rescale_or_crop( + image, target_width, + int(height * float(target_width) / width), + True, True, False) + cropped_image = _crop(image, target_width, target_height) + + res = _pad(np.array(cropped_image), + target_width, target_height, 3) + assert res.shape == (target_width, target_height, 3) + + if vgg_normalization: + res -= VGG_RGB_MEANS + if zero_one_normalization: + res /= 255. + + return res + + def _rescale_or_crop(image: Image.Image, pad_w: int, pad_h: int, rescale_w: bool, rescale_h: bool, keep_aspect_ratio: bool) -> Image.Image: diff --git a/scripts/imagenet_features.py b/scripts/imagenet_features.py new file mode 100755 index 000000000..a079018f0 --- /dev/null +++ b/scripts/imagenet_features.py @@ -0,0 +1,117 @@ +#!/usr/bin/env python3 +"""Extract imagenet features from given images. + +The script reads a list of pahts to images (specified by path prefix and list +of relative paths), process the images using an imagenet network and extract a +given convolutional map from the image. The maps are saved as numpy tensors in +files with a different prefix and the same relative path from this prefix +ending with .npz. +""" + +import argparse +import os +import sys + +import numpy as np +import tensorflow as tf + +from neuralmonkey.dataset import Dataset +from neuralmonkey.encoders.imagenet_encoder import ImageNet +from neuralmonkey.logging import log +from neuralmonkey.readers.image_reader import single_image_for_imagenet + + +SUPPORTED_NETWORKS = [ + "vgg_16", "vgg_19", "resnet_v2_50", "resnet_v2_101", "resnet_v2_152"] + + +def main(): + parser = argparse.ArgumentParser(description=__doc__) + parser.add_argument("--net", type=str, choices=SUPPORTED_NETWORKS, + help="Type of imagenet network.") + parser.add_argument("--input-prefix", type=str, default="", + help="Prefix of the image path.") + parser.add_argument("--output-prefix", type=str, default="", + help="Prefix of the path to the output numpy files.") + parser.add_argument("--slim-models", type=str, required=True, + help="Path to SLIM models in cloned tensorflow/models " + "repository") + parser.add_argument("--model-checkpoint", type=str, required=True, + help="Path to the ImageNet model checkpoint.") + parser.add_argument("--conv-map", type=str, required=True, + help="Name of the convolutional map that is.") + parser.add_argument("--images", type=str, + help="File with paths to images or stdin by default.") + parser.add_argument("--batch-size", type=int, default=128) + args = parser.parse_args() + + if not os.path.exists(args.input_prefix): + raise ValueError("Directory {} does not exist.".format( + args.input_prefix)) + if not os.path.exists(args.output_prefix): + raise ValueError("Directory {} does not exist.".format( + args.output_prefix)) + + if args.net.startswith("vgg_"): + img_size = 224 + vgg_normalization = True + zero_one_normalization = False + elif args.net.startswith("resnet_v2"): + img_size = 229 + vgg_normalization = False + zero_one_normalization = True + else: + raise ValueError("Unspported network: {}.".format(args._net)) + + log("Creating graph for the ImageNet network.") + imagenet = ImageNet( + name="imagenet", data_id="images", network_type=args.net, + slim_models_path=args.slim_models, load_checkpoint=args.model_checkpoint, + spatial_layer=args.conv_map) + + log("Creating TensorFlow session.") + session = tf.Session() + session.run(tf.global_variables_initializer()) + log("Loading ImageNet model variables.") + imagenet.load(session) + + if args.images is None: + log("No input file provided, reading paths from stdin.") + source = sys.stdin + else: + source = open(args.images) + + images = [] + image_paths = [] + + def process_images(): + dataset = Dataset("dataset", {"images": np.array(images)}, {}) + feed_dict = imagenet.feed_dict(dataset) + feature_maps = session.run(imagenet.spatial_states, feed_dict=feed_dict) + + for features, rel_path in zip(feature_maps, image_paths): + npz_path = os.path.join(args.output_prefix, rel_path + ".npz") + os.makedirs(os.path.dirname(npz_path), exist_ok=True) + np.savez(npz_path, features) + print(npz_path) + + + for img in source: + img_path = os.path.join(args.input_prefix, img.rstrip()) + images.append(single_image_for_imagenet( + img_path, img_size, img_size, vgg_normalization, + zero_one_normalization)) + image_paths.append(img.rstrip()) + + if len(images) >= args.batch_size: + process_images() + images = [] + image_paths = [] + process_images() + + if args.images is not None: + source.close() + + +if __name__ == "__main__": + main() From 2fc5613de15ea9fd32f329defaa75782fb4e13e7 Mon Sep 17 00:00:00 2001 From: Jindrich Libovicky Date: Mon, 12 Mar 2018 15:43:59 +0100 Subject: [PATCH 2/4] update numpy reader to read from a list of files --- neuralmonkey/readers/numpy_reader.py | 34 +++++++++++++++++++++++++++- 1 file changed, 33 insertions(+), 1 deletion(-) diff --git a/neuralmonkey/readers/numpy_reader.py b/neuralmonkey/readers/numpy_reader.py index 727f99751..8c62fe200 100644 --- a/neuralmonkey/readers/numpy_reader.py +++ b/neuralmonkey/readers/numpy_reader.py @@ -1,10 +1,42 @@ from typing import List +import os +from typeguard import check_argument_types import numpy as np -def numpy_reader(files: List[str]): +def single_tensor(files: List[str]): + """Load a single tensor from a numpy file.""" + check_argument_types() if len(files) == 1: return np.load(files[0]) return np.concatenate([np.load(f) for f in files], axis=0) + + +def from_file_list(prefix: str, default_tensor_name: str = "arr_0"): + """Load list of numpy arrays according to a list of files. + + Args: + prefix: A common prefix of the files in lists of relative paths. + default_tensor_name: Key of the tensors to load in the npz files. + + Return: + A reader function that loads numpy arrays from files on path writen + path relatively to the given prefix. + """ + check_argument_types() + + def load(files: List[str]): + for list_file in files: + with open(list_file, encoding="utf-8") as f_list: + for line in f_list: + path = os.path.join(prefix, line.rstrip()) + with np.load(path) as npz: + yield npz[default_tensor_name] + + return load + + +# pylint: disable=invalid-name +numpy_file_list_reader = from_file_list("") From c778f5f49b520b20895114d475262375246adafc Mon Sep 17 00:00:00 2001 From: Jindrich Libovicky Date: Mon, 12 Mar 2018 15:46:10 +0100 Subject: [PATCH 3/4] refactor numpy enoder --- neuralmonkey/encoders/__init__.py | 1 - neuralmonkey/encoders/numpy_encoder.py | 36 ++++++++------------------ 2 files changed, 11 insertions(+), 26 deletions(-) diff --git a/neuralmonkey/encoders/__init__.py b/neuralmonkey/encoders/__init__.py index 07234434d..5db833b0e 100644 --- a/neuralmonkey/encoders/__init__.py +++ b/neuralmonkey/encoders/__init__.py @@ -1,6 +1,5 @@ from .cnn_encoder import CNNEncoder from .cnn_encoder import CNNTemporalView -from .numpy_encoder import VectorEncoder from .raw_rnn_encoder import RawRNNEncoder from .recurrent import FactoredEncoder from .recurrent import RecurrentEncoder diff --git a/neuralmonkey/encoders/numpy_encoder.py b/neuralmonkey/encoders/numpy_encoder.py index b0edf718d..98c1ff604 100644 --- a/neuralmonkey/encoders/numpy_encoder.py +++ b/neuralmonkey/encoders/numpy_encoder.py @@ -13,7 +13,7 @@ # pylint: disable=too-few-public-methods -class VectorEncoder(ModelPart, Stateful): +class StatefulNumberEncoder(ModelPart, Stateful): def __init__(self, name: str, @@ -59,13 +59,13 @@ def feed_dict(self, dataset: Dataset, train: bool = False) -> FeedDict: return {self.vector: dataset.get_series(self.data_id)} -class PostCNNImageEncoder(ModelPart, SpatialStatefulWithOutput): +class SpatialNumpyEncoder(ModelPart, SpatialStatefulWithOutput): # pylint: disable=too-many-arguments def __init__(self, name: str, input_shape: List[int], - output_shape: int, data_id: str, + output_shape: int = None, save_checkpoint: Optional[str] = None, load_checkpoint: Optional[str] = None, initializers: InitializerSpecs = None) -> None: @@ -74,40 +74,26 @@ def __init__(self, initializers) assert len(input_shape) == 3 - if output_shape <= 0: + if output_shape is not None and output_shape <= 0: raise ValueError("Output vector dimension must be postive.") self.data_id = data_id - - with self.use_scope(): - features_shape = [None] + input_shape # type: ignore - self.image_features = tf.placeholder(tf.float32, - shape=features_shape, - name="image_input") - - self.flat = tf.reduce_mean(self.image_features, - axis=[1, 2], - name="average_image") - - self.project_w = get_variable( - name="img_init_proj_W", - shape=[input_shape[2], output_shape], - initializer=tf.glorot_normal_initializer()) - self.project_b = get_variable( - name="img_init_b", shape=[output_shape], - initializer=tf.zeros_initializer()) + self.input_shape = input_shape @tensor def output(self) -> tf.Tensor: - return tf.tanh(tf.matmul(self.flat, self.project_w) + self.project_b) + return tf.reduce_mean( + self.spatial_states, axis=[1, 2], name="average_image") @tensor def spatial_states(self) -> tf.Tensor: - return self.image_features + features_shape = [None] + self.input_shape # type: ignore + return tf.placeholder( + tf.float32, shape=features_shape, name="spatial_states") @tensor def spatial_mask(self) -> tf.Tensor: return tf.ones(tf.shape(self.spatial_states)[:3]) def feed_dict(self, dataset: Dataset, train: bool = False) -> FeedDict: - return {self.image_features: dataset.get_series(self.data_id)} + return {self.spatial_states: dataset.get_series(self.data_id)} From a1a8365d528486b90300a8f3e787791f477290e4 Mon Sep 17 00:00:00 2001 From: Jindrich Libovicky Date: Mon, 12 Mar 2018 17:06:11 +0100 Subject: [PATCH 4/4] fix tests & review --- ...py_encoder.py => numpy_stateful_filler.py} | 7 +-- neuralmonkey/readers/numpy_reader.py | 22 ++++---- neuralmonkey/tests/test_encoders_init.py | 55 ------------------- 3 files changed, 13 insertions(+), 71 deletions(-) rename neuralmonkey/encoders/{numpy_encoder.py => numpy_stateful_filler.py} (91%) diff --git a/neuralmonkey/encoders/numpy_encoder.py b/neuralmonkey/encoders/numpy_stateful_filler.py similarity index 91% rename from neuralmonkey/encoders/numpy_encoder.py rename to neuralmonkey/encoders/numpy_stateful_filler.py index 98c1ff604..d3d9394a9 100644 --- a/neuralmonkey/encoders/numpy_encoder.py +++ b/neuralmonkey/encoders/numpy_stateful_filler.py @@ -13,7 +13,7 @@ # pylint: disable=too-few-public-methods -class StatefulNumberEncoder(ModelPart, Stateful): +class StatefulFiller(ModelPart, Stateful): def __init__(self, name: str, @@ -59,13 +59,12 @@ def feed_dict(self, dataset: Dataset, train: bool = False) -> FeedDict: return {self.vector: dataset.get_series(self.data_id)} -class SpatialNumpyEncoder(ModelPart, SpatialStatefulWithOutput): +class SpatialFiller(ModelPart, SpatialStatefulWithOutput): # pylint: disable=too-many-arguments def __init__(self, name: str, input_shape: List[int], data_id: str, - output_shape: int = None, save_checkpoint: Optional[str] = None, load_checkpoint: Optional[str] = None, initializers: InitializerSpecs = None) -> None: @@ -74,8 +73,6 @@ def __init__(self, initializers) assert len(input_shape) == 3 - if output_shape is not None and output_shape <= 0: - raise ValueError("Output vector dimension must be postive.") self.data_id = data_id self.input_shape = input_shape diff --git a/neuralmonkey/readers/numpy_reader.py b/neuralmonkey/readers/numpy_reader.py index 8c62fe200..47c131377 100644 --- a/neuralmonkey/readers/numpy_reader.py +++ b/neuralmonkey/readers/numpy_reader.py @@ -1,11 +1,11 @@ -from typing import List +from typing import List, Callable, Iterable import os from typeguard import check_argument_types import numpy as np -def single_tensor(files: List[str]): +def single_tensor(files: List[str]) -> np.ndarray: """Load a single tensor from a numpy file.""" check_argument_types() if len(files) == 1: @@ -14,20 +14,20 @@ def single_tensor(files: List[str]): return np.concatenate([np.load(f) for f in files], axis=0) -def from_file_list(prefix: str, default_tensor_name: str = "arr_0"): - """Load list of numpy arrays according to a list of files. +def from_file_list(prefix: str, + default_tensor_name: str = "arr_0") -> Callable: + """Load a list of numpy arrays from a list of .npz numpy files. Args: - prefix: A common prefix of the files in lists of relative paths. - default_tensor_name: Key of the tensors to load in the npz files. + prefix: A common prefix for the files in the list. + default_tensor_name: Key of the tensors to load from the npz files. - Return: - A reader function that loads numpy arrays from files on path writen - path relatively to the given prefix. + Returns: + A generator function that yields the loaded arryas. """ check_argument_types() - def load(files: List[str]): + def load(files: List[str]) -> Iterable[np.ndarray]: for list_file in files: with open(list_file, encoding="utf-8") as f_list: for line in f_list: @@ -39,4 +39,4 @@ def load(files: List[str]): # pylint: disable=invalid-name -numpy_file_list_reader = from_file_list("") +numpy_file_list_reader = from_file_list(prefix="") diff --git a/neuralmonkey/tests/test_encoders_init.py b/neuralmonkey/tests/test_encoders_init.py index 3a6986518..fe39373c6 100755 --- a/neuralmonkey/tests/test_encoders_init.py +++ b/neuralmonkey/tests/test_encoders_init.py @@ -6,8 +6,6 @@ from typing import Dict, List, Any, Iterable -from neuralmonkey.encoders.numpy_encoder import (VectorEncoder, - PostCNNImageEncoder) from neuralmonkey.encoders.recurrent import SentenceEncoder from neuralmonkey.encoders.sentence_cnn_encoder import SentenceCNNEncoder from neuralmonkey.model.sequence import EmbeddedSequence @@ -80,37 +78,6 @@ "use_noisy_activations": [None, SentenceEncoder] } -VECTOR_ENCODER_GOOD = { - "name": ["vector_encoder"], - "dimension": [10], - "data_id": ["marmelade"], - "output_shape": [1, None, 100] -} - -VECTOR_ENCODER_BAD = { - "nonexistent": ["ahoj"], - "name": [None, 1], - "dimension": [0, -1, "ahoj", 3.14, VOCABULARY, SentenceEncoder, None], - "data_id": [3.14, VOCABULARY, None], - "output_shape": [0, -1, "ahoj", 3.14, VOCABULARY] -} - -POST_CNN_IMAGE_ENCODER_GOOD = { - "name": ["vector_encoder"], - "input_shape": [[1, 2, 3], [10, 20, 3]], - "output_shape": [10], - "data_id": ["marmelade"], -} - -POST_CNN_IMAGE_ENCODER_BAD = { - "nonexistent": ["ahoj"], - "name": [None, 1], - "data_id": [3.14, VOCABULARY, None], - "output_shape": [0, -1, "hoj", 3.14, None, VOCABULARY, SentenceEncoder], - "input_shape": [3, [10, 20], [-1, 10, 20], "123", "ahoj", 3.14, - VOCABULARY, []] -} - def traverse_combinations( params: Dict[str, List[Any]], @@ -187,28 +154,6 @@ def test_sentence_cnn_encoder(self): SENTENCE_CNN_ENCODER_GOOD, SENTENCE_CNN_ENCODER_BAD) - def test_vector_encoder(self): - with self.assertRaises(Exception): - # pylint: disable=no-value-for-parameter - # on purpose, should fail - VectorEncoder() - # pylint: enable=no-value-for-parameter - - self._run_constructors(VectorEncoder, - VECTOR_ENCODER_GOOD, - VECTOR_ENCODER_BAD) - - def test_post_cnn_encoder(self): - with self.assertRaises(Exception): - # pylint: disable=no-value-for-parameter - # on purpose, should fail - PostCNNImageEncoder() - # pylint: enable=no-value-for-parameter - - self._run_constructors(PostCNNImageEncoder, - POST_CNN_IMAGE_ENCODER_GOOD, - POST_CNN_IMAGE_ENCODER_BAD) - if __name__ == "__main__": unittest.main()