Skip to content

Commit

Permalink
feat(tasks): add a generic feature toggle for skippable tasks (#4008)
Browse files Browse the repository at this point in the history
* feat(tasks): add a generic feature toggle for every task

This would give us a consistent way to disable any task (moving straight to SKIPPED) without the need for each task implementation to reinvent the wheel.

Also:
* move toggle computing logic from StartTaskHandler to SkippableTask (this makes this setting overridable)
* add documentation to SkippableTask
* fix QueueIntegrationTest and StartTaskHandlerTest (the behavior is slightly different now, the task resolver must point to a concrete DummyTask instance so that we can call its `isEnabledPropertyName()` method)
* opportunistically reduce mocking code duplication in QueueIntegrationTest
  • Loading branch information
dreynaud committed Nov 20, 2020
1 parent b20fc42 commit 57eed3f
Show file tree
Hide file tree
Showing 8 changed files with 153 additions and 51 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
package com.netflix.spinnaker.orca.api.pipeline;

/**
* A skippable task can be configured via properties to go directly from NOT_STARTED to SKIPPED. By
* default, the property name is:
*
* <p>tasks.$taskId.enabled
*
* <p>where `taskId` corresponds to the simple class name (without the package) with a lower case
* first character. For example, a skippable class `com.foo.DummySkippableTask` could be disabled
* via property
*
* <p>tasks.dummySkippableTask.enabled
*
* @see StartTaskHandler
*/
public interface SkippableTask extends Task {
static String isEnabledPropertyName(String name) {
String loweredName = Character.toLowerCase(name.charAt(0)) + name.substring(1);
return String.format("tasks.%s.enabled", loweredName);
}

default String isEnabledPropertyName() {
return isEnabledPropertyName(
getClass().getSimpleName().isBlank() ? getClass().getName() : getClass().getSimpleName());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@

@Component
public class WaitStage implements StageDefinitionBuilder {

public static String STAGE_TYPE = "wait";

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ import com.netflix.spinnaker.kork.discovery.DiscoveryStatusChangeEvent
import com.netflix.spinnaker.kork.discovery.InstanceStatus
import com.netflix.spinnaker.kork.discovery.RemoteStatusChangedEvent
import com.netflix.spinnaker.orca.api.pipeline.CancellableStage
import com.netflix.spinnaker.orca.api.pipeline.SkippableTask
import com.netflix.spinnaker.orca.api.pipeline.SyntheticStageOwner
import com.netflix.spinnaker.orca.api.pipeline.SyntheticStageOwner.STAGE_BEFORE
import com.netflix.spinnaker.orca.api.pipeline.TaskResult
Expand Down Expand Up @@ -124,7 +125,13 @@ abstract class QueueIntegrationTest {
}

@After
fun resetMocks() = reset(dummyTask)
fun resetMocks() {
reset(dummyTask)
whenever(dummyTask.extensionClass) doReturn dummyTask::class.java
whenever(dummyTask.getDynamicTimeout(any())) doReturn 2000L
whenever(dummyTask.isEnabledPropertyName) doReturn SkippableTask.isEnabledPropertyName(dummyTask.javaClass.simpleName)
}


@Test
fun `can run a simple pipeline`() {
Expand All @@ -137,8 +144,6 @@ abstract class QueueIntegrationTest {
}
repository.store(pipeline)

whenever(dummyTask.extensionClass) doReturn dummyTask::class.java
whenever(dummyTask.getDynamicTimeout(any())) doReturn 2000L
whenever(dummyTask.execute(any())) doReturn TaskResult.SUCCEEDED

context.runToCompletion(pipeline, runner::start, repository)
Expand All @@ -158,8 +163,6 @@ abstract class QueueIntegrationTest {
}
repository.store(pipeline)

whenever(dummyTask.extensionClass) doReturn dummyTask::class.java
whenever(dummyTask.getDynamicTimeout(any())) doReturn 2000L
whenever(dummyTask.execute(any())) doReturn TaskResult.RUNNING doReturn TaskResult.SUCCEEDED

context.runToCompletion(pipeline, runner::start, repository)
Expand Down Expand Up @@ -194,8 +197,6 @@ abstract class QueueIntegrationTest {
}
repository.store(pipeline)

whenever(dummyTask.extensionClass) doReturn dummyTask::class.java
whenever(dummyTask.getDynamicTimeout(any())) doReturn 2000L
whenever(dummyTask.execute(any())) doReturn TaskResult.SUCCEEDED

context.runToCompletion(pipeline, runner::start, repository)
Expand Down Expand Up @@ -235,8 +236,6 @@ abstract class QueueIntegrationTest {
}
repository.store(pipeline)

whenever(dummyTask.extensionClass) doReturn dummyTask::class.java
whenever(dummyTask.getDynamicTimeout(any())) doReturn 2000L
whenever(dummyTask.execute(any())) doReturn TaskResult.SUCCEEDED

context.runToCompletion(pipeline, runner::start, repository)
Expand Down Expand Up @@ -284,7 +283,6 @@ abstract class QueueIntegrationTest {
}
repository.store(pipeline)

whenever(dummyTask.extensionClass) doReturn dummyTask::class.java
whenever(dummyTask.execute(any())) doReturn TaskResult.ofStatus(TERMINAL)

context.runToCompletion(pipeline, runner::start, repository)
Expand Down Expand Up @@ -330,8 +328,6 @@ abstract class QueueIntegrationTest {
}
repository.store(pipeline)

whenever(dummyTask.extensionClass) doReturn dummyTask::class.java
whenever(dummyTask.getDynamicTimeout(any())) doReturn 2000L
whenever(dummyTask.execute(argThat { refId == "2a1" })) doReturn TaskResult.ofStatus(TERMINAL)
whenever(dummyTask.execute(argThat { refId != "2a1" })) doReturn TaskResult.SUCCEEDED

Expand Down Expand Up @@ -379,8 +375,6 @@ abstract class QueueIntegrationTest {
}
repository.store(pipeline)

whenever(dummyTask.extensionClass) doReturn dummyTask::class.java
whenever(dummyTask.getDynamicTimeout(any())) doReturn 2000L
whenever(dummyTask.execute(argThat { refId == "2a1" })) doReturn TaskResult.ofStatus(TERMINAL)
whenever(dummyTask.execute(argThat { refId != "2a1" })) doReturn TaskResult.SUCCEEDED

Expand Down Expand Up @@ -435,8 +429,6 @@ abstract class QueueIntegrationTest {
}
repository.store(pipeline)

whenever(dummyTask.extensionClass) doReturn dummyTask::class.java
whenever(dummyTask.getDynamicTimeout(any())) doReturn 2000L
whenever(dummyTask.execute(argThat { refId == "2a1" })) doReturn TaskResult.ofStatus(TERMINAL)
whenever(dummyTask.execute(argThat { refId != "2a1" })) doReturn TaskResult.SUCCEEDED

Expand Down Expand Up @@ -481,8 +473,8 @@ abstract class QueueIntegrationTest {
repository.store(childPipeline)
repository.store(parentPipeline)

whenever(dummyTask.extensionClass) doReturn dummyTask::class.java
whenever(dummyTask.execute(argThat { refId == "1b" })) doReturn TaskResult.ofStatus(CANCELED)

context.runParentToCompletion(parentPipeline, childPipeline, runner::start, repository)

repository.retrieve(PIPELINE, parentPipeline.id).apply {
Expand Down Expand Up @@ -524,7 +516,6 @@ abstract class QueueIntegrationTest {
}
repository.store(pipeline)

whenever(dummyTask.extensionClass) doReturn dummyTask::class.java
whenever(dummyTask.execute(argThat { refId == "2b" })) doReturn TaskResult.ofStatus(TERMINAL)

context.runToCompletion(pipeline, runner::start, repository)
Expand Down Expand Up @@ -576,8 +567,6 @@ abstract class QueueIntegrationTest {
}
repository.store(pipeline)

whenever(dummyTask.extensionClass) doReturn dummyTask::class.java
whenever(dummyTask.getDynamicTimeout(any())) doReturn 2000L
whenever(dummyTask.execute(any())) doReturn TaskResult.SUCCEEDED

context.runToCompletion(pipeline, runner::start, repository)
Expand Down Expand Up @@ -616,8 +605,6 @@ abstract class QueueIntegrationTest {
}
repository.store(pipeline)

whenever(dummyTask.extensionClass) doReturn dummyTask::class.java
whenever(dummyTask.getDynamicTimeout(any())) doReturn 2000L
whenever(dummyTask.execute(any())) doReturn TaskResult.SUCCEEDED

context.runToCompletion(pipeline, runner::start, repository)
Expand Down Expand Up @@ -656,8 +643,6 @@ abstract class QueueIntegrationTest {
}
repository.store(pipeline)

whenever(dummyTask.extensionClass) doReturn dummyTask::class.java
whenever(dummyTask.getDynamicTimeout(any())) doReturn 2000L
whenever(dummyTask.execute(any())) doReturn TaskResult.builder(SUCCEEDED).context(mapOf("output" to "foo")).build()

context.runToCompletion(pipeline, runner::start, repository)
Expand Down Expand Up @@ -708,8 +693,6 @@ abstract class QueueIntegrationTest {
}
repository.store(pipeline)

whenever(dummyTask.extensionClass) doReturn dummyTask::class.java
whenever(dummyTask.getDynamicTimeout(any())) doReturn 2000L
whenever(dummyTask.execute(any())) doReturn TaskResult.SUCCEEDED // second run succeeds

context.restartAndRunToCompletion(pipeline.stageByRef("1"), runner::restart, repository)
Expand Down Expand Up @@ -750,8 +733,6 @@ abstract class QueueIntegrationTest {
}
repository.store(pipeline)

whenever(dummyTask.extensionClass) doReturn dummyTask::class.java
whenever(dummyTask.getDynamicTimeout(any())) doReturn 2000L
whenever(dummyTask.execute(any())) doAnswer {
val stage = it.arguments.first() as StageExecution
if (stage.refId == "1") {
Expand Down Expand Up @@ -808,8 +789,6 @@ abstract class QueueIntegrationTest {
}
repository.store(pipeline)

whenever(dummyTask.extensionClass) doReturn dummyTask::class.java
whenever(dummyTask.getDynamicTimeout(any())) doReturn 2000L
whenever(dummyTask.execute(any())) doAnswer {
val stage = it.arguments.first() as StageExecution
if (stage.refId == "1") {
Expand Down Expand Up @@ -846,8 +825,6 @@ abstract class QueueIntegrationTest {
}
repository.store(pipeline)

whenever(dummyTask.extensionClass) doReturn dummyTask::class.java
whenever(dummyTask.getDynamicTimeout(any())) doReturn 2000L
whenever(dummyTask.execute(any())) doAnswer {
val stage = it.arguments.first() as StageExecution
if (stage.refId == "1") {
Expand Down Expand Up @@ -892,6 +869,7 @@ class TestConfig {
fun dummyTask(): DummyTask = mock {
on { extensionClass } doReturn DummyTask::class.java
on { getDynamicTimeout(any()) } doReturn Duration.ofMinutes(2).toMillis()
on { isEnabledPropertyName } doReturn SkippableTask.isEnabledPropertyName(DummyTask::class.java.simpleName)
}

@Bean
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,11 @@ package com.netflix.spinnaker.orca.q

import com.netflix.spinnaker.orca.api.pipeline.OverridableTimeoutRetryableTask
import com.netflix.spinnaker.orca.api.pipeline.RetryableTask
import com.netflix.spinnaker.orca.api.pipeline.SkippableTask
import com.netflix.spinnaker.orca.api.pipeline.Task
import com.netflix.spinnaker.orca.clouddriver.utils.CloudProviderAware

interface DummyTask : RetryableTask
interface DummyTask : RetryableTask, SkippableTask
interface DummyCloudProviderAwareTask : RetryableTask, CloudProviderAware
interface InvalidTask : Task
interface DummyTimeoutOverrideTask : OverridableTimeoutRetryableTask
Original file line number Diff line number Diff line change
Expand Up @@ -99,8 +99,8 @@ class CompleteTaskHandler(
return false
}

// the task was not successful and _should not_ run subsequent tasks
return status != SUCCEEDED
// the task _should not_ run subsequent tasks
return status.isHalt
}

override val messageType = CompleteTask::class.java
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,19 +17,24 @@
package com.netflix.spinnaker.orca.q.handler

import com.netflix.spinnaker.orca.TaskResolver
import com.netflix.spinnaker.orca.api.pipeline.SkippableTask
import com.netflix.spinnaker.orca.api.pipeline.models.ExecutionStatus.RUNNING
import com.netflix.spinnaker.orca.api.pipeline.models.ExecutionStatus.SKIPPED
import com.netflix.spinnaker.orca.api.pipeline.models.TaskExecution
import com.netflix.spinnaker.orca.events.TaskComplete
import com.netflix.spinnaker.orca.events.TaskStarted
import com.netflix.spinnaker.orca.pipeline.StageDefinitionBuilderFactory
import com.netflix.spinnaker.orca.pipeline.persistence.ExecutionRepository
import com.netflix.spinnaker.orca.pipeline.util.ContextParameterProcessor
import com.netflix.spinnaker.orca.q.CompleteTask
import com.netflix.spinnaker.orca.q.RunTask
import com.netflix.spinnaker.orca.q.StartTask
import com.netflix.spinnaker.q.Queue
import java.time.Clock
import org.springframework.beans.factory.annotation.Qualifier
import org.springframework.context.ApplicationEventPublisher
import org.springframework.core.env.Environment
import org.springframework.stereotype.Component
import java.time.Clock

@Component
class StartTaskHandler(
Expand All @@ -39,25 +44,50 @@ class StartTaskHandler(
override val stageDefinitionBuilderFactory: StageDefinitionBuilderFactory,
@Qualifier("queueEventPublisher") private val publisher: ApplicationEventPublisher,
private val taskResolver: TaskResolver,
private val clock: Clock
private val clock: Clock,
private val environment: Environment
) : OrcaMessageHandler<StartTask>, ExpressionAware {

override fun handle(message: StartTask) {
message.withTask { stage, task ->
task.status = RUNNING
task.startTime = clock.millis()
val mergedContextStage = stage.withMergedContext()
repository.storeStage(mergedContextStage)
if (isTaskEnabled(task)) {
task.status = RUNNING
task.startTime = clock.millis()
val mergedContextStage = stage.withMergedContext()
repository.storeStage(mergedContextStage)

queue.push(RunTask(message, task.id, task.type))
queue.push(RunTask(message, task.id, task.type))
publisher.publishEvent(TaskStarted(this, mergedContextStage, task))
} else {
task.status = SKIPPED
val mergedContextStage = stage.withMergedContext()
repository.storeStage(mergedContextStage)

publisher.publishEvent(TaskStarted(this, mergedContextStage, task))
queue.push(CompleteTask(message, SKIPPED))
publisher.publishEvent(TaskComplete(this, mergedContextStage, task))
}
}
}

fun isTaskEnabled(task: TaskExecution): Boolean =
when (task.instance) {
is SkippableTask -> {
val asSkippableTask = task.instance as SkippableTask
val enabled = environment.getProperty(asSkippableTask.isEnabledPropertyName, Boolean::class.java, true)
if (!enabled) {
log.debug("Skipping task.type=${task.type} because ${asSkippableTask.isEnabledPropertyName}=false")
}
enabled
}
else -> true
}

override val messageType = StartTask::class.java

@Suppress("UNCHECKED_CAST")
private val TaskExecution.type
get() = taskResolver.getTaskClass(implementingClass)

private val TaskExecution.instance
get() = taskResolver.getTask(implementingClass)
}
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import com.netflix.spinnaker.orca.api.pipeline.models.ExecutionStatus.FAILED_CON
import com.netflix.spinnaker.orca.api.pipeline.models.ExecutionStatus.NOT_STARTED
import com.netflix.spinnaker.orca.api.pipeline.models.ExecutionStatus.REDIRECT
import com.netflix.spinnaker.orca.api.pipeline.models.ExecutionStatus.SKIPPED
import com.netflix.spinnaker.orca.api.pipeline.models.ExecutionStatus.STOPPED
import com.netflix.spinnaker.orca.api.pipeline.models.ExecutionStatus.SUCCEEDED
import com.netflix.spinnaker.orca.api.pipeline.models.ExecutionStatus.TERMINAL
import com.netflix.spinnaker.orca.api.pipeline.models.ExecutionType.PIPELINE
Expand Down Expand Up @@ -79,7 +80,7 @@ object CompleteTaskHandlerTest : SubjectSpek<CompleteTaskHandler>({

fun resetMocks() = reset(queue, repository, publisher)

setOf(SUCCEEDED).forEach { successfulStatus ->
setOf(SUCCEEDED, SKIPPED).forEach { successfulStatus ->
describe("when a task completes with $successfulStatus status") {
given("the stage contains further tasks") {
val pipeline = pipeline {
Expand Down Expand Up @@ -256,7 +257,7 @@ object CompleteTaskHandlerTest : SubjectSpek<CompleteTaskHandler>({
}
}

setOf(TERMINAL, CANCELED).forEach { status ->
setOf(TERMINAL, CANCELED, STOPPED).forEach { status ->
describe("when a task completes with $status status") {
val pipeline = pipeline {
stage {
Expand Down
Loading

0 comments on commit 57eed3f

Please sign in to comment.