Skip to content

HTTPS clone URL

Subversion checkout URL

You can clone with HTTPS or Subversion.

Download ZIP
Browse files

handles case of empty source for DataMap

  • Loading branch information...
commit ed9d4ecc1444f98b42fc6d0f0899aee44b48c36e 1 parent 1903d73
@unbracketed authored
View
7 snowbird/datamap.py
@@ -50,6 +50,9 @@ def __init__(self, **options):
def get_extract_jobs(self):
"""Returns a list of DataJobs that will use this DataMap"""
metrics = self.IN.get_metrics()
+ if not metrics['rows']:
+ LOG.warning("No rows for source %s" %self.IN)
+ return []
num_jobs = (metrics['rows'] / self.batch_size)
return [DataJob(self.__class__,
offset*self.batch_size,
@@ -69,9 +72,9 @@ def run_job(self):
for row in self.IN:
mapped = self.process_row(row)
if self.OUT:
- pass
+ pass
else:
- LOG.info(str(row))
+ LOG.info(str(row))
def process_row(self, row):
return row
View
18 snowbird/tests/models.py
@@ -2,8 +2,16 @@
class TestModel(models.Model):
- field_1 = models.CharField(max_length=100)
- field_2 = models.CharField(max_length=100)
- field_3 = models.CharField(max_length=100)
- field_4 = models.CharField(max_length=100)
- field_5 = models.CharField(max_length=100)
+ field_1 = models.CharField(max_length=100)
+ field_2 = models.CharField(max_length=100)
+ field_3 = models.CharField(max_length=100)
+ field_4 = models.CharField(max_length=100)
+ field_5 = models.CharField(max_length=100)
+
+
+class TestModel2(models.Model):
+ field_1 = models.CharField(max_length=100)
+ field_2 = models.CharField(max_length=100)
+ field_3 = models.CharField(max_length=100)
+ field_4 = models.CharField(max_length=100)
+ field_5 = models.CharField(max_length=100)
View
22 snowbird/tests/test_datamaps.py
@@ -1,24 +1,22 @@
from django.test import TestCase
-from snowbird.datamap import DataJob
+from snowbird.datamap import DataJob, DataMap
+from snowbird.io import DjangoModel
from snowbird.tests.datamaps import TestModelDataMap
+from snowbird.tests.models import TestModel2
from django.test.utils import setup_test_environment
setup_test_environment()
class TestDataMap(TestCase):
- fixtures = ['testmodel.json']
+ fixtures = ['testmodel']
def setUp(self):
self.tm = TestModelDataMap()
- def test_create_datamap(self):
- pass
-
def test_get_extract_jobs(self):
- print self.tm.get_extract_jobs()
+ "make sure get_extract_jobs works"
dj = DataJob(self.tm.__class__, 0, 10)
- #[DataJob(source=<class 'snowbird.tests.datamaps.TestModelDataMap'>, offset=0, num_rows=1000)]
jobs = self.tm.get_extract_jobs()
self.assertEqual(len(jobs), 1)
self.assertEqual(jobs[0], dj)
@@ -26,7 +24,15 @@ def test_get_extract_jobs(self):
TestModelDataMap.batch_size = 3
tm = TestModelDataMap()
jobs = tm.get_extract_jobs()
- print jobs
self.assertEqual(len(jobs), 4)
self.assertEqual([3,3,3,1], [j.num_rows for j in jobs])
self.assertEqual([0,3,6,9], [j.offset for j in jobs])
+
+ class DM2(DjangoModel):
+ model = TestModel2
+ class TM2(DataMap):
+ source = DM2
+ tm2 = TM2()
+ jobs = tm2.get_extract_jobs()
+ print jobs
+ self.assertEqual(len(jobs), 0)
Please sign in to comment.
Something went wrong with that request. Please try again.