Skip to content

Commit

Permalink
fix tests & review
Browse files Browse the repository at this point in the history
  • Loading branch information
jlibovicky committed Mar 12, 2018
1 parent c778f5f commit a1a8365
Show file tree
Hide file tree
Showing 3 changed files with 13 additions and 71 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# pylint: disable=too-few-public-methods


class StatefulNumberEncoder(ModelPart, Stateful):
class StatefulFiller(ModelPart, Stateful):

def __init__(self,
name: str,
Expand Down Expand Up @@ -59,13 +59,12 @@ def feed_dict(self, dataset: Dataset, train: bool = False) -> FeedDict:
return {self.vector: dataset.get_series(self.data_id)}


class SpatialNumpyEncoder(ModelPart, SpatialStatefulWithOutput):
class SpatialFiller(ModelPart, SpatialStatefulWithOutput):
# pylint: disable=too-many-arguments
def __init__(self,
name: str,
input_shape: List[int],
data_id: str,
output_shape: int = None,
save_checkpoint: Optional[str] = None,
load_checkpoint: Optional[str] = None,
initializers: InitializerSpecs = None) -> None:
Expand All @@ -74,8 +73,6 @@ def __init__(self,
initializers)

assert len(input_shape) == 3
if output_shape is not None and output_shape <= 0:
raise ValueError("Output vector dimension must be postive.")

self.data_id = data_id
self.input_shape = input_shape
Expand Down
22 changes: 11 additions & 11 deletions neuralmonkey/readers/numpy_reader.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
from typing import List
from typing import List, Callable, Iterable
import os

from typeguard import check_argument_types
import numpy as np


def single_tensor(files: List[str]):
def single_tensor(files: List[str]) -> np.ndarray:
"""Load a single tensor from a numpy file."""
check_argument_types()
if len(files) == 1:
Expand All @@ -14,20 +14,20 @@ def single_tensor(files: List[str]):
return np.concatenate([np.load(f) for f in files], axis=0)


def from_file_list(prefix: str, default_tensor_name: str = "arr_0"):
"""Load list of numpy arrays according to a list of files.
def from_file_list(prefix: str,
default_tensor_name: str = "arr_0") -> Callable:
"""Load a list of numpy arrays from a list of .npz numpy files.
Args:
prefix: A common prefix of the files in lists of relative paths.
default_tensor_name: Key of the tensors to load in the npz files.
prefix: A common prefix for the files in the list.
default_tensor_name: Key of the tensors to load from the npz files.
Return:
A reader function that loads numpy arrays from files on path writen
path relatively to the given prefix.
Returns:
A generator function that yields the loaded arryas.
"""
check_argument_types()

def load(files: List[str]):
def load(files: List[str]) -> Iterable[np.ndarray]:
for list_file in files:
with open(list_file, encoding="utf-8") as f_list:
for line in f_list:
Expand All @@ -39,4 +39,4 @@ def load(files: List[str]):


# pylint: disable=invalid-name
numpy_file_list_reader = from_file_list("")
numpy_file_list_reader = from_file_list(prefix="")
55 changes: 0 additions & 55 deletions neuralmonkey/tests/test_encoders_init.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,6 @@

from typing import Dict, List, Any, Iterable

from neuralmonkey.encoders.numpy_encoder import (VectorEncoder,
PostCNNImageEncoder)
from neuralmonkey.encoders.recurrent import SentenceEncoder
from neuralmonkey.encoders.sentence_cnn_encoder import SentenceCNNEncoder
from neuralmonkey.model.sequence import EmbeddedSequence
Expand Down Expand Up @@ -80,37 +78,6 @@
"use_noisy_activations": [None, SentenceEncoder]
}

VECTOR_ENCODER_GOOD = {
"name": ["vector_encoder"],
"dimension": [10],
"data_id": ["marmelade"],
"output_shape": [1, None, 100]
}

VECTOR_ENCODER_BAD = {
"nonexistent": ["ahoj"],
"name": [None, 1],
"dimension": [0, -1, "ahoj", 3.14, VOCABULARY, SentenceEncoder, None],
"data_id": [3.14, VOCABULARY, None],
"output_shape": [0, -1, "ahoj", 3.14, VOCABULARY]
}

POST_CNN_IMAGE_ENCODER_GOOD = {
"name": ["vector_encoder"],
"input_shape": [[1, 2, 3], [10, 20, 3]],
"output_shape": [10],
"data_id": ["marmelade"],
}

POST_CNN_IMAGE_ENCODER_BAD = {
"nonexistent": ["ahoj"],
"name": [None, 1],
"data_id": [3.14, VOCABULARY, None],
"output_shape": [0, -1, "hoj", 3.14, None, VOCABULARY, SentenceEncoder],
"input_shape": [3, [10, 20], [-1, 10, 20], "123", "ahoj", 3.14,
VOCABULARY, []]
}


def traverse_combinations(
params: Dict[str, List[Any]],
Expand Down Expand Up @@ -187,28 +154,6 @@ def test_sentence_cnn_encoder(self):
SENTENCE_CNN_ENCODER_GOOD,
SENTENCE_CNN_ENCODER_BAD)

def test_vector_encoder(self):
with self.assertRaises(Exception):
# pylint: disable=no-value-for-parameter
# on purpose, should fail
VectorEncoder()
# pylint: enable=no-value-for-parameter

self._run_constructors(VectorEncoder,
VECTOR_ENCODER_GOOD,
VECTOR_ENCODER_BAD)

def test_post_cnn_encoder(self):
with self.assertRaises(Exception):
# pylint: disable=no-value-for-parameter
# on purpose, should fail
PostCNNImageEncoder()
# pylint: enable=no-value-for-parameter

self._run_constructors(PostCNNImageEncoder,
POST_CNN_IMAGE_ENCODER_GOOD,
POST_CNN_IMAGE_ENCODER_BAD)


if __name__ == "__main__":
unittest.main()

0 comments on commit a1a8365

Please sign in to comment.