diff --git a/channel.go b/channel.go index 8ba9bab..ae6f2d1 100644 --- a/channel.go +++ b/channel.go @@ -1435,6 +1435,11 @@ func (ch *Channel) PublishWithDeferredConfirmWithContext(ctx context.Context, ex ch.m.Lock() defer ch.m.Unlock() + var dc *DeferredConfirmation + if ch.confirming { + dc = ch.confirms.publish() + } + if err := ch.send(&basicPublish{ Exchange: exchange, RoutingKey: key, @@ -1457,14 +1462,13 @@ func (ch *Channel) PublishWithDeferredConfirmWithContext(ctx context.Context, ex AppId: msg.AppId, }, }); err != nil { + if ch.confirming { + ch.confirms.unpublish() + } return nil, err } - if ch.confirming { - return ch.confirms.Publish(), nil - } - - return nil, nil + return dc, nil } /* diff --git a/confirms.go b/confirms.go index f9973b7..577e042 100644 --- a/confirms.go +++ b/confirms.go @@ -39,7 +39,7 @@ func (c *confirms) Listen(l chan Confirmation) { } // Publish increments the publishing counter -func (c *confirms) Publish() *DeferredConfirmation { +func (c *confirms) publish() *DeferredConfirmation { c.publishedMut.Lock() defer c.publishedMut.Unlock() @@ -47,6 +47,15 @@ func (c *confirms) Publish() *DeferredConfirmation { return c.deferredConfirmations.Add(c.published) } +// unpublish decrements the publishing counter and removes the +// DeferredConfirmation. It must be called immediately after a publish fails. +func (c *confirms) unpublish() { + c.publishedMut.Lock() + defer c.publishedMut.Unlock() + c.deferredConfirmations.remove(c.published) + c.published-- +} + // confirm confirms one publishing, increments the expecting delivery tag, and // removes bookkeeping for that delivery tag. func (c *confirms) confirm(confirmation Confirmation) { @@ -135,6 +144,18 @@ func (d *deferredConfirmations) Add(tag uint64) *DeferredConfirmation { return dc } +// remove is only used to drop a tag whose publish failed +func (d *deferredConfirmations) remove(tag uint64) { + d.m.Lock() + defer d.m.Unlock() + dc, found := d.confirmations[tag] + if !found { + return + } + close(dc.done) + delete(d.confirmations, tag) +} + func (d *deferredConfirmations) Confirm(confirmation Confirmation) { d.m.Lock() defer d.m.Unlock() diff --git a/confirms_test.go b/confirms_test.go index e390f9f..967c12d 100644 --- a/confirms_test.go +++ b/confirms_test.go @@ -26,7 +26,7 @@ func TestConfirmOneResequences(t *testing.T) { c.Listen(l) for i := range fixtures { - if want, got := uint64(i+1), c.Publish(); want != got.DeliveryTag { + if want, got := uint64(i+1), c.publish(); want != got.DeliveryTag { t.Fatalf("expected publish to return the 1 based delivery tag published, want: %d, got: %d", want, got.DeliveryTag) } } @@ -64,7 +64,7 @@ func TestConfirmAndPublishDoNotDeadlock(t *testing.T) { }() for i := 0; i < iterations; i++ { - c.Publish() + c.publish() <-l } } @@ -82,7 +82,7 @@ func TestConfirmMixedResequences(t *testing.T) { c.Listen(l) for range fixtures { - c.Publish() + c.publish() } c.One(fixtures[0]) @@ -117,7 +117,7 @@ func TestConfirmMultipleResequences(t *testing.T) { c.Listen(l) for range fixtures { - c.Publish() + c.publish() } c.Multiple(fixtures[len(fixtures)-1]) @@ -141,7 +141,7 @@ func BenchmarkSequentialBufferedConfirms(t *testing.B) { if i > cap(l)-1 { <-l } - c.One(Confirmation{c.Publish().DeliveryTag, true}) + c.One(Confirmation{c.publish().DeliveryTag, true}) } } @@ -159,7 +159,7 @@ func TestConfirmsIsThreadSafe(t *testing.T) { c.Listen(l) for i := 0; i < count; i++ { - go func() { pub <- Confirmation{c.Publish().DeliveryTag, true} }() + go func() { pub <- Confirmation{c.publish().DeliveryTag, true} }() } for i := 0; i < count; i++ {