diff --git a/predictive_punter/command.py b/predictive_punter/command.py index 7089b9f..c5eb1dd 100644 --- a/predictive_punter/command.py +++ b/predictive_punter/command.py @@ -29,17 +29,21 @@ def parse_args(cls, args): """Return a dictionary of configuration values based on the provided command line arguments""" config = { - 'database_uri': 'mongodb://localhost:27017/predictive_punter', - 'date_from': datetime.now(), - 'date_to': datetime.now(), - 'redis_uri': 'redis://localhost:6379/predictive_punter' + 'backup_database': False, + 'database_uri': 'mongodb://localhost:27017/predictive_punter', + 'date_from': datetime.now(), + 'date_to': datetime.now(), + 'redis_uri': 'redis://localhost:6379/predictive_punter' } - opts, args = getopt(args, 'd:r:', ['database-uri=', 'redis-uri=']) + opts, args = getopt(args, 'bd:r:', ['backup-database', 'database-uri=', 'redis-uri=']) for opt, arg in opts: - if opt in ('-d', '--database-uri'): + if opt in ('-b', '--backup-database'): + config['backup_database'] = True + + elif opt in ('-d', '--database-uri'): config['database_uri'] = arg elif opt in ('-r', '--redis-uri'): @@ -55,7 +59,11 @@ def parse_args(cls, args): def __init__(self, *args, **kwargs): database_client = pymongo.MongoClient(kwargs['database_uri']) - database = database_client.get_default_database() + self.database = database_client.get_default_database() + + self.backup_database_name = None + if kwargs['backup_database'] == True: + self.backup_database_name = self.database.name + '_backup' http_client = None try: @@ -70,7 +78,21 @@ def __init__(self, *args, **kwargs): scraper = punters_client.Scraper(http_client, html_parser) - self.provider = racing_data.Provider(database, scraper) + self.provider = racing_data.Provider(self.database, scraper) + + def backup_database(self): + """Backup the database if backup_database is available""" + + if self.backup_database_name is not None: + self.database.client.drop_database(self.backup_database_name) + self.database.client.admin.command('copydb', fromdb=self.database.name, todb=self.backup_database_name) + + def restore_database(self): + """Restore the database if backup_database is available""" + + if self.backup_database_name is not None: + self.database.client.drop_database(self.database.name) + self.database.client.admin.command('copydb', fromdb=self.backup_database_name, todb=self.database.name) def process_collection(self, collection, target): """Asynchronously process all items in collection via target""" @@ -92,7 +114,15 @@ def process_dates(self, date_from, date_to): def process_date(self, date): """Process all racing data for the specified date""" - self.process_collection(self.provider.get_meets_by_date(date), self.process_meet) + try: + self.process_collection(self.provider.get_meets_by_date(date), self.process_meet) + + except BaseException: + self.restore_database() + raise + + else: + self.backup_database() def process_meet(self, meet): """Process the specified meet""" diff --git a/tests/scrape_test.py b/tests/scrape_test.py index 7551c90..f770989 100644 --- a/tests/scrape_test.py +++ b/tests/scrape_test.py @@ -3,6 +3,16 @@ import pytest +@pytest.fixture(scope='module') +def backup_database(database_uri): + + database_uri += '_backup' + database_name = database_uri.split('/')[-1] + database_client = pymongo.MongoClient(database_uri) + database_client.drop_database(database_name) + return database_client.get_default_database() + + @pytest.fixture(scope='module') def database(database_uri): @@ -21,7 +31,7 @@ def database_uri(): @pytest.fixture(scope='module') def scrape_command(database_uri): - predictive_punter.ScrapeCommand.main(['-d', database_uri, '2016-2-1', '2016-2-2']) + predictive_punter.ScrapeCommand.main(['-b', '-d', database_uri, '2016-2-1', '2016-2-2']) def count_distinct(collection, key, exclude=None): @@ -33,43 +43,44 @@ def count_distinct(collection, key, exclude=None): return len(values) -def test_meets(database, scrape_command): +def test_meets(database, backup_database, scrape_command): """The scrape command should populate the database with the expected number of meets""" - assert database['meets'].count() == 5 + assert database['meets'].count() == backup_database['meets'].count() == 5 -def test_races(database, scrape_command): +def test_races(database, backup_database, scrape_command): """The scrape command should populate the database with the expected number of races""" - assert database['races'].count() == 8 + 8 + 8 + 7 + 8 + assert database['races'].count() == backup_database['races'].count() == 8 + 8 + 8 + 7 + 8 -def test_runners(database, scrape_command): +def test_runners(database, backup_database, scrape_command): """The scrape command should populate the database with the expected number of runners""" - assert database['runners'].count() == 11 + 14 + 14 + 15 + 10 + 17 + 11 + 15 + 8 + 9 + 11 + 9 + 11 + 10 + 13 + 16 + 10 + 6 + 11 + 11 + 9 + 9 + 14 + 10 + 9 + 11 + 13 + 18 + 11 + 10 + 14 + 10 + 15 + 13 + 12 + 11 + 11 + 14 + 16 + assert database['runners'].count() == backup_database['runners'].count() == 11 + 14 + 14 + 15 + 10 + 17 + 11 + 15 + 8 + 9 + 11 + 9 + 11 + 10 + 13 + 16 + 10 + 6 + 11 + 11 + 9 + 9 + 14 + 10 + 9 + 11 + 13 + 18 + 11 + 10 + 14 + 10 + 15 + 13 + 12 + 11 + 11 + 14 + 16 -def test_horses(database, scrape_command): +def test_horses(database, backup_database, scrape_command): """The scrape command should populate the database with the expected number of horses""" - assert database['horses'].count() == count_distinct(database['runners'], 'horse_url', 'https://www.punters.com.au') + assert database['horses'].count() == backup_database['horses'].count() == count_distinct(database['runners'], 'horse_url', 'https://www.punters.com.au') -def test_jockeys(database, scrape_command): +def test_jockeys(database, backup_database, scrape_command): """The scrape command should populate the database with the expected number of jockeys""" - assert database['jockeys'].count() == count_distinct(database['runners'], 'jockey_url', 'https://www.punters.com.au/') + assert database['jockeys'].count() == backup_database['jockeys'].count() == count_distinct(database['runners'], 'jockey_url', 'https://www.punters.com.au/') -def test_trainers(database, scrape_command): +def test_trainers(database, backup_database, scrape_command): """The scrape command should populate the database with the expected number of trainers""" - assert database['trainers'].count() == count_distinct(database['runners'], 'trainer_url', 'https://www.punters.com.au/') + assert database['trainers'].count() == backup_database['trainers'].count() == count_distinct(database['runners'], 'trainer_url', 'https://www.punters.com.au/') -def test_performances(database, scrape_command): +def test_performances(database, backup_database, scrape_command): """The scrape command should populate the database with the expected number of performances""" assert database['performances'].count() >= database['runners'].count({'is_scratched': False}) + assert backup_database['performances'].count() == database['performances'].count()