Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
75 changes: 70 additions & 5 deletions agent/agent.go
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
package agent

import (
"encoding/binary"
"fmt"
"github.com/starkandwayne/goutils/log"
"io"
"net"
"os"
"os/exec"
"syscall"

"golang.org/x/crypto/ssh"
)
Expand Down Expand Up @@ -129,13 +132,75 @@ func (agent *Agent) handleConn(conn *ssh.ServerConn, chans <-chan ssh.NewChannel
// run the agent request
err = request.Run(output)
<-done
rc := []byte{0, 0, 0, 0}
if err != nil {
rc[0] = 1
var rc int
if exitErr, ok := err.(*exec.ExitError); ok {
sys := exitErr.ProcessState.Sys()
// os.ProcessState.Sys() may not return syscall.WaitStatus on non-UNIX machines,
// so currently this feature only works on UNIX, but shouldn't crash on other OSes
if ws, ok := sys.(syscall.WaitStatus); ok {
if ws.Exited() {
rc = ws.ExitStatus()
} else {
var signal syscall.Signal
if ws.Signaled() {
signal = ws.Signal()
}
if ws.Stopped() {
signal = ws.StopSignal()
}
sigStr, ok := SIGSTRING[signal]
if !ok {
sigStr = "ABRT" // use ABRT as catch-all signal for any that don't translate
log.Infof("Task execution terminted due to %s, translating as ABRT for ssh transport", signal)
} else {
log.Infof("Task execution terminated due to SIG%s", sigStr)
}
sigMsg := struct {
Signal string
CoreDumped bool
Error string
Lang string
}{
Signal: sigStr,
CoreDumped: false,
Error: fmt.Sprintf("shield-pipe terminated due to SIG%s", sigStr),
Lang: "en-US",
}
channel.SendRequest("exit-signal", false, ssh.Marshal(&sigMsg))
channel.Close()
continue
}
}
} else if err != nil {
// we got some kind of error that isn't a command execution error,
// from a UNIX system, use an magical error code to signal this to
// the shield daemon - 16777216
log.Infof("Task could not execute: %s", err)
rc = 16777216
}
log.Infof("Task completed with rc=%d", rc[0])
channel.SendRequest("exit-status", false, rc)

log.Infof("Task completed with rc=%d", rc)
byteCode := make([]byte, 4)
binary.BigEndian.PutUint32(byteCode, uint32(rc)) // SSH protocol is big-endian byte ordering
channel.SendRequest("exit-status", false, byteCode)
channel.Close()
}
}
}

// Based on what's handled in https://github.com/golang/crypto/blob/master/ssh/session.go#L21
var SIGSTRING = map[syscall.Signal]string{
syscall.SIGABRT: "ABRT",
syscall.SIGALRM: "ALRM",
syscall.SIGFPE: "FPE",
syscall.SIGHUP: "HUP",
syscall.SIGILL: "ILL",
syscall.SIGINT: "INT",
syscall.SIGKILL: "KILL",
syscall.SIGPIPE: "PIPE",
syscall.SIGQUIT: "QUIT",
syscall.SIGSEGV: "SEGV",
syscall.SIGTERM: "TERM",
syscall.SIGUSR1: "USR1",
syscall.SIGUSR2: "USR2",
}
4 changes: 4 additions & 0 deletions db/tasks.go
Original file line number Diff line number Diff line change
Expand Up @@ -191,6 +191,10 @@ func (db *DB) UpdateTaskLog(id uuid.UUID, more string) error {
}

