From 0bbea8c06057755b76da171dd5fd32c9ef342c43 Mon Sep 17 00:00:00 2001 From: Ti Chi Robot Date: Tue, 15 Aug 2023 10:31:59 +0800 Subject: [PATCH] encryption: fix offset inconsistency between crypter and file (#15092) (#15207) close tikv/tikv#15080 Fix offset inconsistency between crypter and file that could cause data corruption when file I/O is interrupted. Signed-off-by: tabokie Co-authored-by: tabokie Co-authored-by: tonyxuqqi --- components/encryption/src/io.rs | 205 +++++++++++++++++++++----------- 1 file changed, 136 insertions(+), 69 deletions(-) diff --git a/components/encryption/src/io.rs b/components/encryption/src/io.rs index e02aafabe88..fafdd7a2af9 100644 --- a/components/encryption/src/io.rs +++ b/components/encryption/src/io.rs @@ -251,7 +251,11 @@ impl Read for CrypterReader { fn read(&mut self, buf: &mut [u8]) -> IoResult { let count = self.reader.read(buf)?; if let Some(crypter) = self.crypter.as_mut() { - crypter.do_crypter_in_place(&mut buf[..count])?; + if let Err(e) = crypter.do_crypter_in_place(&mut buf[..count]) { + // FIXME: We can't recover from this without rollback `reader` to old offset. + // But that requires `Seek` which requires a wider refactor of user code. + panic!("`do_crypter_in_place` failed: {:?}", e); + } } Ok(count) } @@ -283,7 +287,9 @@ impl AsyncRead for CrypterReader { }; if let Some(crypter) = inner.crypter.as_mut() { if let Err(e) = crypter.do_crypter_in_place(&mut buf[..read_count]) { - return Poll::Ready(Err(e)); + // FIXME: We can't recover from this without rollback `reader` to old offset. + // But that requires `Seek` which requires a wider refactor of user code. + panic!("`do_crypter_in_place` failed: {:?}", e); } } Poll::Ready(Ok(read_count)) @@ -330,7 +336,10 @@ impl Write for CrypterWriter { if let Some(crypter) = self.crypter.as_mut() { let crypted = crypter.do_crypter(buf)?; debug_assert!(crypted.len() == buf.len()); - self.writer.write(crypted) + let r = self.writer.write(crypted); + let missing = buf.len() - r.as_ref().unwrap_or(&0); + crypter.lazy_reset_crypter(crypter.offset - missing as u64); + r } else { self.writer.write(buf) } @@ -388,6 +397,10 @@ struct CrypterCore { key: Vec, mode: Mode, initial_iv: Iv, + + // Used to ensure the atomicity of operation over a chunk of data. Only advance it when + // operation succeeds. + offset: u64, crypter: Option, block_size: usize, @@ -401,9 +414,10 @@ impl CrypterCore { method, key: key.to_owned(), mode, + initial_iv: iv, + offset: 0, crypter: None, block_size: 0, - initial_iv: iv, buffer: Vec::new(), }) } @@ -414,6 +428,17 @@ impl CrypterCore { self.buffer.resize(size + self.block_size, 0); } + // Delay the reset to future operations that use crypter. Guarantees those + // operations can only succeed after crypter is properly reset. + pub fn lazy_reset_crypter(&mut self, offset: u64) { + if self.offset != offset { + self.crypter.take(); + self.offset = offset; + } + } + + // It has the same guarantee as `lazy_reset_crypter`. In addition, it attempts + // to reset immediately and returns any error. pub fn reset_crypter(&mut self, offset: u64) -> IoResult<()> { let mut iv = self.initial_iv; iv.add_offset(offset / AES_BLOCK_SIZE as u64)?; @@ -424,6 +449,7 @@ impl CrypterCore { self.reset_buffer(partial_offset); let crypter_count = crypter.update(&partial_block, &mut self.buffer)?; if crypter_count != partial_offset { + self.lazy_reset_crypter(offset); return Err(IoError::new( ErrorKind::Other, format!( @@ -432,6 +458,7 @@ impl CrypterCore { ), )); } + self.offset = offset; self.crypter = Some(crypter); self.block_size = cipher.block_size(); Ok(()) @@ -443,7 +470,7 @@ impl CrypterCore { /// this code needs to be updated. pub fn do_crypter_in_place(&mut self, buf: &mut [u8]) -> IoResult<()> { if self.crypter.is_none() { - self.reset_crypter(0)?; + self.reset_crypter(self.offset)?; } let count = buf.len(); self.reset_buffer(std::cmp::min(count, MAX_INPLACE_CRYPTION_SIZE)); @@ -454,6 +481,7 @@ impl CrypterCore { debug_assert!(self.buffer.len() >= target - encrypted); let crypter_count = crypter.update(&buf[encrypted..target], &mut self.buffer)?; if crypter_count != target - encrypted { + self.crypter.take(); return Err(IoError::new( ErrorKind::Other, format!( @@ -466,18 +494,20 @@ impl CrypterCore { buf[encrypted..target].copy_from_slice(&self.buffer[..crypter_count]); encrypted += crypter_count; } + self.offset += count as u64; Ok(()) } pub fn do_crypter(&mut self, buf: &[u8]) -> IoResult<&[u8]> { if self.crypter.is_none() { - self.reset_crypter(0)?; + self.reset_crypter(self.offset)?; } let count = buf.len(); self.reset_buffer(count); let crypter = self.crypter.as_mut().unwrap(); let crypter_count = crypter.update(buf, &mut self.buffer)?; if crypter_count != count { + self.crypter.take(); return Err(IoError::new( ErrorKind::Other, format!( @@ -486,6 +516,7 @@ impl CrypterCore { ), )); } + self.offset += count as u64; Ok(&self.buffer[..count]) } @@ -508,7 +539,6 @@ mod tests { use std::{cmp::min, io::Cursor}; use byteorder::{BigEndian, ByteOrder}; - use futures::AsyncReadExt; use rand::{rngs::OsRng, RngCore}; use super::*; @@ -521,6 +551,58 @@ mod tests { key } + struct DecoratedCursor { + cursor: Cursor>, + read_size: usize, + } + + impl DecoratedCursor { + fn new(buff: Vec, read_size: usize) -> DecoratedCursor { + Self { + cursor: Cursor::new(buff.to_vec()), + read_size, + } + } + + fn into_inner(self) -> Vec { + self.cursor.into_inner() + } + } + + impl AsyncRead for DecoratedCursor { + fn poll_read( + mut self: Pin<&mut Self>, + _: &mut Context<'_>, + buf: &mut [u8], + ) -> Poll> { + let len = min(self.read_size, buf.len()); + Poll::Ready(self.cursor.read(&mut buf[..len])) + } + } + + impl Read for DecoratedCursor { + fn read(&mut self, buf: &mut [u8]) -> IoResult { + let len = min(self.read_size, buf.len()); + self.cursor.read(&mut buf[..len]) + } + } + + impl Write for DecoratedCursor { + fn write(&mut self, buf: &[u8]) -> IoResult { + let len = min(self.read_size, buf.len()); + self.cursor.write(&buf[0..len]) + } + fn flush(&mut self) -> IoResult<()> { + self.cursor.flush() + } + } + + impl Seek for DecoratedCursor { + fn seek(&mut self, s: SeekFrom) -> IoResult { + self.cursor.seek(s) + } + } + #[test] fn test_decrypt_encrypted_text() { let methods = [ @@ -552,24 +634,30 @@ mod tests { let mut plaintext = vec![0; 1024]; OsRng.fill_bytes(&mut plaintext); - let buf = Vec::with_capacity(1024); - let mut encrypter = EncrypterWriter::new(buf, method, &key, iv).unwrap(); + let mut encrypter = EncrypterWriter::new( + DecoratedCursor::new(plaintext.clone(), 1), + method, + &key, + iv, + ) + .unwrap(); encrypter.write_all(&plaintext).unwrap(); - let buf = encrypter.finalize().unwrap(); + let encrypted = encrypter.finalize().unwrap().into_inner(); // Make sure it's properly encrypted. if method != EncryptionMethod::Plaintext { - assert_ne!(buf, plaintext); + assert_ne!(encrypted, plaintext); } else { - assert_eq!(buf, plaintext); + assert_eq!(encrypted, plaintext); } - let buf_reader = std::io::Cursor::new(buf); - let mut decrypter = DecrypterReader::new(buf_reader, method, &key, iv).unwrap(); + let mut decrypter = + DecrypterReader::new(DecoratedCursor::new(encrypted, 1), method, &key, iv) + .unwrap(); let mut piece = vec![0; 5]; // Read the first two blocks randomly. for i in 0..31 { assert_eq!(decrypter.seek(SeekFrom::Start(i as u64)).unwrap(), i as u64); - assert_eq!(decrypter.read(&mut piece).unwrap(), piece.len()); + decrypter.read_exact(&mut piece).unwrap(); assert_eq!(piece, plaintext[i..i + piece.len()]); } // Read the rest of the data sequentially. @@ -579,13 +667,14 @@ mod tests { cursor as u64 ); while cursor + piece.len() <= plaintext.len() { - assert_eq!(decrypter.read(&mut piece).unwrap(), piece.len()); + decrypter.read_exact(&mut piece).unwrap(); assert_eq!(piece, plaintext[cursor..cursor + piece.len()]); cursor += piece.len(); } let tail = plaintext.len() - cursor; - assert_eq!(decrypter.read(&mut piece).unwrap(), tail); - assert_eq!(piece[..tail], plaintext[cursor..cursor + tail]); + let mut short_piece = vec![0; tail]; + decrypter.read_exact(&mut short_piece).unwrap(); + assert_eq!(short_piece[..], plaintext[cursor..cursor + tail]); } } } @@ -605,9 +694,10 @@ mod tests { let sizes = [1024, 10240]; for method in methods { let key = generate_data_key(method); - let readable_text = std::io::Cursor::new(plaintext.clone()); let iv = Iv::new_ctr(); - let encrypter = EncrypterReader::new(readable_text, method, &key, iv).unwrap(); + let encrypter = + EncrypterReader::new(DecoratedCursor::new(plaintext.clone(), 1), method, &key, iv) + .unwrap(); let mut decrypter = DecrypterReader::new(encrypter, method, &key, iv).unwrap(); let mut read = vec![0; 10240]; for offset in offsets { @@ -617,7 +707,7 @@ mod tests { offset as u64 ); let actual_size = std::cmp::min(plaintext.len().saturating_sub(offset), size); - assert_eq!(decrypter.read(&mut read[..size]).unwrap(), actual_size); + decrypter.read_exact(&mut read[..actual_size]).unwrap(); if actual_size > 0 { assert_eq!(read[..actual_size], plaintext[offset..offset + actual_size]); } @@ -642,13 +732,14 @@ mod tests { let written = vec![0; 10240]; for method in methods { let key = generate_data_key(method); - let writable_text = std::io::Cursor::new(written.clone()); let iv = Iv::new_ctr(); - let encrypter = EncrypterWriter::new(writable_text, method, &key, iv).unwrap(); + let encrypter = + EncrypterWriter::new(DecoratedCursor::new(written.clone(), 1), method, &key, iv) + .unwrap(); let mut decrypter = DecrypterWriter::new(encrypter, method, &key, iv).unwrap(); // First write full data. assert_eq!(decrypter.seek(SeekFrom::Start(0)).unwrap(), 0); - assert_eq!(decrypter.write(&plaintext).unwrap(), plaintext.len()); + decrypter.write_all(&plaintext).unwrap(); // Then overwrite specific locations. for offset in offsets { for size in sizes { @@ -657,10 +748,9 @@ mod tests { offset as u64 ); let size = std::cmp::min(plaintext.len().saturating_sub(offset), size); - assert_eq!( - decrypter.write(&plaintext[offset..offset + size]).unwrap(), - size - ); + decrypter + .write_all(&plaintext[offset..offset + size]) + .unwrap(); } } let written = decrypter @@ -673,33 +763,8 @@ mod tests { } } - struct MockCursorReader { - cursor: Cursor>, - read_maxsize_once: usize, - } - - impl MockCursorReader { - fn new(buff: &mut [u8], size_once: usize) -> MockCursorReader { - Self { - cursor: Cursor::new(buff.to_vec()), - read_maxsize_once: size_once, - } - } - } - - impl AsyncRead for MockCursorReader { - fn poll_read( - mut self: Pin<&mut Self>, - _cx: &mut Context<'_>, - buf: &mut [u8], - ) -> Poll> { - let len = min(self.read_maxsize_once, buf.len()); - let r = self.cursor.read(&mut buf[..len]).unwrap(); - Poll::Ready(IoResult::Ok(r)) - } - } - async fn test_poll_read() { + use futures::AsyncReadExt; let methods = [ EncryptionMethod::Plaintext, EncryptionMethod::Aes128Ctr, @@ -716,38 +781,39 @@ mod tests { // encrypt plaintext into encrypt_text let read_once = 16; let mut encrypt_reader = EncrypterReader::new( - MockCursorReader::new(&mut plain_text[..], read_once), + DecoratedCursor::new(plain_text.clone(), read_once), method, &key[..], iv, ) .unwrap(); - let mut encrypt_text = [0; 20480]; + let mut encrypt_text = vec![0; 20480]; let mut encrypt_read_len = 0; loop { - let read_len = encrypt_reader - .read(&mut encrypt_text[encrypt_read_len..]) - .await - .unwrap(); + let read_len = + AsyncReadExt::read(&mut encrypt_reader, &mut encrypt_text[encrypt_read_len..]) + .await + .unwrap(); if read_len == 0 { break; } encrypt_read_len += read_len; } + encrypt_text.truncate(encrypt_read_len); if method == EncryptionMethod::Plaintext { - assert_eq!(encrypt_text[..encrypt_read_len], plain_text); + assert_eq!(encrypt_text, plain_text); } else { - assert_ne!(encrypt_text[..encrypt_read_len], plain_text); + assert_ne!(encrypt_text, plain_text); } // decrypt encrypt_text into decrypt_text - let mut decrypt_text = [0; 20480]; + let mut decrypt_text = vec![0; 20480]; let mut decrypt_read_len = 0; let read_once = 20; let mut decrypt_reader = DecrypterReader::new( - MockCursorReader::new(&mut encrypt_text[..encrypt_read_len], read_once), + DecoratedCursor::new(encrypt_text.clone(), read_once), method, &key[..], iv, @@ -755,17 +821,18 @@ mod tests { .unwrap(); loop { - let read_len = decrypt_reader - .read(&mut decrypt_text[decrypt_read_len..]) - .await - .unwrap(); + let read_len = + AsyncReadExt::read(&mut decrypt_reader, &mut decrypt_text[decrypt_read_len..]) + .await + .unwrap(); if read_len == 0 { break; } decrypt_read_len += read_len; } - assert_eq!(decrypt_text[..decrypt_read_len], plain_text); + decrypt_text.truncate(decrypt_read_len); + assert_eq!(decrypt_text, plain_text); } }