Skip to content

Commit

Permalink
perf: remove PGStream.streamBuffer and reuse PgBufferedOutputStream's…
Browse files Browse the repository at this point in the history
… buffer when sending data from InputStream
  • Loading branch information
vlsi committed May 19, 2024
1 parent 504c48f commit 94b7e3f
Show file tree
Hide file tree
Showing 5 changed files with 256 additions and 38 deletions.
53 changes: 17 additions & 36 deletions pgjdbc/src/main/java/org/postgresql/core/PGStream.java
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import org.postgresql.util.PSQLException;
import org.postgresql.util.PSQLState;
import org.postgresql.util.internal.PgBufferedOutputStream;
import org.postgresql.util.internal.SourceStreamIOException;

import org.checkerframework.checker.nullness.qual.Nullable;
import org.ietf.jgss.GSSContext;
Expand Down Expand Up @@ -54,7 +55,6 @@ public class PGStream implements Closeable, Flushable {
private Socket connection;
private VisibleBufferedInputStream pgInput;
private PgBufferedOutputStream pgOutput;
private byte @Nullable [] streamBuffer;

public boolean isGssEncrypted() {
return gssEncrypted;
Expand Down Expand Up @@ -413,9 +413,9 @@ public void send(byte[] buf, int siz) throws IOException {
*/
public void send(byte[] buf, int off, int siz) throws IOException {
int bufamt = buf.length - off;
pgOutput.write(buf, off, bufamt < siz ? bufamt : siz);
for (int i = bufamt; i < siz; i++) {
pgOutput.write(0);
pgOutput.write(buf, off, Math.min(bufamt, siz));
if (siz > bufamt) {
pgOutput.writeZeros(siz - bufamt);
}
}

Expand All @@ -440,9 +440,7 @@ public OutputStream getOutputStream() {
} catch (Exception re) {
throw new IOException("Error writing bytes to stream", re);
}
for (int i = fixedLengthStream.remaining(); i > 0; i--) {
pgOutput.write(0);
}
pgOutput.writeZeros(fixedLengthStream.remaining());
}

/**
Expand Down Expand Up @@ -680,38 +678,21 @@ public void skip(int size) throws IOException {
*
* @param inStream the stream to read data from
* @param remaining the number of bytes to copy
* @throws IOException if a data I/O error occurs
* @throws IOException if error occurs when writing the data to the output stream
* @throws SourceStreamIOException if error occurs when reading the data from the input stream
*/
public void sendStream(InputStream inStream, int remaining) throws IOException {
int expectedLength = remaining;
byte[] streamBuffer = this.streamBuffer;
if (streamBuffer == null) {
this.streamBuffer = streamBuffer = new byte[8192];
}

while (remaining > 0) {
int count = remaining > streamBuffer.length ? streamBuffer.length : remaining;
int readCount;

try {
readCount = inStream.read(streamBuffer, 0, count);
if (readCount < 0) {
throw new EOFException(
GT.tr("Premature end of input stream, expected {0} bytes, but only read {1}.",
expectedLength, expectedLength - remaining));
}
} catch (IOException ioe) {
while (remaining > 0) {
send(streamBuffer, count);
remaining -= count;
count = remaining > streamBuffer.length ? streamBuffer.length : remaining;
}
throw new PGBindException(ioe);
}
pgOutput.write(inStream, remaining);
}

send(streamBuffer, readCount);
remaining -= readCount;
}
/**
* Writes the given amount of zero bytes to the output stream
* @param length the number of zeros to write
* @throws IOException in case writing to the output stream fails
* @throws SourceStreamIOException in case reading from the source stream fails
*/
public void sendZeros(int length) throws IOException {
pgOutput.writeZeros(length);
}

/**
Expand Down
12 changes: 10 additions & 2 deletions pgjdbc/src/main/java/org/postgresql/core/v3/QueryExecutorImpl.java
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@
import org.postgresql.util.PSQLState;
import org.postgresql.util.PSQLWarning;
import org.postgresql.util.ServerErrorMessage;
import org.postgresql.util.internal.SourceStreamIOException;

import org.checkerframework.checker.nullness.qual.Nullable;

Expand Down Expand Up @@ -1771,8 +1772,15 @@ private void sendBind(SimpleQuery query, SimpleParameterList params, @Nullable P
pgStream.sendInteger4(params.getV3Length(i)); // Parameter size
try {
params.writeV3Value(i, pgStream); // Parameter value
} catch (PGBindException be) {
bindException = be;
} catch (SourceStreamIOException sse) {
// Remember the error for rethrow later
if (bindException == null) {
bindException = new PGBindException(sse.getCause());
} else {
bindException.addSuppressed(sse.getCause());
}
// Write out the missing bytes so the stream does not corrupt
pgStream.sendZeros(sse.getBytesRemaining());
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,14 @@
package org.postgresql.util.internal;

import org.postgresql.util.ByteConverter;
import org.postgresql.util.GT;

import java.io.EOFException;
import java.io.FilterOutputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
import java.util.Arrays;

/**
* Buffered output stream. The key difference from {@link java.io.BufferedOutputStream} is that
Expand Down Expand Up @@ -116,4 +120,75 @@ public void write(byte[] b, int off, int len) throws IOException {
System.arraycopy(b, off, buf, 0, len);
count = len;
}

/**
* Writes the given amount of bytes from an input stream to this buffered stream.
* @param inStream input data
* @param remaining the number of bytes to transfer
* @throws IOException in case writing to the output stream fails
* @throws SourceStreamIOException in case reading from the source stream fails
*/
public void write(InputStream inStream, int remaining) throws IOException {
int expectedLength = remaining;
byte[] buf = this.buf;

while (remaining > 0) {
int readSize = Math.min(remaining, buf.length - count);
int readCount;

try {
readCount = inStream.read(buf, count, readSize);
} catch (IOException e) {
throw new SourceStreamIOException(remaining, e);
}

if (readCount < 0) {
throw new SourceStreamIOException(
remaining,
new EOFException(
GT.tr("Premature end of input stream, expected {0} bytes, but only read {1}.",
expectedLength, expectedLength - remaining)));
}

count += readCount;
remaining -= readCount;
if (count == buf.length) {
flushBuffer();
}
}
}

/**
* Writes the required number of zero bytes to the output stream.
* @param len number of bytes to write
* @throws IOException in case writing to the underlying stream fails
*/
public void writeZeros(int len) throws IOException {
int startPos = count;
if (count > 0) {
int avail = buf.length - count;
int prefixLength = Math.min(len, avail);
Arrays.fill(buf, count, count + prefixLength, (byte) 0);
count += prefixLength;
len -= prefixLength;
if (count == buf.length) {
flushBuffer();
}
if (len == 0) {
return;
}
}
// The buffer is empty at this point, and startPos..buf.length is filled with zeros
// So fill the beginning with zeros as well.
Arrays.fill(buf, 0, Math.min(startPos, len), (byte) 0);

while (len >= buf.length) {
// Pretend we have the full buffer
count = buf.length;
flushBuffer();
len -= buf.length;
}
// Pretend we have the remaining zeros in the buffer.
count = len;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
/*
* Copyright (c) 2024, PostgreSQL Global Development Group
* See the LICENSE file in the project root for more information.
*/

package org.postgresql.util.internal;

import static org.postgresql.util.internal.Nullness.castNonNull;

import java.io.IOException;

/**
* A marker exception class to distinguish between "IOException when reading the data" and
* "IOException when writing the data" when transferring data from one stream to another.
*/
public class SourceStreamIOException extends IOException {
/**
* The number of bytes remaining to transfer to the destination stream.
*/
private final int bytesRemaining;

public SourceStreamIOException(int bytesRemaining, IOException cause) {
super(cause);
this.bytesRemaining = bytesRemaining;
}

public int getBytesRemaining() {
return bytesRemaining;
}

@Override
public synchronized IOException getCause() {
return (IOException) castNonNull(super.getCause());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,19 @@

package org.postgresql.util.internal;

import static org.junit.jupiter.api.Assertions.assertAll;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.fail;
import static org.junit.jupiter.params.provider.Arguments.arguments;

import org.postgresql.test.util.StrangeOutputStream;

import org.checkerframework.checker.nullness.qual.Nullable;
import org.junit.jupiter.api.Nested;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.Arguments;
import org.junit.jupiter.params.provider.MethodSource;
import org.junit.jupiter.params.provider.ValueSource;

import java.io.ByteArrayOutputStream;
Expand All @@ -22,6 +26,9 @@
import java.io.OutputStream;
import java.time.Instant;
import java.time.temporal.ChronoUnit;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Random;
import java.util.concurrent.ThreadLocalRandom;

Expand Down Expand Up @@ -167,6 +174,118 @@ void bufferFlushedOnlyWhenFull(int offset) throws IOException {
}
}

private static final int ZEROS_BUFFER_SIZE = 16;

public static Iterable<Arguments> zerosOffsetAndNumber() {
List<Arguments> res = new ArrayList<>();
for (int offset : new int[]{
0, 1, 2, 3,
ZEROS_BUFFER_SIZE - 2, ZEROS_BUFFER_SIZE - 1, ZEROS_BUFFER_SIZE - 1}) {
for (int numZeros : new int[]{
1, 2, 3,
ZEROS_BUFFER_SIZE - 2, ZEROS_BUFFER_SIZE - 1, ZEROS_BUFFER_SIZE - 1,
ZEROS_BUFFER_SIZE,
ZEROS_BUFFER_SIZE + 1, ZEROS_BUFFER_SIZE + 2, ZEROS_BUFFER_SIZE + 3,
ZEROS_BUFFER_SIZE * 2 - 2, ZEROS_BUFFER_SIZE * 2 - 1, ZEROS_BUFFER_SIZE * 2 - 1,
ZEROS_BUFFER_SIZE * 2,
ZEROS_BUFFER_SIZE * 2 + 1, ZEROS_BUFFER_SIZE * 2 + 2, ZEROS_BUFFER_SIZE * 2 + 3}) {
res.add(arguments(offset, numZeros));
}
}
return res;
}

@Nested
class ZeroTests {
@ParameterizedTest
@MethodSource("org.postgresql.util.internal.PgBufferedOutputStreamTest#zerosOffsetAndNumber")
void bufferFlushedOnlyWhenFull(int offset, int numZeros) throws IOException {
AssertBufferedWrites dst = new AssertBufferedWrites(new ByteArrayOutputStream());
int bufferSize = ZEROS_BUFFER_SIZE;
PgBufferedOutputStream out = new PgBufferedOutputStream(dst, bufferSize);

dst.forbidWrites("Writing less data than the buffer size should not cause flushes");
out.writeZeros(offset);
if (numZeros + offset >= bufferSize) {
dst.allowWrites(bufferSize);
}
out.writeZeros(numZeros);
dst.allowWrites((numZeros + offset) % bufferSize);
out.flush();
}

@ParameterizedTest
@MethodSource("org.postgresql.util.internal.PgBufferedOutputStreamTest#zerosOffsetAndNumber")
void writesZeros(int offset, int numZeros) throws IOException {
ByteArrayOutputStream baos = new ByteArrayOutputStream();
int bufferSize = ZEROS_BUFFER_SIZE;
PgBufferedOutputStream out = new PgBufferedOutputStream(baos, bufferSize);
// Fill the buffer with non-zero values
for (int i = 0; i < bufferSize; i++) {
out.write(0xff);
}
// Shift the offset within the buffer
for (int i = 0; i < offset; i++) {
out.write(0xfe);
}
out.writeZeros(numZeros);
out.write(0xca);
out.flush();
byte[] res = baos.toByteArray();
assertAll(
() -> assertEquals(bufferSize + offset + numZeros + 1, res.length,
() -> "Result should have "
+ bufferSize + " 0xff prefix bytes, "
+ offset + " 0xfe offset bytes, "
+ numZeros + " 0x00 zero bytes"
+ ", and the final 0xca"
+ ", result: " + Arrays.toString(res)),
() -> {
for (int i = 0; i < bufferSize; i++) {
int pos = i;
assertEquals(
0xff, res[pos] & 0xff,
() -> "bytes [0.." + bufferSize + ") should be 0xff as they were written"
+ " with single-byte write(0xff) calls to fill the buffer"
+ ", mismatch at position " + pos
+ ", result: " + Arrays.toString(res));
}
},
() -> {
for (int i = 0; i < offset; i++) {
int pos = bufferSize + i;
assertEquals(
0xfe, res[pos] & 0xff,
() -> "bytes [" + bufferSize + ".." + (bufferSize + offset) + ")"
+ " should be 0xfe as they were written"
+ " with single-byte write(0xfe) calls to shift the buffer position"
+ ", mismatch at position " + pos
+ ", result: " + Arrays.toString(res));
}
},
() -> {
for (int i = 0; i < numZeros; i++) {
int pos = bufferSize + offset + i;
assertEquals(
0, res[pos] & 0xff,
() -> "bytes [" + (bufferSize + offset) + ".." + (bufferSize + offset + numZeros) + ")"
+ " should be 0xff as they were written"
+ " with writeZeros(len)"
+ ", mismatch at position " + pos
+ ", result: " + Arrays.toString(res));
}
},
() -> {
assertEquals(
0xca, res[bufferSize + offset + numZeros] & 0xff,
"the last byte should be 0xca as it was written with write(0xca)"
+ " as a terminator char"
+ ", result: " + Arrays.toString(res));
}
);
}
}

@Test
void writeAndCompare() throws IOException {
byte[] data = new byte[1024 * 1024];
Expand Down

0 comments on commit 94b7e3f

Please sign in to comment.