From 809d46e0c983d41f89fccd711ea248575683f01f Mon Sep 17 00:00:00 2001 From: michellethomas Date: Thu, 30 Nov 2017 20:47:22 -0800 Subject: [PATCH] Improving speed of dashboard import (#3958) * Improve dashboard import * Updating tests for Slice.import_obj --- superset/models/core.py | 22 +++++++++++++--------- tests/import_export_tests.py | 18 ++++++++++-------- 2 files changed, 23 insertions(+), 17 deletions(-) diff --git a/superset/models/core.py b/superset/models/core.py index 2996625e2cbf..68c305f227bc 100644 --- a/superset/models/core.py +++ b/superset/models/core.py @@ -248,12 +248,17 @@ def get_viz(self): ) @classmethod - def import_obj(cls, slc_to_import, import_time=None): + def import_obj(cls, slc_to_import, slc_to_override, import_time=None): """Inserts or overrides slc in the database. remote_id and import_time fields in params_dict are set to track the slice origin and ensure correct overrides for multiple imports. Slice.perm is used to find the datasources and connect them. + + :param Slice slc_to_import: Slice object to import + :param Slice slc_to_override: Slice to replace, id matches remote_id + :returns: The resulting id for the imported slice + :rtype: int """ session = db.session make_transient(slc_to_import) @@ -261,13 +266,6 @@ def import_obj(cls, slc_to_import, import_time=None): slc_to_import.alter_params( remote_id=slc_to_import.id, import_time=import_time) - # find if the slice was already imported - slc_to_override = None - for slc in session.query(Slice).all(): - if ('remote_id' in slc.params_dict and - slc.params_dict['remote_id'] == slc_to_import.id): - slc_to_override = slc - slc_to_import = slc_to_import.copy() params = slc_to_import.params_dict slc_to_import.datasource_id = ConnectorRegistry.get_datasource_by_name( @@ -432,10 +430,16 @@ def alter_positions(dashboard, old_to_new_slc_id_dict): new_timed_refresh_immune_slices = [] new_expanded_slices = {} i_params_dict = dashboard_to_import.params_dict + remote_id_slice_map = { + slc.params_dict['remote_id']: slc + for slc in session.query(Slice).all() + if 'remote_id' in slc.params_dict + } for slc in slices: logging.info('Importing slice {} from the dashboard: {}'.format( slc.to_json(), dashboard_to_import.dashboard_title)) - new_slc_id = Slice.import_obj(slc, import_time=import_time) + remote_slc = remote_id_slice_map.get(slc.id) + new_slc_id = Slice.import_obj(slc, remote_slc, import_time=import_time) old_to_new_slc_id_dict[slc.id] = new_slc_id # update json metadata that deals with slice ids new_slc_id_str = '{}'.format(new_slc_id) diff --git a/tests/import_export_tests.py b/tests/import_export_tests.py index 0710cacecec5..2a9f069cb138 100644 --- a/tests/import_export_tests.py +++ b/tests/import_export_tests.py @@ -248,7 +248,7 @@ def test_export_2_dashboards(self): def test_import_1_slice(self): expected_slice = self.create_slice('Import Me', id=10001) - slc_id = models.Slice.import_obj(expected_slice, import_time=1989) + slc_id = models.Slice.import_obj(expected_slice, None, import_time=1989) slc = self.get_slice(slc_id) self.assertEquals(slc.datasource.perm, slc.perm) self.assert_slice_equals(expected_slice, slc) @@ -260,9 +260,9 @@ def test_import_2_slices_for_same_table(self): table_id = self.get_table_by_name('wb_health_population').id # table_id != 666, import func will have to find the table slc_1 = self.create_slice('Import Me 1', ds_id=666, id=10002) - slc_id_1 = models.Slice.import_obj(slc_1) + slc_id_1 = models.Slice.import_obj(slc_1, None) slc_2 = self.create_slice('Import Me 2', ds_id=666, id=10003) - slc_id_2 = models.Slice.import_obj(slc_2) + slc_id_2 = models.Slice.import_obj(slc_2, None) imported_slc_1 = self.get_slice(slc_id_1) imported_slc_2 = self.get_slice(slc_id_2) @@ -277,17 +277,19 @@ def test_import_2_slices_for_same_table(self): def test_import_slices_for_non_existent_table(self): with self.assertRaises(IndexError): models.Slice.import_obj(self.create_slice( - 'Import Me 3', id=10004, table_name='non_existent')) + 'Import Me 3', id=10004, table_name='non_existent'), None) def test_import_slices_override(self): slc = self.create_slice('Import Me New', id=10005) - slc_1_id = models.Slice.import_obj(slc, import_time=1990) + slc_1_id = models.Slice.import_obj(slc, None, import_time=1990) slc.slice_name = 'Import Me New' + imported_slc_1 = self.get_slice(slc_1_id) + slc_2 = self.create_slice('Import Me New', id=10005) slc_2_id = models.Slice.import_obj( - self.create_slice('Import Me New', id=10005), import_time=1990) + slc_2, imported_slc_1, import_time=1990) self.assertEquals(slc_1_id, slc_2_id) - imported_slc = self.get_slice(slc_2_id) - self.assert_slice_equals(slc, imported_slc) + imported_slc_2 = self.get_slice(slc_2_id) + self.assert_slice_equals(slc, imported_slc_2) def test_import_empty_dashboard(self): empty_dash = self.create_dashboard('empty_dashboard', id=10001)