Skip to content
This repository has been archived by the owner on Aug 31, 2021. It is now read-only.

Commit

Permalink
Merge pull request #56 from google/categorical
Browse files Browse the repository at this point in the history
Ref #55: Adding first implementation of categorical variables support
  • Loading branch information
ilblackdragon committed Dec 30, 2015
2 parents 0c400ff + 0e7fd81 commit cc349fc
Show file tree
Hide file tree
Showing 10 changed files with 318 additions and 78 deletions.
8 changes: 6 additions & 2 deletions skflow/data_feeder.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,8 @@

def _get_in_out_shape(x_shape, y_shape, n_classes, batch_size):
"""Returns shape for input and output of the data feeder."""
input_shape = [batch_size] + list(x_shape[1:])
x_shape = list(x_shape[1:]) if len(x_shape) > 1 else [1]
input_shape = [batch_size] + x_shape
y_shape = list(y_shape[1:]) if len(y_shape) > 1 else []
# Skip first dimention if it is 1.
if y_shape and y_shape[0] == 1:
Expand Down Expand Up @@ -92,7 +93,10 @@ def _feed_dict_fn():
out = np.zeros(self.output_shape, dtype=self.output_dtype)
for i in xrange(self.batch_size):
sample = self.random_state.randint(0, self.X.shape[0])
inp[i, :] = self.X[sample, :]
if len(self.X.shape) == 1:
inp[i, :] = [self.X[sample]]
else:
inp[i, :] = self.X[sample, :]
if self.n_classes > 1:
if len(self.output_shape) == 2:
out.itemset((i, self.y[sample]), 1.0)
Expand Down
3 changes: 3 additions & 0 deletions skflow/estimators/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import datetime
from six import string_types

import numpy as np
import tensorflow as tf

