diff --git a/acceptance/upload_product_test.go b/acceptance/upload_product_test.go index 41970e209..c225591e4 100644 --- a/acceptance/upload_product_test.go +++ b/acceptance/upload_product_test.go @@ -19,11 +19,11 @@ import ( . "github.com/onsi/gomega" ) -type TLSServer struct { +type UploadProductTestServer struct { UploadHandler http.Handler } -func (t *TLSServer) ServeHTTP(w http.ResponseWriter, req *http.Request) { +func (t *UploadProductTestServer) ServeHTTP(w http.ResponseWriter, req *http.Request) { var responseString string w.Header().Set("Content-Type", "application/json") @@ -104,7 +104,7 @@ name: some-product`) }) JustBeforeEach(func() { - server = httptest.NewTLSServer(&TLSServer{UploadHandler: http.HandlerFunc(uploadHandler)}) + server = httptest.NewTLSServer(&UploadProductTestServer{UploadHandler: http.HandlerFunc(uploadHandler)}) }) AfterEach(func() { diff --git a/acceptance/upload_stemcell_test.go b/acceptance/upload_stemcell_test.go index 7bb8f8615..05616470a 100644 --- a/acceptance/upload_stemcell_test.go +++ b/acceptance/upload_stemcell_test.go @@ -18,11 +18,56 @@ import ( . "github.com/onsi/gomega" ) +type UploadStemcellTestServer struct { + UploadHandler http.Handler +} + +func (t *UploadStemcellTestServer) ServeHTTP(w http.ResponseWriter, req *http.Request) { + var responseString string + w.Header().Set("Content-Type", "application/json") + + switch req.URL.Path { + case "/uaa/oauth/token": + req.ParseForm() + + if req.PostForm.Get("password") == "" { + w.WriteHeader(http.StatusUnauthorized) + return + } + + responseString = `{ + "access_token": "some-opsman-token", + "token_type": "bearer", + "expires_in": 3600 + }` + case "/api/v0/diagnostic_report": + responseString = "{}" + case "/api/v0/stemcells": + auth := req.Header.Get("Authorization") + + if auth != "Bearer some-opsman-token" { + w.WriteHeader(http.StatusUnauthorized) + return + } + + t.UploadHandler.ServeHTTP(w, req) + return + default: + out, err := httputil.DumpRequest(req, true) + Expect(err).NotTo(HaveOccurred()) + Fail(fmt.Sprintf("unexpected request: %s", out)) + } + + w.Write([]byte(responseString)) +} + var _ = Describe("upload-stemcell command", func() { var ( - stemcellName string - content *os.File - server *httptest.Server + stemcellName string + content *os.File + server *httptest.Server + uploadHandler func(http.ResponseWriter, *http.Request) + snip chan struct{} ) BeforeEach(func() { @@ -33,49 +78,19 @@ var _ = Describe("upload-stemcell command", func() { _, err = content.WriteString("content so validation does not fail") Expect(err).NotTo(HaveOccurred()) - server = httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { - var responseString string - w.Header().Set("Content-Type", "application/json") - - switch req.URL.Path { - case "/uaa/oauth/token": - req.ParseForm() - - if req.PostForm.Get("password") == "" { - w.WriteHeader(http.StatusUnauthorized) - return - } - - responseString = `{ - "access_token": "some-opsman-token", - "token_type": "bearer", - "expires_in": 3600 - }` - case "/api/v0/diagnostic_report": - responseString = "{}" - case "/api/v0/stemcells": - auth := req.Header.Get("Authorization") - - if auth != "Bearer some-opsman-token" { - w.WriteHeader(http.StatusUnauthorized) - return - } - - err := req.ParseMultipartForm(100) - if err != nil { - panic(err) - } - - stemcellName = req.MultipartForm.File["stemcell[file]"][0].Filename - responseString = "{}" - default: - out, err := httputil.DumpRequest(req, true) - Expect(err).NotTo(HaveOccurred()) - Fail(fmt.Sprintf("unexpected request: %s", out)) + uploadHandler = func(w http.ResponseWriter, req *http.Request) { + err := req.ParseMultipartForm(100) + if err != nil { + panic(err) } - w.Write([]byte(responseString)) - })) + stemcellName = req.MultipartForm.File["stemcell[file]"][0].Filename + w.Write([]byte("{}")) + } + }) + + JustBeforeEach(func() { + server = httptest.NewTLSServer(&UploadStemcellTestServer{UploadHandler: http.HandlerFunc(uploadHandler)}) }) AfterEach(func() { @@ -214,5 +229,52 @@ var _ = Describe("upload-stemcell command", func() { Eventually(session.Err).Should(gbytes.Say(`no such file or directory`)) }) }) + + Context("when the server returns EOF during upload", func() { + BeforeEach(func() { + snip = make(chan struct{}) + uploadCallCount := 0 + uploadHandler = func(w http.ResponseWriter, req *http.Request) { + uploadCallCount++ + + if uploadCallCount == 1 { + close(snip) + return + } else { + err := req.ParseMultipartForm(100) + if err != nil { + panic(err) + } + + stemcellName = req.MultipartForm.File["stemcell[file]"][0].Filename + w.Write([]byte("{}")) + } + } + }) + + JustBeforeEach(func() { + go func() { + <-snip + + server.CloseClientConnections() + }() + }) + + It("retries the upload", func() { + command := exec.Command(pathToMain, + "--target", server.URL, + "--username", "some-username", + "--password", "some-password", + "--skip-ssl-validation", + "upload-stemcell", + "--stemcell", content.Name(), + ) + + session, err := gexec.Start(command, GinkgoWriter, GinkgoWriter) + Expect(err).NotTo(HaveOccurred()) + + Eventually(session, 5).Should(gexec.Exit(0)) + }) + }) }) }) diff --git a/commands/upload_product.go b/commands/upload_product.go index 6113836e4..0e0d77859 100644 --- a/commands/upload_product.go +++ b/commands/upload_product.go @@ -10,7 +10,7 @@ import ( "github.com/pivotal-cf/om/validator" ) -const maxUploadRetries = 2 +const maxProductUploadRetries = 2 type UploadProduct struct { multipart multipart @@ -97,7 +97,7 @@ func (up UploadProduct) Execute(args []string) error { return nil } - for i := 0; i <= maxUploadRetries; i++ { + for i := 0; i <= maxProductUploadRetries; i++ { up.logger.Printf("processing product") err = up.multipart.AddFile("product[file]", up.Options.Product) @@ -115,7 +115,7 @@ func (up UploadProduct) Execute(args []string) error { ContentLength: submission.ContentLength, PollingInterval: up.Options.PollingInterval, }) - if network.CanRetry(err) && i < maxUploadRetries { + if network.CanRetry(err) && i < maxProductUploadRetries { up.logger.Printf("retrying product upload after error: %s\n", err) up.multipart.Reset() } else { diff --git a/commands/upload_stemcell.go b/commands/upload_stemcell.go index e08d8a3c3..bc8aeff7d 100644 --- a/commands/upload_stemcell.go +++ b/commands/upload_stemcell.go @@ -7,11 +7,14 @@ import ( "github.com/pivotal-cf/jhanda" "github.com/pivotal-cf/om/api" "github.com/pivotal-cf/om/formcontent" + "github.com/pivotal-cf/om/network" "github.com/pivotal-cf/om/validator" "strconv" ) +const maxStemcellUploadRetries = 2 + type UploadStemcell struct { multipart multipart logger logger @@ -94,28 +97,37 @@ func (us UploadStemcell) Execute(args []string) error { } } - err := us.multipart.AddFile("stemcell[file]", us.Options.Stemcell) - if err != nil { - return fmt.Errorf("failed to load stemcell: %s", err) - } - - err = us.multipart.AddField("stemcell[floating]", strconv.FormatBool(us.Options.Floating)) - if err != nil { - return fmt.Errorf("failed to load stemcell: %s", err) - } + var err error + for i := 0; i <= maxStemcellUploadRetries; i++ { + err = us.multipart.AddFile("stemcell[file]", us.Options.Stemcell) + if err != nil { + return fmt.Errorf("failed to load stemcell: %s", err) + } - submission := us.multipart.Finalize() - if err != nil { - return fmt.Errorf("failed to create multipart form: %s", err) - } + err = us.multipart.AddField("stemcell[floating]", strconv.FormatBool(us.Options.Floating)) + if err != nil { + return fmt.Errorf("failed to load stemcell: %s", err) + } - us.logger.Printf("beginning stemcell upload to Ops Manager") + submission := us.multipart.Finalize() + if err != nil { + return fmt.Errorf("failed to create multipart form: %s", err) + } - _, err = us.service.UploadStemcell(api.StemcellUploadInput{ - Stemcell: submission.Content, - ContentType: submission.ContentType, - ContentLength: submission.ContentLength, - }) + us.logger.Printf("beginning stemcell upload to Ops Manager") + + _, err = us.service.UploadStemcell(api.StemcellUploadInput{ + Stemcell: submission.Content, + ContentType: submission.ContentType, + ContentLength: submission.ContentLength, + }) + if network.CanRetry(err) && i < maxStemcellUploadRetries { + us.logger.Printf("retrying stemcell upload after error: %s\n", err) + us.multipart.Reset() + } else { + break + } + } if err != nil { return fmt.Errorf("failed to upload stemcell: %s", err) } diff --git a/commands/upload_stemcell_test.go b/commands/upload_stemcell_test.go index 368f98cc2..7caad4d8d 100644 --- a/commands/upload_stemcell_test.go +++ b/commands/upload_stemcell_test.go @@ -1,12 +1,14 @@ package commands_test import ( - "errors" "fmt" + "io" "io/ioutil" "os" "strings" + "github.com/pkg/errors" + "github.com/pivotal-cf/jhanda" "github.com/pivotal-cf/om/api" "github.com/pivotal-cf/om/commands" @@ -117,6 +119,64 @@ var _ = Describe("UploadStemcell", func() { format, v = logger.PrintfArgsForCall(2) Expect(fmt.Sprintf(format, v...)).To(Equal("finished upload")) }) + + Context("when the product fails to upload the first time with a retryable error", func() { + It("tries again", func() { + submission := formcontent.ContentSubmission{ + Content: ioutil.NopCloser(strings.NewReader("")), + ContentType: "some content-type", + ContentLength: 10, + } + multipart.FinalizeReturns(submission) + + fakeService.GetDiagnosticReportReturns(api.DiagnosticReport{Stemcells: []string{}}, nil) + + command := commands.NewUploadStemcell(multipart, fakeService, logger) + + fakeService.UploadStemcellReturnsOnCall(0, api.StemcellUploadOutput{}, errors.Wrap(io.EOF, "some upload error")) + fakeService.UploadStemcellReturnsOnCall(1, api.StemcellUploadOutput{}, nil) + + err := command.Execute([]string{ + "--stemcell", "/path/to/stemcell.tgz", + }) + Expect(err).NotTo(HaveOccurred()) + + Expect(multipart.AddFileCallCount()).To(Equal(2)) + Expect(multipart.FinalizeCallCount()).To(Equal(2)) + Expect(multipart.ResetCallCount()).To(Equal(1)) + + Expect(fakeService.UploadStemcellCallCount()).To(Equal(2)) + }) + }) + + Context("when the product fails to upload three times", func() { + It("returns an error", func() { + submission := formcontent.ContentSubmission{ + Content: ioutil.NopCloser(strings.NewReader("")), + ContentType: "some content-type", + ContentLength: 10, + } + multipart.FinalizeReturns(submission) + + fakeService.GetDiagnosticReportReturns(api.DiagnosticReport{Stemcells: []string{}}, nil) + + command := commands.NewUploadStemcell(multipart, fakeService, logger) + + fakeService.UploadStemcellReturns(api.StemcellUploadOutput{}, errors.Wrap(io.EOF, "some upload error")) + + err := command.Execute([]string{ + "--stemcell", "/path/to/stemcell.tgz", + }) + Expect(err).To(HaveOccurred()) + Expect(err.Error()).To(ContainSubstring("EOF")) + + Expect(multipart.AddFileCallCount()).To(Equal(3)) + Expect(multipart.FinalizeCallCount()).To(Equal(3)) + Expect(multipart.ResetCallCount()).To(Equal(2)) + + Expect(fakeService.UploadStemcellCallCount()).To(Equal(3)) + }) + }) }) Context("when the stemcell already exists", func() {