From 91e750e4f0ce007cb436551e7556f72193463b85 Mon Sep 17 00:00:00 2001 From: Ricardo Maraschini Date: Thu, 22 Jun 2023 15:12:00 +0200 Subject: [PATCH 1/3] chore: cover supervisor package with unit tests this commit covers 84.1% of the supervisor package with unit tests. --- pkg/controller/k0scontroller.go | 17 +- pkg/supervisor/detachattr.go | 27 -- pkg/supervisor/logwriter.go | 82 ------- pkg/supervisor/options.go | 53 ++++ pkg/supervisor/options_test.go | 45 ++++ pkg/supervisor/supervisor.go | 394 ++++++++++++++---------------- pkg/supervisor/supervisor_test.go | 137 +++++++++++ pkg/supervisor/supervisor_unix.go | 132 ---------- 8 files changed, 424 insertions(+), 463 deletions(-) delete mode 100644 pkg/supervisor/detachattr.go delete mode 100644 pkg/supervisor/logwriter.go create mode 100644 pkg/supervisor/options.go create mode 100644 pkg/supervisor/options_test.go create mode 100644 pkg/supervisor/supervisor_test.go delete mode 100644 pkg/supervisor/supervisor_unix.go diff --git a/pkg/controller/k0scontroller.go b/pkg/controller/k0scontroller.go index 778ba80777..b086ab53c8 100644 --- a/pkg/controller/k0scontroller.go +++ b/pkg/controller/k0scontroller.go @@ -5,7 +5,6 @@ import ( "context" "encoding/csv" "fmt" - "io" "os" "os/user" "path/filepath" @@ -28,8 +27,7 @@ import ( type K0sController struct { Options config.K0sControllerOptions - supervisor supervisor.Supervisor - Output io.Writer + supervisor *supervisor.Supervisor uid int gid int } @@ -68,16 +66,9 @@ func (k *K0sController) Init(_ context.Context) error { if k.Options.CmdLogLevels != nil { args = append(args, fmt.Sprintf("--logging=%s", createS2SFlag(k.Options.CmdLogLevels))) } - k.supervisor = supervisor.Supervisor{ - Name: "k0s", - UID: k.uid, - GID: k.gid, - BinPath: assets.BinPath("k0s", k.Options.BinDir()), - Output: k.Output, - RunDir: k.Options.RunDir(), - DataDir: k.Options.DataDir, - KeepEnvPrefix: true, - Args: args, + k0spath := assets.BinPath("k0s", k.Options.BinDir()) + if k.supervisor, err = supervisor.New(k0spath, args); err != nil { + return fmt.Errorf("failed to create supervisor: %w", err) } return nil } diff --git a/pkg/supervisor/detachattr.go b/pkg/supervisor/detachattr.go deleted file mode 100644 index 2b807e93aa..0000000000 --- a/pkg/supervisor/detachattr.go +++ /dev/null @@ -1,27 +0,0 @@ -//go:build !windows -// +build !windows - -package supervisor - -import ( - "os" - "syscall" -) - -// DetachAttr creates the proper syscall attributes to run the managed processes -func DetachAttr(uid, gid int) *syscall.SysProcAttr { - var creds *syscall.Credential - - if os.Geteuid() == 0 { - creds = &syscall.Credential{ - Uid: uint32(uid), - Gid: uint32(gid), - } - } - - return &syscall.SysProcAttr{ - Setpgid: true, - Pgid: 0, - Credential: creds, - } -} diff --git a/pkg/supervisor/logwriter.go b/pkg/supervisor/logwriter.go deleted file mode 100644 index b62798114a..0000000000 --- a/pkg/supervisor/logwriter.go +++ /dev/null @@ -1,82 +0,0 @@ -package supervisor - -import ( - "bytes" - "unicode/utf8" - - "github.com/sirupsen/logrus" -) - -// logWriter implements [io.Writer] by forwarding whole lines to log. In case -// the lines get too long, it logs them in multiple chunks. -// -// This is in contrast to logrus's implementation of io.Writer, which simply -// errors out if the log line gets longer than 64k. -type logWriter struct { - log logrus.FieldLogger // receives (possibly chunked) log lines - buf []byte // buffer in which to accumulate chunks; len(buf) determines the chunk length - len int // current buffer length - chunkNo uint // current chunk number; 0 means "no chunk" -} - -// Write implements [io.Writer]. -func (w *logWriter) Write(in []byte) (int, error) { - w.writeBytes(in) - return len(in), nil -} - -func (w *logWriter) writeBytes(in []byte) { - // Fill and drain buffer with available data until everything has been consumed. - for rest := in; len(rest) > 0; { - - n := copy(w.buf[w.len:], rest) // fill buffer with new input data - rest = rest[n:] // strip copied input data - w.len += n // increase buffer length accordingly - - // Loop over buffer as long as there are newlines in it - for off := 0; ; { - idx := bytes.IndexRune(w.buf[off:w.len], '\n') - - // Discard already logged chunks and break if no newline left - if idx < 0 { - if off > 0 { - w.len = copy(w.buf, w.buf[off:w.len]) - } - break - } - - // Strip trailing carriage returns - line := bytes.TrimRight(w.buf[off:off+idx], "\r") - - if w.chunkNo == 0 { - w.log.Infof("%s", line) - } else { - if len(line) > 0 { - w.log.WithField("chunk", w.chunkNo+1).Infof("%s", line) - } - w.chunkNo = 0 - } - - off += idx + 1 // advance read offset behind the newline - } - - // Issue a chunked log entry in case the buffer is full - if w.len == len(w.buf) { - // Try to chunk at UTF-8 rune boundaries - length := w.len - for i := 0; i < utf8.MaxRune && i < w.len; i++ { - if r, _ := utf8.DecodeLastRune(w.buf[:w.len-i]); r != utf8.RuneError { - length = length - i - break - } - } - - // Strip trailing carriage returns - line := bytes.TrimRight(w.buf[:length], "\r") - - w.log.WithField("chunk", w.chunkNo+1).Infof("%s", line) - w.chunkNo++ // increase chunk number - w.len = copy(w.buf, w.buf[length:]) // discard logged bytes - } - } -} diff --git a/pkg/supervisor/options.go b/pkg/supervisor/options.go new file mode 100644 index 0000000000..e417c0d190 --- /dev/null +++ b/pkg/supervisor/options.go @@ -0,0 +1,53 @@ +package supervisor + +import ( + "fmt" + "time" +) + +// Option sets an option on a Supervisor reference. +type Option func(*Supervisor) + +// WithName sets the name of the Supervisor. +func WithName(name string) Option { + return func(s *Supervisor) { + s.name = name + s.log = s.log.WithField("component", name) + s.pidFile = fmt.Sprintf("/run/replicated/%s.pid", name) + } +} + +// WithTimeoutRespawn sets the timeout for respawning a process. +func WithTimeoutRespawn(timeoutRespawn time.Duration) Option { + return func(s *Supervisor) { + s.timeoutRespawn = timeoutRespawn + } +} + +// WithTimeoutStop sets the timeout for stopping a process. +func WithTimeoutStop(timeoutStop time.Duration) Option { + return func(s *Supervisor) { + s.timeoutStop = timeoutStop + } +} + +// WithUID sets the UID of the Supervisor. +func WithUID(uid int) Option { + return func(s *Supervisor) { + s.uid = uid + } +} + +// WithGID sets the GID of the Supervisor. +func WithGID(gid int) Option { + return func(s *Supervisor) { + s.gid = gid + } +} + +// WithPIDFile sets the PID file of the Supervisor. +func WithPIDFile(pidFile string) Option { + return func(s *Supervisor) { + s.pidFile = pidFile + } +} diff --git a/pkg/supervisor/options_test.go b/pkg/supervisor/options_test.go new file mode 100644 index 0000000000..b753d12453 --- /dev/null +++ b/pkg/supervisor/options_test.go @@ -0,0 +1,45 @@ +package supervisor + +import ( + "testing" + "time" + + "github.com/stretchr/testify/assert" +) + +func TestWithName(t *testing.T) { + s, err := New("/bin/cat", []string{"-"}, WithName("CAT")) + assert.NoError(t, err) + assert.Equal(t, "CAT", s.name) + assert.Equal(t, "/run/replicated/CAT.pid", s.pidFile) +} + +func TestWithTimeoutRespawn(t *testing.T) { + s, err := New("/bin/cat", []string{"-"}, WithTimeoutRespawn(time.Second)) + assert.NoError(t, err) + assert.Equal(t, time.Second, s.timeoutRespawn) +} + +func TestWithTimeoutStop(t *testing.T) { + s, err := New("/bin/cat", []string{"-"}, WithTimeoutStop(time.Second)) + assert.NoError(t, err) + assert.Equal(t, time.Second, s.timeoutStop) +} + +func TestWithGID(t *testing.T) { + s, err := New("/bin/cat", []string{"-"}, WithGID(1000)) + assert.NoError(t, err) + assert.Equal(t, 1000, s.gid) +} + +func TestWithUID(t *testing.T) { + s, err := New("/bin/cat", []string{"-"}, WithUID(1000)) + assert.NoError(t, err) + assert.Equal(t, 1000, s.uid) +} + +func TestWithPIDFile(t *testing.T) { + s, err := New("/bin/cat", []string{"-"}, WithPIDFile("/run/abc.pid")) + assert.NoError(t, err) + assert.Equal(t, "/run/abc.pid", s.pidFile) +} diff --git a/pkg/supervisor/supervisor.go b/pkg/supervisor/supervisor.go index f8b2569dba..6de3ab6c01 100644 --- a/pkg/supervisor/supervisor.go +++ b/pkg/supervisor/supervisor.go @@ -1,17 +1,12 @@ -/* -Package supervisor implements a simple process supervisor -*/ package supervisor import ( "context" "fmt" - "io" "os" "os/exec" - "path" - "runtime" - "sort" + "os/user" + "path/filepath" "strconv" "strings" "sync" @@ -21,277 +16,258 @@ import ( "github.com/sirupsen/logrus" ) -// Supervisor is dead simple and stupid process supervisor, just tries to keep the process running in a while-true loop -type Supervisor struct { - Name string - BinPath string - RunDir string - DataDir string - Args []string - PidFile string - UID int - GID int - TimeoutStop time.Duration - TimeoutRespawn time.Duration - // For those components having env prefix convention such as ETCD_xxx, we should keep the prefix. - KeepEnvPrefix bool - // ProcFSPath is only used for testing - ProcFSPath string - // KillFunction is only used for testing - KillFunction func(int, syscall.Signal) error - // A function to clean some leftovers before starting or restarting the supervised process - CleanBeforeFn func() error - Output io.Writer +// New returns a new Supervisor that will start and supervise the provided command with +// the provided arguments. +func New(path string, args []string, opts ...Option) (*Supervisor, error) { + usr, err := user.Current() + if err != nil { + return nil, fmt.Errorf("failed to determine current user: %w", err) + } + uid, err := strconv.Atoi(usr.Uid) + if err != nil { + return nil, fmt.Errorf("failed to parse current user id: %w", err) + } + gid, err := strconv.Atoi(usr.Gid) + if err != nil { + return nil, fmt.Errorf("failed to parse current group id: %w", err) + } + res := &Supervisor{ + binPath: path, + args: args, + name: filepath.Base(path), + log: logrus.WithField("component", filepath.Base(path)), + pidFile: fmt.Sprintf("/run/replicated/%s.pid", filepath.Base(path)), + timeoutStop: 5 * time.Second, + timeoutRespawn: 5 * time.Second, + uid: uid, + gid: gid, + } + for _, opt := range opts { + opt(res) + } + return res, nil +} +// Supervisor is process supervisor, just tries to keep the process running in a while-true loop. +type Supervisor struct { + name string + binPath string + log logrus.FieldLogger + args []string + uid int + gid int + timeoutStop time.Duration + timeoutRespawn time.Duration + pidFile string cmd *exec.Cmd done chan bool - log logrus.FieldLogger - mutex sync.Mutex startStopMutex sync.Mutex cancel context.CancelFunc } -const k0sManaged = "_K0S_MANAGED=yes" - -// processWaitQuit waits for a process to exit or a shut down signal -// returns true if shutdown is requested -func (s *Supervisor) processWaitQuit(ctx context.Context) bool { +// processWaitQuit waits for a process to exit or a shut down signal returns true if shutdown is requested. +func (s *Supervisor) processWaitQuit(ctx context.Context) (bool, error) { waitresult := make(chan error) go func() { waitresult <- s.cmd.Wait() }() - - pidbuf := []byte(strconv.Itoa(s.cmd.Process.Pid) + "\n") - err := os.WriteFile(s.PidFile, pidbuf, 0644) - if err != nil { - s.log.Warnf("Failed to write file %s: %v", s.PidFile, err) + pidbuf := []byte(strconv.Itoa(s.cmd.Process.Pid)) + if err := os.WriteFile(s.pidFile, pidbuf, 0644); err != nil { + return false, fmt.Errorf("failed to write file %s: %w", s.pidFile, err) } defer func() { - _ = os.Remove(s.PidFile) + _ = os.Remove(s.pidFile) }() select { case <-ctx.Done(): - for { - if runtime.GOOS == "windows" { - // Graceful shutdown not implemented on Windows. This requires - // attaching to the target process's console and generating a - // CTRL+BREAK (or CTRL+C) event. Since a process can only be - // attached to a single console at a time, this would require - // k0s to detach from its own console, which is definitely not - // something that k0s wants to do. There might be ways to do - // this by generating the event via a separate helper process, - // but that's left open here as a TODO. - // https://learn.microsoft.com/en-us/windows/console/freeconsole - // https://learn.microsoft.com/en-us/windows/console/attachconsole - // https://learn.microsoft.com/en-us/windows/console/generateconsolectrlevent - // https://learn.microsoft.com/en-us/windows/console/ctrl-c-and-ctrl-break-signals - s.log.Infof("Killing pid %d", s.cmd.Process.Pid) - if err := s.cmd.Process.Kill(); err != nil { - s.log.Warnf("Failed to kill pid %d: %s", s.cmd.Process.Pid, err) - } - } else { - s.log.Infof("Shutting down pid %d", s.cmd.Process.Pid) - if err := s.cmd.Process.Signal(syscall.SIGTERM); err != nil { - s.log.Warnf("Failed to send SIGTERM to pid %d: %s", s.cmd.Process.Pid, err) - } - } - select { - case <-time.After(s.TimeoutStop): - continue - case <-waitresult: - return true - } + if err := s.maybeKillPid(); err != nil { + return true, fmt.Errorf("failed to kill %s process: %w", s.name, err) } + return true, nil case err := <-waitresult: if err != nil { - s.log.Warnf("Failed to wait for process: %v", err) - } else { - s.log.Warnf("Process exited: %s", s.cmd.ProcessState) + s.log.Warnf("Failed to waiting for process %q: %v", s.name, err) + break } + s.log.Warnf("Process %q exited: %s", s.name, s.cmd.ProcessState) } - return false + return false, nil } -// Supervise Starts supervising the given process +// Supervise Starts supervising the given process. func (s *Supervisor) Supervise() error { s.startStopMutex.Lock() defer s.startStopMutex.Unlock() - // check if it is already started if s.cancel != nil { - s.log.Warn("Already started") + s.log.Warn("Supervisor for %q already started", s.name) return nil } - s.log = logrus.WithField("component", s.Name) - s.PidFile = path.Join(s.RunDir, s.Name) + ".pid" - - if err := os.MkdirAll(s.RunDir, 0755); err != nil { + dir := filepath.Dir(s.pidFile) + if err := os.MkdirAll(dir, 0755); err != nil { + return fmt.Errorf("failed to create run dir: %w", err) + } + if err := s.maybeKillPid(); err != nil { return err } + var ctx context.Context + ctx, s.cancel = context.WithCancel(context.Background()) + s.done = make(chan bool) + if err := s.supervise(ctx); err != nil { + return fmt.Errorf("failed to supervise %q: %w", s.name, err) + } + return nil +} - if s.TimeoutStop == 0 { - s.TimeoutStop = 5 * time.Second +// DetachAttr creates the proper syscall attributes to run the managed processes. +func (s *Supervisor) detachAttr() *syscall.SysProcAttr { + var creds *syscall.Credential + if os.Geteuid() == 0 { + creds = &syscall.Credential{ + Uid: uint32(s.uid), + Gid: uint32(s.gid), + } } - if s.TimeoutRespawn == 0 { - s.TimeoutRespawn = 5 * time.Second + return &syscall.SysProcAttr{ + Setpgid: true, + Pgid: 0, + Credential: creds, } +} - if err := s.maybeKillPidFile(nil, nil); err != nil { - return err +// supervise starts the process and waits for it to exit. +func (s *Supervisor) supervise(ctx context.Context) error { + defer func() { + close(s.done) + }() + s.log.Infof("Starting to supervise %q", s.name) + s.cmd = exec.Command(s.binPath, s.args...) + s.cmd.Stdout = logrus.WithField("stream", "stdout").Writer() + s.cmd.Stdout = logrus.WithField("stream", "stderr").Writer() + s.cmd.SysProcAttr = s.detachAttr() + if err := s.cmd.Start(); err != nil { + return fmt.Errorf("failed to start %q: %w", s.name, err) } - - var ctx context.Context - ctx, s.cancel = context.WithCancel(context.Background()) - started := make(chan error) - s.done = make(chan bool) + s.log.Infof("Started %q with pid %d", s.name, s.cmd.Process.Pid) go func() { - defer func() { - close(s.done) - }() - - s.log.Info("Starting to supervise") - restarts := 0 + var restarts int for { - s.mutex.Lock() - - var err error - if s.CleanBeforeFn != nil { - err = s.CleanBeforeFn() - } - if err != nil { - s.log.Warnf("Failed to clean before running the process %s: %s", s.BinPath, err) - } else { - s.cmd = exec.Command(s.BinPath, s.Args...) - s.cmd.Dir = s.DataDir - s.cmd.Env = getEnv(s.DataDir, s.Name, s.KeepEnvPrefix) - - // detach from the process group so children don't - // get signals sent directly to parent. - s.cmd.SysProcAttr = DetachAttr(s.UID, s.GID) - - const maxLogChunkLen = 16 * 1024 - if s.Output != nil { - s.cmd.Stdout = s.Output - s.cmd.Stderr = s.Output - } else { - s.cmd.Stdout = &logWriter{ - log: logrus.WithField("stream", "stdout"), - buf: make([]byte, maxLogChunkLen), - } - s.cmd.Stderr = &logWriter{ - log: logrus.WithField("stream", "stderr"), - buf: make([]byte, maxLogChunkLen), - } - } - - err = s.cmd.Start() - } - s.mutex.Unlock() - if err != nil { - s.log.Warnf("Failed to start: %s", err) - if restarts == 0 { - started <- err - return - } - } else { - if restarts == 0 { - s.log.Infof("Started successfully, go nuts pid %d", s.cmd.Process.Pid) - started <- nil - } else { - s.log.Infof("Restarted (%d)", restarts) - } - restarts++ - if s.processWaitQuit(ctx) { - return - } + if ended, err := s.processWaitQuit(ctx); err != nil { + s.log.Errorf("Supervise for %q ended with error: %w", s.name, err) + return + } else if ended { + s.log.Infof("Supervise for %q ended", s.name) + return } - - // TODO Maybe some backoff thingy would be nice - s.log.Infof("respawning in %s", s.TimeoutRespawn.String()) - + restarts++ + s.log.Infof("Respawning %q in %s", s.name, s.timeoutRespawn.String()) select { case <-ctx.Done(): - s.log.Debug("respawn cancelled") + s.log.Infof("Respawn of %q cancelled", s.name) return - case <-time.After(s.TimeoutRespawn): - s.log.Debug("respawning") + case <-time.After(s.timeoutRespawn): + s.log.Infof("Respawning %q", s.name) + } + s.cmd = exec.Command(s.binPath, s.args...) + s.cmd.Stdout = logrus.WithField("stream", "stdout").Writer() + s.cmd.Stdout = logrus.WithField("stream", "stderr").Writer() + s.cmd.SysProcAttr = s.detachAttr() + if err := s.cmd.Start(); err != nil { + s.log.Errorf("Failed to respawn %q: %s", s.name, err) } } }() - return <-started + return nil } -// Stop stops the supervised +// Stop stops the supervised process. func (s *Supervisor) Stop() error { s.startStopMutex.Lock() defer s.startStopMutex.Unlock() if s.cancel == nil { - s.log.Warn("Not started") + s.log.Warn("Supervised not started") return nil } - s.log.Debug("Sending stop message") - + s.log.Infof("Sending stop message") s.cancel() s.cancel = nil - s.log.Debug("Waiting for stopping is done") if s.done != nil { <-s.done } + s.log.Infof("Supervisor for %q stopped", s.name) return nil } -// Prepare the env for exec: -// - handle component specific env -// - inject k0s embedded bins into path -func getEnv(dataDir, component string, keepEnvPrefix bool) []string { - env := os.Environ() - componentPrefix := fmt.Sprintf("%s_", strings.ToUpper(component)) - - // put the component specific env vars in the front. - sort.Slice(env, func(i, j int) bool { return strings.HasPrefix(env[i], componentPrefix) }) - - overrides := map[string]struct{}{} - i := 0 - for _, e := range env { - kv := strings.SplitN(e, "=", 2) - k, v := kv[0], kv[1] - // if there is already a correspondent component specific env, skip it. - if _, ok := overrides[k]; ok { - continue - } - if strings.HasPrefix(k, componentPrefix) { - var shouldOverride bool - k1 := strings.TrimPrefix(k, componentPrefix) - switch k1 { - // always override proxy env - case "HTTP_PROXY", "HTTPS_PROXY", "NO_PROXY": - shouldOverride = true - default: - if !keepEnvPrefix { - shouldOverride = true +// killPid signals SIGTERM to a PID and if it's still running after s.timeoutStop sends SIGKILL. +func (s *Supervisor) killPid(pid int) error { + deadlineTicker := time.NewTicker(s.timeoutStop) + checkTicker := time.NewTicker(time.Second) + defer deadlineTicker.Stop() + defer checkTicker.Stop() + var stop bool + for { + select { + case <-checkTicker.C: + s.log.Infof("Sending SIGTERM to pid %d", pid) + if err := syscall.Kill(pid, syscall.SIGTERM); err != nil { + if err == syscall.ESRCH { + return nil } + return fmt.Errorf("failed to send sigterm to %d: %w", pid, err) } - if shouldOverride { - k = k1 - overrides[k] = struct{}{} - } + case <-deadlineTicker.C: + stop = true } - env[i] = fmt.Sprintf("%s=%s", k, v) - if k == "PATH" { - env[i] = fmt.Sprintf("PATH=%s:%s", path.Join(dataDir, "bin"), v) + if !stop { + continue + } + s.log.Errorf("Process %d still running, sending SIGKILL", pid) + break + } + if err := syscall.Kill(pid, syscall.SIGKILL); err != nil { + if err == syscall.ESRCH { + return nil } - i++ + return fmt.Errorf("failed to send SIGKILL to pid %d: %s", s.cmd.Process.Pid, err) } - env = append([]string{k0sManaged}, env...) - i++ + return nil +} - return env[:i] +// maybeKillPid checks kills the process in the pidFile if it's has the same binary as the supervisor's. +// This function does not delete the old pidFile as this is done by the caller. +func (s *Supervisor) maybeKillPid() error { + bpid, err := os.ReadFile(s.pidFile) + if err != nil { + if os.IsNotExist(err) { + return nil + } + return fmt.Errorf("failed to read pid file %s: %v", s.pidFile, err) + } + pid, err := strconv.Atoi(string(bpid)) + if err != nil { + return fmt.Errorf("failed to parse pid file %s: %v", s.pidFile, err) + } + if should, err := s.shouldKillProcess(pid); err != nil { + return fmt.Errorf("failed to assess if we should kill pid %d: %w", pid, err) + } else if !should { + return fmt.Errorf("pid %d is not a %q process", pid, s.name) + } + return s.killPid(pid) } -// GetProcess returns the last started process -func (s *Supervisor) GetProcess() *os.Process { - s.mutex.Lock() - defer s.mutex.Unlock() - return s.cmd.Process +// shouldKillProcess returns true if the proccess with the provided pid should be killed. By should be +// killed is understood as the command for process with the given pid matches the command we are +// supervising. +func (s *Supervisor) shouldKillProcess(pid int) (bool, error) { + cmdline, err := os.ReadFile(filepath.Join("/proc", strconv.Itoa(pid), "cmdline")) + if os.IsNotExist(err) { + return false, nil + } else if err != nil { + return false, fmt.Errorf("failed to read process %d cmdline: %v", pid, err) + } + if cmd := strings.Split(string(cmdline), "\x00"); len(cmd) > 0 { + return cmd[0] == s.binPath, nil + } + return false, nil } diff --git a/pkg/supervisor/supervisor_test.go b/pkg/supervisor/supervisor_test.go new file mode 100644 index 0000000000..13674f90b8 --- /dev/null +++ b/pkg/supervisor/supervisor_test.go @@ -0,0 +1,137 @@ +package supervisor + +import ( + "bytes" + "context" + "fmt" + "os" + "os/exec" + "os/user" + "strconv" + "testing" + "time" + + "github.com/stretchr/testify/assert" +) + +func TestNew(t *testing.T) { + usr, err := user.Current() + assert.NoError(t, err) + uid, err := strconv.Atoi(usr.Uid) + assert.NoError(t, err) + gid, err := strconv.Atoi(usr.Gid) + assert.NoError(t, err) + s, err := New("/bin/cat", []string{"-"}) + assert.NoError(t, err) + assert.Equal(t, "/bin/cat", s.binPath) + assert.Equal(t, []string{"-"}, s.args) + assert.Equal(t, uid, s.uid) + assert.Equal(t, gid, s.gid) + assert.NotZero(t, s.timeoutRespawn) + assert.NotZero(t, s.timeoutStop) + assert.NotNil(t, s.log) + assert.Equal(t, s.name, "cat") +} + +func TestSupervise(t *testing.T) { + s, err := New("/usr/bin/sleep", []string{"60"}, WithPIDFile("/tmp/cat.pid")) + assert.NoError(t, err) + fmt.Println(time.Now()) + assert.NoError(t, s.Supervise()) + fmt.Println(time.Now()) + time.Sleep(time.Second) + assert.NoError(t, s.Stop()) +} + +func Test_shouldKillProcess(t *testing.T) { + s, err := New("", nil) + assert.NoError(t, err) + should, err := s.shouldKillProcess(-999) + assert.NoError(t, err) + assert.False(t, should) + should, err = s.shouldKillProcess(os.Getpid()) + assert.NoError(t, err) + assert.False(t, should) + s, err = New("/usr/bin/sleep", []string{"60"}, WithPIDFile("/tmp/cat.pid")) + assert.NoError(t, err) + assert.NoError(t, s.Supervise()) + time.Sleep(time.Second) + should, err = s.shouldKillProcess(s.cmd.Process.Pid) + assert.NoError(t, err) + assert.True(t, should) + assert.NoError(t, s.Stop()) +} + +func Test_killPid(t *testing.T) { + s, err := New("/usr/bin/sleep", []string{"60"}, WithPIDFile("/tmp/cat.pid")) + assert.NoError(t, err) + assert.NoError(t, s.Supervise()) + time.Sleep(time.Second) + assert.NoError(t, s.killPid(s.cmd.Process.Pid)) + assert.NoError(t, s.Stop()) + cmd := exec.Command("/usr/bin/sleep", "60") + assert.NoError(t, cmd.Start()) + assert.NoError(t, s.killPid(cmd.Process.Pid)) +} + +func Test_processWaitQuit(t *testing.T) { + cmd := exec.Command("/usr/bin/sleep", "5") + assert.NoError(t, cmd.Start()) + ppath := "/does-not-exist/does-not-exist.pid" + s := Supervisor{pidFile: ppath, cmd: cmd} + s.pidFile = "/does-not-exist/does-not-exist.pid" + _, err := s.processWaitQuit(context.Background()) + assert.Error(t, err, "failed to write file %[1]s: open %[1]s: no such file or directory", ppath) +} + +var goodScript = `#!/bin/sh +date >> /tmp/good_supervisor_test.log +` + +func TestExitingProcess(t *testing.T) { + assert.NoError(t, os.RemoveAll("/tmp/good_supervisor_test.log")) + assert.NoError(t, os.WriteFile("/tmp/good_supervisor_test.sh", []byte(goodScript), 0755)) + s, err := New("/tmp/good_supervisor_test.sh", nil, WithPIDFile("/tmp/good_supervisor_test.pid"), WithTimeoutRespawn(100*time.Millisecond)) + assert.NoError(t, err) + assert.NoError(t, s.Supervise()) + time.Sleep(time.Second) + assert.NoError(t, s.Stop()) + assert.FileExists(t, "/tmp/good_supervisor_test.log") + data, err := os.ReadFile("/tmp/good_supervisor_test.log") + assert.NoError(t, err) + lines := bytes.Split(data, []byte("\n")) + assert.Greater(t, len(lines), 1) +} + +var badScript = `#!/bin/sh +date >> /tmp/bad_supervisor_test.log +exit 3 +` + +func TestCrashingProcess(t *testing.T) { + assert.NoError(t, os.RemoveAll("/tmp/bad_supervisor_test.log")) + assert.NoError(t, os.WriteFile("/tmp/bad_supervisor_test.sh", []byte(badScript), 0755)) + s, err := New("/tmp/bad_supervisor_test.sh", nil, WithPIDFile("/tmp/bad_supervisor_test.pid"), WithTimeoutRespawn(100*time.Millisecond)) + assert.NoError(t, err) + assert.NoError(t, s.Supervise()) + time.Sleep(time.Second) + assert.NoError(t, s.Stop()) + assert.FileExists(t, "/tmp/bad_supervisor_test.log") + data, err := os.ReadFile("/tmp/bad_supervisor_test.log") + assert.NoError(t, err) + lines := bytes.Split(data, []byte("\n")) + assert.Greater(t, len(lines), 1) +} + +func Test_maybeKillPid(t *testing.T) { + ppath := "/tmp/maybe_kill_pid.pid" + assert.NoError(t, os.WriteFile(ppath, []byte("abc"), 0644)) + s, err := New("/usr/bin/sleep", []string{"60"}, WithPIDFile(ppath)) + assert.NoError(t, err) + assert.Error(t, s.maybeKillPid(), "failed to parse pid file %[1]s: strconv.Atoi: parsing \"abc\": invalid syntax", ppath) + + assert.NoError(t, os.WriteFile(ppath, []byte("1"), 0644)) + s, err = New("/usr/bin/sleep", []string{"60"}, WithPIDFile(ppath)) + assert.NoError(t, err) + assert.Error(t, s.maybeKillPid(), `pid 1 is not a "sleep" process`) +} diff --git a/pkg/supervisor/supervisor_unix.go b/pkg/supervisor/supervisor_unix.go deleted file mode 100644 index b607b09626..0000000000 --- a/pkg/supervisor/supervisor_unix.go +++ /dev/null @@ -1,132 +0,0 @@ -//go:build unix - -package supervisor - -import ( - "fmt" - "os" - "path/filepath" - "strconv" - "strings" - "syscall" - "time" -) - -const ( - exitCheckInterval = 200 * time.Millisecond -) - -// killPid signals SIGTERM to a PID and if it's still running after -// s.TimeoutStop sends SIGKILL. -func (s *Supervisor) killPid(pid int, check <-chan time.Time, deadline <-chan time.Time) error { - if s.KillFunction == nil { - s.KillFunction = syscall.Kill - } - // Kill the process pid - if deadline == nil { - deadlineTicker := time.NewTicker(s.TimeoutStop) - deadline = deadlineTicker.C - defer deadlineTicker.Stop() - } - if check == nil { - checkTicker := time.NewTicker(exitCheckInterval) - check = checkTicker.C - defer checkTicker.Stop() - - } - - // Using two tickers is not very elegant but makes testing easier... -Loop: - for { - select { - case <-check: - shouldKill, err := s.shouldKillProcess(pid) - if err != nil { - return err - } - if !shouldKill { - return nil - } - - err = s.KillFunction(pid, syscall.SIGTERM) - if err == syscall.ESRCH { - return nil - } else if err != nil { - return fmt.Errorf("failed to send SIGTERM to pid %d: %s", s.cmd.Process.Pid, err) - } - case <-deadline: - break Loop - } - } - - shouldKill, err := s.shouldKillProcess(pid) - if err != nil { - return err - } - if !shouldKill { - return nil - } - - err = s.KillFunction(pid, syscall.SIGKILL) - if err == syscall.ESRCH { - return nil - } else if err != nil { - return fmt.Errorf("failed to send SIGKILL to pid %d: %s", s.cmd.Process.Pid, err) - } - return nil -} - -// maybeKillPidFile checks kills the process in the pidFile if it's has -// the same binary as the supervisor's and also checks that the env -// `_KOS_MANAGED=yes`. This function does not delete the old pidFile as -// this is done by the caller. -// The tickers are used for testing purposes, otherwise set them to nil. -func (s *Supervisor) maybeKillPidFile(check <-chan time.Time, deadline <-chan time.Time) error { - if s.ProcFSPath == "" { - s.ProcFSPath = "/proc" - } - - pid, err := os.ReadFile(s.PidFile) - if os.IsNotExist(err) { - return nil - } else if err != nil { - return fmt.Errorf("failed to read pid file %s: %v", s.PidFile, err) - } - - p, err := strconv.Atoi(strings.TrimSuffix(string(pid), "\n")) - if err != nil { - return fmt.Errorf("failed to parse pid file %s: %v", s.PidFile, err) - } - - return s.killPid(p, check, deadline) -} - -func (s *Supervisor) shouldKillProcess(pid int) (bool, error) { - cmdline, err := os.ReadFile(filepath.Join(s.ProcFSPath, strconv.Itoa(pid), "cmdline")) - if os.IsNotExist(err) { - return false, nil - } else if err != nil { - return false, fmt.Errorf("failed to read process %d cmdline: %v", pid, err) - } - - // only kill process if it has the expected cmd - cmd := strings.Split(string(cmdline), "\x00") - if cmd[0] != s.BinPath { - return false, nil - } - - //only kill process if it has the _KOS_MANAGED env set - env, err := os.ReadFile(filepath.Join(s.ProcFSPath, strconv.Itoa(pid), "environ")) - if os.IsNotExist(err) { - return false, nil - } else if err != nil { - return false, fmt.Errorf("failed to read process %d environ: %v", pid, err) - } - - for _, e := range strings.Split(string(env), "\x00") { - if e == k0sManaged { - return true, nil - } - } - return false, nil -} From 7ed7829f2c461b0b2b27663ac7fed463db4182d3 Mon Sep 17 00:00:00 2001 From: Ricardo Maraschini Date: Thu, 22 Jun 2023 15:46:44 +0200 Subject: [PATCH 2/3] chore: fix lint warnings. - keep lines shorter. - added a package comment to the supervisor package. --- pkg/supervisor/supervisor.go | 3 ++- pkg/supervisor/supervisor_test.go | 16 ++++++++++++---- 2 files changed, 14 insertions(+), 5 deletions(-) diff --git a/pkg/supervisor/supervisor.go b/pkg/supervisor/supervisor.go index 6de3ab6c01..479a681598 100644 --- a/pkg/supervisor/supervisor.go +++ b/pkg/supervisor/supervisor.go @@ -1,3 +1,4 @@ +// Package supervisor package implements tooling for spawning and keep processes running. package supervisor import ( @@ -256,7 +257,7 @@ func (s *Supervisor) maybeKillPid() error { return s.killPid(pid) } -// shouldKillProcess returns true if the proccess with the provided pid should be killed. By should be +// shouldKillProcess returns true if the process with the provided pid should be killed. By should be // killed is understood as the command for process with the given pid matches the command we are // supervising. func (s *Supervisor) shouldKillProcess(pid int) (bool, error) { diff --git a/pkg/supervisor/supervisor_test.go b/pkg/supervisor/supervisor_test.go index 13674f90b8..b11ec17e61 100644 --- a/pkg/supervisor/supervisor_test.go +++ b/pkg/supervisor/supervisor_test.go @@ -91,7 +91,9 @@ date >> /tmp/good_supervisor_test.log func TestExitingProcess(t *testing.T) { assert.NoError(t, os.RemoveAll("/tmp/good_supervisor_test.log")) assert.NoError(t, os.WriteFile("/tmp/good_supervisor_test.sh", []byte(goodScript), 0755)) - s, err := New("/tmp/good_supervisor_test.sh", nil, WithPIDFile("/tmp/good_supervisor_test.pid"), WithTimeoutRespawn(100*time.Millisecond)) + script := "/tmp/good_supervisor_test.sh" + pidpath := "/tmp/good_supervisor_test.pid" + s, err := New(script, nil, WithPIDFile(pidpath), WithTimeoutRespawn(100*time.Millisecond)) assert.NoError(t, err) assert.NoError(t, s.Supervise()) time.Sleep(time.Second) @@ -111,7 +113,9 @@ exit 3 func TestCrashingProcess(t *testing.T) { assert.NoError(t, os.RemoveAll("/tmp/bad_supervisor_test.log")) assert.NoError(t, os.WriteFile("/tmp/bad_supervisor_test.sh", []byte(badScript), 0755)) - s, err := New("/tmp/bad_supervisor_test.sh", nil, WithPIDFile("/tmp/bad_supervisor_test.pid"), WithTimeoutRespawn(100*time.Millisecond)) + script := "/tmp/good_supervisor_test.sh" + pidpath := "/tmp/good_supervisor_test.pid" + s, err := New(script, nil, WithPIDFile(pidpath), WithTimeoutRespawn(100*time.Millisecond)) assert.NoError(t, err) assert.NoError(t, s.Supervise()) time.Sleep(time.Second) @@ -128,8 +132,12 @@ func Test_maybeKillPid(t *testing.T) { assert.NoError(t, os.WriteFile(ppath, []byte("abc"), 0644)) s, err := New("/usr/bin/sleep", []string{"60"}, WithPIDFile(ppath)) assert.NoError(t, err) - assert.Error(t, s.maybeKillPid(), "failed to parse pid file %[1]s: strconv.Atoi: parsing \"abc\": invalid syntax", ppath) - + assert.Error( + t, + s.maybeKillPid(), + "failed to parse pid file %[1]s: strconv.Atoi: parsing \"abc\": invalid syntax", + ppath, + ) assert.NoError(t, os.WriteFile(ppath, []byte("1"), 0644)) s, err = New("/usr/bin/sleep", []string{"60"}, WithPIDFile(ppath)) assert.NoError(t, err) From 2eb9a1e70626df768383d1e464f2ee5bfc1b3110 Mon Sep 17 00:00:00 2001 From: Ricardo Maraschini Date: Thu, 22 Jun 2023 16:19:31 +0200 Subject: [PATCH 3/3] chore: fixed wrong unit tests these tests were wrong. --- pkg/supervisor/supervisor_test.go | 13 +++++-------- 1 file changed, 5 insertions(+), 8 deletions(-) diff --git a/pkg/supervisor/supervisor_test.go b/pkg/supervisor/supervisor_test.go index b11ec17e61..2ad394fb26 100644 --- a/pkg/supervisor/supervisor_test.go +++ b/pkg/supervisor/supervisor_test.go @@ -3,7 +3,6 @@ package supervisor import ( "bytes" "context" - "fmt" "os" "os/exec" "os/user" @@ -34,11 +33,9 @@ func TestNew(t *testing.T) { } func TestSupervise(t *testing.T) { - s, err := New("/usr/bin/sleep", []string{"60"}, WithPIDFile("/tmp/cat.pid")) + s, err := New("/usr/bin/sleep", []string{"60"}, WithPIDFile("/tmp/supervise.pid")) assert.NoError(t, err) - fmt.Println(time.Now()) assert.NoError(t, s.Supervise()) - fmt.Println(time.Now()) time.Sleep(time.Second) assert.NoError(t, s.Stop()) } @@ -52,7 +49,7 @@ func Test_shouldKillProcess(t *testing.T) { should, err = s.shouldKillProcess(os.Getpid()) assert.NoError(t, err) assert.False(t, should) - s, err = New("/usr/bin/sleep", []string{"60"}, WithPIDFile("/tmp/cat.pid")) + s, err = New("/usr/bin/sleep", []string{"60"}, WithPIDFile("/tmp/should_kill.pid")) assert.NoError(t, err) assert.NoError(t, s.Supervise()) time.Sleep(time.Second) @@ -63,7 +60,7 @@ func Test_shouldKillProcess(t *testing.T) { } func Test_killPid(t *testing.T) { - s, err := New("/usr/bin/sleep", []string{"60"}, WithPIDFile("/tmp/cat.pid")) + s, err := New("/usr/bin/sleep", []string{"60"}, WithPIDFile("/tmp/kill_pid.pid")) assert.NoError(t, err) assert.NoError(t, s.Supervise()) time.Sleep(time.Second) @@ -113,8 +110,8 @@ exit 3 func TestCrashingProcess(t *testing.T) { assert.NoError(t, os.RemoveAll("/tmp/bad_supervisor_test.log")) assert.NoError(t, os.WriteFile("/tmp/bad_supervisor_test.sh", []byte(badScript), 0755)) - script := "/tmp/good_supervisor_test.sh" - pidpath := "/tmp/good_supervisor_test.pid" + script := "/tmp/bad_supervisor_test.sh" + pidpath := "/tmp/bad_supervisor_test.pid" s, err := New(script, nil, WithPIDFile(pidpath), WithTimeoutRespawn(100*time.Millisecond)) assert.NoError(t, err) assert.NoError(t, s.Supervise())