diff --git a/cmd/config.go b/cmd/config.go index 1a7e4387df3..5fc081806e9 100644 --- a/cmd/config.go +++ b/cmd/config.go @@ -952,10 +952,8 @@ func (c *Config) newRootCmd() (*cobra.Command, error) { func (c *Config) persistentPostRunRootE(cmd *cobra.Command, args []string) error { defer pprof.StopCPUProfile() - if c.persistentState != nil { - if err := c.persistentState.Close(); err != nil { - return err - } + if err := c.persistentState.Close(); err != nil { + return err } if boolAnnotation(cmd, modifiesConfigFile) { @@ -1097,7 +1095,7 @@ func (c *Config) persistentPreRunRootE(cmd *cobra.Command, args []string) error return err } default: - c.persistentState = nil + c.persistentState = chezmoi.NullPersistentState{} } if c.debug && c.persistentState != nil { c.persistentState = chezmoi.NewDebugPersistentState(c.persistentState) @@ -1256,6 +1254,9 @@ func (c *Config) run(dir chezmoi.AbsPath, name string, args []string) error { } func (c *Config) runEditor(args []string) error { + if err := c.persistentState.Close(); err != nil { + return err + } editor, editorArgs := c.editor() return c.run("", editor, append(editorArgs, args...)) } diff --git a/cmd/mergecmd.go b/cmd/mergecmd.go index 089404c83cf..2499214cc7c 100644 --- a/cmd/mergecmd.go +++ b/cmd/mergecmd.go @@ -79,6 +79,9 @@ func (c *Config) runMergeCmd(cmd *cobra.Command, args []string, sourceState *che string(c.sourceDirAbsPath.Join(sourceStateEntry.SourceRelPath().RelPath())), string(targetStatePath), ) + if err := c.persistentState.Close(); err != nil { + return err + } if err := c.run(c.destDirAbsPath, c.Merge.Command, args); err != nil { return fmt.Errorf("%s: %w", targetRelPath, err) } diff --git a/internal/chezmoi/boltpersistentstate.go b/internal/chezmoi/boltpersistentstate.go index 70bf2c5f190..6cdd514362c 100644 --- a/internal/chezmoi/boltpersistentstate.go +++ b/internal/chezmoi/boltpersistentstate.go @@ -18,19 +18,29 @@ const ( // A BoltPersistentState is a state persisted with bolt. type BoltPersistentState struct { - db *bbolt.DB + system System + empty bool + path AbsPath + options bbolt.Options + db *bbolt.DB } // NewBoltPersistentState returns a new BoltPersistentState. func NewBoltPersistentState(system System, path AbsPath, mode BoltPersistentStateMode) (*BoltPersistentState, error) { - if _, err := system.Stat(path); os.IsNotExist(err) { - if mode == BoltPersistentStateReadOnly { - return &BoltPersistentState{}, nil - } - if err := MkdirAll(system, path.Dir(), 0o777); err != nil { - return nil, err - } + empty := false + switch _, err := system.Stat(path); { + case os.IsNotExist(err): + // We need to simulate an empty persistent state because Bolt's + // read-only mode is only supported for databases that already exist. + // + // If the database does not already exist, then Bolt will open it with + // O_RDONLY and then attempt to initialize it, which leads to EBADF + // errors on Linux. See also https://github.com/etcd-io/bbolt/issues/98. + empty = true + case err != nil: + return nil, err } + options := bbolt.Options{ OpenFile: func(name string, flag int, perm os.FileMode) (*os.File, error) { rawPath, err := system.RawPath(AbsPath(name)) @@ -42,28 +52,34 @@ func NewBoltPersistentState(system System, path AbsPath, mode BoltPersistentStat ReadOnly: mode == BoltPersistentStateReadOnly, Timeout: time.Second, } - db, err := bbolt.Open(string(path), 0o600, &options) - if err != nil { - return nil, err - } + return &BoltPersistentState{ - db: db, + system: system, + empty: empty, + path: path, + options: options, }, nil } // Close closes b. func (b *BoltPersistentState) Close() error { - if b.db == nil { - return nil + if b.db != nil { + if err := b.db.Close(); err != nil { + return err + } + b.db = nil } - return b.db.Close() + return nil } // CopyTo copies b to p. func (b *BoltPersistentState) CopyTo(p PersistentState) error { - if b.db == nil { + if b.empty { return nil } + if err := b.open(); err != nil { + return err + } return b.db.View(func(tx *bbolt.Tx) error { return tx.ForEach(func(bucket []byte, b *bbolt.Bucket) error { @@ -77,6 +93,13 @@ func (b *BoltPersistentState) CopyTo(p PersistentState) error { // Delete deletes the value associate with key in bucket. If bucket or key does // not exist then Delete does nothing. func (b *BoltPersistentState) Delete(bucket, key []byte) error { + if b.empty { + return nil + } + if err := b.open(); err != nil { + return err + } + return b.db.Update(func(tx *bbolt.Tx) error { b := tx.Bucket(bucket) if b == nil { @@ -88,9 +111,12 @@ func (b *BoltPersistentState) Delete(bucket, key []byte) error { // ForEach calls fn for each key, value pair in bucket. func (b *BoltPersistentState) ForEach(bucket []byte, fn func(k, v []byte) error) error { - if b.db == nil { + if b.empty { return nil } + if err := b.open(); err != nil { + return err + } return b.db.View(func(tx *bbolt.Tx) error { b := tx.Bucket(bucket) @@ -105,9 +131,12 @@ func (b *BoltPersistentState) ForEach(bucket []byte, fn func(k, v []byte) error) // Get returns the value associated with key in bucket. func (b *BoltPersistentState) Get(bucket, key []byte) ([]byte, error) { - if b.db == nil { + if b.empty { return nil, nil } + if err := b.open(); err != nil { + return nil, err + } var value []byte if err := b.db.View(func(tx *bbolt.Tx) error { @@ -126,6 +155,10 @@ func (b *BoltPersistentState) Get(bucket, key []byte) ([]byte, error) { // Set sets the value associated with key in bucket. bucket will be created if // it does not already exist. func (b *BoltPersistentState) Set(bucket, key, value []byte) error { + if err := b.open(); err != nil { + return err + } + return b.db.Update(func(tx *bbolt.Tx) error { b, err := tx.CreateBucketIfNotExists(bucket) if err != nil { @@ -135,6 +168,22 @@ func (b *BoltPersistentState) Set(bucket, key, value []byte) error { }) } +func (b *BoltPersistentState) open() error { + if b.db != nil { + return nil + } + if err := MkdirAll(b.system, b.path.Dir(), 0o777); err != nil { + return err + } + db, err := bbolt.Open(string(b.path), 0o600, &b.options) + if err != nil { + return err + } + b.empty = false + b.db = db + return nil +} + func copyByteSlice(value []byte) []byte { if value == nil { return nil diff --git a/internal/chezmoi/boltpersistentstate_test.go b/internal/chezmoi/boltpersistentstate_test.go index ae21ede00cf..7999e1bec3d 100644 --- a/internal/chezmoi/boltpersistentstate_test.go +++ b/internal/chezmoi/boltpersistentstate_test.go @@ -25,17 +25,34 @@ func TestBoltPersistentState(t *testing.T) { b1, err := NewBoltPersistentState(s, path, BoltPersistentStateReadWrite) require.NoError(t, err) + + // Test that getting a key from an non-existent state does not create + // the state. + actualValue, err := b1.Get(bucket, key) + require.NoError(t, err) vfst.RunTests(t, fs, "", vfst.TestPath(string(path), - vfst.TestModeIsRegular, + vfst.TestDoesNotExist, ), ) - - actualValue, err := b1.Get(bucket, key) - require.NoError(t, err) assert.Equal(t, []byte(nil), actualValue) + // Test that deleting a key from a non-existent state does not create + // the state. + require.NoError(t, b1.Delete(bucket, key)) + vfst.RunTests(t, fs, "", + vfst.TestPath(string(path), + vfst.TestDoesNotExist, + ), + ) + + // Test that setting a key creates the state. assert.NoError(t, b1.Set(bucket, key, value)) + vfst.RunTests(t, fs, "", + vfst.TestPath(string(path), + vfst.TestModeIsRegular, + ), + ) actualValue, err = b1.Get(bucket, key) require.NoError(t, err) assert.Equal(t, value, actualValue) diff --git a/internal/chezmoi/mockpersistentstate.go b/internal/chezmoi/mockpersistentstate.go index 464e6c39d82..6482455cc00 100644 --- a/internal/chezmoi/mockpersistentstate.go +++ b/internal/chezmoi/mockpersistentstate.go @@ -1,11 +1,5 @@ package chezmoi -import ( - "errors" -) - -var errClosed = errors.New("closed") - // A MockPersistentState is a mock persistent state. type MockPersistentState struct { buckets map[string]map[string][]byte @@ -20,18 +14,11 @@ func NewMockPersistentState() *MockPersistentState { // Close closes s. func (s *MockPersistentState) Close() error { - if s.buckets == nil { - return errClosed - } - s.buckets = nil return nil } // CopyTo implements PersistentState.CopyTo. func (s *MockPersistentState) CopyTo(p PersistentState) error { - if s.buckets == nil { - return errClosed - } for bucket, bucketMap := range s.buckets { for key, value := range bucketMap { if err := p.Set([]byte(bucket), []byte(key), value); err != nil { @@ -44,9 +31,6 @@ func (s *MockPersistentState) CopyTo(p PersistentState) error { // Delete implements PersistentState.Delete. func (s *MockPersistentState) Delete(bucket, key []byte) error { - if s.buckets == nil { - return errClosed - } bucketMap, ok := s.buckets[string(bucket)] if !ok { return nil @@ -57,9 +41,6 @@ func (s *MockPersistentState) Delete(bucket, key []byte) error { // ForEach implements PersistentState.ForEach. func (s *MockPersistentState) ForEach(bucket []byte, fn func(k, v []byte) error) error { - if s.buckets == nil { - return errClosed - } for k, v := range s.buckets[string(bucket)] { if err := fn([]byte(k), v); err != nil { return err @@ -70,9 +51,6 @@ func (s *MockPersistentState) ForEach(bucket []byte, fn func(k, v []byte) error) // Get implements PersistentState.Get. func (s *MockPersistentState) Get(bucket, key []byte) ([]byte, error) { - if s.buckets == nil { - return nil, errClosed - } bucketMap, ok := s.buckets[string(bucket)] if !ok { return nil, nil @@ -82,9 +60,6 @@ func (s *MockPersistentState) Get(bucket, key []byte) ([]byte, error) { // Set implements PersistentState.Set. func (s *MockPersistentState) Set(bucket, key, value []byte) error { - if s.buckets == nil { - return errClosed - } bucketMap, ok := s.buckets[string(bucket)] if !ok { bucketMap = make(map[string][]byte) diff --git a/internal/chezmoi/mockpersistentstate_test.go b/internal/chezmoi/mockpersistentstate_test.go index 7fd4fffcde5..4961f768aa4 100644 --- a/internal/chezmoi/mockpersistentstate_test.go +++ b/internal/chezmoi/mockpersistentstate_test.go @@ -8,43 +8,51 @@ import ( "github.com/stretchr/testify/require" ) -func TestPersistentState(t *testing.T) { +func TestMockPersistentState(t *testing.T) { var ( bucket = []byte("bucket") key = []byte("key") value = []byte("value") ) - s := NewMockPersistentState() + s1 := NewMockPersistentState() - require.NoError(t, s.Delete(bucket, value)) + require.NoError(t, s1.Delete(bucket, value)) - actualValue, err := s.Get(bucket, key) + actualValue, err := s1.Get(bucket, key) require.NoError(t, err) assert.Nil(t, actualValue) - require.NoError(t, s.Set(bucket, key, value)) + require.NoError(t, s1.Set(bucket, key, value)) - actualValue, err = s.Get(bucket, key) + actualValue, err = s1.Get(bucket, key) require.NoError(t, err) assert.Equal(t, value, actualValue) - require.NoError(t, s.ForEach(bucket, func(k, v []byte) error { + require.NoError(t, s1.ForEach(bucket, func(k, v []byte) error { assert.Equal(t, key, k) assert.Equal(t, value, v) return nil })) - assert.Equal(t, io.EOF, s.ForEach(bucket, func(k, v []byte) error { + assert.Equal(t, io.EOF, s1.ForEach(bucket, func(k, v []byte) error { return io.EOF })) - require.NoError(t, s.Delete(bucket, key)) - actualValue, err = s.Get(bucket, key) + s2 := NewMockPersistentState() + require.NoError(t, s1.CopyTo(s2)) + actualValue, err = s2.Get(bucket, key) + assert.NoError(t, err) + assert.Equal(t, value, actualValue) + + require.NoError(t, s1.Close()) + + actualValue, err = s1.Get(bucket, key) + assert.NoError(t, err) + assert.Equal(t, value, actualValue) + + require.NoError(t, s1.Delete(bucket, key)) + actualValue, err = s1.Get(bucket, key) require.NoError(t, err) assert.Nil(t, actualValue) - - require.NoError(t, s.Close()) - _, err = s.Get(bucket, key) - assert.Equal(t, errClosed, err) }