Skip to content

Commit

Permalink
Add BytesCountingFilterInputStream for counting bytes read
Browse files Browse the repository at this point in the history
See title
  • Loading branch information
ywangd committed Jul 18, 2024
1 parent 7ba5b4f commit 432a756
Show file tree
Hide file tree
Showing 2 changed files with 155 additions and 0 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0 and the Server Side Public License, v 1; you may not use this file except
* in compliance with, at your election, the Elastic License 2.0 or the Server
* Side Public License, v 1.
*/

package org.elasticsearch.common.io.stream;

import java.io.FilterInputStream;
import java.io.IOException;
import java.io.InputStream;

public class BytesCountingFilterInputStream extends FilterInputStream {

private int bytesRead = 0;

public BytesCountingFilterInputStream(InputStream in) {
super(in);
}

@Override
public int read() throws IOException {
assert assertInvariant();
final int result = super.read();
if (result != -1) {
bytesRead += 1;
}
return result;
}

// Not overriding read(byte[]) because FilterInputStream delegates to read(byte[], int, int)

@Override
public int read(byte[] b, int off, int len) throws IOException {
assert assertInvariant();
final int n = super.read(b, off, len);
if (n != -1) {
bytesRead += n;
}
return n;
}

@Override
public long skip(long n) throws IOException {
assert assertInvariant();
final long skipped = super.skip(n);
bytesRead += Math.toIntExact(skipped);
return skipped;
}

public int getBytesRead() {
return bytesRead;
}

protected boolean assertInvariant() {
return true;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0 and the Server Side Public License, v 1; you may not use this file except
* in compliance with, at your election, the Elastic License 2.0 or the Server
* Side Public License, v 1.
*/

package org.elasticsearch.common.io.stream;

import org.elasticsearch.core.CheckedConsumer;
import org.elasticsearch.test.ESTestCase;

import java.io.ByteArrayInputStream;
import java.io.IOException;
import java.util.Arrays;

import static org.hamcrest.Matchers.equalTo;

public class BytesCountingFilterInputStreamTests extends ESTestCase {

public void testBytesCounting() throws IOException {
final byte[] input = randomByteArrayOfLength(between(500, 1000));
final var in = new BytesCountingFilterInputStream(new ByteArrayInputStream(input));

assertThat(in.getBytesRead(), equalTo(0));

final CheckedConsumer<Integer, IOException> readRandomly = (Integer length) -> {
switch (between(0, 3)) {
case 0 -> {
for (var i = 0; i < length; i++) {
final int bytesBefore = in.getBytesRead();
final int result = in.read();
assertThat((byte) result, equalTo(input[bytesBefore]));
assertThat(in.getBytesRead(), equalTo(bytesBefore + 1));
}
}
case 1 -> {
final int bytesBefore = in.getBytesRead();
final byte[] b;
if (randomBoolean()) {
b = in.readNBytes(length);
} else {
b = new byte[length];
assertThat(in.read(b), equalTo(length));
}
assertArrayEquals(Arrays.copyOfRange(input, bytesBefore, bytesBefore + length), b);
assertThat(in.getBytesRead(), equalTo(bytesBefore + length));
}
case 2 -> {
final int bytesBefore = in.getBytesRead();
final byte[] b = new byte[length * between(2, 5)];
if (randomBoolean()) {
assertThat(in.read(b, length / 2, length), equalTo(length));
} else {
assertThat(in.readNBytes(b, length / 2, length), equalTo(length));
}
assertArrayEquals(
Arrays.copyOfRange(input, bytesBefore, bytesBefore + length),
Arrays.copyOfRange(b, length / 2, length / 2 + length)
);
assertThat(in.getBytesRead(), equalTo(bytesBefore + length));
}
case 3 -> {
final int bytesBefore = in.getBytesRead();
if (randomBoolean()) {
assertThat((int) in.skip(length), equalTo(length));
} else {
in.skipNBytes(length);
}
assertThat(in.getBytesRead(), equalTo(bytesBefore + length));
}
default -> fail("unexpected");
}
};

while (in.getBytesRead() < input.length - 50) {
readRandomly.accept(between(1, 30));
}

final int bytesBefore = in.getBytesRead();
final byte[] remainingBytes = in.readAllBytes();
assertThat(in.getBytesRead(), equalTo(bytesBefore + remainingBytes.length));
assertThat(in.getBytesRead(), equalTo(input.length));

// Read beyond available data has no effect
in.read();
final byte[] bytes = new byte[between(20, 30)];
in.read(bytes);
in.read(bytes, between(3, 5), between(5, 10));
in.skip(between(10, 20));

assertThat(in.getBytesRead(), equalTo(input.length));
}
}

0 comments on commit 432a756

Please sign in to comment.