From d39fa2024ceb8c3152d44ebdbb57f497eff760fb Mon Sep 17 00:00:00 2001 From: Andrey Smirnov Date: Fri, 30 Aug 2024 18:04:03 +0400 Subject: [PATCH] fix: add context to luks commands Make sure commands can be aborted on timeout. Signed-off-by: Andrey Smirnov --- encryption/luks/luks.go | 52 ++++++++++++++++++------------------ encryption/luks/luks_test.go | 50 ++++++++++++++++++---------------- encryption/provider.go | 21 ++++++++------- 3 files changed, 64 insertions(+), 59 deletions(-) diff --git a/encryption/luks/luks.go b/encryption/luks/luks.go index f7e0e1a..fc0232e 100644 --- a/encryption/luks/luks.go +++ b/encryption/luks/luks.go @@ -133,14 +133,14 @@ func New(cipher Cipher, options ...Option) *LUKS { } // Open runs luksOpen on a device and returns mapped device path. -func (l *LUKS) Open(deviceName, mappedName string, key *encryption.Key) (string, error) { +func (l *LUKS) Open(ctx context.Context, deviceName, mappedName string, key *encryption.Key) (string, error) { args := slices.Concat( []string{"luksOpen", deviceName, mappedName, "--key-file=-"}, keyslotArgs(key), l.perfArgs(), ) - _, err := l.runCommand(args, key.Value) + _, err := l.runCommand(ctx, args, key.Value) if err != nil { return "", err } @@ -149,7 +149,7 @@ func (l *LUKS) Open(deviceName, mappedName string, key *encryption.Key) (string, } // Encrypt implements encryption.Provider. -func (l *LUKS) Encrypt(deviceName string, key *encryption.Key) error { +func (l *LUKS) Encrypt(ctx context.Context, deviceName string, key *encryption.Key) error { cipher, err := l.cipher.String() if err != nil { return err @@ -166,29 +166,29 @@ func (l *LUKS) Encrypt(deviceName string, key *encryption.Key) error { args = append(args, fmt.Sprintf("--sector-size=%d", l.blockSize)) } - _, err = l.runCommand(args, key.Value) + _, err = l.runCommand(ctx, args, key.Value) return err } // Resize implements encryption.Provider. -func (l *LUKS) Resize(devname string, key *encryption.Key) error { +func (l *LUKS) Resize(ctx context.Context, devname string, key *encryption.Key) error { args := []string{"resize", devname, "--key-file=-"} - _, err := l.runCommand(args, key.Value) + _, err := l.runCommand(ctx, args, key.Value) return err } // Close implements encryption.Provider. -func (l *LUKS) Close(devname string) error { - _, err := l.runCommand([]string{"luksClose", devname}, nil) +func (l *LUKS) Close(ctx context.Context, devname string) error { + _, err := l.runCommand(ctx, []string{"luksClose", devname}, nil) return err } // AddKey adds a new key at the LUKS encryption slot. -func (l *LUKS) AddKey(devname string, key, newKey *encryption.Key) error { +func (l *LUKS) AddKey(ctx context.Context, devname string, key, newKey *encryption.Key) error { var buffer bytes.Buffer keyfileLen, _ := buffer.Write(key.Value) @@ -206,13 +206,13 @@ func (l *LUKS) AddKey(devname string, key, newKey *encryption.Key) error { keyslotArgs(newKey), ) - _, err := l.runCommand(args, buffer.Bytes()) + _, err := l.runCommand(ctx, args, buffer.Bytes()) return err } // SetKey sets new key value at the LUKS encryption slot. -func (l *LUKS) SetKey(devname string, oldKey, newKey *encryption.Key) error { +func (l *LUKS) SetKey(ctx context.Context, devname string, oldKey, newKey *encryption.Key) error { if oldKey.Slot != newKey.Slot { return fmt.Errorf("old and new key slots must match") } @@ -234,19 +234,19 @@ func (l *LUKS) SetKey(devname string, oldKey, newKey *encryption.Key) error { l.perfArgs(), ) - _, err := l.runCommand(args, buffer.Bytes()) + _, err := l.runCommand(ctx, args, buffer.Bytes()) return err } // CheckKey checks if the key is valid. -func (l *LUKS) CheckKey(devname string, key *encryption.Key) (bool, error) { +func (l *LUKS) CheckKey(ctx context.Context, devname string, key *encryption.Key) (bool, error) { args := slices.Concat( []string{"luksOpen", "--test-passphrase", devname, "--key-file=-"}, keyslotArgs(key), ) - _, err := l.runCommand(args, key.Value) + _, err := l.runCommand(ctx, args, key.Value) if err != nil { if err == encryption.ErrEncryptionKeyRejected { //nolint:errorlint return false, nil @@ -259,13 +259,13 @@ func (l *LUKS) CheckKey(devname string, key *encryption.Key) (bool, error) { } // RemoveKey removes a key at the specified LUKS encryption slot. -func (l *LUKS) RemoveKey(devname string, slot int, key *encryption.Key) error { - _, err := l.runCommand([]string{"luksKillSlot", devname, strconv.Itoa(slot), "--key-file=-"}, key.Value) +func (l *LUKS) RemoveKey(ctx context.Context, devname string, slot int, key *encryption.Key) error { + _, err := l.runCommand(ctx, []string{"luksKillSlot", devname, strconv.Itoa(slot), "--key-file=-"}, key.Value) if err != nil { return err } - if err = l.RemoveToken(devname, slot); err != nil && !errors.Is(err, encryption.ErrTokenNotFound) { + if err = l.RemoveToken(ctx, devname, slot); err != nil && !errors.Is(err, encryption.ErrTokenNotFound) { return err } @@ -306,7 +306,7 @@ func (l *LUKS) ReadKeyslots(deviceName string) (*encryption.Keyslots, error) { // SetToken adds arbitrary token to the key slot. // Token id == slot id: only one token per key slot is supported. -func (l *LUKS) SetToken(devname string, slot int, token token.Token) error { +func (l *LUKS) SetToken(ctx context.Context, devname string, slot int, token token.Token) error { data, err := token.Bytes() if err != nil { return err @@ -314,14 +314,14 @@ func (l *LUKS) SetToken(devname string, slot int, token token.Token) error { id := strconv.Itoa(slot) - _, err = l.runCommand([]string{"token", "import", "-q", devname, "--token-id", id, "--json-file=-", "--token-replace"}, data) + _, err = l.runCommand(ctx, []string{"token", "import", "-q", devname, "--token-id", id, "--json-file=-", "--token-replace"}, data) return err } // ReadToken reads arbitrary token from the luks metadata. -func (l *LUKS) ReadToken(devname string, slot int, token token.Token) error { - stdout, err := l.runCommand([]string{"token", "export", "-q", devname, "--token-id", strconv.Itoa(slot), "--json-file=-"}, nil) +func (l *LUKS) ReadToken(ctx context.Context, devname string, slot int, token token.Token) error { + stdout, err := l.runCommand(ctx, []string{"token", "export", "-q", devname, "--token-id", strconv.Itoa(slot), "--json-file=-"}, nil) if err != nil { return err } @@ -330,8 +330,8 @@ func (l *LUKS) ReadToken(devname string, slot int, token token.Token) error { } // RemoveToken removes token from the luks metadata. -func (l *LUKS) RemoveToken(devname string, slot int) error { - _, err := l.runCommand([]string{"token", "remove", "--token-id", strconv.Itoa(slot), devname}, nil) +func (l *LUKS) RemoveToken(ctx context.Context, devname string, slot int) error { + _, err := l.runCommand(ctx, []string{"token", "remove", "--token-id", strconv.Itoa(slot), devname}, nil) return err } @@ -339,10 +339,10 @@ func (l *LUKS) RemoveToken(devname string, slot int) error { var notFoundMatcher = regexp.MustCompile("(is not in use|Failed to get token)") // runCommand executes cryptsetup with arguments. -func (l *LUKS) runCommand(args []string, stdin []byte) (string, error) { +func (l *LUKS) runCommand(ctx context.Context, args []string, stdin []byte) (string, error) { stdout, err := cmd.RunContext(cmd.WithStdin( - context.Background(), - bytes.NewBuffer(stdin)), "cryptsetup", args...) + ctx, + bytes.NewReader(stdin)), "cryptsetup", args...) if err != nil { var exitError *cmd.ExitError diff --git a/encryption/luks/luks_test.go b/encryption/luks/luks_test.go index fb36c2c..b3ca8f6 100644 --- a/encryption/luks/luks_test.go +++ b/encryption/luks/luks_test.go @@ -5,6 +5,7 @@ package luks_test import ( + "context" "errors" randv2 "math/rand/v2" "os" @@ -31,6 +32,9 @@ const ( ) func testEncrypt(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute) + t.Cleanup(cancel) + tmpDir := t.TempDir() rawImage := filepath.Join(tmpDir, "image.raw") @@ -97,21 +101,21 @@ func testEncrypt(t *testing.T) { t.Logf("unencrypted partition path %s", path) - require.NoError(t, provider.Encrypt(path, key)) + require.NoError(t, provider.Encrypt(ctx, path, key)) - encryptedPath, err := provider.Open(path, mappedName, key) + encryptedPath, err := provider.Open(ctx, path, mappedName, key) require.NoError(t, err) - require.NoError(t, provider.Resize(encryptedPath, key)) + require.NoError(t, provider.Resize(ctx, encryptedPath, key)) - require.NoError(t, provider.AddKey(path, key, keyExtra)) - require.NoError(t, provider.SetKey(path, keyExtra, keyExtra)) + require.NoError(t, provider.AddKey(ctx, path, key, keyExtra)) + require.NoError(t, provider.SetKey(ctx, path, keyExtra, keyExtra)) - valid, err := provider.CheckKey(path, keyExtra) + valid, err := provider.CheckKey(ctx, path, keyExtra) require.NoError(t, err) require.True(t, valid) - valid, err = provider.CheckKey(path, encryption.NewKey(1, []byte("nope"))) + valid, err = provider.CheckKey(ctx, path, encryption.NewKey(1, []byte("nope"))) require.NoError(t, err) require.False(t, valid) @@ -131,36 +135,36 @@ func testEncrypt(t *testing.T) { Type: "sealedkey", } - err = provider.SetToken(path, 0, token) + err = provider.SetToken(ctx, path, 0, token) require.NoError(t, err) - err = provider.ReadToken(path, 0, token) + err = provider.ReadToken(ctx, path, 0, token) require.NoError(t, err) require.Equal(t, token.UserData.SealedKey, "aaaa") - require.NoError(t, provider.RemoveToken(path, 0)) - require.Error(t, provider.ReadToken(path, 0, token)) + require.NoError(t, provider.RemoveToken(ctx, path, 0)) + require.Error(t, provider.ReadToken(ctx, path, 0, token)) // create and replace token - err = provider.SetToken(path, 0, token) + err = provider.SetToken(ctx, path, 0, token) require.NoError(t, err) token.UserData.SealedKey = "bbbb" - err = provider.SetToken(path, 0, token) + err = provider.SetToken(ctx, path, 0, token) require.NoError(t, err) require.NoError(t, unix.Mount(encryptedPath, mountPath, "vfat", 0, "")) require.NoError(t, unix.Unmount(mountPath, 0)) - require.NoError(t, provider.Close(encryptedPath)) - require.Error(t, provider.Close(encryptedPath)) + require.NoError(t, provider.Close(ctx, encryptedPath)) + require.Error(t, provider.Close(ctx, encryptedPath)) // second key slot - encryptedPath, err = provider.Open(path, mappedName, keyExtra) + encryptedPath, err = provider.Open(ctx, path, mappedName, keyExtra) require.NoError(t, err) - require.NoError(t, provider.Close(encryptedPath)) + require.NoError(t, provider.Close(ctx, encryptedPath)) // check keyslots list keyslots, err := provider.ReadKeyslots(path) @@ -172,23 +176,23 @@ func testEncrypt(t *testing.T) { require.True(t, ok) // remove key slot - err = provider.RemoveKey(path, 1, key) + err = provider.RemoveKey(ctx, path, 1, key) require.NoError(t, err) - _, err = provider.Open(path, mappedName, keyExtra) + _, err = provider.Open(ctx, path, mappedName, keyExtra) require.Equal(t, err, encryption.ErrEncryptionKeyRejected) - valid, err = provider.CheckKey(path, key) + valid, err = provider.CheckKey(ctx, path, key) require.NoError(t, err) require.True(t, valid) // unhappy cases - _, err = provider.Open(path, mappedName, encryption.NewKey(0, []byte("エクスプロシオン"))) + _, err = provider.Open(ctx, path, mappedName, encryption.NewKey(0, []byte("エクスプロシオン"))) require.Equal(t, err, encryption.ErrEncryptionKeyRejected) - _, err = provider.Open("/dev/nosuchdevice", mappedName, encryption.NewKey(0, []byte("エクスプロシオン"))) + _, err = provider.Open(ctx, "/dev/nosuchdevice", mappedName, encryption.NewKey(0, []byte("エクスプロシオン"))) require.Error(t, err) - _, err = provider.Open(loDev.Path(), mappedName, key) + _, err = provider.Open(ctx, loDev.Path(), mappedName, key) require.Error(t, err) } diff --git a/encryption/provider.go b/encryption/provider.go index 5977e50..3aa9d4d 100644 --- a/encryption/provider.go +++ b/encryption/provider.go @@ -5,6 +5,7 @@ package encryption import ( + "context" "fmt" "github.com/siderolabs/go-blockdevice/v2/encryption/token" @@ -20,21 +21,21 @@ const ( // Provider represents encryption utility methods. type Provider interface { TokenProvider - Encrypt(devname string, key *Key) error - Open(devname, mappedName string, key *Key) (string, error) - Close(devname string) error - AddKey(devname string, key, newKey *Key) error - SetKey(devname string, key, newKey *Key) error - CheckKey(devname string, key *Key) (bool, error) - RemoveKey(devname string, slot int, key *Key) error + Encrypt(ctx context.Context, devname string, key *Key) error + Open(ctx context.Context, devname, mappedName string, key *Key) (string, error) + Close(ctx context.Context, devname string) error + AddKey(ctx context.Context, devname string, key, newKey *Key) error + SetKey(ctx context.Context, devname string, key, newKey *Key) error + CheckKey(ctx context.Context, devname string, key *Key) (bool, error) + RemoveKey(ctx context.Context, devname string, slot int, key *Key) error ReadKeyslots(deviceName string) (*Keyslots, error) } // TokenProvider represents token management methods. type TokenProvider interface { - SetToken(devname string, slot int, token token.Token) error - ReadToken(devname string, slot int, token token.Token) error - RemoveToken(devname string, slot int) error + SetToken(ctx context.Context, devname string, slot int, token token.Token) error + ReadToken(ctx context.Context, devname string, slot int, token token.Token) error + RemoveToken(ctx context.Context, devname string, slot int) error } var (