Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

RMQSource refactor #1

Merged
merged 5 commits into from
May 9, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
package org.apache.flink.streaming.connectors.rabbitmq;

import org.apache.flink.api.java.typeutils.ResultTypeQueryable;

import com.rabbitmq.client.AMQP;
import com.rabbitmq.client.Envelope;

import java.io.IOException;
import java.io.Serializable;

/**
* Interface for the set of methods required to parse an RMQ delivery.
* @param <T> The output type of the {@link RMQSource}
*/
public interface RMQDeliveryParser<T> extends Serializable, ResultTypeQueryable<T> {
/**
* This method takes all the RabbitMQ delivery information supplied by the client and returns an output matching
* the {@link RMQSource}.
* @param envelope
* @param properties
* @param body
* @return an output T matching the output of the RMQSource
* @throws IOException
*/
public T parse(Envelope envelope, AMQP.BasicProperties properties, byte[] body) throws IOException;

/**
* A method that extracts a unique correlation id from the RabbitMQ delivery information. This ID is used for
* deduplicating the messages in the RMQSource.
* @param envelope
* @param properties
* @param body
* @return
*/
public String getCorrelationID(Envelope envelope, AMQP.BasicProperties properties, byte[] body);
}
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,11 @@
import org.apache.flink.streaming.connectors.rabbitmq.common.RMQConnectionConfig;
import org.apache.flink.util.Preconditions;

import com.rabbitmq.client.AMQP;
import com.rabbitmq.client.Channel;
import com.rabbitmq.client.Connection;
import com.rabbitmq.client.ConnectionFactory;
import com.rabbitmq.client.Envelope;
import com.rabbitmq.client.QueueingConsumer;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
Expand Down Expand Up @@ -76,6 +78,7 @@ public class RMQSource<OUT> extends MultipleIdsMessageAcknowledgingSourceBase<OU
private final RMQConnectionConfig rmqConnectionConfig;
protected final String queueName;
private final boolean usesCorrelationId;
protected RMQDeliveryParser<OUT> deliveryParser;
protected DeserializationSchema<OUT> schema;

protected transient Connection connection;
Expand Down Expand Up @@ -124,6 +127,46 @@ public RMQSource(RMQConnectionConfig rmqConnectionConfig,
this.schema = deserializationSchema;
}

/**
* Creates a new RabbitMQ source with at-least-once message processing guarantee when
* checkpointing is enabled. No strong delivery guarantees when checkpointing is disabled.
*
* <p>For exactly-once, please use the constructor
* {@link RMQSource#RMQSource(RMQConnectionConfig, String, boolean, RMQDeliveryParser)}.
*
* <p>It also uses the provided {@link RMQDeliveryParser} to parse both the correlationID and the message.
* @param rmqConnectionConfig The RabbiMQ connection configuration {@link RMQConnectionConfig}.
* @param queueName The queue to receive messages from.
* @param deliveryParser A {@link RMQDeliveryParser} for parsing the RMQDelivery.
*/
public RMQSource(RMQConnectionConfig rmqConnectionConfig, String queueName,
RMQDeliveryParser<OUT> deliveryParser) {
this(rmqConnectionConfig, queueName, false, deliveryParser);
}

/**
* Creates a new RabbitMQ source. For exactly-once, you must set the correlation ids of messages
* at the producer. The correlation id must be unique. Otherwise the behavior of the source is
* undefined. If in doubt, set usesCorrelationId to false. When correlation ids are not
* used, this source has at-least-once processing semantics when checkpointing is enabled.
*
* <p>It also uses the provided {@link RMQDeliveryParser} to parse both the correlationID and the message.
* @param rmqConnectionConfig The RabbiMQ connection configuration {@link RMQConnectionConfig}.
* @param queueName The queue to receive messages from.
* @param usesCorrelationId Whether the messages received are supplied with a <b>unique</b>
* id to deduplicate messages (in case of failed acknowledgments).
* Only used when checkpointing is enabled.
* @param deliveryParser A {@link RMQDeliveryParser} for parsing the RMQDelivery.
*/
public RMQSource(RMQConnectionConfig rmqConnectionConfig,
String queueName, boolean usesCorrelationId, RMQDeliveryParser<OUT> deliveryParser) {
super(String.class);
this.rmqConnectionConfig = rmqConnectionConfig;
this.queueName = queueName;
this.usesCorrelationId = usesCorrelationId;
this.deliveryParser = deliveryParser;
}

