Skip to content

Commit

Permalink
fixed netty connection
Browse files Browse the repository at this point in the history
  • Loading branch information
vzakharchenko committed Feb 2, 2020
1 parent b478eab commit c068d50
Show file tree
Hide file tree
Showing 5 changed files with 60 additions and 5 deletions.
Original file line number Diff line number Diff line change
@@ -1,11 +1,15 @@
package com.github.vzakharchenko.radsec.server;

import io.netty.channel.ChannelHandler;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelInitializer;
import io.netty.channel.socket.nio.NioSocketChannel;
import io.netty.util.ReferenceCountUtil;
import org.jboss.logging.Logger;
import org.keycloak.models.KeycloakSession;

public class RadSecChannelInitializer extends ChannelInitializer<NioSocketChannel> {
private static final Logger LOGGER = Logger.getLogger(RadSecChannelInitializer.class);
private final IRadSecServerProvider sslProvider;
private final ChannelHandler logger;
private final ChannelHandler codec;
Expand All @@ -23,10 +27,25 @@ public RadSecChannelInitializer(IRadSecServerProvider sslProvider,
}

@Override
protected void initChannel(NioSocketChannel ch) throws Exception {
protected void initChannel(NioSocketChannel ch) {
ch.pipeline().addFirst(sslProvider.createHandler(ch));
ch.pipeline().addLast(logger);
ch.pipeline().addLast(codec);
ch.pipeline().addLast(handler);
}

@Override
public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) {
LOGGER.error("Failed to initialize a channel. Closing: " + ctx.channel(), cause);
ctx.close();
}

@Override
public void channelRead(ChannelHandlerContext ctx, Object msg) {
try {
ctx.fireChannelRead(msg);
} finally {
ReferenceCountUtil.release(msg);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,9 @@ public ChannelHandler createHandler(Channel ch) {
}
}

private EventLoopGroup createGroup(int threads) {
return new NioEventLoopGroup(threads);
}

public RadSecServerProvider(KeycloakSession session) {
super();
Expand All @@ -71,10 +74,10 @@ public RadSecServerProvider(KeycloakSession session) {
if (radSecSettings.isUseRadSec()) {
SecretProvider secretProvider = new KeycloakSecretProvider();
final PacketEncoder packetEncoder = createPacketEncoder(session);
EventLoopGroup bossGroup = new NioEventLoopGroup(radSecSettings.getnThreads());
ServerBootstrap serverBootstrap = new ServerBootstrap();
channelFuture = serverBootstrap
.channel(NioServerSocketChannel.class).group(bossGroup).clone()
.channel(NioServerSocketChannel.class).group(createGroup(6),
createGroup(radSecSettings.getnThreads())).clone()
.childHandler(new RadSecChannelInitializer(this,
new LoggingHandler(LogLevel.TRACE),
new RadSecCodec(packetEncoder, secretProvider), session))
Expand Down
Original file line number Diff line number Diff line change
@@ -1,15 +1,18 @@
package com.github.vzakharchenko.radsec.server;

import com.github.vzakharchenko.radius.test.AbstractRadiusTest;
import io.netty.channel.ChannelHandler;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.socket.nio.NioSocketChannel;
import org.mockito.Mock;
import org.testng.annotations.BeforeMethod;
import org.testng.annotations.Test;
import com.github.vzakharchenko.radius.test.AbstractRadiusTest;

import java.util.Arrays;
import java.util.List;

import static org.mockito.Mockito.*;

public class RadSecChannelInitializerTest extends AbstractRadiusTest {

@Mock
Expand All @@ -31,6 +34,24 @@ public void testRadSecChannelInitializer() throws Exception {
radSecChannelInitializer.initChannel(new NioSocketChannel());
}

@Test
public void testexceptionCaught() {
RadSecChannelInitializer radSecChannelInitializer =
new RadSecChannelInitializer(sslProvider, channel1, channel2, session);
ChannelHandlerContext ctx = mock(ChannelHandlerContext.class);
radSecChannelInitializer.exceptionCaught(ctx, new Exception());
verify(ctx).close();
}

@Test
public void testChannelRead() {
RadSecChannelInitializer radSecChannelInitializer =
new RadSecChannelInitializer(sslProvider, channel1, channel2, session);
ChannelHandlerContext ctx = mock(ChannelHandlerContext.class);
radSecChannelInitializer.channelRead(ctx, this);
verify(ctx).fireChannelRead(this);
}

@Override
protected List<? extends Object> resetMock() {
return Arrays.asList(channel1, channel2, sslProvider);
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package com.github.vzakharchenko.radius.radius.handlers;

import io.netty.channel.ChannelHandlerContext;
import org.jboss.logging.Logger;
import org.tinyradius.server.RequestCtx;
import org.tinyradius.server.handler.RequestHandler;

Expand All @@ -10,14 +11,24 @@
public abstract class AbstractThreadRequestHandler
extends RequestHandler {

private static final Logger LOGGER = Logger.getLogger(AbstractThreadRequestHandler.class);

private static final ExecutorService EXECUTOR_SERVICE = Executors
.newCachedThreadPool();

@Override
protected final void channelRead0(ChannelHandlerContext ctx, RequestCtx msg) {
EXECUTOR_SERVICE.execute(() -> channelReadRadius(ctx, msg));
EXECUTOR_SERVICE.execute(() -> {
channelReadRadius(ctx, msg);
ctx.close();
});
}

protected abstract void channelReadRadius(ChannelHandlerContext ctx, RequestCtx msg);

@Override
public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) {
LOGGER.error("Connection exception " + cause.getMessage(), cause);
ctx.close();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ public void testChannelRead0() throws InterruptedException {
authHandler.getChannelHandler(session);
authHandler.setAuthRequestInitialization(authRequestInitialization);
authHandler.channelRead0(channelHandlerContext, requestCtx);
authHandler.exceptionCaught(channelHandlerContext, new IllegalStateException());
TimeUnit.SECONDS.sleep(3);
verify(channelHandlerContext).writeAndFlush(any());
}
Expand Down

0 comments on commit c068d50

Please sign in to comment.