Skip to content

Commit

Permalink
Merge pull request #488 from onyxfish/89
Browse files Browse the repository at this point in the history
Add test for FilteringCSVReader with any_match argument, closes #89
  • Loading branch information
James McKinney committed Jan 22, 2016
2 parents a89715e + fbefdf0 commit 0ae1294
Show file tree
Hide file tree
Showing 2 changed files with 61 additions and 34 deletions.
17 changes: 11 additions & 6 deletions csvkit/grep.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,13 +63,18 @@ def __next__(self):

def test_row(self, row):
for idx, test in self.patterns.items():
if self.any_match and test(row[idx]):
return not self.inverse # True

if not self.any_match and not test(row[idx]):
return self.inverse # False
result = test(row[idx])
if self.any_match:
if result:
return not self.inverse # True
else:
if not result:
return self.inverse # False

return not self.inverse # True
if self.any_match:
return self.inverse # False
else:
return not self.inverse # True

def standardize_patterns(column_names, patterns):
"""
Expand Down
78 changes: 50 additions & 28 deletions tests/test_grep.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from csvkit.grep import FilteringCSVReader
from csvkit.exceptions import ColumnIdentifierError


class TestGrep(unittest.TestCase):
def setUp(self):
self.tab1 = [
Expand All @@ -24,23 +25,23 @@ def setUp(self):
[u'1', u'first', u'0'],
[u'4', u'only', u'0'],
[u'1', u'second', u'0'],
[u'2', u'only', u'0', u'0']] # Note extra value in this column
[u'2', u'only', u'0', u'0']] # Note extra value in this column

def test_pattern(self):
fcr = FilteringCSVReader(iter(self.tab1),patterns=['1'])
self.assertEqual(self.tab1[0],next(fcr))
self.assertEqual(self.tab1[1],next(fcr))
self.assertEqual(self.tab1[4],next(fcr))
fcr = FilteringCSVReader(iter(self.tab1), patterns=['1'])
self.assertEqual(self.tab1[0], next(fcr))
self.assertEqual(self.tab1[1], next(fcr))
self.assertEqual(self.tab1[4], next(fcr))
try:
next(fcr)
self.fail("Should be no more rows left.")
except StopIteration:
pass

def test_no_header(self):
fcr = FilteringCSVReader(iter(self.tab1),patterns={ 2: 'only' },header=False)
self.assertEqual(self.tab1[2],next(fcr))
self.assertEqual(self.tab1[3],next(fcr))
fcr = FilteringCSVReader(iter(self.tab1), patterns={2: 'only'}, header=False)
self.assertEqual(self.tab1[2], next(fcr))
self.assertEqual(self.tab1[3], next(fcr))
try:
next(fcr)
self.fail("Should be no more rows left.")
Expand All @@ -49,44 +50,44 @@ def test_no_header(self):

def test_regex(self):
pattern = re.compile(".*(Reader|Tribune).*")
fcr = FilteringCSVReader(iter(self.tab1),patterns = { 1: pattern })
self.assertEqual(self.tab1[0],next(fcr))
self.assertEqual(self.tab1[1],next(fcr))
self.assertEqual(self.tab1[3],next(fcr))
self.assertEqual(self.tab1[4],next(fcr))
fcr = FilteringCSVReader(iter(self.tab1), patterns={1: pattern})

self.assertEqual(self.tab1[0], next(fcr))
self.assertEqual(self.tab1[1], next(fcr))
self.assertEqual(self.tab1[3], next(fcr))
self.assertEqual(self.tab1[4], next(fcr))
try:
next(fcr)
self.fail("Should be no more rows left.")
except StopIteration:
pass

def test_inverse(self):
fcr = FilteringCSVReader(iter(self.tab2),patterns = ['1'], inverse=True)
self.assertEqual(self.tab2[0],next(fcr))
self.assertEqual(self.tab2[2],next(fcr))
self.assertEqual(self.tab2[4],next(fcr))
fcr = FilteringCSVReader(iter(self.tab2), patterns=['1'], inverse=True)
self.assertEqual(self.tab2[0], next(fcr))
self.assertEqual(self.tab2[2], next(fcr))
self.assertEqual(self.tab2[4], next(fcr))
try:
next(fcr)
self.fail("Should be no more rows left.")
except StopIteration:
pass

def test_column_names_in_patterns(self):
fcr = FilteringCSVReader(iter(self.tab2),patterns = {'age': 'only'})
self.assertEqual(self.tab2[0],next(fcr))
self.assertEqual(self.tab2[2],next(fcr))
self.assertEqual(self.tab2[4],next(fcr))
fcr = FilteringCSVReader(iter(self.tab2), patterns={'age': 'only'})
self.assertEqual(self.tab2[0], next(fcr))
self.assertEqual(self.tab2[2], next(fcr))
self.assertEqual(self.tab2[4], next(fcr))
try:
next(fcr)
self.fail("Should be no more rows left.")
except StopIteration:
pass

def test_mixed_indices_and_column_names_in_patterns(self):
fcr = FilteringCSVReader(iter(self.tab2),patterns = {'age': 'only', 0: '2'})
self.assertEqual(self.tab2[0],next(fcr))
self.assertEqual(self.tab2[4],next(fcr))
fcr = FilteringCSVReader(iter(self.tab2), patterns={'age': 'only', 0: '2'})
self.assertEqual(self.tab2[0], next(fcr))
self.assertEqual(self.tab2[4], next(fcr))
try:
next(fcr)
self.fail("Should be no more rows left.")
Expand All @@ -95,8 +96,29 @@ def test_mixed_indices_and_column_names_in_patterns(self):

def test_duplicate_column_ids_in_patterns(self):
try:
fcr = FilteringCSVReader(iter(self.tab2),patterns = {'age': 'only', 1: 'second'})
FilteringCSVReader(iter(self.tab2), patterns={'age': 'only', 1: 'second'})
self.fail("Should be an exception.")
except ColumnIdentifierError:
pass

def test_any_match(self):
fcr = FilteringCSVReader(iter(self.tab2), patterns={'age': 'only', 0: '2'}, any_match=True)
self.assertEqual(self.tab2[0], next(fcr))
self.assertEqual(self.tab2[2], next(fcr))
self.assertEqual(self.tab2[4], next(fcr))
try:
next(fcr)
self.fail("Should be no more rows left.")
except StopIteration:
pass

def test_any_match_and_inverse(self):
fcr = FilteringCSVReader(iter(self.tab2), patterns={'age': 'only', 0: '2'}, any_match=True, inverse=True)
self.assertEqual(self.tab2[0], next(fcr))
self.assertEqual(self.tab2[1], next(fcr))
self.assertEqual(self.tab2[3], next(fcr))
try:
next(fcr)
self.fail("Should be no more rows left.")
except StopIteration:
pass

0 comments on commit 0ae1294

Please sign in to comment.