diff --git a/consumergroup.go b/consumergroup.go index 0c9843a45..103311fe0 100644 --- a/consumergroup.go +++ b/consumergroup.go @@ -432,8 +432,15 @@ func (g *Generation) CommitOffsets(offsets map[string]map[int]int64) error { Topics: topics, } - _, err := g.coord.offsetCommit(genCtx{g}, request) + resp, err := g.coord.offsetCommit(genCtx{g}, request) if err == nil { + for _, partitions := range resp.Topics { + for _, partition := range partitions { + if partition.Error != nil { + return partition.Error + } + } + } // if logging is enabled, print out the partitions that were committed. g.log(func(l Logger) { var report []string @@ -470,7 +477,7 @@ func (g *Generation) heartbeatLoop(interval time.Duration) { case <-ctx.Done(): return case <-ticker.C: - _, err := g.coord.heartbeat(ctx, &HeartbeatRequest{ + resp, err := g.coord.heartbeat(ctx, &HeartbeatRequest{ GroupID: g.GroupID, GenerationID: g.ID, MemberID: g.MemberID, @@ -478,6 +485,9 @@ func (g *Generation) heartbeatLoop(interval time.Duration) { if err != nil { return } + if resp.Error != nil { + return + } } } }) @@ -1091,6 +1101,9 @@ func (cg *ConsumerGroup) fetchOffsets(subs map[string][]int) (map[string]map[int for topic, offsets := range offsets.Topics { offsetsByPartition := map[int]int64{} for _, pr := range offsets { + if pr.Error != nil { + return nil, pr.Error + } if pr.CommittedOffset < 0 { pr.CommittedOffset = cg.config.StartOffset } @@ -1137,7 +1150,7 @@ func (cg *ConsumerGroup) leaveGroup(ctx context.Context, memberID string) error log.Printf("Leaving group %s, member %s", cg.config.ID, memberID) }) - _, err := cg.coord.leaveGroup(ctx, &LeaveGroupRequest{ + resp, err := cg.coord.leaveGroup(ctx, &LeaveGroupRequest{ GroupID: cg.config.ID, Members: []LeaveGroupRequestMember{ { @@ -1145,6 +1158,9 @@ func (cg *ConsumerGroup) leaveGroup(ctx context.Context, memberID string) error }, }, }) + if err == nil && resp.Error != nil { + err = resp.Error + } if err != nil { cg.withErrorLogger(func(log Logger) { log.Printf("leave group failed for group, %v, and member, %v: %v", cg.config.ID, memberID, err) diff --git a/consumergroup_test.go b/consumergroup_test.go index 3bc72b68e..8c5216ef2 100644 --- a/consumergroup_test.go +++ b/consumergroup_test.go @@ -606,3 +606,96 @@ func TestGenerationStartsFunctionAfterClosed(t *testing.T) { } } } + +func TestGenerationEndsOnHeartbeatError(t *testing.T) { + gen := Generation{ + coord: &mockCoordinator{ + heartbeatFunc: func(context.Context, *HeartbeatRequest) (*HeartbeatResponse, error) { + return nil, errors.New("some error") + }, + }, + done: make(chan struct{}), + joined: make(chan struct{}), + log: func(func(Logger)) {}, + logError: func(func(Logger)) {}, + } + + ch := make(chan error) + gen.Start(func(ctx context.Context) { + <-ctx.Done() + ch <- ctx.Err() + }) + + gen.heartbeatLoop(time.Millisecond) + + select { + case <-time.After(time.Second): + t.Fatal("timed out waiting for func to run") + case err := <-ch: + if !errors.Is(err, ErrGenerationEnded) { + t.Fatalf("expected %v but got %v", ErrGenerationEnded, err) + } + } +} + +func TestGenerationEndsOnHeartbeatRebalaceInProgress(t *testing.T) { + gen := Generation{ + coord: &mockCoordinator{ + heartbeatFunc: func(context.Context, *HeartbeatRequest) (*HeartbeatResponse, error) { + return &HeartbeatResponse{ + Error: makeError(int16(RebalanceInProgress), ""), + }, nil + }, + }, + done: make(chan struct{}), + joined: make(chan struct{}), + log: func(func(Logger)) {}, + logError: func(func(Logger)) {}, + } + + ch := make(chan error) + gen.Start(func(ctx context.Context) { + <-ctx.Done() + ch <- ctx.Err() + }) + + gen.heartbeatLoop(time.Millisecond) + + select { + case <-time.After(time.Second): + t.Fatal("timed out waiting for func to run") + case err := <-ch: + if !errors.Is(err, ErrGenerationEnded) { + t.Fatalf("expected %v but got %v", ErrGenerationEnded, err) + } + } +} + +func TestGenerationOffsetCommitErrorsAreReturned(t *testing.T) { + mc := mockCoordinator{ + offsetCommitFunc: func(context.Context, *OffsetCommitRequest) (*OffsetCommitResponse, error) { + return &OffsetCommitResponse{ + Topics: map[string][]OffsetCommitPartition{ + "topic": { + { + Error: ErrGenerationEnded, + }, + }, + }, + }, nil + }, + } + gen := Generation{ + coord: mc, + log: func(func(Logger)) {}, + } + + err := gen.CommitOffsets(map[string]map[int]int64{ + "topic": { + 0: 100, + }, + }) + if err == nil { + t.Fatal("got nil from CommitOffsets when expecting an error") + } +} diff --git a/joingroup.go b/joingroup.go index 13adc71d2..8cb8fb500 100644 --- a/joingroup.go +++ b/joingroup.go @@ -3,7 +3,9 @@ package kafka import ( "bufio" "context" + "errors" "fmt" + "io" "net" "time" @@ -163,7 +165,9 @@ func (c *Client) JoinGroup(ctx context.Context, req *JoinGroupRequest) (*JoinGro for _, member := range r.Members { var meta consumer.Subscription - err = protocol.Unmarshal(member.Metadata, consumer.MaxVersionSupported, &meta) + metaVersion := makeInt16(member.Metadata[0:2]) + err = protocol.Unmarshal(member.Metadata, metaVersion, &meta) + err = joinGroupSubscriptionMetaError(err, metaVersion) if err != nil { return nil, fmt.Errorf("kafka.(*Client).JoinGroup: %w", err) } @@ -188,6 +192,16 @@ func (c *Client) JoinGroup(ctx context.Context, req *JoinGroupRequest) (*JoinGro return res, nil } +// sarama indicates there are some misbehaving clients out there that +// set the version as 1 but don't include the OwnedPartitions section +// https://github.com/Shopify/sarama/blob/610514edec1825240d59b62e4d7f1aba4b1fa000/consumer_group_members.go#L43 +func joinGroupSubscriptionMetaError(err error, version int16) error { + if version >= 1 && errors.Is(err, io.ErrUnexpectedEOF) { + return nil + } + return err +} + type groupMetadata struct { Version int16 Topics []string diff --git a/joingroup_test.go b/joingroup_test.go index 926f5b4a6..a8695e196 100644 --- a/joingroup_test.go +++ b/joingroup_test.go @@ -5,10 +5,14 @@ import ( "bytes" "context" "errors" + "net" "reflect" "testing" "time" + "github.com/segmentio/kafka-go/protocol" + "github.com/segmentio/kafka-go/protocol/consumer" + "github.com/segmentio/kafka-go/protocol/joingroup" ktesting "github.com/segmentio/kafka-go/testing" ) @@ -124,6 +128,84 @@ func TestClientJoinGroup(t *testing.T) { } } +type roundTripFn func(context.Context, net.Addr, Request) (Response, error) + +func (f roundTripFn) RoundTrip(ctx context.Context, addr net.Addr, req Request) (Response, error) { + return f(ctx, addr, req) +} + +// https://github.com/Shopify/sarama/blob/610514edec1825240d59b62e4d7f1aba4b1fa000/consumer_group_members.go#L43 +func TestClientJoinGroupSaramaCompatibility(t *testing.T) { + subscription := consumer.Subscription{ + Version: 1, + Topics: []string{"topic"}, + } + + // Marhsal as Verzon 0 (Without OwnedPartitions) but + // with Version=1. + metadata, err := protocol.Marshal(0, subscription) + if err != nil { + t.Fatalf("failed to marshal subscription %v", err) + } + + client := &Client{ + Addr: TCP("fake:9092"), + Transport: roundTripFn(func(context.Context, net.Addr, Request) (Response, error) { + resp := joingroup.Response{ + ProtocolType: "consumer", + ProtocolName: RoundRobinGroupBalancer{}.ProtocolName(), + LeaderID: "member", + MemberID: "member", + Members: []joingroup.ResponseMember{ + { + MemberID: "member", + Metadata: metadata, + }, + }, + } + return &resp, nil + }), + } + + expResp := JoinGroupResponse{ + ProtocolName: RoundRobinGroupBalancer{}.ProtocolName(), + ProtocolType: "consumer", + LeaderID: "member", + MemberID: "member", + Members: []JoinGroupResponseMember{ + { + ID: "member", + Metadata: GroupProtocolSubscription{ + Topics: []string{"topic"}, + OwnedPartitions: map[string][]int{}, + }, + }, + }, + } + + gotResp, err := client.JoinGroup(context.Background(), &JoinGroupRequest{ + GroupID: "group", + MemberID: "member", + ProtocolType: "consumer", + Protocols: []GroupProtocol{ + { + Name: RoundRobinGroupBalancer{}.ProtocolName(), + Metadata: GroupProtocolSubscription{ + Topics: []string{"topic"}, + UserData: metadata, + }, + }, + }, + }) + if err != nil { + t.Fatalf("error calling JoinGroup: %v", err) + } + + if !reflect.DeepEqual(expResp, *gotResp) { + t.Fatalf("unexpected JoinGroup resp\nexpected: %#v\n got: %#v", expResp, *gotResp) + } +} + func TestSaramaCompatibility(t *testing.T) { var ( // sample data from github.com/Shopify/sarama diff --git a/protocol/heartbeat/heartbeat.go b/protocol/heartbeat/heartbeat.go index cf4c11185..962d6f467 100644 --- a/protocol/heartbeat/heartbeat.go +++ b/protocol/heartbeat/heartbeat.go @@ -27,8 +27,8 @@ type Response struct { // type. _ struct{} `kafka:"min=v4,max=v4,tag"` - ErrorCode int16 `kafka:"min=v0,max=v4"` ThrottleTimeMs int32 `kafka:"min=v1,max=v4"` + ErrorCode int16 `kafka:"min=v0,max=v4"` } func (r *Response) ApiKey() protocol.ApiKey { diff --git a/reader_test.go b/reader_test.go index 7aa4ca9e1..edf4bc6c3 100644 --- a/reader_test.go +++ b/reader_test.go @@ -1684,7 +1684,7 @@ func TestConsumerGroupMultipleWithDefaultTransport(t *testing.T) { recvErr2 <- err }() - time.Sleep(conf1.MaxWait) + time.Sleep(conf1.MaxWait * 5) totalMessages := 10 diff --git a/syncgroup.go b/syncgroup.go index e649f0db9..9f4766376 100644 --- a/syncgroup.go +++ b/syncgroup.go @@ -127,9 +127,13 @@ func (c *Client) SyncGroup(ctx context.Context, req *SyncGroupRequest) (*SyncGro r := m.(*syncgroup.Response) var assignment consumer.Assignment - err = protocol.Unmarshal(r.Assignments, consumer.MaxVersionSupported, &assignment) - if err != nil { - return nil, fmt.Errorf("kafka.(*Client).SyncGroup: %w", err) + var metaVersion int16 + if len(r.Assignments) > 2 { + metaVersion = makeInt16(r.Assignments[0:2]) + err = protocol.Unmarshal(r.Assignments, metaVersion, &assignment) + if err != nil { + return nil, fmt.Errorf("kafka.(*Client).SyncGroup: %w", err) + } } res := &SyncGroupResponse{ diff --git a/syncgroup_test.go b/syncgroup_test.go index 930696bde..435a3875f 100644 --- a/syncgroup_test.go +++ b/syncgroup_test.go @@ -6,11 +6,77 @@ import ( "context" "errors" "io" + "net" "reflect" "testing" "time" + + "github.com/segmentio/kafka-go/protocol" + "github.com/segmentio/kafka-go/protocol/consumer" + "github.com/segmentio/kafka-go/protocol/syncgroup" ) +func TestClientSyncGroupAssignmentV0(t *testing.T) { + client := &Client{ + Addr: TCP("fake:9092"), + Transport: roundTripFn(func(context.Context, net.Addr, Request) (Response, error) { + assigments := consumer.Assignment{ + Version: 0, + AssignedPartitions: []consumer.TopicPartition{ + { + Topic: "topic", + Partitions: []int32{0, 1, 2}, + }, + }, + } + assignmentBytes, err := protocol.Marshal(0, assigments) + if err != nil { + t.Fatalf("failed to marshal assigments: %v", err) + } + resp := syncgroup.Response{ + ProtocolType: "consumer", + ProtocolName: RoundRobinGroupBalancer{}.ProtocolName(), + Assignments: assignmentBytes, + } + return &resp, nil + }), + } + + expResp := SyncGroupResponse{ + ProtocolType: "consumer", + ProtocolName: RoundRobinGroupBalancer{}.ProtocolName(), + Assignment: GroupProtocolAssignment{ + AssignedPartitions: map[string][]int{ + "topic": {0, 1, 2}, + }, + }, + } + + gotResp, err := client.SyncGroup(context.Background(), &SyncGroupRequest{ + GroupID: "group", + MemberID: "member", + ProtocolType: "consumer", + ProtocolName: "roundrobin", + Assignments: []SyncGroupRequestAssignment{ + { + MemberID: "member", + Assignment: GroupProtocolAssignment{ + AssignedPartitions: map[string][]int{ + "topic": {0, 1, 2}, + }, + }, + }, + }, + }) + if err != nil { + t.Fatalf("error calling SyncGroup: %v", err) + } + + if !reflect.DeepEqual(expResp, *gotResp) { + t.Fatalf("unexpected SyncGroup resp\nexpected: %#v\n got: %#v", expResp, *gotResp) + } +} + func TestClientSyncGroup(t *testing.T) { // In order to get to a sync group call we need to first // join a group.