Skip to content

Commit

Permalink
Merge pull request #116 from dmlloyd/assoc-sec
Browse files Browse the repository at this point in the history
Update AssociationImpl to read the identity from the request and dispatch it on method invoke
  • Loading branch information
dmlloyd committed Mar 2, 2017
2 parents 5f81688 + 1306f2d commit 18de628
Showing 1 changed file with 17 additions and 6 deletions.
23 changes: 17 additions & 6 deletions ejb3/src/main/java/org/jboss/as/ejb3/remote/AssociationImpl.java
Expand Up @@ -57,11 +57,13 @@
import org.jboss.remoting3.Connection;
import org.wildfly.clustering.registry.Registry;
import org.wildfly.common.annotation.NotNull;
import org.wildfly.security.auth.server.SecurityIdentity;

import javax.ejb.EJBException;

import java.io.IOException;
import java.lang.reflect.Method;
import java.security.PrivilegedExceptionAction;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
Expand Down Expand Up @@ -354,7 +356,7 @@ public void deploymentRemoved(final DeploymentModuleIdentifier deployment) {
return () -> deploymentRepository.removeListener(listener);
}

protected EjbDeploymentInformation findEJB(final String appName, final String moduleName, final String distinctName, final String beanName) {
private EjbDeploymentInformation findEJB(final String appName, final String moduleName, final String distinctName, final String beanName) {
final DeploymentModuleIdentifier ejbModule = new DeploymentModuleIdentifier(appName, moduleName, distinctName);
final Map<DeploymentModuleIdentifier, ModuleDeployment> modules = this.deploymentRepository.getStartedModules();
if (modules == null || modules.isEmpty()) {
Expand All @@ -367,7 +369,7 @@ protected EjbDeploymentInformation findEJB(final String appName, final String mo
return moduleDeployment.getEjbs().get(beanName);
}

private Object invokeMethod(final ComponentView componentView, final Method method, final InvocationRequest incomingInvocation, final InvocationRequest.Resolved content, final CancellationFlag cancellationFlag) throws Exception {
static Object invokeMethod(final ComponentView componentView, final Method method, final InvocationRequest incomingInvocation, final InvocationRequest.Resolved content, final CancellationFlag cancellationFlag) throws Exception {
final InterceptorContext interceptorContext = new InterceptorContext();
interceptorContext.setParameters(content.getParameters());
interceptorContext.setMethod(method);
Expand Down Expand Up @@ -409,21 +411,30 @@ private Object invokeMethod(final ComponentView componentView, final Method meth
if (content.hasTransaction()) {
interceptorContext.setTransactionSupplier(content::getTransaction);
}
// add security identity
final SecurityIdentity securityIdentity = incomingInvocation.getSecurityIdentity();
final boolean isAsync = componentView.isAsynchronous(method);
final boolean oneWay = isAsync && method.getReturnType() == void.class;
final boolean isSessionBean = componentView.getComponent() instanceof SessionBeanComponent;
if (isAsync && isSessionBean) {
if (! oneWay) {
interceptorContext.putPrivateData(CancellationFlag.class, cancellationFlag);
}
final Object result = componentView.invoke(interceptorContext);
final Object result = invokeWithIdentity(componentView, interceptorContext, securityIdentity);
return result == null ? null : ((Future<?>) result).get();
} else {
return componentView.invoke(interceptorContext);
return invokeWithIdentity(componentView, interceptorContext, securityIdentity);
}
}

private Method findMethod(final ComponentView componentView, final EJBMethodLocator ejbMethodLocator) {
private static Object invokeWithIdentity(final ComponentView componentView, final InterceptorContext interceptorContext, final SecurityIdentity securityIdentity) throws Exception {
return securityIdentity == null ? componentView.invoke(interceptorContext) : securityIdentity.runAs((PrivilegedExceptionAction<Object>) () -> {
// TODO: replace this with identity.runAsFunctionEx() once it is available
return componentView.invoke(interceptorContext);
});
}

private static Method findMethod(final ComponentView componentView, final EJBMethodLocator ejbMethodLocator) {
final Set<Method> viewMethods = componentView.getViewMethods();
for (final Method method : viewMethods) {
if (method.getName().equals(ejbMethodLocator.getMethodName())) {
Expand All @@ -446,7 +457,7 @@ private Method findMethod(final ComponentView componentView, final EJBMethodLoca
return null;
}

private Affinity getWeakAffinity(final StatefulSessionComponent statefulSessionComponent, final StatefulEJBLocator<?> statefulEJBLocator) {
private static Affinity getWeakAffinity(final StatefulSessionComponent statefulSessionComponent, final StatefulEJBLocator<?> statefulEJBLocator) {
final SessionID sessionID = statefulEJBLocator.getSessionId();
return statefulSessionComponent.getCache().getWeakAffinity(sessionID);
}
Expand Down

0 comments on commit 18de628

Please sign in to comment.