/**
* Initializes the connection to RMQ with a default connection factory. The user may override
* this method to setup and configure their own ConnectionFactory.
Expand Down Expand Up @@ -187,26 +230,75 @@ public void close() throws Exception {
}
}

/**
* Parse and returns the body of the an AMQP message.
*
* <p>If any of the constructors with the {@link DeserializationSchema} class was used to construct the source
* it uses the {@link DeserializationSchema#deserialize(byte[])} to parse the body of the AMQP message.
*
* <p>If any of the constructors with the {@link RMQDeliveryParser } class was used to construct the source it uses the
* {@link RMQDeliveryParser#parse(Envelope, AMQP.BasicProperties, byte[])} method of that provided instance.
*
* @param envelope
* @param properties
* @param body
* @return OUT
* @throws IOException
*/
protected OUT parseBody(Envelope envelope, AMQP.BasicProperties properties, byte[] body) throws IOException {
if (deliveryParser != null){
return deliveryParser.parse(envelope, properties, body);
} else {
return schema.deserialize(body);
}
}

/**
* Extracts and returns the correlationID.
*
* <p>If any of the constructors with the {@link DeserializationSchema} class was used to construct the source
* it uses the {@link AMQP.BasicProperties#getCorrelationId()} to retrieve the correlationID.
*
* <p>If any of the constructors with the {@link RMQDeliveryParser } class was used to construct the source it uses the
* {@link RMQDeliveryParser#getCorrelationID(Envelope, AMQP.BasicProperties, byte[])} to retrieve the correlationID.
*
* @param envelope
* @param properties
* @param body
* @return String
* @throws IOException
*/
protected String extractCorrelationID(Envelope envelope, AMQP.BasicProperties properties, byte[] body) throws IOException {
if (deliveryParser != null){
return deliveryParser.getCorrelationID(envelope, properties, body);
} else {
return properties.getCorrelationId();
}
}

