diff --git a/db/tasks.go b/db/tasks.go index 8b7f5d012..ae8ead63a 100644 --- a/db/tasks.go +++ b/db/tasks.go @@ -191,6 +191,10 @@ func (db *DB) UpdateTaskLog(id uuid.UUID, more string) error { } func (db *DB) CreateTaskArchive(id uuid.UUID, key string, effective time.Time) (uuid.UUID, error) { + // fail on empty store_key, as '' seems to satisfy the NOT NULL constraint in postgres + if key == "" { + return nil, fmt.Errorf("cannot create an archive without a store_key") + } // determine how long we need to keep this specific archive for r, err := db.Query( `SELECT r.expiry diff --git a/db/tasks_test.go b/db/tasks_test.go index 371352448..b2e22e859 100644 --- a/db/tasks_test.go +++ b/db/tasks_test.go @@ -28,6 +28,12 @@ var _ = Describe("Task Management", func() { Ω(n).Should(BeNumerically(">", 0)) } + shouldNotExist := func(q string, params ...interface{}) { + n, err := db.Count(q, params...) + Expect(err).ShouldNot(HaveOccurred()) + Expect(n).Should(BeNumerically("==", 0)) + } + BeforeEach(func() { var err error db, err = Database( @@ -184,7 +190,7 @@ var _ = Describe("Task Management", func() { Ω(db.CompleteTask(id, time.Now())).Should(Succeed()) archive_id, err := db.CreateTaskArchive(id, "SOME-KEY", time.Now()) Expect(err).ShouldNot(HaveOccurred()) - Expect(id).ShouldNot(BeNil()) + Expect(archive_id).ShouldNot(BeNil()) shouldExist(`SELECT * FROM tasks`) shouldExist(`SELECT * FROM tasks WHERE archive_uuid IS NOT NULL`) @@ -197,4 +203,17 @@ var _ = Describe("Task Management", func() { shouldExist(`SELECT * FROM archives WHERE taken_at IS NOT NULL`) shouldExist(`SELECT * FROM archives WHERE expires_at IS NOT NULL`) }) + It("Fails to associate archives with a task, when no restore key is present", func() { + id, err := db.CreateBackupTask("bob", JOB_UUID) + Expect(err).ShouldNot(HaveOccurred()) + Expect(id).ShouldNot(BeNil()) + + Expect(db.StartTask(id, time.Now())).Should(Succeed()) + Expect(db.CompleteTask(id, time.Now())).Should(Succeed()) + archive_id, err := db.CreateTaskArchive(id, "", time.Now()) + Expect(err).Should(HaveOccurred()) + Expect(archive_id).Should(BeNil()) + + shouldNotExist(`SELECT * from archives where store_key = ''`) + }) }) diff --git a/supervisor/supervisor.go b/supervisor/supervisor.go index c70235746..0e5fb5f45 100644 --- a/supervisor/supervisor.go +++ b/supervisor/supervisor.go @@ -268,7 +268,7 @@ func (s *Supervisor) Run() error { case FAILED: log.Warnf(" %s: task failed!", u.Task) - if err := s.Database.FailTask(u.Task, time.Now()); err != nil { + if err := s.Database.FailTask(u.Task, u.StoppedAt); err != nil { log.Errorf(" %s: !! failed to update database - %s", u.Task, err) } @@ -282,6 +282,7 @@ func (s *Supervisor) Run() error { log.Infof(" %s: restore key is %s", u.Task, u.Output) if id, err := s.Database.CreateTaskArchive(u.Task, u.Output, time.Now()); err != nil { log.Errorf(" %s: !! failed to update database - %s", u.Task, err) + } else { if !u.TaskSuccess { s.Database.InvalidateArchive(id) } diff --git a/supervisor/worker.go b/supervisor/worker.go index a08320a5e..a867bfe72 100644 --- a/supervisor/worker.go +++ b/supervisor/worker.go @@ -117,7 +117,6 @@ func worker(id uint, privateKeyFile string, work chan Task, updates chan WorkerU if err != nil { updates <- WorkerUpdate{Task: t.UUID, Op: OUTPUT, Output: fmt.Sprintf("TASK FAILED!! shield worker %d failed to execute the command against the remote agent %s (%s)\n", id, remote, err)} - updates <- WorkerUpdate{Task: t.UUID, Op: FAILED} jobFailed = true } @@ -134,15 +133,22 @@ func worker(id uint, privateKeyFile string, work chan Task, updates chan WorkerU err := dec.Decode(&v) if err != nil { + jobFailed = true updates <- WorkerUpdate{Task: t.UUID, Op: OUTPUT, Output: fmt.Sprintf("WORKER FAILED!! shield worker %d failed to parse JSON response from remote agent %s (%s)\n", id, remote, err)} } else { - updates <- WorkerUpdate{ - Task: t.UUID, - Op: RESTORE_KEY, - TaskSuccess: !jobFailed, - Output: v.Key, + if v.Key != "" { + updates <- WorkerUpdate{ + Task: t.UUID, + Op: RESTORE_KEY, + TaskSuccess: !jobFailed, + Output: v.Key, + } + } else { + jobFailed = true + updates <- WorkerUpdate{Task: t.UUID, Op: OUTPUT, + Output: fmt.Sprintf("TASK FAILED!! No restore key detected in worker %d. Cowardly refusing to create an archive record", id)} } } } @@ -156,10 +162,14 @@ func worker(id uint, privateKeyFile string, work chan Task, updates chan WorkerU } // signal to the supervisor that we finished - updates <- WorkerUpdate{ - Task: t.UUID, - Op: STOPPED, - StoppedAt: time.Now(), + if jobFailed { + updates <- WorkerUpdate{Task: t.UUID, Op: FAILED, StoppedAt: time.Now()} + } else { + updates <- WorkerUpdate{ + Task: t.UUID, + Op: STOPPED, + StoppedAt: time.Now(), + } } } }