Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Introduce support for the most common features of Spring Security's @PreAuthorize #5787

Merged
merged 4 commits into from Nov 29, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
Expand Up @@ -10,6 +10,7 @@

/**
* Used as an integration point when extensions need to customize the security behavior of a bean
* The ResultHandle that is returned by function needs to be an instance of SecurityCheck
*/
public final class AdditionalSecurityCheckBuildItem extends MultiBuildItem {

Expand Down
Expand Up @@ -17,7 +17,7 @@ public class SecurityCheckInstantiationUtil {
private SecurityCheckInstantiationUtil() {
}

public static Function<BytecodeCreator, ResultHandle> rolesAllowedSecurityCheck(final String[] rolesAllowed) {
public static Function<BytecodeCreator, ResultHandle> rolesAllowedSecurityCheck(String... rolesAllowed) {
return new Function<BytecodeCreator, ResultHandle>() {
@Override
public ResultHandle apply(BytecodeCreator creator) {
Expand Down
Expand Up @@ -20,11 +20,11 @@ public class SecurityConstrainer {
@Inject
SecurityCheckStorage storage;

public void checkRoles(Method method) {
public void check(Method method, Object[] parameters) {

SecurityCheck securityCheck = storage.getSecurityCheck(method);
if (securityCheck != null) {
securityCheck.apply(identity);
securityCheck.apply(identity, method, parameters);
}
}
}
Expand Up @@ -20,7 +20,7 @@ public Object handle(InvocationContext ic) throws Exception {
if (alreadyHandled(ic)) {
return ic.proceed();
}
constrainer.checkRoles(ic.getMethod());
constrainer.check(ic.getMethod(), ic.getParameters());
return ic.proceed();
}

Expand Down
@@ -1,5 +1,7 @@
package io.quarkus.security.runtime.interceptor.check;

import java.lang.reflect.Method;

import io.quarkus.security.UnauthorizedException;
import io.quarkus.security.identity.SecurityIdentity;

Expand All @@ -11,7 +13,7 @@ private AuthenticatedCheck() {
}

@Override
public void apply(SecurityIdentity identity) {
public void apply(SecurityIdentity identity, Method method, Object[] parameters) {
if (identity.isAnonymous()) {
throw new UnauthorizedException();
}
Expand Down
@@ -1,5 +1,7 @@
package io.quarkus.security.runtime.interceptor.check;

import java.lang.reflect.Method;

import io.quarkus.security.ForbiddenException;
import io.quarkus.security.UnauthorizedException;
import io.quarkus.security.identity.SecurityIdentity;
Expand All @@ -12,7 +14,7 @@ private DenyAllCheck() {
}

@Override
public void apply(SecurityIdentity identity) {
public void apply(SecurityIdentity identity, Method method, Object[] parameters) {
if (identity.isAnonymous()) {
throw new UnauthorizedException();
} else {
Expand Down
@@ -1,5 +1,7 @@
package io.quarkus.security.runtime.interceptor.check;

import java.lang.reflect.Method;

import io.quarkus.security.identity.SecurityIdentity;

public class PermitAllCheck implements SecurityCheck {
Expand All @@ -10,6 +12,6 @@ private PermitAllCheck() {
}

@Override
public void apply(SecurityIdentity identity) {
public void apply(SecurityIdentity identity, Method method, Object[] parameters) {
}
}
@@ -1,5 +1,6 @@
package io.quarkus.security.runtime.interceptor.check;

import java.lang.reflect.Method;
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
Expand Down Expand Up @@ -49,7 +50,7 @@ private static Collection<String> getCollectionForKey(String[] allowedRoles) {
}

@Override
public void apply(SecurityIdentity identity) {
public void apply(SecurityIdentity identity, Method method, Object[] parameters) {
Set<String> roles = identity.getRoles();
if (roles != null) {
for (String role : allowedRoles) {
Expand Down
@@ -1,7 +1,9 @@
package io.quarkus.security.runtime.interceptor.check;

import java.lang.reflect.Method;

import io.quarkus.security.identity.SecurityIdentity;

public interface SecurityCheck {
void apply(SecurityIdentity identity);
}
void apply(SecurityIdentity identity, Method method, Object[] parameters);
}
Expand Up @@ -8,9 +8,11 @@
public class AuthData {
public final Set<String> roles;
public final boolean anonymous;
public final String name;

public AuthData(Set<String> roles, boolean anonymous) {
public AuthData(Set<String> roles, boolean anonymous, String name) {
this.roles = roles;
this.anonymous = anonymous;
this.name = name;
}
}
Expand Up @@ -22,21 +22,28 @@
@Priority(1)
public class IdentityMock implements SecurityIdentity {

public static final AuthData ANONYMOUS = new AuthData(null, true);
public static final AuthData USER = new AuthData(Collections.singleton("user"), false);
public static final AuthData ADMIN = new AuthData(Collections.singleton("admin"), false);
public static final AuthData ANONYMOUS = new AuthData(null, true, null);
public static final AuthData USER = new AuthData(Collections.singleton("user"), false, "user");
public static final AuthData ADMIN = new AuthData(Collections.singleton("admin"), false, "admin");

private static volatile boolean anonymous;
private static volatile Set<String> roles;
private static volatile String name;

public static void setUpAuth(AuthData auth) {
IdentityMock.anonymous = auth.anonymous;
IdentityMock.roles = auth.roles;
IdentityMock.name = auth.name;
}

@Override
public Principal getPrincipal() {
return () -> "whatever";
return new Principal() {
@Override
public String getName() {
return name;
}
};
}

@Override
Expand Down
@@ -0,0 +1,24 @@
package io.quarkus.spring.di.deployment;

import java.util.Map;

import org.jboss.jandex.DotName;

import io.quarkus.builder.item.SimpleBuildItem;

/**
* The purpose of this bean is to map the names of the Spring Beans to their associated DotName
* This info is needed when trying to convert SpEL expressions that reference beans by name, to bytecode
*/
public final class SpringBeanNameToDotNameBuildItem extends SimpleBuildItem {

private final Map<String, DotName> map;

public SpringBeanNameToDotNameBuildItem(Map<String, DotName> map) {
this.map = map;
}

public Map<String, DotName> getMap() {
return map;
}
}
Expand Up @@ -88,6 +88,37 @@ FeatureBuildItem registerFeature() {
return new FeatureBuildItem(FeatureBuildItem.SPRING_DI);
}

/*
* This Build Item can't be generated in the beanTransformer method because the annotation transformer
* is generated lazily.
* However the logic is the same
*/
@BuildStep
SpringBeanNameToDotNameBuildItem createBeanNamesMap(BeanArchiveIndexBuildItem beanArchiveIndexBuildItem) {
final Map<String, DotName> result = new HashMap<>();

final IndexView index = beanArchiveIndexBuildItem.getIndex();
final Collection<AnnotationInstance> stereotypeInstances = new ArrayList<>();
stereotypeInstances.addAll(index.getAnnotations(SPRING_COMPONENT));
stereotypeInstances.addAll(index.getAnnotations(SPRING_REPOSITORY));
stereotypeInstances.addAll(index.getAnnotations(SPRING_SERVICE));
for (AnnotationInstance stereotypeInstance : stereotypeInstances) {
if (stereotypeInstance.target().kind() != AnnotationTarget.Kind.CLASS) {
continue;
}
result.put(getBeanNameFromStereotypeInstance(stereotypeInstance), stereotypeInstance.target().asClass().name());
}

for (AnnotationInstance beanInstance : index.getAnnotations(BEAN_ANNOTATION)) {
if (beanInstance.target().kind() != AnnotationTarget.Kind.METHOD) {
continue;
}
result.put(getBeanNameFromBeanInstance(beanInstance), beanInstance.target().asMethod().returnType().name());
}

return new SpringBeanNameToDotNameBuildItem(result);
}

@BuildStep
AnnotationsTransformerBuildItem beanTransformer(
final BeanArchiveIndexBuildItem beanArchiveIndexBuildItem,
Expand Down Expand Up @@ -276,17 +307,8 @@ Set<AnnotationInstance> getAnnotationsToAdd(
if (scopeNames != null) {
scopes.addAll(scopeNames);
}
if (SPRING_STEREOTYPE_ANNOTATIONS.contains(clazzAnnotation)) {
//check if the spring annotation defines a name for the bean
final AnnotationValue value = classInfo.classAnnotation(clazzAnnotation).value();
if (value == null) {
continue;
}
final String name = value.asString();
if (name == null || name.isEmpty()) {
continue;
}
names.add(name);
if (SPRING_STEREOTYPE_ANNOTATIONS.contains(clazzAnnotation) && !isAnnotation(classInfo.flags())) {
names.add(getBeanNameFromStereotypeInstance(classInfo.classAnnotation(clazzAnnotation)));
}
}
}
Expand Down Expand Up @@ -377,12 +399,11 @@ Set<AnnotationInstance> getAnnotationsToAdd(
Collections.emptyList()));
}

//check if the spring annotation defines a name for the bean
final AnnotationValue beanNameAnnotationValue = methodInfo.annotation(BEAN_ANNOTATION).value("name");
final AnnotationValue beanValueAnnotationValue = methodInfo.annotation(BEAN_ANNOTATION).value("value");
if (!addCDINamedAnnotation(target, beanNameAnnotationValue, annotationsToAdd)) {
addCDINamedAnnotation(target, beanValueAnnotationValue, annotationsToAdd);
}
String beanName = getBeanNameFromBeanInstance(methodInfo.annotation(BEAN_ANNOTATION));
annotationsToAdd.add(create(
CDI_NAMED_ANNOTATION,
target,
Collections.singletonList(AnnotationValue.createStringValue("value", beanName))));
}

// add method parameter conversion annotations
Expand All @@ -407,6 +428,79 @@ Set<AnnotationInstance> getAnnotationsToAdd(
return annotationsToAdd;
}

/**
* Meant to be called with instances of @Component, @Service, @Repository
*/
private String getBeanNameFromStereotypeInstance(AnnotationInstance annotationInstance) {
if (annotationInstance.target().kind() != AnnotationTarget.Kind.CLASS) {
throw new IllegalStateException(
"AnnotationInstance " + annotationInstance + " is an invalid target. Only Class targets are supported");
}
final AnnotationValue value = annotationInstance.value();
if ((value == null) || value.asString().isEmpty()) {
return getDefaultBeanNameFromClass(annotationInstance.target().asClass().name().toString());
} else {
return value.asString();
}
}

/**
* Meant to be called with instances of @Bean
*/
private String getBeanNameFromBeanInstance(AnnotationInstance annotationInstance) {
if (annotationInstance.target().kind() != AnnotationTarget.Kind.METHOD) {
throw new IllegalStateException(
"AnnotationInstance " + annotationInstance + " is an invalid target. Only Method targets are supported");
}

String beanName = null;
final AnnotationValue beanNameAnnotationValue = annotationInstance.value("name");
if (beanNameAnnotationValue != null) {
beanName = determineName(beanNameAnnotationValue);
}
if (beanName == null || beanName.isEmpty()) {
final AnnotationValue beanValueAnnotationValue = annotationInstance.value();
if (beanNameAnnotationValue != null) {
beanName = determineName(beanValueAnnotationValue);
}
}
if (beanName == null || beanName.isEmpty()) {
beanName = annotationInstance.target().asMethod().name();
}

return beanName;
}

// this does what Spring's AnnotationBeanNameGenerator does to generate a name
private String getDefaultBeanNameFromClass(String className) {
return decapitalize(getShortNameOfClass(className));
}

private String getShortNameOfClass(String className) {
int lastDotIndex = className.lastIndexOf('.');
int nameEndIndex = className.indexOf("$$");
if (nameEndIndex == -1) {
nameEndIndex = className.length();
}
String shortName = className.substring(lastDotIndex + 1, nameEndIndex);
shortName = shortName.replace('$', '.');
return shortName;
}

private String decapitalize(String name) {
if (name != null && name.length() != 0) {
if (name.length() > 1 && Character.isUpperCase(name.charAt(1)) && Character.isUpperCase(name.charAt(0))) {
return name;
} else {
char[] chars = name.toCharArray();
chars[0] = Character.toLowerCase(chars[0]);
return new String(chars);
}
} else {
return name;
}
}

private void addSpringValueAnnotations(AnnotationTarget target, AnnotationInstance annotation, boolean addInject,
Set<AnnotationInstance> annotationsToAdd) {
final AnnotationValue annotationValue = annotation.value();
Expand Down Expand Up @@ -440,26 +534,6 @@ private void addSpringValueAnnotations(AnnotationTarget target, AnnotationInstan
}
}

private static boolean addCDINamedAnnotation(AnnotationTarget target,
AnnotationValue annotationValue,
Set<AnnotationInstance> annotationsToAdd) {
if (annotationValue == null) {
return false;
}

final String beanName = determineName(annotationValue);
if (beanName != null && !"".equals(beanName)) {
annotationsToAdd.add(create(
CDI_NAMED_ANNOTATION,
target,
Collections.singletonList(AnnotationValue.createStringValue("value", beanName))));

return true;
}

return false;
}

private static String determineName(AnnotationValue annotationValue) {
if (annotationValue.kind() == AnnotationValue.Kind.ARRAY) {
return annotationValue.asStringArray()[0];
Expand Down
@@ -0,0 +1,7 @@
package io.quarkus.spring.di.deployment;

import org.springframework.stereotype.Component;

@Component
public class Bean {
}