Skip to content

Commit

Permalink
Unlock state when running editor or merge command
Browse files Browse the repository at this point in the history
  • Loading branch information
twpayne committed Mar 26, 2021
1 parent c53f938 commit 30f1ab9
Show file tree
Hide file tree
Showing 6 changed files with 120 additions and 67 deletions.
11 changes: 6 additions & 5 deletions cmd/config.go
Expand Up @@ -952,10 +952,8 @@ func (c *Config) newRootCmd() (*cobra.Command, error) {
func (c *Config) persistentPostRunRootE(cmd *cobra.Command, args []string) error {
defer pprof.StopCPUProfile()

if c.persistentState != nil {
if err := c.persistentState.Close(); err != nil {
return err
}
if err := c.persistentState.Close(); err != nil {
return err
}

if boolAnnotation(cmd, modifiesConfigFile) {
Expand Down Expand Up @@ -1097,7 +1095,7 @@ func (c *Config) persistentPreRunRootE(cmd *cobra.Command, args []string) error
return err
}
default:
c.persistentState = nil
c.persistentState = chezmoi.NullPersistentState{}
}
if c.debug && c.persistentState != nil {
c.persistentState = chezmoi.NewDebugPersistentState(c.persistentState)
Expand Down Expand Up @@ -1256,6 +1254,9 @@ func (c *Config) run(dir chezmoi.AbsPath, name string, args []string) error {
}

func (c *Config) runEditor(args []string) error {
if err := c.persistentState.Close(); err != nil {
return err
}
editor, editorArgs := c.editor()
return c.run("", editor, append(editorArgs, args...))
}
Expand Down
3 changes: 3 additions & 0 deletions cmd/mergecmd.go
Expand Up @@ -79,6 +79,9 @@ func (c *Config) runMergeCmd(cmd *cobra.Command, args []string, sourceState *che
string(c.sourceDirAbsPath.Join(sourceStateEntry.SourceRelPath().RelPath())),
string(targetStatePath),
)
if err := c.persistentState.Close(); err != nil {
return err
}
if err := c.run(c.destDirAbsPath, c.Merge.Command, args); err != nil {
return fmt.Errorf("%s: %w", targetRelPath, err)
}
Expand Down
87 changes: 68 additions & 19 deletions internal/chezmoi/boltpersistentstate.go
Expand Up @@ -18,19 +18,29 @@ const (

// A BoltPersistentState is a state persisted with bolt.
type BoltPersistentState struct {
db *bbolt.DB
system System
empty bool
path AbsPath
options bbolt.Options
db *bbolt.DB
}

// NewBoltPersistentState returns a new BoltPersistentState.
func NewBoltPersistentState(system System, path AbsPath, mode BoltPersistentStateMode) (*BoltPersistentState, error) {
if _, err := system.Stat(path); os.IsNotExist(err) {
if mode == BoltPersistentStateReadOnly {
return &BoltPersistentState{}, nil
}
if err := MkdirAll(system, path.Dir(), 0o777); err != nil {
return nil, err
}
empty := false
switch _, err := system.Stat(path); {
case os.IsNotExist(err):
// We need to simulate an empty persistent state because Bolt's
// read-only mode is only supported for databases that already exist.
//
// If the database does not already exist, then Bolt will open it with
// O_RDONLY and then attempt to initialize it, which leads to EBADF
// errors on Linux. See also https://github.com/etcd-io/bbolt/issues/98.
empty = true
case err != nil:
return nil, err
}

options := bbolt.Options{
OpenFile: func(name string, flag int, perm os.FileMode) (*os.File, error) {
rawPath, err := system.RawPath(AbsPath(name))
Expand All @@ -42,28 +52,34 @@ func NewBoltPersistentState(system System, path AbsPath, mode BoltPersistentStat
ReadOnly: mode == BoltPersistentStateReadOnly,
Timeout: time.Second,
}
db, err := bbolt.Open(string(path), 0o600, &options)
if err != nil {
return nil, err
}

return &BoltPersistentState{
db: db,
system: system,
empty: empty,
path: path,
options: options,
}, nil
}

// Close closes b.
func (b *BoltPersistentState) Close() error {
if b.db == nil {
return nil
if b.db != nil {
if err := b.db.Close(); err != nil {
return err
}
b.db = nil
}
return b.db.Close()
return nil
}

// CopyTo copies b to p.
func (b *BoltPersistentState) CopyTo(p PersistentState) error {
if b.db == nil {
if b.empty {
return nil
}
if err := b.open(); err != nil {
return err
}

return b.db.View(func(tx *bbolt.Tx) error {
return tx.ForEach(func(bucket []byte, b *bbolt.Bucket) error {
Expand All @@ -77,6 +93,13 @@ func (b *BoltPersistentState) CopyTo(p PersistentState) error {
// Delete deletes the value associate with key in bucket. If bucket or key does
// not exist then Delete does nothing.
func (b *BoltPersistentState) Delete(bucket, key []byte) error {
if b.empty {
return nil
}
if err := b.open(); err != nil {
return err
}

return b.db.Update(func(tx *bbolt.Tx) error {
b := tx.Bucket(bucket)
if b == nil {
Expand All @@ -88,9 +111,12 @@ func (b *BoltPersistentState) Delete(bucket, key []byte) error {

// ForEach calls fn for each key, value pair in bucket.
func (b *BoltPersistentState) ForEach(bucket []byte, fn func(k, v []byte) error) error {
if b.db == nil {
if b.empty {
return nil
}
if err := b.open(); err != nil {
return err
}

return b.db.View(func(tx *bbolt.Tx) error {
b := tx.Bucket(bucket)
Expand All @@ -105,9 +131,12 @@ func (b *BoltPersistentState) ForEach(bucket []byte, fn func(k, v []byte) error)

// Get returns the value associated with key in bucket.
func (b *BoltPersistentState) Get(bucket, key []byte) ([]byte, error) {
if b.db == nil {
if b.empty {
return nil, nil
}
if err := b.open(); err != nil {
return nil, err
}

var value []byte
if err := b.db.View(func(tx *bbolt.Tx) error {
Expand All @@ -126,6 +155,10 @@ func (b *BoltPersistentState) Get(bucket, key []byte) ([]byte, error) {
// Set sets the value associated with key in bucket. bucket will be created if
// it does not already exist.
func (b *BoltPersistentState) Set(bucket, key, value []byte) error {
if err := b.open(); err != nil {
return err
}

return b.db.Update(func(tx *bbolt.Tx) error {
b, err := tx.CreateBucketIfNotExists(bucket)
if err != nil {
Expand All @@ -135,6 +168,22 @@ func (b *BoltPersistentState) Set(bucket, key, value []byte) error {
})
}

func (b *BoltPersistentState) open() error {
if b.db != nil {
return nil
}
if err := MkdirAll(b.system, b.path.Dir(), 0o777); err != nil {
return err
}
db, err := bbolt.Open(string(b.path), 0o600, &b.options)
if err != nil {
return err
}
b.empty = false
b.db = db
return nil
}

func copyByteSlice(value []byte) []byte {
if value == nil {
return nil
Expand Down
25 changes: 21 additions & 4 deletions internal/chezmoi/boltpersistentstate_test.go
Expand Up @@ -25,17 +25,34 @@ func TestBoltPersistentState(t *testing.T) {

b1, err := NewBoltPersistentState(s, path, BoltPersistentStateReadWrite)
require.NoError(t, err)

// Test that getting a key from an non-existent state does not create
// the state.
actualValue, err := b1.Get(bucket, key)
require.NoError(t, err)
vfst.RunTests(t, fs, "",
vfst.TestPath(string(path),
vfst.TestModeIsRegular,
vfst.TestDoesNotExist,
),
)

actualValue, err := b1.Get(bucket, key)
require.NoError(t, err)
assert.Equal(t, []byte(nil), actualValue)

// Test that deleting a key from a non-existent state does not create
// the state.
require.NoError(t, b1.Delete(bucket, key))
vfst.RunTests(t, fs, "",
vfst.TestPath(string(path),
vfst.TestDoesNotExist,
),
)

// Test that setting a key creates the state.
assert.NoError(t, b1.Set(bucket, key, value))
vfst.RunTests(t, fs, "",
vfst.TestPath(string(path),
vfst.TestModeIsRegular,
),
)
actualValue, err = b1.Get(bucket, key)
require.NoError(t, err)
assert.Equal(t, value, actualValue)
Expand Down
25 changes: 0 additions & 25 deletions internal/chezmoi/mockpersistentstate.go
@@ -1,11 +1,5 @@
package chezmoi

import (
"errors"
)

var errClosed = errors.New("closed")

// A MockPersistentState is a mock persistent state.
type MockPersistentState struct {
buckets map[string]map[string][]byte
Expand All @@ -20,18 +14,11 @@ func NewMockPersistentState() *MockPersistentState {

// Close closes s.
func (s *MockPersistentState) Close() error {
if s.buckets == nil {
return errClosed
}
s.buckets = nil
return nil
}

// CopyTo implements PersistentState.CopyTo.
func (s *MockPersistentState) CopyTo(p PersistentState) error {
if s.buckets == nil {
return errClosed
}
for bucket, bucketMap := range s.buckets {
for key, value := range bucketMap {
if err := p.Set([]byte(bucket), []byte(key), value); err != nil {
Expand All @@ -44,9 +31,6 @@ func (s *MockPersistentState) CopyTo(p PersistentState) error {

// Delete implements PersistentState.Delete.
func (s *MockPersistentState) Delete(bucket, key []byte) error {
if s.buckets == nil {
return errClosed
}
bucketMap, ok := s.buckets[string(bucket)]
if !ok {
return nil
Expand All @@ -57,9 +41,6 @@ func (s *MockPersistentState) Delete(bucket, key []byte) error {

// ForEach implements PersistentState.ForEach.
func (s *MockPersistentState) ForEach(bucket []byte, fn func(k, v []byte) error) error {
if s.buckets == nil {
return errClosed
}
for k, v := range s.buckets[string(bucket)] {
if err := fn([]byte(k), v); err != nil {
return err
Expand All @@ -70,9 +51,6 @@ func (s *MockPersistentState) ForEach(bucket []byte, fn func(k, v []byte) error)

// Get implements PersistentState.Get.
func (s *MockPersistentState) Get(bucket, key []byte) ([]byte, error) {
if s.buckets == nil {
return nil, errClosed
}
bucketMap, ok := s.buckets[string(bucket)]
if !ok {
return nil, nil
Expand All @@ -82,9 +60,6 @@ func (s *MockPersistentState) Get(bucket, key []byte) ([]byte, error) {

// Set implements PersistentState.Set.
func (s *MockPersistentState) Set(bucket, key, value []byte) error {
if s.buckets == nil {
return errClosed
}
bucketMap, ok := s.buckets[string(bucket)]
if !ok {
bucketMap = make(map[string][]byte)
Expand Down
36 changes: 22 additions & 14 deletions internal/chezmoi/mockpersistentstate_test.go
Expand Up @@ -8,43 +8,51 @@ import (
"github.com/stretchr/testify/require"
)

func TestPersistentState(t *testing.T) {
func TestMockPersistentState(t *testing.T) {
var (
bucket = []byte("bucket")
key = []byte("key")
value = []byte("value")
)

s := NewMockPersistentState()
s1 := NewMockPersistentState()

require.NoError(t, s.Delete(bucket, value))
require.NoError(t, s1.Delete(bucket, value))

actualValue, err := s.Get(bucket, key)
actualValue, err := s1.Get(bucket, key)
require.NoError(t, err)
assert.Nil(t, actualValue)

require.NoError(t, s.Set(bucket, key, value))
require.NoError(t, s1.Set(bucket, key, value))

actualValue, err = s.Get(bucket, key)
actualValue, err = s1.Get(bucket, key)
require.NoError(t, err)
assert.Equal(t, value, actualValue)

require.NoError(t, s.ForEach(bucket, func(k, v []byte) error {
require.NoError(t, s1.ForEach(bucket, func(k, v []byte) error {
assert.Equal(t, key, k)
assert.Equal(t, value, v)
return nil
}))

assert.Equal(t, io.EOF, s.ForEach(bucket, func(k, v []byte) error {
assert.Equal(t, io.EOF, s1.ForEach(bucket, func(k, v []byte) error {
return io.EOF
}))

require.NoError(t, s.Delete(bucket, key))
actualValue, err = s.Get(bucket, key)
s2 := NewMockPersistentState()
require.NoError(t, s1.CopyTo(s2))
actualValue, err = s2.Get(bucket, key)
assert.NoError(t, err)
assert.Equal(t, value, actualValue)

require.NoError(t, s1.Close())

actualValue, err = s1.Get(bucket, key)
assert.NoError(t, err)
assert.Equal(t, value, actualValue)

require.NoError(t, s1.Delete(bucket, key))
actualValue, err = s1.Get(bucket, key)
require.NoError(t, err)
assert.Nil(t, actualValue)

require.NoError(t, s.Close())
_, err = s.Get(bucket, key)
assert.Equal(t, errClosed, err)
}

0 comments on commit 30f1ab9

Please sign in to comment.