diff --git a/spring-kafka/src/main/java/org/springframework/kafka/core/DefaultKafkaProducerFactory.java b/spring-kafka/src/main/java/org/springframework/kafka/core/DefaultKafkaProducerFactory.java index 6c75c171f2..0333308f13 100644 --- a/spring-kafka/src/main/java/org/springframework/kafka/core/DefaultKafkaProducerFactory.java +++ b/spring-kafka/src/main/java/org/springframework/kafka/core/DefaultKafkaProducerFactory.java @@ -127,8 +127,6 @@ public class DefaultKafkaProducerFactory implements ProducerFactory, private final ThreadLocal> threadBoundProducers = new ThreadLocal<>(); - private final ThreadLocal threadBoundProducerEpochs = new ThreadLocal<>(); - private final AtomicInteger epoch = new AtomicInteger(); private final AtomicInteger clientIdCounter = new AtomicInteger(); @@ -402,25 +400,21 @@ private Producer doCreateProducer(@Nullable String txIdPrefix) { } if (this.producerPerThread) { CloseSafeProducer tlProducer = this.threadBoundProducers.get(); - if (this.threadBoundProducerEpochs.get() == null) { - this.threadBoundProducerEpochs.set(this.epoch.get()); - } - if (tlProducer != null && this.epoch.get() != this.threadBoundProducerEpochs.get()) { + if (tlProducer != null && this.epoch.get() != tlProducer.epoch) { closeThreadBoundProducer(); tlProducer = null; } if (tlProducer == null) { tlProducer = new CloseSafeProducer<>(createKafkaProducer(), this::removeProducer, - this.physicalCloseTimeout); + this.physicalCloseTimeout, this.epoch); this.threadBoundProducers.set(tlProducer); - this.threadBoundProducerEpochs.set(this.epoch.get()); } return tlProducer; } synchronized (this) { if (this.producer == null) { this.producer = new CloseSafeProducer<>(createKafkaProducer(), this::removeProducer, - this.physicalCloseTimeout); + this.physicalCloseTimeout, this.epoch); } return this.producer; } @@ -527,7 +521,8 @@ private CloseSafeProducer doCreateTxProducer(String prefix, String suffix, newProducer = createRawProducer(newProducerConfigs); newProducer.initTransactions(); return new CloseSafeProducer<>(newProducer, getCache(prefix), remover, - (String) newProducerConfigs.get(ProducerConfig.TRANSACTIONAL_ID_CONFIG), this.physicalCloseTimeout); + (String) newProducerConfigs.get(ProducerConfig.TRANSACTIONAL_ID_CONFIG), this.physicalCloseTimeout, + this.epoch); } protected Producer createRawProducer(Map configs) { @@ -596,6 +591,10 @@ protected static class CloseSafeProducer implements Producer { private final Duration closeTimeout; + final int epoch; // NOSONAR + + private final AtomicInteger factoryEpoch; + private volatile Exception producerFailed; private volatile boolean closed; @@ -603,30 +602,46 @@ protected static class CloseSafeProducer implements Producer { CloseSafeProducer(Producer delegate, Consumer> removeProducer, Duration closeTimeout) { - this(delegate, null, removeProducer, null, closeTimeout); + this(delegate, null, removeProducer, null, closeTimeout, new AtomicInteger()); + Assert.isTrue(!(delegate instanceof CloseSafeProducer), "Cannot double-wrap a producer"); + } + + CloseSafeProducer(Producer delegate, Consumer> removeProducer, + Duration closeTimeout, AtomicInteger epoch) { + + this(delegate, null, removeProducer, null, closeTimeout, epoch); Assert.isTrue(!(delegate instanceof CloseSafeProducer), "Cannot double-wrap a producer"); } CloseSafeProducer(Producer delegate, BlockingQueue> cache, Duration closeTimeout) { - this(delegate, cache, null, closeTimeout); + this(delegate, cache, null, null, closeTimeout, new AtomicInteger()); } CloseSafeProducer(Producer delegate, BlockingQueue> cache, @Nullable Consumer> removeConsumerProducer, Duration closeTimeout) { - this(delegate, cache, removeConsumerProducer, null, closeTimeout); + this(delegate, cache, removeConsumerProducer, null, closeTimeout, new AtomicInteger()); } CloseSafeProducer(Producer delegate, BlockingQueue> cache, @Nullable Consumer> removeProducer, @Nullable String txId, Duration closeTimeout) { + this(delegate, cache, removeProducer, txId, closeTimeout, new AtomicInteger()); + } + + CloseSafeProducer(Producer delegate, BlockingQueue> cache, + @Nullable Consumer> removeProducer, @Nullable String txId, + Duration closeTimeout, AtomicInteger epoch) { + this.delegate = delegate; this.cache = cache; this.removeProducer = removeProducer; this.txId = txId; this.closeTimeout = closeTimeout; + this.epoch = epoch.get(); + this.factoryEpoch = epoch; LOGGER.debug(() -> "Created new Producer: " + this); } @@ -760,8 +775,8 @@ public void close(@Nullable Duration timeout) { else { if (this.cache != null && this.removeProducer == null) { // dedicated consumer producers are not cached synchronized (this) { - if (!this.cache.contains(this) - && !this.cache.offer(this)) { + if (this.epoch != this.factoryEpoch.get() + || (!this.cache.contains(this) && !this.cache.offer(this))) { this.closed = true; this.delegate.close(timeout); } diff --git a/spring-kafka/src/test/java/org/springframework/kafka/core/DefaultKafkaProducerFactoryTests.java b/spring-kafka/src/test/java/org/springframework/kafka/core/DefaultKafkaProducerFactoryTests.java index 163fc5a607..9c0f7af89b 100644 --- a/spring-kafka/src/test/java/org/springframework/kafka/core/DefaultKafkaProducerFactoryTests.java +++ b/spring-kafka/src/test/java/org/springframework/kafka/core/DefaultKafkaProducerFactoryTests.java @@ -174,7 +174,42 @@ protected Producer createTransactionalProducer(String txIdPrefix) { @Test @SuppressWarnings({ "rawtypes", "unchecked" }) - void testThreadLocal() { + void dontReturnToCacheAfterReset() { + final Producer producer = mock(Producer.class); + ApplicationContext ctx = mock(ApplicationContext.class); + DefaultKafkaProducerFactory pf = new DefaultKafkaProducerFactory(new HashMap<>()) { + + @Override + protected Producer createRawProducer(Map configs) { + return producer; + } + + }; + pf.setApplicationContext(ctx); + pf.setTransactionIdPrefix("foo"); + Producer aProducer = pf.createProducer(); + assertThat(aProducer).isNotNull(); + aProducer.close(); + Producer bProducer = pf.createProducer(); + assertThat(bProducer).isSameAs(aProducer); + bProducer.close(); + assertThat(KafkaTestUtils.getPropertyValue(pf, "producer")).isNull(); + Map cache = KafkaTestUtils.getPropertyValue(pf, "cache", Map.class); + assertThat(cache.size()).isEqualTo(1); + Queue queue = (Queue) cache.get("foo"); + assertThat(queue.size()).isEqualTo(1); + bProducer = pf.createProducer(); + assertThat(bProducer).isSameAs(aProducer); + assertThat(queue.size()).isEqualTo(0); + pf.reset(); + bProducer.close(); + assertThat(queue.size()).isEqualTo(0); + pf.destroy(); + } + + @Test + @SuppressWarnings({ "rawtypes", "unchecked" }) + void testThreadLocal() throws InterruptedException { final Producer producer = mock(Producer.class); DefaultKafkaProducerFactory pf = new DefaultKafkaProducerFactory(new HashMap<>()) {