Skip to content

Commit

Permalink
clientupdate: download SPK and MSI packages with distsign
Browse files Browse the repository at this point in the history
Reimplement `downloadURLToFile` using `distsign.Download` and move all
of the progress reporting logic over there.

Updates #6995
Updates #755

Signed-off-by: Andrew Lytvynov <awly@tailscale.com>
  • Loading branch information
awly committed Aug 28, 2023
1 parent c86a610 commit 0d8d649
Show file tree
Hide file tree
Showing 3 changed files with 113 additions and 137 deletions.
125 changes: 16 additions & 109 deletions clientupdate/clientupdate.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,6 @@ import (
"bufio"
"bytes"
"context"
"crypto/sha256"
"encoding/hex"
"encoding/json"
"errors"
"fmt"
Expand All @@ -25,12 +23,10 @@ import (
"runtime"
"strconv"
"strings"
"time"

"github.com/google/uuid"
"tailscale.com/net/tshttpproxy"
"tailscale.com/clientupdate/distsign"
"tailscale.com/types/logger"
"tailscale.com/util/must"
"tailscale.com/util/winutil"
"tailscale.com/version"
"tailscale.com/version/distro"
Expand Down Expand Up @@ -88,6 +84,9 @@ type UpdateArgs struct {
// if this new version should be installed. When Confirm returns false, the
// update is aborted.
Confirm func(newVer string) bool
// PkgsAddr is the address of the pkgs server to fetch updates from.
// Defaults to "https://pkgs.tailscale.com".
PkgsAddr string
}

func (args UpdateArgs) validate() error {
Expand All @@ -109,6 +108,9 @@ func Update(args UpdateArgs) error {
if err := args.validate(); err != nil {
return err
}
if args.PkgsAddr == "" {
args.PkgsAddr = "https://pkgs.tailscale.com"
}
up := &updater{
UpdateArgs: args,
}
Expand Down Expand Up @@ -222,10 +224,9 @@ func (up *updater) updateSynology() error {
if err != nil {
return err
}
url := fmt.Sprintf("https://pkgs.tailscale.com/%s/%s", up.track, spkName)
spkPath := filepath.Join(spkDir, path.Base(url))
// TODO(awly): we should sign SPKs and validate signatures here too.
if err := up.downloadURLToFile(url, spkPath); err != nil {
pkgsPath := fmt.Sprintf("%s/%s", up.track, spkName)
spkPath := filepath.Join(spkDir, path.Base(pkgsPath))
if err := up.downloadURLToFile(pkgsPath, spkPath); err != nil {
return err
}

Expand Down Expand Up @@ -650,9 +651,9 @@ func (up *updater) updateWindows() error {
if err := os.MkdirAll(msiDir, 0700); err != nil {
return err
}
url := fmt.Sprintf("https://pkgs.tailscale.com/%s/tailscale-setup-%s-%s.msi", up.track, ver, arch)
msiTarget := filepath.Join(msiDir, path.Base(url))
if err := up.downloadURLToFile(url, msiTarget); err != nil {
pkgsPath := fmt.Sprintf("%s/tailscale-setup-%s-%s.msi", up.track, ver, arch)
msiTarget := filepath.Join(msiDir, path.Base(pkgsPath))
if err := up.downloadURLToFile(pkgsPath, msiTarget); err != nil {
return err
}

Expand Down Expand Up @@ -751,106 +752,12 @@ func makeSelfCopy() (tmpPathExe string, err error) {
return f2.Name(), f2.Close()
}

func (up *updater) downloadURLToFile(urlSrc, fileDst string) (ret error) {
tr := http.DefaultTransport.(*http.Transport).Clone()
tr.Proxy = tshttpproxy.ProxyFromEnvironment
defer tr.CloseIdleConnections()
c := &http.Client{Transport: tr}

quickCtx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()
headReq := must.Get(http.NewRequestWithContext(quickCtx, "HEAD", urlSrc, nil))

res, err := c.Do(headReq)
if err != nil {
return err
}
if res.StatusCode != http.StatusOK {
return fmt.Errorf("HEAD %s: %v", urlSrc, res.Status)
}
if res.ContentLength <= 0 {
return fmt.Errorf("HEAD %s: unexpected Content-Length %v", urlSrc, res.ContentLength)
}
up.Logf("Download size: %v", res.ContentLength)

hashReq := must.Get(http.NewRequestWithContext(quickCtx, "GET", urlSrc+".sha256", nil))
hashRes, err := c.Do(hashReq)
func (up *updater) downloadURLToFile(pathSrc, fileDst string) (ret error) {
c, err := distsign.NewClient(up.Logf, up.PkgsAddr)
if err != nil {
return err
}
hashHex, err := io.ReadAll(io.LimitReader(hashRes.Body, 100))
hashRes.Body.Close()
if res.StatusCode != http.StatusOK {
return fmt.Errorf("GET %s.sha256: %v", urlSrc, res.Status)
}
if err != nil {
return err
}
wantHash, err := hex.DecodeString(string(strings.TrimSpace(string(hashHex))))
if err != nil {
return err
}
hash := sha256.New()

dlReq := must.Get(http.NewRequestWithContext(context.Background(), "GET", urlSrc, nil))
dlRes, err := c.Do(dlReq)
if err != nil {
return err
}
// TODO(bradfitz): resume from existing partial file on disk
if dlRes.StatusCode != http.StatusOK {
return fmt.Errorf("GET %s: %v", urlSrc, dlRes.Status)
}

of, err := os.Create(fileDst)
if err != nil {
return err
}
defer func() {
if ret != nil {
of.Close()
// TODO(bradfitz): os.Remove(fileDst) too? or keep it to resume from/debug later.
}
}()
pw := &progressWriter{total: res.ContentLength, logf: up.Logf}
n, err := io.Copy(io.MultiWriter(hash, of, pw), io.LimitReader(dlRes.Body, res.ContentLength))
if err != nil {
return err
}
if n != res.ContentLength {
return fmt.Errorf("downloaded %v; want %v", n, res.ContentLength)
}
if err := of.Close(); err != nil {
return err
}
pw.print()

if !bytes.Equal(hash.Sum(nil), wantHash) {
return fmt.Errorf("SHA-256 of downloaded MSI didn't match expected value")
}
up.Logf("hash matched")

return nil
}

type progressWriter struct {
done int64
total int64
lastPrint time.Time
logf logger.Logf
}

func (pw *progressWriter) Write(p []byte) (n int, err error) {
pw.done += int64(len(p))
if time.Since(pw.lastPrint) > 2*time.Second {
pw.print()
}
return len(p), nil
}

func (pw *progressWriter) print() {
pw.lastPrint = time.Now()
pw.logf("Downloaded %v/%v (%.1f%%)", pw.done, pw.total, float64(pw.done)/float64(pw.total)*100)
return c.Download(context.Background(), pathSrc, fileDst)
}

func (up *updater) updateFreeBSD() (err error) {
Expand Down
95 changes: 80 additions & 15 deletions clientupdate/distsign/distsign.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
package distsign

import (
"context"
"crypto/ed25519"
"crypto/rand"
"encoding/binary"
Expand All @@ -46,12 +47,17 @@ import (
"fmt"
"hash"
"io"
"log"
"net/http"
"net/url"
"os"
"time"

"github.com/hdevalence/ed25519consensus"
"golang.org/x/crypto/blake2s"
"tailscale.com/net/tshttpproxy"
"tailscale.com/types/logger"
"tailscale.com/util/must"
)

const (
Expand Down Expand Up @@ -177,18 +183,22 @@ func (ph *PackageHash) Len() int64 { return ph.len }

// Client downloads and validates files from a distribution server.
type Client struct {
logf logger.Logf
roots []ed25519.PublicKey
pkgsAddr *url.URL
}

// NewClient returns a new client for distribution server located at pkgsAddr,
// and uses embedded root keys from the roots/ subdirectory of this package.
func NewClient(pkgsAddr string) (*Client, error) {
func NewClient(logf logger.Logf, pkgsAddr string) (*Client, error) {
if logf == nil {
logf = log.Printf
}
u, err := url.Parse(pkgsAddr)
if err != nil {
return nil, fmt.Errorf("invalid pkgsAddr %q: %w", pkgsAddr, err)
}
return &Client{roots: roots(), pkgsAddr: u}, nil
return &Client{logf: logf, roots: roots(), pkgsAddr: u}, nil
}

func (c *Client) url(path string) string {
Expand All @@ -199,7 +209,7 @@ func (c *Client) url(path string) string {
// The file is downloaded to dstPath and its signature is validated using the
// embedded root keys. Download returns an error if anything goes wrong with
// the actual file download or with signature validation.
func (c *Client) Download(srcPath, dstPath string) error {
func (c *Client) Download(ctx context.Context, srcPath, dstPath string) error {
// Always fetch a fresh signing key.
sigPub, err := c.signingKeys()
if err != nil {
Expand All @@ -209,11 +219,13 @@ func (c *Client) Download(srcPath, dstPath string) error {
srcURL := c.url(srcPath)
sigURL := srcURL + ".sig"

c.logf("Downloading %q", srcURL)
dstPathUnverified := dstPath + ".unverified"
hash, len, err := download(srcURL, dstPathUnverified, downloadSizeLimit)
hash, len, err := c.download(ctx, srcURL, dstPathUnverified, downloadSizeLimit)
if err != nil {
return err
}
c.logf("Downloading %q", sigURL)
sig, err := fetch(sigURL, signatureSizeLimit)
if err != nil {
// Best-effort clean up of downloaded package.
Expand All @@ -226,6 +238,7 @@ func (c *Client) Download(srcPath, dstPath string) error {
os.Remove(dstPathUnverified)
return fmt.Errorf("signature %q for key %q does not validate with the current release signing key; either you are under attack, or attempting to download an old version of Tailscale which was signed with an older signing key", sigURL, srcURL)
}
c.logf("Signature OK")

if err := os.Rename(dstPathUnverified, dstPath); err != nil {
return fmt.Errorf("failed to move %q to %q after signature validation", dstPathUnverified, dstPath)
Expand Down Expand Up @@ -272,32 +285,84 @@ func fetch(url string, limit int64) ([]byte, error) {

// download writes the response body of url into a local file at dst, up to
// limit bytes. On success, the returned value is a BLAKE2s hash of the file.
func download(url, dst string, limit int64) ([]byte, int64, error) {
resp, err := http.Get(url)
func (c *Client) download(ctx context.Context, url, dst string, limit int64) ([]byte, int64, error) {
tr := http.DefaultTransport.(*http.Transport).Clone()
tr.Proxy = tshttpproxy.ProxyFromEnvironment
defer tr.CloseIdleConnections()
hc := &http.Client{Transport: tr}

quickCtx, cancel := context.WithTimeout(ctx, 30*time.Second)
defer cancel()
headReq := must.Get(http.NewRequestWithContext(quickCtx, http.MethodHead, url, nil))

res, err := hc.Do(headReq)
if err != nil {
return nil, 0, err
}
defer resp.Body.Close()

h := NewPackageHash()
r := io.TeeReader(io.LimitReader(resp.Body, limit), h)
if res.StatusCode != http.StatusOK {
return nil, 0, fmt.Errorf("HEAD %q: %v", url, res.Status)
}
if res.ContentLength <= 0 {
return nil, 0, fmt.Errorf("HEAD %q: unexpected Content-Length %v", url, res.ContentLength)
}
c.logf("Download size: %v", res.ContentLength)

f, err := os.Create(dst)
dlReq := must.Get(http.NewRequestWithContext(ctx, http.MethodGet, url, nil))
dlRes, err := hc.Do(dlReq)
if err != nil {
return nil, 0, err
}
defer f.Close()
defer dlRes.Body.Close()
// TODO(bradfitz): resume from existing partial file on disk
if dlRes.StatusCode != http.StatusOK {
return nil, 0, fmt.Errorf("GET %q: %v", url, dlRes.Status)
}

if _, err := io.Copy(f, r); err != nil {
of, err := os.Create(dst)
if err != nil {
return nil, 0, err
}
if err := f.Close(); err != nil {
return nil, 0, err
defer of.Close()
pw := &progressWriter{total: res.ContentLength, logf: c.logf}
h := NewPackageHash()
n, err := io.Copy(io.MultiWriter(of, h, pw), io.LimitReader(dlRes.Body, limit))
if err != nil {
return nil, n, err
}
if n != res.ContentLength {
return nil, n, fmt.Errorf("GET %q: downloaded %v, want %v", url, n, res.ContentLength)
}
if err := dlRes.Body.Close(); err != nil {
return nil, n, err
}
if err := of.Close(); err != nil {
return nil, n, err
}
pw.print()

return h.Sum(nil), h.Len(), nil
}

type progressWriter struct {
done int64
total int64
lastPrint time.Time
logf logger.Logf
}

func (pw *progressWriter) Write(p []byte) (n int, err error) {
pw.done += int64(len(p))
if time.Since(pw.lastPrint) > 2*time.Second {
pw.print()
}
return len(p), nil
}

func (pw *progressWriter) print() {
pw.lastPrint = time.Now()
pw.logf("Downloaded %v/%v (%.1f%%)", pw.done, pw.total, float64(pw.done)/float64(pw.total)*100)
}

func parsePrivateKey(data []byte, typeTag string) (ed25519.PrivateKey, error) {
b, rest := pem.Decode(data)
if b == nil {
Expand Down

0 comments on commit 0d8d649

Please sign in to comment.