Skip to content

Commit

Permalink
Add force flag to upgrade without confirmation
Browse files Browse the repository at this point in the history
  • Loading branch information
ruimarinho committed Mar 6, 2021
1 parent 4e3cac6 commit f8a6ec6
Show file tree
Hide file tree
Showing 4 changed files with 35 additions and 19 deletions.
1 change: 1 addition & 0 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ require (
github.com/jdxcode/netrc v0.0.0-20190329161231-b36f1c51d91d
github.com/kr/pretty v0.1.0 // indirect
github.com/sirupsen/logrus v1.5.0
github.com/spf13/pflag v1.0.5
github.com/stretchr/testify v1.3.0
golang.org/x/sys v0.0.0-20200117145432-59e60aa80a0c // indirect
gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127 // indirect
Expand Down
2 changes: 2 additions & 0 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,8 @@ github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZb
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/sirupsen/logrus v1.5.0 h1:1N5EYkVAPEywqZRJd7cwnRtCb6xJx7NH3T3WUTF980Q=
github.com/sirupsen/logrus v1.5.0/go.mod h1:+F7Ogzej0PZc/94MaYx/nvG9jOFMD2osvC3s+Squfpo=
github.com/spf13/pflag v1.0.5 h1:iy+VFUOCP1a+8yFto/drg2CJ5u0yRoB7fZw3DKv/JXA=
github.com/spf13/pflag v1.0.5/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg=
github.com/stretchr/objx v0.1.0 h1:4G4v2dO3VZwixGIRoQ5Lfboy6nUhCyYzaqnIAPPhYs4=
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
github.com/stretchr/testify v1.2.1/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs=
Expand Down
11 changes: 6 additions & 5 deletions main.go
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
package main

import (
"flag"
"fmt"
"os"

log "github.com/sirupsen/logrus"
flag "github.com/spf13/pflag"
)

var (
Expand All @@ -16,10 +16,11 @@ var (

var (
domain = flag.String("domain", "local", "Set the search domain for the local network.")
waitTime = flag.Int("wait", 60, "Duration in [s] to run discovery.")
httpPort = flag.Int("http-port", 0, "HTTP port to listen for OTA requests. If not specified, a random port is chosen.")
waitTime = flag.IntP("wait", "w", 60, "Duration in [s] to run discovery.")
httpPort = flag.IntP("http-port", "p", 0, "HTTP port to listen for OTA requests. If not specified, a random port is chosen.")
verbose = flag.Bool("verbose", false, "Enable verbose mode.")
showVersion = flag.Bool("version", false, "Show version information")
showVersion = flag.BoolP("version", "v", false, "Show version information")
force = flag.BoolP("force", "f", false, "Force upgrades without asking for confirmation")
)

func main() {
Expand All @@ -38,7 +39,7 @@ func main() {
os.Exit(0)
}

updater, err := NewOTAUpdater(*httpPort, "_http._tcp.", *domain, *waitTime)
updater, err := NewOTAUpdater(*httpPort, "_http._tcp.", *domain, *waitTime, WithForcedUpgrades(*force))
if err != nil {
log.Fatal(err)
}
Expand Down
40 changes: 26 additions & 14 deletions ota_updater.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,19 +26,28 @@ type OTAUpdater struct {
downloadDir string
httpPort int
serverIP net.IP
force bool
}

// OTAUpdaterOption is an option interface for OTAUpdater.
type OTAUpdaterOption func(*OTAUpdater)

// WithAPIClient is an OTAUpdater option that allows overriding the
// APIClient used to interact with the Shelly API.
func WithAPIClient(api *APIClient) func(*OTAUpdater) {
func WithAPIClient(api *APIClient) OTAUpdaterOption {
return func(o *OTAUpdater) {
o.api = api
}
}

// WithForcedUpgrades is an OTAUpdater option that allows overriding
// the default behaviour of confirming upgrades interactively.
func WithForcedUpgrades(force bool) OTAUpdaterOption {
return func(o *OTAUpdater) {
o.force = force
}
}

// NewOTAUpdater returns an instance of OTAUpdater with the default
// options. Firmware downloads are stored on the OS cache or temp
// directories.
Expand Down Expand Up @@ -70,8 +79,8 @@ func NewOTAUpdater(httpPort int, service string, domain string, waitTime int, op
}

// Apply custom OTAUpdaterOptions.
for i := range options {
options[i](&updater)
for _, option := range options {
option(&updater)
}

return updater, nil
Expand Down Expand Up @@ -230,19 +239,22 @@ func (o *OTAUpdater) PromptForUpgrade() error {
}

upgrade := false
prompt := &survey.Confirm{
Message: fmt.Sprintf("Would you like to upgrade %v (%v) from %v to %v?", device.ModelName(), device.IP, device.CurrentFWVersion, device.NewFWVersion),
}

err := survey.AskOne(prompt, &upgrade)
if err == terminal.InterruptErr {
break
} else if err != nil {
return err
}
if !o.force {
prompt := &survey.Confirm{
Message: fmt.Sprintf("Would you like to upgrade %v (%v) from %v to %v?", device.ModelName(), device.IP, device.CurrentFWVersion, device.NewFWVersion),
}

if !upgrade {
continue
err := survey.AskOne(prompt, &upgrade)
if err == terminal.InterruptErr {
break
} else if err != nil {
return err
}

if !upgrade {
continue
}
}

o.UpgradeDevice(device)
Expand Down

0 comments on commit f8a6ec6

Please sign in to comment.