Skip to content

Commit

Permalink
Add access control check while listing views
Browse files Browse the repository at this point in the history
  • Loading branch information
skrzypo987 authored and kokosing committed Apr 1, 2020
1 parent 0c13d16 commit f922c85
Show file tree
Hide file tree
Showing 4 changed files with 67 additions and 4 deletions.
Expand Up @@ -4733,6 +4733,37 @@ public void testShowColumnMetadata()
assertUpdate("DROP TABLE " + tableName);
}

@Test
public void testShowViews()
{
String viewName = "test_show_views";

Session testSession = testSessionBuilder()
.setIdentity(Identity.ofUser("test_view_access_owner"))
.setCatalog(getSession().getCatalog().get())
.setSchema(getSession().getSchema().get())
.build();

assertUpdate("CREATE VIEW " + viewName + " AS SELECT abs(1) as whatever");

String showViews = format("SELECT * FROM information_schema.views WHERE table_name = '%s'", viewName);
assertQuery(
format("SELECT table_name FROM information_schema.views WHERE table_name = '%s'", viewName),
format("VALUES '%s'", viewName));

executeExclusively(() -> {
try {
getQueryRunner().getAccessControl().denyTables(table -> false);
assertQueryReturnsEmptyResult(testSession, showViews);
}
finally {
getQueryRunner().getAccessControl().reset();
}
});

assertUpdate("DROP VIEW " + viewName);
}

@Test
public void testShowTablePrivileges()
{
Expand Down
Expand Up @@ -17,7 +17,6 @@
import com.google.common.collect.ImmutableList;
import io.prestosql.Session;
import io.prestosql.metadata.Metadata;
import io.prestosql.metadata.QualifiedObjectName;
import io.prestosql.metadata.QualifiedTablePrefix;
import io.prestosql.security.AccessControl;
import io.prestosql.spi.Page;
Expand Down Expand Up @@ -53,6 +52,7 @@
import static com.google.common.collect.Sets.union;
import static io.prestosql.connector.informationschema.InformationSchemaMetadata.defaultPrefixes;
import static io.prestosql.connector.informationschema.InformationSchemaMetadata.isTablesEnumeratingTable;
import static io.prestosql.metadata.MetadataListing.getViews;
import static io.prestosql.metadata.MetadataListing.listSchemas;
import static io.prestosql.metadata.MetadataListing.listTableColumns;
import static io.prestosql.metadata.MetadataListing.listTablePrivileges;
Expand Down Expand Up @@ -288,11 +288,11 @@ private void addTablesRecords(QualifiedTablePrefix prefix)

private void addViewsRecords(QualifiedTablePrefix prefix)
{
for (Map.Entry<QualifiedObjectName, ConnectorViewDefinition> entry : metadata.getViews(session, prefix).entrySet()) {
for (Map.Entry<SchemaTableName, ConnectorViewDefinition> entry : getViews(session, metadata, accessControl, prefix).entrySet()) {
addRecord(
entry.getKey().getCatalogName(),
prefix.getCatalogName(),
entry.getKey().getSchemaName(),
entry.getKey().getObjectName(),
entry.getKey().getTableName(),
entry.getValue().getOriginalSql());
if (isLimitExhausted()) {
return;
Expand Down
Expand Up @@ -22,6 +22,7 @@
import io.prestosql.security.AccessControl;
import io.prestosql.spi.connector.CatalogSchemaTableName;
import io.prestosql.spi.connector.ColumnMetadata;
import io.prestosql.spi.connector.ConnectorViewDefinition;
import io.prestosql.spi.connector.SchemaTableName;
import io.prestosql.spi.security.GrantInfo;

Expand Down Expand Up @@ -75,6 +76,18 @@ public static Set<SchemaTableName> listViews(Session session, Metadata metadata,
return accessControl.filterTables(session.toSecurityContext(), prefix.getCatalogName(), tableNames);
}

public static Map<SchemaTableName, ConnectorViewDefinition> getViews(Session session, Metadata metadata, AccessControl accessControl, QualifiedTablePrefix prefix)
{
Map<SchemaTableName, ConnectorViewDefinition> views = metadata.getViews(session, prefix).entrySet().stream()
.collect(toImmutableMap(entry -> entry.getKey().asSchemaTableName(), Entry::getValue));

Set<SchemaTableName> accessible = accessControl.filterTables(session.toSecurityContext(), prefix.getCatalogName(), views.keySet());

return views.entrySet().stream()
.filter(entry -> accessible.contains(entry.getKey()))
.collect(toImmutableMap(Entry::getKey, Entry::getValue));
}

public static Set<GrantInfo> listTablePrivileges(Session session, Metadata metadata, AccessControl accessControl, QualifiedTablePrefix prefix)
{
List<GrantInfo> grants = metadata.listTablePrivileges(session, prefix);
Expand Down
Expand Up @@ -20,6 +20,7 @@
import io.prestosql.security.SecurityContext;
import io.prestosql.spi.connector.CatalogSchemaName;
import io.prestosql.spi.connector.CatalogSchemaTableName;
import io.prestosql.spi.connector.SchemaTableName;
import io.prestosql.spi.security.Identity;
import io.prestosql.spi.security.ViewExpression;
import io.prestosql.spi.type.Type;
Expand Down Expand Up @@ -112,6 +113,7 @@ public class TestingAccessControlManager
private final Map<RowFilterKey, List<ViewExpression>> rowFilters = new HashMap<>();
private final Map<ColumnMaskKey, List<ViewExpression>> columnMasks = new HashMap<>();
private Predicate<String> deniedCatalogs = s -> true;
private Predicate<SchemaTableName> deniedTables = s -> true;

@Inject
public TestingAccessControlManager(TransactionManager transactionManager)
Expand Down Expand Up @@ -155,6 +157,7 @@ public void reset()
{
denyPrivileges.clear();
deniedCatalogs = s -> true;
deniedTables = s -> true;
rowFilters.clear();
columnMasks.clear();
}
Expand All @@ -164,6 +167,11 @@ public void denyCatalogs(Predicate<String> deniedCatalogs)
this.deniedCatalogs = this.deniedCatalogs.and(deniedCatalogs);
}

public void denyTables(Predicate<SchemaTableName> deniedTables)
{
this.deniedTables = this.deniedTables.and(deniedTables);
}

@Override
public Set<String> filterCatalogs(Identity identity, Set<String> catalogs)
{
Expand All @@ -174,6 +182,17 @@ public Set<String> filterCatalogs(Identity identity, Set<String> catalogs)
.collect(toImmutableSet()));
}

@Override
public Set<SchemaTableName> filterTables(SecurityContext context, String catalogName, Set<SchemaTableName> tableNames)
{
return super.filterTables(
context,
catalogName,
tableNames.stream()
.filter(this.deniedTables)
.collect(toImmutableSet()));
}

@Override
public void checkCanImpersonateUser(Identity identity, String userName)
{
Expand Down

0 comments on commit f922c85

Please sign in to comment.