Skip to content

Commit

Permalink
Merge pull request #1358 from weaveworks/deflakenize-github-auth-tests
Browse files Browse the repository at this point in the history
Restructure GitHub auth test to remove flakiness
  • Loading branch information
dhwthompson committed Feb 1, 2022
2 parents 6a50f77 + 3526d34 commit a1f07f6
Showing 1 changed file with 109 additions and 50 deletions.
159 changes: 109 additions & 50 deletions pkg/services/auth/github_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,10 @@ import (
"net/http"
"net/http/httptest"
"net/url"
"runtime"
"strings"
"sync"
"time"

"github.com/benbjohnson/clock"
. "github.com/onsi/ginkgo"
. "github.com/onsi/gomega"
fakehttp "github.com/weaveworks/weave-gitops/pkg/vendorfakes/http"
Expand All @@ -38,6 +37,27 @@ func (t *testServerTransport) RoundTrip(r *http.Request) (*http.Response, error)
return t.roundTripper.RoundTrip(r)
}

// sleeper is a very lightweight fake sleep timer. Instead of faking out the system
// clock, we can accept `sleep` calls and keep track of how long we've slept.
type sleeper struct {
mutex sync.Mutex
time time.Time
}

func (t *sleeper) sleep(d time.Duration) {
t.mutex.Lock()
defer t.mutex.Unlock()

t.time = t.time.Add(d)
}

func (t *sleeper) now() time.Time {
t.mutex.Lock()
defer t.mutex.Unlock()

return t.time
}

var _ = Describe("Github Device Flow", func() {
var ts *httptest.Server
var client *http.Client
Expand Down Expand Up @@ -91,57 +111,86 @@ var _ = Describe("Github Device Flow", func() {
Expect(cliOutput.String()).To(ContainSubstring(verificationUri))
})

// These auth flow tests are failing intermittently, so bypass them for now
XDescribe("pollAuthStatus", func() {
It("retries after a slow_down response from github", func() {
rt := newMockRoundTripper(3, token)
client.Transport = &testServerTransport{testServeUrl: ts.URL, roundTripper: rt}
interval := 5 * time.Second
Describe("pollAuthStatus", func() {
var rt *mockAuthRoundTripper
var s *sleeper

// pollTimes is a convenience function to convert from a series of polling intervals
// to their respective polling timestamps, relative to the sleeper type's starting time
pollTimes := func(intervals []time.Duration) []time.Time {
zero := time.Time{}
times := make([]time.Time, len(intervals))
for index, interval := range intervals {
switch index {
case 0:
times[index] = zero.Add(interval)
default:
times[index] = times[index-1].Add(interval)
}
}
return times
}

c := clock.NewMock()
drainPollTimes := func(pollChan <-chan time.Time) (result []time.Time) {
for pollTime := range pollChan {
result = append(result, pollTime)
}
return
}

go func() {
_, err := pollAuthStatus(c.Sleep, interval, client, "somedevicecode")
Expect(err).NotTo(HaveOccurred())
}()
runtime.Gosched()
Context("after a slow_down response from GitHub", func() {
BeforeEach(func() {
s = &sleeper{}
rt = newMockRoundTripper(1, token, s.now)
client.Transport = &testServerTransport{testServeUrl: ts.URL, roundTripper: rt}
})

// check one second after interval
c.Add(interval + 1*time.Second)
Expect(rt.calls).To(Equal(1), "should have tried the first time")
It("retries with a longer interval", func() {
interval := 5 * time.Second

// check during back off
c.Add(interval)
Expect(rt.calls).To(Equal(1), "should NOT have retried early")
_, _ = pollAuthStatus(s.sleep, interval, client, "somedevicecode")

// check one second after back off ended
c.Add(interval + 6*time.Second)
Expect(rt.calls).To(Equal(2), "should have backed off 10 seconds")
})
It("returns a token after a slow_down", func() {
rt := newMockRoundTripper(1, token)
client.Transport = &testServerTransport{testServeUrl: ts.URL, roundTripper: rt}
interval := 5 * time.Second
c := clock.NewMock()

var resultToken string
var err error
go func() {
resultToken, err = pollAuthStatus(c.Sleep, interval, client, "somedevicecode")
expectedPollTimes := pollTimes([]time.Duration{
interval,
interval + 5*time.Second,
})
Expect(drainPollTimes(rt.callChan)).To(Equal(expectedPollTimes))
})

It("returns a token", func() {
interval := 5 * time.Second

resultToken, err := pollAuthStatus(s.sleep, interval, client, "somedevicecode")

Expect(resultToken).To(Equal(token))
Expect(err).NotTo(HaveOccurred())
}()
runtime.Gosched()
})
})

Context("after several slow_down responses from GitHub", func() {
var s *sleeper

BeforeEach(func() {
s = &sleeper{}
rt = newMockRoundTripper(3, token, s.now)
client.Transport = &testServerTransport{testServeUrl: ts.URL, roundTripper: rt}
})

// check 1 second after interval
c.Add(interval + 1*time.Second)
Expect(rt.calls).To(Equal(1), "should have tried the first time")
It("keeps slowing down", func() {
interval := 5 * time.Second

// check 1 second after back off ended
c.Add(interval + 6*time.Second)
Expect(rt.calls).To(Equal(2), "should have tried again after back off")
_, _ = pollAuthStatus(s.sleep, interval, client, "somedevicecode")

Expect(resultToken).To(Equal(token))
expectedPollTimes := pollTimes([]time.Duration{
interval,
interval + 5*time.Second,
interval + 10*time.Second,
interval + 15*time.Second,
})
Expect(drainPollTimes(rt.callChan)).To(Equal(expectedPollTimes))
})
})

})
})

Expand All @@ -164,8 +213,9 @@ var _ = Describe("ValidateToken", func() {
})

type mockAuthRoundTripper struct {
fn func(r *http.Request) (*http.Response, error)
calls int
fn func(r *http.Request) (*http.Response, error)
calls int
callChan chan time.Time
}

func (rt *mockAuthRoundTripper) MockRoundTrip(fn func(r *http.Request) (*http.Response, error)) {
Expand All @@ -176,15 +226,24 @@ func (rt *mockAuthRoundTripper) RoundTrip(r *http.Request) (*http.Response, erro
return rt.fn(r)
}

func newMockRoundTripper(pollCount int, token string) *mockAuthRoundTripper {
rt := &mockAuthRoundTripper{calls: 0}
func newMockRoundTripper(pollCount int, token string, now func() time.Time) *mockAuthRoundTripper {
rt := &mockAuthRoundTripper{calls: 0, callChan: make(chan time.Time, pollCount+1)}

rt.MockRoundTrip(func(r *http.Request) (*http.Response, error) {
b := bytes.NewBuffer(nil)

data := githubAuthResponse{Error: "slow_down"}
if rt.calls == pollCount {
data = githubAuthResponse{Error: "", AccessToken: token}
var data githubAuthResponse

switch {
case rt.calls > pollCount:
panic("mock API called after successful request")
case rt.calls == pollCount:
data = githubAuthResponse{AccessToken: token}
rt.callChan <- now()
close(rt.callChan)
default:
data = githubAuthResponse{Error: "slow_down"}
rt.callChan <- now()
}

if err := json.NewEncoder(b).Encode(data); err != nil {
Expand Down

0 comments on commit a1f07f6

Please sign in to comment.