Skip to content

Commit

Permalink
lazy init stmtBuf
Browse files Browse the repository at this point in the history
  • Loading branch information
bobotu committed Apr 2, 2020
1 parent b711194 commit 466c87c
Showing 1 changed file with 59 additions and 21 deletions.
80 changes: 59 additions & 21 deletions session/txn.go
Expand Up @@ -62,12 +62,36 @@ func (st *TxnState) init() {
st.mutations = make(map[int64]*binlog.TableMutation)
}

func (st *TxnState) initStmtBuf() {
if st.stmtBuf == nil {
st.stmtBuf = st.Transaction.NewStagingBuffer()
}
}

func (st *TxnState) stmtBufLen() int {
if st.stmtBuf == nil {
return 0
}
return st.stmtBuf.Len()
}

func (st *TxnState) stmtBufSize() int {
if st.stmtBuf == nil {
return 0
}
return st.stmtBuf.Size()
}

func (st *TxnState) stmtBufGet(ctx context.Context, k kv.Key) ([]byte, error) {
if st.stmtBuf == nil {
return nil, kv.ErrNotExist
}
return st.stmtBuf.Get(ctx, k)
}

// Size implements the MemBuffer interface.
func (st *TxnState) Size() int {
size := 0
if st.stmtBuf != nil {
size += st.stmtBuf.Size()
}
size := st.stmtBufSize()
if st.Transaction != nil {
size += st.Transaction.Size()
}
Expand All @@ -76,16 +100,23 @@ func (st *TxnState) Size() int {

// NewBuffer returns a new child write buffer.
func (st *TxnState) NewStagingBuffer() kv.MemBuffer {
st.initStmtBuf()
return st.stmtBuf.NewStagingBuffer()
}

// Flush flushes all staging kvs into parent buffer.
func (st *TxnState) Flush() (int, error) {
if st.stmtBuf == nil {
return 0, nil
}
return st.stmtBuf.Flush()
}

// Discard discards all staging kvs.
func (st *TxnState) Discard() {
if st.stmtBuf == nil {
return
}
st.stmtBuf.Discard()
}

Expand Down Expand Up @@ -127,8 +158,8 @@ func (st *TxnState) GoString() string {
if len(st.mutations) > 0 {
fmt.Fprintf(&s, ", len(mutations)=%d, %#v", len(st.mutations), st.mutations)
}
if st.stmtBuf != nil && st.stmtBuf.Len() != 0 {
fmt.Fprintf(&s, ", buf.length: %d, buf.size: %d", st.stmtBuf.Len(), st.stmtBuf.Size())
if st.stmtBufLen() != 0 {
fmt.Fprintf(&s, ", buf.length: %d, buf.size: %d", st.stmtBufLen(), st.stmtBufSize())
}
} else {
s.WriteString("state=invalid")
Expand All @@ -140,7 +171,6 @@ func (st *TxnState) GoString() string {

func (st *TxnState) changeInvalidToValid(txn kv.Transaction) {
st.Transaction = txn
st.stmtBuf = txn.NewStagingBuffer()
st.txnFuture = nil
}

Expand All @@ -163,7 +193,6 @@ func (st *TxnState) changePendingToValid() error {
return err
}
st.Transaction = txn
st.stmtBuf = txn.NewStagingBuffer()
return nil
}

Expand Down Expand Up @@ -210,7 +239,7 @@ func ResetMockAutoRandIDRetryCount(failTimes int64) {
// Commit overrides the Transaction interface.
func (st *TxnState) Commit(ctx context.Context) error {
defer st.reset()
if len(st.mutations) != 0 || len(st.dirtyTableOP) != 0 || st.stmtBuf.Len() != 0 {
if len(st.mutations) != 0 || len(st.dirtyTableOP) != 0 || st.stmtBufLen() != 0 {
logutil.BgLogger().Error("the code should never run here",
zap.String("TxnState", st.GoString()),
zap.Stack("something must be wrong"))
Expand Down Expand Up @@ -262,7 +291,7 @@ func (st *TxnState) reset() {

// Get overrides the Transaction interface.
func (st *TxnState) Get(ctx context.Context, k kv.Key) ([]byte, error) {
val, err := st.stmtBuf.Get(ctx, k)
val, err := st.stmtBufGet(ctx, k)
if kv.IsErrNotFound(err) {
val, err = st.Transaction.Get(ctx, k)
if kv.IsErrNotFound(err) {
Expand All @@ -283,7 +312,7 @@ func (st *TxnState) BatchGet(ctx context.Context, keys []kv.Key) (map[string][]b
bufferValues := make([][]byte, len(keys))
shrinkKeys := make([]kv.Key, 0, len(keys))
for i, key := range keys {
val, err := st.stmtBuf.Get(ctx, key)
val, err := st.stmtBufGet(ctx, key)
if kv.IsErrNotFound(err) {
shrinkKeys = append(shrinkKeys, key)
continue
Expand All @@ -310,21 +339,26 @@ func (st *TxnState) BatchGet(ctx context.Context, keys []kv.Key) (map[string][]b

// Set overrides the Transaction interface.
func (st *TxnState) Set(k kv.Key, v []byte) error {
st.initStmtBuf()
return st.stmtBuf.Set(k, v)
}

// Delete overrides the Transaction interface.
func (st *TxnState) Delete(k kv.Key) error {
st.initStmtBuf()
return st.stmtBuf.Delete(k)
}

// Iter overrides the Transaction interface.
func (st *TxnState) Iter(k kv.Key, upperBound kv.Key) (kv.Iterator, error) {
bufferIt, err := st.stmtBuf.Iter(k, upperBound)
retrieverIt, err := st.Transaction.Iter(k, upperBound)
if err != nil {
return nil, err
}
retrieverIt, err := st.Transaction.Iter(k, upperBound)
if st.stmtBuf == nil {
return retrieverIt, nil
}
bufferIt, err := st.stmtBuf.Iter(k, upperBound)
if err != nil {
return nil, err
}
Expand All @@ -333,11 +367,14 @@ func (st *TxnState) Iter(k kv.Key, upperBound kv.Key) (kv.Iterator, error) {

// IterReverse overrides the Transaction interface.
func (st *TxnState) IterReverse(k kv.Key) (kv.Iterator, error) {
bufferIt, err := st.stmtBuf.IterReverse(k)
retrieverIt, err := st.Transaction.IterReverse(k)
if err != nil {
return nil, err
}
retrieverIt, err := st.Transaction.IterReverse(k)
if st.stmtBuf == nil {
return retrieverIt, nil
}
bufferIt, err := st.stmtBuf.IterReverse(k)
if err != nil {
return nil, err
}
Expand All @@ -346,10 +383,8 @@ func (st *TxnState) IterReverse(k kv.Key) (kv.Iterator, error) {

func (st *TxnState) cleanup() {
if st.stmtBuf != nil {
st.Discard()
}
if st.Transaction != nil {
st.stmtBuf = st.Transaction.NewStagingBuffer()
st.stmtBuf.Discard()
st.stmtBuf = nil
}
for key := range st.mutations {
delete(st.mutations, key)
Expand All @@ -370,7 +405,10 @@ func (st *TxnState) cleanup() {

// KeysNeedToLock returns the keys need to be locked.
func (st *TxnState) KeysNeedToLock() ([]kv.Key, error) {
keys := make([]kv.Key, 0, st.stmtBuf.Len())
if st.stmtBufLen() == 0 {
return nil, nil
}
keys := make([]kv.Key, 0, st.stmtBufLen())
if err := kv.WalkMemBuffer(st.stmtBuf, func(k kv.Key, v []byte) error {
if !keyNeedToLock(k, v) {
return nil
Expand Down Expand Up @@ -512,7 +550,7 @@ func (s *session) StmtCommit(memTracker *memory.Tracker) error {

var err error
failpoint.Inject("mockStmtCommitError", func(val failpoint.Value) {
if val.(bool) && st.stmtBuf.Len() > 3 {
if val.(bool) && st.stmtBufLen() > 3 {
err = errors.New("mock stmt commit error")
}
})
Expand Down

0 comments on commit 466c87c

Please sign in to comment.