From 8212fc513c66ebd9996180456305ecd6c425d5da Mon Sep 17 00:00:00 2001 From: Jan Verbeek Date: Sun, 3 Mar 2024 12:57:28 +0100 Subject: [PATCH] Fix quadratic behavior of repeated vectored writes MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Some implementations of `Write::write_vectored` in the standard library (`BufWriter`, `LineWriter`, `Stdout`, `Stderr`) check all buffers to calculate the total length. This is O(n) over the number of buffers. It's common that only a limited number of buffers is written at a time (e.g. 1024 for `writev(2)`). `write_vectored_all` will then call `write_vectored` repeatedly, leading to a runtime of O(n²) over the number of buffers. The fix is to only calculate as much as needed if it's needed. --- library/std/src/io/buffered/bufwriter.rs | 51 ++++++++++--------- library/std/src/io/buffered/linewritershim.rs | 15 ++++-- library/std/src/io/stdio.rs | 15 ++++-- 3 files changed, 50 insertions(+), 31 deletions(-) diff --git a/library/std/src/io/buffered/bufwriter.rs b/library/std/src/io/buffered/bufwriter.rs index 665d8602c0800..2d13230ffbabd 100644 --- a/library/std/src/io/buffered/bufwriter.rs +++ b/library/std/src/io/buffered/bufwriter.rs @@ -554,35 +554,38 @@ impl Write for BufWriter { // same underlying buffer, as otherwise the buffers wouldn't fit in memory). If the // computation overflows, then surely the input cannot fit in our buffer, so we forward // to the inner writer's `write_vectored` method to let it handle it appropriately. - let saturated_total_len = - bufs.iter().fold(0usize, |acc, b| acc.saturating_add(b.len())); + let mut saturated_total_len: usize = 0; - if saturated_total_len > self.spare_capacity() { - // Flush if the total length of the input exceeds our buffer's spare capacity. - // If we would have overflowed, this condition also holds, and we need to flush. - self.flush_buf()?; + for buf in bufs { + saturated_total_len = saturated_total_len.saturating_add(buf.len()); + + if saturated_total_len > self.spare_capacity() && !self.buf.is_empty() { + // Flush if the total length of the input exceeds our buffer's spare capacity. + // If we would have overflowed, this condition also holds, and we need to flush. + self.flush_buf()?; + } + + if saturated_total_len >= self.buf.capacity() { + // Forward to our inner writer if the total length of the input is greater than or + // equal to our buffer capacity. If we would have overflowed, this condition also + // holds, and we punt to the inner writer. + self.panicked = true; + let r = self.get_mut().write_vectored(bufs); + self.panicked = false; + return r; + } } - if saturated_total_len >= self.buf.capacity() { - // Forward to our inner writer if the total length of the input is greater than or - // equal to our buffer capacity. If we would have overflowed, this condition also - // holds, and we punt to the inner writer. - self.panicked = true; - let r = self.get_mut().write_vectored(bufs); - self.panicked = false; - r - } else { - // `saturated_total_len < self.buf.capacity()` implies that we did not saturate. + // `saturated_total_len < self.buf.capacity()` implies that we did not saturate. - // SAFETY: We checked whether or not the spare capacity was large enough above. If - // it was, then we're safe already. If it wasn't, we flushed, making sufficient - // room for any input <= the buffer size, which includes this input. - unsafe { - bufs.iter().for_each(|b| self.write_to_buffer_unchecked(b)); - }; + // SAFETY: We checked whether or not the spare capacity was large enough above. If + // it was, then we're safe already. If it wasn't, we flushed, making sufficient + // room for any input <= the buffer size, which includes this input. + unsafe { + bufs.iter().for_each(|b| self.write_to_buffer_unchecked(b)); + }; - Ok(saturated_total_len) - } + Ok(saturated_total_len) } else { let mut iter = bufs.iter(); let mut total_written = if let Some(buf) = iter.by_ref().find(|&buf| !buf.is_empty()) { diff --git a/library/std/src/io/buffered/linewritershim.rs b/library/std/src/io/buffered/linewritershim.rs index 7eadfd413661d..c3ac7855d4450 100644 --- a/library/std/src/io/buffered/linewritershim.rs +++ b/library/std/src/io/buffered/linewritershim.rs @@ -175,6 +175,10 @@ impl<'a, W: ?Sized + Write> Write for LineWriterShim<'a, W> { } // Find the buffer containing the last newline + // FIXME: This is overly slow if there are very many bufs and none contain + // newlines. e.g. writev() on Linux only writes up to 1024 slices, so + // scanning the rest is wasted effort. This makes write_all_vectored() + // quadratic. let last_newline_buf_idx = bufs .iter() .enumerate() @@ -215,9 +219,14 @@ impl<'a, W: ?Sized + Write> Write for LineWriterShim<'a, W> { // Don't try to reconstruct the exact amount written; just bail // in the event of a partial write - let lines_len = lines.iter().map(|buf| buf.len()).sum(); - if flushed < lines_len { - return Ok(flushed); + let mut lines_len: usize = 0; + for buf in lines { + // With overlapping/duplicate slices the total length may in theory + // exceed usize::MAX + lines_len = lines_len.saturating_add(buf.len()); + if flushed < lines_len { + return Ok(flushed); + } } // Now that the write has succeeded, buffer the rest (or as much of the diff --git a/library/std/src/io/stdio.rs b/library/std/src/io/stdio.rs index 455de6d98bae6..42fc0053030af 100644 --- a/library/std/src/io/stdio.rs +++ b/library/std/src/io/stdio.rs @@ -128,8 +128,8 @@ impl Write for StdoutRaw { } fn write_vectored(&mut self, bufs: &[IoSlice<'_>]) -> io::Result { - let total = bufs.iter().map(|b| b.len()).sum(); - handle_ebadf(self.0.write_vectored(bufs), total) + let total = || bufs.iter().map(|b| b.len()).sum(); + handle_ebadf_lazy(self.0.write_vectored(bufs), total) } #[inline] @@ -160,8 +160,8 @@ impl Write for StderrRaw { } fn write_vectored(&mut self, bufs: &[IoSlice<'_>]) -> io::Result { - let total = bufs.iter().map(|b| b.len()).sum(); - handle_ebadf(self.0.write_vectored(bufs), total) + let total = || bufs.iter().map(|b| b.len()).sum(); + handle_ebadf_lazy(self.0.write_vectored(bufs), total) } #[inline] @@ -193,6 +193,13 @@ fn handle_ebadf(r: io::Result, default: T) -> io::Result { } } +fn handle_ebadf_lazy(r: io::Result, default: impl FnOnce() -> T) -> io::Result { + match r { + Err(ref e) if stdio::is_ebadf(e) => Ok(default()), + r => r, + } +} + /// A handle to the standard input stream of a process. /// /// Each handle is a shared reference to a global buffer of input data to this