from google.protobuf import text_format
Expand Down Expand Up @@ -226,6 +227,8 @@ def _predict(self, X):
raise NotFittedError()
if HAS_PANDAS:
X = extract_pandas_data(X)
if len(X.shape) == 1:
X = np.reshape(X, (-1, 1))
pred = self._session.run(self._model_predictions,
feed_dict={
self._inp.name: X
Expand Down
1 change: 1 addition & 0 deletions skflow/preprocessing/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,3 +14,4 @@
# limitations under the License.

from skflow.preprocessing.text import *
from skflow.preprocessing.categorical import *
118 changes: 118 additions & 0 deletions skflow/preprocessing/categorical.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
"""Implements preprocesing transformers for categorical variables."""
# Copyright 2015-present Scikit Flow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import math
import numpy as np

from skflow.preprocessing import categorical_vocabulary


class CategoricalProcessor(object):
"""Maps documents to sequences of word ids.
As a common convention, Nan values are handled as unknown tokens.
Both float('nan') and np.nan are accepted.
Parameters:
min_frequency: Minimum frequency of categories in the vocabulary.
share: Share vocabulary between variables.
vocabularies: list of CategoricalVocabulary objects for each variable in
the input dataset.
Attributes:
vocabularies_: CategoricalVocabulary object.
"""

def __init__(self, min_frequency=0, share=False, vocabularies=None):
self.min_frequency = min_frequency
self.share = share
self.vocabularies_ = vocabularies

def freeze(self, freeze=True):
"""Freeze or unfreeze all vocabularies.
Args:
freeze: Boolean, indicate if vocabularies should be frozen.
"""
for vocab in self.vocabularies_:
vocab.freeze(freeze)

def fit(self, X, unused_y=None):
"""Learn a vocabulary dictionary of all categories in X.
Args:
raw_documents: numpy matrix or iterable of lists/numpy arrays.
unused_y: to match fit format signature of estimators.
Returns:
self
"""
for row in X:
# Create vocabularies if not given.
if self.vocabularies_ is None:
# If not share, one per column, else one shared across.
if not self.share:
self.vocabularies_ = [
categorical_vocabulary.CategoricalVocabulary() for _ in row]
else:
vocab = categorical_vocabulary.CategoricalVocabulary()
self.vocabularies_ = [vocab for _ in row]
for idx, value in enumerate(row):
# Nans are handled as unknowns.
if (isinstance(value, float) and math.isnan(value)) or value == np.nan:
continue
self.vocabularies_[idx].add(value)
if self.min_frequency > 0:
for vocab in self.vocabularies_:
vocab.trim(self.min_frequency)
self.freeze()
return self

def fit_transform(self, X, unused_y=None):
"""Learn the vocabulary dictionary and return indexies of categories.
Args:
X: numpy matrix or iterable of lists/numpy arrays.
unused_y: to match fit_transform signature of estimators.
Returns:
X: iterable, [n_samples]. Category-id matrix.
"""
self.fit(X)
return self.transform(X)

def transform(self, X):
"""Transform documents to category-id matrix.
Converts categories to ids give fitted vocabulary from `fit` or
one provided in the constructor.
Args:
X: numpy matrix or iterable of lists/numpy arrays.
Returns:
X: iterable, [n_samples]. Category-id matrix.
"""
self.freeze()
for row in X:
output_row = []
for idx, value in enumerate(row):
# Return <UNK> when it's Nan.
if (isinstance(value, float) and math.isnan(value)) or value == np.nan:
output_row.append(0)
continue
output_row.append(self.vocabularies_[idx].get(value))
yield np.array(output_row, dtype=np.int64)

88 changes: 88 additions & 0 deletions skflow/preprocessing/categorical_vocabulary.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
"""Categorical vocabulary classes to map categories to indexes.
Can be used for categorical variables, sparse variables and words.
"""

# Copyright 2015-present Scikit Flow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import collections
import six


class CategoricalVocabulary(object):
"""Categorical variables vocabulary class.
Accumulates and provides mapping from classes to indexes.
Can be easily used for words.
"""

def __init__(self, unknown_token='<UNK>'):
self._mapping = {unknown_token: 0}
self._freq = collections.defaultdict(int)
self._freeze = False

def __len__(self):
return len(self._mapping)

def freeze(self, freeze=True):
"""Freezes the vocabulary, after which new words return unknown token id.
Args:
freeze: True to freeze, False to unfreeze.
"""
self._freeze = freeze

def get(self, category):
"""Returns word's id in the vocabulary.
If category is new, creates a new id for it.
Args:
category: string or integer to lookup in vocabulary.
Returns:
interger, id in the vocabulary.
"""
if category not in self._mapping:
if self._freeze:
return 0
self._mapping[category] = len(self._mapping)
return self._mapping[category]

def add(self, category, count=1):
"""Adds count of the category to the frequency table.
Args:
category: string or integer, category to add frequency to.
count: optional integer, how many to add.
"""
category_id = self.get(category)
if category_id <= 0:
return
self._freq[category] += count

def trim(self, min_frequency, max_frequency=-1):
"""Trims vocabulary for minimum frequency.
Args:
min_frequency: minimum frequency to keep.
max_frequency: optional, maximum frequency to keep.
Useful to remove very frequent categories (like stop words).
"""
for category, count in six.iteritems(self._freq):
if count <= min_frequency and (max_frequency < 0 or
count >= max_frequency):
self._mapping.pop(category)

45 changes: 45 additions & 0 deletions skflow/preprocessing/tests/test_categorical.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
# encoding: utf-8

# Copyright 2015-present Scikit Flow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import numpy as np
import tensorflow as tf

from skflow.preprocessing import categorical


class CategoricalTest(tf.test.TestCase):

def testSingleCategoricalProcessor(self):
cat_processor = categorical.CategoricalProcessor(
min_frequency=1)
X = cat_processor.fit_transform(
[["0"], [1], [float('nan')],
["C"], ["C"], [1], ["0"], [np.nan], [3]])
self.assertAllEqual(list(X), [
[1], [2], [0], [3],
[3], [2], [1], [0],
[0]])

def testMultiCategoricalProcessor(self):
cat_processor = categorical.CategoricalProcessor(
min_frequency=0, share=False)
x = cat_processor.fit_transform(
[["0", "Male"], [1, "Female"], ["3", "Male"]])
self.assertAllEqual(list(x), [[1, 1], [2, 2], [3, 1]])


if __name__ == "__main__":
tf.test.main()
42 changes: 42 additions & 0 deletions skflow/preprocessing/tests/test_categorical_vocabulary.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
# encoding: utf-8

# Copyright 2015-present Scikit Flow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import tensorflow as tf

from skflow.preprocessing import categorical_vocabulary


class CategoricalVocabularyTest(tf.test.TestCase):

def testIntVocabulary(self):
vocab = categorical_vocabulary.CategoricalVocabulary()
self.assertEqual(vocab.get(1), 1)
self.assertEqual(vocab.get(3), 2)
self.assertEqual(vocab.get(2), 3)
self.assertEqual(vocab.get(3), 2)
self.assertEqual(vocab.get(float('nan')), 4)


def testWordVocabulary(self):
vocab = categorical_vocabulary.CategoricalVocabulary()
self.assertEqual(vocab.get('a'), 1)
self.assertEqual(vocab.get('b'), 2)
self.assertEqual(vocab.get('a'), 1)
self.assertEqual(vocab.get('b'), 2)


if __name__ == "__main__":
tf.test.main()
7 changes: 0 additions & 7 deletions skflow/preprocessing/tests/test_text.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,13 +42,6 @@ def testByteProcessor(self):
[97, 98, 99, 0, 0, 0, 0, 0],
[49, 50, 51, 52, 53, 54, 55, 56]])

def testWordVocabulary(self):
vocab = text.WordVocabulary()
self.assertEqual(vocab.get('a'), 1)
self.assertEqual(vocab.get('b'), 2)
self.assertEqual(vocab.get('a'), 1)
self.assertEqual(vocab.get('b'), 2)

def testVocabularyProcessor(self):
vocab_processor = text.VocabularyProcessor(
max_document_length=4,
Expand Down
Loading

0 comments on commit cc349fc

Please sign in to comment.