From 309d50644e978ce989f66bb464c1ed6d75830f2b Mon Sep 17 00:00:00 2001 From: Daniel Vaz Gaspar Date: Wed, 31 Aug 2022 18:11:03 +0100 Subject: [PATCH] fix: dataset name change and permission change (#21161) * fix: dataset name change and permission change (cherry picked from commit 3f2e894af3dbb7a5c714de46240243b91d3d579c) --- superset/connectors/sqla/models.py | 7 +- superset/databases/commands/update.py | 34 + superset/datasets/commands/create.py | 11 +- superset/datasets/commands/delete.py | 24 - superset/models/core.py | 2 +- superset/security/manager.py | 638 +++++++++--- tests/integration_tests/datasets/api_tests.py | 13 + tests/integration_tests/security_tests.py | 930 ++++++++++++++---- tests/unit_tests/conftest.py | 3 + 9 files changed, 1278 insertions(+), 384 deletions(-) diff --git a/superset/connectors/sqla/models.py b/superset/connectors/sqla/models.py index b9ffb4792cfa..a7ce51a34492 100644 --- a/superset/connectors/sqla/models.py +++ b/superset/connectors/sqla/models.py @@ -2266,11 +2266,11 @@ def after_insert( For more context: https://github.com/apache/superset/issues/14909 """ - security_manager.set_perm(mapper, connection, sqla_table) + security_manager.dataset_after_insert(mapper, connection, sqla_table) sqla_table.write_shadow_dataset() @staticmethod - def after_delete( # pylint: disable=unused-argument + def after_delete( mapper: Mapper, connection: Connection, sqla_table: "SqlaTable", @@ -2287,6 +2287,7 @@ def after_delete( # pylint: disable=unused-argument For more context: https://github.com/apache/superset/issues/14909 """ + security_manager.dataset_after_delete(mapper, connection, sqla_table) session = inspect(sqla_table).session dataset = ( session.query(NewDataset).filter_by(uuid=sqla_table.uuid).one_or_none() @@ -2313,7 +2314,7 @@ def after_update( For more context: https://github.com/apache/superset/issues/14909 """ # set permissions - security_manager.set_perm(mapper, connection, sqla_table) + security_manager.dataset_after_update(mapper, connection, sqla_table) inspector = inspect(sqla_table) session = inspector.session diff --git a/superset/databases/commands/update.py b/superset/databases/commands/update.py index f30adf00015e..80e3a9b54e61 100644 --- a/superset/databases/commands/update.py +++ b/superset/databases/commands/update.py @@ -32,6 +32,7 @@ from superset.databases.dao import DatabaseDAO from superset.extensions import db, security_manager from superset.models.core import Database +from superset.utils.core import DatasourceType logger = logging.getLogger(__name__) @@ -66,8 +67,10 @@ def run(self) -> Model: except Exception as ex: db.session.rollback() raise DatabaseConnectionFailedError() from ex + # Update database schema permissions new_schemas: List[str] = [] + for schema in schemas: old_view_menu_name = security_manager.get_schema_perm( old_database_name, schema @@ -81,6 +84,10 @@ def run(self) -> Model: # Update the schema permission if the database name changed if schema_pvm and old_database_name != database.database_name: schema_pvm.view_menu.name = new_view_menu_name + + self._propagate_schema_permissions( + old_view_menu_name, new_view_menu_name + ) else: new_schemas.append(schema) for schema in new_schemas: @@ -94,6 +101,33 @@ def run(self) -> Model: raise DatabaseUpdateFailedError() from ex return database + @staticmethod + def _propagate_schema_permissions( + old_view_menu_name: str, new_view_menu_name: str + ) -> None: + from superset.connectors.sqla.models import ( # pylint: disable=import-outside-toplevel + SqlaTable, + ) + from superset.models.slice import ( # pylint: disable=import-outside-toplevel + Slice, + ) + + # Update schema_perm on all datasets + datasets = ( + db.session.query(SqlaTable) + .filter(SqlaTable.schema_perm == old_view_menu_name) + .all() + ) + for dataset in datasets: + dataset.schema_perm = new_view_menu_name + charts = db.session.query(Slice).filter( + Slice.datasource_type == DatasourceType.TABLE, + Slice.datasource_id == dataset.id, + ) + # Update schema_perm on all charts + for chart in charts: + chart.schema_perm = new_view_menu_name + def validate(self) -> None: exceptions: List[ValidationError] = [] # Validate/populate model exists diff --git a/superset/datasets/commands/create.py b/superset/datasets/commands/create.py index b638581abe08..1fa2e0ccf766 100644 --- a/superset/datasets/commands/create.py +++ b/superset/datasets/commands/create.py @@ -31,7 +31,7 @@ TableNotFoundValidationError, ) from superset.datasets.dao import DatasetDAO -from superset.extensions import db, security_manager +from superset.extensions import db logger = logging.getLogger(__name__) @@ -47,15 +47,6 @@ def run(self) -> Model: dataset = DatasetDAO.create(self._properties, commit=False) # Updates columns and metrics from the dataset dataset.fetch_metadata(commit=False) - # Add datasource access permission - security_manager.add_permission_view_menu( - "datasource_access", dataset.get_perm() - ) - # Add schema access permission if exists - if dataset.schema: - security_manager.add_permission_view_menu( - "schema_access", dataset.schema_perm - ) db.session.commit() except (SQLAlchemyError, DAOCreateFailedError) as ex: logger.warning(ex, exc_info=True) diff --git a/superset/datasets/commands/delete.py b/superset/datasets/commands/delete.py index 9ab8f41a4270..6f9156795813 100644 --- a/superset/datasets/commands/delete.py +++ b/superset/datasets/commands/delete.py @@ -45,30 +45,6 @@ def run(self) -> Model: self.validate() try: dataset = DatasetDAO.delete(self._model, commit=False) - - view_menu = ( - security_manager.find_view_menu(self._model.get_perm()) - if self._model - else None - ) - - if view_menu: - permission_views = ( - db.session.query(security_manager.permissionview_model) - .filter_by(view_menu=view_menu) - .all() - ) - - for permission_view in permission_views: - db.session.delete(permission_view) - if view_menu: - db.session.delete(view_menu) - else: - if not view_menu: - logger.error( - "Could not find the data access permission for the dataset", - exc_info=True, - ) db.session.commit() except (SQLAlchemyError, DAODeleteFailedError) as ex: logger.exception(ex) diff --git a/superset/models/core.py b/superset/models/core.py index 2ecd68d182f3..8b937562a1d2 100755 --- a/superset/models/core.py +++ b/superset/models/core.py @@ -810,7 +810,7 @@ def get_dialect(self) -> Dialect: return sqla_url.get_dialect()() -sqla.event.listen(Database, "after_insert", security_manager.set_perm) +sqla.event.listen(Database, "after_insert", security_manager.database_after_insert) sqla.event.listen(Database, "after_update", security_manager.database_after_update) sqla.event.listen(Database, "after_delete", security_manager.database_after_delete) diff --git a/superset/security/manager.py b/superset/security/manager.py index a66e35e2d845..7ddbd0a44cd5 100644 --- a/superset/security/manager.py +++ b/superset/security/manager.py @@ -77,12 +77,18 @@ GuestTokenUser, GuestUser, ) -from superset.utils.core import DatasourceName, get_user_id, RowLevelSecurityFilterType +from superset.utils.core import ( + DatasourceName, + DatasourceType, + get_user_id, + RowLevelSecurityFilterType, +) from superset.utils.urls import get_url_host if TYPE_CHECKING: from superset.common.query_context import QueryContext from superset.connectors.base.models import BaseDatasource + from superset.connectors.sqla.models import SqlaTable from superset.models.core import Database from superset.models.dashboard import Dashboard from superset.models.sql_lab import Query @@ -941,16 +947,89 @@ def _is_granter_pvm( # pylint: disable=no-self-use return pvm.permission.name in {"can_override_role_permissions", "can_approve"} + def database_after_insert( + self, + mapper: Mapper, + connection: Connection, + target: "Database", + ) -> None: + """ + Handles permissions when a database is created. + Triggered by a SQLAlchemy after_insert event. + + We need to create: + - The database PVM + + :param mapper: The SQLA mapper + :param connection: The SQLA connection + :param target: The changed database object + :return: + """ + self._insert_pvm_on_sqla_event( + mapper, connection, "database_access", target.get_perm() + ) + def database_after_delete( self, mapper: Mapper, connection: Connection, target: "Database", ) -> None: + """ + Handles permissions update when a database is deleted. + Triggered by a SQLAlchemy after_delete event. + + We need to delete: + - The database PVM + + :param mapper: The SQLA mapper + :param connection: The SQLA connection + :param target: The changed database object + :return: + """ self._delete_vm_database_access( mapper, connection, target.id, target.database_name ) + def database_after_update( + self, + mapper: Mapper, + connection: Connection, + target: "Database", + ) -> None: + """ + Handles all permissions update when a database is changed. + Triggered by a SQLAlchemy after_update event. + + We need to update: + - The database PVM + - All datasets PVMs that reference the db, and it's local perm name + - All datasets local schema perm that reference the db. + - All charts local perm related with said datasets + - All charts local schema perm related with said datasets + + :param mapper: The SQLA mapper + :param connection: The SQLA connection + :param target: The changed database object + :return: + """ + # Check if database name has changed + state = inspect(target) + history = state.get_history("database_name", True) + if not history.has_changes() or not history.deleted: + return + + old_database_name = history.deleted[0] + # update database access permission + self._update_vm_database_access(mapper, connection, old_database_name, target) + # update datasource access + self._update_vm_datasources_access( + mapper, connection, old_database_name, target + ) + # Note schema permissions are updated at the API level + # (database.commands.update). Since we need to fetch all existing schemas from + # the db + def _delete_vm_database_access( self, mapper: Mapper, @@ -958,29 +1037,11 @@ def _delete_vm_database_access( database_id: int, database_name: str, ) -> None: - view_menu_table = self.viewmenu_model.__table__ # pylint: disable=no-member - permission_view_menu_table = ( - self.permissionview_model.__table__ # pylint: disable=no-member - ) view_menu_name = self.get_database_perm(database_id, database_name) # Clean database access permission - db_pvm = self.find_permission_view_menu("database_access", view_menu_name) - if not db_pvm: - logger.warning( - "Could not find previous database permission %s", - view_menu_name, - ) - return - connection.execute( - permission_view_menu_table.delete().where( - permission_view_menu_table.c.id == db_pvm.id - ) - ) - self.on_permission_after_delete(mapper, connection, db_pvm) - connection.execute( - view_menu_table.delete().where(view_menu_table.c.id == db_pvm.view_menu_id) + self._delete_pvm_on_sqla_event( + mapper, connection, "database_access", view_menu_name ) - # Clean database schema permissions schema_pvms = ( self.get_session.query(self.permissionview_model) @@ -991,17 +1052,7 @@ def _delete_vm_database_access( .all() ) for schema_pvm in schema_pvms: - connection.execute( - permission_view_menu_table.delete().where( - permission_view_menu_table.c.id == schema_pvm.id - ) - ) - self.on_permission_after_delete(mapper, connection, schema_pvm) - connection.execute( - view_menu_table.delete().where( - view_menu_table.c.id == schema_pvm.view_menu_id - ) - ) + self._delete_pvm_on_sqla_event(mapper, connection, pvm=schema_pvm) def _update_vm_database_access( self, @@ -1010,6 +1061,15 @@ def _update_vm_database_access( old_database_name: str, target: "Database", ) -> Optional[ViewMenu]: + """ + Helper method that Updates all database access permission + when a database name changes. + + :param connection: Current connection (called on SQLAlchemy event listener scope) + :param old_database_name: the old database name + :param target: The database object + :return: A list of changed view menus (permission resource names) + """ view_menu_table = self.viewmenu_model.__table__ # pylint: disable=no-member new_database_name = target.database_name old_view_menu_name = self.get_database_perm(target.id, old_database_name) @@ -1020,6 +1080,9 @@ def _update_vm_database_access( "Could not find previous database permission %s", old_view_menu_name, ) + self._insert_pvm_on_sqla_event( + mapper, connection, "database_access", new_view_menu_name + ) return None new_updated_pvm = self.find_permission_view_menu( "database_access", new_view_menu_name @@ -1051,11 +1114,12 @@ def _update_vm_datasources_access( # pylint: disable=too-many-locals target: "Database", ) -> List[ViewMenu]: """ - Updates all datasource access permission when a database name changes + Helper method that Updates all datasource access permission + when a database name changes. :param connection: Current connection (called on SQLAlchemy event listener scope) :param old_database_name: the old database name - :param target: The new database name + :param target: The database object :return: A list of changed view menus (permission resource names) """ from superset.connectors.sqla.models import ( # pylint: disable=import-outside-toplevel @@ -1090,6 +1154,9 @@ def _update_vm_datasources_access( # pylint: disable=too-many-locals .where(view_menu_table.c.name == old_dataset_vm_name) .values(name=new_dataset_vm_name) ) + # After update refresh + new_dataset_view_menu = self.find_view_menu(new_dataset_vm_name) + # Update dataset (SqlaTable perm field) connection.execute( sqlatable_table.update() @@ -1106,83 +1173,417 @@ def _update_vm_datasources_access( # pylint: disable=too-many-locals .values(perm=new_dataset_vm_name) ) self.on_view_menu_after_update(mapper, connection, new_dataset_view_menu) - updated_view_menus.append(self.find_view_menu(new_dataset_view_menu)) + updated_view_menus.append(new_dataset_view_menu) return updated_view_menus - def database_after_update( + def dataset_after_insert( self, mapper: Mapper, connection: Connection, - target: "Database", + target: "SqlaTable", ) -> None: - # Check if database name has changed + """ + Handles permission creation when a dataset is inserted. + Triggered by a SQLAlchemy after_insert event. + + We need to create: + - The dataset PVM and set local and schema perm + + :param mapper: The SQLA mapper + :param connection: The SQLA connection + :param target: The changed dataset object + :return: + """ + try: + dataset_perm = target.get_perm() + except DatasetInvalidPermissionEvaluationException: + logger.warning("Dataset has no database refusing to set permission") + return + dataset_table = target.__table__ + + self._insert_pvm_on_sqla_event( + mapper, connection, "datasource_access", dataset_perm + ) + if target.perm != dataset_perm: + target.perm = dataset_perm + connection.execute( + dataset_table.update() + .where(dataset_table.c.id == target.id) + .values(perm=dataset_perm) + ) + + if target.schema: + dataset_schema_perm = self.get_schema_perm( + target.database.database_name, target.schema + ) + self._insert_pvm_on_sqla_event( + mapper, connection, "schema_access", dataset_schema_perm + ) + target.schema_perm = dataset_schema_perm + connection.execute( + dataset_table.update() + .where(dataset_table.c.id == target.id) + .values(schema_perm=dataset_schema_perm) + ) + + def dataset_after_delete( + self, + mapper: Mapper, + connection: Connection, + target: "SqlaTable", + ) -> None: + """ + Handles permissions update when a dataset is deleted. + Triggered by a SQLAlchemy after_delete event. + + We need to delete: + - The dataset PVM + + :param mapper: The SQLA mapper + :param connection: The SQLA connection + :param target: The changed dataset object + :return: + """ + dataset_vm_name = self.get_dataset_perm( + target.id, target.table_name, target.database.database_name + ) + self._delete_pvm_on_sqla_event( + mapper, connection, "datasource_access", dataset_vm_name + ) + + def dataset_after_update( + self, + mapper: Mapper, + connection: Connection, + target: "SqlaTable", + ) -> None: + """ + Handles all permissions update when a dataset is changed. + Triggered by a SQLAlchemy after_update event. + + We need to update: + - The dataset PVM and local perm + - All charts local perm related with said datasets + - All charts local schema perm related with said datasets + + :param mapper: The SQLA mapper + :param connection: The SQLA connection + :param target: The changed dataset object + :return: + """ + # Check if watched fields have changed state = inspect(target) - history = state.get_history("database_name", True) - if not history.has_changes() or not history.deleted: + history_database = state.get_history("database_id", True) + history_table_name = state.get_history("table_name", True) + history_schema = state.get_history("schema", True) + + # When database name changes + if history_database.has_changes() and history_database.deleted: + new_dataset_vm_name = self.get_dataset_perm( + target.id, target.table_name, target.database.database_name + ) + self._update_dataset_perm( + mapper, connection, target.perm, new_dataset_vm_name, target + ) + + # Updates schema permissions + new_dataset_schema_name = self.get_schema_perm( + target.database.database_name, target.schema + ) + self._update_dataset_schema_perm( + mapper, + connection, + new_dataset_schema_name, + target, + ) + + # When table name changes + if history_table_name.has_changes() and history_table_name.deleted: + old_dataset_name = history_table_name.deleted[0] + new_dataset_vm_name = self.get_dataset_perm( + target.id, target.table_name, target.database.database_name + ) + old_dataset_vm_name = self.get_dataset_perm( + target.id, old_dataset_name, target.database.database_name + ) + self._update_dataset_perm( + mapper, connection, old_dataset_vm_name, new_dataset_vm_name, target + ) + + # When schema changes + if history_schema.has_changes() and history_schema.deleted: + new_dataset_schema_name = self.get_schema_perm( + target.database.database_name, target.schema + ) + self._update_dataset_schema_perm( + mapper, + connection, + new_dataset_schema_name, + target, + ) + + def _update_dataset_schema_perm( + self, + mapper: Mapper, + connection: Connection, + new_schema_permission_name: Optional[str], + target: "SqlaTable", + ) -> None: + """ + Helper method that is called by SQLAlchemy events on datasets to update + a new schema permission name, propagates the name change to datasets and charts. + + If the schema permission name does not exist already has a PVM, + creates a new one. + + :param mapper: The SQLA event mapper + :param connection: The SQLA connection + :param new_schema_permission_name: The new schema permission name that changed + :param target: Dataset that was updated + :return: + """ + from superset.connectors.sqla.models import ( # pylint: disable=import-outside-toplevel + SqlaTable, + ) + from superset.models.slice import ( # pylint: disable=import-outside-toplevel + Slice, + ) + + sqlatable_table = SqlaTable.__table__ # pylint: disable=no-member + chart_table = Slice.__table__ # pylint: disable=no-member + + # insert new schema PVM if it does not exist + self._insert_pvm_on_sqla_event( + mapper, connection, "schema_access", new_schema_permission_name + ) + + # Update dataset (SqlaTable schema_perm field) + connection.execute( + sqlatable_table.update() + .where( + sqlatable_table.c.id == target.id, + ) + .values(schema_perm=new_schema_permission_name) + ) + + # Update charts (Slice schema_perm field) + connection.execute( + chart_table.update() + .where( + chart_table.c.datasource_id == target.id, + chart_table.c.datasource_type == DatasourceType.TABLE, + ) + .values(schema_perm=new_schema_permission_name) + ) + + def _update_dataset_perm( # pylint: disable=too-many-arguments + self, + mapper: Mapper, + connection: Connection, + old_permission_name: Optional[str], + new_permission_name: Optional[str], + target: "SqlaTable", + ) -> None: + """ + Helper method that is called by SQLAlchemy events on datasets to update + a permission name change, propagates the name change to VM, datasets and charts. + + :param mapper: + :param connection: + :param old_permission_name + :param new_permission_name: + :param target: + :return: + """ + from superset.connectors.sqla.models import ( # pylint: disable=import-outside-toplevel + SqlaTable, + ) + from superset.models.slice import ( # pylint: disable=import-outside-toplevel + Slice, + ) + + view_menu_table = self.viewmenu_model.__table__ # pylint: disable=no-member + sqlatable_table = SqlaTable.__table__ # pylint: disable=no-member + chart_table = Slice.__table__ # pylint: disable=no-member + + new_dataset_view_menu = self.find_view_menu(new_permission_name) + if new_dataset_view_menu: return + # Update VM + connection.execute( + view_menu_table.update() + .where(view_menu_table.c.name == old_permission_name) + .values(name=new_permission_name) + ) + # VM changed, so call hook + new_dataset_view_menu = self.find_view_menu(new_permission_name) + self.on_view_menu_after_update(mapper, connection, new_dataset_view_menu) + # Update dataset (SqlaTable perm field) + connection.execute( + sqlatable_table.update() + .where( + sqlatable_table.c.id == target.id, + ) + .values(perm=new_permission_name) + ) + # Update charts (Slice perm field) + connection.execute( + chart_table.update() + .where( + chart_table.c.datasource_type == DatasourceType.TABLE, + chart_table.c.datasource_id == target.id, + ) + .values(perm=new_permission_name) + ) - old_database_name = history.deleted[0] - # update database access permission - self._update_vm_database_access(mapper, connection, old_database_name, target) - # update datasource access - self._update_vm_datasources_access( - mapper, connection, old_database_name, target + def _delete_pvm_on_sqla_event( # pylint: disable=too-many-arguments + self, + mapper: Mapper, + connection: Connection, + permission_name: Optional[str] = None, + view_menu_name: Optional[str] = None, + pvm: Optional[PermissionView] = None, + ) -> None: + """ + Helper method that is called by SQLAlchemy events. + Deletes a PVM. + + :param mapper: The SQLA event mapper + :param connection: The SQLA connection + :param permission_name: e.g.: datasource_access, schema_access + :param view_menu_name: e.g. [db1].[public] + :param pvm: Can be called with the actual PVM already + :return: + """ + view_menu_table = self.viewmenu_model.__table__ # pylint: disable=no-member + permission_view_menu_table = ( + self.permissionview_model.__table__ # pylint: disable=no-member ) - def on_view_menu_after_update( - self, mapper: Mapper, connection: Connection, target: ViewMenu + if not pvm: + pvm = self.find_permission_view_menu(permission_name, view_menu_name) + if not pvm: + return + # Delete Any Role to PVM association + connection.execute( + assoc_permissionview_role.delete().where( + assoc_permissionview_role.c.permission_view_id == pvm.id + ) + ) + # Delete the database access PVM + connection.execute( + permission_view_menu_table.delete().where( + permission_view_menu_table.c.id == pvm.id + ) + ) + self.on_permission_view_after_delete(mapper, connection, pvm) + connection.execute( + view_menu_table.delete().where(view_menu_table.c.id == pvm.view_menu_id) + ) + + def _insert_pvm_on_sqla_event( + self, + mapper: Mapper, + connection: Connection, + permission_name: str, + view_menu_name: Optional[str], ) -> None: """ - Hook that allows for further custom operations when a new ViewMenu - is updated + Helper method that is called by SQLAlchemy events. + Inserts a new PVM (if it does not exist already) - Since the update may be performed on after_update event. We cannot - update ViewMenus using a session, so any SQLAlchemy events hooked to - `ViewMenu` will not trigger an after_update. + :param mapper: The SQLA event mapper + :param connection: The SQLA connection + :param permission_name: e.g.: datasource_access, schema_access + :param view_menu_name: e.g. [db1].[public] + :return: + """ + permission_table = self.permission_model.__table__ # pylint: disable=no-member + view_menu_table = self.viewmenu_model.__table__ # pylint: disable=no-member + permission_view_table = ( + self.permissionview_model.__table__ # pylint: disable=no-member + ) + if not view_menu_name: + return + pvm = self.find_permission_view_menu(permission_name, view_menu_name) + if pvm: + return + permission = self.find_permission(permission_name) + view_menu = self.find_view_menu(view_menu_name) + if not permission: + connection.execute(permission_table.insert().values(name=permission_name)) + permission = self.find_permission(permission_name) + self.on_permission_after_insert(mapper, connection, permission) + if not view_menu: + connection.execute(view_menu_table.insert().values(name=view_menu_name)) + view_menu = self.find_view_menu(view_menu_name) + self.on_view_menu_after_insert(mapper, connection, view_menu) + connection.execute( + permission_view_table.insert().values( + permission_id=permission.id, view_menu_id=view_menu.id + ) + ) + permission = self.find_permission_view_menu(permission_name, view_menu_name) + self.on_permission_view_after_insert(mapper, connection, permission) + + def on_role_after_update( + self, mapper: Mapper, connection: Connection, target: Role + ) -> None: + """ + Hook that allows for further custom operations when a Role update + is created by SQLAlchemy events. + + On SQLAlchemy after_insert events, we cannot + create new view_menu's using a session, so any SQLAlchemy events hooked to + `ViewMenu` will not trigger an after_insert. :param mapper: The table mapper :param connection: The DB-API connection - :param target: The mapped instance being persisted + :param target: The mapped instance being changed """ - def on_permission_after_delete( - self, mapper: Mapper, connection: Connection, target: Permission + def on_view_menu_after_insert( + self, mapper: Mapper, connection: Connection, target: ViewMenu ) -> None: """ - Hook that allows for further custom operations when a permission - is deleted by sqlalchemy events. + Hook that allows for further custom operations when a new ViewMenu + is created by set_perm. + + On SQLAlchemy after_insert events, we cannot + create new view_menu's using a session, so any SQLAlchemy events hooked to + `ViewMenu` will not trigger an after_insert. :param mapper: The table mapper :param connection: The DB-API connection :param target: The mapped instance being persisted """ - def on_permission_after_insert( - self, mapper: Mapper, connection: Connection, target: Permission + def on_view_menu_after_update( + self, mapper: Mapper, connection: Connection, target: ViewMenu ) -> None: """ - Hook that allows for further custom operations when a new permission - is created by set_perm. + Hook that allows for further custom operations when a new ViewMenu + is updated - Since set_perm is executed by SQLAlchemy after_insert events, we cannot - create new permissions using a session, so any SQLAlchemy events hooked to - `Permission` will not trigger an after_insert. + Since the update may be performed on after_update event. We cannot + update ViewMenus using a session, so any SQLAlchemy events hooked to + `ViewMenu` will not trigger an after_update. :param mapper: The table mapper :param connection: The DB-API connection :param target: The mapped instance being persisted """ - def on_view_menu_after_insert( - self, mapper: Mapper, connection: Connection, target: ViewMenu + def on_permission_after_insert( + self, mapper: Mapper, connection: Connection, target: Permission ) -> None: """ - Hook that allows for further custom operations when a new ViewMenu + Hook that allows for further custom operations when a new permission is created by set_perm. Since set_perm is executed by SQLAlchemy after_insert events, we cannot - create new view_menu's using a session, so any SQLAlchemy events hooked to - `ViewMenu` will not trigger an after_insert. + create new permissions using a session, so any SQLAlchemy events hooked to + `Permission` will not trigger an after_insert. :param mapper: The table mapper :param connection: The DB-API connection @@ -1194,9 +1595,9 @@ def on_permission_view_after_insert( ) -> None: """ Hook that allows for further custom operations when a new PermissionView - is created by set_perm. + is created by SQLAlchemy events. - Since set_perm is executed by SQLAlchemy after_insert events, we cannot + On SQLAlchemy after_insert events, we cannot create new pvms using a session, so any SQLAlchemy events hooked to `PermissionView` will not trigger an after_insert. @@ -1205,98 +1606,21 @@ def on_permission_view_after_insert( :param target: The mapped instance being persisted """ - def set_perm( - self, mapper: Mapper, connection: Connection, target: "BaseDatasource" + def on_permission_view_after_delete( + self, mapper: Mapper, connection: Connection, target: PermissionView ) -> None: """ - Set the datasource permissions. + Hook that allows for further custom operations when a new PermissionView + is delete by SQLAlchemy events. + + On SQLAlchemy after_delete events, we cannot + delete pvms using a session, so any SQLAlchemy events hooked to + `PermissionView` will not trigger an after_delete. :param mapper: The table mapper :param connection: The DB-API connection :param target: The mapped instance being persisted """ - try: - target_get_perm = target.get_perm() - except DatasetInvalidPermissionEvaluationException: - logger.warning("Dataset has no database refusing to set permission") - return - permission_table = self.permission_model.__table__ # pylint: disable=no-member - view_menu_table = self.viewmenu_model.__table__ # pylint: disable=no-member - link_table = target.__table__ - if target.perm != target_get_perm: - connection.execute( - link_table.update() - .where(link_table.c.id == target.id) - .values(perm=target_get_perm) - ) - connection.execute( - permission_table.update() - .where(permission_table.c.name == target.perm) - .values(name=target_get_perm) - ) - connection.execute( - view_menu_table.update() - .where(view_menu_table.c.name == target.perm) - .values(name=target_get_perm) - ) - target.perm = target_get_perm - - # check schema perm for datasets - if ( - hasattr(target, "schema_perm") - and target.schema_perm != target.get_schema_perm() - ): - connection.execute( - link_table.update() - .where(link_table.c.id == target.id) - .values(schema_perm=target.get_schema_perm()) - ) - target.schema_perm = target.get_schema_perm() - - pvm_names = [] - if target.__tablename__ in {"dbs", "clusters"}: - pvm_names.append(("database_access", target_get_perm)) - else: - pvm_names.append(("datasource_access", target_get_perm)) - if target.schema: - pvm_names.append(("schema_access", target.get_schema_perm())) - - # TODO(bogdan): modify slice permissions as well. - for permission_name, view_menu_name in pvm_names: - permission = self.find_permission(permission_name) - view_menu = self.find_view_menu(view_menu_name) - pv = None - - if not permission: - connection.execute( - permission_table.insert().values(name=permission_name) - ) - permission = self.find_permission(permission_name) - self.on_permission_after_insert(mapper, connection, permission) - if not view_menu: - connection.execute(view_menu_table.insert().values(name=view_menu_name)) - view_menu = self.find_view_menu(view_menu_name) - self.on_view_menu_after_insert(mapper, connection, view_menu) - - if permission and view_menu: - pv = ( - self.get_session.query(self.permissionview_model) - .filter_by(permission=permission, view_menu=view_menu) - .first() - ) - if not pv and permission and view_menu: - permission_view_table = ( - self.permissionview_model.__table__ # pylint: disable=no-member - ) - connection.execute( - permission_view_table.insert().values( - permission_id=permission.id, view_menu_id=view_menu.id - ) - ) - permission = self.find_permission_view_menu( - permission_name, view_menu_name - ) - self.on_permission_view_after_insert(mapper, connection, permission) def raise_for_access( # pylint: disable=too-many-arguments,too-many-locals diff --git a/tests/integration_tests/datasets/api_tests.py b/tests/integration_tests/datasets/api_tests.py index d410797c52bc..019f07027fc5 100644 --- a/tests/integration_tests/datasets/api_tests.py +++ b/tests/integration_tests/datasets/api_tests.py @@ -260,6 +260,15 @@ def test_get_dataset_related_database_gamma(self): if backend() == "sqlite": return + # Add main database access to gamma role + main_db = get_main_database() + main_db_pvm = security_manager.find_permission_view_menu( + "database_access", main_db.perm + ) + gamma_role = security_manager.find_role("Gamma") + gamma_role.permissions.append(main_db_pvm) + db.session.commit() + self.login(username="gamma") uri = "api/v1/dataset/related/database" rv = self.client.get(uri) @@ -270,6 +279,10 @@ def test_get_dataset_related_database_gamma(self): main_db = get_main_database() assert filter(lambda x: x.text == main_db, response["result"]) != [] + # revert gamma permission + gamma_role.permissions.remove(main_db_pvm) + db.session.commit() + @pytest.mark.usefixtures("load_energy_table_with_slice") def test_get_dataset_item(self): """ diff --git a/tests/integration_tests/security_tests.py b/tests/integration_tests/security_tests.py index ebb1e65e36f4..26d7c6e772ab 100644 --- a/tests/integration_tests/security_tests.py +++ b/tests/integration_tests/security_tests.py @@ -28,10 +28,9 @@ import pytest from flask import current_app +from flask_appbuilder.security.sqla.models import Role from superset.datasource.dao import DatasourceDAO - from superset.models.dashboard import Dashboard - from superset import app, appbuilder, db, security_manager, viz from superset.connectors.sqla.models import SqlaTable from superset.errors import ErrorLevel, SupersetError, SupersetErrorType @@ -156,42 +155,43 @@ def tearDown(self): session.delete(security_manager.find_role(SCHEMA_ACCESS_ROLE)) session.commit() - def test_set_perm_sqla_table(self): + def test_after_insert_dataset(self): security_manager.on_view_menu_after_insert = Mock() security_manager.on_permission_view_after_insert = Mock() session = db.session + tmp_db1 = Database(database_name="tmp_db1", sqlalchemy_uri="sqlite://") + session.add(tmp_db1) + table = SqlaTable( schema="tmp_schema", table_name="tmp_perm_table", - database=get_example_database(), + database=tmp_db1, ) session.add(table) session.commit() - stored_table = ( - session.query(SqlaTable).filter_by(table_name="tmp_perm_table").one() - ) - self.assertEqual( - stored_table.perm, f"[examples].[tmp_perm_table](id:{stored_table.id})" - ) + table = session.query(SqlaTable).filter_by(table_name="tmp_perm_table").one() + self.assertEqual(table.perm, f"[tmp_db1].[tmp_perm_table](id:{table.id})") pvm_dataset = security_manager.find_permission_view_menu( - "datasource_access", stored_table.perm + "datasource_access", table.perm ) pvm_schema = security_manager.find_permission_view_menu( - "schema_access", stored_table.schema_perm + "schema_access", table.schema_perm ) + # Assert dataset permission is created and local perms are ok self.assertIsNotNone(pvm_dataset) - self.assertEqual(stored_table.schema_perm, "[examples].[tmp_schema]") + self.assertEqual(table.perm, f"[tmp_db1].[tmp_perm_table](id:{table.id})") + self.assertEqual(table.schema_perm, "[tmp_db1].[tmp_schema]") self.assertIsNotNone(pvm_schema) # assert on permission hooks view_menu_dataset = security_manager.find_view_menu( - f"[examples].[tmp_perm_table](id:{stored_table.id})" + f"[tmp_db1].[tmp_perm_table](id:{table.id})" ) - view_menu_schema = security_manager.find_view_menu(f"[examples].[tmp_schema]") + view_menu_schema = security_manager.find_view_menu(f"[tmp_db1].[tmp_schema]") security_manager.on_view_menu_after_insert.assert_has_calls( [ call(ANY, ANY, view_menu_dataset), @@ -205,103 +205,43 @@ def test_set_perm_sqla_table(self): ] ) - # table name change - orig_table_perm = stored_table.perm - stored_table.table_name = "tmp_perm_table_v2" + # Cleanup + session.delete(table) + session.delete(tmp_db1) session.commit() - stored_table = ( - session.query(SqlaTable).filter_by(table_name="tmp_perm_table_v2").one() - ) - self.assertEqual( - stored_table.perm, f"[examples].[tmp_perm_table_v2](id:{stored_table.id})" - ) - self.assertIsNone( - security_manager.find_permission_view_menu( - "datasource_access", orig_table_perm - ) - ) - self.assertIsNotNone( - security_manager.find_permission_view_menu( - "datasource_access", stored_table.perm - ) - ) - # no changes in schema - self.assertEqual(stored_table.schema_perm, "[examples].[tmp_schema]") - self.assertIsNotNone( - security_manager.find_permission_view_menu( - "schema_access", stored_table.schema_perm - ) - ) - # schema name change - stored_table.schema = "tmp_schema_v2" + def test_after_insert_dataset_rollback(self): + session = db.session + tmp_db1 = Database(database_name="tmp_db1", sqlalchemy_uri="sqlite://") + session.add(tmp_db1) session.commit() - stored_table = ( - session.query(SqlaTable).filter_by(table_name="tmp_perm_table_v2").one() - ) - self.assertEqual( - stored_table.perm, f"[examples].[tmp_perm_table_v2](id:{stored_table.id})" - ) - self.assertIsNotNone( - security_manager.find_permission_view_menu( - "datasource_access", stored_table.perm - ) - ) - # no changes in schema - self.assertEqual(stored_table.schema_perm, "[examples].[tmp_schema_v2]") - self.assertIsNotNone( - security_manager.find_permission_view_menu( - "schema_access", stored_table.schema_perm - ) - ) - # database change - new_db = Database(sqlalchemy_uri="sqlite://", database_name="tmp_db") - session.add(new_db) - stored_table.database = ( - session.query(Database).filter_by(database_name="tmp_db").one() - ) - session.commit() - stored_table = ( - session.query(SqlaTable).filter_by(table_name="tmp_perm_table_v2").one() - ) - self.assertEqual( - stored_table.perm, f"[tmp_db].[tmp_perm_table_v2](id:{stored_table.id})" - ) - self.assertIsNotNone( - security_manager.find_permission_view_menu( - "datasource_access", stored_table.perm - ) - ) - # no changes in schema - self.assertEqual(stored_table.schema_perm, "[tmp_db].[tmp_schema_v2]") - self.assertIsNotNone( - security_manager.find_permission_view_menu( - "schema_access", stored_table.schema_perm - ) + table = SqlaTable( + schema="tmp_schema", + table_name="tmp_table", + database=tmp_db1, ) + session.add(table) + session.flush() - # no schema - stored_table.schema = None - session.commit() - stored_table = ( - session.query(SqlaTable).filter_by(table_name="tmp_perm_table_v2").one() - ) - self.assertEqual( - stored_table.perm, f"[tmp_db].[tmp_perm_table_v2](id:{stored_table.id})" + pvm_dataset = security_manager.find_permission_view_menu( + "datasource_access", f"[tmp_db1].[tmp_table](id:{table.id})" ) - self.assertIsNotNone( - security_manager.find_permission_view_menu( - "datasource_access", stored_table.perm - ) + self.assertIsNotNone(pvm_dataset) + table_id = table.id + session.rollback() + + table = session.query(SqlaTable).filter_by(table_name="tmp_table").one_or_none() + self.assertIsNone(table) + pvm_dataset = security_manager.find_permission_view_menu( + "datasource_access", f"[tmp_db1].[tmp_table](id:{table_id})" ) - self.assertIsNone(stored_table.schema_perm) + self.assertIsNone(pvm_dataset) - session.delete(new_db) - session.delete(stored_table) + session.delete(tmp_db1) session.commit() - def test_set_perm_sqla_table_none(self): + def test_after_insert_dataset_table_none(self): session = db.session table = SqlaTable( schema="tmp_schema", @@ -327,126 +267,197 @@ def test_set_perm_sqla_table_none(self): "datasource_access", f"[None].[tmp_perm_table](id:{stored_table.id})" ) ) + + # Cleanup session.delete(table) session.commit() - def test_set_perm_database(self): + def test_after_insert_database(self): + security_manager.on_permission_view_after_insert = Mock() + session = db.session - database = Database(database_name="tmp_database", sqlalchemy_uri="sqlite://") - session.add(database) + tmp_db1 = Database(database_name="tmp_db1", sqlalchemy_uri="sqlite://") + session.add(tmp_db1) + + tmp_db1 = session.query(Database).filter_by(database_name="tmp_db1").one() + self.assertEqual(tmp_db1.perm, f"[tmp_db1].(id:{tmp_db1.id})") + tmp_db1_pvm = security_manager.find_permission_view_menu( + "database_access", tmp_db1.perm + ) + self.assertIsNotNone(tmp_db1_pvm) + + # Assert the hook is called + security_manager.on_permission_view_after_insert.assert_has_calls( + [ + call(ANY, ANY, tmp_db1_pvm), + ] + ) + session.delete(tmp_db1) + session.commit() - stored_db = ( - session.query(Database).filter_by(database_name="tmp_database").one() + def test_after_insert_database_rollback(self): + session = db.session + tmp_db1 = Database(database_name="tmp_db1", sqlalchemy_uri="sqlite://") + session.add(tmp_db1) + session.flush() + + pvm_database = security_manager.find_permission_view_menu( + "database_access", f"[tmp_db1].(id:{tmp_db1.id})" ) - self.assertEqual(stored_db.perm, f"[tmp_database].(id:{stored_db.id})") + self.assertIsNotNone(pvm_database) + session.rollback() + + pvm_database = security_manager.find_permission_view_menu( + "database_access", f"[tmp_db1](id:{tmp_db1.id})" + ) + self.assertIsNone(pvm_database) + + def test_after_update_database__perm_database_access(self): + security_manager.on_view_menu_after_update = Mock() + + session = db.session + tmp_db1 = Database(database_name="tmp_db1", sqlalchemy_uri="sqlite://") + session.add(tmp_db1) + session.commit() + tmp_db1 = session.query(Database).filter_by(database_name="tmp_db1").one() + self.assertIsNotNone( - security_manager.find_permission_view_menu( - "database_access", stored_db.perm - ) + security_manager.find_permission_view_menu("database_access", tmp_db1.perm) ) - stored_db.database_name = "tmp_database2" + tmp_db1.database_name = "tmp_db2" session.commit() - stored_db = ( - session.query(Database).filter_by(database_name="tmp_database2").one() + + # Assert that the old permission was updated + self.assertIsNone( + security_manager.find_permission_view_menu( + "database_access", f"[tmp_db1].(id:{tmp_db1.id})" + ) ) - self.assertEqual(stored_db.perm, f"[tmp_database2].(id:{stored_db.id})") + # Assert that the db permission was updated self.assertIsNotNone( security_manager.find_permission_view_menu( - "database_access", stored_db.perm + "database_access", f"[tmp_db2].(id:{tmp_db1.id})" ) ) - session.delete(stored_db) + # Assert the hook is called + tmp_db1_view_menu = security_manager.find_view_menu( + f"[tmp_db2].(id:{tmp_db1.id})" + ) + security_manager.on_view_menu_after_update.assert_has_calls( + [ + call(ANY, ANY, tmp_db1_view_menu), + ] + ) + + session.delete(tmp_db1) session.commit() - def test_after_update_database__perm_database_access(self): + def test_after_update_database_rollback(self): session = db.session - database = Database(database_name="tmp_database", sqlalchemy_uri="sqlite://") - session.add(database) + tmp_db1 = Database(database_name="tmp_db1", sqlalchemy_uri="sqlite://") + session.add(tmp_db1) session.commit() - stored_db = ( - session.query(Database).filter_by(database_name="tmp_database").one() - ) + tmp_db1 = session.query(Database).filter_by(database_name="tmp_db1").one() self.assertIsNotNone( - security_manager.find_permission_view_menu( - "database_access", stored_db.perm - ) + security_manager.find_permission_view_menu("database_access", tmp_db1.perm) ) - stored_db.database_name = "tmp_database2" - session.commit() + tmp_db1.database_name = "tmp_db2" + session.flush() # Assert that the old permission was updated self.assertIsNone( security_manager.find_permission_view_menu( - "database_access", f"[tmp_database].(id:{stored_db.id})" + "database_access", f"[tmp_db1].(id:{tmp_db1.id})" ) ) # Assert that the db permission was updated self.assertIsNotNone( security_manager.find_permission_view_menu( - "database_access", f"[tmp_database2].(id:{stored_db.id})" + "database_access", f"[tmp_db2].(id:{tmp_db1.id})" ) ) - session.delete(stored_db) + + session.rollback() + self.assertIsNotNone( + security_manager.find_permission_view_menu( + "database_access", f"[tmp_db1].(id:{tmp_db1.id})" + ) + ) + # Assert that the db permission was updated + self.assertIsNone( + security_manager.find_permission_view_menu( + "database_access", f"[tmp_db2].(id:{tmp_db1.id})" + ) + ) + + session.delete(tmp_db1) session.commit() def test_after_update_database__perm_database_access_exists(self): + security_manager.on_permission_view_after_delete = Mock() + session = db.session # Add a bogus existing permission before the change - database = Database(database_name="tmp_database", sqlalchemy_uri="sqlite://") - session.add(database) + tmp_db1 = Database(database_name="tmp_db1", sqlalchemy_uri="sqlite://") + session.add(tmp_db1) session.commit() - stored_db = ( - session.query(Database).filter_by(database_name="tmp_database").one() - ) + tmp_db1 = session.query(Database).filter_by(database_name="tmp_db1").one() security_manager.add_permission_view_menu( - "database_access", f"[tmp_database2].(id:{stored_db.id})" + "database_access", f"[tmp_db2].(id:{tmp_db1.id})" ) self.assertIsNotNone( - security_manager.find_permission_view_menu( - "database_access", stored_db.perm - ) + security_manager.find_permission_view_menu("database_access", tmp_db1.perm) ) - stored_db.database_name = "tmp_database2" + tmp_db1.database_name = "tmp_db2" session.commit() # Assert that the old permission was updated self.assertIsNone( security_manager.find_permission_view_menu( - "database_access", f"[tmp_database].(id:{stored_db.id})" + "database_access", f"[tmp_db1].(id:{tmp_db1.id})" ) ) # Assert that the db permission was updated self.assertIsNotNone( security_manager.find_permission_view_menu( - "database_access", f"[tmp_database2].(id:{stored_db.id})" + "database_access", f"[tmp_db2].(id:{tmp_db1.id})" ) ) - session.delete(stored_db) + + security_manager.on_permission_view_after_delete.assert_has_calls( + [ + call(ANY, ANY, ANY), + ] + ) + + session.delete(tmp_db1) session.commit() def test_after_update_database__perm_datasource_access(self): + security_manager.on_view_menu_after_update = Mock() + session = db.session - database = Database(database_name="tmp_database", sqlalchemy_uri="sqlite://") - session.add(database) + tmp_db1 = Database(database_name="tmp_db1", sqlalchemy_uri="sqlite://") + session.add(tmp_db1) session.commit() table1 = SqlaTable( schema="tmp_schema", table_name="tmp_table1", - database=database, + database=tmp_db1, ) session.add(table1) table2 = SqlaTable( schema="tmp_schema", table_name="tmp_table2", - database=database, + database=tmp_db1, ) session.add(table2) session.commit() @@ -465,81 +476,633 @@ def test_after_update_database__perm_datasource_access(self): # assert initial perms self.assertIsNotNone( security_manager.find_permission_view_menu( - "datasource_access", f"[tmp_database].[tmp_table1](id:{table1.id})" + "datasource_access", f"[tmp_db1].[tmp_table1](id:{table1.id})" ) ) self.assertIsNotNone( security_manager.find_permission_view_menu( - "datasource_access", f"[tmp_database].[tmp_table2](id:{table2.id})" + "datasource_access", f"[tmp_db1].[tmp_table2](id:{table2.id})" ) ) - self.assertEqual(slice1.perm, f"[tmp_database].[tmp_table1](id:{table1.id})") - self.assertEqual(table1.perm, f"[tmp_database].[tmp_table1](id:{table1.id})") - self.assertEqual(table2.perm, f"[tmp_database].[tmp_table2](id:{table2.id})") + self.assertEqual(slice1.perm, f"[tmp_db1].[tmp_table1](id:{table1.id})") + self.assertEqual(table1.perm, f"[tmp_db1].[tmp_table1](id:{table1.id})") + self.assertEqual(table2.perm, f"[tmp_db1].[tmp_table2](id:{table2.id})") - stored_db = ( - session.query(Database).filter_by(database_name="tmp_database").one() - ) - stored_db.database_name = "tmp_database2" + # Refresh and update the database name + tmp_db1 = session.query(Database).filter_by(database_name="tmp_db1").one() + tmp_db1.database_name = "tmp_db2" session.commit() # Assert that the old permissions were updated self.assertIsNone( security_manager.find_permission_view_menu( - "datasource_access", f"[tmp_database].[tmp_table1](id:{table1.id})" + "datasource_access", f"[tmp_db1].[tmp_table1](id:{table1.id})" ) ) self.assertIsNone( security_manager.find_permission_view_menu( - "datasource_access", f"[tmp_database].[tmp_table2](id:{table2.id})" + "datasource_access", f"[tmp_db1].[tmp_table2](id:{table2.id})" ) ) # Assert that the db permission was updated self.assertIsNotNone( security_manager.find_permission_view_menu( - "datasource_access", f"[tmp_database2].[tmp_table1](id:{table1.id})" + "datasource_access", f"[tmp_db2].[tmp_table1](id:{table1.id})" ) ) self.assertIsNotNone( security_manager.find_permission_view_menu( - "datasource_access", f"[tmp_database2].[tmp_table2](id:{table2.id})" + "datasource_access", f"[tmp_db2].[tmp_table2](id:{table2.id})" ) ) - self.assertEqual(slice1.perm, f"[tmp_database2].[tmp_table1](id:{table1.id})") - self.assertEqual(table1.perm, f"[tmp_database2].[tmp_table1](id:{table1.id})") - self.assertEqual(table2.perm, f"[tmp_database2].[tmp_table2](id:{table2.id})") + self.assertEqual(slice1.perm, f"[tmp_db2].[tmp_table1](id:{table1.id})") + self.assertEqual(table1.perm, f"[tmp_db2].[tmp_table1](id:{table1.id})") + self.assertEqual(table2.perm, f"[tmp_db2].[tmp_table2](id:{table2.id})") + + # Assert hooks are called + tmp_db1_view_menu = security_manager.find_view_menu( + f"[tmp_db2].(id:{tmp_db1.id})" + ) + table1_view_menu = security_manager.find_view_menu( + f"[tmp_db2].[tmp_table1](id:{table1.id})" + ) + table2_view_menu = security_manager.find_view_menu( + f"[tmp_db2].[tmp_table2](id:{table2.id})" + ) + security_manager.on_view_menu_after_update.assert_has_calls( + [ + call(ANY, ANY, tmp_db1_view_menu), + call(ANY, ANY, table1_view_menu), + call(ANY, ANY, table2_view_menu), + ] + ) session.delete(slice1) session.delete(table1) session.delete(table2) - session.delete(stored_db) + session.delete(tmp_db1) session.commit() - def test_after_delete_database__perm_database_access(self): + def test_after_delete_database(self): session = db.session - database = Database(database_name="tmp_database", sqlalchemy_uri="sqlite://") - session.add(database) + tmp_db1 = Database(database_name="tmp_db1", sqlalchemy_uri="sqlite://") + session.add(tmp_db1) session.commit() - stored_db = ( - session.query(Database).filter_by(database_name="tmp_database").one() + tmp_db1 = session.query(Database).filter_by(database_name="tmp_db1").one() + + database_pvm = security_manager.find_permission_view_menu( + "database_access", tmp_db1.perm ) + self.assertIsNotNone(database_pvm) + role1 = Role(name="tmp_role1") + role1.permissions.append(database_pvm) + session.add(role1) + session.commit() - self.assertIsNotNone( + session.delete(tmp_db1) + session.commit() + + # Assert that PVM is removed from Role + role1 = security_manager.find_role("tmp_role1") + self.assertEqual(role1.permissions, []) + + # Assert that the old permission was updated + self.assertIsNone( security_manager.find_permission_view_menu( - "database_access", stored_db.perm + "database_access", f"[tmp_db1].(id:{tmp_db1.id})" ) ) - session.delete(stored_db) + + # Cleanup + session.delete(role1) + session.commit() + + def test_after_delete_database_rollback(self): + session = db.session + tmp_db1 = Database(database_name="tmp_db1", sqlalchemy_uri="sqlite://") + session.add(tmp_db1) + session.commit() + tmp_db1 = session.query(Database).filter_by(database_name="tmp_db1").one() + + database_pvm = security_manager.find_permission_view_menu( + "database_access", tmp_db1.perm + ) + self.assertIsNotNone(database_pvm) + role1 = Role(name="tmp_role1") + role1.permissions.append(database_pvm) + session.add(role1) session.commit() - # Assert that the old permission was updated + session.delete(tmp_db1) + session.flush() + + role1 = security_manager.find_role("tmp_role1") + self.assertEqual(role1.permissions, []) + self.assertIsNone( security_manager.find_permission_view_menu( - "database_access", f"[tmp_database].(id:{stored_db.id})" + "database_access", f"[tmp_db1].(id:{tmp_db1.id})" ) ) + session.rollback() + + # Test a rollback reverts everything + database_pvm = security_manager.find_permission_view_menu( + "database_access", f"[tmp_db1].(id:{tmp_db1.id})" + ) + + role1 = security_manager.find_role("tmp_role1") + self.assertEqual(role1.permissions, [database_pvm]) + + # Cleanup + session.delete(role1) + session.delete(tmp_db1) + session.commit() + + def test_after_delete_dataset(self): + security_manager.on_permission_view_after_delete = Mock() + + session = db.session + tmp_db = Database(database_name="tmp_db", sqlalchemy_uri="sqlite://") + session.add(tmp_db) + session.commit() + + table1 = SqlaTable( + schema="tmp_schema", + table_name="tmp_table1", + database=tmp_db, + ) + session.add(table1) + session.commit() + + table1_pvm = security_manager.find_permission_view_menu( + "datasource_access", f"[tmp_db].[tmp_table1](id:{table1.id})" + ) + self.assertIsNotNone(table1_pvm) + + role1 = Role(name="tmp_role1") + role1.permissions.append(table1_pvm) + session.add(role1) + session.commit() + + # refresh + table1 = session.query(SqlaTable).filter_by(table_name="tmp_table1").one() + + # Test delete + session.delete(table1) + session.commit() + + role1 = security_manager.find_role("tmp_role1") + self.assertEqual(role1.permissions, []) + + table1_pvm = security_manager.find_permission_view_menu( + "datasource_access", f"[tmp_db].[tmp_table1](id:{table1.id})" + ) + self.assertIsNone(table1_pvm) + table1_view_menu = security_manager.find_view_menu( + f"[tmp_db].[tmp_table1](id:{table1.id})" + ) + self.assertIsNone(table1_view_menu) + + # Assert the hook is called + security_manager.on_permission_view_after_delete.assert_has_calls( + [ + call(ANY, ANY, ANY), + ] + ) + + # cleanup + session.delete(role1) + session.delete(tmp_db) + session.commit() + + def test_after_delete_dataset_rollback(self): + session = db.session + tmp_db = Database(database_name="tmp_db", sqlalchemy_uri="sqlite://") + session.add(tmp_db) + session.commit() + + table1 = SqlaTable( + schema="tmp_schema", + table_name="tmp_table1", + database=tmp_db, + ) + session.add(table1) + session.commit() + + table1_pvm = security_manager.find_permission_view_menu( + "datasource_access", f"[tmp_db].[tmp_table1](id:{table1.id})" + ) + self.assertIsNotNone(table1_pvm) + + role1 = Role(name="tmp_role1") + role1.permissions.append(table1_pvm) + session.add(role1) + session.commit() + + # refresh + table1 = session.query(SqlaTable).filter_by(table_name="tmp_table1").one() + + # Test delete, permissions are correctly deleted + session.delete(table1) + session.flush() + + role1 = security_manager.find_role("tmp_role1") + self.assertEqual(role1.permissions, []) + + table1_pvm = security_manager.find_permission_view_menu( + "datasource_access", f"[tmp_db].[tmp_table1](id:{table1.id})" + ) + self.assertIsNone(table1_pvm) + + # Test rollback, permissions exist everything is correctly rollback + session.rollback() + role1 = security_manager.find_role("tmp_role1") + table1_pvm = security_manager.find_permission_view_menu( + "datasource_access", f"[tmp_db].[tmp_table1](id:{table1.id})" + ) + self.assertIsNotNone(table1_pvm) + self.assertEqual(role1.permissions, [table1_pvm]) + + # cleanup + session.delete(table1) + session.delete(role1) + session.delete(tmp_db) + session.commit() + + def test_after_update_dataset__name_changes(self): + security_manager.on_view_menu_after_update = Mock() + + session = db.session + tmp_db = Database(database_name="tmp_db", sqlalchemy_uri="sqlite://") + session.add(tmp_db) + session.commit() + + table1 = SqlaTable( + schema="tmp_schema", + table_name="tmp_table1", + database=tmp_db, + ) + session.add(table1) + session.commit() + + slice1 = Slice( + datasource_id=table1.id, + datasource_type=DatasourceType.TABLE, + datasource_name="tmp_table1", + slice_name="tmp_slice1", + ) + session.add(slice1) + session.commit() + + table1_pvm = security_manager.find_permission_view_menu( + "datasource_access", f"[tmp_db].[tmp_table1](id:{table1.id})" + ) + self.assertIsNotNone(table1_pvm) + + # refresh + table1 = session.query(SqlaTable).filter_by(table_name="tmp_table1").one() + # Test update + table1.table_name = "tmp_table1_changed" + session.commit() + + # Test old permission does not exist + old_table1_pvm = security_manager.find_permission_view_menu( + "datasource_access", f"[tmp_db].[tmp_table1](id:{table1.id})" + ) + self.assertIsNone(old_table1_pvm) + + # Test new permission exist + new_table1_pvm = security_manager.find_permission_view_menu( + "datasource_access", f"[tmp_db].[tmp_table1_changed](id:{table1.id})" + ) + self.assertIsNotNone(new_table1_pvm) + + # test dataset permission changed + changed_table1 = ( + session.query(SqlaTable).filter_by(table_name="tmp_table1_changed").one() + ) + self.assertEqual( + changed_table1.perm, f"[tmp_db].[tmp_table1_changed](id:{table1.id})" + ) + + # Test Chart permission changed + slice1 = session.query(Slice).filter_by(slice_name="tmp_slice1").one() + self.assertEqual(slice1.perm, f"[tmp_db].[tmp_table1_changed](id:{table1.id})") + + # Assert hook is called + view_menu_dataset = security_manager.find_view_menu( + f"[tmp_db].[tmp_table1_changed](id:{table1.id})" + ) + security_manager.on_view_menu_after_update.assert_has_calls( + [ + call(ANY, ANY, view_menu_dataset), + ] + ) + # cleanup + session.delete(slice1) + session.delete(table1) + session.delete(tmp_db) + session.commit() + + def test_after_update_dataset_rollback(self): + session = db.session + tmp_db = Database(database_name="tmp_db", sqlalchemy_uri="sqlite://") + session.add(tmp_db) + session.commit() + + table1 = SqlaTable( + schema="tmp_schema", + table_name="tmp_table1", + database=tmp_db, + ) + session.add(table1) + session.commit() + + slice1 = Slice( + datasource_id=table1.id, + datasource_type=DatasourceType.TABLE, + datasource_name="tmp_table1", + slice_name="tmp_slice1", + ) + session.add(slice1) + session.commit() + + # refresh + table1 = session.query(SqlaTable).filter_by(table_name="tmp_table1").one() + # Test update + table1.table_name = "tmp_table1_changed" + session.flush() + + # Test old permission does not exist + old_table1_pvm = security_manager.find_permission_view_menu( + "datasource_access", f"[tmp_db].[tmp_table1](id:{table1.id})" + ) + self.assertIsNone(old_table1_pvm) + + # Test new permission exist + new_table1_pvm = security_manager.find_permission_view_menu( + "datasource_access", f"[tmp_db].[tmp_table1_changed](id:{table1.id})" + ) + self.assertIsNotNone(new_table1_pvm) + + # Test rollback + session.rollback() + + old_table1_pvm = security_manager.find_permission_view_menu( + "datasource_access", f"[tmp_db].[tmp_table1](id:{table1.id})" + ) + self.assertIsNotNone(old_table1_pvm) + + # cleanup + session.delete(slice1) + session.delete(table1) + session.delete(tmp_db) + session.commit() + + def test_after_update_dataset__db_changes(self): + session = db.session + tmp_db1 = Database(database_name="tmp_db1", sqlalchemy_uri="sqlite://") + tmp_db2 = Database(database_name="tmp_db2", sqlalchemy_uri="sqlite://") + session.add(tmp_db1) + session.add(tmp_db2) + session.commit() + + table1 = SqlaTable( + schema="tmp_schema", + table_name="tmp_table1", + database=tmp_db1, + ) + session.add(table1) + session.commit() + + slice1 = Slice( + datasource_id=table1.id, + datasource_type=DatasourceType.TABLE, + datasource_name="tmp_table1", + slice_name="tmp_slice1", + ) + session.add(slice1) + session.commit() + + table1_pvm = security_manager.find_permission_view_menu( + "datasource_access", f"[tmp_db1].[tmp_table1](id:{table1.id})" + ) + self.assertIsNotNone(table1_pvm) + + # refresh + table1 = session.query(SqlaTable).filter_by(table_name="tmp_table1").one() + # Test update + table1.database = tmp_db2 + session.commit() + + # Test old permission does not exist + table1_pvm = security_manager.find_permission_view_menu( + "datasource_access", f"[tmp_db1].[tmp_table1](id:{table1.id})" + ) + self.assertIsNone(table1_pvm) + + # Test new permission exist + table1_pvm = security_manager.find_permission_view_menu( + "datasource_access", f"[tmp_db2].[tmp_table1](id:{table1.id})" + ) + self.assertIsNotNone(table1_pvm) + + # test dataset permission and schema permission changed + changed_table1 = ( + session.query(SqlaTable).filter_by(table_name="tmp_table1").one() + ) + self.assertEqual(changed_table1.perm, f"[tmp_db2].[tmp_table1](id:{table1.id})") + self.assertEqual(changed_table1.schema_perm, f"[tmp_db2].[tmp_schema]") + + # Test Chart permission changed + slice1 = session.query(Slice).filter_by(slice_name="tmp_slice1").one() + self.assertEqual(slice1.perm, f"[tmp_db2].[tmp_table1](id:{table1.id})") + self.assertEqual(slice1.schema_perm, f"[tmp_db2].[tmp_schema]") + + # cleanup + session.delete(slice1) + session.delete(table1) + session.delete(tmp_db1) + session.delete(tmp_db2) + session.commit() + + def test_after_update_dataset__schema_changes(self): + session = db.session + tmp_db1 = Database(database_name="tmp_db1", sqlalchemy_uri="sqlite://") + session.add(tmp_db1) + session.commit() + + table1 = SqlaTable( + schema="tmp_schema", + table_name="tmp_table1", + database=tmp_db1, + ) + session.add(table1) + session.commit() + + slice1 = Slice( + datasource_id=table1.id, + datasource_type=DatasourceType.TABLE, + datasource_name="tmp_table1", + slice_name="tmp_slice1", + ) + session.add(slice1) + session.commit() + + table1_pvm = security_manager.find_permission_view_menu( + "datasource_access", f"[tmp_db1].[tmp_table1](id:{table1.id})" + ) + self.assertIsNotNone(table1_pvm) + + # refresh + table1 = session.query(SqlaTable).filter_by(table_name="tmp_table1").one() + # Test update + table1.schema = "tmp_schema_changed" + session.commit() + + # Test permission still exists + table1_pvm = security_manager.find_permission_view_menu( + "datasource_access", f"[tmp_db1].[tmp_table1](id:{table1.id})" + ) + self.assertIsNotNone(table1_pvm) + + # test dataset schema permission changed + changed_table1 = ( + session.query(SqlaTable).filter_by(table_name="tmp_table1").one() + ) + self.assertEqual(changed_table1.perm, f"[tmp_db1].[tmp_table1](id:{table1.id})") + self.assertEqual(changed_table1.schema_perm, f"[tmp_db1].[tmp_schema_changed]") + + # Test Chart schema permission changed + slice1 = session.query(Slice).filter_by(slice_name="tmp_slice1").one() + self.assertEqual(slice1.perm, f"[tmp_db1].[tmp_table1](id:{table1.id})") + self.assertEqual(slice1.schema_perm, f"[tmp_db1].[tmp_schema_changed]") + + # cleanup + session.delete(slice1) + session.delete(table1) + session.delete(tmp_db1) + session.commit() + + def test_after_update_dataset__schema_none(self): + session = db.session + tmp_db1 = Database(database_name="tmp_db1", sqlalchemy_uri="sqlite://") + session.add(tmp_db1) + session.commit() + + table1 = SqlaTable( + schema="tmp_schema", + table_name="tmp_table1", + database=tmp_db1, + ) + session.add(table1) + session.commit() + + slice1 = Slice( + datasource_id=table1.id, + datasource_type=DatasourceType.TABLE, + datasource_name="tmp_table1", + slice_name="tmp_slice1", + ) + session.add(slice1) + session.commit() + + table1_pvm = security_manager.find_permission_view_menu( + "datasource_access", f"[tmp_db1].[tmp_table1](id:{table1.id})" + ) + self.assertIsNotNone(table1_pvm) + + # refresh + table1 = session.query(SqlaTable).filter_by(table_name="tmp_table1").one() + # Test update + table1.schema = None + session.commit() + + # refresh + table1 = session.query(SqlaTable).filter_by(table_name="tmp_table1").one() + + self.assertEqual(table1.perm, f"[tmp_db1].[tmp_table1](id:{table1.id})") + self.assertIsNone(table1.schema_perm) + + # cleanup + session.delete(slice1) + session.delete(table1) + session.delete(tmp_db1) + session.commit() + + def test_after_update_dataset__name_db_changes(self): + session = db.session + tmp_db1 = Database(database_name="tmp_db1", sqlalchemy_uri="sqlite://") + tmp_db2 = Database(database_name="tmp_db2", sqlalchemy_uri="sqlite://") + session.add(tmp_db1) + session.add(tmp_db2) + session.commit() + + table1 = SqlaTable( + schema="tmp_schema", + table_name="tmp_table1", + database=tmp_db1, + ) + session.add(table1) + session.commit() + + slice1 = Slice( + datasource_id=table1.id, + datasource_type=DatasourceType.TABLE, + datasource_name="tmp_table1", + slice_name="tmp_slice1", + ) + session.add(slice1) + session.commit() + + table1_pvm = security_manager.find_permission_view_menu( + "datasource_access", f"[tmp_db1].[tmp_table1](id:{table1.id})" + ) + self.assertIsNotNone(table1_pvm) + + # refresh + table1 = session.query(SqlaTable).filter_by(table_name="tmp_table1").one() + # Test update + table1.table_name = "tmp_table1_changed" + table1.database = tmp_db2 + session.commit() + + # Test old permission does not exist + table1_pvm = security_manager.find_permission_view_menu( + "datasource_access", f"[tmp_db1].[tmp_table1](id:{table1.id})" + ) + self.assertIsNone(table1_pvm) + + # Test new permission exist + table1_pvm = security_manager.find_permission_view_menu( + "datasource_access", f"[tmp_db2].[tmp_table1_changed](id:{table1.id})" + ) + self.assertIsNotNone(table1_pvm) + + # test dataset permission and schema permission changed + changed_table1 = ( + session.query(SqlaTable).filter_by(table_name="tmp_table1_changed").one() + ) + self.assertEqual( + changed_table1.perm, f"[tmp_db2].[tmp_table1_changed](id:{table1.id})" + ) + self.assertEqual(changed_table1.schema_perm, f"[tmp_db2].[tmp_schema]") + + # Test Chart permission changed + slice1 = session.query(Slice).filter_by(slice_name="tmp_slice1").one() + self.assertEqual(slice1.perm, f"[tmp_db2].[tmp_table1_changed](id:{table1.id})") + self.assertEqual(slice1.schema_perm, f"[tmp_db2].[tmp_schema]") + + # cleanup + session.delete(slice1) + session.delete(table1) + session.delete(tmp_db1) + session.delete(tmp_db2) + session.commit() + def test_hybrid_perm_database(self): database = Database(database_name="tmp_database3", sqlalchemy_uri="sqlite://") @@ -590,23 +1153,14 @@ def test_set_perm_slice(self): table.schema = "tmp_perm_schema" table.table_name = "tmp_perm_table_v2" session.commit() - # TODO(bogdan): modify slice permissions on the table update. - self.assertNotEqual(slice.perm, table.perm) - self.assertEqual(slice.perm, f"[tmp_database].[tmp_perm_table](id:{table.id})") - self.assertEqual( - table.perm, f"[tmp_database].[tmp_perm_table_v2](id:{table.id})" - ) - # TODO(bogdan): modify slice schema permissions on the table update. - self.assertNotEqual(slice.schema_perm, table.schema_perm) - self.assertIsNone(slice.schema_perm) - - # updating slice refreshes the permissions - slice.slice_name = "slice_name_v2" - session.commit() + table = session.query(SqlaTable).filter_by(table_name="tmp_perm_table_v2").one() self.assertEqual(slice.perm, table.perm) self.assertEqual( slice.perm, f"[tmp_database].[tmp_perm_table_v2](id:{table.id})" ) + self.assertEqual( + table.perm, f"[tmp_database].[tmp_perm_table_v2](id:{table.id})" + ) self.assertEqual(slice.schema_perm, table.schema_perm) self.assertEqual(slice.schema_perm, "[tmp_database].[tmp_perm_schema]") @@ -616,8 +1170,6 @@ def test_set_perm_slice(self): session.commit() - # TODO test slice permission - @patch("superset.utils.core.g") @patch("superset.security.manager.g") def test_schemas_accessible_by_user_admin(self, mock_sm_g, mock_g): diff --git a/tests/unit_tests/conftest.py b/tests/unit_tests/conftest.py index a9c645d082cd..935e128dc9b9 100644 --- a/tests/unit_tests/conftest.py +++ b/tests/unit_tests/conftest.py @@ -53,6 +53,9 @@ def get_session(): get_session.return_value = in_memory_session # FAB calls get_session.get_bind() to get a handler to the engine get_session.get_bind.return_value = engine + # Allow for queries on security manager + get_session.query = in_memory_session.query + mocker.patch("superset.db.session", in_memory_session) return in_memory_session