Skip to content

Commit

Permalink
Correctly write pending data after ssl handshake completes. Related to [
Browse files Browse the repository at this point in the history
netty#2437]

Motivation:
When writing data from a server before the ssl handshake completes may not be written at all to the remote peer
if nothing else is written after the handshake was done.

Modification:
Correctly try to write pending data after the handshake was complete

Result:
Correctly write out all pending data
  • Loading branch information
Norman Maurer committed Apr 30, 2014
1 parent 82220b0 commit 76355a2
Show file tree
Hide file tree
Showing 2 changed files with 170 additions and 3 deletions.
15 changes: 12 additions & 3 deletions handler/src/main/java/io/netty/handler/ssl/SslHandler.java
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,7 @@ public class SslHandler extends ByteToMessageDecoder implements ChannelOutboundH

private final boolean startTls;
private boolean sentFirstMessage;

private boolean flushedBeforeHandshakeDone;
private final LazyChannelPromise handshakePromise = new LazyChannelPromise();
private final LazyChannelPromise sslCloseFuture = new LazyChannelPromise();
private final Deque<PendingWrite> pendingUnencryptedWrites = new ArrayDeque<PendingWrite>();
Expand Down Expand Up @@ -415,6 +415,9 @@ public void flush(ChannelHandlerContext ctx) throws Exception {
if (pendingUnencryptedWrites.isEmpty()) {
pendingUnencryptedWrites.add(PendingWrite.newInstance(Unpooled.EMPTY_BUFFER, null));
}
if (!handshakePromise.isDone()) {
flushedBeforeHandshakeDone = true;
}
wrap(ctx, false);
ctx.flush();
}
Expand All @@ -437,7 +440,6 @@ private void wrap(ChannelHandlerContext ctx, boolean inUnwrap) throws SSLExcepti
pendingUnencryptedWrites.remove();
continue;
}

ByteBuf buf = (ByteBuf) pending.msg();
SSLEngineResult result = wrap(engine, buf, out);

Expand Down Expand Up @@ -890,7 +892,6 @@ private void unwrapNonApp(ChannelHandlerContext ctx) throws SSLException {
private void unwrapMultiple(
ChannelHandlerContext ctx, ByteBuffer packet, int totalLength,
int[] recordLengths, int nRecords, List<Object> out) throws SSLException {

for (int i = 0; i < nRecords; i ++) {
packet.limit(packet.position() + recordLengths[i]);
try {
Expand Down Expand Up @@ -949,6 +950,14 @@ private void unwrapSingle(
wrapLater = true;
continue;
}
if (flushedBeforeHandshakeDone) {
// We need to call wrap(...) in case there was a flush done before the handshake completed.
//
// See https://github.com/netty/netty/pull/2437
flushedBeforeHandshakeDone = false;
wrapLater = true;
}

break;
default:
throw new IllegalStateException("Unknown handshake status: " + handshakeStatus);
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,158 @@
/*
* Copyright 2014 The Netty Project
*
* The Netty Project licenses this file to you under the Apache License,
* version 2.0 (the "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at:
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*/
package io.netty.testsuite.transport.socket;

import io.netty.bootstrap.Bootstrap;
import io.netty.bootstrap.ServerBootstrap;
import io.netty.buffer.ByteBuf;
import io.netty.buffer.Unpooled;
import io.netty.channel.Channel;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelInitializer;
import io.netty.channel.ChannelPipeline;
import io.netty.channel.SimpleChannelInboundHandler;
import io.netty.channel.socket.SocketChannel;
import io.netty.handler.logging.LogLevel;
import io.netty.handler.logging.LoggingHandler;
import io.netty.handler.ssl.SslHandler;
import io.netty.testsuite.util.BogusSslContextFactory;
import io.netty.util.ReferenceCountUtil;
import org.junit.Test;

import javax.net.ssl.SSLEngine;
import java.io.IOException;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.atomic.AtomicReference;

import static org.junit.Assert.assertEquals;


public class SocketSslGreetingTest extends AbstractSocketTest {

private static final LogLevel LOG_LEVEL = LogLevel.TRACE;
private final ByteBuf greeting = ReferenceCountUtil.releaseLater(Unpooled.buffer().writeByte('a'));

// Test for https://github.com/netty/netty/pull/2437
@Test(timeout = 30000)
public void testSslGreeting() throws Throwable {
run();
}

public void testSslGreeting(ServerBootstrap sb, Bootstrap cb) throws Throwable {
final ServerHandler sh = new ServerHandler();
final ClientHandler ch = new ClientHandler();

sb.childHandler(new ChannelInitializer<SocketChannel>() {
@Override
public void initChannel(SocketChannel sch) throws Exception {
final SSLEngine sse = BogusSslContextFactory.getServerContext().createSSLEngine();
sse.setUseClientMode(false);

ChannelPipeline p = sch.pipeline();
p.addLast(new SslHandler(sse));
p.addLast("logger", new LoggingHandler(LOG_LEVEL));
p.addLast(sh);
}
});

cb.handler(new ChannelInitializer<SocketChannel>() {
@Override
public void initChannel(SocketChannel sch) throws Exception {
final SSLEngine cse = BogusSslContextFactory.getClientContext().createSSLEngine();
cse.setUseClientMode(true);

ChannelPipeline p = sch.pipeline();
p.addLast(new SslHandler(cse));
p.addLast("logger", new LoggingHandler(LOG_LEVEL));
p.addLast(ch);
}
});

Channel sc = sb.bind().sync().channel();
Channel cc = cb.connect().sync().channel();

ch.latch.await();

sh.channel.close().awaitUninterruptibly();
cc.close().awaitUninterruptibly();
sc.close().awaitUninterruptibly();

if (sh.exception.get() != null && !(sh.exception.get() instanceof IOException)) {
throw sh.exception.get();
}
if (ch.exception.get() != null && !(ch.exception.get() instanceof IOException)) {
throw ch.exception.get();
}
if (sh.exception.get() != null) {
throw sh.exception.get();
}
if (ch.exception.get() != null) {
throw ch.exception.get();
}
}

private class ClientHandler extends SimpleChannelInboundHandler<ByteBuf> {

final AtomicReference<Throwable> exception = new AtomicReference<Throwable>();
final CountDownLatch latch = new CountDownLatch(1);

@Override
public void channelRead0(ChannelHandlerContext ctx, ByteBuf buf) throws Exception {
assertEquals(greeting, buf);
latch.countDown();
ctx.close();
}

@Override
public void exceptionCaught(ChannelHandlerContext ctx,
Throwable cause) throws Exception {
if (logger.isWarnEnabled()) {
logger.warn("Unexpected exception from the client side", cause);
}

exception.compareAndSet(null, cause);
ctx.close();
}
}

private class ServerHandler extends SimpleChannelInboundHandler<String> {
volatile Channel channel;
final AtomicReference<Throwable> exception = new AtomicReference<Throwable>();

@Override
protected void channelRead0(ChannelHandlerContext ctx, String msg) throws Exception {
// discard
}

@Override
public void channelActive(ChannelHandlerContext ctx)
throws Exception {
channel = ctx.channel();
channel.writeAndFlush(greeting.duplicate().retain());
}

@Override
public void exceptionCaught(ChannelHandlerContext ctx,
Throwable cause) throws Exception {
if (logger.isWarnEnabled()) {
logger.warn("Unexpected exception from the server side", cause);
}

exception.compareAndSet(null, cause);
ctx.close();
}
}
}

0 comments on commit 76355a2

Please sign in to comment.