diff --git a/initrd/ociimage.go b/initrd/ociimage.go index 7b8b8db36..0b4125a73 100644 --- a/initrd/ociimage.go +++ b/initrd/ociimage.go @@ -36,6 +36,10 @@ type ociimage struct { // NewFromOCIImage creates a new initrd from a remote container image. func NewFromOCIImage(ctx context.Context, path string, opts ...InitrdOption) (Initrd, error) { + if _, err := os.Stat(path); err == nil { + return nil, fmt.Errorf("file with path %s already exists", path) + } + if !strings.Contains("://", path) { path = fmt.Sprintf("docker://%s", path) } diff --git a/internal/cli/kraft/cloud/volume/import/cpio.go b/internal/cli/kraft/cloud/volume/import/cpio.go index 6a06a84ac..7f3925bfc 100644 --- a/internal/cli/kraft/cloud/volume/import/cpio.go +++ b/internal/cli/kraft/cloud/volume/import/cpio.go @@ -16,6 +16,8 @@ import ( "net" "os" "strings" + "sync" + "sync/atomic" "time" "kraftkit.sh/initrd" @@ -47,7 +49,7 @@ func (r *okResponse) clear() { r.message = nil } -func (r *okResponse) parse(resp []byte) error { +func (r *okResponse) parseMetadata(resp []byte) error { r.clear() err := binary.Read(bytes.NewReader(resp[:4]), binary.LittleEndian, &r.status) @@ -64,6 +66,14 @@ func (r *okResponse) parse(resp []byte) error { return err } + return nil +} + +func (r *okResponse) parse(resp []byte) error { + if err := r.parseMetadata(resp); err != nil { + return err + } + r.message = resp[8 : 8+r.msglen] return nil @@ -71,18 +81,31 @@ func (r *okResponse) parse(resp []byte) error { func (r *okResponse) waitForOK(conn *tls.Conn, errorMsg string) error { retErr := fmt.Errorf(errorMsg) - for it := 0; ; it++ { + for { // A message can have at max: // status - 4 bytes // msglen - 4 bytes // msg - 1024 bytes - respRaw := make([]byte, 1032) + respHeadRaw := make([]byte, 8) + respMsgRaw := make([]byte, 1024) - _, err := io.ReadAtLeast(conn, respRaw, 4) + _, err := io.CopyN(bytes.NewBuffer(respHeadRaw), conn, 8) if err != nil { return fmt.Errorf("%w: %s", retErr, err) } + if err := r.parseMetadata(respHeadRaw); err != nil { + return fmt.Errorf("%w: %s", retErr, err) + } + + if r.msglen != 0 { + _, err = io.CopyN(bytes.NewBuffer(respMsgRaw), conn, int64(r.msglen)) + if err != nil { + return fmt.Errorf("%w: %s", retErr, err) + } + } + + respRaw := append(respHeadRaw, respMsgRaw...) if err := r.parse(respRaw); err != nil { return fmt.Errorf("%w: %s", retErr, err) } @@ -102,6 +125,33 @@ func (r *okResponse) waitForOK(conn *tls.Conn, errorMsg string) error { } } +var counterFileName int32 +var counterFileContent int32 +var counterHdr int32 +var counter int32 + +// waitForOKs waits for OKs to be sent over the connection and decrements the +// waitgroup counter. +func waitForOKs(conn *tls.Conn, waitErr *error, waitFor *sync.WaitGroup) { + resp := okResponse{} + + for { + if err := resp.waitForOK(conn, "transmission failed"); err != nil { + _ = conn.SetWriteDeadline(immediateNetCancel) + _ = conn.SetReadDeadline(immediateNetCancel) + + if err == io.EOF { + return + } + *waitErr = err + + return + } + atomic.AddInt32(&counter, -1) + // waitFor.Done() + } +} + // buildCPIO generates a CPIO archive from the data at the given source. func buildCPIO(ctx context.Context, source string) (path string, size int64, err error) { if source == "." { @@ -129,9 +179,11 @@ func buildCPIO(ctx context.Context, source string) (path string, size int64, err } // copyCPIO copies the CPIO archive at the given path over the provided tls.Conn. -func copyCPIO(ctx context.Context, conn *tls.Conn, auth, path string, timeoutS, size uint64, callback progressCallbackFunc) error { +func copyCPIO(ctx context.Context, conn *tls.Conn, auth, path string, timeoutS, size uint64, callback progressCallbackFunc) (err error) { var resp okResponse var currentSize uint64 + var waitErr error + var waitFor sync.WaitGroup // NOTE(antoineco): this call is critical as it allows writes to be later // cancelled, because the deadline applies to all future and pending I/O and @@ -157,6 +209,16 @@ func copyCPIO(ctx context.Context, conn *tls.Conn, auth, path string, timeoutS, return err } + // From this point forward we wait for OKs to be sent on a separate thread + // When returning errors we will use `returnErrors` to ensure that the + // correct error is propagated up. + go waitForOKs(conn, &waitErr, &waitFor) + defer func() { + if waitErr != nil { + err = waitErr + } + }() + fi, err := os.Open(path) if err != nil { return err @@ -185,6 +247,8 @@ func copyCPIO(ctx context.Context, conn *tls.Conn, auth, path string, timeoutS, } // 1. Send the header + waitFor.Add(1) + atomic.AddInt32(&counterHdr, 1) n, err := io.CopyN(conn, bytes.NewBuffer(raw.Bytes()), int64(len(raw.Bytes()))) // NOTE(antoineco): such error can be expected if volimport exited early or // a deadline was set due to cancellation. What we should convey in the error @@ -199,9 +263,6 @@ func copyCPIO(ctx context.Context, conn *tls.Conn, auth, path string, timeoutS, return err } - if err := resp.waitForOK(conn, "header copy failed"); err != nil { - return err - } currentSize += uint64(len(raw.Bytes())) updateProgress(float64(currentSize), float64(size), callback) @@ -211,6 +272,10 @@ func copyCPIO(ctx context.Context, conn *tls.Conn, auth, path string, timeoutS, nameBytesToSend = append(nameBytesToSend, 0x00) // 2. Send the file name + if !shouldStop { + waitFor.Add(1) + atomic.AddInt32(&counterFileName, 1) + } n, err = io.CopyN(conn, bytes.NewReader(nameBytesToSend), int64(len(nameBytesToSend))) if err != nil { if !isNetClosedError(err) { @@ -222,9 +287,6 @@ func copyCPIO(ctx context.Context, conn *tls.Conn, auth, path string, timeoutS, return err } - if err := resp.waitForOK(conn, "name copy failed"); err != nil { - return err - } currentSize += uint64(len(nameBytesToSend)) updateProgress(float64(currentSize), float64(size), callback) @@ -256,6 +318,10 @@ func copyCPIO(ctx context.Context, conn *tls.Conn, auth, path string, timeoutS, return err } + if empty { + waitFor.Add(1) + atomic.AddInt32(&counterFileContent, 1) + } n, err := io.CopyN(conn, bytes.NewReader(buf), int64(bread)) if err != nil { if !isNetClosedError(err) { @@ -274,6 +340,8 @@ func copyCPIO(ctx context.Context, conn *tls.Conn, auth, path string, timeoutS, } else { bread := len(hdr.Linkname) + waitFor.Add(1) + atomic.AddInt32(&counterFileContent, 1) n, err := io.CopyN(conn, bytes.NewReader([]byte(hdr.Linkname)), int64(bread)) if err != nil { if !isNetClosedError(err) { @@ -289,14 +357,18 @@ func copyCPIO(ctx context.Context, conn *tls.Conn, auth, path string, timeoutS, currentSize += uint64(bread) updateProgress(float64(currentSize), float64(size), callback) } + } - // Don't wait for ok if nothing was written - if !empty { - if err := resp.waitForOK(conn, "file copy failed"); err != nil { - return err - } + go func() { + for { + fmt.Printf("Waiting for OKs: %d\n\n\n\n", atomic.LoadInt32(&counter)) + fmt.Printf("Waiting for OKs: %d\n\n\n\n", atomic.LoadInt32(&counterHdr)) + fmt.Printf("Waiting for OKs: %d\n\n\n\n", atomic.LoadInt32(&counterFileName)) + fmt.Printf("Waiting for OKs: %d\n\n\n\n", atomic.LoadInt32(&counterFileContent)) + time.Sleep(1 * time.Second) } - } + }() + waitFor.Wait() return nil }