Skip to content

Commit

Permalink
Allow returning multiple filters and masks in ConnectorAccessControl
Browse files Browse the repository at this point in the history
Now that the `SystemAccessControl` can provide multiple filtering and
masking expressions, there's no reason for the `ConnectorAccessControl`
not to follow suit.
  • Loading branch information
ksobolew authored and kokosing committed Apr 3, 2022
1 parent ae66a8b commit 827de57
Show file tree
Hide file tree
Showing 17 changed files with 148 additions and 45 deletions.
Expand Up @@ -1165,8 +1165,8 @@ public List<ViewExpression> getRowFilters(SecurityContext context, QualifiedObje
CatalogAccessControlEntry entry = getConnectorAccessControl(context.getTransactionId(), tableName.getCatalogName());

if (entry != null) {
entry.getAccessControl().getRowFilter(entry.toConnectorSecurityContext(context), tableName.asSchemaTableName())
.ifPresent(filters::add);
entry.getAccessControl().getRowFilters(entry.toConnectorSecurityContext(context), tableName.asSchemaTableName())
.forEach(filters::add);
}

for (SystemAccessControl systemAccessControl : getSystemAccessControls()) {
Expand All @@ -1188,8 +1188,8 @@ public List<ViewExpression> getColumnMasks(SecurityContext context, QualifiedObj
// connector-provided masks take precedence over global masks
CatalogAccessControlEntry entry = getConnectorAccessControl(context.getTransactionId(), tableName.getCatalogName());
if (entry != null) {
entry.getAccessControl().getColumnMask(entry.toConnectorSecurityContext(context), tableName.asSchemaTableName(), columnName, type)
.ifPresent(masks::add);
entry.getAccessControl().getColumnMasks(entry.toConnectorSecurityContext(context), tableName.asSchemaTableName(), columnName, type)
.forEach(masks::add);
}

for (SystemAccessControl systemAccessControl : getSystemAccessControls()) {
Expand Down
Expand Up @@ -13,6 +13,7 @@
*/
package io.trino.security;

import com.google.common.collect.ImmutableList;
import io.trino.metadata.QualifiedObjectName;
import io.trino.spi.TrinoException;
import io.trino.spi.connector.CatalogSchemaName;
Expand All @@ -26,6 +27,7 @@
import io.trino.spi.security.ViewExpression;
import io.trino.spi.type.Type;

import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
Expand Down Expand Up @@ -445,21 +447,21 @@ public void checkCanExecuteTableProcedure(ConnectorSecurityContext context, Sche
}

@Override
public Optional<ViewExpression> getRowFilter(ConnectorSecurityContext context, SchemaTableName tableName)
public List<ViewExpression> getRowFilters(ConnectorSecurityContext context, SchemaTableName tableName)
{
checkArgument(context == null, "context must be null");
if (accessControl.getRowFilters(securityContext, new QualifiedObjectName(catalogName, tableName.getSchemaName(), tableName.getTableName())).isEmpty()) {
return Optional.empty();
return ImmutableList.of();
}
throw new TrinoException(NOT_SUPPORTED, "Row filtering not supported");
}

@Override
public Optional<ViewExpression> getColumnMask(ConnectorSecurityContext context, SchemaTableName tableName, String columnName, Type type)
public List<ViewExpression> getColumnMasks(ConnectorSecurityContext context, SchemaTableName tableName, String columnName, Type type)
{
checkArgument(context == null, "context must be null");
if (accessControl.getColumnMasks(securityContext, new QualifiedObjectName(catalogName, tableName.getSchemaName(), tableName.getTableName()), columnName, type).isEmpty()) {
return Optional.empty();
return ImmutableList.of();
}
throw new TrinoException(NOT_SUPPORTED, "Column masking not supported");
}
Expand Down
Expand Up @@ -13,6 +13,7 @@
*/
package io.trino.connector;

import com.google.common.collect.ImmutableList;
import io.trino.plugin.base.security.AllowAllAccessControl;
import io.trino.spi.connector.ConnectorSecurityContext;
import io.trino.spi.connector.SchemaTableName;
Expand All @@ -23,6 +24,7 @@
import io.trino.spi.type.Type;

import java.util.Arrays;
import java.util.List;
import java.util.Optional;
import java.util.Set;
import java.util.function.BiFunction;
Expand Down Expand Up @@ -120,15 +122,19 @@ public void checkCanRevokeTablePrivilege(ConnectorSecurityContext context, Privi
}

@Override
public Optional<ViewExpression> getRowFilter(ConnectorSecurityContext context, SchemaTableName tableName)
public List<ViewExpression> getRowFilters(ConnectorSecurityContext context, SchemaTableName tableName)
{
return Optional.ofNullable(rowFilters.apply(tableName));
return Optional.ofNullable(rowFilters.apply(tableName))
.map(ImmutableList::of)
.orElseGet(ImmutableList::of);
}

@Override
public Optional<ViewExpression> getColumnMask(ConnectorSecurityContext context, SchemaTableName tableName, String columnName, Type type)
public List<ViewExpression> getColumnMasks(ConnectorSecurityContext context, SchemaTableName tableName, String columnName, Type type)
{
return Optional.ofNullable(columnMasks.apply(tableName, columnName));
return Optional.ofNullable(columnMasks.apply(tableName, columnName))
.map(ImmutableList::of)
.orElseGet(ImmutableList::of);
}

public void grantSchemaPrivileges(String schemaName, Set<Privilege> privileges, TrinoPrincipal grantee, boolean grantOption)
Expand Down
Expand Up @@ -224,9 +224,9 @@ public void checkCanSetSystemSessionProperty(SystemSecurityContext context, Stri
accessControlManager.addCatalogAccessControl(new CatalogName("catalog"), new ConnectorAccessControl()
{
@Override
public Optional<ViewExpression> getColumnMask(ConnectorSecurityContext context, SchemaTableName tableName, String column, Type type)
public List<ViewExpression> getColumnMasks(ConnectorSecurityContext context, SchemaTableName tableName, String column, Type type)
{
return Optional.of(new ViewExpression("user", Optional.empty(), Optional.empty(), "connector mask"));
return ImmutableList.of(new ViewExpression("user", Optional.empty(), Optional.empty(), "connector mask"));
}

@Override
Expand Down
Expand Up @@ -13,7 +13,11 @@
*/
package io.trino.security;

import com.google.common.collect.ImmutableSet;
import io.trino.spi.connector.ConnectorAccessControl;
import io.trino.spi.connector.ConnectorSecurityContext;
import io.trino.spi.connector.SchemaTableName;
import io.trino.spi.type.Type;
import org.testng.annotations.Test;

import static io.trino.spi.testing.InterfaceTestUtils.assertAllMethodsOverridden;
Expand All @@ -22,7 +26,10 @@ public class TestInjectedConnectorAccessControl
{
@Test
public void testEverythingImplemented()
throws NoSuchMethodException
{
assertAllMethodsOverridden(ConnectorAccessControl.class, InjectedConnectorAccessControl.class);
assertAllMethodsOverridden(ConnectorAccessControl.class, InjectedConnectorAccessControl.class, ImmutableSet.of(
InjectedConnectorAccessControl.class.getMethod("getRowFilter", ConnectorSecurityContext.class, SchemaTableName.class),
InjectedConnectorAccessControl.class.getMethod("getColumnMask", ConnectorSecurityContext.class, SchemaTableName.class, String.class, Type.class)));
}
}
Expand Up @@ -18,6 +18,7 @@
import io.trino.spi.security.ViewExpression;
import io.trino.spi.type.Type;

import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
Expand Down Expand Up @@ -74,6 +75,7 @@
import static io.trino.spi.security.AccessDeniedException.denyShowTables;
import static io.trino.spi.security.AccessDeniedException.denyTruncateTable;
import static io.trino.spi.security.AccessDeniedException.denyUpdateTableColumns;
import static java.util.Collections.emptyList;
import static java.util.Collections.emptySet;

public interface ConnectorAccessControl
Expand Down Expand Up @@ -600,22 +602,51 @@ default void checkCanExecuteTableProcedure(ConnectorSecurityContext context, Sch
* The filter must be a scalar SQL expression of boolean type over the columns in the table.
*
* @return the filter, or {@link Optional#empty()} if not applicable
* @deprecated use {@link #getRowFilters(ConnectorSecurityContext, SchemaTableName)} instead
*/
@Deprecated
default Optional<ViewExpression> getRowFilter(ConnectorSecurityContext context, SchemaTableName tableName)
{
return Optional.empty();
}

/**
* Get row filters associated with the given table and identity.
* <p>
* Each filter must be a scalar SQL expression of boolean type over the columns in the table.
*
* @return the list of filters, or empty list if not applicable
*/
default List<ViewExpression> getRowFilters(ConnectorSecurityContext context, SchemaTableName tableName)
{
return emptyList();
}

/**
* Get a column mask associated with the given table, column and identity.
* <p>
* The mask must be a scalar SQL expression of a type coercible to the type of the column being masked. The expression
* must be written in terms of columns in the table.
*
* @return the mask, or {@link Optional#empty()} if not applicable
* @deprecated use {@link #getColumnMasks(ConnectorSecurityContext, SchemaTableName, String, Type)} instead
*/
@Deprecated
default Optional<ViewExpression> getColumnMask(ConnectorSecurityContext context, SchemaTableName tableName, String columnName, Type type)
{
return Optional.empty();
}

/**
* Get column masks associated with the given table, column and identity.
* <p>
* Each mask must be a scalar SQL expression of a type coercible to the type of the column being masked. The expression
* must be written in terms of columns in the table.
*
* @return the list of masks, or empty list if not applicable
*/
default List<ViewExpression> getColumnMasks(ConnectorSecurityContext context, SchemaTableName tableName, String columnName, Type type)
{
return emptyList();
}
}
Expand Up @@ -25,6 +25,7 @@

import javax.inject.Inject;

import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
Expand Down Expand Up @@ -493,18 +494,18 @@ public void checkCanExecuteTableProcedure(ConnectorSecurityContext context, Sche
}

@Override
public Optional<ViewExpression> getRowFilter(ConnectorSecurityContext context, SchemaTableName tableName)
public List<ViewExpression> getRowFilters(ConnectorSecurityContext context, SchemaTableName tableName)
{
try (ThreadContextClassLoader ignored = new ThreadContextClassLoader(classLoader)) {
return delegate.getRowFilter(context, tableName);
return delegate.getRowFilters(context, tableName);
}
}

@Override
public Optional<ViewExpression> getColumnMask(ConnectorSecurityContext context, SchemaTableName tableName, String columnName, Type type)
public List<ViewExpression> getColumnMasks(ConnectorSecurityContext context, SchemaTableName tableName, String columnName, Type type)
{
try (ThreadContextClassLoader ignored = new ThreadContextClassLoader(classLoader)) {
return delegate.getColumnMask(context, tableName, columnName, type);
return delegate.getColumnMasks(context, tableName, columnName, type);
}
}
}
Expand Up @@ -13,6 +13,7 @@
*/
package io.trino.plugin.base.security;

import com.google.common.collect.ImmutableList;
import io.trino.spi.connector.ConnectorAccessControl;
import io.trino.spi.connector.ConnectorSecurityContext;
import io.trino.spi.connector.SchemaRoutineName;
Expand All @@ -22,6 +23,7 @@
import io.trino.spi.security.ViewExpression;
import io.trino.spi.type.Type;

import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
Expand Down Expand Up @@ -316,14 +318,14 @@ public void checkCanExecuteTableProcedure(ConnectorSecurityContext context, Sche
}

@Override
public Optional<ViewExpression> getRowFilter(ConnectorSecurityContext context, SchemaTableName tableName)
public List<ViewExpression> getRowFilters(ConnectorSecurityContext context, SchemaTableName tableName)
{
return Optional.empty();
return ImmutableList.of();
}

@Override
public Optional<ViewExpression> getColumnMask(ConnectorSecurityContext context, SchemaTableName tableName, String columnName, Type type)
public List<ViewExpression> getColumnMasks(ConnectorSecurityContext context, SchemaTableName tableName, String columnName, Type type)
{
return Optional.empty();
return ImmutableList.of();
}
}
Expand Up @@ -13,6 +13,7 @@
*/
package io.trino.plugin.base.security;

import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableSet;
import io.trino.plugin.base.CatalogName;
import io.trino.plugin.base.security.TableAccessControlRule.TablePrivilege;
Expand All @@ -31,10 +32,10 @@
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import java.util.function.Function;
import java.util.function.Predicate;

import static com.google.common.base.Preconditions.checkArgument;
import static com.google.common.collect.ImmutableList.toImmutableList;
import static com.google.common.collect.ImmutableSet.toImmutableSet;
import static io.trino.plugin.base.security.TableAccessControlRule.TablePrivilege.DELETE;
import static io.trino.plugin.base.security.TableAccessControlRule.TablePrivilege.GRANT_SELECT;
Expand Down Expand Up @@ -591,33 +592,37 @@ public void checkCanExecuteTableProcedure(ConnectorSecurityContext context, Sche
}

@Override
public Optional<ViewExpression> getRowFilter(ConnectorSecurityContext context, SchemaTableName tableName)
public List<ViewExpression> getRowFilters(ConnectorSecurityContext context, SchemaTableName tableName)
{
if (INFORMATION_SCHEMA_NAME.equals(tableName.getSchemaName())) {
return Optional.empty();
return ImmutableList.of();
}

ConnectorIdentity identity = context.getIdentity();
return tableRules.stream()
.filter(rule -> rule.matches(identity.getUser(), identity.getEnabledSystemRoles(), identity.getGroups(), tableName))
.map(rule -> rule.getFilter(identity.getUser(), catalogName, tableName.getSchemaName()))
.findFirst()
.flatMap(Function.identity());
.flatMap(Optional::stream)
// we return the first one we find
.limit(1)
.collect(toImmutableList());
}

@Override
public Optional<ViewExpression> getColumnMask(ConnectorSecurityContext context, SchemaTableName tableName, String columnName, Type type)
public List<ViewExpression> getColumnMasks(ConnectorSecurityContext context, SchemaTableName tableName, String columnName, Type type)
{
if (INFORMATION_SCHEMA_NAME.equals(tableName.getSchemaName())) {
return Optional.empty();
return ImmutableList.of();
}

ConnectorIdentity identity = context.getIdentity();
return tableRules.stream()
.filter(rule -> rule.matches(identity.getUser(), identity.getEnabledSystemRoles(), identity.getGroups(), tableName))
.map(rule -> rule.getColumnMask(identity.getUser(), catalogName, tableName.getSchemaName(), columnName))
.findFirst()
.flatMap(Function.identity());
.flatMap(Optional::stream)
// we return the first one we find
.limit(1)
.collect(toImmutableList());
}

private boolean canSetSessionProperty(ConnectorSecurityContext context, String property)
Expand Down
Expand Up @@ -22,6 +22,7 @@
import io.trino.spi.security.ViewExpression;
import io.trino.spi.type.Type;

import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
Expand Down Expand Up @@ -391,9 +392,21 @@ public Optional<ViewExpression> getRowFilter(ConnectorSecurityContext context, S
return delegate().getRowFilter(context, tableName);
}

@Override
public List<ViewExpression> getRowFilters(ConnectorSecurityContext context, SchemaTableName tableName)
{
return delegate().getRowFilters(context, tableName);
}

@Override
public Optional<ViewExpression> getColumnMask(ConnectorSecurityContext context, SchemaTableName tableName, String columnName, Type type)
{
return delegate().getColumnMask(context, tableName, columnName, type);
}

@Override
public List<ViewExpression> getColumnMasks(ConnectorSecurityContext context, SchemaTableName tableName, String columnName, Type type)
{
return delegate().getColumnMasks(context, tableName, columnName, type);
}
}

0 comments on commit 827de57

Please sign in to comment.