Skip to content
This repository has been archived by the owner on Mar 9, 2021. It is now read-only.

fix client view getter hack #23

Merged
merged 1 commit into from
Mar 27, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion cmd/diffs/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ func serve(parent *kingpin.Application, sps *string, errs io.Writer) {
kc.Action(func(_ *kingpin.ParseContext) error {
ps := fmt.Sprintf(":%d", *port)
log.Printf("Listening on %s...", ps)
s := servepkg.NewService(*sps, accounts.Accounts(), *overrideClientViewURL)
s := servepkg.NewService(*sps, accounts.Accounts(), *overrideClientViewURL, servepkg.ClientViewGetter{})
http.Handle("/", s)
return http.ListenAndServe(fmt.Sprintf(":%d", *port), nil)
})
Expand Down
8 changes: 3 additions & 5 deletions serve/client_view.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,18 +10,16 @@ import (
servetypes "roci.dev/diff-server/serve/types"
)

type ClientViewGetter struct {
url string
}
type ClientViewGetter struct{}

// Get fetches a client view. It returns an error if the response from the data layer doesn't have
// a lastMutationID.
func (g ClientViewGetter) Get(req servetypes.ClientViewRequest, authToken string) (servetypes.ClientViewResponse, error) {
func (g ClientViewGetter) Get(url string, req servetypes.ClientViewRequest, authToken string) (servetypes.ClientViewResponse, error) {
reqBody, err := json.Marshal(req)
if err != nil {
return servetypes.ClientViewResponse{}, fmt.Errorf("could not marshal ClientViewRequest: %w", err)
}
httpReq, err := http.NewRequest("POST", g.url, bytes.NewReader(reqBody))
httpReq, err := http.NewRequest("POST", url, bytes.NewReader(reqBody))
if err != nil {
return servetypes.ClientViewResponse{}, fmt.Errorf("could not create client view http request: %w", err)
}
Expand Down
6 changes: 2 additions & 4 deletions serve/client_view_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -65,10 +65,8 @@ func TestClientViewGetter_Get(t *testing.T) {
w.Write([]byte(tt.respBody))
}))

g := ClientViewGetter{
url: server.URL,
}
got, err := g.Get(tt.req, tt.clientViewAuth)
g := ClientViewGetter{}
got, err := g.Get(server.URL, tt.req, tt.clientViewAuth)
if tt.wantErr == "" {
assert.NoError(err)
} else {
Expand Down
2 changes: 1 addition & 1 deletion serve/hello_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ func TestHello(t *testing.T) {

for i, t := range tc {
td, _ := ioutil.TempDir("", "")
s := NewService(td, []Account{Account{ID: "accountID", Name: "accountID", Pubkey: nil}}, "")
s := NewService(td, []Account{Account{ID: "accountID", Name: "accountID", Pubkey: nil}}, "", nil)

msg := fmt.Sprintf("test case %d", i)
req := httptest.NewRequest(t.method, "/hello", nil)
Expand Down
2 changes: 1 addition & 1 deletion serve/inject_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ func TestInject(t *testing.T) {

for i, t := range tc {
td, _ := ioutil.TempDir("", "")
s := NewService(td, []Account{Account{ID: "accountID", Name: "accountID", Pubkey: nil}}, "")
s := NewService(td, []Account{Account{ID: "accountID", Name: "accountID", Pubkey: nil}}, "", nil)

msg := fmt.Sprintf("test case %d", i)
req := httptest.NewRequest(t.method, "/inject", strings.NewReader(t.req))
Expand Down
2 changes: 1 addition & 1 deletion serve/prod/prod.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ const (
)

var (
svc = serve.NewService("aws:replicant/aa-replicant2", accounts.Accounts(), "")
svc = serve.NewService("aws:replicant/aa-replicant2", accounts.Accounts(), "", serve.ClientViewGetter{})
)

func init() {
Expand Down
11 changes: 3 additions & 8 deletions serve/pull.go
Original file line number Diff line number Diff line change
Expand Up @@ -80,18 +80,13 @@ func (s *Service) pull(rw http.ResponseWriter, req *http.Request) {
return
}

// TODO fritz HACK, the cvg should take a url to Get and should be passed in as a depenency to the service.
clientViewURL := acct.ClientViewURL
if s.overridClientViewURL != "" {
log.Printf("WARNING: overriding all client view URLs with %s", s.overridClientViewURL)
clientViewURL = s.overridClientViewURL
}
if clientViewURL != "" {
s.cvg = ClientViewGetter{url: clientViewURL}
}
cvReq := servetypes.ClientViewRequest{}
maybeGetAndStoreNewClientView(db, preq.ClientViewAuth, s.cvg, cvReq)
s.cvg = nil
maybeGetAndStoreNewClientView(db, preq.ClientViewAuth, clientViewURL, s.clientViewGetter, cvReq)

patch, err := db.Diff(from, *fromChecksum)
if err != nil {
Expand Down Expand Up @@ -128,7 +123,7 @@ func (s *Service) pull(rw http.ResponseWriter, req *http.Request) {
}
}

func maybeGetAndStoreNewClientView(db *db.DB, clientViewAuth string, cvg clientViewGetter, cvReq servetypes.ClientViewRequest) {
func maybeGetAndStoreNewClientView(db *db.DB, clientViewAuth string, url string, cvg clientViewGetter, cvReq servetypes.ClientViewRequest) {
var err error
defer func() {
if err != nil {
Expand All @@ -140,7 +135,7 @@ func maybeGetAndStoreNewClientView(db *db.DB, clientViewAuth string, cvg clientV
err = errors.New("not fetching new client view: no url provided via account or --clientview")
return
}
cvResp, err := cvg.Get(cvReq, clientViewAuth)
cvResp, err := cvg.Get(url, cvReq, clientViewAuth)
if err != nil {
return
}
Expand Down
17 changes: 8 additions & 9 deletions serve/pull_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -168,8 +168,14 @@ func TestAPI(t *testing.T) {
}

for i, t := range tc {
fcvg := fakeClientViewGet{resp: t.CVResponse, err: t.CVErr}
var cvg clientViewGetter
if t.expCVReq != nil {
cvg = &fcvg
}

td, _ := ioutil.TempDir("", "")
s := NewService(td, []Account{Account{ID: "accountID", Name: "accountID", Pubkey: nil}}, "")
s := NewService(td, []Account{Account{ID: "accountID", Name: "accountID", Pubkey: nil}}, "", cvg)
noms, err := s.getNoms("accountID")
assert.NoError(err)
db, err := db.New(noms.GetDataset("client/clientid"))
Expand All @@ -178,13 +184,6 @@ func TestAPI(t *testing.T) {
err = db.PutData(m.NomsMap(), types.String(m.Checksum().String()), 1 /*lastMutationID*/)
assert.NoError(err)

fcvg := fakeClientViewGet{resp: t.CVResponse, err: t.CVErr}
if t.expCVReq == nil {
s.cvg = nil
} else {
s.cvg = &fcvg
}

msg := fmt.Sprintf("test case %d: %s", i, t.pullReq)
req := httptest.NewRequest(t.pullMethod, "/sync", strings.NewReader(t.pullReq))
req.Header.Set("Content-type", "application/json")
Expand Down Expand Up @@ -220,7 +219,7 @@ type fakeClientViewGet struct {
gotAuth string
}

func (f *fakeClientViewGet) Get(req servetypes.ClientViewRequest, authToken string) (servetypes.ClientViewResponse, error) {
func (f *fakeClientViewGet) Get(url string, req servetypes.ClientViewRequest, authToken string) (servetypes.ClientViewResponse, error) {
f.called = true
f.gotReq = req
f.gotAuth = authToken
Expand Down
11 changes: 6 additions & 5 deletions serve/service.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ var (
pathRegex = regexp.MustCompile(`^\/([\w-]+)\/([\w-]+)\/([\w-]+)\/?$`)
)

// Service is a running instance of the Replicant service. A service handles one or more servers.
// Service is a running instance of the Replicant service.
type Service struct {
storageRoot string
urlPrefix string
Expand All @@ -37,11 +37,11 @@ type Service struct {

// cvg may be nil, in which case the server skips the client view request in pull, which is
// useful if you are populating the db directly or in tests.
cvg clientViewGetter
clientViewGetter clientViewGetter
}

type clientViewGetter interface {
Get(req servetypes.ClientViewRequest, authToken string) (servetypes.ClientViewResponse, error)
Get(url string, req servetypes.ClientViewRequest, authToken string) (servetypes.ClientViewResponse, error)
}

// Account is information about a customer of Replicant. This is a stand-in for what will eventually be
Expand All @@ -54,13 +54,14 @@ type Account struct {
}

// NewService creates a new instances of the Replicant web service.
func NewService(storageRoot string, accounts []Account, clientViewURL string) *Service {
func NewService(storageRoot string, accounts []Account, overrideClientViewURL string, cvg clientViewGetter) *Service {
return &Service{
storageRoot: storageRoot,
accounts: accounts,
nomsen: map[string]datas.Database{},
overridClientViewURL: clientViewURL,
overridClientViewURL: overrideClientViewURL,
mu: sync.Mutex{},
clientViewGetter: cvg,
}
}

Expand Down
7 changes: 3 additions & 4 deletions serve/service_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ func DISABLED_TestCheckAccess(t *testing.T) {
accounts = append(accounts, *t.addAccount)
}

svc := NewService(td, accounts, "")
svc := NewService(td, accounts, "", nil)
res := httptest.NewRecorder()

req := httptest.NewRequest("POST", fmt.Sprintf("/%s/pull", t.dbName),
Expand All @@ -117,7 +117,6 @@ func DISABLED_TestCheckAccess(t *testing.T) {
}

func TestConcurrentAccessUsingMultipleServices(t *testing.T) {
// TO
assert := assert.New(t)
td, _ := ioutil.TempDir("", "")

Expand All @@ -129,8 +128,8 @@ func TestConcurrentAccessUsingMultipleServices(t *testing.T) {
},
}

svc1 := NewService(td, accounts, "")
svc2 := NewService(td, accounts, "")
svc1 := NewService(td, accounts, "", nil)
svc2 := NewService(td, accounts, "", nil)

res := []*httptest.ResponseRecorder{
httptest.NewRecorder(),
Expand Down