Skip to content

Commit cc3a355

Browse files
Use atomic.Bool in Go 1.19+
Instead of using our own atomicBool type, let's use the standard atomic.Bool type in versions of Go that contain it.
1 parent c2b5ed2 commit cc3a355

File tree

6 files changed

+51
-36
lines changed

6 files changed

+51
-36
lines changed

pkg/kgo/atomic_maybe_work.go

Lines changed: 0 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -2,27 +2,6 @@ package kgo
22

33
import "sync/atomic"
44

5-
// a helper type for some places
6-
type atomicBool uint32
7-
8-
func (b *atomicBool) set(v bool) {
9-
if v {
10-
atomic.StoreUint32((*uint32)(b), 1)
11-
} else {
12-
atomic.StoreUint32((*uint32)(b), 0)
13-
}
14-
}
15-
16-
func (b *atomicBool) get() bool { return atomic.LoadUint32((*uint32)(b)) == 1 }
17-
18-
func (b *atomicBool) swap(v bool) bool {
19-
var swap uint32
20-
if v {
21-
swap = 1
22-
}
23-
return atomic.SwapUint32((*uint32)(b), swap) == 1
24-
}
25-
265
const (
276
stateUnstarted = iota
287
stateWorking

pkg/kgo/consumer_group.go

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -371,7 +371,7 @@ func (g *groupConsumer) manage() {
371371
g.lastAssigned = nil
372372
g.fetching = nil
373373

374-
g.leader.set(false)
374+
g.leader.Store(false)
375375
g.resetExternal()
376376
}
377377

@@ -992,13 +992,13 @@ func (g *groupConsumer) rejoin(why string) {
992992
// for group cancelation to return early.
993993
func (g *groupConsumer) joinAndSync(joinWhy string) error {
994994
g.cfg.logger.Log(LogLevelInfo, "joining group", "group", g.cfg.group)
995-
g.leader.set(false)
995+
g.leader.Store(false)
996996
g.getAndResetExternalRejoin()
997997
defer func() {
998998
// If we are not leader, we clear any tracking of external
999999
// topics from when we were previously leader, since tracking
10001000
// these is just a waste.
1001-
if !g.leader.get() {
1001+
if !g.leader.Load() {
10021002
g.resetExternal()
10031003
}
10041004
}()
@@ -1169,7 +1169,7 @@ func (g *groupConsumer) handleJoinResp(resp *kmsg.JoinGroupResponse) (restart bo
11691169
leader := resp.LeaderID == resp.MemberID
11701170
leaderNoPlan := !leader && resp.Version <= 8 && g.cfg.instanceID != nil && strings.HasPrefix(resp.LeaderID, *g.cfg.instanceID+"-")
11711171
if leader {
1172-
g.leader.set(true)
1172+
g.leader.Store(true)
11731173
g.cfg.logger.Log(LogLevelInfo, "joined, balancing group",
11741174
"group", g.cfg.group,
11751175
"member_id", g.memberID,
@@ -1180,7 +1180,7 @@ func (g *groupConsumer) handleJoinResp(resp *kmsg.JoinGroupResponse) (restart bo
11801180
)
11811181
plan, err = g.balanceGroup(protocol, resp.Members, resp.SkipAssignment)
11821182
} else if leaderNoPlan {
1183-
g.leader.set(true)
1183+
g.leader.Store(true)
11841184
g.cfg.logger.Log(LogLevelInfo, "joined as leader but unable to balance group due to KIP-345 limitations",
11851185
"group", g.cfg.group,
11861186
"member_id", g.memberID,
@@ -1259,8 +1259,8 @@ func (g *groupConsumer) getAndResetExternalRejoin() bool {
12591259
if e == nil {
12601260
return false
12611261
}
1262-
defer e.rejoin.set(false)
1263-
return e.rejoin.get()
1262+
defer e.rejoin.Store(false)
1263+
return e.rejoin.Load()
12641264
}
12651265

12661266
// Runs fn over a load, not copy, of our map.
@@ -1307,7 +1307,7 @@ func (g *groupExternal) updateLatest(meta map[string]*metadataTopic) {
13071307
}
13081308
}
13091309
if rejoin {
1310-
g.rejoin.set(true)
1310+
g.rejoin.Store(true)
13111311
}
13121312
})
13131313
}
@@ -1662,7 +1662,7 @@ func (g *groupConsumer) findNewAssignments() {
16621662
}
16631663
}
16641664

1665-
externalRejoin := g.leader.get() && g.getAndResetExternalRejoin()
1665+
externalRejoin := g.leader.Load() && g.getAndResetExternalRejoin()
16661666

16671667
if len(toChange) == 0 && !externalRejoin {
16681668
return
@@ -1687,7 +1687,7 @@ func (g *groupConsumer) findNewAssignments() {
16871687

16881688
if numNewTopics > 0 {
16891689
g.rejoin("rejoining because there are more topics to consume, our interests have changed")
1690-
} else if g.leader.get() {
1690+
} else if g.leader.Load() {
16911691
if len(toChange) > 0 {
16921692
g.rejoin("rejoining because we are the leader and noticed some topics have new partitions")
16931693
} else if externalRejoin {

pkg/kgo/go118.go

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
//go:build !go1.19
2+
// +build !go1.19
3+
4+
package kgo
5+
6+
import "sync/atomic"
7+
8+
type atomicBool uint32
9+
10+
func (b *atomicBool) Store(v bool) {
11+
if v {
12+
atomic.StoreUint32((*uint32)(b), 1)
13+
} else {
14+
atomic.StoreUint32((*uint32)(b), 0)
15+
}
16+
}
17+
18+
func (b *atomicBool) Load() bool { return atomic.LoadUint32((*uint32)(b)) == 1 }
19+
20+
func (b *atomicBool) Swap(v bool) bool {
21+
var swap uint32
22+
if v {
23+
swap = 1
24+
}
25+
return atomic.SwapUint32((*uint32)(b), swap) == 1
26+
}

pkg/kgo/go119.go

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
//go:build go1.19
2+
// +build go1.19
3+
4+
package kgo
5+
6+
import "sync/atomic"
7+
8+
type atomicBool struct {
9+
atomic.Bool
10+
}

pkg/kgo/sink.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -158,7 +158,7 @@ func (t *txnReqBuilder) add(rb *recBuf) {
158158
if t.txnID == nil {
159159
return
160160
}
161-
if rb.addedToTxn.swap(true) {
161+
if rb.addedToTxn.Swap(true) {
162162
return
163163
}
164164
if t.req == nil {
@@ -443,7 +443,7 @@ func (s *sink) doTxnReq(
443443
// inflight, and that it was not added to the txn and that we need to reset the
444444
// drain index.
445445
func (b *recBatch) removeFromTxn() {
446-
b.owner.addedToTxn.set(false)
446+
b.owner.addedToTxn.Store(false)
447447
b.owner.resetBatchDrainIdx()
448448
b.decInflight()
449449
}

pkg/kgo/txn.go

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -596,7 +596,7 @@ func (cl *Client) EndAndBeginTransaction(
596596
var readd map[string][]int32
597597
for topic, parts := range cl.producer.topics.load() {
598598
for i, part := range parts.load().partitions {
599-
if part.records.addedToTxn.swap(false) {
599+
if part.records.addedToTxn.Swap(false) {
600600
if how == EndBeginTxnUnsafe {
601601
if readd == nil {
602602
readd = make(map[string][]int32)
@@ -705,7 +705,7 @@ func (cl *Client) EndAndBeginTransaction(
705705

706706
for topic, parts := range cl.producer.topics.load() {
707707
for i, part := range parts.load().partitions {
708-
if part.records.addedToTxn.get() {
708+
if part.records.addedToTxn.Load() {
709709
readd[topic] = append(readd[topic], int32(i))
710710
}
711711
}
@@ -836,7 +836,7 @@ func (cl *Client) EndTransaction(ctx context.Context, commit TransactionEndTry)
836836
// addedToTxn to false outside of any mutex.
837837
for _, parts := range cl.producer.topics.load() {
838838
for _, part := range parts.load().partitions {
839-
anyAdded = part.records.addedToTxn.swap(false) || anyAdded
839+
anyAdded = part.records.addedToTxn.Swap(false) || anyAdded
840840
}
841841
}
842842

0 commit comments

Comments
 (0)