diff --git a/rust/src/database/migration.rs b/rust/src/database/migration.rs index 959f8124c6..c86a1c66e6 100644 --- a/rust/src/database/migration.rs +++ b/rust/src/database/migration.rs @@ -19,14 +19,16 @@ use std::{ cmp::max, - collections::VecDeque, fs::{File, OpenOptions}, - io::{BufRead, BufWriter, Read, Write}, + io::{BufRead, Read, Write}, marker::PhantomData, path::Path, }; -use prost::Message; +use prost::{ + bytes::{Buf, BytesMut}, + Message, +}; use typedb_protocol::migration::Item as MigrationItemProto; use crate::{error::MigrationError, Error, Result}; @@ -40,35 +42,42 @@ pub(crate) enum DatabaseExportAnswer { pub struct ProtoMessageIterator { reader: R, - buffer: VecDeque, + buffer: BytesMut, _phantom_data: PhantomData, } impl ProtoMessageIterator { - const BUF_CAPACITY: usize = 1024; + const BUF_CAPACITY: usize = 8 * 1024; // prost's length delimiters take up to 10 bytes const MAX_LENGTH_DELIMITER_LEN: usize = 10; pub fn new(reader: R) -> Self { - Self { reader, buffer: VecDeque::with_capacity(Self::BUF_CAPACITY), _phantom_data: PhantomData } + Self { reader, buffer: BytesMut::with_capacity(Self::BUF_CAPACITY), _phantom_data: PhantomData } } - fn try_read_more(&mut self, bytes_to_read: usize) -> std::io::Result { - let mut addition = vec![0; bytes_to_read]; + fn read_more(&mut self, bytes_to_read: usize) -> std::io::Result { + if self.buffer.capacity() - self.buffer.len() < bytes_to_read { + self.buffer.reserve(max(bytes_to_read, Self::BUF_CAPACITY)); + } + let mut addition = vec![0u8; max(bytes_to_read, 1)]; let bytes_read = self.reader.read(&mut addition)?; - self.buffer.extend(&addition[..bytes_read]); + self.buffer.extend_from_slice(&addition[..bytes_read]); Ok(bytes_read) } - fn try_get_next_message_len(&mut self) -> Result> { + fn decode_next_len(&mut self) -> Result> { loop { - if let Ok(len) = prost::decode_length_delimiter(&mut self.buffer) { - return Ok(Some(len)); - } else { - if self.buffer.len() < Self::MAX_LENGTH_DELIMITER_LEN { - assert!(Self::MAX_LENGTH_DELIMITER_LEN < Self::BUF_CAPACITY); - let to_read = max(Self::MAX_LENGTH_DELIMITER_LEN, Self::BUF_CAPACITY - self.buffer.len()); - match self.try_read_more(to_read) { + let mut cursor: &[u8] = &self.buffer; + match prost::decode_length_delimiter(&mut cursor) { + Ok(len) => { + let consumed = self.buffer.len() - cursor.len(); + return Ok(Some((len, consumed))); + } + Err(_) => { + if self.buffer.len() >= Self::MAX_LENGTH_DELIMITER_LEN { + return Err(Error::Migration(MigrationError::CannotDecodeImportedConceptLength)); + } + match self.read_more(Self::MAX_LENGTH_DELIMITER_LEN - self.buffer.len()) { Ok(bytes_read) if bytes_read == 0 => { return if self.buffer.is_empty() { Ok(None) @@ -79,40 +88,34 @@ impl ProtoMessageIterator { Err(_) => return Err(Error::Migration(MigrationError::CannotDecodeImportedConceptLength)), Ok(_) => continue, } - } else { - return Err(Error::Migration(MigrationError::CannotDecodeImportedConceptLength)); } } } } - - fn get_message_buf(&mut self, len: usize) -> VecDeque { - let message_buf = self.buffer.split_off(len); - std::mem::replace(&mut self.buffer, message_buf) - } } impl Iterator for ProtoMessageIterator { type Item = Result; fn next(&mut self) -> Option { - let message_len = match self.try_get_next_message_len() { - Ok(Some(len)) => len, + let (message_len, consumed) = match self.decode_next_len() { + Ok(Some(res)) => res, Ok(None) => return None, Err(err) => return Some(Err(err)), }; - if self.buffer.len() < message_len { - let required = message_len - self.buffer.len(); - let to_read = max(required, Self::BUF_CAPACITY); - match self.try_read_more(to_read) { - Ok(bytes_read) if bytes_read >= required => {} - _ => return Some(Err(Error::Migration(MigrationError::CannotDecodeImportedConcept))), + let required = consumed + message_len; + while self.buffer.len() < required { + let to_read = required - self.buffer.len(); + match self.read_more(max(to_read, Self::BUF_CAPACITY)) { + Ok(0) | Err(_) => return Some(Err(Error::Migration(MigrationError::CannotDecodeImportedConcept))), + Ok(_) => {} } } - let mut message_buf = self.get_message_buf(message_len); - Some(M::decode(&mut message_buf).map_err(|_| Error::Migration(MigrationError::CannotDecodeImportedConcept))) + self.buffer.advance(consumed); + let message_bytes = self.buffer.split_to(message_len).freeze(); + Some(M::decode(message_bytes).map_err(|_| Error::Migration(MigrationError::CannotDecodeImportedConcept))) } }