Skip to content

Commit

Permalink
Fix #3369: javalib FileChannel unix-like now reports partial writes (#…
Browse files Browse the repository at this point in the history
…3370)

(cherry picked from commit 3bcc03e)
  • Loading branch information
LeeTibbert authored and WojciechMazur committed Sep 4, 2023
1 parent 7ad8306 commit 54e68d4
Show file tree
Hide file tree
Showing 2 changed files with 124 additions and 62 deletions.
160 changes: 98 additions & 62 deletions javalib/src/main/scala/java/nio/channels/FileChannelImpl.scala
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@ import scala.scalanative.nio.fs.unix.UnixException
import java.io.FileDescriptor
import java.io.File

import java.util.Objects

import scala.scalanative.meta.LinktimeInfo.isWindows
import java.io.IOException

Expand Down Expand Up @@ -51,6 +53,12 @@ private[java] final class FileChannelImpl(
private def ensureOpen(): Unit =
if (!isOpen()) throw new ClosedChannelException()

private def ensureOpenForWrite(): Unit = {
ensureOpen()
if (!openForWriting)
throw new NonWritableChannelException()
}

private def seekEOF(): Unit = {
if (isWindows) {
SetFilePointerEx(
Expand All @@ -76,10 +84,8 @@ private[java] final class FileChannelImpl(
override def force(metadata: Boolean): Unit =
fd.sync()

@inline private def assertIfCanLock(): Unit = {
if (!isOpen()) throw new ClosedChannelException()
if (!openForWriting) throw new NonWritableChannelException()
}
@inline private def assertIfCanLock(): Unit =
ensureOpenForWrite()

override def tryLock(
position: Long,
Expand Down Expand Up @@ -456,51 +462,62 @@ private[java] final class FileChannelImpl(
this
}

/* 2023-07-02 NOTE: This method is BROKEN! It should be returning
* an Int number of bytes written. It detects errors but not
* partial writes. Bad dog!
*
* Fix 'writeByteBuffer()' after this methods gets fixed.
* The former should return the actual number of bytes written
* on partial writes.
*/
private def writeArray(
array: Array[Byte],
offset: Int,
count: Int
): Int = {
// Precondition: caller has checked arguments.

val nWritten =
if (count == 0) 0
else {
// we use the runtime knowledge of the array layout to avoid an
// intermediate buffer, and read straight from the array memory.
val buf = array.at(offset)

if (isWindows) {
val hasSucceded =
WriteFile(fd.handle, buf, count.toUInt, null, null)
if (!hasSucceded) {
throw WindowsException.onPath(
file.fold("<file descriptor>")(_.toString)
)
}

count // Windows will fail on partial write, so nWritten == count
} else {
// unix-like may do partial writes, so be robust to them.
val writeCount = unistd.write(fd.fd, buf, count.toUInt)

if (writeCount < 0) {
// negative value (typically -1) indicates that write failed
throw UnixException(file.fold("")(_.toString), errno.errno)
}

writeCount // may be < requested count
}
}

nWritten
}

// since all of java package can call this, be stricter with argument checks.
private[java] def write(
buffer: Array[Byte],
offset: Int,
count: Int
): Unit = {
if (buffer == null) {
throw new NullPointerException
}
if (offset < 0 || count < 0 || count > buffer.length - offset) {
throw new IndexOutOfBoundsException
}
if (count == 0) {
return
}
): Int = {
Objects.requireNonNull(buffer, "buffer")

// we use the runtime knowledge of the array layout to avoid
// intermediate buffer, and read straight from the array memory
val buf = buffer.at(offset)
if (isWindows) {
val hasSucceded =
WriteFile(fd.handle, buf, count.toUInt, null, null)
if (!hasSucceded) {
throw WindowsException.onPath(
file.fold("<file descriptor>")(_.toString)
)
}
} else {
val writeCount = unistd.write(fd.fd, buf, count.toUInt)
if ((offset < 0) || (count < 0) || (count > buffer.length - offset))
throw new IndexOutOfBoundsException

if (writeCount < 0) {
// negative value (typically -1) indicates that write failed
throw UnixException(file.fold("")(_.toString), errno.errno)
}
}
writeArray(buffer, offset, count)
}

private def writeByteBuffer(src: ByteBuffer): Int = {
// Precondition: caller has ensured that channel is open and open for write
val srcPos = src.position()
val srcLim = src.limit()
val nBytes = srcLim - srcPos // number of bytes in range.
Expand All @@ -513,35 +530,54 @@ private[java] final class FileChannelImpl(
(ba, 0)
}

write(arr, offset, nBytes)

src.position(srcPos + nBytes)
val nWritten = writeArray(arr, offset, nBytes)

/* 2023-07-02 NOTE: This return is BROKEN! It does not handle
* partial OS writes. Fix after/when the 'write(arr, offset, nBytes)'
* method gets fixed to return a value.
/* Advance the srcPos only by the number of bytes actually written.
* This allows higher level callers to re-try partial writes
* in a 'natural' manner (no buffer futzing required).
*/
nBytes // BUGGY
}
src.position(srcPos + nWritten)

/* 2023-07-02 NOTE: This method is BROKEN! It should be returning
* an Long number of bytes written. Instead it is wrongly returning
* 'i' the number of buffers written. At least here the return type is
* correct.
*/
nWritten
}

override def write(
buffers: Array[ByteBuffer],
srcs: Array[ByteBuffer],
offset: Int,
length: Int
): Long = {
// write(ByteBuffer) will call ensureOpen(), saveCPU cycles by no call here
var i = 0
while (i < length) {
write(buffers(offset + i))
i += 1

Objects.requireNonNull(srcs, "srcs")

if ((offset < 0) ||
(offset > srcs.length) ||
(length < 0) ||
(length > srcs.length - offset))
throw new IndexOutOfBoundsException

ensureOpenForWrite()

var totalWritten = 0

var partialWriteSeen = false
var j = 0

while ((j < length) && !partialWriteSeen) {
val src = srcs(j)
val srcPos = src.position()
val srcLim = src.limit()
val nExpected = srcLim - srcPos // number of bytes in range.

val nWritten = writeByteBuffer(src)

totalWritten += nWritten
if (nWritten < nExpected)
partialWriteSeen = true

j += 1
}
i

totalWritten
}

/* Write to absolute position, do not change current position.
Expand All @@ -556,7 +592,7 @@ private[java] final class FileChannelImpl(
* really changing the "current position".
*/
override def write(src: ByteBuffer, pos: Long): Int = {
ensureOpen()
ensureOpenForWrite()
val stashPosition = position()
compelPosition(pos)

Expand All @@ -572,7 +608,7 @@ private[java] final class FileChannelImpl(

// Write relative to current position (SEEK_CUR) or, for APPEND, SEEK_END.
override def write(src: ByteBuffer): Int = {
ensureOpen()
ensureOpenForWrite()
writeByteBuffer(src)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -342,6 +342,32 @@ class FileChannelTest {
}
}

@Test def writeOfMultipleBuffersReturnsTotalBytesWritten(): Unit = {
withTemporaryDirectory { dir =>
val f = dir.resolve("f")

val data = Array("Parsley", "sage", "rosemary", "thyme")

val nbytes = new Array[Int](data.size)
for (j <- 0 until data.size)
nbytes(j) = data(j).size

val srcs = new Array[ByteBuffer](data.size)
for (j <- 0 until data.size)
srcs(j) = ByteBuffer.wrap(data(j).getBytes("UTF-8"))

val expectedTotalWritten = nbytes.sum

val channel =
FileChannel.open(f, StandardOpenOption.CREATE, StandardOpenOption.WRITE)

try {
val nWritten = channel.write(srcs, 0, srcs.size)
assertEquals("total bytes written", expectedTotalWritten, nWritten)
} finally channel.close()
}
}

@Test def canMoveFilePointer(): Unit = {
withTemporaryDirectory { dir =>
val f = dir.resolve("f")
Expand Down

0 comments on commit 54e68d4

Please sign in to comment.