From 3d21625a9c4a538998f6c5cd59ba9ea7dff6e208 Mon Sep 17 00:00:00 2001 From: Cezar Craciunoiu Date: Tue, 14 May 2024 19:13:35 +0300 Subject: [PATCH] feat(cpio): Wait for oks in parallel Also prints statistics at the end and queries at start. Signed-off-by: Cezar Craciunoiu --- initrd/ociimage.go | 4 + .../cli/kraft/cloud/volume/import/cpio.go | 284 +++++++++++++----- .../cli/kraft/cloud/volume/import/import.go | 26 +- 3 files changed, 229 insertions(+), 85 deletions(-) diff --git a/initrd/ociimage.go b/initrd/ociimage.go index 7b8b8db36..0b4125a73 100644 --- a/initrd/ociimage.go +++ b/initrd/ociimage.go @@ -36,6 +36,10 @@ type ociimage struct { // NewFromOCIImage creates a new initrd from a remote container image. func NewFromOCIImage(ctx context.Context, path string, opts ...InitrdOption) (Initrd, error) { + if _, err := os.Stat(path); err == nil { + return nil, fmt.Errorf("file with path %s already exists", path) + } + if !strings.Contains("://", path) { path = fmt.Sprintf("docker://%s", path) } diff --git a/internal/cli/kraft/cloud/volume/import/cpio.go b/internal/cli/kraft/cloud/volume/import/cpio.go index 6a06a84ac..ee3974ba4 100644 --- a/internal/cli/kraft/cloud/volume/import/cpio.go +++ b/internal/cli/kraft/cloud/volume/import/cpio.go @@ -18,10 +18,66 @@ import ( "strings" "time" + "github.com/dustin/go-humanize" "kraftkit.sh/initrd" "kraftkit.sh/internal/cpio" + "kraftkit.sh/tui/confirm" ) +// startResponse is the response sent by the server after the token is validated. +type startResponse struct { + // Free is the number of bytes free on the volume. + Free uint64 // 8 bytes + + // Total is the total number of bytes on the volume. + Total uint64 // 8 bytes + + // Maxlen is the maximum file name length that can be sent. + Maxlen uint64 // 8 bytes +} + +func parseStartRespose(resp []byte) (*startResponse, error) { + var r startResponse + + if len(resp) != 24 { + return nil, fmt.Errorf("unknown start response") + } + + err := binary.Read(bytes.NewReader(resp), binary.LittleEndian, &r) + if err != nil { + return nil, err + } + + return &r, nil +} + +type stopResponse struct { + // Free is the number of bytes free on the volume. + Free uint64 // 8 bytes + + // Total is the total number of bytes on the volume. + Total uint64 // 8 bytes + + // Maxlen is the maximum file name length that can be sent. + Maxlen uint64 // 8 bytes + +} + +func parseStopRespose(resp []byte) (*stopResponse, error) { + var r stopResponse + + if len(resp) != 24 { + return nil, fmt.Errorf("unknown stop response") + } + + err := binary.Read(bytes.NewReader(resp), binary.LittleEndian, &r) + if err != nil { + return nil, err + } + + return &r, nil +} + type okResponse struct { // status is the status code of the response. It can be 1 for success, -1 // for error, or 0 for finished sending error. @@ -38,7 +94,7 @@ type okResponse struct { const ( // The size read by the `volimport` unikernel on one socket read - msgMaxSize = 32 * 1024 // 32 KiB + msgMaxSize = 64 * 1024 // 64K ) func (r *okResponse) clear() { @@ -47,7 +103,7 @@ func (r *okResponse) clear() { r.message = nil } -func (r *okResponse) parse(resp []byte) error { +func (r *okResponse) parseMetadata(resp []byte) error { r.clear() err := binary.Read(bytes.NewReader(resp[:4]), binary.LittleEndian, &r.status) @@ -64,40 +120,115 @@ func (r *okResponse) parse(resp []byte) error { return err } + return nil +} + +func (r *okResponse) parse(resp []byte) error { + if err := r.parseMetadata(resp); err != nil { + return err + } + r.message = resp[8 : 8+r.msglen] return nil } -func (r *okResponse) waitForOK(conn *tls.Conn, errorMsg string) error { +func (r *okResponse) waitForOK(conn *tls.Conn, errorMsg string) ([]byte, error) { retErr := fmt.Errorf(errorMsg) - for it := 0; ; it++ { + for { // A message can have at max: // status - 4 bytes // msglen - 4 bytes // msg - 1024 bytes - respRaw := make([]byte, 1032) + var respHeadRawBuf []byte + var respMsgRawBuf []byte + respHeadRaw := bytes.NewBuffer(respHeadRawBuf) + respMsgRaw := bytes.NewBuffer(respMsgRawBuf) - _, err := io.ReadAtLeast(conn, respRaw, 4) + _, err := io.CopyN(respHeadRaw, conn, 8) if err != nil { - return fmt.Errorf("%w: %s", retErr, err) + return nil, fmt.Errorf("%w: reading header: %s", retErr, err) + } + + if err := r.parseMetadata(respHeadRaw.Bytes()); err != nil { + return nil, fmt.Errorf("%w: parsing header: %s", retErr, err) } + if r.msglen != 0 { + _, err = io.CopyN(respMsgRaw, conn, int64(r.msglen)) + if err != nil { + return nil, fmt.Errorf("%w: reading body: %s", retErr, err) + } + } + + respRaw := append(respHeadRaw.Bytes(), respMsgRaw.Bytes()...) if err := r.parse(respRaw); err != nil { - return fmt.Errorf("%w: %s", retErr, err) + return nil, fmt.Errorf("%w: parsing body: %s", retErr, err) } + switch { case r.status == 0: - if errorMsg != retErr.Error() { - return retErr + // If error is unchanged, it means that the server has finished sending + // and closed the connection without problems. + if retErr.Error() == errorMsg { + return r.message, nil } - return nil + + return nil, retErr case r.status == 1: - return nil + return nil, nil + case r.status == 2: + return r.message, nil case r.status < 0: - retErr = fmt.Errorf("%w: %s", retErr, strings.TrimSuffix(string(r.message), "\x0a\n")) + retErr = fmt.Errorf("%w: %s", retErr, strings.TrimSuffix(string(r.message[:len(r.message)-1]), "\n")) default: - return fmt.Errorf("unexpected status: %d", r.status) + return nil, fmt.Errorf("unexpected status: %d", r.status) + } + } +} + +// waitForOKs waits for OKs to be sent over the connection and decrements the +// waitgroup counter. +func waitForOKs(conn *tls.Conn, auth string, result chan *stopResponse, waitErr chan *error) { + var err error + var final *stopResponse + resp := okResponse{} + + // Close the context on exit + // We need to do this because the server might have closed before + // we got an answer for all messages. + defer func() { + waitErr <- &err + result <- final + }() + + for { + var stopRespRaw []byte + if stopRespRaw, err = resp.waitForOK(conn, "transmission failed"); err != nil { + if strings.Contains(err.Error(), "EOF") || + strings.Contains(err.Error(), "use of closed network connection") || + strings.Contains(err.Error(), "i/o timeout") || + strings.Contains(err.Error(), "broken pipe") { + return + } + + // Send a term signal to the server + io.Copy(conn, strings.NewReader(auth)) + + _ = conn.SetWriteDeadline(immediateNetCancel) + _ = conn.SetReadDeadline(immediateNetCancel) + + return + } else { + // If we got no error but we got a message then it means we finished + // We signal the main goroutine to exit. + if len(stopRespRaw) > 0 { + final, _ = parseStopRespose(stopRespRaw) + + // Send a term signal to the server + io.Copy(conn, strings.NewReader(auth)) + return + } } } } @@ -129,7 +260,7 @@ func buildCPIO(ctx context.Context, source string) (path string, size int64, err } // copyCPIO copies the CPIO archive at the given path over the provided tls.Conn. -func copyCPIO(ctx context.Context, conn *tls.Conn, auth, path string, timeoutS, size uint64, callback progressCallbackFunc) error { +func copyCPIO(ctx context.Context, conn *tls.Conn, auth, path string, timeoutS, size uint64, callback progressCallbackFunc) (free uint64, total uint64, err error) { var resp okResponse var currentSize uint64 @@ -150,16 +281,54 @@ func copyCPIO(ctx context.Context, conn *tls.Conn, auth, path string, timeoutS, }() if _, err := io.Copy(conn, strings.NewReader(auth)); err != nil { - return err + return 0, 0, err } - if err := resp.waitForOK(conn, "authentication failed"); err != nil { - return err + var startRespRaw []byte + if startRespRaw, err = resp.waitForOK(conn, "authentication failed"); err != nil { + return 0, 0, err + } + + volumeStartStats, err := parseStartRespose(startRespRaw) + if err != nil { + return 0, 0, err + } + + // TODO(nderjung): Decide where to move this as currently the promport is hidden + if size > volumeStartStats.Free { + response, err := confirm.NewConfirm( + fmt.Sprintf("Import might exceed volume capacity. Continue? (free: %s, required: %s, total: %s)\n", + humanize.IBytes(volumeStartStats.Free), + humanize.IBytes(size), + humanize.IBytes(volumeStartStats.Total), + ), + ) + if err != nil { + return 0, 0, err + } + + if !response { + return 0, 0, fmt.Errorf("not enough free space on volume for input data (%d/%d)", size, volumeStartStats.Free) + } } + // From this point forward we wait for OKs to be sent on a separate thread + // When returning errors we will use `returnErrors` to ensure that the + // correct error is propagated up. + + var result = make(chan *stopResponse, 1) + var waitErr = make(chan *error, 1) + + go waitForOKs(conn, auth, result, waitErr) + defer func() { + if retErr := <-waitErr; retErr != nil { + err = *retErr + } + }() + fi, err := os.Open(path) if err != nil { - return err + return 0, 0, err } defer fi.Close() @@ -172,36 +341,27 @@ func copyCPIO(ctx context.Context, conn *tls.Conn, auth, path string, timeoutS, // Iterate through the files in the archive. // Sending a file has a list of steps - // 1. Send the raw CPIO header -- wait for OK - // 2. Send the name of the file (NUL terminated) -- wait for OK + // 1. Send the raw CPIO header athe name of the file (NUL terminated) // 2'. Stop if last entry detected - // 3. Copy the file content piece by piece | Link destination -- wait for OK + // 2. Copy the file content piece by piece | Link destination +initrdLoop: for { hdr, raw, err := reader.Next() if err == io.EOF { shouldStop = true } else if err != nil { - return err + return 0, 0, err } // 1. Send the header - n, err := io.CopyN(conn, bytes.NewBuffer(raw.Bytes()), int64(len(raw.Bytes()))) + _, err = io.CopyN(conn, bytes.NewBuffer(raw.Bytes()), int64(len(raw.Bytes()))) // NOTE(antoineco): such error can be expected if volimport exited early or // a deadline was set due to cancellation. What we should convey in the error // is that the data import didn't complete, not the low-level network error. if err != nil { - if !isNetClosedError(err) { - return err - } - if n != int64(len(raw.Bytes())) { - return fmt.Errorf("incomplete write (%d/%d)", n, len(raw.Bytes())) - } - return err + break } - if err := resp.waitForOK(conn, "header copy failed"); err != nil { - return err - } currentSize += uint64(len(raw.Bytes())) updateProgress(float64(currentSize), float64(size), callback) @@ -210,21 +370,12 @@ func copyCPIO(ctx context.Context, conn *tls.Conn, auth, path string, timeoutS, // Add NUL-termination to name string as per CPIO spec nameBytesToSend = append(nameBytesToSend, 0x00) - // 2. Send the file name - n, err = io.CopyN(conn, bytes.NewReader(nameBytesToSend), int64(len(nameBytesToSend))) + // 1. Send the file name + _, err = io.CopyN(conn, bytes.NewReader(nameBytesToSend), int64(len(nameBytesToSend))) if err != nil { - if !isNetClosedError(err) { - return err - } - if n != int64(len(hdr.Name)) { - return fmt.Errorf("incomplete write (%d/%d)", n, len(hdr.Name)) - } - return err + break } - if err := resp.waitForOK(conn, "name copy failed"); err != nil { - return err - } currentSize += uint64(len(nameBytesToSend)) updateProgress(float64(currentSize), float64(size), callback) @@ -233,9 +384,6 @@ func copyCPIO(ctx context.Context, conn *tls.Conn, auth, path string, timeoutS, break } - // If nothing was copied the entry was a directory which has no size - empty := true - // 3. Send the file content. If the file is a link copy the destination // as content in this step. Copy runs uninterrupted until the whole size // was sent. @@ -253,52 +401,38 @@ func copyCPIO(ctx context.Context, conn *tls.Conn, auth, path string, timeoutS, if err == io.EOF { break } else if err != nil { - return err + return 0, 0, err } - n, err := io.CopyN(conn, bytes.NewReader(buf), int64(bread)) + _, err = io.CopyN(conn, bytes.NewReader(buf), int64(bread)) if err != nil { - if !isNetClosedError(err) { - return err - } - if n != int64(bread) { - return fmt.Errorf("incomplete write (%d/%d)", n, int64(bread)) - } - return err + break initrdLoop } - empty = false currentSize += uint64(bread) updateProgress(float64(currentSize), float64(size), callback) } } else { bread := len(hdr.Linkname) - n, err := io.CopyN(conn, bytes.NewReader([]byte(hdr.Linkname)), int64(bread)) + _, err := io.CopyN(conn, bytes.NewReader([]byte(hdr.Linkname)), int64(bread)) if err != nil { - if !isNetClosedError(err) { - return err - } - if n != int64(bread) { - return fmt.Errorf("incomplete write (%d/%d)", n, int64(bread)) - } - return err + break } - empty = false currentSize += uint64(bread) updateProgress(float64(currentSize), float64(size), callback) } + } - // Don't wait for ok if nothing was written - if !empty { - if err := resp.waitForOK(conn, "file copy failed"); err != nil { - return err - } - } + // Wait for finish or error to come from the server + final := <-result + if final == nil { + // If we got here, error will be set in the defer function + return 0, 0, fmt.Errorf("no stop response received") } - return nil + return final.Free, final.Total, nil } var ( diff --git a/internal/cli/kraft/cloud/volume/import/import.go b/internal/cli/kraft/cloud/volume/import/import.go index 58a1c1fe4..598eccbe6 100644 --- a/internal/cli/kraft/cloud/volume/import/import.go +++ b/internal/cli/kraft/cloud/volume/import/import.go @@ -50,6 +50,9 @@ func NewCmd() *cobra.Command { Example: heredoc.Doc(` # Import data from a local directory "path/to/data" to a volume named "my-volume" $ kraft cloud volume import --source path/to/data --volume my-volume + + # Import data from a docker registry "docker.io/nginx:latest" to a volume named "my-volume" + $ kraft cloud volume import --source docker.io/nginx:latest --volume my-volume `), Annotations: map[string]string{ cmdfactory.AnnotationHelpGroup: "kraftcloud-vol", @@ -67,6 +70,10 @@ func (opts *ImportOptions) Pre(cmd *cobra.Command, _ []string) error { return fmt.Errorf("must specify a value for the --volume flag") } + if finfo, err := os.Stat(opts.Source); err == nil && !finfo.IsDir() { + return fmt.Errorf("local source path must be a directory") + } + err := utils.PopulateMetroToken(cmd, &opts.Metro, &opts.Token) if err != nil { return fmt.Errorf("could not populate metro and token: %w", err) @@ -130,8 +137,7 @@ func importVolumeData(ctx context.Context, opts *ImportOptions) (retErr error) { }() var volUUID string - var volSize int64 - if volUUID, volSize, err = volumeSanityCheck(ctx, vcli, opts.VolID, cpioSize); err != nil { + if volUUID, _, err = volumeSanityCheck(ctx, vcli, opts.VolID, cpioSize); err != nil { return err } @@ -166,6 +172,8 @@ func importVolumeData(ctx context.Context, opts *ImportOptions) (retErr error) { // nil error upon context cancellation. We temporarily handle potential copy // errors ourselves here. var copyCPIOErr error + var freeSpace uint64 + var totalSpace uint64 paraprogress, err := paraProgress(ctx, fmt.Sprintf("Importing data (%s)", humanize.IBytes(uint64(cpioSize))), func(ctx context.Context, callback func(float64)) (retErr error) { @@ -174,13 +182,11 @@ func importVolumeData(ctx context.Context, opts *ImportOptions) (retErr error) { if err != nil { return fmt.Errorf("connecting to volume data import instance send port: %w", err) } - defer func() { - retErr = errors.Join(retErr, conn.Close()) - }() + defer conn.Close() ctx, cancel := context.WithCancel(ctx) defer cancel() - err = copyCPIO(ctx, conn, authStr, cpioPath, opts.Timeout, uint64(cpioSize), callback) + freeSpace, totalSpace, err = copyCPIO(ctx, conn, authStr, cpioPath, opts.Timeout, uint64(cpioSize), callback) copyCPIOErr = err return err }, @@ -201,12 +207,12 @@ func importVolumeData(ctx context.Context, opts *ImportOptions) (retErr error) { Value: opts.VolID, }, fancymap.FancyMapEntry{ - Key: "imported", - Value: humanize.IBytes(uint64(cpioSize)), + Key: "free", + Value: humanize.IBytes(freeSpace), }, fancymap.FancyMapEntry{ - Key: "capacity", - Value: humanize.IBytes(uint64(volSize)), + Key: "total", + Value: humanize.IBytes(totalSpace), }, )