Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
90 changes: 57 additions & 33 deletions internal/tasks/tasks.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,49 +5,62 @@ import (
"fmt"
"os"
"os/signal"
"reflect"
)

type Task func(ctx context.Context, args interface{}) (nextArgs interface{}, err error)
type TaskWithCleanup[T any] func(ctx context.Context, args interface{}) (nextArgs interface{}, cleanupArgs T, err error)
type Cleanup[T any] func(ctx context.Context, cleanupArgs T) error
type TaskFunc[T any, U any] func(t *Task, args T) (nextArgs U, err error)
type CleanupFunc func(ctx context.Context) error

type taskInfo struct {
Name string
function TaskWithCleanup[any]
cleanFunction Cleanup[any]
cleanupArgs interface{}
type Task struct {
Name string
Ctx context.Context

taskFunction TaskFunc[any, any]
argType reflect.Type
returnType reflect.Type
cleanFunctions []CleanupFunc
}

type Tasks struct {
tasks []taskInfo
tasks []Task
}

func Begin() *Tasks {
return &Tasks{}
}

// Add a task that does not need cleanup
func (ts *Tasks) Add(name string, task Task) {
ts.tasks = append(ts.tasks, taskInfo{
Name: name,
function: func(ctx context.Context, i interface{}) (passedData interface{}, cleanUpData interface{}, err error) {
passedData, err = task(ctx, i)
func Add[TaskArg any, TaskReturn any](ts *Tasks, name string, taskFunc TaskFunc[TaskArg, TaskReturn]) {
var argValue TaskArg
var returnValue TaskReturn
argType := reflect.TypeOf(argValue)
returnType := reflect.TypeOf(returnValue)

tasksAmount := len(ts.tasks)
if tasksAmount > 0 {
lastTask := &ts.tasks[tasksAmount-1]
if argType != lastTask.returnType {
panic(fmt.Errorf("invalid task declared, wait for %s, previous task returns %s", argType.Name(), lastTask.returnType.Name()))
}
}

ts.tasks = append(ts.tasks, Task{
Name: name,
argType: argType,
returnType: returnType,
taskFunction: func(t *Task, i interface{}) (passedData interface{}, err error) {
if i == nil {
var zero TaskArg
passedData, err = taskFunc(t, zero)
} else {
passedData, err = taskFunc(t, i.(TaskArg))
}
return
},
})
}

// AddWithCleanUp adds a task to the list with a cleanup function in case of fail during tasks execution
func AddWithCleanUp[T any](ts *Tasks, name string, task TaskWithCleanup[T], clean Cleanup[T]) {
ts.tasks = append(ts.tasks, taskInfo{
Name: name,
function: func(ctx context.Context, args interface{}) (nextArgs interface{}, cleanUpArgs any, err error) {
return task(ctx, args)
},
cleanFunction: func(ctx context.Context, cleanupArgs any) error {
return clean(ctx, cleanupArgs.(T))
},
})
func (t *Task) AddToCleanUp(cleanupFunc CleanupFunc) {
t.cleanFunctions = append(t.cleanFunctions, cleanupFunc)
}

// setupContext return a contextWithCancel that will cancel on os interrupt (Ctrl-C)
Expand All @@ -73,14 +86,17 @@ func (ts *Tasks) Cleanup(ctx context.Context, failed int) {
default:
}

if task.cleanFunction != nil {
if len(task.cleanFunctions) != 0 {
fmt.Printf("[%d/%d] Cleaning task %q\n", i+1, totalTasks, task.Name)
loader.Start()

err := task.cleanFunction(cancelableCtx, task.cleanupArgs)
if err != nil {
fmt.Printf("task %d failed to cleanup, there may be dangling resources: %s\n", i+1, err.Error())
for _, cleanUpFunc := range task.cleanFunctions {
err := cleanUpFunc(cancelableCtx)
if err != nil {
fmt.Printf("task %d failed to cleanup, there may be dangling resources: %s\n", i+1, err.Error())
}
}

loader.Stop()
}
}
Expand All @@ -97,19 +113,27 @@ func (ts *Tasks) Execute(ctx context.Context, data interface{}) (interface{}, er

for i := range ts.tasks {
task := &ts.tasks[i]
// Add context and reset cleanup functions, allows to execute multiple times
task.Ctx = cancelableCtx
task.cleanFunctions = []CleanupFunc(nil)

fmt.Printf("[%d/%d] %s\n", i+1, totalTasks, task.Name)
loader.Start()

data, task.cleanupArgs, err = task.function(cancelableCtx, data)
data, err = task.taskFunction(task, data)
taskIsCancelled := false
select {
case <-cancelableCtx.Done():
taskIsCancelled = true
default:
}
if err != nil || taskIsCancelled {
if err != nil {
loader.Stop()
fmt.Println("task failed, cleaning up created resources")
if taskIsCancelled {
fmt.Println("task canceled, cleaning up created resources")
} else {
fmt.Println("task failed, cleaning up created resources")
}
ts.Cleanup(ctx, i)
return nil, fmt.Errorf("task %d %q failed: %w", i+1, task.Name, err)
}
Expand Down
138 changes: 89 additions & 49 deletions internal/tasks/tasks_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,36 +5,78 @@ import (
"fmt"
"os"
"runtime"
"strconv"
"strings"
"testing"
"time"

"github.com/alecthomas/assert"
"github.com/scaleway/scaleway-cli/v2/internal/tasks"
)

func TestGeneric(t *testing.T) {
ts := tasks.Begin()

tasks.Add(ts, "convert int to string", func(t *tasks.Task, args int) (nextArgs string, err error) {
return fmt.Sprintf("%d", args), nil
})
tasks.Add(ts, "convert string to int and divide by 4", func(t *tasks.Task, args string) (nextArgs int, err error) {
i, err := strconv.ParseInt(args, 10, 32)
if err != nil {
return 0, err
}
return int(i) / 4, nil
})

res, err := ts.Execute(context.Background(), 12)
assert.Nil(t, err)
assert.Equal(t, 3, res)
}

func TestInvalidGeneric(t *testing.T) {
defer func() {
if r := recover(); r == nil {
t.Errorf("The code did not panic")
}
}()

ts := tasks.Begin()

tasks.Add(ts, "convert int to string", func(t *tasks.Task, args int) (nextArgs string, err error) {
return fmt.Sprintf("%d", args), nil
})
tasks.Add(ts, "divide by 4", func(t *tasks.Task, args int) (nextArgs int, err error) {
return args / 4, nil
})
}

func TestCleanup(t *testing.T) {
ts := tasks.Begin()

clean := 0

tasks.AddWithCleanUp(ts, "Task 1", func(context.Context, interface{}) (interface{}, string, error) {
return nil, "", nil
}, func(context.Context, string) error {
clean++
return nil
tasks.Add(ts, "TaskFunc 1", func(task *tasks.Task, args interface{}) (nextArgs interface{}, err error) {
task.AddToCleanUp(func(ctx context.Context) error {
clean++
return nil
})
return nil, nil
})
tasks.AddWithCleanUp(ts, "Task 2", func(context.Context, interface{}) (interface{}, string, error) {
return nil, "", nil
}, func(context.Context, string) error {
clean++
return nil
tasks.Add(ts, "TaskFunc 2", func(task *tasks.Task, args interface{}) (nextArgs interface{}, err error) {
task.AddToCleanUp(func(ctx context.Context) error {
clean++
return nil
})
return nil, nil
})
tasks.AddWithCleanUp(ts, "Task 3", func(context.Context, interface{}) (interface{}, string, error) {
return nil, "", fmt.Errorf("fail")
}, func(context.Context, string) error {
clean++
return nil
tasks.Add(ts, "TaskFunc 3", func(task *tasks.Task, args interface{}) (nextArgs interface{}, err error) {
task.AddToCleanUp(func(ctx context.Context) error {
clean++
return nil
})
return nil, fmt.Errorf("fail")
})

_, err := ts.Execute(context.Background(), nil)
assert.NotNil(t, err, "Execute should return error after cleanup")
assert.Equal(t, clean, 2, "2 task cleanup should have been executed")
Expand All @@ -49,48 +91,46 @@ func TestCleanupOnContext(t *testing.T) {
clean := 0
ctx := context.Background()

tasks.AddWithCleanUp(ts, "Task 1",
func(context.Context, interface{}) (interface{}, string, error) {
return nil, "", nil
}, func(context.Context, string) error {
tasks.Add(ts, "TaskFunc 1", func(task *tasks.Task, args interface{}) (nextArgs interface{}, err error) {
task.AddToCleanUp(func(ctx context.Context) error {
clean++
return nil
},
)
tasks.AddWithCleanUp(ts, "Task 2",
func(context.Context, interface{}) (interface{}, string, error) {
return nil, "", nil
}, func(context.Context, string) error {
})
return nil, nil
})
tasks.Add(ts, "TaskFunc 2", func(task *tasks.Task, args interface{}) (nextArgs interface{}, err error) {
task.AddToCleanUp(func(ctx context.Context) error {
clean++
return nil
},
)
tasks.AddWithCleanUp(ts, "Task 3",
func(ctx context.Context, _ interface{}) (interface{}, string, error) {
p, err := os.FindProcess(os.Getpid())
if err != nil {
return nil, "", err
}

// Interrupt tasks, as done with Ctrl-C
err = p.Signal(os.Interrupt)
if err != nil {
t.Fatal(err)
}

select {
case <-time.After(time.Second):
return nil, "", nil
case <-ctx.Done():
return nil, "", fmt.Errorf("interrupted")
}
}, func(context.Context, string) error {
})
return nil, nil
})
tasks.Add(ts, "TaskFunc 3", func(task *tasks.Task, args interface{}) (nextArgs interface{}, err error) {
task.AddToCleanUp(func(ctx context.Context) error {
clean++
return nil
},
)
})
p, err := os.FindProcess(os.Getpid())
if err != nil {
return nil, err
}

// Interrupt tasks, as done with Ctrl-C
err = p.Signal(os.Interrupt)
if err != nil {
t.Fatal(err)
}

select {
case <-task.Ctx.Done():
return nil, fmt.Errorf("interrupted")
case <-time.After(time.Second * 3):
return nil, nil
}
})

_, err := ts.Execute(ctx, nil)
assert.NotNil(t, err, "context should have been interrupted")
assert.True(t, strings.Contains(err.Error(), "interrupted"), "error is not interrupted: %s", err)
assert.Equal(t, clean, 2, "2 task cleanup should have been executed")
}