Skip to content

Commit

Permalink
Merge pull request #3 from petrux/dev
Browse files Browse the repository at this point in the history
Merging from dev branch
  • Loading branch information
petrux committed Apr 20, 2017
2 parents b23da2e + f669d74 commit aed41e0
Show file tree
Hide file tree
Showing 6 changed files with 290 additions and 3 deletions.
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
*.pyc
.vscode
.vscode
.venv
2 changes: 1 addition & 1 deletion liteflow/tests/test_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -279,7 +279,7 @@ def test_timeslice(self):
sess.run(tf.global_variables_initializer())
outputs_actual = sess.run(outputs, {tensor: tensor_actual, indices: indices_actual})
self.assertAllClose(outputs_expected, outputs_actual)
print outputs_actual


if __name__ == '__main__':
tf.test.main()
121 changes: 121 additions & 0 deletions liteflow/tests/test_vocabulary.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
"""Test module for the `liteflow.vocabulary` module."""

import unittest

import mock

from liteflow import vocabulary


class _BaseVocabulary(vocabulary.BaseVocabulary):
def contains(self, word): pass # pylint: disable=I0011,C0321
def index(self, word): pass # pylint: disable=I0011,C0321
def word(self, index): pass # pylint: disable=I0011,C0321
def size(self): pass # pylint: disable=I0011,C0321
def items(self): pass # pylint: disable=I0011,C0321


class BaseVocabularyTest(unittest.TestCase):
"""Test case for the `liteflow.vocabulary.BaseVocabulary` contract."""

@mock.patch.object(_BaseVocabulary, 'contains')
def test_contains(self, contains):
"""Test that __contains__ bounces on contains()."""

vocab = _BaseVocabulary()
self.assertEquals(0, contains.call_count)

arg = 23
_ = arg in vocab
self.assertEquals(1, contains.call_count)
contains.assert_called_with(arg)

arg = object()
_ = arg in vocab
self.assertEquals(2, contains.call_count)
contains.assert_called_with(arg)

@mock.patch.object(_BaseVocabulary, 'size')
def test_size(self, size):
"""Test that __len__ bounces on size()."""

vocab = _BaseVocabulary()
self.assertEquals(0, size.call_count)

_ = len(vocab)
self.assertEquals(1, size.call_count)

_ = len(vocab)
self.assertEquals(2, size.call_count)

@mock.patch.object(_BaseVocabulary, 'items')
def test_items(self, items):
"""Test that __iter__ bounces on items()."""

vocab = _BaseVocabulary()
items.return_value = iter([])
self.assertEquals(0, items.call_count)
_ = iter(vocab)
self.assertEquals(1, items.call_count)


class InMemoryVocabularyTest(unittest.TestCase):
"""Test case for the `liteflow.vocabulary.InMemoryVocabulary` class."""

def test_empty(self):
"""Test the empty vocabulary."""
vocab = vocabulary.InMemoryVocabulary()
self.assertEquals(0, len(vocab))

def test_base(self):
"""Test the basic functionalities of the vocabulary."""

words = 'A B C X Y Z'.split()
vocab = vocabulary.InMemoryVocabulary()

for i, word in enumerate(words):
self.assertFalse(word in vocab)
self.assertEquals(i, vocab.add(word))
self.assertTrue(word in vocab)
self.assertEquals(i, vocab.index(word))
self.assertEquals(word, vocab.word(i))
self.assertEquals(i + 1, len(vocab))

def test_oov_words(self):
"""Test out-of-vocabulary words."""

unk = 'Q'
words = 'A B C X Y Z'.split()
vocab = vocabulary.InMemoryVocabulary()
for word in words:
vocab.add(word)

self.assertFalse(unk in vocab)
self.assertRaises(ValueError, lambda: vocab.index(unk))

def test_oov_indexes(self):
"""Test out-of-vocabulary indexes."""

words = 'A B C X Y Z'.split()
vocab = vocabulary.InMemoryVocabulary()
for word in words:
vocab.add(word)

for index, word in enumerate(words):
self.assertEquals(word, vocab.word(index))
self.assertRaises(ValueError, lambda: vocab.word(-1))
self.assertRaises(ValueError, lambda: vocab.word(len(words)))

def test_add_twice(self):
"""Test adding a word twice."""

word = 'WORD'
vocab = vocabulary.InMemoryVocabulary()
index = vocab.add(word)
self.assertEquals(1, vocab.size())
self.assertEquals(index, vocab.add(word))
self.assertEquals(1, vocab.size())


