Skip to content

Commit

Permalink
chore: properly propagate context object in the controller
Browse files Browse the repository at this point in the history
This is required to correctly handle ACPI reboot or forceful reboots
during sequence that locks the controller.
Additionally fix `NoSchedule` untaint when the configuration is changed.

Signed-off-by: Artem Chernyshev <artem.0xD2@gmail.com>
  • Loading branch information
Unix4ever committed Mar 3, 2021
1 parent 60aa011 commit 638af35
Show file tree
Hide file tree
Showing 7 changed files with 90 additions and 75 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,7 @@ func (s *Server) ApplyConfiguration(ctx context.Context, in *machine.ApplyConfig
}

go func() {
if err := s.Controller.Run(runtime.SequenceApplyConfiguration, in); err != nil {
if err := s.Controller.Run(context.Background(), runtime.SequenceApplyConfiguration, in); err != nil {
if !runtime.IsRebootError(err) {
log.Println("apply configuration failed:", err)
}
Expand Down Expand Up @@ -198,7 +198,7 @@ func (s *Server) Reboot(ctx context.Context, in *empty.Empty) (reply *machine.Re
}

go func() {
if err := s.Controller.Run(runtime.SequenceReboot, in); err != nil {
if err := s.Controller.Run(context.Background(), runtime.SequenceReboot, in); err != nil {
if !runtime.IsRebootError(err) {
log.Println("reboot failed:", err)
}
Expand Down Expand Up @@ -264,7 +264,7 @@ func (s *Server) Rollback(ctx context.Context, in *machine.RollbackRequest) (*ma
}

go func() {
if err := s.Controller.Run(runtime.SequenceReboot, in, runtime.WithForce()); err != nil {
if err := s.Controller.Run(context.Background(), runtime.SequenceReboot, in, runtime.WithForce()); err != nil {
if !runtime.IsRebootError(err) {
log.Println("reboot failed:", err)
}
Expand Down Expand Up @@ -295,7 +295,7 @@ func (s *Server) Bootstrap(ctx context.Context, in *machine.BootstrapRequest) (r
}

go func() {
if err := s.Controller.Run(runtime.SequenceBootstrap, in); err != nil {
if err := s.Controller.Run(context.Background(), runtime.SequenceBootstrap, in); err != nil {
log.Println("bootstrap failed:", err)

if err != runtime.ErrLocked {
Expand Down Expand Up @@ -326,7 +326,7 @@ func (s *Server) Shutdown(ctx context.Context, in *empty.Empty) (reply *machine.
}

go func() {
if err := s.Controller.Run(runtime.SequenceShutdown, in); err != nil {
if err := s.Controller.Run(context.Background(), runtime.SequenceShutdown, in); err != nil {
if !runtime.IsRebootError(err) {
log.Println("shutdown failed:", err)
}
Expand Down Expand Up @@ -426,7 +426,7 @@ func (s *Server) Upgrade(ctx context.Context, in *machine.UpgradeRequest) (reply
defer mu.Unlock(ctx) // nolint: errcheck
}

if err := s.Controller.Run(runtime.SequenceStageUpgrade, in); err != nil {
if err := s.Controller.Run(context.Background(), runtime.SequenceStageUpgrade, in); err != nil {
if !runtime.IsRebootError(err) {
log.Println("reboot for staged upgrade failed:", err)
}
Expand All @@ -444,7 +444,7 @@ func (s *Server) Upgrade(ctx context.Context, in *machine.UpgradeRequest) (reply
defer mu.Unlock(ctx) // nolint: errcheck
}

if err := s.Controller.Run(runtime.SequenceUpgrade, in); err != nil {
if err := s.Controller.Run(context.Background(), runtime.SequenceUpgrade, in); err != nil {
if !runtime.IsRebootError(err) {
log.Println("upgrade failed:", err)
}
Expand Down Expand Up @@ -543,7 +543,7 @@ func (s *Server) Reset(ctx context.Context, in *machine.ResetRequest) (reply *ma
}

go func() {
if err := s.Controller.Run(runtime.SequenceReset, &opts); err != nil {
if err := s.Controller.Run(context.Background(), runtime.SequenceReset, &opts); err != nil {
if !runtime.IsRebootError(err) {
log.Println("reset failed:", err)
}
Expand Down
8 changes: 4 additions & 4 deletions internal/app/machined/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -203,7 +203,7 @@ func run() error {

// Start signal and ACPI listeners.
go func() {
if e := c.ListenForEvents(); e != nil {
if e := c.ListenForEvents(ctx); e != nil {
log.Printf("WARNING: signals and ACPI events will be ignored: %s", e)
}
}()
Expand All @@ -218,20 +218,20 @@ func run() error {
}()

// Initialize the machine.
if err = c.Run(runtime.SequenceInitialize, nil); err != nil {
if err = c.Run(ctx, runtime.SequenceInitialize, nil); err != nil {
return err
}

// Perform an installation if required.
if err = c.Run(runtime.SequenceInstall, nil); err != nil {
if err = c.Run(ctx, runtime.SequenceInstall, nil); err != nil {
return err
}

// Start the machine API.
system.Services(c.Runtime()).LoadAndStart(&services.Machined{Controller: c})

// Boot the machine.
if err = c.Run(runtime.SequenceBoot, nil); err != nil {
if err = c.Run(ctx, runtime.SequenceBoot, nil); err != nil {
return err
}

Expand Down
2 changes: 1 addition & 1 deletion internal/app/machined/pkg/runtime/controller.go
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ func DefaultControllerOptions() ControllerOptions {
type Controller interface {
Runtime() Runtime
Sequencer() Sequencer
Run(Sequence, interface{}, ...ControllerOption) error
Run(context.Context, Sequence, interface{}, ...ControllerOption) error
V1Alpha2() V1Alpha2Controller
}

Expand Down
28 changes: 17 additions & 11 deletions internal/app/machined/pkg/runtime/v1alpha1/v1alpha1_controller.go
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ func NewController(b []byte) (*Controller, error) {

// Run executes all phases known to the controller in serial. `Controller`
// aborts immediately if any phase fails.
func (c *Controller) Run(seq runtime.Sequence, data interface{}, setters ...runtime.ControllerOption) error {
func (c *Controller) Run(ctx context.Context, seq runtime.Sequence, data interface{}, setters ...runtime.ControllerOption) error {
// We must ensure that the runtime is configured since all sequences depend
// on the runtime.
if c.r == nil {
Expand Down Expand Up @@ -129,7 +129,7 @@ func (c *Controller) Run(seq runtime.Sequence, data interface{}, setters ...runt
return err
}

err = c.run(seq, phases, data)
err = c.run(ctx, seq, phases, data)
if err != nil {
c.Runtime().Events().Publish(&machine.SequenceEvent{
Sequence: seq.String(),
Expand Down Expand Up @@ -163,7 +163,7 @@ func (c *Controller) Sequencer() runtime.Sequencer {

// ListenForEvents starts the event listener. The listener will trigger a
// shutdown in response to a SIGTERM signal and ACPI button/power event.
func (c *Controller) ListenForEvents() error {
func (c *Controller) ListenForEvents(ctx context.Context) error {
sigs := make(chan os.Signal, 1)

signal.Notify(sigs, syscall.SIGTERM)
Expand All @@ -176,7 +176,7 @@ func (c *Controller) ListenForEvents() error {

log.Printf("shutdown via SIGTERM received")

if err := c.Run(runtime.SequenceShutdown, nil); err != nil {
if err := c.Run(ctx, runtime.SequenceShutdown, nil); err != nil {
log.Printf("shutdown failed: %v", err)
}

Expand All @@ -198,7 +198,7 @@ func (c *Controller) ListenForEvents() error {

// TODO: The sequencer lock will prevent this. We need a way to force the
// shutdown.
if err := c.Run(runtime.SequenceShutdown, nil); err != nil {
if err := c.Run(ctx, runtime.SequenceShutdown, nil); err != nil {
log.Printf("shutdown failed: %v", err)
}

Expand All @@ -222,7 +222,7 @@ func (c *Controller) Unlock() bool {
return atomic.CompareAndSwapInt32(&c.semaphore, 1, 0)
}

func (c *Controller) run(seq runtime.Sequence, phases []runtime.Phase, data interface{}) error {
func (c *Controller) run(ctx context.Context, seq runtime.Sequence, phases []runtime.Phase, data interface{}) error {
c.Runtime().Events().Publish(&machine.SequenceEvent{
Sequence: seq.String(),
Action: machine.SequenceEvent_START,
Expand Down Expand Up @@ -263,21 +263,27 @@ func (c *Controller) run(seq runtime.Sequence, phases []runtime.Phase, data inte

log.Printf("phase %s (%s): %d tasks(s)", phase.Name, progress, len(phase.Tasks))

if err = c.runPhase(phase, seq, data); err != nil {
if err = c.runPhase(ctx, phase, seq, data); err != nil {
if !runtime.IsRebootError(err) {
log.Printf("phase %s (%s): failed", phase.Name, progress)
}

return fmt.Errorf("error running phase %d in %s sequence: %w", number, seq.String(), err)
}

select {
case <-ctx.Done():
return ctx.Err()
default:
}

log.Printf("phase %s (%s): done, %s", phase.Name, progress, time.Since(start))
}

return nil
}

func (c *Controller) runPhase(phase runtime.Phase, seq runtime.Sequence, data interface{}) error {
func (c *Controller) runPhase(ctx context.Context, phase runtime.Phase, seq runtime.Sequence, data interface{}) error {
c.Runtime().Events().Publish(&machine.PhaseEvent{
Phase: phase.Name,
Action: machine.PhaseEvent_START,
Expand All @@ -301,7 +307,7 @@ func (c *Controller) runPhase(phase runtime.Phase, seq runtime.Sequence, data in
eg.Go(func() error {
progress := fmt.Sprintf("%d/%d", number, len(phase.Tasks))

if err := c.runTask(progress, task, seq, data); err != nil {
if err := c.runTask(ctx, progress, task, seq, data); err != nil {
return fmt.Errorf("task %s: failed, %w", progress, err)
}

Expand All @@ -312,7 +318,7 @@ func (c *Controller) runPhase(phase runtime.Phase, seq runtime.Sequence, data in
return eg.Wait()
}

func (c *Controller) runTask(progress string, f runtime.TaskSetupFunc, seq runtime.Sequence, data interface{}) error {
func (c *Controller) runTask(ctx context.Context, progress string, f runtime.TaskSetupFunc, seq runtime.Sequence, data interface{}) error {
task, taskName := f(seq, data)
if task == nil {
return nil
Expand Down Expand Up @@ -346,7 +352,7 @@ func (c *Controller) runTask(progress string, f runtime.TaskSetupFunc, seq runti

logger := log.New(log.Writer(), fmt.Sprintf("[talos] task %s (%s): ", taskName, progress), log.Flags())

err = task(context.TODO(), logger, c.r)
err = task(ctx, logger, c.r)

return err
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
package v1alpha1

import (
"context"
"reflect"
"strconv"
"testing"
Expand Down Expand Up @@ -63,14 +64,16 @@ func TestController_Run(t *testing.T) {
// TODO: Add test cases.
}

ctx := context.Background()

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
c := &Controller{
r: tt.fields.r,
s: tt.fields.s,
semaphore: tt.fields.semaphore,
}
if err := c.Run(tt.args.seq, tt.args.data); (err != nil) != tt.wantErr {
if err := c.Run(ctx, tt.args.seq, tt.args.data); (err != nil) != tt.wantErr {
t.Errorf("Controller.Run() error = %v, wantErr %v", err, tt.wantErr)
}
})
Expand Down Expand Up @@ -149,6 +152,7 @@ func TestController_ListenForEvents(t *testing.T) {
}{
// TODO: Add test cases.
}
ctx := context.Background()

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
Expand All @@ -157,7 +161,7 @@ func TestController_ListenForEvents(t *testing.T) {
s: tt.fields.s,
semaphore: tt.fields.semaphore,
}
if err := c.ListenForEvents(); (err != nil) != tt.wantErr {
if err := c.ListenForEvents(ctx); (err != nil) != tt.wantErr {
t.Errorf("Controller.ListenForEvents() error = %v, wantErr %v", err, tt.wantErr)
}
})
Expand Down Expand Up @@ -244,14 +248,16 @@ func TestController_run(t *testing.T) {
// TODO: Add test cases.
}

ctx := context.Background()

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
c := &Controller{
r: tt.fields.r,
s: tt.fields.s,
semaphore: tt.fields.semaphore,
}
if err := c.run(tt.args.seq, tt.args.phases, tt.args.data); (err != nil) != tt.wantErr {
if err := c.run(ctx, tt.args.seq, tt.args.phases, tt.args.data); (err != nil) != tt.wantErr {
t.Errorf("Controller.run() error = %v, wantErr %v", err, tt.wantErr)
}
})
Expand Down Expand Up @@ -279,6 +285,7 @@ func TestController_runPhase(t *testing.T) {
}{
// TODO: Add test cases.
}
ctx := context.Background()

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
Expand All @@ -287,7 +294,7 @@ func TestController_runPhase(t *testing.T) {
s: tt.fields.s,
semaphore: tt.fields.semaphore,
}
if err := c.runPhase(tt.args.phase, tt.args.seq, tt.args.data); (err != nil) != tt.wantErr {
if err := c.runPhase(ctx, tt.args.phase, tt.args.seq, tt.args.data); (err != nil) != tt.wantErr {
t.Errorf("Controller.runPhase() error = %v, wantErr %v", err, tt.wantErr)
}
})
Expand Down Expand Up @@ -316,6 +323,7 @@ func TestController_runTask(t *testing.T) {
}{
// TODO: Add test cases.
}
ctx := context.Background()

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
Expand All @@ -324,7 +332,7 @@ func TestController_runTask(t *testing.T) {
s: tt.fields.s,
semaphore: tt.fields.semaphore,
}
if err := c.runTask(strconv.Itoa(tt.args.n), tt.args.f, tt.args.seq, tt.args.data); (err != nil) != tt.wantErr {
if err := c.runTask(ctx, strconv.Itoa(tt.args.n), tt.args.f, tt.args.seq, tt.args.data); (err != nil) != tt.wantErr {
t.Errorf("Controller.runTask() error = %v, wantErr %v", err, tt.wantErr)
}
})
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1285,11 +1285,11 @@ func UncordonNode(seq runtime.Sequence, data interface{}) (runtime.TaskExecution
return err
}

if err = kubeHelper.WaitUntilReady(nodename); err != nil {
if err = kubeHelper.WaitUntilReady(ctx, nodename); err != nil {
return err
}

if err = kubeHelper.Uncordon(nodename, false); err != nil {
if err = kubeHelper.Uncordon(ctx, nodename, false); err != nil {
return err
}

Expand Down Expand Up @@ -1503,8 +1503,8 @@ func LabelNodeAsMaster(seq runtime.Sequence, data interface{}) (runtime.TaskExec
return err
}

err = retry.Constant(constants.NodeReadyTimeout, retry.WithUnits(3*time.Second), retry.WithErrorLogging(true)).Retry(func() error {
if err = h.LabelNodeAsMaster(nodename, !r.Config().Cluster().ScheduleOnMasters()); err != nil {
err = retry.Constant(constants.NodeReadyTimeout, retry.WithUnits(3*time.Second), retry.WithErrorLogging(true)).RetryWithContext(ctx, func(ctx context.Context) error {
if err = h.LabelNodeAsMaster(ctx, nodename, !r.Config().Cluster().ScheduleOnMasters()); err != nil {
return retry.ExpectedError(err)
}

Expand Down

0 comments on commit 638af35

Please sign in to comment.