Skip to content

Commit

Permalink
Merge pull request #672 from ufal/img_feat
Browse files Browse the repository at this point in the history
Extracting image features
  • Loading branch information
jindrahelcl committed Mar 12, 2018
2 parents dfe8e09 + a1a8365 commit 07ec99c
Show file tree
Hide file tree
Showing 6 changed files with 196 additions and 110 deletions.
1 change: 0 additions & 1 deletion neuralmonkey/encoders/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from .cnn_encoder import CNNEncoder
from .cnn_encoder import CNNTemporalView
from .numpy_encoder import VectorEncoder
from .raw_rnn_encoder import RawRNNEncoder
from .recurrent import FactoredEncoder
from .recurrent import RecurrentEncoder
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# pylint: disable=too-few-public-methods


class VectorEncoder(ModelPart, Stateful):
class StatefulFiller(ModelPart, Stateful):

def __init__(self,
name: str,
Expand Down Expand Up @@ -59,12 +59,11 @@ def feed_dict(self, dataset: Dataset, train: bool = False) -> FeedDict:
return {self.vector: dataset.get_series(self.data_id)}


class PostCNNImageEncoder(ModelPart, SpatialStatefulWithOutput):
class SpatialFiller(ModelPart, SpatialStatefulWithOutput):
# pylint: disable=too-many-arguments
def __init__(self,
name: str,
input_shape: List[int],
output_shape: int,
data_id: str,
save_checkpoint: Optional[str] = None,
load_checkpoint: Optional[str] = None,
Expand All @@ -74,40 +73,24 @@ def __init__(self,
initializers)

assert len(input_shape) == 3
if output_shape <= 0:
raise ValueError("Output vector dimension must be postive.")

self.data_id = data_id

with self.use_scope():
features_shape = [None] + input_shape # type: ignore
self.image_features = tf.placeholder(tf.float32,
shape=features_shape,
name="image_input")

self.flat = tf.reduce_mean(self.image_features,
axis=[1, 2],
name="average_image")

self.project_w = get_variable(
name="img_init_proj_W",
shape=[input_shape[2], output_shape],
initializer=tf.glorot_normal_initializer())
self.project_b = get_variable(
name="img_init_b", shape=[output_shape],
initializer=tf.zeros_initializer())
self.input_shape = input_shape

@tensor
def output(self) -> tf.Tensor:
return tf.tanh(tf.matmul(self.flat, self.project_w) + self.project_b)
return tf.reduce_mean(
self.spatial_states, axis=[1, 2], name="average_image")

@tensor
def spatial_states(self) -> tf.Tensor:
return self.image_features
features_shape = [None] + self.input_shape # type: ignore
return tf.placeholder(
tf.float32, shape=features_shape, name="spatial_states")

@tensor
def spatial_mask(self) -> tf.Tensor:
return tf.ones(tf.shape(self.spatial_states)[:3])

def feed_dict(self, dataset: Dataset, train: bool = False) -> FeedDict:
return {self.image_features: dataset.get_series(self.data_id)}
return {self.spatial_states: dataset.get_series(self.data_id)}
62 changes: 36 additions & 26 deletions neuralmonkey/readers/image_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,37 +122,47 @@ def load(list_files: List[str]) -> Iterable[np.ndarray]:
"Image file '{}' no. {} does not exist."
.format(path, i + 1))

image = Image.open(path).convert("RGB")

width, height = image.size
if width == height:
_rescale_or_crop(image, target_width, target_height,
True, True, False)
elif height < width:
_rescale_or_crop(
image,
int(width * float(target_height) / height),
target_height, True, True, False)
else:
_rescale_or_crop(
image, target_width,
int(height * float(target_width) / width),
True, True, False)
cropped_image = _crop(image, target_width, target_height)

res = _pad(np.array(cropped_image),
target_width, target_height, 3)
assert res.shape == (target_width, target_height, 3)

if vgg_normalization:
res -= VGG_RGB_MEANS
if zero_one_normalization:
res /= 255.
res = single_image_for_imagenet(
path, target_height, target_width,
vgg_normalization, zero_one_normalization)

yield res
return load


def single_image_for_imagenet(
path: str, target_height: int, target_width: int,
vgg_normalization: bool, zero_one_normalization: bool) -> np.ndarray:
image = Image.open(path).convert("RGB")

width, height = image.size
if width == height:
_rescale_or_crop(image, target_width, target_height,
True, True, False)
elif height < width:
_rescale_or_crop(
image,
int(width * float(target_height) / height),
target_height, True, True, False)
else:
_rescale_or_crop(
image, target_width,
int(height * float(target_width) / width),
True, True, False)
cropped_image = _crop(image, target_width, target_height)

res = _pad(np.array(cropped_image),
target_width, target_height, 3)
assert res.shape == (target_width, target_height, 3)

if vgg_normalization:
res -= VGG_RGB_MEANS
if zero_one_normalization:
res /= 255.

return res


def _rescale_or_crop(image: Image.Image, pad_w: int, pad_h: int,
rescale_w: bool, rescale_h: bool,
keep_aspect_ratio: bool) -> Image.Image:
Expand Down
36 changes: 34 additions & 2 deletions neuralmonkey/readers/numpy_reader.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,42 @@
from typing import List
from typing import List, Callable, Iterable
import os

from typeguard import check_argument_types
import numpy as np


def numpy_reader(files: List[str]):
def single_tensor(files: List[str]) -> np.ndarray:
"""Load a single tensor from a numpy file."""
check_argument_types()
if len(files) == 1:
return np.load(files[0])

return np.concatenate([np.load(f) for f in files], axis=0)


def from_file_list(prefix: str,
default_tensor_name: str = "arr_0") -> Callable:
"""Load a list of numpy arrays from a list of .npz numpy files.
Args:
prefix: A common prefix for the files in the list.
default_tensor_name: Key of the tensors to load from the npz files.
Returns:
A generator function that yields the loaded arryas.
"""
check_argument_types()

def load(files: List[str]) -> Iterable[np.ndarray]:
for list_file in files:
with open(list_file, encoding="utf-8") as f_list:
for line in f_list:
path = os.path.join(prefix, line.rstrip())
with np.load(path) as npz:
yield npz[default_tensor_name]

return load


# pylint: disable=invalid-name
numpy_file_list_reader = from_file_list(prefix="")
55 changes: 0 additions & 55 deletions neuralmonkey/tests/test_encoders_init.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,6 @@

from typing import Dict, List, Any, Iterable

from neuralmonkey.encoders.numpy_encoder import (VectorEncoder,
PostCNNImageEncoder)
from neuralmonkey.encoders.recurrent import SentenceEncoder
from neuralmonkey.encoders.sentence_cnn_encoder import SentenceCNNEncoder
from neuralmonkey.model.sequence import EmbeddedSequence
Expand Down Expand Up @@ -80,37 +78,6 @@
"use_noisy_activations": [None, SentenceEncoder]
}

