Skip to content

Commit

Permalink
Add test for FilteringCSVReader with any_match argument, closes #89
Browse files Browse the repository at this point in the history
  • Loading branch information
James McKinney committed Jan 22, 2016
1 parent a89715e commit 0b75a63
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 27 deletions.
7 changes: 4 additions & 3 deletions csvkit/grep.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,13 +63,14 @@ def __next__(self):

def test_row(self, row):
for idx, test in self.patterns.items():
if self.any_match and test(row[idx]):
result = test(row[idx])
if result and self.any_match:
return not self.inverse # True

if not self.any_match and not test(row[idx]):
if not result and not self.any_match:
return self.inverse # False

return not self.inverse # True
return not self.any_match and not self.inverse # True

def standardize_patterns(column_names, patterns):
"""
Expand Down
58 changes: 34 additions & 24 deletions tests/test_grep.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,20 +27,20 @@ def setUp(self):
[u'2', u'only', u'0', u'0']] # Note extra value in this column

def test_pattern(self):
fcr = FilteringCSVReader(iter(self.tab1),patterns=['1'])
self.assertEqual(self.tab1[0],next(fcr))
self.assertEqual(self.tab1[1],next(fcr))
self.assertEqual(self.tab1[4],next(fcr))
fcr = FilteringCSVReader(iter(self.tab1), patterns=['1'])
self.assertEqual(self.tab1[0], next(fcr))
self.assertEqual(self.tab1[1], next(fcr))
self.assertEqual(self.tab1[4], next(fcr))
try:
next(fcr)
self.fail("Should be no more rows left.")
except StopIteration:
pass

def test_no_header(self):
fcr = FilteringCSVReader(iter(self.tab1),patterns={ 2: 'only' },header=False)
self.assertEqual(self.tab1[2],next(fcr))
self.assertEqual(self.tab1[3],next(fcr))
fcr = FilteringCSVReader(iter(self.tab1), patterns={2: 'only'}, header=False)
self.assertEqual(self.tab1[2], next(fcr))
self.assertEqual(self.tab1[3], next(fcr))
try:
next(fcr)
self.fail("Should be no more rows left.")
Expand All @@ -49,44 +49,44 @@ def test_no_header(self):

def test_regex(self):
pattern = re.compile(".*(Reader|Tribune).*")
fcr = FilteringCSVReader(iter(self.tab1),patterns = { 1: pattern })
fcr = FilteringCSVReader(iter(self.tab1), patterns={1: pattern})

self.assertEqual(self.tab1[0],next(fcr))
self.assertEqual(self.tab1[1],next(fcr))
self.assertEqual(self.tab1[3],next(fcr))
self.assertEqual(self.tab1[4],next(fcr))
self.assertEqual(self.tab1[0], next(fcr))
self.assertEqual(self.tab1[1], next(fcr))
self.assertEqual(self.tab1[3], next(fcr))
self.assertEqual(self.tab1[4], next(fcr))
try:
next(fcr)
self.fail("Should be no more rows left.")
except StopIteration:
pass

def test_inverse(self):
fcr = FilteringCSVReader(iter(self.tab2),patterns = ['1'], inverse=True)
self.assertEqual(self.tab2[0],next(fcr))
self.assertEqual(self.tab2[2],next(fcr))
self.assertEqual(self.tab2[4],next(fcr))
fcr = FilteringCSVReader(iter(self.tab2), patterns=['1'], inverse=True)
self.assertEqual(self.tab2[0], next(fcr))
self.assertEqual(self.tab2[2], next(fcr))
self.assertEqual(self.tab2[4], next(fcr))
try:
next(fcr)
self.fail("Should be no more rows left.")
except StopIteration:
pass

def test_column_names_in_patterns(self):
fcr = FilteringCSVReader(iter(self.tab2),patterns = {'age': 'only'})
self.assertEqual(self.tab2[0],next(fcr))
self.assertEqual(self.tab2[2],next(fcr))
self.assertEqual(self.tab2[4],next(fcr))
fcr = FilteringCSVReader(iter(self.tab2), patterns={'age': 'only'})
self.assertEqual(self.tab2[0], next(fcr))
self.assertEqual(self.tab2[2], next(fcr))
self.assertEqual(self.tab2[4], next(fcr))
try:
next(fcr)
self.fail("Should be no more rows left.")
except StopIteration:
pass

def test_mixed_indices_and_column_names_in_patterns(self):
fcr = FilteringCSVReader(iter(self.tab2),patterns = {'age': 'only', 0: '2'})
self.assertEqual(self.tab2[0],next(fcr))
self.assertEqual(self.tab2[4],next(fcr))
fcr = FilteringCSVReader(iter(self.tab2), patterns={'age': 'only', 0: '2'})
self.assertEqual(self.tab2[0], next(fcr))
self.assertEqual(self.tab2[4], next(fcr))
try:
next(fcr)
self.fail("Should be no more rows left.")
Expand All @@ -95,8 +95,18 @@ def test_mixed_indices_and_column_names_in_patterns(self):

def test_duplicate_column_ids_in_patterns(self):
try:
fcr = FilteringCSVReader(iter(self.tab2),patterns = {'age': 'only', 1: 'second'})
fcr = FilteringCSVReader(iter(self.tab2), patterns={'age': 'only', 1: 'second'})
self.fail("Should be an exception.")
except ColumnIdentifierError:
pass

def test_any_match(self):
fcr = FilteringCSVReader(iter(self.tab2), patterns={'age': 'only', 0: '2'}, any_match=True)
self.assertEqual(self.tab2[0], next(fcr))
self.assertEqual(self.tab2[2], next(fcr))
self.assertEqual(self.tab2[4], next(fcr))
try:
next(fcr)
self.fail("Should be no more rows left.")
except StopIteration:
pass

0 comments on commit 0b75a63

Please sign in to comment.