diff --git a/spring-grpc-core/src/main/java/org/springframework/grpc/server/security/AuthenticationProcessInterceptor.java b/spring-grpc-core/src/main/java/org/springframework/grpc/server/security/AuthenticationProcessInterceptor.java index 5a06a080..b0648e35 100644 --- a/spring-grpc-core/src/main/java/org/springframework/grpc/server/security/AuthenticationProcessInterceptor.java +++ b/spring-grpc-core/src/main/java/org/springframework/grpc/server/security/AuthenticationProcessInterceptor.java @@ -26,6 +26,7 @@ import org.springframework.security.core.context.SecurityContext; import org.springframework.security.core.context.SecurityContextHolder; +import io.grpc.ForwardingServerCallListener.SimpleForwardingServerCallListener; import io.grpc.Metadata; import io.grpc.ServerCall; import io.grpc.ServerCall.Listener; @@ -90,7 +91,64 @@ else if (user == null || !user.isAuthenticated()) { throw new BadCredentialsException("not authenticated"); } - return next.startCall(call, headers); + SecurityContext currentContext = SecurityContextHolder.getContext(); + return new SecurityContextClearingListener<>(next.startCall(call, headers), currentContext); + } + + static class SecurityContextClearingListener extends SimpleForwardingServerCallListener { + + private final SecurityContext securityContext; + + SecurityContextClearingListener(ServerCall.Listener delegate, SecurityContext securityContext) { + super(delegate); + this.securityContext = securityContext; + } + + @Override + public void onMessage(ReqT message) { + SecurityContextHolder.setContext(this.securityContext); + try { + super.onMessage(message); + } + finally { + SecurityContextHolder.clearContext(); + } + } + + @Override + public void onHalfClose() { + SecurityContextHolder.setContext(this.securityContext); + try { + super.onHalfClose(); + } + finally { + SecurityContextHolder.clearContext(); + } + } + + @Override + public void onReady() { + SecurityContextHolder.setContext(this.securityContext); + try { + super.onReady(); + } + finally { + SecurityContextHolder.clearContext(); + } + } + + @Override + public void onCancel() { + super.onCancel(); + SecurityContextHolder.clearContext(); + } + + @Override + public void onComplete() { + super.onComplete(); + SecurityContextHolder.clearContext(); + } + } }