-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #3 from petrux/dev
Merging from dev branch
- Loading branch information
Showing
6 changed files
with
290 additions
and
3 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,2 +1,3 @@ | ||
*.pyc | ||
.vscode | ||
.vscode | ||
.venv |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,121 @@ | ||
"""Test module for the `liteflow.vocabulary` module.""" | ||
|
||
import unittest | ||
|
||
import mock | ||
|
||
from liteflow import vocabulary | ||
|
||
|
||
class _BaseVocabulary(vocabulary.BaseVocabulary): | ||
def contains(self, word): pass # pylint: disable=I0011,C0321 | ||
def index(self, word): pass # pylint: disable=I0011,C0321 | ||
def word(self, index): pass # pylint: disable=I0011,C0321 | ||
def size(self): pass # pylint: disable=I0011,C0321 | ||
def items(self): pass # pylint: disable=I0011,C0321 | ||
|
||
|
||
class BaseVocabularyTest(unittest.TestCase): | ||
"""Test case for the `liteflow.vocabulary.BaseVocabulary` contract.""" | ||
|
||
@mock.patch.object(_BaseVocabulary, 'contains') | ||
def test_contains(self, contains): | ||
"""Test that __contains__ bounces on contains().""" | ||
|
||
vocab = _BaseVocabulary() | ||
self.assertEquals(0, contains.call_count) | ||
|
||
arg = 23 | ||
_ = arg in vocab | ||
self.assertEquals(1, contains.call_count) | ||
contains.assert_called_with(arg) | ||
|
||
arg = object() | ||
_ = arg in vocab | ||
self.assertEquals(2, contains.call_count) | ||
contains.assert_called_with(arg) | ||
|
||
@mock.patch.object(_BaseVocabulary, 'size') | ||
def test_size(self, size): | ||
"""Test that __len__ bounces on size().""" | ||
|
||
vocab = _BaseVocabulary() | ||
self.assertEquals(0, size.call_count) | ||
|
||
_ = len(vocab) | ||
self.assertEquals(1, size.call_count) | ||
|
||
_ = len(vocab) | ||
self.assertEquals(2, size.call_count) | ||
|
||
@mock.patch.object(_BaseVocabulary, 'items') | ||
def test_items(self, items): | ||
"""Test that __iter__ bounces on items().""" | ||
|
||
vocab = _BaseVocabulary() | ||
items.return_value = iter([]) | ||
self.assertEquals(0, items.call_count) | ||
_ = iter(vocab) | ||
self.assertEquals(1, items.call_count) | ||
|
||
|
||
class InMemoryVocabularyTest(unittest.TestCase): | ||
"""Test case for the `liteflow.vocabulary.InMemoryVocabulary` class.""" | ||
|
||
def test_empty(self): | ||
"""Test the empty vocabulary.""" | ||
vocab = vocabulary.InMemoryVocabulary() | ||
self.assertEquals(0, len(vocab)) | ||
|
||
def test_base(self): | ||
"""Test the basic functionalities of the vocabulary.""" | ||
|
||
words = 'A B C X Y Z'.split() | ||
vocab = vocabulary.InMemoryVocabulary() | ||
|
||
for i, word in enumerate(words): | ||
self.assertFalse(word in vocab) | ||
self.assertEquals(i, vocab.add(word)) | ||
self.assertTrue(word in vocab) | ||
self.assertEquals(i, vocab.index(word)) | ||
self.assertEquals(word, vocab.word(i)) | ||
self.assertEquals(i + 1, len(vocab)) | ||
|
||
def test_oov_words(self): | ||
"""Test out-of-vocabulary words.""" | ||
|
||
unk = 'Q' | ||
words = 'A B C X Y Z'.split() | ||
vocab = vocabulary.InMemoryVocabulary() | ||
for word in words: | ||
vocab.add(word) | ||
|
||
self.assertFalse(unk in vocab) | ||
self.assertRaises(ValueError, lambda: vocab.index(unk)) | ||
|
||
def test_oov_indexes(self): | ||
"""Test out-of-vocabulary indexes.""" | ||
|
||
words = 'A B C X Y Z'.split() | ||
vocab = vocabulary.InMemoryVocabulary() | ||
for word in words: | ||
vocab.add(word) | ||
|
||
for index, word in enumerate(words): | ||
self.assertEquals(word, vocab.word(index)) | ||
self.assertRaises(ValueError, lambda: vocab.word(-1)) | ||
self.assertRaises(ValueError, lambda: vocab.word(len(words))) | ||
|
||
def test_add_twice(self): | ||
"""Test adding a word twice.""" | ||
|
||
word = 'WORD' | ||
vocab = vocabulary.InMemoryVocabulary() | ||
index = vocab.add(word) | ||
self.assertEquals(1, vocab.size()) | ||
self.assertEquals(index, vocab.add(word)) | ||
self.assertEquals(1, vocab.size()) | ||
|
||
|
||
if __name__ == '__main__': | ||
unittest.main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,153 @@ | ||
"""Vocabularies for managing symbols encoding/decoding. | ||
In this module, a base class `BaseVocabulary` defines the main API for a vocabulary | ||
that maps words to integer values and viceversa. Such poblem is very common dealing | ||
with natural language processing tasks, so a shared API and some already implemented | ||
code can be handy sometimes. | ||
""" | ||
|
||
import abc | ||
|
||
|
||
class BaseVocabulary(object): | ||
"""Base vocabulary read-only interface. | ||
The `BaseVocabulary` abstract class provides the base interface | ||
for a vocabulary containing a set of |V| words and their corresponding | ||
index, an integer value ranging from 0 to |V|-1. | ||
""" | ||
|
||
__metaclass__ = abc.ABCMeta | ||
|
||
@abc.abstractmethod | ||
def contains(self, word): | ||
"""Check if a word is contained in the vocabulary. | ||
Arguments: | ||
word: a word from the vocabulary. | ||
Returns: | ||
`True` if the word is contained in the vocabulary, | ||
`False` otherwise. | ||
""" | ||
raise NotImplementedError( | ||
"""The abstract method `contains(self, word)` """ | ||
"""must be implemented in subclasses.""") | ||
|
||
@abc.abstractmethod | ||
def index(self, word): | ||
"""Get the index value for the given word. | ||
Arguments: | ||
word: a word from the vocabulary. | ||
Returns: | ||
an `int` representing the index value of the | ||
given word. | ||
Raises: | ||
ValueError: if the word is not in the vocabulary. | ||
""" | ||
raise NotImplementedError( | ||
"""The abstract method `index(self, word)` """ | ||
"""must be implemented in subclasses.""") | ||
|
||
@abc.abstractmethod | ||
def word(self, index): | ||
"""Get the word for the given index. | ||
Arguments: | ||
index: an `int` representing an index value for a word. | ||
Returns: | ||
the word corresponding to the given index value. | ||
Raises: | ||
ValueError: if the index value is not between 0 and the | ||
number of words contained in the vocabulary minus 1. | ||
""" | ||
raise NotImplementedError( | ||
"""The abstract method `word(self, index)` """ | ||
"""must be implemented in subclasses.""") | ||
|
||
@abc.abstractmethod | ||
def size(self): | ||
"""Get the number of words in the vocabulary.""" | ||
raise NotImplementedError( | ||
"""The abstract method `size(self)` """ | ||
"""must be implemented in subclasses.""") | ||
|
||
@abc.abstractmethod | ||
def items(self): | ||
"""Return an iterator over the pairs (index, word).""" | ||
raise NotImplementedError( | ||
"""The abstract method `items(self)` """ | ||
"""must be implemented in subclasses.""") | ||
|
||
def __contains__(self, item): | ||
"""Magic method wrapping the `contains()` method.""" | ||
return self.contains(item) | ||
|
||
def __len__(self): | ||
"""Magic method wrapping the `size()` method.""" | ||
return self.size() | ||
|
||
def __iter__(self): | ||
"""Magic method wrapping the `items()` method.""" | ||
return self.items() | ||
|
||
|
||
class InMemoryVocabulary(BaseVocabulary): | ||
"""In-memory implementation of the BaseVocabulary contract. | ||
The InMemoryVocabulary class holds in-memory data structure to | ||
extend the BaseVocabulary superclass. All the access operations | ||
are ensured to have O(1) time complexity. | ||
""" | ||
|
||
def __init__(self): | ||
self._index = {} | ||
self._words = [] | ||
|
||
def contains(self, word): | ||
return word in self._index | ||
|
||
def index(self, word): | ||
if word in self._index: | ||
return self._index[word] | ||
raise ValueError('Word \'%s\' is not in the vocabulary.' % word) | ||
|
||
def word(self, index): | ||
if index < 0 or index >= len(self._words): | ||
raise ValueError('Index must be between 0 and %d, found %d instead' | ||
% (len(self._words) - 1, index)) | ||
return self._words[index] | ||
|
||
def size(self): | ||
return len(self._words) | ||
|
||
def items(self): | ||
return enumerate(self._words) | ||
|
||
def add(self, word): | ||
"""Add a new word to the vocabulary. | ||
Arguments: | ||
word: a word to be added to the vocabulary. | ||
Returns: | ||
The index of the word. | ||
Remarks: | ||
if the word is already in the vocabulary, it is not added twice | ||
and its current index value will be returned. | ||
""" | ||
|
||
if word in self._words: | ||
return self._index[word] | ||
|
||
index = len(self._words) | ||
self._words.append(word) | ||
self._index[word] = index | ||
return index | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,12 @@ | ||
appdirs==1.4.3 | ||
funcsigs==1.0.2 | ||
mock==2.0.0 | ||
numpy==1.12.1 | ||
packaging==16.8 | ||
pbr==2.1.0 | ||
pkg-resources==0.0.0 | ||
protobuf==3.2.0 | ||
pyparsing==2.2.0 | ||
pyspark==2.1.0+hadoop2.7 | ||
six==1.10.0 | ||
tensorflow==1.0.1 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters