Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
import io.temporal.common.interceptors.WorkflowOutboundCallsInterceptor;
import io.temporal.common.interceptors.WorkflowOutboundCallsInterceptorBase;
import io.temporal.failure.CanceledFailure;
import io.temporal.internal.WorkflowThreadMarker;
import io.temporal.internal.context.ContextThreadLocal;
import io.temporal.internal.replay.ExecuteActivityParameters;
import io.temporal.internal.replay.ExecuteLocalActivityParameters;
Expand Down Expand Up @@ -142,7 +143,13 @@ static Optional<WorkflowThread> currentThreadInternalIfPresent() {
}

static void setCurrentThreadInternal(WorkflowThread coroutine) {
currentThreadThreadLocal.set(coroutine);
if (coroutine != null) {
currentThreadThreadLocal.set(coroutine);
WorkflowThreadMarkerAccessor.markAsWorkflowThread();
} else {
currentThreadThreadLocal.set(null);
WorkflowThreadMarkerAccessor.markAsNonWorkflowThread();
}
}

/**
Expand Down Expand Up @@ -757,4 +764,14 @@ public int getAttempt() {
return 1;
}
}

private static class WorkflowThreadMarkerAccessor extends WorkflowThreadMarker {
public static void markAsWorkflowThread() {
isWorkflowThreadThreadLocal.set(true);
}

public static void markAsNonWorkflowThread() {
isWorkflowThreadThreadLocal.set(false);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@

package io.temporal.internal.sync;

import static io.temporal.internal.WorkflowThreadMarker.enforceNonWorkflowThread;

import com.google.common.base.Strings;
import com.google.common.reflect.TypeToken;
import com.uber.m3.tally.Scope;
Expand All @@ -32,6 +34,7 @@
import io.temporal.common.converter.DataConverter;
import io.temporal.common.interceptors.WorkflowClientCallsInterceptor;
import io.temporal.common.interceptors.WorkflowClientInterceptor;
import io.temporal.internal.WorkflowThreadMarker;
import io.temporal.internal.client.RootWorkflowClientInvoker;
import io.temporal.internal.external.GenericWorkflowClientExternalImpl;
import io.temporal.internal.external.ManualActivityCompletionClientFactory;
Expand Down Expand Up @@ -62,15 +65,18 @@ public final class WorkflowClientInternal implements WorkflowClient {
private final Scope metricsScope;

/**
* Creates client that connects to an instance of the Temporal Service.
* Creates client that connects to an instance of the Temporal Service. Cannot be used from within
* workflow code.
*
* @param service client to the Temporal Service endpoint.
* @param options Options (like {@link io.temporal.common.converter.DataConverter} override) for
* configuring client.
*/
public static WorkflowClient newInstance(
WorkflowServiceStubs service, WorkflowClientOptions options) {
return new WorkflowClientInternal(service, options);
enforceNonWorkflowThread();
return WorkflowThreadMarker.protectFromWorkflowThread(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should consider adding an option that would disable this enforcement. If users are already running "bad" workflows in production, then there should be a way to disable this feature until workflows are refactored/updated, otherwise we are forcing users to stay on older SDK versions.

new WorkflowClientInternal(service, options), WorkflowClient.class);
}

private WorkflowClientInternal(
Expand Down Expand Up @@ -210,7 +216,9 @@ public WorkflowStub newUntypedWorkflowStub(
@Override
public ActivityCompletionClient newActivityCompletionClient() {
ActivityCompletionClient result =
new ActivityCompletionClientImpl(manualActivityCompletionClientFactory, () -> {});
WorkflowThreadMarker.protectFromWorkflowThread(
new ActivityCompletionClientImpl(manualActivityCompletionClientFactory, () -> {}),
ActivityCompletionClient.class);
for (WorkflowClientInterceptor i : interceptors) {
result = i.newActivityCompletionClient(result);
}
Expand All @@ -228,6 +236,7 @@ public WorkflowExecution signalWithStart(BatchRequest signalWithStartBatch) {
}

public static WorkflowExecution start(Functions.Proc workflow) {
enforceNonWorkflowThread();
WorkflowInvocationHandler.initAsyncInvocation(InvocationType.START);
try {
workflow.apply();
Expand Down Expand Up @@ -321,6 +330,7 @@ public static <A1, A2, A3, A4, A5, A6, R> WorkflowExecution start(

@SuppressWarnings("unchecked")
public static CompletableFuture<Void> execute(Functions.Proc workflow) {
enforceNonWorkflowThread();
WorkflowInvocationHandler.initAsyncInvocation(InvocationType.EXECUTE);
try {
workflow.apply();
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
/*
* Copyright (C) 2020 Temporal Technologies, Inc. All Rights Reserved.
*
* Copyright 2012-2016 Amazon.com, Inc. or its affiliates. All Rights Reserved.
*
* Modifications copyright (C) 2017 Uber Technologies, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License"). You may not
* use this file except in compliance with the License. A copy of the License is
* located at
*
* http://aws.amazon.com/apache2.0
*
* or in the "license" file accompanying this file. This file is distributed on
* an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
* express or implied. See the License for the specific language governing
* permissions and limitations under the License.
*/

package io.temporal.workflow;

import static org.junit.Assert.assertTrue;
import static org.junit.Assert.fail;

import io.temporal.client.WorkflowClient;
import io.temporal.workflow.shared.SDKTestWorkflowRule;
import io.temporal.workflow.shared.TestWorkflows.ITestNamedChild;
import io.temporal.workflow.shared.TestWorkflows.NoArgsWorkflow;
import io.temporal.workflow.shared.TestWorkflows.TestNamedChild;
import org.junit.Before;
import org.junit.Rule;
import org.junit.Test;

public class ProhibitedCallsFromWorkflowTest {

@Rule
public SDKTestWorkflowRule testWorkflowRule =
SDKTestWorkflowRule.newBuilder()
.setWorkflowTypes(TestWorkflow.class, TestNamedChild.class)
.build();

private static WorkflowClient workflowClient;

@Before
public void setUp() throws Exception {
workflowClient = testWorkflowRule.getWorkflowClient();
}

@Test
public void testWorkflowClientCallFromWorkflow() {
NoArgsWorkflow client = testWorkflowRule.newWorkflowStubTimeoutOptions(NoArgsWorkflow.class);
client.execute();
}

public static class TestWorkflow implements NoArgsWorkflow {
@Override
public void execute() {
ITestNamedChild child = Workflow.newChildWorkflowStub(ITestNamedChild.class);
try {
WorkflowClient.execute(child::execute, "hello");
fail("should be unreachable, we expect an exception");
} catch (IllegalStateException e) {
assertTrue(e.getMessage().startsWith("Cannot be called from workflow thread."));
}
try {
WorkflowClient.start(child::execute, "world");
fail("should be unreachable, we expect an exception");
} catch (IllegalStateException e) {
assertTrue(e.getMessage().startsWith("Cannot be called from workflow thread."));
}
try {
// let's imagine that the workflow code somehow got a WorkflowClient instance (from DI for
// example).
// Let's make sure it still can't trigger it's methods
workflowClient.getOptions();
fail("should be unreachable, we expect an exception");
} catch (IllegalStateException e) {
assertTrue(e.getMessage().startsWith("Cannot be called from workflow thread."));
}
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
package io.temporal.conf;

public final class EnvironmentVariableNames {
/**
* Specify this env variable to disable checks and enforcement for classes that are not intended
* to be accessed from workflow code.
*
* <p>Not specifying it or setting it to "false" (case insensitive) leaves the checks enforced.
*
* <p>This option is exposed for backwards compatibility only and should never be enabled for any
* new code or application.
*/
public static final String DISABLE_NON_WORKFLOW_CODE_ENFORCEMENTS =
"TEMPORAL_DISABLE_NON_WORKFLOW_CODE_ENFORCEMENTS";

private EnvironmentVariableNames() {}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
/*
* Copyright (C) 2020 Temporal Technologies, Inc. All Rights Reserved.
*
* Copyright 2012-2016 Amazon.com, Inc. or its affiliates. All Rights Reserved.
*
* Modifications copyright (C) 2017 Uber Technologies, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License"). You may not
* use this file except in compliance with the License. A copy of the License is
* located at
*
* http://aws.amazon.com/apache2.0
*
* or in the "license" file accompanying this file. This file is distributed on
* an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
* express or implied. See the License for the specific language governing
* permissions and limitations under the License.
*/

package io.temporal.internal;

import io.temporal.conf.EnvironmentVariableNames;
import java.lang.reflect.InvocationTargetException;
import java.lang.reflect.Proxy;

/**
* Provides an access to information about a thread type the current code executes in to perform
* different type of access checks inside Temporal library code.
*
* <p>Note: This class is a singleton and is not intended for an extension.
*
* <p>Note: This class shouldn't be accessed in any way by the application code.
*/
public abstract class WorkflowThreadMarker {
protected static final ThreadLocal<Boolean> isWorkflowThreadThreadLocal =
ThreadLocal.withInitial(() -> false);

private static final boolean enableEnforcements;

static {
String envValue =
System.getenv(EnvironmentVariableNames.DISABLE_NON_WORKFLOW_CODE_ENFORCEMENTS);
enableEnforcements = envValue == null || "false".equalsIgnoreCase(envValue);
}

/** @return true if the current thread is workflow thread */
public static boolean isWorkflowThread() {
return isWorkflowThreadThreadLocal.get();
}

/**
* Throws {@link IllegalStateException} if it's called from workflow thread.
*
* @see io.temporal.conf.EnvironmentVariableNames#DISABLE_NON_WORKFLOW_CODE_ENFORCEMENTS
*/
public static void enforceNonWorkflowThread() {
if (enableEnforcements && isWorkflowThread()) {
throw new IllegalStateException("Cannot be called from workflow thread.");
}
}

/**
* Create a proxy that checks all methods executions if they are done from a workflow thread and
* makes them throw an IllegalStateException if they are indeed triggered from workflow code
*
* @param instance an instance to wrap
* @param iface an interface the {@code instance} implements and that proxy should implement and
* intercept
* @return a proxy that makes sure that it's methods can't be called from workflow thread
*/
@SuppressWarnings("unchecked")
public static <T> T protectFromWorkflowThread(T instance, Class<T> iface) {
return (T)
Proxy.newProxyInstance(
iface.getClassLoader(),
new Class<?>[] {iface},
(proxy, method, args) -> {
enforceNonWorkflowThread();
try {
return method.invoke(instance, args);
} catch (InvocationTargetException e) {
throw e.getCause();
}
});
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,10 @@

package io.temporal.serviceclient;

import static io.temporal.internal.WorkflowThreadMarker.enforceNonWorkflowThread;

import io.temporal.api.workflowservice.v1.WorkflowServiceGrpc;
import io.temporal.internal.WorkflowThreadMarker;
import java.util.concurrent.TimeUnit;

/** Initializes and holds gRPC blocking and future stubs. */
Expand All @@ -30,12 +33,12 @@ public interface WorkflowServiceStubs {
* the locally running temporal service.
*/
static WorkflowServiceStubs newInstance() {
return new WorkflowServiceStubsImpl(null, WorkflowServiceStubsOptions.getDefaultInstance());
return newInstance(WorkflowServiceStubsOptions.getDefaultInstance());
}

/** Create gRPC connection stubs using provided options. */
static WorkflowServiceStubs newInstance(WorkflowServiceStubsOptions options) {
return new WorkflowServiceStubsImpl(null, options);
return newInstance(null, options);
}

/**
Expand All @@ -44,7 +47,9 @@ static WorkflowServiceStubs newInstance(WorkflowServiceStubsOptions options) {
*/
static WorkflowServiceStubs newInstance(
WorkflowServiceGrpc.WorkflowServiceImplBase service, WorkflowServiceStubsOptions options) {
return new WorkflowServiceStubsImpl(service, options);
enforceNonWorkflowThread();
return WorkflowThreadMarker.protectFromWorkflowThread(
new WorkflowServiceStubsImpl(service, options), WorkflowServiceStubs.class);
}

/** @return Blocking (synchronous) stub that allows direct calls to service. */
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,19 +19,16 @@

package io.temporal.internal.sync;

import com.google.common.annotations.VisibleForTesting;
import java.lang.reflect.InvocationHandler;
import java.lang.reflect.Method;
import java.util.Objects;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ExecutionException;

@VisibleForTesting
public class DeterministicRunnerWrapper implements InvocationHandler {

private final InvocationHandler invocationHandler;

@VisibleForTesting
public DeterministicRunnerWrapper(InvocationHandler invocationHandler) {
this.invocationHandler = Objects.requireNonNull(invocationHandler);
}
Expand Down
Loading