Skip to content

Commit

Permalink
wip(cpio): Copy in parallel
Browse files Browse the repository at this point in the history
Signed-off-by: Cezar Craciunoiu <cezar.craciunoiu@unikraft.io>
  • Loading branch information
craciunoiuc committed May 14, 2024
1 parent 99b9159 commit 1c888ed
Show file tree
Hide file tree
Showing 2 changed files with 93 additions and 17 deletions.
4 changes: 4 additions & 0 deletions initrd/ociimage.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,10 @@ type ociimage struct {

// NewFromOCIImage creates a new initrd from a remote container image.
func NewFromOCIImage(ctx context.Context, path string, opts ...InitrdOption) (Initrd, error) {
if _, err := os.Stat(path); err == nil {
return nil, fmt.Errorf("file with path %s already exists", path)
}

if !strings.Contains("://", path) {
path = fmt.Sprintf("docker://%s", path)
}
Expand Down
106 changes: 89 additions & 17 deletions internal/cli/kraft/cloud/volume/import/cpio.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@ import (
"net"
"os"
"strings"
"sync"
"sync/atomic"
"time"

"kraftkit.sh/initrd"
Expand Down Expand Up @@ -47,7 +49,7 @@ func (r *okResponse) clear() {
r.message = nil
}

func (r *okResponse) parse(resp []byte) error {
func (r *okResponse) parseMetadata(resp []byte) error {
r.clear()

err := binary.Read(bytes.NewReader(resp[:4]), binary.LittleEndian, &r.status)
Expand All @@ -64,25 +66,46 @@ func (r *okResponse) parse(resp []byte) error {
return err
}

return nil
}

func (r *okResponse) parse(resp []byte) error {
if err := r.parseMetadata(resp); err != nil {
return err
}

r.message = resp[8 : 8+r.msglen]

return nil
}

func (r *okResponse) waitForOK(conn *tls.Conn, errorMsg string) error {
retErr := fmt.Errorf(errorMsg)
for it := 0; ; it++ {
for {
// A message can have at max:
// status - 4 bytes
// msglen - 4 bytes
// msg - 1024 bytes
respRaw := make([]byte, 1032)
respHeadRaw := make([]byte, 8)
respMsgRaw := make([]byte, 1024)

_, err := io.ReadAtLeast(conn, respRaw, 4)
_, err := io.CopyN(bytes.NewBuffer(respHeadRaw), conn, 8)
if err != nil {
return fmt.Errorf("%w: %s", retErr, err)
}

if err := r.parseMetadata(respHeadRaw); err != nil {
return fmt.Errorf("%w: %s", retErr, err)
}

if r.msglen != 0 {
_, err = io.CopyN(bytes.NewBuffer(respMsgRaw), conn, int64(r.msglen))
if err != nil {
return fmt.Errorf("%w: %s", retErr, err)
}
}

respRaw := append(respHeadRaw, respMsgRaw...)
if err := r.parse(respRaw); err != nil {
return fmt.Errorf("%w: %s", retErr, err)
}
Expand All @@ -102,6 +125,33 @@ func (r *okResponse) waitForOK(conn *tls.Conn, errorMsg string) error {
}
}

var counterFileName int32
var counterFileContent int32
var counterHdr int32
var counter int32

// waitForOKs waits for OKs to be sent over the connection and decrements the
// waitgroup counter.
func waitForOKs(conn *tls.Conn, waitErr *error, waitFor *sync.WaitGroup) {
resp := okResponse{}

for {
if err := resp.waitForOK(conn, "transmission failed"); err != nil {
_ = conn.SetWriteDeadline(immediateNetCancel)
_ = conn.SetReadDeadline(immediateNetCancel)

if err == io.EOF {
return
}
*waitErr = err

return
}
atomic.AddInt32(&counter, -1)
// waitFor.Done()
}
}

// buildCPIO generates a CPIO archive from the data at the given source.
func buildCPIO(ctx context.Context, source string) (path string, size int64, err error) {
if source == "." {
Expand Down Expand Up @@ -129,9 +179,11 @@ func buildCPIO(ctx context.Context, source string) (path string, size int64, err
}

// copyCPIO copies the CPIO archive at the given path over the provided tls.Conn.
func copyCPIO(ctx context.Context, conn *tls.Conn, auth, path string, timeoutS, size uint64, callback progressCallbackFunc) error {
func copyCPIO(ctx context.Context, conn *tls.Conn, auth, path string, timeoutS, size uint64, callback progressCallbackFunc) (err error) {
var resp okResponse
var currentSize uint64
var waitErr error
var waitFor sync.WaitGroup

// NOTE(antoineco): this call is critical as it allows writes to be later
// cancelled, because the deadline applies to all future and pending I/O and
Expand All @@ -157,6 +209,16 @@ func copyCPIO(ctx context.Context, conn *tls.Conn, auth, path string, timeoutS,
return err
}

// From this point forward we wait for OKs to be sent on a separate thread
// When returning errors we will use `returnErrors` to ensure that the
// correct error is propagated up.
go waitForOKs(conn, &waitErr, &waitFor)
defer func() {
if waitErr != nil {
err = waitErr
}
}()

fi, err := os.Open(path)
if err != nil {
return err
Expand Down Expand Up @@ -185,6 +247,8 @@ func copyCPIO(ctx context.Context, conn *tls.Conn, auth, path string, timeoutS,
}

// 1. Send the header
waitFor.Add(1)
atomic.AddInt32(&counterHdr, 1)
n, err := io.CopyN(conn, bytes.NewBuffer(raw.Bytes()), int64(len(raw.Bytes())))
// NOTE(antoineco): such error can be expected if volimport exited early or
// a deadline was set due to cancellation. What we should convey in the error
Expand All @@ -199,9 +263,6 @@ func copyCPIO(ctx context.Context, conn *tls.Conn, auth, path string, timeoutS,
return err
}

if err := resp.waitForOK(conn, "header copy failed"); err != nil {
return err
}
currentSize += uint64(len(raw.Bytes()))
updateProgress(float64(currentSize), float64(size), callback)

Expand All @@ -211,6 +272,10 @@ func copyCPIO(ctx context.Context, conn *tls.Conn, auth, path string, timeoutS,
nameBytesToSend = append(nameBytesToSend, 0x00)

// 2. Send the file name
if !shouldStop {
waitFor.Add(1)
atomic.AddInt32(&counterFileName, 1)
}
n, err = io.CopyN(conn, bytes.NewReader(nameBytesToSend), int64(len(nameBytesToSend)))
if err != nil {
if !isNetClosedError(err) {
Expand All @@ -222,9 +287,6 @@ func copyCPIO(ctx context.Context, conn *tls.Conn, auth, path string, timeoutS,
return err
}

if err := resp.waitForOK(conn, "name copy failed"); err != nil {
return err
}
currentSize += uint64(len(nameBytesToSend))
updateProgress(float64(currentSize), float64(size), callback)

Expand Down Expand Up @@ -256,6 +318,10 @@ func copyCPIO(ctx context.Context, conn *tls.Conn, auth, path string, timeoutS,
return err
}

if empty {
waitFor.Add(1)
atomic.AddInt32(&counterFileContent, 1)
}
n, err := io.CopyN(conn, bytes.NewReader(buf), int64(bread))
if err != nil {
if !isNetClosedError(err) {
Expand All @@ -274,6 +340,8 @@ func copyCPIO(ctx context.Context, conn *tls.Conn, auth, path string, timeoutS,
} else {
bread := len(hdr.Linkname)

waitFor.Add(1)
atomic.AddInt32(&counterFileContent, 1)
n, err := io.CopyN(conn, bytes.NewReader([]byte(hdr.Linkname)), int64(bread))
if err != nil {
if !isNetClosedError(err) {
Expand All @@ -289,14 +357,18 @@ func copyCPIO(ctx context.Context, conn *tls.Conn, auth, path string, timeoutS,
currentSize += uint64(bread)
updateProgress(float64(currentSize), float64(size), callback)
}
}

// Don't wait for ok if nothing was written
if !empty {
if err := resp.waitForOK(conn, "file copy failed"); err != nil {
return err
}
go func() {
for {
fmt.Printf("Waiting for OKs: %d\n\n\n\n", atomic.LoadInt32(&counter))
fmt.Printf("Waiting for OKs: %d\n\n\n\n", atomic.LoadInt32(&counterHdr))
fmt.Printf("Waiting for OKs: %d\n\n\n\n", atomic.LoadInt32(&counterFileName))
fmt.Printf("Waiting for OKs: %d\n\n\n\n", atomic.LoadInt32(&counterFileContent))
time.Sleep(1 * time.Second)
}
}
}()
waitFor.Wait()

return nil
}
Expand Down

0 comments on commit 1c888ed

Please sign in to comment.