diff --git a/internal/tasks/tasks.go b/internal/tasks/tasks.go index 3590101800..bdee8531a6 100644 --- a/internal/tasks/tasks.go +++ b/internal/tasks/tasks.go @@ -5,49 +5,62 @@ import ( "fmt" "os" "os/signal" + "reflect" ) -type Task func(ctx context.Context, args interface{}) (nextArgs interface{}, err error) -type TaskWithCleanup[T any] func(ctx context.Context, args interface{}) (nextArgs interface{}, cleanupArgs T, err error) -type Cleanup[T any] func(ctx context.Context, cleanupArgs T) error +type TaskFunc[T any, U any] func(t *Task, args T) (nextArgs U, err error) +type CleanupFunc func(ctx context.Context) error -type taskInfo struct { - Name string - function TaskWithCleanup[any] - cleanFunction Cleanup[any] - cleanupArgs interface{} +type Task struct { + Name string + Ctx context.Context + + taskFunction TaskFunc[any, any] + argType reflect.Type + returnType reflect.Type + cleanFunctions []CleanupFunc } type Tasks struct { - tasks []taskInfo + tasks []Task } func Begin() *Tasks { return &Tasks{} } -// Add a task that does not need cleanup -func (ts *Tasks) Add(name string, task Task) { - ts.tasks = append(ts.tasks, taskInfo{ - Name: name, - function: func(ctx context.Context, i interface{}) (passedData interface{}, cleanUpData interface{}, err error) { - passedData, err = task(ctx, i) +func Add[TaskArg any, TaskReturn any](ts *Tasks, name string, taskFunc TaskFunc[TaskArg, TaskReturn]) { + var argValue TaskArg + var returnValue TaskReturn + argType := reflect.TypeOf(argValue) + returnType := reflect.TypeOf(returnValue) + + tasksAmount := len(ts.tasks) + if tasksAmount > 0 { + lastTask := &ts.tasks[tasksAmount-1] + if argType != lastTask.returnType { + panic(fmt.Errorf("invalid task declared, wait for %s, previous task returns %s", argType.Name(), lastTask.returnType.Name())) + } + } + + ts.tasks = append(ts.tasks, Task{ + Name: name, + argType: argType, + returnType: returnType, + taskFunction: func(t *Task, i interface{}) (passedData interface{}, err error) { + if i == nil { + var zero TaskArg + passedData, err = taskFunc(t, zero) + } else { + passedData, err = taskFunc(t, i.(TaskArg)) + } return }, }) } -// AddWithCleanUp adds a task to the list with a cleanup function in case of fail during tasks execution -func AddWithCleanUp[T any](ts *Tasks, name string, task TaskWithCleanup[T], clean Cleanup[T]) { - ts.tasks = append(ts.tasks, taskInfo{ - Name: name, - function: func(ctx context.Context, args interface{}) (nextArgs interface{}, cleanUpArgs any, err error) { - return task(ctx, args) - }, - cleanFunction: func(ctx context.Context, cleanupArgs any) error { - return clean(ctx, cleanupArgs.(T)) - }, - }) +func (t *Task) AddToCleanUp(cleanupFunc CleanupFunc) { + t.cleanFunctions = append(t.cleanFunctions, cleanupFunc) } // setupContext return a contextWithCancel that will cancel on os interrupt (Ctrl-C) @@ -73,14 +86,17 @@ func (ts *Tasks) Cleanup(ctx context.Context, failed int) { default: } - if task.cleanFunction != nil { + if len(task.cleanFunctions) != 0 { fmt.Printf("[%d/%d] Cleaning task %q\n", i+1, totalTasks, task.Name) loader.Start() - err := task.cleanFunction(cancelableCtx, task.cleanupArgs) - if err != nil { - fmt.Printf("task %d failed to cleanup, there may be dangling resources: %s\n", i+1, err.Error()) + for _, cleanUpFunc := range task.cleanFunctions { + err := cleanUpFunc(cancelableCtx) + if err != nil { + fmt.Printf("task %d failed to cleanup, there may be dangling resources: %s\n", i+1, err.Error()) + } } + loader.Stop() } } @@ -97,19 +113,27 @@ func (ts *Tasks) Execute(ctx context.Context, data interface{}) (interface{}, er for i := range ts.tasks { task := &ts.tasks[i] + // Add context and reset cleanup functions, allows to execute multiple times + task.Ctx = cancelableCtx + task.cleanFunctions = []CleanupFunc(nil) + fmt.Printf("[%d/%d] %s\n", i+1, totalTasks, task.Name) loader.Start() - data, task.cleanupArgs, err = task.function(cancelableCtx, data) + data, err = task.taskFunction(task, data) taskIsCancelled := false select { case <-cancelableCtx.Done(): taskIsCancelled = true default: } - if err != nil || taskIsCancelled { + if err != nil { loader.Stop() - fmt.Println("task failed, cleaning up created resources") + if taskIsCancelled { + fmt.Println("task canceled, cleaning up created resources") + } else { + fmt.Println("task failed, cleaning up created resources") + } ts.Cleanup(ctx, i) return nil, fmt.Errorf("task %d %q failed: %w", i+1, task.Name, err) } diff --git a/internal/tasks/tasks_test.go b/internal/tasks/tasks_test.go index b4912d6f85..c79d270140 100644 --- a/internal/tasks/tasks_test.go +++ b/internal/tasks/tasks_test.go @@ -5,6 +5,8 @@ import ( "fmt" "os" "runtime" + "strconv" + "strings" "testing" "time" @@ -12,29 +14,69 @@ import ( "github.com/scaleway/scaleway-cli/v2/internal/tasks" ) +func TestGeneric(t *testing.T) { + ts := tasks.Begin() + + tasks.Add(ts, "convert int to string", func(t *tasks.Task, args int) (nextArgs string, err error) { + return fmt.Sprintf("%d", args), nil + }) + tasks.Add(ts, "convert string to int and divide by 4", func(t *tasks.Task, args string) (nextArgs int, err error) { + i, err := strconv.ParseInt(args, 10, 32) + if err != nil { + return 0, err + } + return int(i) / 4, nil + }) + + res, err := ts.Execute(context.Background(), 12) + assert.Nil(t, err) + assert.Equal(t, 3, res) +} + +func TestInvalidGeneric(t *testing.T) { + defer func() { + if r := recover(); r == nil { + t.Errorf("The code did not panic") + } + }() + + ts := tasks.Begin() + + tasks.Add(ts, "convert int to string", func(t *tasks.Task, args int) (nextArgs string, err error) { + return fmt.Sprintf("%d", args), nil + }) + tasks.Add(ts, "divide by 4", func(t *tasks.Task, args int) (nextArgs int, err error) { + return args / 4, nil + }) +} + func TestCleanup(t *testing.T) { ts := tasks.Begin() clean := 0 - tasks.AddWithCleanUp(ts, "Task 1", func(context.Context, interface{}) (interface{}, string, error) { - return nil, "", nil - }, func(context.Context, string) error { - clean++ - return nil + tasks.Add(ts, "TaskFunc 1", func(task *tasks.Task, args interface{}) (nextArgs interface{}, err error) { + task.AddToCleanUp(func(ctx context.Context) error { + clean++ + return nil + }) + return nil, nil }) - tasks.AddWithCleanUp(ts, "Task 2", func(context.Context, interface{}) (interface{}, string, error) { - return nil, "", nil - }, func(context.Context, string) error { - clean++ - return nil + tasks.Add(ts, "TaskFunc 2", func(task *tasks.Task, args interface{}) (nextArgs interface{}, err error) { + task.AddToCleanUp(func(ctx context.Context) error { + clean++ + return nil + }) + return nil, nil }) - tasks.AddWithCleanUp(ts, "Task 3", func(context.Context, interface{}) (interface{}, string, error) { - return nil, "", fmt.Errorf("fail") - }, func(context.Context, string) error { - clean++ - return nil + tasks.Add(ts, "TaskFunc 3", func(task *tasks.Task, args interface{}) (nextArgs interface{}, err error) { + task.AddToCleanUp(func(ctx context.Context) error { + clean++ + return nil + }) + return nil, fmt.Errorf("fail") }) + _, err := ts.Execute(context.Background(), nil) assert.NotNil(t, err, "Execute should return error after cleanup") assert.Equal(t, clean, 2, "2 task cleanup should have been executed") @@ -49,48 +91,46 @@ func TestCleanupOnContext(t *testing.T) { clean := 0 ctx := context.Background() - tasks.AddWithCleanUp(ts, "Task 1", - func(context.Context, interface{}) (interface{}, string, error) { - return nil, "", nil - }, func(context.Context, string) error { + tasks.Add(ts, "TaskFunc 1", func(task *tasks.Task, args interface{}) (nextArgs interface{}, err error) { + task.AddToCleanUp(func(ctx context.Context) error { clean++ return nil - }, - ) - tasks.AddWithCleanUp(ts, "Task 2", - func(context.Context, interface{}) (interface{}, string, error) { - return nil, "", nil - }, func(context.Context, string) error { + }) + return nil, nil + }) + tasks.Add(ts, "TaskFunc 2", func(task *tasks.Task, args interface{}) (nextArgs interface{}, err error) { + task.AddToCleanUp(func(ctx context.Context) error { clean++ return nil - }, - ) - tasks.AddWithCleanUp(ts, "Task 3", - func(ctx context.Context, _ interface{}) (interface{}, string, error) { - p, err := os.FindProcess(os.Getpid()) - if err != nil { - return nil, "", err - } - - // Interrupt tasks, as done with Ctrl-C - err = p.Signal(os.Interrupt) - if err != nil { - t.Fatal(err) - } - - select { - case <-time.After(time.Second): - return nil, "", nil - case <-ctx.Done(): - return nil, "", fmt.Errorf("interrupted") - } - }, func(context.Context, string) error { + }) + return nil, nil + }) + tasks.Add(ts, "TaskFunc 3", func(task *tasks.Task, args interface{}) (nextArgs interface{}, err error) { + task.AddToCleanUp(func(ctx context.Context) error { clean++ return nil - }, - ) + }) + p, err := os.FindProcess(os.Getpid()) + if err != nil { + return nil, err + } + + // Interrupt tasks, as done with Ctrl-C + err = p.Signal(os.Interrupt) + if err != nil { + t.Fatal(err) + } + + select { + case <-task.Ctx.Done(): + return nil, fmt.Errorf("interrupted") + case <-time.After(time.Second * 3): + return nil, nil + } + }) _, err := ts.Execute(ctx, nil) assert.NotNil(t, err, "context should have been interrupted") + assert.True(t, strings.Contains(err.Error(), "interrupted"), "error is not interrupted: %s", err) assert.Equal(t, clean, 2, "2 task cleanup should have been executed") }