VECTOR_ENCODER_GOOD = {
"name": ["vector_encoder"],
"dimension": [10],
"data_id": ["marmelade"],
"output_shape": [1, None, 100]
}

VECTOR_ENCODER_BAD = {
"nonexistent": ["ahoj"],
"name": [None, 1],
"dimension": [0, -1, "ahoj", 3.14, VOCABULARY, SentenceEncoder, None],
"data_id": [3.14, VOCABULARY, None],
"output_shape": [0, -1, "ahoj", 3.14, VOCABULARY]
}

POST_CNN_IMAGE_ENCODER_GOOD = {
"name": ["vector_encoder"],
"input_shape": [[1, 2, 3], [10, 20, 3]],
"output_shape": [10],
"data_id": ["marmelade"],
}

POST_CNN_IMAGE_ENCODER_BAD = {
"nonexistent": ["ahoj"],
"name": [None, 1],
"data_id": [3.14, VOCABULARY, None],
"output_shape": [0, -1, "hoj", 3.14, None, VOCABULARY, SentenceEncoder],
"input_shape": [3, [10, 20], [-1, 10, 20], "123", "ahoj", 3.14,
VOCABULARY, []]
}


def traverse_combinations(
params: Dict[str, List[Any]],
Expand Down Expand Up @@ -187,28 +154,6 @@ def test_sentence_cnn_encoder(self):
SENTENCE_CNN_ENCODER_GOOD,
SENTENCE_CNN_ENCODER_BAD)

def test_vector_encoder(self):
with self.assertRaises(Exception):
# pylint: disable=no-value-for-parameter
# on purpose, should fail
VectorEncoder()
# pylint: enable=no-value-for-parameter

