Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,160 @@
/*
* Copyright (C) 2020 Temporal Technologies, Inc. All Rights Reserved.
*
* Copyright 2012-2016 Amazon.com, Inc. or its affiliates. All Rights Reserved.
*
* Modifications copyright (C) 2017 Uber Technologies, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License"). You may not
* use this file except in compliance with the License. A copy of the License is
* located at
*
* http://aws.amazon.com/apache2.0
*
* or in the "license" file accompanying this file. This file is distributed on
* an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
* express or implied. See the License for the specific language governing
* permissions and limitations under the License.
*/

package io.temporal.authorization;

import static org.junit.Assert.*;

import io.grpc.*;
import io.temporal.client.WorkflowClient;
import io.temporal.client.WorkflowClientOptions;
import io.temporal.client.WorkflowOptions;
import io.temporal.internal.testing.WorkflowTestingTest;
import io.temporal.serviceclient.WorkflowServiceStubsOptions;
import io.temporal.testing.TestEnvironmentOptions;
import io.temporal.testing.TestWorkflowEnvironment;
import io.temporal.worker.Worker;
import io.temporal.workflow.Workflow;
import io.temporal.workflow.shared.TestWorkflows;
import java.time.Duration;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import org.junit.After;
import org.junit.Before;
import org.junit.Rule;
import org.junit.Test;
import org.junit.rules.TestWatcher;
import org.junit.runner.Description;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class AuthorizationTokenTest {

private static final Logger log = LoggerFactory.getLogger(AuthorizationTokenTest.class);
private static final String TASK_QUEUE = "test-workflow";
private static final String AUTH_TOKEN = "Bearer <token>";

private TestWorkflowEnvironment testEnvironment;

@Rule
public TestWatcher watchman =
new TestWatcher() {
@Override
protected void failed(Throwable e, Description description) {
System.err.println(testEnvironment.getDiagnostics());
}
};

private final List<GrpcRequest> loggedRequests = new ArrayList<>();

@Before
public void setUp() {
loggedRequests.clear();
WorkflowServiceStubsOptions stubOptions =
WorkflowServiceStubsOptions.newBuilder()
.addGrpcClientInterceptor(
new ClientInterceptor() {
@Override
public <ReqT, RespT> ClientCall<ReqT, RespT> interceptCall(
MethodDescriptor<ReqT, RespT> method, CallOptions callOptions, Channel next) {
return new LoggingClientCall<>(method, next.newCall(method, callOptions));
}

class LoggingClientCall<ReqT, RespT>
extends ForwardingClientCall.SimpleForwardingClientCall<ReqT, RespT> {

private final MethodDescriptor<ReqT, RespT> method;

LoggingClientCall(
MethodDescriptor<ReqT, RespT> method, ClientCall<ReqT, RespT> call) {
super(call);
this.method = method;
}

@Override
public void start(Listener<RespT> responseListener, Metadata headers) {
loggedRequests.add(
new GrpcRequest(
method.getBareMethodName(),
headers.get(
AuthorizationGrpcMetadataProvider.JWT_AUTHORIZATION_HEADER_KEY)));
super.start(responseListener, headers);
}
}
})
.addGrpcMetadataProvider(new AuthorizationGrpcMetadataProvider(() -> AUTH_TOKEN))
.build();

TestEnvironmentOptions options =
TestEnvironmentOptions.newBuilder()
.setWorkflowClientOptions(
WorkflowClientOptions.newBuilder()
.setContextPropagators(
Collections.singletonList(new WorkflowTestingTest.TestContextPropagator()))
.build())
.setWorkflowServiceStubsOptions(stubOptions)
.build();

testEnvironment = TestWorkflowEnvironment.newInstance(options);
}

@After
public void tearDown() {
testEnvironment.close();
}

@Test
public void allRequestsShouldHaveAnAuthToken() {
Worker worker = testEnvironment.newWorker(TASK_QUEUE);
worker.registerWorkflowImplementationTypes(EmptyWorkflowImpl.class);
testEnvironment.start();
WorkflowClient client = testEnvironment.getWorkflowClient();
WorkflowOptions options = WorkflowOptions.newBuilder().setTaskQueue(TASK_QUEUE).build();
TestWorkflows.TestWorkflow1 workflow =
client.newWorkflowStub(TestWorkflows.TestWorkflow1.class, options);
String result = workflow.execute("input1");
assertEquals("TestWorkflow1-input1", result);

assertFalse(loggedRequests.isEmpty());
for (GrpcRequest grpcRequest : loggedRequests) {
assertEquals(
"All requests should have an auth token", AUTH_TOKEN, grpcRequest.authTokenValue);
}
}

public static class EmptyWorkflowImpl implements TestWorkflows.TestWorkflow1 {

@Override
public String execute(String input) {
Workflow.sleep(Duration.ofMinutes(5)); // test time skipping
return Workflow.getInfo().getWorkflowType() + "-" + input;
}
}

class GrpcRequest {
String methodName;
String authTokenValue;

public GrpcRequest(String methodName, String authTokenValue) {
this.methodName = methodName;
this.authTokenValue = authTokenValue;
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
/*
* Copyright (C) 2020 Temporal Technologies, Inc. All Rights Reserved.
*
* Copyright 2012-2016 Amazon.com, Inc. or its affiliates. All Rights Reserved.
*
* Modifications copyright (C) 2017 Uber Technologies, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License"). You may not
* use this file except in compliance with the License. A copy of the License is
* located at
*
* http://aws.amazon.com/apache2.0
*
* or in the "license" file accompanying this file. This file is distributed on
* an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
* express or implied. See the License for the specific language governing
* permissions and limitations under the License.
*/

package io.temporal.authorization;

import io.grpc.Metadata;
import io.temporal.serviceclient.GrpcMetadataProvider;

public class AuthorizationGrpcMetadataProvider implements GrpcMetadataProvider {
public static final Metadata.Key<String> JWT_AUTHORIZATION_HEADER_KEY =
Metadata.Key.of("authorization", Metadata.ASCII_STRING_MARSHALLER);

private final AuthorizationTokenSupplier authorizationTokenSupplier;

protected AuthorizationGrpcMetadataProvider(
AuthorizationTokenSupplier authorizationTokenSupplier) {
this.authorizationTokenSupplier = authorizationTokenSupplier;
}

@Override
public Metadata getMetadata() {
Metadata metadata = new Metadata();
metadata.put(JWT_AUTHORIZATION_HEADER_KEY, authorizationTokenSupplier.supply());
return metadata;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
/*
* Copyright (C) 2020 Temporal Technologies, Inc. All Rights Reserved.
*
* Copyright 2012-2016 Amazon.com, Inc. or its affiliates. All Rights Reserved.
*
* Modifications copyright (C) 2017 Uber Technologies, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License"). You may not
* use this file except in compliance with the License. A copy of the License is
* located at
*
* http://aws.amazon.com/apache2.0
*
* or in the "license" file accompanying this file. This file is distributed on
* an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
* express or implied. See the License for the specific language governing
* permissions and limitations under the License.
*/

package io.temporal.authorization;

/**
* Supplies tokens that will be sent to the Temporal server to perform authorization.
*
* <p>The default JWT ClaimMapper expects authorization tokens to be in the following format:
*
* <p>{@code Bearer <token>}
*
* <p>{@code <token>} Must be the Base64 url-encoded value of the token.
*
* @see <a href="https://docs.temporal.io/docs/server/security/#format-of-json-web-tokens">Format of
* JWT</a>
*/
public interface AuthorizationTokenSupplier {
String supply();
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
/*
* Copyright (C) 2020 Temporal Technologies, Inc. All Rights Reserved.
*
* Copyright 2012-2016 Amazon.com, Inc. or its affiliates. All Rights Reserved.
*
* Modifications copyright (C) 2017 Uber Technologies, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License"). You may not
* use this file except in compliance with the License. A copy of the License is
* located at
*
* http://aws.amazon.com/apache2.0
*
* or in the "license" file accompanying this file. This file is distributed on
* an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
* express or implied. See the License for the specific language governing
* permissions and limitations under the License.
*/

package io.temporal.serviceclient;

import io.grpc.Metadata;

/** Provides additional Metadata (gRPC Headers) that should be used on every request. */
public interface GrpcMetadataProvider {
Metadata getMetadata();
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
/*
* Copyright (C) 2020 Temporal Technologies, Inc. All Rights Reserved.
*
* Copyright 2012-2016 Amazon.com, Inc. or its affiliates. All Rights Reserved.
*
* Modifications copyright (C) 2017 Uber Technologies, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License"). You may not
* use this file except in compliance with the License. A copy of the License is
* located at
*
* http://aws.amazon.com/apache2.0
*
* or in the "license" file accompanying this file. This file is distributed on
* an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
* express or implied. See the License for the specific language governing
* permissions and limitations under the License.
*/

package io.temporal.serviceclient;

import static com.google.common.base.Preconditions.checkNotNull;

import io.grpc.*;
import java.util.Collection;

public class GrpcMetadataProviderInterceptor implements ClientInterceptor {
private final Collection<GrpcMetadataProvider> grpcMetadataProviders;

public GrpcMetadataProviderInterceptor(Collection<GrpcMetadataProvider> grpcMetadataProviders) {
this.grpcMetadataProviders = checkNotNull(grpcMetadataProviders, "grpcMetadataProviders");
}

@Override
public <ReqT, RespT> ClientCall<ReqT, RespT> interceptCall(
MethodDescriptor<ReqT, RespT> method, CallOptions callOptions, Channel next) {
return new HeaderAttachingClientCall<>(next.newCall(method, callOptions));
}

private final class HeaderAttachingClientCall<ReqT, RespT>
extends ForwardingClientCall.SimpleForwardingClientCall<ReqT, RespT> {

HeaderAttachingClientCall(ClientCall<ReqT, RespT> call) {
super(call);
}

@Override
public void start(Listener<RespT> responseListener, Metadata headers) {
grpcMetadataProviders.stream().map(GrpcMetadataProvider::getMetadata).forEach(headers::merge);
super.start(responseListener, headers);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
import io.temporal.api.workflowservice.v1.WorkflowServiceGrpc;
import io.temporal.internal.retryer.GrpcRetryer;
import java.io.IOException;
import java.util.Collection;
import java.util.concurrent.Executors;
import java.util.concurrent.ScheduledExecutorService;
import java.util.concurrent.TimeUnit;
Expand Down Expand Up @@ -163,10 +164,22 @@ public WorkflowServiceStubsImpl(
healthBlockingStub = HealthGrpc.newBlockingStub(channel);
checkHealth();

Channel interceptedChannel = channel;

interceptedChannel = applyCustomInterceptors(interceptedChannel);
interceptedChannel = applyStandardInterceptors(interceptedChannel);

this.blockingStub = WorkflowServiceGrpc.newBlockingStub(interceptedChannel);
this.futureStub = WorkflowServiceGrpc.newFutureStub(interceptedChannel);
log.info(String.format("Created GRPC client for channel: %s", channel));
}

private Channel applyStandardInterceptors(Channel channel) {
GrpcMetricsInterceptor metricsInterceptor =
new GrpcMetricsInterceptor(options.getMetricsScope());
ClientInterceptor deadlineInterceptor = new GrpcDeadlineInterceptor(options);
GrpcTracingInterceptor tracingInterceptor = new GrpcTracingInterceptor();

Metadata headers = new Metadata();
headers.merge(options.getHeaders());
headers.put(LIBRARY_VERSION_HEADER_KEY, Version.LIBRARY_VERSION);
Expand All @@ -181,19 +194,28 @@ public WorkflowServiceStubsImpl(
if (tracingInterceptor.isEnabled()) {
interceptedChannel = ClientInterceptors.intercept(interceptedChannel, tracingInterceptor);
}
WorkflowServiceGrpc.WorkflowServiceBlockingStub bs =
WorkflowServiceGrpc.newBlockingStub(interceptedChannel);
if (options.getBlockingStubInterceptor().isPresent()) {
bs = options.getBlockingStubInterceptor().get().apply(bs);
interceptedChannel = applyGrpcMetadataProviderInterceptors(interceptedChannel);
return interceptedChannel;
}

private Channel applyGrpcMetadataProviderInterceptors(Channel channel) {
Collection<GrpcMetadataProvider> grpcMetadataProviders = options.getGrpcMetadataProviders();
if (grpcMetadataProviders != null && !grpcMetadataProviders.isEmpty()) {
GrpcMetadataProviderInterceptor grpcMetadataProviderInterceptor =
new GrpcMetadataProviderInterceptor(grpcMetadataProviders);
channel = ClientInterceptors.intercept(channel, grpcMetadataProviderInterceptor);
}
this.blockingStub = bs;
WorkflowServiceGrpc.WorkflowServiceFutureStub fs =
WorkflowServiceGrpc.newFutureStub(interceptedChannel);
if (options.getFutureStubInterceptor().isPresent()) {
fs = options.getFutureStubInterceptor().get().apply(fs);
return channel;
}

private Channel applyCustomInterceptors(Channel channel) {
Collection<ClientInterceptor> grpcClientInterceptors = options.getGrpcClientInterceptors();
if (grpcClientInterceptors != null) {
for (ClientInterceptor interceptor : grpcClientInterceptors) {
channel = ClientInterceptors.intercept(channel, interceptor);
}
}
this.futureStub = fs;
log.info(String.format("Created GRPC client for channel: %s", channel));
return channel;
}

private Runnable enterGrpcIdleChannelStateTask() {
Expand Down
Loading