Skip to content

Commit

Permalink
Merge pull request #424 from olegz/INT-2508
Browse files Browse the repository at this point in the history
  • Loading branch information
garyrussell committed May 9, 2012
2 parents 9235efe + 7d0dcf0 commit 77e3886
Show file tree
Hide file tree
Showing 2 changed files with 59 additions and 23 deletions.
@@ -1,5 +1,5 @@
/*
* Copyright 2002-2011 the original author or authors.
* Copyright 2002-2012 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -17,11 +17,9 @@
package org.springframework.integration.channel;

import java.util.Comparator;
import java.util.Map;
import java.util.concurrent.PriorityBlockingQueue;
import java.util.concurrent.atomic.AtomicLong;

import org.springframework.beans.DirectFieldAccessor;
import org.springframework.integration.Message;
import org.springframework.integration.MessageHeaders;
import org.springframework.integration.util.UpperBound;
Expand All @@ -38,9 +36,6 @@ public class PriorityChannel extends QueueChannel {
private final UpperBound upperBound;

private final AtomicLong sequenceCounter = new AtomicLong();

private static final String SEQUENCE_HEADER_NAME = "__priorityChannelSequence__";


/**
* Create a channel with the specified queue capacity. If the capacity
Expand Down Expand Up @@ -80,26 +75,20 @@ public PriorityChannel() {
this(0, null);
}


@SuppressWarnings({ "rawtypes", "unchecked" })
@Override
protected boolean doSend(Message<?> message, long timeout) {
if (!upperBound.tryAcquire(timeout)) {
return false;
}
Map innerMap = (Map) new DirectFieldAccessor(message.getHeaders()).getPropertyValue("headers");
innerMap.put(SEQUENCE_HEADER_NAME, sequenceCounter.incrementAndGet());
message = new MessageWrapper(message);
return super.doSend(message, 0);
}

@SuppressWarnings({ "rawtypes"})
@Override
protected Message<?> doReceive(long timeout) {
Message<?> message = super.doReceive(timeout);

if (message != null) {
Map innerMap = (Map) new DirectFieldAccessor(message.getHeaders()).getPropertyValue("headers");
innerMap.remove(SEQUENCE_HEADER_NAME);
message = ((MessageWrapper)message).getRootMessage();
upperBound.release();
}
return message;
Expand Down Expand Up @@ -128,12 +117,38 @@ public int compare(Message<?> message1, Message<?> message2) {
}

if (compareResult == 0){
Long sequence1 = (Long) message1.getHeaders().get(SEQUENCE_HEADER_NAME);
Long sequence2 = (Long) message2.getHeaders().get(SEQUENCE_HEADER_NAME);
Long sequence1 = ((MessageWrapper) message1).getSequence();
Long sequence2 = ((MessageWrapper) message2).getSequence();
compareResult = sequence1.compareTo(sequence2);
}
return compareResult;
}
}

//we need this because of INT-2508
private class MessageWrapper implements Message<Object>{
private final Message<?> rootMessage;
private final long sequence;

public MessageWrapper(Message<?> rootMessage){
this.rootMessage = rootMessage;
this.sequence = sequenceCounter.incrementAndGet();
}

public Message<?> getRootMessage(){
return this.rootMessage;
}

public MessageHeaders getHeaders() {
return this.rootMessage.getHeaders();
}

public Object getPayload() {
return rootMessage.getPayload();
}

long getSequence(){
return this.sequence;
}
}
}
@@ -1,5 +1,5 @@
/*
* Copyright 2002-2010 the original author or authors.
* Copyright 2002-2012 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -16,12 +16,6 @@

package org.springframework.integration.channel;

import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertFalse;
import static org.junit.Assert.assertNotNull;
import static org.junit.Assert.assertNull;
import static org.junit.Assert.assertTrue;

import java.util.Comparator;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.Executor;
Expand All @@ -35,6 +29,12 @@
import org.springframework.integration.message.GenericMessage;
import org.springframework.integration.support.MessageBuilder;

import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertFalse;
import static org.junit.Assert.assertNotNull;
import static org.junit.Assert.assertNull;
import static org.junit.Assert.assertTrue;

/**
* @author Mark Fisher
*/
Expand Down Expand Up @@ -82,6 +82,27 @@ public void testDefaultComparator() {
assertEquals("test:-99", channel.receive(0).getPayload());
}

// although this test has no assertions it results in ConcurrentModificationException
// if executed before changes for INT-2508
@Test
public void testPriorityChannelWithConcurrentModification() throws Exception{
final PriorityChannel channel = new PriorityChannel();
final Message<String> message = new GenericMessage<String>("hello");
for (int i = 0; i < 1000; i++) {
channel.send(message);
new Thread(new Runnable() {
public void run() {
channel.receive();
}
}).start();
new Thread(new Runnable() {
public void run() {
message.getHeaders().toString();
}
}).start();
}
}

@Test
public void testCustomComparator() {
PriorityChannel channel = new PriorityChannel(5, new StringPayloadComparator());
Expand Down

0 comments on commit 77e3886

Please sign in to comment.