diff --git a/daemon/api.go b/daemon/api.go index 0d10281cba6..674346e02d6 100644 --- a/daemon/api.go +++ b/daemon/api.go @@ -968,7 +968,7 @@ func snapUpdateMany(inst *snapInstruction, st *state.State) (*snapInstructionRes } // TODO: use a per-request context - updated, tasksets, err := snapstateUpdateMany(context.TODO(), st, inst.Snaps, inst.userID) + updated, tasksets, err := snapstateUpdateMany(context.TODO(), st, inst.Snaps, inst.userID, nil) if err != nil { return nil, err } diff --git a/daemon/api_test.go b/daemon/api_test.go index 63414e3ab4f..a91f14e556d 100644 --- a/daemon/api_test.go +++ b/daemon/api_test.go @@ -3427,7 +3427,7 @@ func (s *apiSuite) TestRefreshIgnoreValidation(c *check.C) { func (s *apiSuite) TestPostSnapsOp(c *check.C) { assertstateRefreshSnapDeclarations = func(*state.State, int) error { return nil } - snapstateUpdateMany = func(_ context.Context, s *state.State, names []string, userID int) ([]string, []*state.TaskSet, error) { + snapstateUpdateMany = func(_ context.Context, s *state.State, names []string, userID int, flags *snapstate.Flags) ([]string, []*state.TaskSet, error) { c.Check(names, check.HasLen, 0) t := s.NewTask("fake-refresh-all", "Refreshing everything") return []string{"fake1", "fake2"}, []*state.TaskSet{state.NewTaskSet(t)}, nil @@ -3472,7 +3472,7 @@ func (s *apiSuite) TestRefreshAll(c *check.C) { } { refreshSnapDecls = false - snapstateUpdateMany = func(_ context.Context, s *state.State, names []string, userID int) ([]string, []*state.TaskSet, error) { + snapstateUpdateMany = func(_ context.Context, s *state.State, names []string, userID int, flags *snapstate.Flags) ([]string, []*state.TaskSet, error) { c.Check(names, check.HasLen, 0) t := s.NewTask("fake-refresh-all", "Refreshing everything") return tst.snaps, []*state.TaskSet{state.NewTaskSet(t)}, nil @@ -3496,7 +3496,7 @@ func (s *apiSuite) TestRefreshAllNoChanges(c *check.C) { return assertstate.RefreshSnapDeclarations(s, userID) } - snapstateUpdateMany = func(_ context.Context, s *state.State, names []string, userID int) ([]string, []*state.TaskSet, error) { + snapstateUpdateMany = func(_ context.Context, s *state.State, names []string, userID int, flags *snapstate.Flags) ([]string, []*state.TaskSet, error) { c.Check(names, check.HasLen, 0) return nil, nil, nil } @@ -3519,7 +3519,7 @@ func (s *apiSuite) TestRefreshMany(c *check.C) { return nil } - snapstateUpdateMany = func(_ context.Context, s *state.State, names []string, userID int) ([]string, []*state.TaskSet, error) { + snapstateUpdateMany = func(_ context.Context, s *state.State, names []string, userID int, flags *snapstate.Flags) ([]string, []*state.TaskSet, error) { c.Check(names, check.HasLen, 2) t := s.NewTask("fake-refresh-2", "Refreshing two") return names, []*state.TaskSet{state.NewTaskSet(t)}, nil @@ -3544,7 +3544,7 @@ func (s *apiSuite) TestRefreshMany1(c *check.C) { return nil } - snapstateUpdateMany = func(_ context.Context, s *state.State, names []string, userID int) ([]string, []*state.TaskSet, error) { + snapstateUpdateMany = func(_ context.Context, s *state.State, names []string, userID int, flags *snapstate.Flags) ([]string, []*state.TaskSet, error) { c.Check(names, check.HasLen, 1) t := s.NewTask("fake-refresh-1", "Refreshing one") return names, []*state.TaskSet{state.NewTaskSet(t)}, nil diff --git a/image/helpers.go b/image/helpers.go index 91d6194ad3a..921db3c58f1 100644 --- a/image/helpers.go +++ b/image/helpers.go @@ -50,7 +50,7 @@ import ( // A Store can find metadata on snaps, download snaps and fetch assertions. type Store interface { SnapAction(context.Context, []*store.CurrentSnap, []*store.SnapAction, *auth.UserState, *store.RefreshOptions) ([]*snap.Info, error) - Download(ctx context.Context, name, targetFn string, downloadInfo *snap.DownloadInfo, pbar progress.Meter, user *auth.UserState) error + Download(ctx context.Context, name, targetFn string, downloadInfo *snap.DownloadInfo, pbar progress.Meter, user *auth.UserState, dlOpts *store.DownloadOptions) error Assertion(assertType *asserts.AssertionType, primaryKey []string, user *auth.UserState) (asserts.Assertion, error) } @@ -265,7 +265,7 @@ func (tsto *ToolingStore) DownloadSnap(name string, revision snap.Revision, opts os.Exit(1) }() - if err = sto.Download(context.TODO(), name, targetFn, &snap.DownloadInfo, pb, tsto.user); err != nil { + if err = sto.Download(context.TODO(), name, targetFn, &snap.DownloadInfo, pb, tsto.user, nil); err != nil { return "", nil, err } diff --git a/image/image_test.go b/image/image_test.go index b322e7bf478..6fd32b33326 100644 --- a/image/image_test.go +++ b/image/image_test.go @@ -54,7 +54,7 @@ func (s *emptyStore) SnapAction(context.Context, []*store.CurrentSnap, []*store. return nil, fmt.Errorf("cannot find snap") } -func (s *emptyStore) Download(ctx context.Context, name, targetFn string, downloadInfo *snap.DownloadInfo, pbar progress.Meter, user *auth.UserState) error { +func (s *emptyStore) Download(ctx context.Context, name, targetFn string, downloadInfo *snap.DownloadInfo, pbar progress.Meter, user *auth.UserState, dlOpts *store.DownloadOptions) error { return fmt.Errorf("cannot download") } @@ -202,7 +202,7 @@ func (s *imageSuite) SnapAction(_ context.Context, _ []*store.CurrentSnap, actio return nil, fmt.Errorf("no %q in the fake store", actions[0].InstanceName) } -func (s *imageSuite) Download(ctx context.Context, name, targetFn string, downloadInfo *snap.DownloadInfo, pbar progress.Meter, user *auth.UserState) error { +func (s *imageSuite) Download(ctx context.Context, name, targetFn string, downloadInfo *snap.DownloadInfo, pbar progress.Meter, user *auth.UserState, dlOpts *store.DownloadOptions) error { return osutil.CopyFile(s.downloadedSnaps[name], targetFn, 0) } diff --git a/overlord/configstate/configcore/corecfg.go b/overlord/configstate/configcore/corecfg.go index a99587f4583..87566e4c867 100644 --- a/overlord/configstate/configcore/corecfg.go +++ b/overlord/configstate/configcore/corecfg.go @@ -101,6 +101,9 @@ func Run(tr Conf) error { if err := validateRefreshSchedule(tr); err != nil { return err } + if err := validateRefreshRateLimit(tr); err != nil { + return err + } if err := validateExperimentalSettings(tr); err != nil { return err } diff --git a/overlord/configstate/configcore/refresh.go b/overlord/configstate/configcore/refresh.go index 4cd13d2f508..48fc626be50 100644 --- a/overlord/configstate/configcore/refresh.go +++ b/overlord/configstate/configcore/refresh.go @@ -25,6 +25,7 @@ import ( "time" "github.com/snapcore/snapd/overlord/devicestate" + "github.com/snapcore/snapd/strutil" "github.com/snapcore/snapd/timeutil" ) @@ -34,6 +35,7 @@ func init() { supportedConfigurations["core.refresh.timer"] = true supportedConfigurations["core.refresh.metered"] = true supportedConfigurations["core.refresh.retain"] = true + supportedConfigurations["core.refresh.rate-limit"] = true } func validateRefreshSchedule(tr Conf) error { @@ -114,3 +116,18 @@ func validateRefreshSchedule(tr Conf) error { _, err = timeutil.ParseLegacySchedule(refreshScheduleStr) return err } + +func validateRefreshRateLimit(tr Conf) error { + refreshRateLimit, err := coreCfg(tr, "refresh.rate-limit") + if err != nil { + return err + } + // reset is fine + if len(refreshRateLimit) == 0 { + return nil + } + if _, err := strutil.ParseByteSize(refreshRateLimit); err != nil { + return err + } + return nil +} diff --git a/overlord/hookstate/ctlcmd/services_test.go b/overlord/hookstate/ctlcmd/services_test.go index 18b20c3070f..46cf45f5dbe 100644 --- a/overlord/hookstate/ctlcmd/services_test.go +++ b/overlord/hookstate/ctlcmd/services_test.go @@ -346,7 +346,7 @@ func (s *servicectlSuite) TestQueuedCommandsUpdateMany(c *C) { s.st.Lock() chg := s.st.NewChange("update many change", "update change") - installed, tts, err := snapstate.UpdateMany(context.TODO(), s.st, []string{"test-snap", "other-snap"}, 0) + installed, tts, err := snapstate.UpdateMany(context.TODO(), s.st, []string{"test-snap", "other-snap"}, 0, nil) c.Assert(err, IsNil) sort.Strings(installed) c.Check(installed, DeepEquals, []string{"other-snap", "test-snap"}) diff --git a/overlord/managers_test.go b/overlord/managers_test.go index 42dc7b5efd1..3a5664de87f 100644 --- a/overlord/managers_test.go +++ b/overlord/managers_test.go @@ -1023,7 +1023,7 @@ version: @VERSION@ snapPath, _ = ms.makeStoreTestSnap(c, strings.Replace(snapYamlContent, "@VERSION@", ver, -1), revno) ms.serveSnap(snapPath, revno) - updated, tss, err := snapstate.UpdateMany(context.TODO(), st, []string{"foo"}, 0) + updated, tss, err := snapstate.UpdateMany(context.TODO(), st, []string{"foo"}, 0, nil) c.Check(updated, IsNil) c.Check(tss, IsNil) // no validation we, get an error @@ -1043,7 +1043,7 @@ version: @VERSION@ c.Assert(err, IsNil) // ... and try again - updated, tss, err = snapstate.UpdateMany(context.TODO(), st, []string{"foo"}, 0) + updated, tss, err = snapstate.UpdateMany(context.TODO(), st, []string{"foo"}, 0, nil) c.Assert(err, IsNil) c.Assert(updated, DeepEquals, []string{"foo"}) c.Assert(tss, HasLen, 1) @@ -1590,7 +1590,7 @@ apps: ms.serveSnap(fooPath, "15") // refresh all - updated, tss, err := snapstate.UpdateMany(context.TODO(), st, nil, 0) + updated, tss, err := snapstate.UpdateMany(context.TODO(), st, nil, 0, nil) c.Assert(err, IsNil) c.Assert(updated, DeepEquals, []string{"foo"}) c.Assert(tss, HasLen, 1) @@ -1836,7 +1836,7 @@ apps: err = assertstate.RefreshSnapDeclarations(st, 0) c.Assert(err, IsNil) - updated, tss, err := snapstate.UpdateMany(context.TODO(), st, nil, 0) + updated, tss, err := snapstate.UpdateMany(context.TODO(), st, nil, 0, nil) c.Assert(err, IsNil) sort.Strings(updated) c.Assert(updated, DeepEquals, []string{"bar", "foo"}) @@ -2361,7 +2361,7 @@ version: @VERSION@` err := assertstate.RefreshSnapDeclarations(st, 0) c.Assert(err, IsNil) - updates, tts, err := snapstate.UpdateMany(context.TODO(), st, []string{"core", "some-snap", "other-snap"}, 0) + updates, tts, err := snapstate.UpdateMany(context.TODO(), st, []string{"core", "some-snap", "other-snap"}, 0, nil) c.Assert(err, IsNil) c.Check(updates, HasLen, 3) c.Assert(tts, HasLen, 3) diff --git a/overlord/snapstate/backend.go b/overlord/snapstate/backend.go index 11e91f8bf05..644b93a14f1 100644 --- a/overlord/snapstate/backend.go +++ b/overlord/snapstate/backend.go @@ -42,7 +42,7 @@ type StoreService interface { Sections(ctx context.Context, user *auth.UserState) ([]string, error) WriteCatalogs(ctx context.Context, names io.Writer, adder store.SnapAdder) error - Download(context.Context, string, string, *snap.DownloadInfo, progress.Meter, *auth.UserState) error + Download(context.Context, string, string, *snap.DownloadInfo, progress.Meter, *auth.UserState, *store.DownloadOptions) error Assertion(assertType *asserts.AssertionType, primaryKey []string, user *auth.UserState) (asserts.Assertion, error) diff --git a/overlord/snapstate/backend_test.go b/overlord/snapstate/backend_test.go index 17463d90432..fa35cbd5483 100644 --- a/overlord/snapstate/backend_test.go +++ b/overlord/snapstate/backend_test.go @@ -97,6 +97,7 @@ type fakeDownload struct { name string macaroon string target string + opts *store.DownloadOptions } type byName []store.CurrentSnap @@ -485,7 +486,7 @@ func (f *fakeStore) SuggestedCurrency() string { return "XTS" } -func (f *fakeStore) Download(ctx context.Context, name, targetFn string, snapInfo *snap.DownloadInfo, pb progress.Meter, user *auth.UserState) error { +func (f *fakeStore) Download(ctx context.Context, name, targetFn string, snapInfo *snap.DownloadInfo, pb progress.Meter, user *auth.UserState, dlOpts *store.DownloadOptions) error { f.pokeStateLock() if _, key := snap.SplitInstanceName(name); key != "" { @@ -495,10 +496,15 @@ func (f *fakeStore) Download(ctx context.Context, name, targetFn string, snapInf if user != nil { macaroon = user.StoreMacaroon } + // only add the options if they contain anything interessting + if *dlOpts == (store.DownloadOptions{}) { + dlOpts = nil + } f.downloads = append(f.downloads, fakeDownload{ macaroon: macaroon, name: name, target: targetFn, + opts: dlOpts, }) f.fakeBackend.ops = append(f.fakeBackend.ops, fakeOp{op: "storesvc-download", name: name}) diff --git a/overlord/snapstate/flags.go b/overlord/snapstate/flags.go index 50414517283..e9ff49d54e4 100644 --- a/overlord/snapstate/flags.go +++ b/overlord/snapstate/flags.go @@ -57,6 +57,9 @@ type Flags struct { // Amend allows refreshing out of a snap unknown to the store // and into one that is known. Amend bool `json:"amend,omitempty"` + + // IsAutoRefresh is true if the snap is currently auto-refreshed + IsAutoRefresh bool `json:"is-auto-refresh,omitempty"` } // DevModeAllowed returns whether a snap can be installed with devmode confinement (either set or overridden) diff --git a/overlord/snapstate/handlers.go b/overlord/snapstate/handlers.go index 0ccd493de27..1c328514100 100644 --- a/overlord/snapstate/handlers.go +++ b/overlord/snapstate/handlers.go @@ -43,6 +43,8 @@ import ( "github.com/snapcore/snapd/overlord/state" "github.com/snapcore/snapd/release" "github.com/snapcore/snapd/snap" + "github.com/snapcore/snapd/store" + "github.com/snapcore/snapd/strutil" ) // hook setup by devicestate @@ -401,6 +403,23 @@ func installInfoUnlocked(st *state.State, snapsup *SnapSetup) (*snap.Info, error return installInfo(st, snapsup.InstanceName(), snapsup.Channel, snapsup.Revision(), snapsup.UserID) } +// autoRefreshRateLimited returns the rate limit of auto-refreshes or 0 if +// there is no limit. +func autoRefreshRateLimited(st *state.State) (rate int64) { + tr := config.NewTransaction(st) + + var rateLimit string + err := tr.Get("core", "refresh.rate-limit", &rateLimit) + if err != nil { + return 0 + } + val, err := strutil.ParseByteSize(rateLimit) + if err != nil { + return 0 + } + return val +} + func (m *SnapManager) doDownloadSnap(t *state.Task, tomb *tomb.Tomb) error { st := t.State() st.Lock() @@ -412,6 +431,7 @@ func (m *SnapManager) doDownloadSnap(t *state.Task, tomb *tomb.Tomb) error { st.Lock() theStore := Store(st) + rate := autoRefreshRateLimited(st) user, err := userFromUserID(st, snapsup.UserID) st.Unlock() if err != nil { @@ -420,6 +440,11 @@ func (m *SnapManager) doDownloadSnap(t *state.Task, tomb *tomb.Tomb) error { meter := NewTaskProgressAdapterUnlocked(t) targetFn := snapsup.MountFile() + + dlOpts := &store.DownloadOptions{} + if snapsup.IsAutoRefresh && rate > 0 { + dlOpts.RateLimit = rate + } if snapsup.DownloadInfo == nil { var storeInfo *snap.Info // COMPATIBILITY - this task was created from an older version @@ -429,10 +454,10 @@ func (m *SnapManager) doDownloadSnap(t *state.Task, tomb *tomb.Tomb) error { if err != nil { return err } - err = theStore.Download(tomb.Context(nil), snapsup.SnapName(), targetFn, &storeInfo.DownloadInfo, meter, user) + err = theStore.Download(tomb.Context(nil), snapsup.SnapName(), targetFn, &storeInfo.DownloadInfo, meter, user, dlOpts) snapsup.SideInfo = &storeInfo.SideInfo } else { - err = theStore.Download(tomb.Context(nil), snapsup.SnapName(), targetFn, snapsup.DownloadInfo, meter, user) + err = theStore.Download(tomb.Context(nil), snapsup.SnapName(), targetFn, snapsup.DownloadInfo, meter, user, dlOpts) } if err != nil { return err diff --git a/overlord/snapstate/handlers_download_test.go b/overlord/snapstate/handlers_download_test.go index adcd29eebc1..f988e1a1720 100644 --- a/overlord/snapstate/handlers_download_test.go +++ b/overlord/snapstate/handlers_download_test.go @@ -20,8 +20,12 @@ package snapstate_test import ( + "path/filepath" + . "gopkg.in/check.v1" + "github.com/snapcore/snapd/dirs" + "github.com/snapcore/snapd/overlord/configstate/config" "github.com/snapcore/snapd/overlord/snapstate" "github.com/snapcore/snapd/overlord/state" "github.com/snapcore/snapd/snap" @@ -186,3 +190,47 @@ func (s *downloadSnapSuite) TestDoUndoDownloadSnap(c *C) { c.Assert(err, Equals, state.ErrNoState) } + +func (s *downloadSnapSuite) TestDoDownloadRateLimitedIntegration(c *C) { + s.state.Lock() + + // set auto-refresh rate-limit + tr := config.NewTransaction(s.state) + tr.Set("core", "refresh.rate-limit", "1234B") + tr.Commit() + + // setup fake auto-refresh download + si := &snap.SideInfo{ + RealName: "foo", + SnapID: "foo-id", + Revision: snap.R(11), + } + t := s.state.NewTask("download-snap", "test") + t.Set("snap-setup", &snapstate.SnapSetup{ + SideInfo: si, + DownloadInfo: &snap.DownloadInfo{ + DownloadURL: "http://some-url.com/snap", + }, + Flags: snapstate.Flags{ + IsAutoRefresh: true, + }, + }) + s.state.NewChange("dummy", "...").AddTask(t) + + s.state.Unlock() + + s.se.Ensure() + s.se.Wait() + + // ensure that rate limit was honored + c.Assert(s.fakeStore.downloads, DeepEquals, []fakeDownload{ + { + name: "foo", + target: filepath.Join(dirs.SnapBlobDir, "foo_11.snap"), + opts: &store.DownloadOptions{ + RateLimit: 1234, + }, + }, + }) + +} diff --git a/overlord/snapstate/snapstate.go b/overlord/snapstate/snapstate.go index 461bd62e3ed..0f2984441fe 100644 --- a/overlord/snapstate/snapstate.go +++ b/overlord/snapstate/snapstate.go @@ -669,7 +669,10 @@ var ValidateRefreshes func(st *state.State, refreshes []*snap.Info, ignoreValida // UpdateMany updates everything from the given list of names that the // store says is updateable. If the list is empty, update everything. // Note that the state must be locked by the caller. -func UpdateMany(ctx context.Context, st *state.State, names []string, userID int) ([]string, []*state.TaskSet, error) { +func UpdateMany(ctx context.Context, st *state.State, names []string, userID int, flags *Flags) ([]string, []*state.TaskSet, error) { + if flags == nil { + flags = &Flags{} + } user, err := userFromUserID(st, userID) if err != nil { return nil, nil, err @@ -698,10 +701,14 @@ func UpdateMany(ctx context.Context, st *state.State, names []string, userID int } - return doUpdate(st, names, updates, params, userID) + return doUpdate(ctx, st, names, updates, params, userID, flags) } -func doUpdate(st *state.State, names []string, updates []*snap.Info, params func(*snap.Info) (channel string, flags Flags, snapst *SnapState), userID int) ([]string, []*state.TaskSet, error) { +func doUpdate(ctx context.Context, st *state.State, names []string, updates []*snap.Info, params func(*snap.Info) (channel string, flags Flags, snapst *SnapState), userID int, globalFlags *Flags) ([]string, []*state.TaskSet, error) { + if globalFlags == nil { + globalFlags = &Flags{} + } + tasksets := make([]*state.TaskSet, 0, len(updates)) refreshAll := len(names) == 0 @@ -756,7 +763,7 @@ func doUpdate(st *state.State, names []string, updates []*snap.Info, params func // and bases and then other snaps for _, update := range updates { channel, flags, snapst := params(update) - + flags.IsAutoRefresh = globalFlags.IsAutoRefresh if err := validateInfoAndFlags(update, snapst, flags); err != nil { if refreshAll { logger.Noticef("cannot update %q: %v", update.InstanceName(), err) @@ -1075,7 +1082,7 @@ func Update(st *state.State, name, channel string, revision snap.Revision, userI return channel, flags, &snapst } - _, tts, err := doUpdate(st, []string{name}, updates, params, userID) + _, tts, err := doUpdate(context.TODO(), st, []string{name}, updates, params, userID, &flags) if err != nil { return nil, err } @@ -1189,7 +1196,7 @@ func AutoRefresh(ctx context.Context, st *state.State) ([]string, []*state.TaskS } } - return UpdateMany(ctx, st, nil, userID) + return UpdateMany(ctx, st, nil, userID, &Flags{IsAutoRefresh: true}) } // Enable sets a snap to the active state diff --git a/overlord/snapstate/snapstate_test.go b/overlord/snapstate/snapstate_test.go index 686ec207aad..9bf2c1e8168 100644 --- a/overlord/snapstate/snapstate_test.go +++ b/overlord/snapstate/snapstate_test.go @@ -439,6 +439,19 @@ func verifyRemoveTasks(c *C, ts *state.TaskSet) { verifyStopReason(c, ts, "remove") } +func checkIsAutoRefresh(c *C, tasks []*state.Task, expected bool) { + for _, t := range tasks { + if t.Kind() == "download-snap" { + var snapsup snapstate.SnapSetup + err := t.Get("snap-setup", &snapsup) + c.Assert(err, IsNil) + c.Check(snapsup.IsAutoRefresh, Equals, expected) + return + } + } + c.Fatalf("cannot find download-snap task in %q", tasks) +} + func (s *snapmgrTestSuite) TestLastIndexFindsLast(c *C) { snapst := &snapstate.SnapState{Sequence: []*snap.SideInfo{ {Revision: snap.R(7)}, @@ -738,7 +751,7 @@ func (s *snapmgrTestSuite) TestUpdateMany(c *C) { SnapType: "app", }) - updates, tts, err := snapstate.UpdateMany(context.TODO(), s.state, nil, 0) + updates, tts, err := snapstate.UpdateMany(context.TODO(), s.state, nil, 0, nil) c.Assert(err, IsNil) c.Assert(tts, HasLen, 1) c.Check(updates, DeepEquals, []string{"some-snap"}) @@ -751,6 +764,8 @@ func (s *snapmgrTestSuite) TestUpdateMany(c *C) { c.Assert(t.Lanes(), DeepEquals, []int{1}) } c.Assert(s.state.TaskCount(), Equals, len(ts.Tasks())) + + checkIsAutoRefresh(c, ts.Tasks(), false) } func (s *snapmgrTestSuite) TestParallelInstanceUpdateMany(c *C) { @@ -784,7 +799,7 @@ func (s *snapmgrTestSuite) TestParallelInstanceUpdateMany(c *C) { InstanceKey: "instance", }) - updates, tts, err := snapstate.UpdateMany(context.TODO(), s.state, nil, 0) + updates, tts, err := snapstate.UpdateMany(context.TODO(), s.state, nil, 0, nil) c.Assert(err, IsNil) c.Assert(tts, HasLen, 2) // ensure stable ordering of updates list @@ -827,7 +842,7 @@ func (s *snapmgrTestSuite) TestUpdateManyDevModeConfinementFiltering(c *C) { }) // updated snap is devmode, updatemany doesn't update it - _, tts, _ := snapstate.UpdateMany(context.TODO(), s.state, []string{"some-snap"}, s.user.ID) + _, tts, _ := snapstate.UpdateMany(context.TODO(), s.state, []string{"some-snap"}, s.user.ID, nil) // FIXME: UpdateMany will not error out in this case (daemon catches this case, with a weird error) c.Assert(tts, HasLen, 0) } @@ -849,7 +864,7 @@ func (s *snapmgrTestSuite) TestUpdateManyClassicConfinementFiltering(c *C) { }) // if a snap installed without --classic gets a classic update it isn't installed - _, tts, _ := snapstate.UpdateMany(context.TODO(), s.state, []string{"some-snap"}, s.user.ID) + _, tts, _ := snapstate.UpdateMany(context.TODO(), s.state, []string{"some-snap"}, s.user.ID, nil) // FIXME: UpdateMany will not error out in this case (daemon catches this case, with a weird error) c.Assert(tts, HasLen, 0) } @@ -872,7 +887,7 @@ func (s *snapmgrTestSuite) TestUpdateManyClassic(c *C) { }) // snap installed with classic: refresh gets classic - _, tts, err := snapstate.UpdateMany(context.TODO(), s.state, []string{"some-snap"}, s.user.ID) + _, tts, err := snapstate.UpdateMany(context.TODO(), s.state, []string{"some-snap"}, s.user.ID, nil) c.Assert(err, IsNil) c.Assert(tts, HasLen, 1) } @@ -891,7 +906,7 @@ func (s *snapmgrTestSuite) TestUpdateManyDevMode(c *C) { SnapType: "app", }) - updates, _, err := snapstate.UpdateMany(context.TODO(), s.state, []string{"some-snap"}, 0) + updates, _, err := snapstate.UpdateMany(context.TODO(), s.state, []string{"some-snap"}, 0, nil) c.Assert(err, IsNil) c.Check(updates, HasLen, 1) } @@ -910,7 +925,7 @@ func (s *snapmgrTestSuite) TestUpdateAllDevMode(c *C) { SnapType: "app", }) - updates, _, err := snapstate.UpdateMany(context.TODO(), s.state, nil, 0) + updates, _, err := snapstate.UpdateMany(context.TODO(), s.state, nil, 0, nil) c.Assert(err, IsNil) c.Check(updates, HasLen, 0) } @@ -947,7 +962,7 @@ func (s *snapmgrTestSuite) TestUpdateManyWaitForBasesUC16(c *C) { Channel: "channel-for-base", }) - updates, tts, err := snapstate.UpdateMany(context.TODO(), s.state, []string{"some-snap", "core", "some-base"}, 0) + updates, tts, err := snapstate.UpdateMany(context.TODO(), s.state, []string{"some-snap", "core", "some-base"}, 0, nil) c.Assert(err, IsNil) c.Assert(tts, HasLen, 3) c.Check(updates, HasLen, 3) @@ -1025,7 +1040,7 @@ func (s *snapmgrTestSuite) TestUpdateManyWaitForBasesUC18(c *C) { Channel: "channel-for-base", }) - updates, tts, err := snapstate.UpdateMany(context.TODO(), s.state, []string{"some-snap", "core18", "some-base", "snapd"}, 0) + updates, tts, err := snapstate.UpdateMany(context.TODO(), s.state, []string{"some-snap", "core18", "some-base", "snapd"}, 0, nil) c.Assert(err, IsNil) c.Assert(tts, HasLen, 4) c.Check(updates, HasLen, 4) @@ -1089,7 +1104,7 @@ func (s *snapmgrTestSuite) TestUpdateManyValidateRefreshes(c *C) { // hook it up snapstate.ValidateRefreshes = validateRefreshes - updates, tts, err := snapstate.UpdateMany(context.TODO(), s.state, nil, 0) + updates, tts, err := snapstate.UpdateMany(context.TODO(), s.state, nil, 0, nil) c.Assert(err, IsNil) c.Assert(tts, HasLen, 1) c.Check(updates, DeepEquals, []string{"some-snap"}) @@ -1122,13 +1137,13 @@ func (s *snapmgrTestSuite) TestUpdateManyValidateRefreshesUnhappy(c *C) { snapstate.ValidateRefreshes = validateRefreshes // refresh all => no error - updates, tts, err := snapstate.UpdateMany(context.TODO(), s.state, nil, 0) + updates, tts, err := snapstate.UpdateMany(context.TODO(), s.state, nil, 0, nil) c.Assert(err, IsNil) c.Check(tts, HasLen, 0) c.Check(updates, HasLen, 0) // refresh some-snap => report error - updates, tts, err = snapstate.UpdateMany(context.TODO(), s.state, []string{"some-snap"}, 0) + updates, tts, err = snapstate.UpdateMany(context.TODO(), s.state, []string{"some-snap"}, 0, nil) c.Assert(err, Equals, validateErr) c.Check(tts, HasLen, 0) c.Check(updates, HasLen, 0) @@ -3125,7 +3140,7 @@ func (s *snapmgrTestSuite) TestUpdateManyMultipleCredsNoUserRunThrough(c *C) { chg := s.state.NewChange("refresh", "refresh all snaps") // no user is passed to use for UpdateMany - updated, tts, err := snapstate.UpdateMany(context.TODO(), s.state, nil, 0) + updated, tts, err := snapstate.UpdateMany(context.TODO(), s.state, nil, 0, nil) c.Assert(err, IsNil) for _, ts := range tts { chg.AddAll(ts) @@ -3215,7 +3230,7 @@ func (s *snapmgrTestSuite) TestUpdateManyMultipleCredsUserRunThrough(c *C) { chg := s.state.NewChange("refresh", "refresh all snaps") // do UpdateMany using user 2 as fallback - updated, tts, err := snapstate.UpdateMany(context.TODO(), s.state, nil, 2) + updated, tts, err := snapstate.UpdateMany(context.TODO(), s.state, nil, 2, nil) c.Assert(err, IsNil) for _, ts := range tts { chg.AddAll(ts) @@ -3324,7 +3339,7 @@ func (s *snapmgrTestSuite) TestUpdateManyMultipleCredsUserWithNoStoreAuthRunThro chg := s.state.NewChange("refresh", "refresh all snaps") // no user is passed to use for UpdateMany - updated, tts, err := snapstate.UpdateMany(context.TODO(), s.state, nil, 0) + updated, tts, err := snapstate.UpdateMany(context.TODO(), s.state, nil, 0, nil) c.Assert(err, IsNil) for _, ts := range tts { chg.AddAll(ts) @@ -4190,7 +4205,7 @@ func (s *snapmgrTestSuite) TestUpdateIgnoreValidationSticky(c *C) { s.fakeStore.refreshRevnos = map[string]snap.Revision{ "some-snap-id": snap.R(12), } - _, tts, err := snapstate.UpdateMany(context.TODO(), s.state, []string{"some-snap"}, s.user.ID) + _, tts, err := snapstate.UpdateMany(context.TODO(), s.state, []string{"some-snap"}, s.user.ID, nil) c.Assert(err, IsNil) c.Check(tts, HasLen, 1) @@ -4417,7 +4432,7 @@ func (s *snapmgrTestSuite) TestMultiUpdateBlockedRevision(c *C) { SnapType: "app", }) - updates, _, err := snapstate.UpdateMany(context.TODO(), s.state, []string{"some-snap"}, s.user.ID) + updates, _, err := snapstate.UpdateMany(context.TODO(), s.state, []string{"some-snap"}, s.user.ID, nil) c.Assert(err, IsNil) c.Check(updates, DeepEquals, []string{"some-snap"}) @@ -4456,7 +4471,7 @@ func (s *snapmgrTestSuite) TestAllUpdateBlockedRevision(c *C) { Current: si7.Revision, }) - updates, _, err := snapstate.UpdateMany(context.TODO(), s.state, nil, s.user.ID) + updates, _, err := snapstate.UpdateMany(context.TODO(), s.state, nil, s.user.ID, nil) c.Check(err, IsNil) c.Check(updates, HasLen, 0) @@ -4560,7 +4575,7 @@ func (s *snapmgrTestSuite) TestUpdateManyAutoAliasesScenarios(c *C) { snapstate.Set(s.state, snapName, &snapst) } - updates, tts, err := snapstate.UpdateMany(context.TODO(), s.state, scenario.names, s.user.ID) + updates, tts, err := snapstate.UpdateMany(context.TODO(), s.state, scenario.names, s.user.ID, nil) c.Check(err, IsNil) _, dropped, err := snapstate.AutoAliasesDelta(s.state, []string{"some-snap", "other-snap"}) @@ -7723,6 +7738,8 @@ func (s *snapmgrTestSuite) TestEnsureRefreshesWithUpdate(c *C) { c.Check(chg.Kind(), Equals, "auto-refresh") c.Check(chg.IsReady(), Equals, false) s.verifyRefreshLast(c) + + checkIsAutoRefresh(c, chg.Tasks(), true) } func (s *snapmgrTestSuite) TestEnsureRefreshesImmediateWithUpdate(c *C) { @@ -11383,7 +11400,7 @@ func (s *snapmgrTestSuite) TestUpdateManyExplicitLayoutsChecksFeatureFlag(c *C) SnapType: "app", }) - _, _, err := snapstate.UpdateMany(context.TODO(), s.state, []string{"some-snap"}, s.user.ID) + _, _, err := snapstate.UpdateMany(context.TODO(), s.state, []string{"some-snap"}, s.user.ID, nil) c.Assert(err, ErrorMatches, "experimental feature disabled - test it by setting 'experimental.layouts' to true") // enable layouts @@ -11391,7 +11408,7 @@ func (s *snapmgrTestSuite) TestUpdateManyExplicitLayoutsChecksFeatureFlag(c *C) tr.Set("core", "experimental.layouts", true) tr.Commit() - _, _, err = snapstate.UpdateMany(context.TODO(), s.state, []string{"some-snap"}, s.user.ID) + _, _, err = snapstate.UpdateMany(context.TODO(), s.state, []string{"some-snap"}, s.user.ID, nil) c.Assert(err, IsNil) } @@ -11409,7 +11426,7 @@ func (s *snapmgrTestSuite) TestUpdateManyLayoutsChecksFeatureFlag(c *C) { SnapType: "app", }) - refreshes, _, err := snapstate.UpdateMany(context.TODO(), s.state, nil, s.user.ID) + refreshes, _, err := snapstate.UpdateMany(context.TODO(), s.state, nil, s.user.ID, nil) c.Assert(err, IsNil) c.Assert(refreshes, HasLen, 0) @@ -11418,7 +11435,7 @@ func (s *snapmgrTestSuite) TestUpdateManyLayoutsChecksFeatureFlag(c *C) { tr.Set("core", "experimental.layouts", true) tr.Commit() - refreshes, _, err = snapstate.UpdateMany(context.TODO(), s.state, nil, s.user.ID) + refreshes, _, err = snapstate.UpdateMany(context.TODO(), s.state, nil, s.user.ID, nil) c.Assert(err, IsNil) c.Assert(refreshes, DeepEquals, []string{"some-snap"}) } diff --git a/packaging/fedora/snapd.spec b/packaging/fedora/snapd.spec index 8e4974f2ede..1fb680f3b2f 100644 --- a/packaging/fedora/snapd.spec +++ b/packaging/fedora/snapd.spec @@ -150,6 +150,7 @@ BuildRequires: golang(github.com/godbus/dbus) BuildRequires: golang(github.com/godbus/dbus/introspect) BuildRequires: golang(github.com/gorilla/mux) BuildRequires: golang(github.com/jessevdk/go-flags) +BuildRequires: golang(github.com/juju/ratelimit) BuildRequires: golang(github.com/kr/pretty) BuildRequires: golang(github.com/kr/text) BuildRequires: golang(github.com/mvo5/goconfigparser) @@ -243,6 +244,7 @@ Requires: golang(github.com/godbus/dbus) Requires: golang(github.com/godbus/dbus/introspect) Requires: golang(github.com/gorilla/mux) Requires: golang(github.com/jessevdk/go-flags) +Requires: golang(github.com/juju/ratelimit) Requires: golang(github.com/kr/pretty) Requires: golang(github.com/kr/text) Requires: golang(github.com/mvo5/goconfigparser) @@ -271,6 +273,7 @@ Provides: bundled(golang(github.com/godbus/dbus)) Provides: bundled(golang(github.com/godbus/dbus/introspect)) Provides: bundled(golang(github.com/gorilla/mux)) Provides: bundled(golang(github.com/jessevdk/go-flags)) +Provides: bundled(golang(github.com/juju/ratelimit)) Provides: bundled(golang(github.com/kr/pretty)) Provides: bundled(golang(github.com/kr/text)) Provides: bundled(golang(github.com/mvo5/goconfigparser)) diff --git a/store/export_test.go b/store/export_test.go index ee199e92dee..0317929fa33 100644 --- a/store/export_test.go +++ b/store/export_test.go @@ -20,9 +20,12 @@ package store import ( - "github.com/snapcore/snapd/testutil" + "io" + "github.com/juju/ratelimit" "gopkg.in/retry.v1" + + "github.com/snapcore/snapd/testutil" ) var ( @@ -57,3 +60,11 @@ func MockOsRemove(f func(name string) error) func() { osRemove = oldOsRemove } } + +func MockRatelimitReader(f func(r io.Reader, bucket *ratelimit.Bucket) io.Reader) (restore func()) { + oldRatelimitReader := ratelimitReader + ratelimitReader = f + return func() { + ratelimitReader = oldRatelimitReader + } +} diff --git a/store/store.go b/store/store.go index ab46c94c949..8a2cf5bff28 100644 --- a/store/store.go +++ b/store/store.go @@ -38,6 +38,7 @@ import ( "sync" "time" + "github.com/juju/ratelimit" "golang.org/x/net/context" "golang.org/x/net/context/ctxhttp" "gopkg.in/retry.v1" @@ -1319,11 +1320,15 @@ func (e HashError) Error() string { return fmt.Sprintf("sha3-384 mismatch for %q: got %s but expected %s", e.name, e.sha3_384, e.targetSha3_384) } +type DownloadOptions struct { + RateLimit int64 +} + // Download downloads the snap addressed by download info and returns its // filename. // The file is saved in temporary storage, and should be removed // after use to prevent the disk from running out of space. -func (s *Store) Download(ctx context.Context, name string, targetPath string, downloadInfo *snap.DownloadInfo, pbar progress.Meter, user *auth.UserState) error { +func (s *Store) Download(ctx context.Context, name string, targetPath string, downloadInfo *snap.DownloadInfo, pbar progress.Meter, user *auth.UserState, dlOpts *DownloadOptions) error { if err := os.MkdirAll(filepath.Dir(targetPath), 0755); err != nil { return err } @@ -1380,7 +1385,7 @@ func (s *Store) Download(ctx context.Context, name string, targetPath string, do } if downloadInfo.Size == 0 || resume < downloadInfo.Size { - err = download(ctx, name, downloadInfo.Sha3_384, url, user, s, w, resume, pbar) + err = download(ctx, name, downloadInfo.Sha3_384, url, user, s, w, resume, pbar, dlOpts) if err != nil { logger.Debugf("download of %q failed: %#v", url, err) } @@ -1410,7 +1415,7 @@ func (s *Store) Download(ctx context.Context, name string, targetPath string, do if err != nil { return err } - err = download(ctx, name, downloadInfo.Sha3_384, url, user, s, w, 0, pbar) + err = download(ctx, name, downloadInfo.Sha3_384, url, user, s, w, 0, pbar, nil) if err != nil { logger.Debugf("download of %q failed: %#v", url, err) } @@ -1431,7 +1436,7 @@ func (s *Store) Download(ctx context.Context, name string, targetPath string, do return s.cacher.Put(downloadInfo.Sha3_384, targetPath) } -func downloadOptions(storeURL *url.URL, cdnHeader string) *requestOptions { +func reqOptions(storeURL *url.URL, cdnHeader string) *requestOptions { reqOptions := requestOptions{ Method: "GET", URL: storeURL, @@ -1444,8 +1449,14 @@ func downloadOptions(storeURL *url.URL, cdnHeader string) *requestOptions { return &reqOptions } +var ratelimitReader = ratelimit.Reader + // download writes an http.Request showing a progress.Meter -var download = func(ctx context.Context, name, sha3_384, downloadURL string, user *auth.UserState, s *Store, w io.ReadWriteSeeker, resume int64, pbar progress.Meter) error { +var download = func(ctx context.Context, name, sha3_384, downloadURL string, user *auth.UserState, s *Store, w io.ReadWriteSeeker, resume int64, pbar progress.Meter, dlOpts *DownloadOptions) error { + if dlOpts == nil { + dlOpts = &DownloadOptions{} + } + storeURL, err := url.Parse(downloadURL) if err != nil { return err @@ -1460,7 +1471,7 @@ var download = func(ctx context.Context, name, sha3_384, downloadURL string, use var dlSize float64 startTime := time.Now() for attempt := retry.Start(defaultRetryStrategy, nil); attempt.Next(); { - reqOptions := downloadOptions(storeURL, cdnHeader) + reqOptions := reqOptions(storeURL, cdnHeader) httputil.MaybeLogRetryAttempt(reqOptions.URL.String(), attempt, startTime) @@ -1519,7 +1530,13 @@ var download = func(ctx context.Context, name, sha3_384, downloadURL string, use dlSize = float64(resp.ContentLength) pbar.Start(name, dlSize) mw := io.MultiWriter(w, h, pbar) - _, finalErr = io.Copy(mw, resp.Body) + var limiter io.Reader + limiter = resp.Body + if limit := dlOpts.RateLimit; limit > 0 { + bucket := ratelimit.NewBucketWithRate(float64(limit), 2*limit) + limiter = ratelimitReader(resp.Body, bucket) + } + _, finalErr = io.Copy(mw, limiter) pbar.Finished() if finalErr != nil { if httputil.ShouldRetryError(attempt, finalErr) { @@ -1584,7 +1601,7 @@ func (s *Store) downloadDelta(deltaName string, downloadInfo *snap.DownloadInfo, url = deltaInfo.DownloadURL } - return download(context.TODO(), deltaName, deltaInfo.Sha3_384, url, user, s, w, 0, pbar) + return download(context.TODO(), deltaName, deltaInfo.Sha3_384, url, user, s, w, 0, pbar, nil) } func getXdelta3Cmd(args ...string) (*exec.Cmd, error) { @@ -2370,7 +2387,7 @@ func (s *Store) snapConnCheck() ([]string, error) { return hosts, err } - reqOptions := downloadOptions(dlURL, cdnHeader) + reqOptions := reqOptions(dlURL, cdnHeader) reqOptions.Method = "HEAD" // not actually a download // TODO: We need the HEAD here so that we get redirected to the diff --git a/store/store_test.go b/store/store_test.go index baa404a6531..75acab06331 100644 --- a/store/store_test.go +++ b/store/store_test.go @@ -37,6 +37,7 @@ import ( "testing" "time" + "github.com/juju/ratelimit" "golang.org/x/crypto/sha3" "golang.org/x/net/context" . "gopkg.in/check.v1" @@ -164,7 +165,7 @@ type storeTestSuite struct { localUser *auth.UserState device *auth.DeviceState - origDownloadFunc func(context.Context, string, string, string, *auth.UserState, *Store, io.ReadWriteSeeker, int64, progress.Meter) error + origDownloadFunc func(context.Context, string, string, string, *auth.UserState, *Store, io.ReadWriteSeeker, int64, progress.Meter, *DownloadOptions) error mockXDelta *testutil.MockCmd restoreLogger func() @@ -419,7 +420,7 @@ func (s *storeTestSuite) expectedAuthorization(c *C, user *auth.UserState) strin func (s *storeTestSuite) TestDownloadOK(c *C) { expectedContent := []byte("I was downloaded") - download = func(ctx context.Context, name, sha3, url string, user *auth.UserState, s *Store, w io.ReadWriteSeeker, resume int64, pbar progress.Meter) error { + download = func(ctx context.Context, name, sha3, url string, user *auth.UserState, s *Store, w io.ReadWriteSeeker, resume int64, pbar progress.Meter, dlOpts *DownloadOptions) error { c.Check(url, Equals, "anon-url") w.Write(expectedContent) return nil @@ -432,7 +433,7 @@ func (s *storeTestSuite) TestDownloadOK(c *C) { snap.Size = int64(len(expectedContent)) path := filepath.Join(c.MkDir(), "downloaded-file") - err := s.store.Download(context.TODO(), "foo", path, &snap.DownloadInfo, nil, nil) + err := s.store.Download(context.TODO(), "foo", path, &snap.DownloadInfo, nil, nil, nil) c.Assert(err, IsNil) defer os.Remove(path) @@ -444,7 +445,7 @@ func (s *storeTestSuite) TestDownloadRangeRequest(c *C) { missingContentStr := "was downloaded" expectedContentStr := partialContentStr + missingContentStr - download = func(ctx context.Context, name, sha3, url string, user *auth.UserState, s *Store, w io.ReadWriteSeeker, resume int64, pbar progress.Meter) error { + download = func(ctx context.Context, name, sha3, url string, user *auth.UserState, s *Store, w io.ReadWriteSeeker, resume int64, pbar progress.Meter, dlOpts *DownloadOptions) error { c.Check(resume, Equals, int64(len(partialContentStr))) c.Check(url, Equals, "anon-url") w.Write([]byte(missingContentStr)) @@ -462,7 +463,7 @@ func (s *storeTestSuite) TestDownloadRangeRequest(c *C) { err := ioutil.WriteFile(targetFn+".partial", []byte(partialContentStr), 0644) c.Assert(err, IsNil) - err = s.store.Download(context.TODO(), "foo", targetFn, &snap.DownloadInfo, nil, nil) + err = s.store.Download(context.TODO(), "foo", targetFn, &snap.DownloadInfo, nil, nil, nil) c.Assert(err, IsNil) c.Assert(targetFn, testutil.FileEquals, expectedContentStr) @@ -484,7 +485,7 @@ func (s *storeTestSuite) TestResumeOfCompleted(c *C) { err := ioutil.WriteFile(targetFn+".partial", []byte(expectedContentStr), 0644) c.Assert(err, IsNil) - err = s.store.Download(context.TODO(), "foo", targetFn, &snap.DownloadInfo, nil, nil) + err = s.store.Download(context.TODO(), "foo", targetFn, &snap.DownloadInfo, nil, nil, nil) c.Assert(err, IsNil) c.Assert(targetFn, testutil.FileEquals, expectedContentStr) @@ -525,7 +526,7 @@ func (s *storeTestSuite) TestDownloadEOFHandlesResumeHashCorrectly(c *C) { snap.Size = 50000 targetFn := filepath.Join(c.MkDir(), "foo_1.0_all.snap") - err := s.store.Download(context.TODO(), "foo", targetFn, &snap.DownloadInfo, nil, nil) + err := s.store.Download(context.TODO(), "foo", targetFn, &snap.DownloadInfo, nil, nil, nil) c.Assert(err, IsNil) c.Assert(targetFn, testutil.FileEquals, buf) c.Assert(s.logbuf.String(), Matches, "(?s).*Retrying .* attempt 2, .*") @@ -569,7 +570,7 @@ func (s *storeTestSuite) TestDownloadRetryHashErrorIsFullyRetried(c *C) { snap.Size = 50000 targetFn := filepath.Join(c.MkDir(), "foo_1.0_all.snap") - err := s.store.Download(context.TODO(), "foo", targetFn, &snap.DownloadInfo, nil, nil) + err := s.store.Download(context.TODO(), "foo", targetFn, &snap.DownloadInfo, nil, nil, nil) c.Assert(err, IsNil) c.Assert(targetFn, testutil.FileEquals, buf) @@ -606,7 +607,7 @@ func (s *storeTestSuite) TestResumeOfCompletedRetriedOnHashFailure(c *C) { targetFn := filepath.Join(c.MkDir(), "foo_1.0_all.snap") c.Assert(ioutil.WriteFile(targetFn+".partial", badbuf, 0644), IsNil) - err := s.store.Download(context.TODO(), "foo", targetFn, &snap.DownloadInfo, nil, nil) + err := s.store.Download(context.TODO(), "foo", targetFn, &snap.DownloadInfo, nil, nil, nil) c.Assert(err, IsNil) c.Assert(targetFn, testutil.FileEquals, buf) @@ -634,7 +635,7 @@ func (s *storeTestSuite) TestDownloadRetryHashErrorIsFullyRetriedOnlyOnce(c *C) snap.Size = int64(len("something invalid")) targetFn := filepath.Join(c.MkDir(), "foo_1.0_all.snap") - err := s.store.Download(context.TODO(), "foo", targetFn, &snap.DownloadInfo, nil, nil) + err := s.store.Download(context.TODO(), "foo", targetFn, &snap.DownloadInfo, nil, nil, nil) _, ok := err.(HashError) c.Assert(ok, Equals, true) @@ -647,7 +648,7 @@ func (s *storeTestSuite) TestDownloadRangeRequestRetryOnHashError(c *C) { partialContentStr := "partial content " n := 0 - download = func(ctx context.Context, name, sha3, url string, user *auth.UserState, s *Store, w io.ReadWriteSeeker, resume int64, pbar progress.Meter) error { + download = func(ctx context.Context, name, sha3, url string, user *auth.UserState, s *Store, w io.ReadWriteSeeker, resume int64, pbar progress.Meter, dlOpts *DownloadOptions) error { n++ if n == 1 { // force sha3 error on first download @@ -669,7 +670,7 @@ func (s *storeTestSuite) TestDownloadRangeRequestRetryOnHashError(c *C) { err := ioutil.WriteFile(targetFn+".partial", []byte(partialContentStr), 0644) c.Assert(err, IsNil) - err = s.store.Download(context.TODO(), "foo", targetFn, &snap.DownloadInfo, nil, nil) + err = s.store.Download(context.TODO(), "foo", targetFn, &snap.DownloadInfo, nil, nil, nil) c.Assert(err, IsNil) c.Assert(n, Equals, 2) @@ -680,7 +681,7 @@ func (s *storeTestSuite) TestDownloadRangeRequestFailOnHashError(c *C) { partialContentStr := "partial content " n := 0 - download = func(ctx context.Context, name, sha3, url string, user *auth.UserState, s *Store, w io.ReadWriteSeeker, resume int64, pbar progress.Meter) error { + download = func(ctx context.Context, name, sha3, url string, user *auth.UserState, s *Store, w io.ReadWriteSeeker, resume int64, pbar progress.Meter, dlOpts *DownloadOptions) error { n++ return HashError{"foo", "1234", "5678"} } @@ -696,7 +697,7 @@ func (s *storeTestSuite) TestDownloadRangeRequestFailOnHashError(c *C) { err := ioutil.WriteFile(targetFn+".partial", []byte(partialContentStr), 0644) c.Assert(err, IsNil) - err = s.store.Download(context.TODO(), "foo", targetFn, &snap.DownloadInfo, nil, nil) + err = s.store.Download(context.TODO(), "foo", targetFn, &snap.DownloadInfo, nil, nil, nil) c.Assert(err, NotNil) c.Assert(err, ErrorMatches, `sha3-384 mismatch for "foo": got 1234 but expected 5678`) c.Assert(n, Equals, 2) @@ -704,7 +705,7 @@ func (s *storeTestSuite) TestDownloadRangeRequestFailOnHashError(c *C) { func (s *storeTestSuite) TestAuthenticatedDownloadDoesNotUseAnonURL(c *C) { expectedContent := []byte("I was downloaded") - download = func(ctx context.Context, name, sha3, url string, user *auth.UserState, _ *Store, w io.ReadWriteSeeker, resume int64, pbar progress.Meter) error { + download = func(ctx context.Context, name, sha3, url string, user *auth.UserState, _ *Store, w io.ReadWriteSeeker, resume int64, pbar progress.Meter, dlOpts *DownloadOptions) error { // check user is pass and auth url is used c.Check(user, Equals, s.user) c.Check(url, Equals, "AUTH-URL") @@ -720,7 +721,7 @@ func (s *storeTestSuite) TestAuthenticatedDownloadDoesNotUseAnonURL(c *C) { snap.Size = int64(len(expectedContent)) path := filepath.Join(c.MkDir(), "downloaded-file") - err := s.store.Download(context.TODO(), "foo", path, &snap.DownloadInfo, nil, s.user) + err := s.store.Download(context.TODO(), "foo", path, &snap.DownloadInfo, nil, s.user, nil) c.Assert(err, IsNil) defer os.Remove(path) @@ -729,7 +730,7 @@ func (s *storeTestSuite) TestAuthenticatedDownloadDoesNotUseAnonURL(c *C) { func (s *storeTestSuite) TestAuthenticatedDeviceDoesNotUseAnonURL(c *C) { expectedContent := []byte("I was downloaded") - download = func(ctx context.Context, name, sha3, url string, user *auth.UserState, s *Store, w io.ReadWriteSeeker, resume int64, pbar progress.Meter) error { + download = func(ctx context.Context, name, sha3, url string, user *auth.UserState, s *Store, w io.ReadWriteSeeker, resume int64, pbar progress.Meter, dlOpts *DownloadOptions) error { // check auth url is used c.Check(url, Equals, "AUTH-URL") @@ -747,7 +748,7 @@ func (s *storeTestSuite) TestAuthenticatedDeviceDoesNotUseAnonURL(c *C) { sto := New(&Config{}, authContext) path := filepath.Join(c.MkDir(), "downloaded-file") - err := sto.Download(context.TODO(), "foo", path, &snap.DownloadInfo, nil, nil) + err := sto.Download(context.TODO(), "foo", path, &snap.DownloadInfo, nil, nil, nil) c.Assert(err, IsNil) defer os.Remove(path) @@ -756,7 +757,7 @@ func (s *storeTestSuite) TestAuthenticatedDeviceDoesNotUseAnonURL(c *C) { func (s *storeTestSuite) TestLocalUserDownloadUsesAnonURL(c *C) { expectedContentStr := "I was downloaded" - download = func(ctx context.Context, name, sha3, url string, user *auth.UserState, s *Store, w io.ReadWriteSeeker, resume int64, pbar progress.Meter) error { + download = func(ctx context.Context, name, sha3, url string, user *auth.UserState, s *Store, w io.ReadWriteSeeker, resume int64, pbar progress.Meter, dlOpts *DownloadOptions) error { c.Check(url, Equals, "anon-url") w.Write([]byte(expectedContentStr)) @@ -770,7 +771,7 @@ func (s *storeTestSuite) TestLocalUserDownloadUsesAnonURL(c *C) { snap.Size = int64(len(expectedContentStr)) path := filepath.Join(c.MkDir(), "downloaded-file") - err := s.store.Download(context.TODO(), "foo", path, &snap.DownloadInfo, nil, s.localUser) + err := s.store.Download(context.TODO(), "foo", path, &snap.DownloadInfo, nil, s.localUser, nil) c.Assert(err, IsNil) defer os.Remove(path) @@ -779,7 +780,7 @@ func (s *storeTestSuite) TestLocalUserDownloadUsesAnonURL(c *C) { func (s *storeTestSuite) TestDownloadFails(c *C) { var tmpfile *os.File - download = func(ctx context.Context, name, sha3, url string, user *auth.UserState, s *Store, w io.ReadWriteSeeker, resume int64, pbar progress.Meter) error { + download = func(ctx context.Context, name, sha3, url string, user *auth.UserState, s *Store, w io.ReadWriteSeeker, resume int64, pbar progress.Meter, dlOpts *DownloadOptions) error { tmpfile = w.(*os.File) return fmt.Errorf("uh, it failed") } @@ -791,7 +792,7 @@ func (s *storeTestSuite) TestDownloadFails(c *C) { snap.Size = 1 // simulate a failed download path := filepath.Join(c.MkDir(), "downloaded-file") - err := s.store.Download(context.TODO(), "foo", path, &snap.DownloadInfo, nil, nil) + err := s.store.Download(context.TODO(), "foo", path, &snap.DownloadInfo, nil, nil, nil) c.Assert(err, ErrorMatches, "uh, it failed") // ... and ensure that the tempfile is removed c.Assert(osutil.FileExists(tmpfile.Name()), Equals, false) @@ -799,7 +800,7 @@ func (s *storeTestSuite) TestDownloadFails(c *C) { func (s *storeTestSuite) TestDownloadSyncFails(c *C) { var tmpfile *os.File - download = func(ctx context.Context, name, sha3, url string, user *auth.UserState, s *Store, w io.ReadWriteSeeker, resume int64, pbar progress.Meter) error { + download = func(ctx context.Context, name, sha3, url string, user *auth.UserState, s *Store, w io.ReadWriteSeeker, resume int64, pbar progress.Meter, dlOpts *DownloadOptions) error { tmpfile = w.(*os.File) w.Write([]byte("sync will fail")) err := tmpfile.Close() @@ -815,7 +816,7 @@ func (s *storeTestSuite) TestDownloadSyncFails(c *C) { // simulate a failed sync path := filepath.Join(c.MkDir(), "downloaded-file") - err := s.store.Download(context.TODO(), "foo", path, &snap.DownloadInfo, nil, nil) + err := s.store.Download(context.TODO(), "foo", path, &snap.DownloadInfo, nil, nil, nil) c.Assert(err, ErrorMatches, `(sync|fsync:) .*`) // ... and ensure that the tempfile is removed c.Assert(osutil.FileExists(tmpfile.Name()), Equals, false) @@ -835,7 +836,7 @@ func (s *storeTestSuite) TestActualDownload(c *C) { var buf SillyBuffer // keep tests happy sha3 := "" - err := download(context.TODO(), "foo", sha3, mockServer.URL, nil, theStore, &buf, 0, nil) + err := download(context.TODO(), "foo", sha3, mockServer.URL, nil, theStore, &buf, 0, nil, nil) c.Assert(err, IsNil) c.Check(buf.String(), Equals, "response-data") c.Check(n, Equals, 1) @@ -856,7 +857,7 @@ func (s *storeTestSuite) TestActualDownloadNoCDN(c *C) { var buf SillyBuffer // keep tests happy sha3 := "" - err := download(context.TODO(), "foo", sha3, mockServer.URL, nil, theStore, &buf, 0, nil) + err := download(context.TODO(), "foo", sha3, mockServer.URL, nil, theStore, &buf, 0, nil, nil) c.Assert(err, IsNil) c.Check(buf.String(), Equals, "response-data") } @@ -875,7 +876,7 @@ func (s *storeTestSuite) TestActualDownloadFullCloudInfoFromAuthContext(c *C) { var buf SillyBuffer // keep tests happy sha3 := "" - err := download(context.TODO(), "foo", sha3, mockServer.URL, nil, theStore, &buf, 0, nil) + err := download(context.TODO(), "foo", sha3, mockServer.URL, nil, theStore, &buf, 0, nil, nil) c.Assert(err, IsNil) c.Check(buf.String(), Equals, "response-data") } @@ -894,7 +895,7 @@ func (s *storeTestSuite) TestActualDownloadLessDetailedCloudInfoFromAuthContext( var buf SillyBuffer // keep tests happy sha3 := "" - err := download(context.TODO(), "foo", sha3, mockServer.URL, nil, theStore, &buf, 0, nil) + err := download(context.TODO(), "foo", sha3, mockServer.URL, nil, theStore, &buf, 0, nil, nil) c.Assert(err, IsNil) c.Check(buf.String(), Equals, "response-data") } @@ -922,7 +923,7 @@ func (s *storeTestSuite) TestDownloadCancellation(c *C) { go func() { sha3 := "" var buf SillyBuffer - err := download(ctx, "foo", sha3, mockServer.URL, nil, theStore, &buf, 0, nil) + err := download(ctx, "foo", sha3, mockServer.URL, nil, theStore, &buf, 0, nil, nil) result <- err.Error() close(result) }() @@ -935,6 +936,27 @@ func (s *storeTestSuite) TestDownloadCancellation(c *C) { c.Assert(err, Equals, "The download has been cancelled: context canceled") } +func (s *storeTestSuite) TestActualDownloadRateLimited(c *C) { + var ratelimitReaderUsed bool + restore := MockRatelimitReader(func(r io.Reader, bucket *ratelimit.Bucket) io.Reader { + ratelimitReaderUsed = true + return r + }) + defer restore() + + canary := "downloaded data" + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + fmt.Fprint(w, canary) + })) + defer ts.Close() + + var buf SillyBuffer + err := download(context.TODO(), "example-name", "", ts.URL, nil, s.store, &buf, 0, nil, &DownloadOptions{RateLimit: 1}) + c.Assert(err, IsNil) + c.Check(buf.String(), Equals, canary) + c.Check(ratelimitReaderUsed, Equals, true) +} + type nopeSeeker struct{ io.ReadWriter } func (nopeSeeker) Seek(int64, int) (int64, error) { @@ -954,7 +976,7 @@ func (s *storeTestSuite) TestActualDownloadNonPurchased402(c *C) { theStore := New(&Config{}, nil) var buf bytes.Buffer - err := download(context.TODO(), "foo", "sha3", mockServer.URL, nil, theStore, nopeSeeker{&buf}, -1, nil) + err := download(context.TODO(), "foo", "sha3", mockServer.URL, nil, theStore, nopeSeeker{&buf}, -1, nil, nil) c.Assert(err, NotNil) c.Check(err.Error(), Equals, "please buy foo before installing it.") c.Check(n, Equals, 1) @@ -971,7 +993,7 @@ func (s *storeTestSuite) TestActualDownload404(c *C) { theStore := New(&Config{}, nil) var buf SillyBuffer - err := download(context.TODO(), "foo", "sha3", mockServer.URL, nil, theStore, &buf, 0, nil) + err := download(context.TODO(), "foo", "sha3", mockServer.URL, nil, theStore, &buf, 0, nil, nil) c.Assert(err, NotNil) c.Assert(err, FitsTypeOf, &DownloadError{}) c.Check(err.(*DownloadError).Code, Equals, 404) @@ -989,7 +1011,7 @@ func (s *storeTestSuite) TestActualDownload500(c *C) { theStore := New(&Config{}, nil) var buf SillyBuffer - err := download(context.TODO(), "foo", "sha3", mockServer.URL, nil, theStore, &buf, 0, nil) + err := download(context.TODO(), "foo", "sha3", mockServer.URL, nil, theStore, &buf, 0, nil, nil) c.Assert(err, NotNil) c.Assert(err, FitsTypeOf, &DownloadError{}) c.Check(err.(*DownloadError).Code, Equals, 500) @@ -1013,7 +1035,7 @@ func (s *storeTestSuite) TestActualDownload500Once(c *C) { var buf SillyBuffer // keep tests happy sha3 := "" - err := download(context.TODO(), "foo", sha3, mockServer.URL, nil, theStore, &buf, 0, nil) + err := download(context.TODO(), "foo", sha3, mockServer.URL, nil, theStore, &buf, 0, nil, nil) c.Assert(err, IsNil) c.Check(buf.String(), Equals, "response-data") c.Check(n, Equals, 2) @@ -1080,7 +1102,7 @@ func (s *storeTestSuite) TestActualDownloadResume(c *C) { h := crypto.SHA3_384.New() h.Write([]byte("some data")) sha3 := fmt.Sprintf("%x", h.Sum(nil)) - err := download(context.TODO(), "foo", sha3, mockServer.URL, nil, theStore, buf, int64(len("some ")), nil) + err := download(context.TODO(), "foo", sha3, mockServer.URL, nil, theStore, buf, int64(len("some ")), nil, nil) c.Check(err, IsNil) c.Check(buf.String(), Equals, "some data") c.Check(n, Equals, 1) @@ -1214,7 +1236,7 @@ func (s *storeTestSuite) TestDownloadWithDelta(c *C) { for _, testCase := range deltaTests { testCase.info.Size = int64(len(testCase.expectedContent)) downloadIndex := 0 - download = func(ctx context.Context, name, sha3, url string, user *auth.UserState, s *Store, w io.ReadWriteSeeker, resume int64, pbar progress.Meter) error { + download = func(ctx context.Context, name, sha3, url string, user *auth.UserState, s *Store, w io.ReadWriteSeeker, resume int64, pbar progress.Meter, dlOpts *DownloadOptions) error { if testCase.downloads[downloadIndex].error { downloadIndex++ return errors.New("Bang") @@ -1232,7 +1254,7 @@ func (s *storeTestSuite) TestDownloadWithDelta(c *C) { } path := filepath.Join(c.MkDir(), "subdir", "downloaded-file") - err := s.store.Download(context.TODO(), "foo", path, &testCase.info, nil, nil) + err := s.store.Download(context.TODO(), "foo", path, &testCase.info, nil, nil, nil) c.Assert(err, IsNil) defer os.Remove(path) @@ -1344,7 +1366,7 @@ func (s *storeTestSuite) TestDownloadDelta(c *C) { for _, testCase := range downloadDeltaTests { sto.deltaFormat = testCase.format - download = func(ctx context.Context, name, sha3, url string, user *auth.UserState, _ *Store, w io.ReadWriteSeeker, resume int64, pbar progress.Meter) error { + download = func(ctx context.Context, name, sha3, url string, user *auth.UserState, _ *Store, w io.ReadWriteSeeker, resume int64, pbar progress.Meter, dlOpts *DownloadOptions) error { expectedUser := s.user if testCase.useLocalUser { expectedUser = s.localUser @@ -4472,7 +4494,7 @@ func (s *storeTestSuite) TestDownloadCacheHit(c *C) { obs := &cacheObserver{inCache: map[string]bool{"the-snaps-sha3_384": true}} s.store.cacher = obs - download = func(ctx context.Context, name, sha3, url string, user *auth.UserState, s *Store, w io.ReadWriteSeeker, resume int64, pbar progress.Meter) error { + download = func(ctx context.Context, name, sha3, url string, user *auth.UserState, s *Store, w io.ReadWriteSeeker, resume int64, pbar progress.Meter, dlOpts *DownloadOptions) error { c.Fatalf("download should not be called when results come from the cache") return nil } @@ -4481,7 +4503,7 @@ func (s *storeTestSuite) TestDownloadCacheHit(c *C) { snap.Sha3_384 = "the-snaps-sha3_384" path := filepath.Join(c.MkDir(), "downloaded-file") - err := s.store.Download(context.TODO(), "foo", path, &snap.DownloadInfo, nil, nil) + err := s.store.Download(context.TODO(), "foo", path, &snap.DownloadInfo, nil, nil, nil) c.Assert(err, IsNil) c.Check(obs.gets, DeepEquals, []string{fmt.Sprintf("%s:%s", snap.Sha3_384, path)}) @@ -4495,7 +4517,7 @@ func (s *storeTestSuite) TestDownloadCacheMiss(c *C) { s.store.cacher = obs downloadWasCalled := false - download = func(ctx context.Context, name, sha3, url string, user *auth.UserState, s *Store, w io.ReadWriteSeeker, resume int64, pbar progress.Meter) error { + download = func(ctx context.Context, name, sha3, url string, user *auth.UserState, s *Store, w io.ReadWriteSeeker, resume int64, pbar progress.Meter, dlOpts *DownloadOptions) error { downloadWasCalled = true return nil } @@ -4504,7 +4526,7 @@ func (s *storeTestSuite) TestDownloadCacheMiss(c *C) { snap.Sha3_384 = "the-snaps-sha3_384" path := filepath.Join(c.MkDir(), "downloaded-file") - err := s.store.Download(context.TODO(), "foo", path, &snap.DownloadInfo, nil, nil) + err := s.store.Download(context.TODO(), "foo", path, &snap.DownloadInfo, nil, nil, nil) c.Assert(err, IsNil) c.Check(downloadWasCalled, Equals, true) diff --git a/store/storetest/storetest.go b/store/storetest/storetest.go index f927b69b612..31af24c50db 100644 --- a/store/storetest/storetest.go +++ b/store/storetest/storetest.go @@ -53,7 +53,7 @@ func (Store) SnapAction(context.Context, []*store.CurrentSnap, []*store.SnapActi panic("Store.SnapAction not expected") } -func (Store) Download(context.Context, string, string, *snap.DownloadInfo, progress.Meter, *auth.UserState) error { +func (Store) Download(context.Context, string, string, *snap.DownloadInfo, progress.Meter, *auth.UserState, *store.DownloadOptions) error { panic("Store.Download not expected") } diff --git a/vendor/vendor.json b/vendor/vendor.json index 8e282a35c5e..8884139bd04 100644 --- a/vendor/vendor.json +++ b/vendor/vendor.json @@ -42,6 +42,12 @@ "revision": "96dc06278ce32a0e9d957d590bb987c81ee66407", "revisionTime": "2017-07-20T12:40:56Z" }, + { + "checksumSHA1": "dqtKfXGotqkiYaS328PpWeusig8=", + "path": "github.com/juju/ratelimit", + "revision": "59fac5042749a5afb9af70e813da1dd5474f0167", + "revisionTime": "2017-10-26T09:04:26Z" + }, { "checksumSHA1": "3ohk4dFYrERZ6WTdKkIwnTA0HSI=", "path": "github.com/kr/pretty",