@Override
public void run(SourceContext<OUT> ctx) throws Exception {
while (running) {
QueueingConsumer.Delivery delivery = consumer.nextDelivery();

synchronized (ctx.getCheckpointLock()) {

OUT result = schema.deserialize(delivery.getBody());
Envelope envelope = delivery.getEnvelope();
AMQP.BasicProperties properties = delivery.getProperties();
byte[] body = delivery.getBody();

OUT result = parseBody(envelope, properties, body);

if (schema.isEndOfStream(result)) {
if (schema != null && schema.isEndOfStream(result)) {
break;
}

if (!autoAck) {
final long deliveryTag = delivery.getEnvelope().getDeliveryTag();
if (usesCorrelationId) {
final String correlationId = delivery.getProperties().getCorrelationId();
final String correlationId = extractCorrelationID(envelope, properties, body);
Preconditions.checkNotNull(correlationId, "RabbitMQ source was instantiated " +
"with usesCorrelationId set to true but a message was received with " +
"correlation id set to null!");
"with usesCorrelationId set to true yet we couldn't extract the correlation id from it !");
if (!addId(correlationId)) {
// we have already processed this message
continue;
Expand Down Expand Up @@ -239,6 +331,6 @@ protected void acknowledgeSessionIDs(List<Long> sessionIds) {

@Override
public TypeInformation<OUT> getProducedType() {
return schema.getProducedType();
return deliveryParser == null ? schema.getProducedType() : deliveryParser.getProducedType();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -285,6 +285,48 @@ public void testConstructorParams() throws Exception {
assertEquals("passTest", testObj.getFactory().getPassword());
}

/**
* Tests getting the correct correlation ID given which constructor was called.
* if the constructor with the DeserializationSchema was called it uses the mocked AMQP property correlationID.
* if the constructor with the RMQDeliveryParser was called it uses it's getCorrelationID which retrieves the mocked
* AMQP property messageId.
*/
@Test
public void testExtractCorrelationId() throws Exception {
RMQTestSource sourceDeserializer = new RMQTestSource();
sourceDeserializer.initAMQPMocks();
String correlationID = sourceDeserializer.extractCorrelationID(
sourceDeserializer.mockedAMQPEnvelope, sourceDeserializer.mockedAMQPProperties, "".getBytes());
assertEquals("0", correlationID);

RMQTestSource sourceParser = new RMQTestSource(new CustomDeliveryParser());
sourceParser.initAMQPMocks();
correlationID = sourceParser.extractCorrelationID(
sourceParser.mockedAMQPEnvelope, sourceParser.mockedAMQPProperties, "".getBytes());
assertEquals("1-MESSAGE_ID", correlationID);
}

/**
* Tests getting the correct body given which constructor was called.
* if the constructor with the DeserilizationSchema was called it uses it parse the AMQP body.
* if the constructor with the RMQDeliveryParser was called it uses it's parseBody method to retrieve the mocked AMQP
* property messageID
*/
@Test
public void testParseBody() throws Exception {
RMQTestSource sourceDeserializer = new RMQTestSource();
sourceDeserializer.open(config);
String correlationID = sourceDeserializer.parseBody(
sourceDeserializer.mockedAMQPEnvelope, sourceDeserializer.mockedAMQPProperties, "I Love Turtles".getBytes());
assertEquals("I Love Turtles", correlationID);

RMQTestSource sourceParser = new RMQTestSource(new CustomDeliveryParser());
sourceParser.open(config);
correlationID = sourceParser.parseBody(
sourceParser.mockedAMQPEnvelope, sourceParser.mockedAMQPProperties, "".getBytes());
assertEquals("1-MESSAGE_ID", correlationID);
}

private static class ConstructorTestClass extends RMQSource<String> {

private ConnectionFactory factory;
Expand Down Expand Up @@ -338,16 +380,44 @@ public TypeInformation<String> getProducedType() {
}
}

private class CustomDeliveryParser implements RMQDeliveryParser<String> {

@Override
public String parse(Envelope envelope, AMQP.BasicProperties properties, byte[] body) throws IOException {
return properties.getMessageId();
}

@Override
public String getCorrelationID(Envelope envelope, AMQP.BasicProperties properties, byte[] body) {
return properties.getMessageId();
}

@Override
public TypeInformation<String> getProducedType() {
return TypeExtractor.getForClass(String.class);
}
}

private class RMQTestSource extends RMQSource<String> {

private ArrayDeque<Tuple2<Long, Set<String>>> restoredState;

private QueueingConsumer.Delivery mockedDelivery;
public Envelope mockedAMQPEnvelope;
public AMQP.BasicProperties mockedAMQPProperties;

public RMQTestSource() {
super(new RMQConnectionConfig.Builder().setHost("hostTest")
.setPort(999).setUserName("userTest").setPassword("passTest").setVirtualHost("/").build()
, "queueDummy", true, new StringDeserializationScheme());
}

public RMQTestSource(RMQDeliveryParser deliveryParser) {
super(new RMQConnectionConfig.Builder().setHost("hostTest")
.setPort(999).setUserName("userTest").setPassword("passTest").setVirtualHost("/").build()
, "queueDummy", true, deliveryParser);
}

@Override
public void initializeState(FunctionInitializationContext context) throws Exception {
super.initializeState(context);
Expand All @@ -358,10 +428,7 @@ public ArrayDeque<Tuple2<Long, Set<String>>> getRestoredState() {
return this.restoredState;
}

@Override
public void open(Configuration config) throws Exception {
super.open(config);

public void initAMQPMocks() {
consumer = Mockito.mock(QueueingConsumer.class);

// Mock for delivery
Expand All @@ -375,27 +442,39 @@ public void open(Configuration config) throws Exception {
}

// Mock for envelope
Envelope envelope = Mockito.mock(Envelope.class);
Mockito.when(deliveryMock.getEnvelope()).thenReturn(envelope);
mockedAMQPEnvelope = Mockito.mock(Envelope.class);
Mockito.when(deliveryMock.getEnvelope()).thenReturn(mockedAMQPEnvelope);

Mockito.when(envelope.getDeliveryTag()).thenAnswer(new Answer<Long>() {
Mockito.when(mockedAMQPEnvelope.getDeliveryTag()).thenAnswer(new Answer<Long>() {
@Override
public Long answer(InvocationOnMock invocation) throws Throwable {
return ++messageId;
}
});

// Mock for properties
AMQP.BasicProperties props = Mockito.mock(AMQP.BasicProperties.class);
Mockito.when(deliveryMock.getProperties()).thenReturn(props);
mockedAMQPProperties = Mockito.mock(AMQP.BasicProperties.class);
Mockito.when(deliveryMock.getProperties()).thenReturn(mockedAMQPProperties);

Mockito.when(props.getCorrelationId()).thenAnswer(new Answer<String>() {
Mockito.when(mockedAMQPProperties.getCorrelationId()).thenAnswer(new Answer<String>() {
@Override
public String answer(InvocationOnMock invocation) throws Throwable {
return generateCorrelationIds ? "" + messageId : null;
}
});

Mockito.when(mockedAMQPProperties.getMessageId()).thenAnswer(new Answer<String>(){
@Override
public String answer(InvocationOnMock invocation) throws Throwable {
return ++messageId + "-MESSAGE_ID";
}
});
}

@Override
public void open(Configuration config) throws Exception {
super.open(config);
initAMQPMocks();
}

@Override
Expand Down