Skip to content

Commit

Permalink
Merge pull request #5787 from geoand/spring-preauthorize
Browse files Browse the repository at this point in the history
Introduce support for the most common features of Spring Security's @PreAuthorize
  • Loading branch information
gsmet committed Nov 29, 2019
2 parents 560a9be + 02f3563 commit adb87d9
Show file tree
Hide file tree
Showing 71 changed files with 2,571 additions and 63 deletions.
Expand Up @@ -10,6 +10,7 @@

/**
* Used as an integration point when extensions need to customize the security behavior of a bean
* The ResultHandle that is returned by function needs to be an instance of SecurityCheck
*/
public final class AdditionalSecurityCheckBuildItem extends MultiBuildItem {

Expand Down
Expand Up @@ -17,7 +17,7 @@ public class SecurityCheckInstantiationUtil {
private SecurityCheckInstantiationUtil() {
}

public static Function<BytecodeCreator, ResultHandle> rolesAllowedSecurityCheck(final String[] rolesAllowed) {
public static Function<BytecodeCreator, ResultHandle> rolesAllowedSecurityCheck(String... rolesAllowed) {
return new Function<BytecodeCreator, ResultHandle>() {
@Override
public ResultHandle apply(BytecodeCreator creator) {
Expand Down
Expand Up @@ -20,11 +20,11 @@ public class SecurityConstrainer {
@Inject
SecurityCheckStorage storage;

public void checkRoles(Method method) {
public void check(Method method, Object[] parameters) {

SecurityCheck securityCheck = storage.getSecurityCheck(method);
if (securityCheck != null) {
securityCheck.apply(identity);
securityCheck.apply(identity, method, parameters);
}
}
}
Expand Up @@ -20,7 +20,7 @@ public Object handle(InvocationContext ic) throws Exception {
if (alreadyHandled(ic)) {
return ic.proceed();
}
constrainer.checkRoles(ic.getMethod());
constrainer.check(ic.getMethod(), ic.getParameters());
return ic.proceed();
}

Expand Down
@@ -1,5 +1,7 @@
package io.quarkus.security.runtime.interceptor.check;

import java.lang.reflect.Method;

import io.quarkus.security.UnauthorizedException;
import io.quarkus.security.identity.SecurityIdentity;

Expand All @@ -11,7 +13,7 @@ private AuthenticatedCheck() {
}

@Override
public void apply(SecurityIdentity identity) {
public void apply(SecurityIdentity identity, Method method, Object[] parameters) {
if (identity.isAnonymous()) {
throw new UnauthorizedException();
}
Expand Down
@@ -1,5 +1,7 @@
package io.quarkus.security.runtime.interceptor.check;

import java.lang.reflect.Method;

import io.quarkus.security.ForbiddenException;
import io.quarkus.security.UnauthorizedException;
import io.quarkus.security.identity.SecurityIdentity;
Expand All @@ -12,7 +14,7 @@ private DenyAllCheck() {
}

@Override
public void apply(SecurityIdentity identity) {
public void apply(SecurityIdentity identity, Method method, Object[] parameters) {
if (identity.isAnonymous()) {
throw new UnauthorizedException();
} else {
Expand Down
@@ -1,5 +1,7 @@
package io.quarkus.security.runtime.interceptor.check;

import java.lang.reflect.Method;

import io.quarkus.security.identity.SecurityIdentity;

public class PermitAllCheck implements SecurityCheck {
Expand All @@ -10,6 +12,6 @@ private PermitAllCheck() {
}

@Override
public void apply(SecurityIdentity identity) {
public void apply(SecurityIdentity identity, Method method, Object[] parameters) {
}
}
@@ -1,5 +1,6 @@
package io.quarkus.security.runtime.interceptor.check;

import java.lang.reflect.Method;
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
Expand Down Expand Up @@ -49,7 +50,7 @@ private static Collection<String> getCollectionForKey(String[] allowedRoles) {
}

@Override
public void apply(SecurityIdentity identity) {
public void apply(SecurityIdentity identity, Method method, Object[] parameters) {
Set<String> roles = identity.getRoles();
if (roles != null) {
for (String role : allowedRoles) {
Expand Down
@@ -1,7 +1,9 @@
package io.quarkus.security.runtime.interceptor.check;

import java.lang.reflect.Method;

import io.quarkus.security.identity.SecurityIdentity;

public interface SecurityCheck {
void apply(SecurityIdentity identity);
}
void apply(SecurityIdentity identity, Method method, Object[] parameters);
}
Expand Up @@ -8,9 +8,11 @@
public class AuthData {
public final Set<String> roles;
public final boolean anonymous;
public final String name;

public AuthData(Set<String> roles, boolean anonymous) {
public AuthData(Set<String> roles, boolean anonymous, String name) {
this.roles = roles;
this.anonymous = anonymous;
this.name = name;
}
}
Expand Up @@ -22,21 +22,28 @@
@Priority(1)
public class IdentityMock implements SecurityIdentity {

public static final AuthData ANONYMOUS = new AuthData(null, true);
public static final AuthData USER = new AuthData(Collections.singleton("user"), false);
public static final AuthData ADMIN = new AuthData(Collections.singleton("admin"), false);
public static final AuthData ANONYMOUS = new AuthData(null, true, null);
public static final AuthData USER = new AuthData(Collections.singleton("user"), false, "user");
public static final AuthData ADMIN = new AuthData(Collections.singleton("admin"), false, "admin");

private static volatile boolean anonymous;
private static volatile Set<String> roles;
private static volatile String name;

public static void setUpAuth(AuthData auth) {
IdentityMock.anonymous = auth.anonymous;
IdentityMock.roles = auth.roles;
IdentityMock.name = auth.name;
}

@Override
public Principal getPrincipal() {
return () -> "whatever";
return new Principal() {
@Override
public String getName() {
return name;
}
};
}

@Override
Expand Down
@@ -0,0 +1,24 @@
package io.quarkus.spring.di.deployment;

import java.util.Map;

import org.jboss.jandex.DotName;

import io.quarkus.builder.item.SimpleBuildItem;

/**
* The purpose of this bean is to map the names of the Spring Beans to their associated DotName
* This info is needed when trying to convert SpEL expressions that reference beans by name, to bytecode
*/
public final class SpringBeanNameToDotNameBuildItem extends SimpleBuildItem {

private final Map<String, DotName> map;

public SpringBeanNameToDotNameBuildItem(Map<String, DotName> map) {
this.map = map;
}

public Map<String, DotName> getMap() {
return map;
}
}
Expand Up @@ -88,6 +88,37 @@ FeatureBuildItem registerFeature() {
return new FeatureBuildItem(FeatureBuildItem.SPRING_DI);
}

/*
* This Build Item can't be generated in the beanTransformer method because the annotation transformer
* is generated lazily.
* However the logic is the same
*/
@BuildStep
SpringBeanNameToDotNameBuildItem createBeanNamesMap(BeanArchiveIndexBuildItem beanArchiveIndexBuildItem) {
final Map<String, DotName> result = new HashMap<>();

final IndexView index = beanArchiveIndexBuildItem.getIndex();
final Collection<AnnotationInstance> stereotypeInstances = new ArrayList<>();
stereotypeInstances.addAll(index.getAnnotations(SPRING_COMPONENT));
stereotypeInstances.addAll(index.getAnnotations(SPRING_REPOSITORY));
stereotypeInstances.addAll(index.getAnnotations(SPRING_SERVICE));
for (AnnotationInstance stereotypeInstance : stereotypeInstances) {
if (stereotypeInstance.target().kind() != AnnotationTarget.Kind.CLASS) {
continue;
}
result.put(getBeanNameFromStereotypeInstance(stereotypeInstance), stereotypeInstance.target().asClass().name());
}

for (AnnotationInstance beanInstance : index.getAnnotations(BEAN_ANNOTATION)) {
if (beanInstance.target().kind() != AnnotationTarget.Kind.METHOD) {
continue;
}
result.put(getBeanNameFromBeanInstance(beanInstance), beanInstance.target().asMethod().returnType().name());
}

return new SpringBeanNameToDotNameBuildItem(result);
}

@BuildStep
AnnotationsTransformerBuildItem beanTransformer(
final BeanArchiveIndexBuildItem beanArchiveIndexBuildItem,
Expand Down Expand Up @@ -276,17 +307,8 @@ Set<AnnotationInstance> getAnnotationsToAdd(
if (scopeNames != null) {
scopes.addAll(scopeNames);
}
if (SPRING_STEREOTYPE_ANNOTATIONS.contains(clazzAnnotation)) {
//check if the spring annotation defines a name for the bean
final AnnotationValue value = classInfo.classAnnotation(clazzAnnotation).value();
if (value == null) {
continue;
}
final String name = value.asString();
if (name == null || name.isEmpty()) {
continue;
}
names.add(name);
if (SPRING_STEREOTYPE_ANNOTATIONS.contains(clazzAnnotation) && !isAnnotation(classInfo.flags())) {
names.add(getBeanNameFromStereotypeInstance(classInfo.classAnnotation(clazzAnnotation)));
}
}
}
Expand Down Expand Up @@ -377,12 +399,11 @@ Set<AnnotationInstance> getAnnotationsToAdd(
Collections.emptyList()));
}

//check if the spring annotation defines a name for the bean
final AnnotationValue beanNameAnnotationValue = methodInfo.annotation(BEAN_ANNOTATION).value("name");
final AnnotationValue beanValueAnnotationValue = methodInfo.annotation(BEAN_ANNOTATION).value("value");
if (!addCDINamedAnnotation(target, beanNameAnnotationValue, annotationsToAdd)) {
addCDINamedAnnotation(target, beanValueAnnotationValue, annotationsToAdd);
}
String beanName = getBeanNameFromBeanInstance(methodInfo.annotation(BEAN_ANNOTATION));
annotationsToAdd.add(create(
CDI_NAMED_ANNOTATION,
target,
Collections.singletonList(AnnotationValue.createStringValue("value", beanName))));
}

// add method parameter conversion annotations
Expand All @@ -407,6 +428,79 @@ Set<AnnotationInstance> getAnnotationsToAdd(
return annotationsToAdd;
}

/**
* Meant to be called with instances of @Component, @Service, @Repository
*/
private String getBeanNameFromStereotypeInstance(AnnotationInstance annotationInstance) {
if (annotationInstance.target().kind() != AnnotationTarget.Kind.CLASS) {
throw new IllegalStateException(
"AnnotationInstance " + annotationInstance + " is an invalid target. Only Class targets are supported");
}
final AnnotationValue value = annotationInstance.value();
if ((value == null) || value.asString().isEmpty()) {
return getDefaultBeanNameFromClass(annotationInstance.target().asClass().name().toString());
} else {
return value.asString();
}
}

/**
* Meant to be called with instances of @Bean
*/
private String getBeanNameFromBeanInstance(AnnotationInstance annotationInstance) {
if (annotationInstance.target().kind() != AnnotationTarget.Kind.METHOD) {
throw new IllegalStateException(
"AnnotationInstance " + annotationInstance + " is an invalid target. Only Method targets are supported");
}

String beanName = null;
final AnnotationValue beanNameAnnotationValue = annotationInstance.value("name");
if (beanNameAnnotationValue != null) {
beanName = determineName(beanNameAnnotationValue);
}
if (beanName == null || beanName.isEmpty()) {
final AnnotationValue beanValueAnnotationValue = annotationInstance.value();
if (beanNameAnnotationValue != null) {
beanName = determineName(beanValueAnnotationValue);
}
}
if (beanName == null || beanName.isEmpty()) {
beanName = annotationInstance.target().asMethod().name();
}

return beanName;
}

// this does what Spring's AnnotationBeanNameGenerator does to generate a name
private String getDefaultBeanNameFromClass(String className) {
return decapitalize(getShortNameOfClass(className));
}

private String getShortNameOfClass(String className) {
int lastDotIndex = className.lastIndexOf('.');
int nameEndIndex = className.indexOf("$$");
if (nameEndIndex == -1) {
nameEndIndex = className.length();
}
String shortName = className.substring(lastDotIndex + 1, nameEndIndex);
shortName = shortName.replace('$', '.');
return shortName;
}

private String decapitalize(String name) {
if (name != null && name.length() != 0) {
if (name.length() > 1 && Character.isUpperCase(name.charAt(1)) && Character.isUpperCase(name.charAt(0))) {
return name;
} else {
char[] chars = name.toCharArray();
chars[0] = Character.toLowerCase(chars[0]);
return new String(chars);
}
} else {
return name;
}
}

private void addSpringValueAnnotations(AnnotationTarget target, AnnotationInstance annotation, boolean addInject,
Set<AnnotationInstance> annotationsToAdd) {
final AnnotationValue annotationValue = annotation.value();
Expand Down Expand Up @@ -440,26 +534,6 @@ private void addSpringValueAnnotations(AnnotationTarget target, AnnotationInstan
}
}

private static boolean addCDINamedAnnotation(AnnotationTarget target,
AnnotationValue annotationValue,
Set<AnnotationInstance> annotationsToAdd) {
if (annotationValue == null) {
return false;
}

final String beanName = determineName(annotationValue);
if (beanName != null && !"".equals(beanName)) {
annotationsToAdd.add(create(
CDI_NAMED_ANNOTATION,
target,
Collections.singletonList(AnnotationValue.createStringValue("value", beanName))));

return true;
}

return false;
}

private static String determineName(AnnotationValue annotationValue) {
if (annotationValue.kind() == AnnotationValue.Kind.ARRAY) {
return annotationValue.asStringArray()[0];
Expand Down
@@ -0,0 +1,7 @@
package io.quarkus.spring.di.deployment;

import org.springframework.stereotype.Component;

@Component
public class Bean {
}

0 comments on commit adb87d9

Please sign in to comment.