Skip to content

Commit

Permalink
feat: add arg index_elements to ColumnsIndex
Browse files Browse the repository at this point in the history
  • Loading branch information
pckhoi committed Nov 23, 2021
1 parent 65168e1 commit e86ba47
Show file tree
Hide file tree
Showing 2 changed files with 79 additions and 5 deletions.
25 changes: 20 additions & 5 deletions datamatch/indices.py
Expand Up @@ -6,6 +6,7 @@
import operator
import functools
import itertools
from collections.abc import Iterable
from typing import Type
from abc import ABC, abstractmethod

Expand Down Expand Up @@ -91,17 +92,22 @@ class ColumnsIndex(BaseIndex):
"""Split data into multiple buckets based on one or more columns.
"""

def __init__(self, cols: str or list[str], ignore_key_error: bool = False) -> None:
def __init__(self, cols: str or list[str], ignore_key_error: bool = False, index_elements: bool = False) -> None:
"""
:param cols: single column name or list of column names to index.
:type cols: :obj:`str` or :obj:`list` of :obj:`str`
:param ignore_key_error: When set to True, a column does not exist in the frame, don't produce
any bucket instead of raising a KeyError.
:type ignore_key_error: :obj:`bool`
:param index_elements: Set this to True when each value in the column to index is a list, and
you want to index using the list elements.
:type index_elements: :obj:`bool`
"""
super().__init__()
self._ignore_key_error = ignore_key_error
self._index_elements = index_elements
if type(cols) is str:
self._cols = [cols]
else:
Expand All @@ -111,10 +117,19 @@ def _key_ind_map(self, df: pd.DataFrame) -> dict:
result = dict()
try:
for idx, row in df.iterrows():
key = tuple(
row[col] for col in self._cols
)
result.setdefault(key, list()).append(idx)
if self._index_elements:
for col in self._cols:
if not isinstance(row[col], Iterable):
raise ValueError('column %s at row %s is not iterable: %s' % (
col, row.name, row[col]
))
for key in itertools.product(*list(row[col] for col in self._cols)):
result.setdefault(key, list()).append(idx)
else:
key = tuple(
row[col] for col in self._cols
)
result.setdefault(key, list()).append(idx)
for l in result.values():
l.sort()
except KeyError:
Expand Down
59 changes: 59 additions & 0 deletions datamatch/test_indices.py
@@ -1,5 +1,6 @@
import unittest
import pandas as pd
from pandas.core.indexes.range import RangeIndex
from pandas.testing import assert_frame_equal
import itertools

Expand Down Expand Up @@ -77,6 +78,64 @@ def test_ignore_key_error(self):
self.assertEqual(ColumnsIndex(
'c', ignore_key_error=True).keys(df), set())

def test_index_elements(self):
cols = ['col1', 'col2']
df = pd.DataFrame(
[
[['a', 'b'], 'q'],
[['c'], 'w'],
[['b'], 'e'],
],
index=RangeIndex(start=0, stop=3),
columns=cols
)
idx = ColumnsIndex('col1', index_elements=True)
keys = idx.keys(df)
self.assertEqual(keys, set([('a',), ('b',), ('c',)]))
assert_frame_equal(
idx.bucket(df, ('a',)),
pd.DataFrame([
[['a', 'b'], 'q']
], index=[0], columns=cols)
)
assert_frame_equal(
idx.bucket(df, ('b',)),
pd.DataFrame([
[['a', 'b'], 'q'],
[['b'], 'e'],
], index=[0, 2], columns=cols)
)

def test_index_elements_multi_columns(self):
cols = ['col1', 'col2', 'col3']
df = pd.DataFrame(
[
[['a', 'b'], 'q', [1]],
[['c'], 'w', [2, 3]],
[['b'], 'e', [1]],
],
index=RangeIndex(start=0, stop=3),
columns=cols
)
idx = ColumnsIndex(['col1', 'col3'], index_elements=True)
keys = idx.keys(df)
self.assertEqual(keys, set([
('c', 2), ('a', 1), ('b', 1), ('b', 1), ('c', 3)
]))
assert_frame_equal(
idx.bucket(df, ('a', 1)),
pd.DataFrame([
[['a', 'b'], 'q', [1]],
], index=[0], columns=cols)
)
assert_frame_equal(
idx.bucket(df, ('b', 1)),
pd.DataFrame([
[['a', 'b'], 'q', [1]],
[['b'], 'e', [1]],
], index=[0, 2], columns=cols)
)


class MultiIndexTestCase(BaseIndexTestCase):
def test_index(self):
Expand Down

0 comments on commit e86ba47

Please sign in to comment.