Skip to content

Commit

Permalink
refactor: simplify the buffer-locking code
Browse files Browse the repository at this point in the history
  • Loading branch information
xaionaro committed Mar 14, 2020
1 parent 9788140 commit ed1029f
Show file tree
Hide file tree
Showing 11 changed files with 129 additions and 289 deletions.
72 changes: 11 additions & 61 deletions buffer_pool.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,10 @@ import (
"sync/atomic"
)

var (
nextLockID uint64
)

type lockID uint64

type buffer struct {
locker lockerRWMutex

pool *bufferPool
locker sync.RWMutex
isMonopolized uint32
isBusy bool
lockID uint64
Expand All @@ -26,67 +21,22 @@ type buffer struct {
MetadataVariableUInt uint
}

func (buf *buffer) IsMonopolized() bool {
return atomic.LoadUint32(&buf.isMonopolized) != 0
}

func (buf *buffer) SetMonopolized(prevLockID lockID, isMonopolized bool) error {
var addErr error
lockErr := buf.lockDo(prevLockID, func(lockID) {
if isMonopolized {
atomic.StoreUint32(&buf.isMonopolized, 1)
} else {
if atomic.SwapUint32(&buf.isMonopolized, 0) == 0 {
addErr = newErrNotMonopolized()
}
}
}, !isMonopolized)
if lockErr != nil {
return lockErr
func (buf *buffer) LockDo(fn func()) {
if !buf.isBusy {
panic(`should not happen`)
}
return addErr
buf.locker.LockDo(fn)
}

func (buf *buffer) LockDo(prevLockID lockID, fn func(lockID)) error {
return buf.lockDo(prevLockID, fn, false)
}

func (buf *buffer) lockDo(prevLockID lockID, fn func(lockID), ignoreIsMonopolized bool) error {
var lockIDValue lockID
if prevLockID != 0 && lockID(atomic.LoadUint64(&buf.lockID)) == prevLockID {
lockIDValue = prevLockID
} else {
buf.locker.Lock()
defer func() {
atomic.StoreUint64(&buf.lockID, 0)
buf.locker.Unlock()
}()
lockIDValue = lockID(atomic.AddUint64(&nextLockID, 1))
atomic.StoreUint64(&buf.lockID, uint64(lockIDValue))
}

func (buf *buffer) Lock() {
if !buf.isBusy {
panic(`should not happened`)
}
if !ignoreIsMonopolized && buf.IsMonopolized() {
return newErrMonopolized()
}

if len(buf.Bytes) > maxPossiblePacketSize {
panic(fmt.Sprintf(`should not happened: %v > %v`, len(buf.Bytes), maxPossiblePacketSize))
}
fn(lockID(lockIDValue))
if len(buf.Bytes) > maxPossiblePacketSize {
panic(fmt.Sprintf(`should not happened: %v > %v`, len(buf.Bytes), maxPossiblePacketSize))
panic(`should not happen`)
}
return nil
buf.locker.Lock()
}

func (buf *buffer) RLockDo(fn func()) error {
buf.locker.RLock()
defer buf.locker.RUnlock()
fn()
return nil
func (buf *buffer) Unlock() {
buf.locker.Unlock()
}

func (buf *buffer) Read(b []byte) (int, error) {
Expand Down
49 changes: 0 additions & 49 deletions buffer_pool_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,16 +8,6 @@ import (

var testBufferPool = newBufferPool(64)

func TestBuffer_RLockDo(t *testing.T) {
buf := testBufferPool.AcquireBuffer()
defer buf.Release()

assert.NoError(t, buf.RLockDo(func() {
buf.locker.RUnlock()
buf.locker.RLock()
}))
}

func TestBuffer_Read(t *testing.T) {
buf := testBufferPool.AcquireBuffer()
defer buf.Release()
Expand All @@ -38,49 +28,10 @@ func TestBuffer_Read(t *testing.T) {
assert.Equal(t, byte(2), b[0])
}

func TestBuffer_SetMonopolized(t *testing.T) {
buf := testBufferPool.AcquireBuffer()
defer buf.Release()

assert.NoError(t, buf.SetMonopolized(0, true))
assert.Error(t, buf.SetMonopolized(0, true))
assert.NoError(t, buf.SetMonopolized(0, false))
assert.Error(t, buf.SetMonopolized(0, false))
}

func TestBuffer_negative(t *testing.T) {
pool := newBufferPool(64)
buf := pool.AcquireBuffer()

func() {
defer func() {
assert.NotNil(t, recover())
}()
buf.isBusy = false
_ = buf.lockDo(0, func(id lockID) {}, false)
}()
buf.isBusy = true

func() {
defer func() {
assert.NotNil(t, recover())
}()

buf.Bytes = make([]byte, maxPossiblePacketSize+1)
_ = buf.lockDo(0, func(id lockID) {}, false)
}()

func() {
defer func() {
assert.NotNil(t, recover())
}()

buf.Bytes = nil
_ = buf.lockDo(0, func(id lockID) {
buf.Bytes = make([]byte, maxPossiblePacketSize+1)
}, false)
}()

buf.Release()
func() {
defer func() {
Expand Down
11 changes: 11 additions & 0 deletions errors.go
Original file line number Diff line number Diff line change
Expand Up @@ -351,3 +351,14 @@ func newErrInvalidPublicKey() error {
func (err errInvalidPublicKey) Error() string {
return fmt.Sprintf("[kx] invalid public key")
}

type errUnableToLock struct{}

func newErrUnableToLock() error {
err := errors.New(errUnableToLock{})
err.Traceback.CutOffFirstNLines += 2
return err
}
func (err errUnableToLock) Error() string {
return fmt.Sprintf("unable to get a lock")
}
5 changes: 3 additions & 2 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,13 @@ require (
github.com/aead/chacha20 v0.0.0-20180709150244-8b13a72661da
github.com/aead/ecdh v0.2.0
github.com/mohae/deepcopy v0.0.0-20170929034955-c48cc78d4826
github.com/stretchr/testify v1.4.0
github.com/stretchr/testify v1.5.1
github.com/xaionaro-go/bytesextra v0.0.0-20200223153815-1ed74f0bfcd8
github.com/xaionaro-go/errors v0.0.0-20200223133802-5f1bdcd2dd3e
github.com/xaionaro-go/slice v0.0.0-20200126131228-455c082ffedb
github.com/xaionaro-go/spinlock v0.0.0-20190309154744-55278e21e817
github.com/xaionaro-go/unsafetools v0.0.0-20200202162159-021b112c4d30
golang.org/x/crypto v0.0.0-20200221231518-2aa609cf4a9d
golang.org/x/crypto v0.0.0-20200302210943-78000ba7a073
golang.org/x/sys v0.0.0-20200301204400-5d559ad92b82 // indirect
lukechampine.com/blake3 v1.0.0
)
8 changes: 6 additions & 2 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -11,18 +11,22 @@ github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZN
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
github.com/stretchr/testify v1.4.0 h1:2E4SXV/wtOkTonXsotYi4li6zVWxYlZuYNCXe9XRJyk=
github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4=
github.com/stretchr/testify v1.5.1 h1:nOGnQDM7FYENwehXlg/kFVnos3rEvtKTjRvOWSzb6H4=
github.com/stretchr/testify v1.5.1/go.mod h1:5W2xD1RspED5o8YsWQXVCued0rvSQ+mT+I5cxcmMvtA=
github.com/xaionaro-go/bytesextra v0.0.0-20200223153815-1ed74f0bfcd8 h1:WemQcsheKiKH0sLCbwv2h15UO/DByoPQX+VLx6jLSaE=
github.com/xaionaro-go/bytesextra v0.0.0-20200223153815-1ed74f0bfcd8/go.mod h1:op5hoGu7YbHB+PlxrR0jAhl5OaCpYCEqtdCfusfZCYk=
github.com/xaionaro-go/errors v0.0.0-20200223133802-5f1bdcd2dd3e h1:xPlvDQvKD3ZebBBhBn1p+mT7Kjeo1dGsXVEwLBdDmhc=
github.com/xaionaro-go/errors v0.0.0-20200223133802-5f1bdcd2dd3e/go.mod h1:kbL3cDyZdqjEhQR1+LVHz0/ZYbFpH62e6t3zCDKmbLA=
github.com/xaionaro-go/slice v0.0.0-20200126131228-455c082ffedb h1:qEmeKHWu+mhuk7+Na2/apZpMQb9U/wjtBDb3NlOhkHc=
github.com/xaionaro-go/slice v0.0.0-20200126131228-455c082ffedb/go.mod h1:dd/cZOg36Ci+eE6IS0quZIBv5+4WLuQj/kMmU3s/ti4=
github.com/xaionaro-go/spinlock v0.0.0-20190309154744-55278e21e817 h1:0ikx4JlTx9uNiHGGC4o0k93GhcWOtONYdhk2H8RUnZU=
github.com/xaionaro-go/spinlock v0.0.0-20190309154744-55278e21e817/go.mod h1:Nb/15eS0BMty6TMuWgRQM8WCDIUlyPZagcpchHT6c9Y=
github.com/xaionaro-go/unsafetools v0.0.0-20200202162159-021b112c4d30 h1:6HCWbXp+IoQx6XRlVbcDlM0BSWigu8H2FM6VdXeOk5s=
github.com/xaionaro-go/unsafetools v0.0.0-20200202162159-021b112c4d30/go.mod h1:spWmgrD4QEkFsootCLGU3OLWgeeJew0KXrCsVPWBQ88=
golang.org/x/crypto v0.0.0-20190211182817-74369b46fc67/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4=
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
golang.org/x/crypto v0.0.0-20200221231518-2aa609cf4a9d h1:1ZiEyfaQIg3Qh0EoqpwAakHVhecoE5wlSg5GjnafJGw=
golang.org/x/crypto v0.0.0-20200221231518-2aa609cf4a9d/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto=
golang.org/x/crypto v0.0.0-20200302210943-78000ba7a073 h1:xMPOj6Pz6UipU1wXLkrtqpHbR0AVFnyPEQq/wRWz9lM=
golang.org/x/crypto v0.0.0-20200302210943-78000ba7a073/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto=
golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg=
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20190412213103-97732733099d h1:+R4KGOnez64A81RvjARKc4UT5/tI9ujCIVX+P5KiHuI=
Expand Down
14 changes: 12 additions & 2 deletions identity_linux_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,9 @@ import (

func TestMissedKeySeedMessage(t *testing.T) {
conn0, conn1 := testUDPPair(t)
identity0, identity1, _, _ := testPair(t)
identity0, identity1, _c0, _c1 := testPair(t)
_c0.Close()
_c1.Close()

opts := &SessionOptions{}
opts.OnInitFuncs = []OnInitFunc{func(sess *Session) { printLogsOfSession(t, false, sess) }}
Expand All @@ -24,7 +26,15 @@ func TestMissedKeySeedMessage(t *testing.T) {
opts.PacketIDStorageSize = -1 // it's UDP :(
opts.KeyExchangerOptions.RetryInterval = time.Millisecond // speed-up the unit test

ctx := context.Background()
ctx, cancelFn := context.WithCancel(context.Background())
defer cancelFn()

go func() {
select {
case <-ctx.Done():
case <-time.After(time.Second):
}
}()

var wg sync.WaitGroup

Expand Down
8 changes: 8 additions & 0 deletions identity_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ func TestNewIdentity(t *testing.T) {

func testIdentityMutualConfirmationOfIdentityWithPSKs(t *testing.T, shouldFail bool, psk0, psk1 []byte) {
identity0, identity1, conn0, conn1 := testPair(t)

defer conn0.Close()
defer conn1.Close()

Expand All @@ -82,6 +83,13 @@ func testIdentityMutualConfirmationOfIdentityWithPSKs(t *testing.T, shouldFail b
ctx, cancelFunc := context.WithDeadline(context.Background(), time.Now().Add(time.Second*5))
defer cancelFunc()

go func() {
select {
case <-ctx.Done():
case <-time.After(time.Second):
}
}()

var wg sync.WaitGroup

var err0 error
Expand Down
2 changes: 1 addition & 1 deletion key_exchanger.go
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ type keyExchanger struct {
localKeyCreatedAt uint64
nextLocalKeyCreatedAt uint64
successNotifyChan chan uint64
keyUpdateLocker lockerRWMutex
keyUpdateLocker lockerMutex
skipKeyUpdateUntil time.Time

cryptoRandReader io.Reader
Expand Down

0 comments on commit ed1029f

Please sign in to comment.