diff --git a/usersession/agent/export_test.go b/usersession/agent/export_test.go index 472e1ebe073..36b1d1243b5 100644 --- a/usersession/agent/export_test.go +++ b/usersession/agent/export_test.go @@ -19,7 +19,22 @@ package agent +import ( + "time" +) + var ( SessionInfoCmd = sessionInfoCmd ServicesCmd = servicesCmd ) + +func MockStopTimeouts(stop, kill time.Duration) (restore func()) { + oldStopTimeout := stopTimeout + stopTimeout = stop + oldKillWait := killWait + killWait = kill + return func() { + stopTimeout = oldStopTimeout + killWait = oldKillWait + } +} diff --git a/usersession/agent/rest_api.go b/usersession/agent/rest_api.go index 02f17291221..e241e66dbbb 100644 --- a/usersession/agent/rest_api.go +++ b/usersession/agent/rest_api.go @@ -65,10 +65,13 @@ type serviceInstruction struct { Services []string `json:"services"` } -var killWait = 5 * time.Second +var ( + stopTimeout = time.Duration(timeout.DefaultTimeout) + killWait = 5 * time.Second +) func stopOneService(service string, sysd systemd.Systemd) error { - err := sysd.Stop(service, time.Duration(timeout.DefaultTimeout)) + err := sysd.Stop(service, stopTimeout) if err != nil && systemd.IsTimeout(err) { // ignore errors for kill; nothing we'd do differently at this point sysd.Kill(service, "TERM", "") @@ -86,7 +89,7 @@ func serviceStart(inst *serviceInstruction, sysd systemd.Systemd) Response { } } - var startErrors map[string]string + startErrors := make(map[string]string) var started []string for _, service := range inst.Services { if err := sysd.Start(service); err != nil { @@ -114,7 +117,7 @@ func serviceStop(inst *serviceInstruction, sysd systemd.Systemd) Response { } } - var stopErrors map[string]string + stopErrors := make(map[string]string) for _, service := range inst.Services { if err := stopOneService(service, sysd); err != nil { stopErrors[service] = err.Error() diff --git a/usersession/agent/rest_api_test.go b/usersession/agent/rest_api_test.go index 2692a0ddbf8..2acf221fd6d 100644 --- a/usersession/agent/rest_api_test.go +++ b/usersession/agent/rest_api_test.go @@ -57,6 +57,8 @@ func (s *restSuite) SetUpTest(c *C) { s.AddCleanup(restore) restore = systemd.MockStopDelays(time.Millisecond, 25*time.Second) s.AddCleanup(restore) + restore = agent.MockStopTimeouts(20*time.Millisecond, time.Millisecond) + s.AddCleanup(restore) } func (s *restSuite) TearDownTest(c *C) { @@ -141,7 +143,7 @@ func (s *restSuite) TestServicesStart(c *C) { var rsp resp c.Assert(json.Unmarshal(rec.Body.Bytes(), &rsp), IsNil) c.Check(rsp.Type, Equals, agent.ResponseTypeSync) - c.Check(rsp.Result, IsNil) + c.Check(rsp.Result, DeepEquals, map[string]interface{}{}) c.Check(s.sysdLog, DeepEquals, [][]string{ {"--user", "start", "snap.foo.service"}, @@ -149,6 +151,66 @@ func (s *restSuite) TestServicesStart(c *C) { }) } +func (s *restSuite) TestServicesStartNonSnap(c *C) { + _, err := agent.New() + c.Assert(err, IsNil) + + req, err := http.NewRequest("POST", "/v1/services", bytes.NewBufferString(`{"action":"start","services":["snap.foo.service", "not-snap.bar.service"]}`)) + req.Header.Set("Content-Type", "application/json") + c.Assert(err, IsNil) + rec := httptest.NewRecorder() + agent.ServicesCmd.POST(agent.ServicesCmd, req).ServeHTTP(rec, req) + c.Check(rec.Code, Equals, 500) + c.Check(rec.HeaderMap.Get("Content-Type"), Equals, "application/json") + + var rsp resp + c.Assert(json.Unmarshal(rec.Body.Bytes(), &rsp), IsNil) + c.Check(rsp.Type, Equals, agent.ResponseTypeError) + c.Check(rsp.Result, DeepEquals, map[string]interface{}{ + "message": "cannot start service not-snap.bar.service", + }) + + // No services were started on the error. + c.Check(s.sysdLog, HasLen, 0) +} + +func (s *restSuite) TestServicesStartFailureStopsServices(c *C) { + var sysdLog [][]string + restore := systemd.MockSystemctl(func(cmd ...string) ([]byte, error) { + sysdLog = append(sysdLog, cmd) + if cmd[0] == "--user" && cmd[1] == "start" && cmd[2] == "snap.bar.service" { + return nil, fmt.Errorf("start failure") + } + return []byte("ActiveState=inactive\n"), nil + }) + defer restore() + + _, err := agent.New() + c.Assert(err, IsNil) + + req, err := http.NewRequest("POST", "/v1/services", bytes.NewBufferString(`{"action":"start","services":["snap.foo.service", "snap.bar.service"]}`)) + req.Header.Set("Content-Type", "application/json") + c.Assert(err, IsNil) + rec := httptest.NewRecorder() + agent.ServicesCmd.POST(agent.ServicesCmd, req).ServeHTTP(rec, req) + c.Check(rec.Code, Equals, 200) + c.Check(rec.HeaderMap.Get("Content-Type"), Equals, "application/json") + + var rsp resp + c.Assert(json.Unmarshal(rec.Body.Bytes(), &rsp), IsNil) + c.Check(rsp.Type, Equals, agent.ResponseTypeSync) + c.Check(rsp.Result, DeepEquals, map[string]interface{}{ + "snap.bar.service": "start failure", + }) + + c.Check(sysdLog, DeepEquals, [][]string{ + {"--user", "start", "snap.foo.service"}, + {"--user", "start", "snap.bar.service"}, + {"--user", "stop", "snap.foo.service"}, + {"--user", "show", "--property=ActiveState", "snap.foo.service"}, + }) +} + func (s *restSuite) TestServicesStop(c *C) { _, err := agent.New() c.Assert(err, IsNil) @@ -164,7 +226,7 @@ func (s *restSuite) TestServicesStop(c *C) { var rsp resp c.Assert(json.Unmarshal(rec.Body.Bytes(), &rsp), IsNil) c.Check(rsp.Type, Equals, agent.ResponseTypeSync) - c.Check(rsp.Result, IsNil) + c.Check(rsp.Result, DeepEquals, map[string]interface{}{}) c.Check(s.sysdLog, DeepEquals, [][]string{ {"--user", "stop", "snap.foo.service"}, @@ -173,3 +235,66 @@ func (s *restSuite) TestServicesStop(c *C) { {"--user", "show", "--property=ActiveState", "snap.bar.service"}, }) } + +func (s *restSuite) TestServicesStopNonSnap(c *C) { + _, err := agent.New() + c.Assert(err, IsNil) + + req, err := http.NewRequest("POST", "/v1/services", bytes.NewBufferString(`{"action":"stop","services":["snap.foo.service", "not-snap.bar.service"]}`)) + req.Header.Set("Content-Type", "application/json") + c.Assert(err, IsNil) + rec := httptest.NewRecorder() + agent.ServicesCmd.POST(agent.ServicesCmd, req).ServeHTTP(rec, req) + c.Check(rec.Code, Equals, 500) + c.Check(rec.HeaderMap.Get("Content-Type"), Equals, "application/json") + + var rsp resp + c.Assert(json.Unmarshal(rec.Body.Bytes(), &rsp), IsNil) + c.Check(rsp.Type, Equals, agent.ResponseTypeError) + c.Check(rsp.Result, DeepEquals, map[string]interface{}{ + "message": "cannot stop service not-snap.bar.service", + }) + + // No services were started on the error. + c.Check(s.sysdLog, HasLen, 0) +} + +func (s *restSuite) TestServicesStopFallbackToKill(c *C) { + var sysdLog [][]string + restore := systemd.MockSystemctl(func(cmd ...string) ([]byte, error) { + // Ignore "show" spam + if cmd[1] != "show" { + sysdLog = append(sysdLog, cmd) + } + if cmd[len(cmd)-1] == "snap.bar.service" { + return []byte("ActiveState=active\n"), nil + } + return []byte("ActiveState=inactive\n"), nil + }) + defer restore() + + _, err := agent.New() + c.Assert(err, IsNil) + + req, err := http.NewRequest("POST", "/v1/services", bytes.NewBufferString(`{"action":"stop","services":["snap.foo.service", "snap.bar.service"]}`)) + req.Header.Set("Content-Type", "application/json") + c.Assert(err, IsNil) + rec := httptest.NewRecorder() + agent.ServicesCmd.POST(agent.ServicesCmd, req).ServeHTTP(rec, req) + c.Check(rec.Code, Equals, 200) + c.Check(rec.HeaderMap.Get("Content-Type"), Equals, "application/json") + + var rsp resp + c.Assert(json.Unmarshal(rec.Body.Bytes(), &rsp), IsNil) + c.Check(rsp.Type, Equals, agent.ResponseTypeSync) + c.Check(rsp.Result, DeepEquals, map[string]interface{}{ + "snap.bar.service": "snap.bar.service failed to stop: timeout", + }) + + c.Check(sysdLog, DeepEquals, [][]string{ + {"--user", "stop", "snap.foo.service"}, + {"--user", "stop", "snap.bar.service"}, + {"--user", "kill", "snap.bar.service", "-s", "TERM", "--kill-who=all"}, + {"--user", "kill", "snap.bar.service", "-s", "KILL", "--kill-who=all"}, + }) +}