diff --git a/cmd/tailscale/cli/update.go b/cmd/tailscale/cli/update.go index 514f75e89976f..77e6ede6bf3fc 100644 --- a/cmd/tailscale/cli/update.go +++ b/cmd/tailscale/cli/update.go @@ -5,17 +5,29 @@ package cli import ( + "bytes" "context" + "crypto/sha256" + "encoding/hex" "encoding/json" "errors" "flag" "fmt" + "io" + "log" "net/http" "os" + "os/exec" + "path" + "path/filepath" "runtime" + "strconv" "strings" + "time" "github.com/peterbourgon/ff/v3/ffcli" + "tailscale.com/net/tshttpproxy" + "tailscale.com/util/must" "tailscale.com/util/winutil" "tailscale.com/version" "tailscale.com/version/distro" @@ -24,7 +36,7 @@ import ( var updateCmd = &ffcli.Command{ Name: "update", ShortUsage: "update", - ShortHelp: "Update Tailscale to the latest/different version", + ShortHelp: "[ALPHA] Update Tailscale to the latest/different version", Exec: runUpdate, FlagSet: (func() *flag.FlagSet { fs := newFlagSet("update") @@ -43,7 +55,22 @@ var updateArgs struct { version string // explicit version; empty means auto } +// winMSIEnv is the environment variable that, if set, contains makes the update +// command install the MSI file of this environment variable value. It's passed +// like this so we can stop the tailscale.exe process from running before the +// msiexec process runs and tries to overwrite ourselves. +const winMSIEnv = "TS_UPDATE_WIN_MSI" + func runUpdate(ctx context.Context, args []string) error { + if msi := os.Getenv(winMSIEnv); msi != "" { + log.Printf("installing %v ...", msi) + if err := installMSI(msi); err != nil { + log.Printf("MSI install failed: %v", err) + return err + } + log.Printf("success.") + return nil + } if len(args) > 0 { return flag.ErrHelp } @@ -57,6 +84,22 @@ func runUpdate(ctx context.Context, args []string) error { return up.update() } +func versionIsStable(v string) (stable, wellFormed bool) { + _, rest, ok := strings.Cut(v, ".") + if !ok { + return false, false + } + minorStr, _, ok := strings.Cut(rest, ".") + if !ok { + return false, false + } + minor, err := strconv.Atoi(minorStr) + if err != nil { + return false, false + } + return minor%2 == 0, true +} + func newUpdater() (*updater, error) { up := &updater{ track: updateArgs.track, @@ -69,6 +112,17 @@ func newUpdater() (*updater, error) { } else { up.track = "stable" } + if updateArgs.version != "" { + stable, ok := versionIsStable(updateArgs.version) + if !ok { + return nil, fmt.Errorf("malformed version %q", updateArgs.version) + } + if stable { + up.track = "stable" + } else { + up.track = "unstable" + } + } default: return nil, fmt.Errorf("unknown track %q; must be 'stable' or 'unstable'", up.track) } @@ -115,6 +169,23 @@ func (up *updater) currentOrDryRun(ver string) bool { return false } +func (up *updater) confirm(ver string) error { + if updateArgs.yes { + log.Printf("Updating Tailscale from %v to %v; --yes given, continuing without prompts.\n", version.Short, ver) + return nil + } + + fmt.Printf("This will update Tailscale from %v to %v. Continue? [y/n] ", version.Short, ver) + var resp string + fmt.Scanln(&resp) + resp = strings.ToLower(resp) + switch resp { + case "y", "yes", "sure": + return nil + } + return errors.New("aborting update") +} + func (up *updater) updateSynology() error { // TODO(bradfitz): detect, map GOARCH+CPU to the right Synology arch. // TODO(bradfitz): add pkgs.tailscale.com endpoint to get release info @@ -200,6 +271,169 @@ func (up *updater) updateWindows() error { if !winutil.IsCurrentProcessElevated() { return errors.New("must be run as Administrator") } - // TODO(bradfitz): require elevated mode - return errors.New("TODO: download + msiexec /i /quiet " + url) + if err := up.confirm(ver); err != nil { + return err + } + + targetDir := filepath.Join(os.Getenv("ProgramData"), "Tailscale", "MSICache") + if err := os.MkdirAll(targetDir, 0700); err != nil { + return err + } + msiTarget := filepath.Join(targetDir, path.Base(url)) + if err := downloadURLToFile(url, msiTarget); err != nil { + return err + } + + log.Printf("copying tailscaled.exe to copy...") + selfCopy, err := makeSelfCopy() + if err != nil { + return err + } + defer os.Remove(selfCopy) + log.Printf("running copy of self...") + + cmd := exec.Command(selfCopy, "update") + cmd.Env = append(os.Environ(), winMSIEnv+"="+msiTarget) + cmd.Stdout = os.Stderr + cmd.Stderr = os.Stderr + cmd.Stdin = os.Stdin + if err := cmd.Start(); err != nil { + return err + } + // Once it's started, exit ourselves, so the binary is free + // to be replaced. + os.Exit(0) + panic("unreachable") +} + +func installMSI(msi string) error { + cmd := exec.Command("msiexec.exe", "/i", filepath.Base(msi), "/quiet", "/promptrestart", "/qn") + // TODO(bradfitz): add REINSTALL=ALL REINSTALLMODE=A to permit downgrades? Doesn't seem to work. + cmd.Dir = filepath.Dir(msi) + cmd.Stdout = os.Stdout + cmd.Stderr = os.Stderr + cmd.Stdin = os.Stdin + return cmd.Run() +} + +func makeSelfCopy() (tmpPathExe string, err error) { + selfExe, err := os.Executable() + if err != nil { + return "", err + } + f, err := os.Open(selfExe) + if err != nil { + return "", err + } + defer f.Close() + f2, err := os.CreateTemp("", "tailscale-updater-*.exe") + if err != nil { + return "", err + } + if _, err := io.Copy(f2, f); err != nil { + f2.Close() + return "", err + } + return f2.Name(), f2.Close() +} + +func downloadURLToFile(urlSrc, fileDst string) (ret error) { + tr := http.DefaultTransport.(*http.Transport).Clone() + tr.Proxy = tshttpproxy.ProxyFromEnvironment + defer tr.CloseIdleConnections() + c := &http.Client{Transport: tr} + + quickCtx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + defer cancel() + headReq := must.Get(http.NewRequestWithContext(quickCtx, "HEAD", urlSrc, nil)) + + res, err := c.Do(headReq) + if err != nil { + return err + } + if res.StatusCode != http.StatusOK { + return fmt.Errorf("HEAD %s: %v", urlSrc, res.Status) + } + if res.ContentLength <= 0 { + return fmt.Errorf("HEAD %s: unexpected Content-Length %v", urlSrc, res.ContentLength) + } + log.Printf("Download size: %v", res.ContentLength) + + hashReq := must.Get(http.NewRequestWithContext(quickCtx, "GET", urlSrc+".sha256", nil)) + hashRes, err := c.Do(hashReq) + if err != nil { + return err + } + hashHex, err := io.ReadAll(io.LimitReader(hashRes.Body, 100)) + hashRes.Body.Close() + if res.StatusCode != http.StatusOK { + return fmt.Errorf("GET %s.sha256: %v", urlSrc, res.Status) + } + if err != nil { + return err + } + wantHash, err := hex.DecodeString(string(strings.TrimSpace(string(hashHex)))) + if err != nil { + return err + } + hash := sha256.New() + + dlReq := must.Get(http.NewRequestWithContext(context.Background(), "GET", urlSrc, nil)) + dlRes, err := c.Do(dlReq) + if err != nil { + return err + } + // TODO(bradfitz): resume from existing partial file on disk + if dlRes.StatusCode != http.StatusOK { + return fmt.Errorf("GET %s: %v", urlSrc, dlRes.Status) + } + + of, err := os.Create(fileDst) + if err != nil { + return err + } + defer func() { + if ret != nil { + of.Close() + // TODO(bradfitz): os.Remove(fileDst) too? or keep it to resume from/debug later. + } + }() + pw := &progressWriter{total: res.ContentLength} + n, err := io.Copy(io.MultiWriter(hash, of, pw), io.LimitReader(dlRes.Body, res.ContentLength)) + if err != nil { + return err + } + if n != res.ContentLength { + return fmt.Errorf("downloaded %v; want %v", n, res.ContentLength) + } + if err := of.Close(); err != nil { + return err + } + pw.print() + + if !bytes.Equal(hash.Sum(nil), wantHash) { + return fmt.Errorf("SHA-256 of downloaded MSI didn't match expected value") + } + log.Printf("hash matched") + + return nil +} + +type progressWriter struct { + done int64 + total int64 + lastPrint time.Time +} + +func (pw *progressWriter) Write(p []byte) (n int, err error) { + pw.done += int64(len(p)) + if time.Since(pw.lastPrint) > 2*time.Second { + pw.print() + } + return len(p), nil +} + +func (pw *progressWriter) print() { + pw.lastPrint = time.Now() + log.Printf("Downloaded %v/%v (%.1f%%)", pw.done, pw.total, float64(pw.done)/float64(pw.total)*100) }