diff --git a/common/src/main/java/org/apache/uniffle/common/netty/protocol/GetLocalShuffleDataRequest.java b/common/src/main/java/org/apache/uniffle/common/netty/protocol/GetLocalShuffleDataRequest.java index d80d0aa6d9..b96c028fbd 100644 --- a/common/src/main/java/org/apache/uniffle/common/netty/protocol/GetLocalShuffleDataRequest.java +++ b/common/src/main/java/org/apache/uniffle/common/netty/protocol/GetLocalShuffleDataRequest.java @@ -131,4 +131,9 @@ public int getLength() { public long getTimestamp() { return timestamp; } + + @Override + public String getOperationType() { + return "getLocalShuffleData"; + } } diff --git a/common/src/main/java/org/apache/uniffle/common/netty/protocol/GetLocalShuffleIndexRequest.java b/common/src/main/java/org/apache/uniffle/common/netty/protocol/GetLocalShuffleIndexRequest.java index 1ccdfae10f..105fea051d 100644 --- a/common/src/main/java/org/apache/uniffle/common/netty/protocol/GetLocalShuffleIndexRequest.java +++ b/common/src/main/java/org/apache/uniffle/common/netty/protocol/GetLocalShuffleIndexRequest.java @@ -93,4 +93,9 @@ public int getPartitionNumPerRange() { public int getPartitionNum() { return partitionNum; } + + @Override + public String getOperationType() { + return "getLocalShuffleIndex"; + } } diff --git a/common/src/main/java/org/apache/uniffle/common/netty/protocol/GetMemoryShuffleDataRequest.java b/common/src/main/java/org/apache/uniffle/common/netty/protocol/GetMemoryShuffleDataRequest.java index d358cf7cdf..13a2412414 100644 --- a/common/src/main/java/org/apache/uniffle/common/netty/protocol/GetMemoryShuffleDataRequest.java +++ b/common/src/main/java/org/apache/uniffle/common/netty/protocol/GetMemoryShuffleDataRequest.java @@ -148,4 +148,9 @@ public long getTimestamp() { public Roaring64NavigableMap getExpectedTaskIdsBitmap() { return expectedTaskIdsBitmap; } + + @Override + public String getOperationType() { + return "getMemoryShuffleData"; + } } diff --git a/common/src/main/java/org/apache/uniffle/common/netty/protocol/RequestMessage.java b/common/src/main/java/org/apache/uniffle/common/netty/protocol/RequestMessage.java index cfa55287cf..946f906cc6 100644 --- a/common/src/main/java/org/apache/uniffle/common/netty/protocol/RequestMessage.java +++ b/common/src/main/java/org/apache/uniffle/common/netty/protocol/RequestMessage.java @@ -35,4 +35,6 @@ public RequestMessage(long requestId, ManagedBuffer managedBuffer) { public long getRequestId() { return requestId; } + + public abstract String getOperationType(); } diff --git a/common/src/main/java/org/apache/uniffle/common/netty/protocol/SendShuffleDataRequest.java b/common/src/main/java/org/apache/uniffle/common/netty/protocol/SendShuffleDataRequest.java index 492b5b64b9..a77b0d3c7a 100644 --- a/common/src/main/java/org/apache/uniffle/common/netty/protocol/SendShuffleDataRequest.java +++ b/common/src/main/java/org/apache/uniffle/common/netty/protocol/SendShuffleDataRequest.java @@ -145,4 +145,9 @@ public long getTimestamp() { public void setTimestamp(long timestamp) { this.timestamp = timestamp; } + + @Override + public String getOperationType() { + return "sendShuffleData"; + } } diff --git a/server/src/main/java/org/apache/uniffle/server/buffer/ShuffleBufferManager.java b/server/src/main/java/org/apache/uniffle/server/buffer/ShuffleBufferManager.java index ceca592113..4d42b0576f 100644 --- a/server/src/main/java/org/apache/uniffle/server/buffer/ShuffleBufferManager.java +++ b/server/src/main/java/org/apache/uniffle/server/buffer/ShuffleBufferManager.java @@ -98,7 +98,12 @@ public ShuffleBufferManager( this.readCapacity = conf.getSizeAsBytes(ShuffleServerConf.SERVER_READ_BUFFER_CAPACITY); if (this.readCapacity < 0) { this.readCapacity = - (long) (heapSize * conf.getDouble(ShuffleServerConf.SERVER_READ_BUFFER_CAPACITY_RATIO)); + nettyServerEnabled + ? (long) + (NettyUtils.getMaxDirectMemory() + * conf.getDouble(ShuffleServerConf.SERVER_READ_BUFFER_CAPACITY_RATIO)) + : (long) + (heapSize * conf.getDouble(ShuffleServerConf.SERVER_READ_BUFFER_CAPACITY_RATIO)); } LOG.info( "Init shuffle buffer manager with capacity: {}, read buffer capacity: {}.", diff --git a/server/src/main/java/org/apache/uniffle/server/netty/ShuffleServerNettyHandler.java b/server/src/main/java/org/apache/uniffle/server/netty/ShuffleServerNettyHandler.java index ac8973ecc8..4dbc43f3bc 100644 --- a/server/src/main/java/org/apache/uniffle/server/netty/ShuffleServerNettyHandler.java +++ b/server/src/main/java/org/apache/uniffle/server/netty/ShuffleServerNettyHandler.java @@ -24,6 +24,8 @@ import com.google.common.collect.Lists; import io.netty.buffer.Unpooled; +import io.netty.channel.ChannelFuture; +import io.netty.channel.ChannelFutureListener; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -281,18 +283,13 @@ public void handleGetMemoryShuffleDataRequest( ShuffleServerMetrics.counterTotalReadDataSize.inc(data.size()); ShuffleServerMetrics.counterTotalReadMemoryDataSize.inc(data.size()); } - long costTime = System.currentTimeMillis() - start; - shuffleServer - .getNettyMetrics() - .recordProcessTime(GetMemoryShuffleDataRequest.class.getName(), costTime); - LOG.info( - "Successfully getInMemoryShuffleData cost {} ms with {} bytes shuffle" + " data for {}", - costTime, - data.size(), - requestInfo); - response = new GetMemoryShuffleDataResponse(req.getRequestId(), status, msg, bufferSegments, data); + ReleaseMemoryAndRecordReadTimeListener listener = + new ReleaseMemoryAndRecordReadTimeListener( + start, readBufferSize, data.size(), requestInfo, req, client); + client.getChannel().writeAndFlush(response).addListener(listener); + return; } catch (Exception e) { status = StatusCode.INTERNAL_ERROR; msg = @@ -304,8 +301,6 @@ public void handleGetMemoryShuffleDataRequest( response = new GetMemoryShuffleDataResponse( req.getRequestId(), status, msg, Lists.newArrayList(), Unpooled.EMPTY_BUFFER); - } finally { - shuffleServer.getShuffleBufferManager().releaseReadMemory(readBufferSize); } } else { status = StatusCode.INTERNAL_ERROR; @@ -361,12 +356,11 @@ public void handleGetLocalShuffleIndexRequest( response = new GetLocalShuffleIndexResponse( req.getRequestId(), status, msg, data, shuffleIndexResult.getDataFileLen()); - long readTime = System.currentTimeMillis() - start; - LOG.info( - "Successfully getShuffleIndex cost {} ms for {}" + " bytes with {}", - readTime, - data.size(), - requestInfo); + ReleaseMemoryAndRecordReadTimeListener listener = + new ReleaseMemoryAndRecordReadTimeListener( + start, assumedFileSize, data.size(), requestInfo, req, client); + client.getChannel().writeAndFlush(response).addListener(listener); + return; } catch (FileNotFoundException indexFileNotFoundException) { LOG.warn( "Index file for {} is not found, maybe the data has been flushed to cold storage.", @@ -382,8 +376,6 @@ public void handleGetLocalShuffleIndexRequest( response = new GetLocalShuffleIndexResponse( req.getRequestId(), status, msg, Unpooled.EMPTY_BUFFER, 0L); - } finally { - shuffleServer.getShuffleBufferManager().releaseReadMemory(assumedFileSize); } } else { status = StatusCode.INTERNAL_ERROR; @@ -459,20 +451,16 @@ public void handleGetLocalShuffleData(TransportClient client, GetLocalShuffleDat storageType, offset, length); - long readTime = System.currentTimeMillis() - start; - ShuffleServerMetrics.counterTotalReadTime.inc(readTime); ShuffleServerMetrics.counterTotalReadDataSize.inc(sdr.getDataLength()); ShuffleServerMetrics.counterTotalReadLocalDataFileSize.inc(sdr.getDataLength()); - shuffleServer - .getNettyMetrics() - .recordProcessTime(GetLocalShuffleDataRequest.class.getName(), readTime); - LOG.info( - "Successfully getShuffleData cost {} ms for shuffle" + " data with {}", - readTime, - requestInfo); response = new GetLocalShuffleDataResponse( req.getRequestId(), status, msg, sdr.getManagedBuffer()); + ReleaseMemoryAndRecordReadTimeListener listener = + new ReleaseMemoryAndRecordReadTimeListener( + start, length, sdr.getDataLength(), requestInfo, req, client); + client.getChannel().writeAndFlush(response).addListener(listener); + return; } catch (Exception e) { status = StatusCode.INTERNAL_ERROR; msg = "Error happened when get shuffle data for " + requestInfo + ", " + e.getMessage(); @@ -480,8 +468,6 @@ public void handleGetLocalShuffleData(TransportClient client, GetLocalShuffleDat response = new GetLocalShuffleDataResponse( req.getRequestId(), status, msg, new NettyManagedBuffer(Unpooled.EMPTY_BUFFER)); - } finally { - shuffleServer.getShuffleBufferManager().releaseReadMemory(length); } } else { status = StatusCode.INTERNAL_ERROR; @@ -522,4 +508,88 @@ private ShufflePartitionedBlock[] toPartitionedBlock(List bloc } return ret; } + + class ReleaseMemoryAndRecordReadTimeListener implements ChannelFutureListener { + private final long readStartedTime; + private final long readBufferSize; + private final long dataSize; + private final String requestInfo; + private final RequestMessage request; + private final TransportClient client; + + public ReleaseMemoryAndRecordReadTimeListener( + long readStartedTime, + long readBufferSize, + long dataSize, + String requestInfo, + RequestMessage request, + TransportClient client) { + this.readStartedTime = readStartedTime; + this.readBufferSize = readBufferSize; + this.dataSize = dataSize; + this.requestInfo = requestInfo; + this.request = request; + this.client = client; + } + + @Override + public void operationComplete(ChannelFuture future) { + shuffleServer.getShuffleBufferManager().releaseReadMemory(readBufferSize); + long readTime = System.currentTimeMillis() - readStartedTime; + ShuffleServerMetrics.counterTotalReadTime.inc(readTime); + shuffleServer.getNettyMetrics().recordProcessTime(request.getClass().getName(), readTime); + if (!future.isSuccess()) { + Throwable cause = future.cause(); + String errorMsg = + "Error happened when executing " + + request.getOperationType() + + " for " + + requestInfo + + ", " + + cause.getMessage(); + LOG.error(errorMsg, future.cause()); + RpcResponse errorResponse; + if (request instanceof GetLocalShuffleDataRequest) { + errorResponse = + new GetLocalShuffleDataResponse( + request.getRequestId(), + StatusCode.INTERNAL_ERROR, + errorMsg, + new NettyManagedBuffer(Unpooled.EMPTY_BUFFER)); + } else if (request instanceof GetLocalShuffleIndexRequest) { + errorResponse = + new GetLocalShuffleIndexResponse( + request.getRequestId(), + StatusCode.INTERNAL_ERROR, + errorMsg, + Unpooled.EMPTY_BUFFER, + 0L); + } else if (request instanceof GetMemoryShuffleDataRequest) { + errorResponse = + new GetMemoryShuffleDataResponse( + request.getRequestId(), + StatusCode.INTERNAL_ERROR, + errorMsg, + Lists.newArrayList(), + Unpooled.EMPTY_BUFFER); + } else { + throw new RssException("Can not handle request " + request.type()); + } + client.getChannel().writeAndFlush(errorResponse); + LOG.error( + "Failed to execute {} for {}. Took {} ms and could not retrieve {} bytes of data", + request.getOperationType(), + requestInfo, + readTime, + dataSize); + } else { + LOG.info( + "Successfully executed {} for {}. Took {} ms and retrieved {} bytes of data", + request.getOperationType(), + requestInfo, + readTime, + dataSize); + } + } + } }