Skip to content

Commit

Permalink
integrate sudo into other command types
Browse files Browse the repository at this point in the history
lint: remove unreachable code

add test of delete command with and without sudo
  • Loading branch information
umputun committed May 8, 2023
1 parent 467d568 commit 30987fd
Show file tree
Hide file tree
Showing 5 changed files with 206 additions and 70 deletions.
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -246,7 +246,7 @@ Each command type supports the following options:
- `ignore_errors`: if set to `true` the command will not fail the task in case of an error.
- `no_auto`: if set to `true` the command will not be executed automatically, but can be executed manually using the `--only` flag.
- `local`: if set to `true` the command will be executed on the local host (the one running the `spot` command) instead of the remote host(s).
- `sudo`: if set to `true` the script command will be executed with `sudo` privileges.
- `sudo`: if set to `true` the command will be executed with `sudo` privileges.

example setting `ignore_errors` and `no_auto` options:

Expand All @@ -257,7 +257,7 @@ example setting `ignore_errors` and `no_auto` options:
options: {ignore_errors: true, no_auto: true}
```

Please note that the `sudo` option is only supported for the `script` command type. This limitation exists because there is no direct and universal method for uploading files over SFTP with sudo privileges. As a workaround, users can first use the `copy` command to transfer files to a temporary location, and then execute a `script` command with `sudo: true` to move those files to their final destination. Alternatively, using the root user directly in the playbook will allow direct file transfer to any restricted location and enable running privileged commands without the need to use sudo.
Please note that the `sudo` option is not supported for the `sync` command type, but all other command types support it.


### Script Execution
Expand Down
14 changes: 7 additions & 7 deletions pkg/config/command.go
Original file line number Diff line number Diff line change
Expand Up @@ -73,17 +73,17 @@ func (cmd *Cmd) GetScript() (string, io.Reader) {
elems := strings.Split(cmd.Script, "\n")
if len(elems) > 1 {
log.Printf("[DEBUG] command %q is multiline, using script file", cmd.Name)
return "", cmd.getScriptFile()
return "", cmd.scriptFile(cmd.Script)
}

log.Printf("[DEBUG] command %q is single line, using script string", cmd.Name)
return cmd.getScriptCommand(), nil
return cmd.scriptCommand(cmd.Script), nil
}

// GetScriptCommand concatenates all script line in commands into one a string to be executed by shell.
// Empty string is returned if no script is defined.
func (cmd *Cmd) getScriptCommand() string {
if cmd.Script == "" {
func (cmd *Cmd) scriptCommand(inp string) string {
if inp == "" {
return ""
}

Expand All @@ -100,7 +100,7 @@ func (cmd *Cmd) getScriptCommand() string {
res += strings.Join(secrets, " ") + " "
}

elems := strings.Split(cmd.Script, "\n")
elems := strings.Split(inp, "\n")
var parts []string // nolint
for _, el := range elems {
c := strings.TrimSpace(el)
Expand All @@ -118,7 +118,7 @@ func (cmd *Cmd) getScriptCommand() string {

// GetScriptFile returns a reader for script file. All the line in the command used as a script, with hashbang,
// set -e and environment variables.
func (cmd *Cmd) getScriptFile() io.Reader {
func (cmd *Cmd) scriptFile(inp string) io.Reader {
var buf bytes.Buffer

buf.WriteString("#!/bin/sh\n") // add hashbang
Expand All @@ -133,7 +133,7 @@ func (cmd *Cmd) getScriptFile() io.Reader {
}
}

elems := strings.Split(cmd.Script, "\n")
elems := strings.Split(inp, "\n")
for _, el := range elems {
c := strings.TrimSpace(el)
if len(c) < 2 {
Expand Down
8 changes: 4 additions & 4 deletions pkg/config/command_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -123,21 +123,21 @@ func TestCmd_getScriptCommand(t *testing.T) {
t.Run("script", func(t *testing.T) {
cmd := c.Tasks[0].Commands[3]
assert.Equal(t, "git", cmd.Name, "name")
res := cmd.getScriptCommand()
res := cmd.scriptCommand(cmd.Script)
assert.Equal(t, `sh -c "git clone https://example.com/remark42.git /srv || true; cd /srv; git pull"`, res)
})

t.Run("no-script", func(t *testing.T) {
cmd := c.Tasks[0].Commands[1]
assert.Equal(t, "copy configuration", cmd.Name)
res := cmd.getScriptCommand()
res := cmd.scriptCommand(cmd.Script)
assert.Equal(t, "", res)
})

t.Run("script with env", func(t *testing.T) {
cmd := c.Tasks[0].Commands[4]
assert.Equal(t, "docker", cmd.Name)
res := cmd.getScriptCommand()
res := cmd.scriptCommand(cmd.Script)
assert.Equal(t, `sh -c "BAR='qux' FOO='bar' docker pull umputun/remark42:latest; docker stop remark42 || true; docker rm remark42 || true; docker run -d --name remark42 -p 8080:8080 umputun/remark42:latest"`, res)
})
}
Expand Down Expand Up @@ -212,7 +212,7 @@ func TestCmd_getScriptFile(t *testing.T) {

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
reader := tt.cmd.getScriptFile()
reader := tt.cmd.scriptFile(tt.cmd.Script)
scriptContentBytes, err := io.ReadAll(reader)
assert.NoError(t, err)
scriptContent := string(scriptContentBytes)
Expand Down
123 changes: 75 additions & 48 deletions pkg/runner/runner.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,8 @@ type Process struct {
secrets []string
}

const tmpRemoteDir = "/tmp/.spot" // this is a directory on remote host to store temporary files

// Connector is an interface for connecting to a host, and returning remote executer.
type Connector interface {
Connect(ctx context.Context, hostAddr, hostName, user string) (*executor.Remote, error)
Expand Down Expand Up @@ -263,21 +265,25 @@ func (p *Process) execCopyCommand(ctx context.Context, ep execCmdParams) (detail
if ep.cmd.Options.Sudo {
// if sudo is set, we need to upload the file to a temporary directory and move it to the final destination
details = fmt.Sprintf(" {copy: %s -> %s, sudo: true}", src, dst)
tmpDest := filepath.Join("/tmp/.spot", filepath.Base(dst))
tmpDest := filepath.Join(tmpRemoteDir, filepath.Base(dst))
if err := ep.exec.Upload(ctx, src, tmpDest, true); err != nil { // upload to a temporary directory with mkdir
return details, fmt.Errorf("can't copy file to %s: %w", ep.hostAddr, err)
}

mvScripts := fmt.Sprintf("mv -f %s %s\n rm -rf %s", tmpDest, dst, tmpDest)
mvCmd := fmt.Sprintf("mv -f %s %s", tmpDest, dst) // move a single file
if strings.Contains(src, "*") && !strings.HasSuffix(tmpDest, "/") {
mvScripts = fmt.Sprintf("mv -f %s/* %s\n rm -rf %s", tmpDest, dst, tmpDest)
mvCmd = fmt.Sprintf("mv -f %s/* %s", tmpDest, dst) // move multiple files, if wildcard is used
defer func() {
// remove temporary directory we created under /tmp/.spot for multiple files
if _, err := ep.exec.Run(ctx, fmt.Sprintf("rm -rf %s", tmpDest), p.Verbose); err != nil {
log.Printf("[WARN] can't remove temporary directory on %s: %v", ep.hostAddr, err)
}
}()
}
rdr := strings.NewReader(mvScripts)
c, teardown, err := p.prepScript(ctx, "", rdr, ep)
c, _, err := p.prepScript(ctx, mvCmd, nil, ep)
if err != nil {
return details, fmt.Errorf("can't prepare script sudo moving on %s: %w", ep.hostAddr, err)
return details, fmt.Errorf("can't prepare sudo moving command on %s: %w", ep.hostAddr, err)
}
defer func() { _ = teardown() }()

sudoMove := fmt.Sprintf("sudo %s", c)
if _, err := ep.exec.Run(ctx, sudoMove, p.Verbose); err != nil {
Expand All @@ -296,8 +302,10 @@ func (p *Process) execMCopyCommand(ctx context.Context, ep execCmdParams) (detai
dst := p.applyTemplates(c.Dest,
templateData{hostAddr: ep.hostAddr, hostName: ep.hostName, task: ep.tsk, command: ep.cmd.Name})
msgs = append(msgs, fmt.Sprintf("%s -> %s", src, dst))
if err := ep.exec.Upload(ctx, src, dst, c.Mkdir); err != nil {
return details, fmt.Errorf("can't copy file on %s: %w", ep.hostAddr, err)
epSingle := ep
epSingle.cmd.Copy = config.CopyInternal{Source: src, Dest: dst, Mkdir: c.Mkdir}
if _, err := p.execCopyCommand(ctx, epSingle); err != nil {
return details, fmt.Errorf("can't copy file to %s: %w", ep.hostAddr, err)
}
}
details = fmt.Sprintf(" {copy: %s}", strings.Join(msgs, ", "))
Expand All @@ -319,23 +327,71 @@ func (p *Process) execSyncCommand(ctx context.Context, ep execCmdParams) (detail
func (p *Process) execDeleteCommand(ctx context.Context, ep execCmdParams) (details string, err error) {
loc := p.applyTemplates(ep.cmd.Delete.Location,
templateData{hostAddr: ep.hostAddr, hostName: ep.hostName, task: ep.tsk, command: ep.cmd.Name})
details = fmt.Sprintf(" {delete: %s, recursive: %v}", loc, ep.cmd.Delete.Recursive)
if err := ep.exec.Delete(ctx, loc, ep.cmd.Delete.Recursive); err != nil {
return details, fmt.Errorf("can't delete files on %s: %w", ep.hostAddr, err)

if !ep.cmd.Options.Sudo {
// if sudo is not set, we can delete the file directly
if err := ep.exec.Delete(ctx, loc, ep.cmd.Delete.Recursive); err != nil {
return details, fmt.Errorf("can't delete files on %s: %w", ep.hostAddr, err)
}
details = fmt.Sprintf(" {delete: %s, recursive: %v}", loc, ep.cmd.Delete.Recursive)
}

if ep.cmd.Options.Sudo {
// if sudo is set, we need to delete the file using sudo by ssh-ing into the host and running the command
cmd := fmt.Sprintf("sudo rm -f %s", loc)
if ep.cmd.Delete.Recursive {
cmd = fmt.Sprintf("sudo rm -rf %s", loc)
}
if _, err := ep.exec.Run(ctx, cmd, p.Verbose); err != nil {
return details, fmt.Errorf("can't delete file(s) on %s: %w", ep.hostAddr, err)
}
details = fmt.Sprintf(" {delete: %s, recursive: %v, sudo: true}", loc, ep.cmd.Delete.Recursive)
}

return details, nil
}

// execWaitCommand waits for a command to complete on a target hostAddr. It runs the command in a loop with a check duration
// until the command succeeds or the timeout is exceeded.
func (p *Process) execWaitCommand(ctx context.Context, ep execCmdParams) (details string, err error) {
c := p.applyTemplates(ep.cmd.Wait.Command,
templateData{hostAddr: ep.hostAddr, hostName: ep.hostName, task: ep.tsk, command: ep.cmd.Name})
params := config.WaitInternal{Command: c, Timeout: ep.cmd.Wait.Timeout, CheckDuration: ep.cmd.Wait.CheckDuration}

timeout, duration := ep.cmd.Wait.Timeout, ep.cmd.Wait.CheckDuration
if duration == 0 {
duration = 5 * time.Second // default check duration if not set
}
if timeout == 0 {
timeout = time.Hour * 24 // default timeout if not set, wait practically forever
}

details = fmt.Sprintf(" {wait: %s, timeout: %v, duration: %v}",
c, ep.cmd.Wait.Timeout.Truncate(100*time.Millisecond), ep.cmd.Wait.CheckDuration.Truncate(100*time.Millisecond))
if err := p.wait(ctx, ep.exec, params); err != nil {
return details, fmt.Errorf("wait failed on %s: %w", ep.hostAddr, err)
c, timeout.Truncate(100*time.Millisecond), duration.Truncate(100*time.Millisecond))

waitCmd := fmt.Sprintf("sh -c %q", c) // run wait command in a shell
if ep.cmd.Options.Sudo {
details = fmt.Sprintf(" {wait: %s, timeout: %v, duration: %v, sudo: true}",
c, timeout.Truncate(100*time.Millisecond), duration.Truncate(100*time.Millisecond))
waitCmd = fmt.Sprintf("sudo sh -c %q", c) // add sudo if needed
}

checkTk := time.NewTicker(duration)
defer checkTk.Stop()
timeoutTk := time.NewTicker(timeout)
defer timeoutTk.Stop()

for {
select {
case <-ctx.Done():
return details, ctx.Err()
case <-timeoutTk.C:
return details, fmt.Errorf("timeout exceeded")
case <-checkTk.C:
if _, err := ep.exec.Run(ctx, waitCmd, false); err == nil {
return details, nil // command succeeded
}
}
}
return details, nil
}

type tdFn func() error // tdFn is a type for teardown functions, should be called after the command execution
Expand Down Expand Up @@ -378,10 +434,10 @@ func (p *Process) prepScript(ctx context.Context, s string, r io.Reader, ep exec
}

// get temp file name for remote hostAddr
dst := filepath.Join("/tmp", filepath.Base(tmp.Name())) // nolint
dst := filepath.Join(tmpRemoteDir, filepath.Base(tmp.Name())) // nolint

// upload the script to the remote hostAddr
if err = ep.exec.Upload(ctx, tmp.Name(), dst, false); err != nil {
if err = ep.exec.Upload(ctx, tmp.Name(), dst, true); err != nil {
return "", nil, fmt.Errorf("can't upload script to %s: %w", ep.hostAddr, err)
}
remoteCmd := fmt.Sprintf("sh -c %s", dst)
Expand All @@ -397,35 +453,6 @@ func (p *Process) prepScript(ctx context.Context, s string, r io.Reader, ep exec
return remoteCmd, teardown, nil
}

// wait waits for a command to complete on a target hostAddr. It runs the command in a loop with a check duration
// until the command succeeds or the timeout is exceeded.
func (p *Process) wait(ctx context.Context, sess executor.Interface, params config.WaitInternal) error {
if params.Timeout == 0 {
return nil
}
duration := params.CheckDuration
if params.CheckDuration == 0 {
duration = 5 * time.Second // default check duration if not set
}
checkTk := time.NewTicker(duration)
defer checkTk.Stop()
timeoutTk := time.NewTicker(params.Timeout)
defer timeoutTk.Stop()

for {
select {
case <-ctx.Done():
return ctx.Err()
case <-timeoutTk.C:
return fmt.Errorf("timeout exceeded")
case <-checkTk.C:
if _, err := sess.Run(ctx, params.Command, false); err == nil {
return nil
}
}
}
}

type templateData struct {
hostAddr string
hostName string
Expand Down
Loading

0 comments on commit 30987fd

Please sign in to comment.