Skip to content

Commit

Permalink
INT-4366: Fix MulticastSendingMessageHandler (#2329)
Browse files Browse the repository at this point in the history
* INT-4366: Fix MulticastSendingMessageHandler

JIRA: https://jira.spring.io/browse/INT-4366

Fix race condition in the `MulticastSendingMessageHandler` around
`multicastSocket` and super `socket` properties.

* Synchronize around `this` and check for the `multicastSocket == null`.
This let the `MulticastSendingMessageHandler` to fully configure and
prepare the socket for use.
* Remove `socket.setInterface(whichNic)` since it is populated by the
`InetSocketAddress` ctor before

**Cherry-pick to 4.3.x**

* Fix thread leaks in TCP/IP tests
  • Loading branch information
artembilan authored and garyrussell committed Jan 19, 2018
1 parent 8aa91d1 commit c3b64dc
Show file tree
Hide file tree
Showing 13 changed files with 261 additions and 235 deletions.
@@ -1,5 +1,5 @@
/*
* Copyright 2001-2016 the original author or authors.
* Copyright 2001-2018 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -38,6 +38,8 @@
* determine success.
*
* @author Gary Russell
* @author Artem Bilan
*
* @since 2.0
*/
public class MulticastSendingMessageHandler extends UnicastSendingMessageHandler {
Expand Down Expand Up @@ -126,49 +128,45 @@ public MulticastSendingMessageHandler(String destinationExpression) {

@Override
protected DatagramSocket getSocket() throws IOException {
if (this.getTheSocket() == null) {
if (this.multicastSocket == null) {
synchronized (this) {
createSocket();
if (this.multicastSocket == null) {
createSocket();
}
}
}
return this.getTheSocket();
return getTheSocket();
}

private void createSocket() throws IOException {
if (this.getTheSocket() == null) {
MulticastSocket socket;
if (this.isAcknowledge()) {
int ackPort = this.getAckPort();
if (this.localAddress == null) {
socket = ackPort == 0 ? new MulticastSocket() : new MulticastSocket(ackPort);
}
else {
InetAddress whichNic = InetAddress.getByName(this.localAddress);
socket = new MulticastSocket(new InetSocketAddress(whichNic, ackPort));
}
if (getSoReceiveBufferSize() > 0) {
socket.setReceiveBufferSize(this.getSoReceiveBufferSize());
}
if (logger.isDebugEnabled()) {
logger.debug("Listening for acks on port: " + socket.getLocalPort());
}
setSocket(socket);
updateAckAddress();
MulticastSocket socket;
if (isAcknowledge()) {
int ackPort = getAckPort();
if (this.localAddress == null) {
socket = ackPort == 0 ? new MulticastSocket() : new MulticastSocket(ackPort);
}
else {
socket = new MulticastSocket();
setSocket(socket);
InetAddress whichNic = InetAddress.getByName(this.localAddress);
socket = new MulticastSocket(new InetSocketAddress(whichNic, ackPort));
}
if (this.timeToLive >= 0) {
socket.setTimeToLive(this.timeToLive);
if (getSoReceiveBufferSize() > 0) {
socket.setReceiveBufferSize(getSoReceiveBufferSize());
}
setSocketAttributes(socket);
if (this.localAddress != null) {
InetAddress whichNic = InetAddress.getByName(this.localAddress);
socket.setInterface(whichNic);
if (logger.isDebugEnabled()) {
logger.debug("Listening for acks on port: " + socket.getLocalPort());
}
this.multicastSocket = socket;
setSocket(socket);
updateAckAddress();
}
else {
socket = new MulticastSocket();
setSocket(socket);
}
if (this.timeToLive >= 0) {
socket.setTimeToLive(this.timeToLive);
}
setSocketAttributes(socket);
this.multicastSocket = socket;
}


Expand All @@ -178,7 +176,7 @@ private void createSocket() throws IOException {
* @param minAcksForSuccess The minimum number of acks that will represent success.
*/
public void setMinAcksForSuccess(int minAcksForSuccess) {
this.setAckCounter(minAcksForSuccess);
setAckCounter(minAcksForSuccess);
}

/**
Expand Down
@@ -1,5 +1,5 @@
/*
* Copyright 2002-2016 the original author or authors.
* Copyright 2002-2018 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -28,7 +28,6 @@
import java.util.HashSet;
import java.util.Set;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.Executors;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicInteger;
Expand All @@ -39,6 +38,7 @@
import org.junit.Test;

import org.springframework.beans.factory.BeanFactory;
import org.springframework.core.task.SimpleAsyncTaskExecutor;
import org.springframework.integration.channel.DirectChannel;
import org.springframework.integration.channel.QueueChannel;
import org.springframework.integration.handler.ServiceActivatingHandler;
Expand All @@ -56,6 +56,8 @@

/**
* @author Gary Russell
* @author Artem Bilan
*
* @since 2.0
*/
public class TcpInboundGatewayTests {
Expand Down Expand Up @@ -119,30 +121,31 @@ public void testNetClientMode() throws Exception {
final CountDownLatch latch2 = new CountDownLatch(1);
final CountDownLatch latch3 = new CountDownLatch(1);
final AtomicBoolean done = new AtomicBoolean();
Executors.newSingleThreadExecutor().execute(() -> {
try {
ServerSocket server = ServerSocketFactory.getDefault().createServerSocket(0, 10);
port.set(server.getLocalPort());
latch1.countDown();
Socket socket = server.accept();
socket.getOutputStream().write("Test1\r\nTest2\r\n".getBytes());
byte[] bytes = new byte[12];
readFully(socket.getInputStream(), bytes);
assertEquals("Echo:Test1\r\n", new String(bytes));
readFully(socket.getInputStream(), bytes);
assertEquals("Echo:Test2\r\n", new String(bytes));
latch2.await();
socket.close();
server.close();
done.set(true);
latch3.countDown();
}
catch (Exception e) {
if (!done.get()) {
e.printStackTrace();
}
}
});
new SimpleAsyncTaskExecutor()
.execute(() -> {
try {
ServerSocket server = ServerSocketFactory.getDefault().createServerSocket(0, 10);
port.set(server.getLocalPort());
latch1.countDown();
Socket socket = server.accept();
socket.getOutputStream().write("Test1\r\nTest2\r\n".getBytes());
byte[] bytes = new byte[12];
readFully(socket.getInputStream(), bytes);
assertEquals("Echo:Test1\r\n", new String(bytes));
readFully(socket.getInputStream(), bytes);
assertEquals("Echo:Test2\r\n", new String(bytes));
latch2.await();
socket.close();
server.close();
done.set(true);
latch3.countDown();
}
catch (Exception e) {
if (!done.get()) {
e.printStackTrace();
}
}
});
assertTrue(latch1.await(10, TimeUnit.SECONDS));
AbstractClientConnectionFactory ccf = new TcpNetClientConnectionFactory("localhost", port.get());
ccf.setSingleUse(false);
Expand Down
Expand Up @@ -43,7 +43,6 @@
import java.util.Set;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.Executors;
import java.util.concurrent.Future;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicBoolean;
Expand All @@ -61,6 +60,8 @@
import org.springframework.beans.factory.BeanFactory;
import org.springframework.core.serializer.DefaultDeserializer;
import org.springframework.core.serializer.DefaultSerializer;
import org.springframework.core.task.AsyncTaskExecutor;
import org.springframework.core.task.SimpleAsyncTaskExecutor;
import org.springframework.expression.EvaluationContext;
import org.springframework.expression.Expression;
import org.springframework.expression.spel.standard.SpelExpressionParser;
Expand Down Expand Up @@ -90,6 +91,8 @@ public class TcpOutboundGatewayTests {

private static final Log logger = LogFactory.getLog(TcpOutboundGatewayTests.class);

private AsyncTaskExecutor executor = new SimpleAsyncTaskExecutor();

@ClassRule
public static LongRunningIntegrationTest longTests = new LongRunningIntegrationTest();

Expand All @@ -101,13 +104,13 @@ public class TcpOutboundGatewayTests {
public void testGoodNetSingle() throws Exception {
final CountDownLatch latch = new CountDownLatch(1);
final AtomicBoolean done = new AtomicBoolean();
final AtomicReference<ServerSocket> serverSocket = new AtomicReference<ServerSocket>();
Executors.newSingleThreadExecutor().execute(() -> {
final AtomicReference<ServerSocket> serverSocket = new AtomicReference<>();
this.executor.execute(() -> {
try {
ServerSocket server = ServerSocketFactory.getDefault().createServerSocket(0, 100);
serverSocket.set(server);
latch.countDown();
List<Socket> sockets = new ArrayList<Socket>();
List<Socket> sockets = new ArrayList<>();
int i = 0;
while (true) {
Socket socket = server.accept();
Expand Down Expand Up @@ -165,8 +168,8 @@ public void testGoodNetSingle() throws Exception {
public void testGoodNetMultiplex() throws Exception {
final CountDownLatch latch = new CountDownLatch(1);
final AtomicBoolean done = new AtomicBoolean();
final AtomicReference<ServerSocket> serverSocket = new AtomicReference<ServerSocket>();
Executors.newSingleThreadExecutor().execute(() -> {
final AtomicReference<ServerSocket> serverSocket = new AtomicReference<>();
this.executor.execute(() -> {
try {
ServerSocket server = ServerSocketFactory.getDefault().createServerSocket(0, 10);
serverSocket.set(server);
Expand Down Expand Up @@ -220,8 +223,8 @@ public void testGoodNetMultiplex() throws Exception {
public void testGoodNetTimeout() throws Exception {
final CountDownLatch latch = new CountDownLatch(1);
final AtomicBoolean done = new AtomicBoolean();
final AtomicReference<ServerSocket> serverSocket = new AtomicReference<ServerSocket>();
Executors.newSingleThreadExecutor().execute(() -> {
final AtomicReference<ServerSocket> serverSocket = new AtomicReference<>();
this.executor.execute(() -> {
try {
ServerSocket server = ServerSocketFactory.getDefault().createServerSocket(0);
serverSocket.set(server);
Expand Down Expand Up @@ -260,12 +263,12 @@ public void testGoodNetTimeout() throws Exception {
Future<Integer>[] results = (Future<Integer>[]) new Future<?>[2];
for (int i = 0; i < 2; i++) {
final int j = i;
results[j] = (Executors.newSingleThreadExecutor().submit(() -> {
results[j] = (this.executor.submit(() -> {
gateway.handleMessage(MessageBuilder.withPayload("Test" + j).build());
return 0;
}));
}
Set<String> replies = new HashSet<String>();
Set<String> replies = new HashSet<>();
int timeouts = 0;
for (int i = 0; i < 2; i++) {
try {
Expand Down Expand Up @@ -344,7 +347,7 @@ private void testGoodNetGWTimeoutGuts(final int port, AbstractClientConnectionFa
final AtomicReference<String> lastReceived = new AtomicReference<String>();
final CountDownLatch serverLatch = new CountDownLatch(2);

Executors.newSingleThreadExecutor().execute(() -> {
this.executor.execute(() -> {
try {
latch.countDown();
int i = 0;
Expand Down Expand Up @@ -398,7 +401,7 @@ private void testGoodNetGWTimeoutGuts(final int port, AbstractClientConnectionFa

for (int i = 0; i < 2; i++) {
final int j = i;
results[j] = (Executors.newSingleThreadExecutor().submit(() -> {
results[j] = (this.executor.submit(() -> {
gateway.handleMessage(MessageBuilder.withPayload("Test" + j).build());
return j;
}));
Expand Down Expand Up @@ -442,7 +445,7 @@ public void testCachingFailover() throws Exception {
final AtomicBoolean done = new AtomicBoolean();
final CountDownLatch serverLatch = new CountDownLatch(1);

Executors.newSingleThreadExecutor().execute(() -> {
this.executor.execute(() -> {
try {
ServerSocket server = ServerSocketFactory.getDefault().createServerSocket(0);
serverSocket.set(server);
Expand Down Expand Up @@ -517,12 +520,12 @@ public void testCachingFailover() throws Exception {

@Test
public void testFailoverCached() throws Exception {
final AtomicReference<ServerSocket> serverSocket = new AtomicReference<ServerSocket>();
final AtomicReference<ServerSocket> serverSocket = new AtomicReference<>();
final CountDownLatch latch = new CountDownLatch(1);
final AtomicBoolean done = new AtomicBoolean();
final CountDownLatch serverLatch = new CountDownLatch(1);

Executors.newSingleThreadExecutor().execute(() -> {
this.executor.execute(() -> {
try {
ServerSocket server = ServerSocketFactory.getDefault().createServerSocket(0);
serverSocket.set(server);
Expand Down Expand Up @@ -667,11 +670,11 @@ private void testGWPropagatesSocketCloseGuts(final int port, AbstractClientConne
final ServerSocket server) throws Exception {
final CountDownLatch latch = new CountDownLatch(1);
final AtomicBoolean done = new AtomicBoolean();
final AtomicReference<String> lastReceived = new AtomicReference<String>();
final AtomicReference<String> lastReceived = new AtomicReference<>();
final CountDownLatch serverLatch = new CountDownLatch(1);

Executors.newSingleThreadExecutor().execute(() -> {
List<Socket> sockets = new ArrayList<Socket>();
this.executor.execute(() -> {
List<Socket> sockets = new ArrayList<>();
try {
latch.countDown();
while (!done.get()) {
Expand Down Expand Up @@ -793,8 +796,8 @@ private void testGWPropagatesSocketTimeoutGuts(final int port, AbstractClientCon
final CountDownLatch latch = new CountDownLatch(1);
final AtomicBoolean done = new AtomicBoolean();

Executors.newSingleThreadExecutor().execute(() -> {
List<Socket> sockets = new ArrayList<Socket>();
this.executor.execute(() -> {
List<Socket> sockets = new ArrayList<>();
try {
latch.countDown();
while (!done.get()) {
Expand Down

0 comments on commit c3b64dc

Please sign in to comment.