if __name__ == '__main__':
unittest.main()
153 changes: 153 additions & 0 deletions liteflow/vocabulary.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,153 @@
"""Vocabularies for managing symbols encoding/decoding.
In this module, a base class `BaseVocabulary` defines the main API for a vocabulary
that maps words to integer values and viceversa. Such poblem is very common dealing
with natural language processing tasks, so a shared API and some already implemented
code can be handy sometimes.
"""

import abc


class BaseVocabulary(object):
"""Base vocabulary read-only interface.
The `BaseVocabulary` abstract class provides the base interface
for a vocabulary containing a set of |V| words and their corresponding
index, an integer value ranging from 0 to |V|-1.
"""

__metaclass__ = abc.ABCMeta

@abc.abstractmethod
def contains(self, word):
"""Check if a word is contained in the vocabulary.
Arguments:
word: a word from the vocabulary.
Returns:
`True` if the word is contained in the vocabulary,
`False` otherwise.
"""
raise NotImplementedError(
"""The abstract method `contains(self, word)` """
"""must be implemented in subclasses.""")

@abc.abstractmethod
def index(self, word):
"""Get the index value for the given word.
Arguments:
word: a word from the vocabulary.
Returns:
an `int` representing the index value of the
given word.
Raises:
ValueError: if the word is not in the vocabulary.
"""
raise NotImplementedError(
"""The abstract method `index(self, word)` """
"""must be implemented in subclasses.""")

@abc.abstractmethod
def word(self, index):
"""Get the word for the given index.
Arguments:
index: an `int` representing an index value for a word.
Returns:
the word corresponding to the given index value.
Raises:
ValueError: if the index value is not between 0 and the
number of words contained in the vocabulary minus 1.
"""
raise NotImplementedError(
"""The abstract method `word(self, index)` """
"""must be implemented in subclasses.""")

@abc.abstractmethod
def size(self):
"""Get the number of words in the vocabulary."""
raise NotImplementedError(
"""The abstract method `size(self)` """
"""must be implemented in subclasses.""")

@abc.abstractmethod
def items(self):
"""Return an iterator over the pairs (index, word)."""
raise NotImplementedError(
"""The abstract method `items(self)` """
"""must be implemented in subclasses.""")

def __contains__(self, item):
"""Magic method wrapping the `contains()` method."""
return self.contains(item)

def __len__(self):
"""Magic method wrapping the `size()` method."""
return self.size()

def __iter__(self):
"""Magic method wrapping the `items()` method."""
return self.items()


class InMemoryVocabulary(BaseVocabulary):
"""In-memory implementation of the BaseVocabulary contract.
The InMemoryVocabulary class holds in-memory data structure to
extend the BaseVocabulary superclass. All the access operations
are ensured to have O(1) time complexity.
"""

def __init__(self):
self._index = {}
self._words = []

def contains(self, word):
return word in self._index

def index(self, word):
if word in self._index:
return self._index[word]
raise ValueError('Word \'%s\' is not in the vocabulary.' % word)

def word(self, index):
if index < 0 or index >= len(self._words):
raise ValueError('Index must be between 0 and %d, found %d instead'
% (len(self._words) - 1, index))
return self._words[index]

def size(self):
return len(self._words)

def items(self):
return enumerate(self._words)

def add(self, word):
"""Add a new word to the vocabulary.
Arguments:
word: a word to be added to the vocabulary.
Returns:
The index of the word.
Remarks:
if the word is already in the vocabulary, it is not added twice
and its current index value will be returned.
"""

if word in self._words:
return self._index[word]

index = len(self._words)
self._words.append(word)
self._index[word] = index
return index

12 changes: 12 additions & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
appdirs==1.4.3
funcsigs==1.0.2
mock==2.0.0
numpy==1.12.1
packaging==16.8
pbr==2.1.0
pkg-resources==0.0.0
protobuf==3.2.0
pyparsing==2.2.0
pyspark==2.1.0+hadoop2.7
six==1.10.0
tensorflow==1.0.1
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

setup(name='liteflow',
version='0.1',
description='Liteweight TensorFlow extensio library',
description='Liteweight TensorFlow extension library',
url='https://github.com/petrux/LiTeFlow',
author='Giulio Petrucci (petrux)',
author_email='giulio.petrucci@gmail.com',
Expand Down

0 comments on commit aed41e0

Please sign in to comment.