self._run_constructors(VectorEncoder,
VECTOR_ENCODER_GOOD,
VECTOR_ENCODER_BAD)

def test_post_cnn_encoder(self):
with self.assertRaises(Exception):
# pylint: disable=no-value-for-parameter
# on purpose, should fail
PostCNNImageEncoder()
# pylint: enable=no-value-for-parameter

self._run_constructors(PostCNNImageEncoder,
POST_CNN_IMAGE_ENCODER_GOOD,
POST_CNN_IMAGE_ENCODER_BAD)


if __name__ == "__main__":
unittest.main()
117 changes: 117 additions & 0 deletions scripts/imagenet_features.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
#!/usr/bin/env python3
"""Extract imagenet features from given images.
The script reads a list of pahts to images (specified by path prefix and list
of relative paths), process the images using an imagenet network and extract a
given convolutional map from the image. The maps are saved as numpy tensors in
files with a different prefix and the same relative path from this prefix
ending with .npz.
"""

import argparse
import os
import sys

import numpy as np
import tensorflow as tf

from neuralmonkey.dataset import Dataset
from neuralmonkey.encoders.imagenet_encoder import ImageNet
from neuralmonkey.logging import log
from neuralmonkey.readers.image_reader import single_image_for_imagenet


SUPPORTED_NETWORKS = [
"vgg_16", "vgg_19", "resnet_v2_50", "resnet_v2_101", "resnet_v2_152"]


def main():
parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument("--net", type=str, choices=SUPPORTED_NETWORKS,
help="Type of imagenet network.")
parser.add_argument("--input-prefix", type=str, default="",
help="Prefix of the image path.")
parser.add_argument("--output-prefix", type=str, default="",
help="Prefix of the path to the output numpy files.")
parser.add_argument("--slim-models", type=str, required=True,
help="Path to SLIM models in cloned tensorflow/models "
"repository")
parser.add_argument("--model-checkpoint", type=str, required=True,
help="Path to the ImageNet model checkpoint.")
parser.add_argument("--conv-map", type=str, required=True,
help="Name of the convolutional map that is.")
parser.add_argument("--images", type=str,
help="File with paths to images or stdin by default.")
parser.add_argument("--batch-size", type=int, default=128)
args = parser.parse_args()

if not os.path.exists(args.input_prefix):
raise ValueError("Directory {} does not exist.".format(
args.input_prefix))
if not os.path.exists(args.output_prefix):
raise ValueError("Directory {} does not exist.".format(
args.output_prefix))

if args.net.startswith("vgg_"):
img_size = 224
vgg_normalization = True
zero_one_normalization = False
elif args.net.startswith("resnet_v2"):
img_size = 229
vgg_normalization = False
zero_one_normalization = True
else:
raise ValueError("Unspported network: {}.".format(args._net))

log("Creating graph for the ImageNet network.")
imagenet = ImageNet(
name="imagenet", data_id="images", network_type=args.net,
slim_models_path=args.slim_models, load_checkpoint=args.model_checkpoint,
spatial_layer=args.conv_map)

log("Creating TensorFlow session.")
session = tf.Session()
session.run(tf.global_variables_initializer())
log("Loading ImageNet model variables.")
imagenet.load(session)

if args.images is None:
log("No input file provided, reading paths from stdin.")
source = sys.stdin
else:
source = open(args.images)

images = []
image_paths = []

def process_images():
dataset = Dataset("dataset", {"images": np.array(images)}, {})
feed_dict = imagenet.feed_dict(dataset)
feature_maps = session.run(imagenet.spatial_states, feed_dict=feed_dict)

for features, rel_path in zip(feature_maps, image_paths):
npz_path = os.path.join(args.output_prefix, rel_path + ".npz")
os.makedirs(os.path.dirname(npz_path), exist_ok=True)
np.savez(npz_path, features)
print(npz_path)


for img in source:
img_path = os.path.join(args.input_prefix, img.rstrip())
images.append(single_image_for_imagenet(
img_path, img_size, img_size, vgg_normalization,
zero_one_normalization))
image_paths.append(img.rstrip())

if len(images) >= args.batch_size:
process_images()
images = []
image_paths = []
process_images()

if args.images is not None:
source.close()


if __name__ == "__main__":
main()

0 comments on commit 07ec99c

Please sign in to comment.