diff --git a/README.md b/README.md index b7c44f8..737c0f2 100644 --- a/README.md +++ b/README.md @@ -101,6 +101,15 @@ The proxy runs fine natively, but if you wish, you can also create a docker imag ./build.sh user/taskcluster-proxy ``` +## Endpoints + +#### Credentials Update + +The proxy has the endpoint `/credentials` which accepts `PUT` request for +credentials update. The body is a +[Credentials](http://docs.taskcluster.net/queue/api-docs/#claimTask) +object in json format. + ## Running tests diff --git a/authorization_test.go b/authorization_test.go index c871c8c..9575cb1 100644 --- a/authorization_test.go +++ b/authorization_test.go @@ -142,9 +142,11 @@ func TestBewit(t *testing.T) { test := func(t *testing.T, creds *tcclient.Credentials) *httptest.ResponseRecorder { // Test setup - routes := Routes(tcclient.ConnectionData{ - Credentials: creds, - }) + routes := Routes{ + ConnectionData: tcclient.ConnectionData{ + Credentials: creds, + }, + } req, err := http.NewRequest( "POST", "http://localhost:60024/bewit", @@ -186,15 +188,17 @@ func TestAuthorizationDelegate(t *testing.T) { test := func(name string, scopes []string) IntegrationTest { return func(t *testing.T, creds *tcclient.Credentials) *httptest.ResponseRecorder { // Test setup - routes := Routes(tcclient.ConnectionData{ - Authenticate: true, - Credentials: &tcclient.Credentials{ - ClientId: creds.ClientId, - AccessToken: creds.AccessToken, - Certificate: creds.Certificate, - AuthorizedScopes: scopes, + routes := Routes{ + ConnectionData: tcclient.ConnectionData{ + Authenticate: true, + Credentials: &tcclient.Credentials{ + ClientId: creds.ClientId, + AccessToken: creds.AccessToken, + Certificate: creds.Certificate, + AuthorizedScopes: scopes, + }, }, - }) + } // Requires scope "auth:azure-table-access:fakeaccount/DuMmYtAbLe" req, err := http.NewRequest( @@ -229,10 +233,12 @@ func TestAPICallWithPayload(t *testing.T) { test := func(t *testing.T, creds *tcclient.Credentials) *httptest.ResponseRecorder { // Test setup - routes := Routes(tcclient.ConnectionData{ - Authenticate: true, - Credentials: creds, - }) + routes := Routes{ + ConnectionData: tcclient.ConnectionData{ + Authenticate: true, + Credentials: creds, + }, + } taskId := slugid.Nice() taskGroupId := slugid.Nice() created := time.Now() @@ -303,10 +309,12 @@ func TestNon200HasErrorBody(t *testing.T) { test := func(t *testing.T, creds *tcclient.Credentials) *httptest.ResponseRecorder { // Test setup - routes := Routes(tcclient.ConnectionData{ - Authenticate: true, - Credentials: creds, - }) + routes := Routes{ + ConnectionData: tcclient.ConnectionData{ + Authenticate: true, + Credentials: creds, + }, + } taskId := slugid.Nice() req, err := http.NewRequest( @@ -336,10 +344,12 @@ func TestOversteppedScopes(t *testing.T) { test := func(t *testing.T, creds *tcclient.Credentials) *httptest.ResponseRecorder { // Test setup - routes := Routes(tcclient.ConnectionData{ - Authenticate: true, - Credentials: creds, - }) + routes := Routes{ + ConnectionData: tcclient.ConnectionData{ + Authenticate: true, + Credentials: creds, + }, + } // This scope is not in the scopes of the temp credentials, which would // happen if a task declares a scope that the provisioner does not @@ -374,14 +384,16 @@ func TestOversteppedScopes(t *testing.T) { } func TestBadCredsReturns500(t *testing.T) { - routes := Routes(tcclient.ConnectionData{ - Authenticate: true, - Credentials: &tcclient.Credentials{ - ClientId: "abc", - AccessToken: "def", - Certificate: "ghi", // baaaad certificate + routes := Routes{ + ConnectionData: tcclient.ConnectionData{ + Authenticate: true, + Credentials: &tcclient.Credentials{ + ClientId: "abc", + AccessToken: "def", + Certificate: "ghi", // baaaad certificate + }, }, - }) + } req, err := http.NewRequest( "GET", "http://localhost:60024/secrets/v1/secret/garbage/pmoore/foo", diff --git a/credentials_update_test.go b/credentials_update_test.go new file mode 100644 index 0000000..21fd1c3 --- /dev/null +++ b/credentials_update_test.go @@ -0,0 +1,105 @@ +package main + +import ( + "bytes" + "encoding/json" + "github.com/taskcluster/taskcluster-client-go/tcclient" + "io/ioutil" + "net/http" + "net/http/httptest" + "testing" +) + +type RoutesTest struct { + Routes + t *testing.T +} + +func TestCredentialsUpdate(t *testing.T) { + newCreds := CredentialsUpdate{ + ClientId: "newClientId", + AccessToken: "newAccessToken", + Certificate: "newCertificate", + } + + body, err := json.Marshal(&newCreds) + + if err != nil { + t.Fatal(err) + } + + routes := NewRoutesTest(t) + + response := routes.request("POST", body) + if response.Code != 405 { + t.Errorf("Should return 405, but returned %d", response.Code) + } + + response = routes.request("PUT", make([]byte, 0)) + if response.Code != 400 { + t.Errorf("Should return 400, but returned %d", response.Code) + } + + response = routes.request("PUT", body) + if response.Code != 200 { + content, _ := ioutil.ReadAll(response.Body) + t.Fatal("Request error %d: %s", response.Code, string(content)) + } + + if routes.Credentials.ClientId != newCreds.ClientId { + t.Errorf( + "ClientId should be \"%s\", but got \"%s\"", + newCreds.ClientId, + routes.Credentials.ClientId, + ) + } + + if routes.Credentials.AccessToken != newCreds.AccessToken { + t.Errorf( + "AccessToken should be \"%s\", but got \"%s\"", + newCreds.AccessToken, + routes.Credentials.AccessToken, + ) + } + + if routes.Credentials.Certificate != newCreds.Certificate { + t.Errorf( + "Certificate should be \"%s\", but got \"%s\"", + newCreds.Certificate, + routes.Credentials.Certificate, + ) + } +} + +func (self *RoutesTest) request(method string, content []byte) (res *httptest.ResponseRecorder) { + req, err := http.NewRequest( + method, + "http://localhost:8080/credentials", + bytes.NewBuffer(content), + ) + + if err != nil { + self.t.Fatal(err) + } + + req.ContentLength = int64(len(content)) + res = httptest.NewRecorder() + self.ServeHTTP(res, req) + return +} + +func NewRoutesTest(t *testing.T) *RoutesTest { + return &RoutesTest{ + Routes: Routes{ + ConnectionData: tcclient.ConnectionData{ + Authenticate: true, + Credentials: &tcclient.Credentials{ + ClientId: "clientId", + AccessToken: "accessToken", + Certificate: "certificate", + }, + }, + }, + t: t, + } +} diff --git a/main.go b/main.go index c3caa54..e03c744 100644 --- a/main.go +++ b/main.go @@ -98,10 +98,12 @@ func main() { log.Println("Proxy with scopes: ", creds.AuthorizedScopes) - routes := Routes(tcclient.ConnectionData{ - Authenticate: true, - Credentials: creds, - }) + routes := Routes{ + ConnectionData: tcclient.ConnectionData{ + Authenticate: true, + Credentials: creds, + }, + } startError := http.ListenAndServe(fmt.Sprintf(":%d", port), &routes) if startError != nil { diff --git a/routes.go b/routes.go index 7353bba..0faef5b 100644 --- a/routes.go +++ b/routes.go @@ -7,6 +7,7 @@ import ( "log" "net/http" "strings" + "sync" "time" "github.com/taskcluster/httpbackoff" @@ -14,7 +15,16 @@ import ( tc "github.com/taskcluster/taskcluster-proxy/taskcluster" ) -type Routes tcclient.ConnectionData +type Routes struct { + tcclient.ConnectionData + lock sync.RWMutex +} + +type CredentialsUpdate struct { + ClientId string `json:"clientId"` + AccessToken string `json:"accessToken"` + Certificate string `json:"certificate"` +} var tcServices = tc.NewServices() var httpClient = &http.Client{} @@ -31,7 +41,7 @@ func (self *Routes) signUrl(res http.ResponseWriter, req *http.Request) { } urlString := strings.TrimSpace(string(body)) - cd := tcclient.ConnectionData(*self) + cd := tcclient.ConnectionData(self.ConnectionData) bewitUrl, err := (&cd).SignedURL(urlString, nil, time.Hour*1) if err != nil { @@ -46,8 +56,44 @@ func (self *Routes) signUrl(res http.ResponseWriter, req *http.Request) { fmt.Fprintf(res, bewitUrl.String()) } +func (self *Routes) updateCredentials(res http.ResponseWriter, req *http.Request) { + if req.Method != "PUT" { + log.Printf("Invalid method %s\n", req.Method) + res.WriteHeader(405) + return + } + + decoder := json.NewDecoder(req.Body) + + credentials := &CredentialsUpdate{} + err := decoder.Decode(credentials) + + if err != nil { + log.Printf("Could not decode request: %v\n", err) + res.WriteHeader(400) + return + } + + self.lock.Lock() + defer self.lock.Unlock() + self.Credentials.ClientId = credentials.ClientId + self.Credentials.AccessToken = credentials.AccessToken + self.Credentials.Certificate = credentials.Certificate + + res.WriteHeader(200) +} + // Routes implements the `http.Handler` interface func (self *Routes) ServeHTTP(res http.ResponseWriter, req *http.Request) { + if req.URL.Path == "/credentials" { + log.Printf("Update credentials request %s\n", req.URL.String()) + self.updateCredentials(res, req) + return + } + + self.lock.RLock() + defer self.lock.RUnlock() + headersToSend := res.Header() headersToSend.Set("X-Taskcluster-Proxy-Version", version) cert, err := self.Credentials.Cert() @@ -122,7 +168,7 @@ func (self *Routes) ServeHTTP(res http.ResponseWriter, req *http.Request) { } } - cd := tcclient.ConnectionData(*self) + cd := tcclient.ConnectionData(self.ConnectionData) _, cs, err := (&cd).APICall(payload, req.Method, targetPath.String(), new(json.RawMessage), nil) // If we fail to create a request notify the client. if err != nil {