Skip to content

Commit

Permalink
fix [javalib]: FileChannel scattering api respect offsets (#3907)
Browse files Browse the repository at this point in the history
Filechannel impliementation of ScatteringByteChannel discard the
buffer offsets when reading and writing

Add extra check for index sizes
  • Loading branch information
RustedBones committed May 10, 2024
1 parent 1a42105 commit eefc50e
Show file tree
Hide file tree
Showing 2 changed files with 77 additions and 46 deletions.
64 changes: 30 additions & 34 deletions javalib/src/main/scala/java/nio/channels/FileChannelImpl.scala
Original file line number Diff line number Diff line change
Expand Up @@ -257,27 +257,33 @@ private[java] final class FileChannelImpl(
start: Int,
number: Int
): Long = {
Objects.requireNonNull(buffers, "dsts")
Objects.checkFromIndexSize(start, number, buffers.length)

ensureOpen()

var bytesRead = 0L
var i = 0

while (i < number) {
val startPos = buffers(i).position()
val len = buffers(i).limit() - startPos
val dst = new Array[Byte](len)
val nb = read(dst, 0, dst.length)

if (nb > 0) {
buffers(i).put(dst)
buffers(i).position(startPos + nb)
var partialReadSeen = false
var totalRead = 0L
while (i < number && !partialReadSeen) {
val dst = buffers(start + i)
val len = dst.remaining()

val bs = new Array[Byte](len)
val n = read(bs, 0, len)

if (n > 0) {
dst.put(bs, 0, n)
totalRead += n
}

bytesRead += nb
if (n < len) {
partialReadSeen = true
}
i += 1
}

bytesRead
totalRead
}

override def read(buffer: ByteBuffer, pos: Long): Int = {
Expand Down Expand Up @@ -676,35 +682,25 @@ private[java] final class FileChannelImpl(
offset: Int,
length: Int
): Long = {

Objects.requireNonNull(srcs, "srcs")

if ((offset < 0) ||
(offset > srcs.length) ||
(length < 0) ||
(length > srcs.length - offset))
throw new IndexOutOfBoundsException
Objects.checkFromIndexSize(offset, length, srcs.length)

ensureOpenForWrite()

var totalWritten = 0

var i = 0
var partialWriteSeen = false
var j = 0

while ((j < length) && !partialWriteSeen) {
val src = srcs(j)
val srcPos = src.position()
val srcLim = src.limit()
val nExpected = srcLim - srcPos // number of bytes in range.
var totalWritten = 0
while (i < length && !partialWriteSeen) {
val src = srcs(offset + i)
val len = src.remaining()

val nWritten = writeByteBuffer(src)
val n = writeByteBuffer(src)

totalWritten += nWritten
if (nWritten < nExpected)
totalWritten += n
if (n < len) {
partialWriteSeen = true

j += 1
}
i += 1
}

totalWritten
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -145,28 +145,63 @@ class FileChannelTest {
assertTrue(Files.getAttribute(f, "size") == 5)

val channel = FileChannel.open(f)
val bufferA = ByteBuffer.allocate(2)
val bufferB = ByteBuffer.allocate(3)
val buffers = Array[ByteBuffer](bufferA, bufferB)

val bread = channel.read(buffers)
bufferA.flip()
bufferB.flip()
val offset = 1
val limit = 2
val dsts = Array[ByteBuffer](
ByteBuffer.allocate(1),
ByteBuffer.allocate(2),
ByteBuffer.allocate(3),
ByteBuffer.allocate(4)
)

assertTrue(bufferA.limit() == 2)
assertTrue(bufferB.limit() == 3)
assertTrue(bufferA.position() == 0)
assertTrue(bufferB.position() == 0)
val bread = channel.read(dsts, offset, limit)
dsts.foreach(_.flip())

assertTrue(bread == 5L)
assertTrue(bufferA.array() sameElements Array[Byte](1, 2))
assertTrue(bufferB.array() sameElements Array[Byte](3, 4, 5))

assertTrue(dsts(0).remaining() == 0)
assertTrue(dsts(1).remaining() == 2)
assertTrue(dsts(2).remaining() == 3)
assertTrue(dsts(3).remaining() == 0)

assertTrue(dsts(1).array() sameElements Array[Byte](1, 2))
assertTrue(dsts(2).array() sameElements Array[Byte](3, 4, 5))

channel.close()
}
}

@Test def fileChannelCanWriteToFile(): Unit = {
withTemporaryDirectory { dir =>
val f = dir.resolve("f")
val offset = 1
val limit = 3
val srcs = Array[ByteBuffer](
ByteBuffer.wrap(Array[Byte](1)),
ByteBuffer.wrap(Array[Byte](2, 3)),
ByteBuffer.wrap(Array[Byte](4, 5, 6)),
ByteBuffer.wrap(Array[Byte](7, 8, 9, 10))
)
val channel =
FileChannel.open(f, StandardOpenOption.WRITE, StandardOpenOption.CREATE)

val expected = Array[Byte](2, 3, 4, 5, 6)
var written = 0
while (written < expected.length) {
written += channel.write(srcs, offset, limit).toInt
}

val in = Files.newInputStream(f)
var i = 0
while (i < expected.length) {
assertTrue(in.read() == expected(i))
i += 1
}
}
}

@Test def fileChannelCanWriteBuffersToFile(): Unit = {
withTemporaryDirectory { dir =>
val f = dir.resolve("f")
val bytes = Array.apply[Byte](1, 2, 3, 4, 5)
Expand Down

0 comments on commit eefc50e

Please sign in to comment.