func (db *DB) CreateTaskArchive(id uuid.UUID, key string, effective time.Time) (uuid.UUID, error) {
// fail on empty store_key, as '' seems to satisfy the NOT NULL constraint in postgres
if key == "" {
return nil, fmt.Errorf("cannot create an archive without a store_key")
}
// determine how long we need to keep this specific archive for
r, err := db.Query(
`SELECT r.expiry
Expand Down
21 changes: 20 additions & 1 deletion db/tasks_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,12 @@ var _ = Describe("Task Management", func() {
Ω(n).Should(BeNumerically(">", 0))
}

shouldNotExist := func(q string, params ...interface{}) {
n, err := db.Count(q, params...)
Expect(err).ShouldNot(HaveOccurred())
Expect(n).Should(BeNumerically("==", 0))
}

BeforeEach(func() {
var err error
db, err = Database(
Expand Down Expand Up @@ -184,7 +190,7 @@ var _ = Describe("Task Management", func() {
Ω(db.CompleteTask(id, time.Now())).Should(Succeed())
archive_id, err := db.CreateTaskArchive(id, "SOME-KEY", time.Now())
Expect(err).ShouldNot(HaveOccurred())
Expect(id).ShouldNot(BeNil())
Expect(archive_id).ShouldNot(BeNil())

shouldExist(`SELECT * FROM tasks`)
shouldExist(`SELECT * FROM tasks WHERE archive_uuid IS NOT NULL`)
Expand All @@ -197,4 +203,17 @@ var _ = Describe("Task Management", func() {
shouldExist(`SELECT * FROM archives WHERE taken_at IS NOT NULL`)
shouldExist(`SELECT * FROM archives WHERE expires_at IS NOT NULL`)
})
It("Fails to associate archives with a task, when no restore key is present", func() {
id, err := db.CreateBackupTask("bob", JOB_UUID)
Expect(err).ShouldNot(HaveOccurred())
Expect(id).ShouldNot(BeNil())

Expect(db.StartTask(id, time.Now())).Should(Succeed())
Expect(db.CompleteTask(id, time.Now())).Should(Succeed())
archive_id, err := db.CreateTaskArchive(id, "", time.Now())
Expect(err).Should(HaveOccurred())
Expect(archive_id).Should(BeNil())

shouldNotExist(`SELECT * from archives where store_key = ''`)
})
})
3 changes: 2 additions & 1 deletion supervisor/supervisor.go
Original file line number Diff line number Diff line change
Expand Up @@ -268,7 +268,7 @@ func (s *Supervisor) Run() error {

case FAILED:
log.Warnf(" %s: task failed!", u.Task)
if err := s.Database.FailTask(u.Task, time.Now()); err != nil {
if err := s.Database.FailTask(u.Task, u.StoppedAt); err != nil {
log.Errorf(" %s: !! failed to update database - %s", u.Task, err)
}

Expand All @@ -282,6 +282,7 @@ func (s *Supervisor) Run() error {
log.Infof(" %s: restore key is %s", u.Task, u.Output)
if id, err := s.Database.CreateTaskArchive(u.Task, u.Output, time.Now()); err != nil {
log.Errorf(" %s: !! failed to update database - %s", u.Task, err)
} else {
if !u.TaskSuccess {
s.Database.InvalidateArchive(id)
}
Expand Down
30 changes: 20 additions & 10 deletions supervisor/worker.go
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,6 @@ func worker(id uint, privateKeyFile string, work chan Task, updates chan WorkerU
if err != nil {
updates <- WorkerUpdate{Task: t.UUID, Op: OUTPUT,
Output: fmt.Sprintf("TASK FAILED!! shield worker %d failed to execute the command against the remote agent %s (%s)\n", id, remote, err)}
updates <- WorkerUpdate{Task: t.UUID, Op: FAILED}
jobFailed = true
}

Expand All @@ -134,15 +133,22 @@ func worker(id uint, privateKeyFile string, work chan Task, updates chan WorkerU
err := dec.Decode(&v)

if err != nil {
jobFailed = true
updates <- WorkerUpdate{Task: t.UUID, Op: OUTPUT,
Output: fmt.Sprintf("WORKER FAILED!! shield worker %d failed to parse JSON response from remote agent %s (%s)\n", id, remote, err)}

} else {
updates <- WorkerUpdate{
Task: t.UUID,
Op: RESTORE_KEY,
TaskSuccess: !jobFailed,
Output: v.Key,
if v.Key != "" {
updates <- WorkerUpdate{
Task: t.UUID,
Op: RESTORE_KEY,
TaskSuccess: !jobFailed,
Output: v.Key,
}
} else {
jobFailed = true
updates <- WorkerUpdate{Task: t.UUID, Op: OUTPUT,
Output: fmt.Sprintf("TASK FAILED!! No restore key detected in worker %d. Cowardly refusing to create an archive record", id)}
}
}
}
Expand All @@ -156,10 +162,14 @@ func worker(id uint, privateKeyFile string, work chan Task, updates chan WorkerU
}

// signal to the supervisor that we finished
updates <- WorkerUpdate{
Task: t.UUID,
Op: STOPPED,
StoppedAt: time.Now(),
if jobFailed {
updates <- WorkerUpdate{Task: t.UUID, Op: FAILED, StoppedAt: time.Now()}
} else {
updates <- WorkerUpdate{
Task: t.UUID,
Op: STOPPED,
StoppedAt: time.Now(),
}
}
}
}
Expand Down