This repository has been archived by the owner on Aug 31, 2021. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 441
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #56 from google/categorical
Ref #55: Adding first implementation of categorical variables support
- Loading branch information
Showing
10 changed files
with
318 additions
and
78 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,118 @@ | ||
"""Implements preprocesing transformers for categorical variables.""" | ||
# Copyright 2015-present Scikit Flow Authors. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
import math | ||
import numpy as np | ||
|
||
from skflow.preprocessing import categorical_vocabulary | ||
|
||
|
||
class CategoricalProcessor(object): | ||
"""Maps documents to sequences of word ids. | ||
As a common convention, Nan values are handled as unknown tokens. | ||
Both float('nan') and np.nan are accepted. | ||
Parameters: | ||
min_frequency: Minimum frequency of categories in the vocabulary. | ||
share: Share vocabulary between variables. | ||
vocabularies: list of CategoricalVocabulary objects for each variable in | ||
the input dataset. | ||
Attributes: | ||
vocabularies_: CategoricalVocabulary object. | ||
""" | ||
|
||
def __init__(self, min_frequency=0, share=False, vocabularies=None): | ||
self.min_frequency = min_frequency | ||
self.share = share | ||
self.vocabularies_ = vocabularies | ||
|
||
def freeze(self, freeze=True): | ||
"""Freeze or unfreeze all vocabularies. | ||
Args: | ||
freeze: Boolean, indicate if vocabularies should be frozen. | ||
""" | ||
for vocab in self.vocabularies_: | ||
vocab.freeze(freeze) | ||
|
||
def fit(self, X, unused_y=None): | ||
"""Learn a vocabulary dictionary of all categories in X. | ||
Args: | ||
raw_documents: numpy matrix or iterable of lists/numpy arrays. | ||
unused_y: to match fit format signature of estimators. | ||
Returns: | ||
self | ||
""" | ||
for row in X: | ||
# Create vocabularies if not given. | ||
if self.vocabularies_ is None: | ||
# If not share, one per column, else one shared across. | ||
if not self.share: | ||
self.vocabularies_ = [ | ||
categorical_vocabulary.CategoricalVocabulary() for _ in row] | ||
else: | ||
vocab = categorical_vocabulary.CategoricalVocabulary() | ||
self.vocabularies_ = [vocab for _ in row] | ||
for idx, value in enumerate(row): | ||
# Nans are handled as unknowns. | ||
if (isinstance(value, float) and math.isnan(value)) or value == np.nan: | ||
continue | ||
self.vocabularies_[idx].add(value) | ||
if self.min_frequency > 0: | ||
for vocab in self.vocabularies_: | ||
vocab.trim(self.min_frequency) | ||
self.freeze() | ||
return self | ||
|
||
def fit_transform(self, X, unused_y=None): | ||
"""Learn the vocabulary dictionary and return indexies of categories. | ||
Args: | ||
X: numpy matrix or iterable of lists/numpy arrays. | ||
unused_y: to match fit_transform signature of estimators. | ||
Returns: | ||
X: iterable, [n_samples]. Category-id matrix. | ||
""" | ||
self.fit(X) | ||
return self.transform(X) | ||
|
||
def transform(self, X): | ||
"""Transform documents to category-id matrix. | ||
Converts categories to ids give fitted vocabulary from `fit` or | ||
one provided in the constructor. | ||
Args: | ||
X: numpy matrix or iterable of lists/numpy arrays. | ||
Returns: | ||
X: iterable, [n_samples]. Category-id matrix. | ||
""" | ||
self.freeze() | ||
for row in X: | ||
output_row = [] | ||
for idx, value in enumerate(row): | ||
# Return <UNK> when it's Nan. | ||
if (isinstance(value, float) and math.isnan(value)) or value == np.nan: | ||
output_row.append(0) | ||
continue | ||
output_row.append(self.vocabularies_[idx].get(value)) | ||
yield np.array(output_row, dtype=np.int64) | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,88 @@ | ||
"""Categorical vocabulary classes to map categories to indexes. | ||
Can be used for categorical variables, sparse variables and words. | ||
""" | ||
|
||
# Copyright 2015-present Scikit Flow Authors. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
import collections | ||
import six | ||
|
||
|
||
class CategoricalVocabulary(object): | ||
"""Categorical variables vocabulary class. | ||
Accumulates and provides mapping from classes to indexes. | ||
Can be easily used for words. | ||
""" | ||
|
||
def __init__(self, unknown_token='<UNK>'): | ||
self._mapping = {unknown_token: 0} | ||
self._freq = collections.defaultdict(int) | ||
self._freeze = False | ||
|
||
def __len__(self): | ||
return len(self._mapping) | ||
|
||
def freeze(self, freeze=True): | ||
"""Freezes the vocabulary, after which new words return unknown token id. | ||
Args: | ||
freeze: True to freeze, False to unfreeze. | ||
""" | ||
self._freeze = freeze | ||
|
||
def get(self, category): | ||
"""Returns word's id in the vocabulary. | ||
If category is new, creates a new id for it. | ||
Args: | ||
category: string or integer to lookup in vocabulary. | ||
Returns: | ||
interger, id in the vocabulary. | ||
""" | ||
if category not in self._mapping: | ||
if self._freeze: | ||
return 0 | ||
self._mapping[category] = len(self._mapping) | ||
return self._mapping[category] | ||
|
||
def add(self, category, count=1): | ||
"""Adds count of the category to the frequency table. | ||
Args: | ||
category: string or integer, category to add frequency to. | ||
count: optional integer, how many to add. | ||
""" | ||
category_id = self.get(category) | ||
if category_id <= 0: | ||
return | ||
self._freq[category] += count | ||
|
||
def trim(self, min_frequency, max_frequency=-1): | ||
"""Trims vocabulary for minimum frequency. | ||
Args: | ||
min_frequency: minimum frequency to keep. | ||
max_frequency: optional, maximum frequency to keep. | ||
Useful to remove very frequent categories (like stop words). | ||
""" | ||
for category, count in six.iteritems(self._freq): | ||
if count <= min_frequency and (max_frequency < 0 or | ||
count >= max_frequency): | ||
self._mapping.pop(category) | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,45 @@ | ||
# encoding: utf-8 | ||
|
||
# Copyright 2015-present Scikit Flow Authors. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
import numpy as np | ||
import tensorflow as tf | ||
|
||
from skflow.preprocessing import categorical | ||
|
||
|
||
class CategoricalTest(tf.test.TestCase): | ||
|
||
def testSingleCategoricalProcessor(self): | ||
cat_processor = categorical.CategoricalProcessor( | ||
min_frequency=1) | ||
X = cat_processor.fit_transform( | ||
[["0"], [1], [float('nan')], | ||
["C"], ["C"], [1], ["0"], [np.nan], [3]]) | ||
self.assertAllEqual(list(X), [ | ||
[1], [2], [0], [3], | ||
[3], [2], [1], [0], | ||
[0]]) | ||
|
||
def testMultiCategoricalProcessor(self): | ||
cat_processor = categorical.CategoricalProcessor( | ||
min_frequency=0, share=False) | ||
x = cat_processor.fit_transform( | ||
[["0", "Male"], [1, "Female"], ["3", "Male"]]) | ||
self.assertAllEqual(list(x), [[1, 1], [2, 2], [3, 1]]) | ||
|
||
|
||
if __name__ == "__main__": | ||
tf.test.main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,42 @@ | ||
# encoding: utf-8 | ||
|
||
# Copyright 2015-present Scikit Flow Authors. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
import tensorflow as tf | ||
|
||
from skflow.preprocessing import categorical_vocabulary | ||
|
||
|
||
class CategoricalVocabularyTest(tf.test.TestCase): | ||
|
||
def testIntVocabulary(self): | ||
vocab = categorical_vocabulary.CategoricalVocabulary() | ||
self.assertEqual(vocab.get(1), 1) | ||
self.assertEqual(vocab.get(3), 2) | ||
self.assertEqual(vocab.get(2), 3) | ||
self.assertEqual(vocab.get(3), 2) | ||
self.assertEqual(vocab.get(float('nan')), 4) | ||
|
||
|
||
def testWordVocabulary(self): | ||
vocab = categorical_vocabulary.CategoricalVocabulary() | ||
self.assertEqual(vocab.get('a'), 1) | ||
self.assertEqual(vocab.get('b'), 2) | ||
self.assertEqual(vocab.get('a'), 1) | ||
self.assertEqual(vocab.get('b'), 2) | ||
|
||
|
||
if __name__ == "__main__": | ||
tf.test.main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.