diff --git a/applications/tari_console_wallet/src/cli.rs b/applications/tari_console_wallet/src/cli.rs index 8930930cf3..57962e1014 100644 --- a/applications/tari_console_wallet/src/cli.rs +++ b/applications/tari_console_wallet/src/cli.rs @@ -27,7 +27,10 @@ use clap::{Args, Parser, Subcommand}; use tari_app_utilities::{common_cli_args::CommonCliArgs, utilities::UniPublicKey}; use tari_comms::multiaddr::Multiaddr; use tari_core::transactions::{tari_amount, tari_amount::MicroTari}; -use tari_utilities::hex::{Hex, HexError}; +use tari_utilities::{ + hex::{Hex, HexError}, + SafePassword, +}; const DEFAULT_NETWORK: &str = "dibbler"; @@ -45,7 +48,7 @@ pub(crate) struct Cli { /// command line, since it's visible using `ps ax` from anywhere on the system, so always use the env var where /// possible. #[clap(long, env = "TARI_WALLET_PASSWORD", hide_env_values = true)] - pub password: Option, + pub password: Option, /// Change the password for the console wallet #[clap(long, alias = "update-password")] pub change_password: bool, diff --git a/applications/tari_console_wallet/src/init/mod.rs b/applications/tari_console_wallet/src/init/mod.rs index 9fdc08ac5b..b9a353d3ed 100644 --- a/applications/tari_console_wallet/src/init/mod.rs +++ b/applications/tari_console_wallet/src/init/mod.rs @@ -39,6 +39,7 @@ use tari_crypto::keys::PublicKey; use tari_key_manager::{cipher_seed::CipherSeed, mnemonic::MnemonicLanguage}; use tari_p2p::{initialization::CommsInitializationError, peer_seeds::SeedPeer, TransportType}; use tari_shutdown::ShutdownSignal; +use tari_utilities::SafePassword; use tari_wallet::{ error::{WalletError, WalletStorageError}, output_manager_service::storage::database::OutputManagerDatabase, @@ -72,20 +73,19 @@ pub enum WalletBoot { /// Gets the password provided by command line argument or environment variable if available. /// Otherwise prompts for the password to be typed in. pub fn get_or_prompt_password( - arg_password: Option, - config_password: Option, -) -> Result, ExitError> { + arg_password: Option, + config_password: Option, +) -> Result, ExitError> { if arg_password.is_some() { return Ok(arg_password); } let env = std::env::var_os(TARI_WALLET_PASSWORD); if let Some(p) = env { - let env_password = Some( - p.into_string() - .map_err(|_| ExitError::new(ExitCode::IOError, "Failed to convert OsString into String"))?, - ); - return Ok(env_password); + let env_password = p + .into_string() + .map_err(|_| ExitError::new(ExitCode::IOError, "Failed to convert OsString into String"))?; + return Ok(Some(env_password.into())); } if config_password.is_some() { @@ -97,7 +97,7 @@ pub fn get_or_prompt_password( Ok(Some(password)) } -fn prompt_password(prompt: &str) -> Result { +fn prompt_password(prompt: &str) -> Result { let password = loop { let pass = prompt_password_stdout(prompt).map_err(|e| ExitError::new(ExitCode::IOError, e))?; if pass.is_empty() { @@ -108,13 +108,13 @@ fn prompt_password(prompt: &str) -> Result { } }; - Ok(password) + Ok(SafePassword::from(password)) } /// Allows the user to change the password of the wallet. pub async fn change_password( config: &ApplicationConfig, - arg_password: Option, + arg_password: Option, shutdown_signal: ShutdownSignal, ) -> Result<(), ExitError> { let mut wallet = init_wallet(config, arg_password, None, None, shutdown_signal).await?; @@ -221,7 +221,7 @@ pub(crate) fn wallet_mode(cli: &Cli, boot_mode: WalletBoot) -> WalletMode { #[allow(clippy::too_many_lines)] pub async fn init_wallet( config: &ApplicationConfig, - arg_password: Option, + arg_password: Option, seed_words_file_name: Option, recovery_seed: Option, shutdown_signal: ShutdownSignal, diff --git a/base_layer/wallet/src/config.rs b/base_layer/wallet/src/config.rs index 869e12851a..b480833cfb 100644 --- a/base_layer/wallet/src/config.rs +++ b/base_layer/wallet/src/config.rs @@ -34,6 +34,7 @@ use tari_common::{ }; use tari_comms::multiaddr::Multiaddr; use tari_p2p::P2pConfig; +use tari_utilities::SafePassword; use crate::{ base_node_service::config::BaseNodeServiceConfig, @@ -72,7 +73,7 @@ pub struct WalletConfig { /// The main wallet db sqlite database backend connection pool size for concurrent reads pub db_connection_pool_size: usize, /// The main wallet password - pub password: Option, // TODO: Make clear on drop + pub password: Option, /// The auto ping interval to use for contacts liveness data #[serde(with = "serializers::seconds")] pub contacts_auto_ping_interval: Duration, diff --git a/base_layer/wallet/src/key_manager_service/storage/sqlite_db/key_manager_state.rs b/base_layer/wallet/src/key_manager_service/storage/sqlite_db/key_manager_state.rs index 05b5df969c..8ee80df77c 100644 --- a/base_layer/wallet/src/key_manager_service/storage/sqlite_db/key_manager_state.rs +++ b/base_layer/wallet/src/key_manager_service/storage/sqlite_db/key_manager_state.rs @@ -151,24 +151,38 @@ pub struct KeyManagerStateUpdateSql { } impl Encryptable for KeyManagerStateSql { + fn domain(&self, field_name: &'static str) -> Vec { + [Self::KEY_MANAGER, self.branch_seed.as_bytes(), field_name.as_bytes()] + .concat() + .to_vec() + } + fn encrypt(&mut self, cipher: &Aes256Gcm) -> Result<(), String> { - let encrypted_index = encrypt_bytes_integral_nonce(cipher, self.primary_key_index.clone())?; - self.primary_key_index = encrypted_index; + self.primary_key_index = + encrypt_bytes_integral_nonce(cipher, self.domain("primary_key_index"), self.primary_key_index.clone())?; + Ok(()) } fn decrypt(&mut self, cipher: &Aes256Gcm) -> Result<(), String> { - let decrypted_index = decrypt_bytes_integral_nonce(cipher, self.primary_key_index.clone())?; - self.primary_key_index = decrypted_index; + self.primary_key_index = + decrypt_bytes_integral_nonce(cipher, self.domain("primary_key_index"), self.primary_key_index.clone())?; Ok(()) } } impl Encryptable for NewKeyManagerStateSql { + fn domain(&self, field_name: &'static str) -> Vec { + [Self::KEY_MANAGER, self.branch_seed.as_bytes(), field_name.as_bytes()] + .concat() + .to_vec() + } + fn encrypt(&mut self, cipher: &Aes256Gcm) -> Result<(), String> { - let encrypted_index = encrypt_bytes_integral_nonce(cipher, self.primary_key_index.clone())?; - self.primary_key_index = encrypted_index; + self.primary_key_index = + encrypt_bytes_integral_nonce(cipher, self.domain("primary_key_index"), self.primary_key_index.clone())?; + Ok(()) } diff --git a/base_layer/wallet/src/output_manager_service/storage/sqlite_db/mod.rs b/base_layer/wallet/src/output_manager_service/storage/sqlite_db/mod.rs index 09fdafe15d..73af04cd9c 100644 --- a/base_layer/wallet/src/output_manager_service/storage/sqlite_db/mod.rs +++ b/base_layer/wallet/src/output_manager_service/storage/sqlite_db/mod.rs @@ -1456,13 +1456,23 @@ impl From for KnownOneSidedPaymentScriptSql { } impl Encryptable for KnownOneSidedPaymentScriptSql { + fn domain(&self, field_name: &'static str) -> Vec { + [ + Self::KNOWN_ONESIDED_PAYMENT_SCRIPT, + self.script_hash.as_slice(), + field_name.as_bytes(), + ] + .concat() + .to_vec() + } + fn encrypt(&mut self, cipher: &Aes256Gcm) -> Result<(), String> { - self.private_key = encrypt_bytes_integral_nonce(cipher, self.private_key.clone())?; + self.private_key = encrypt_bytes_integral_nonce(cipher, self.domain("private_key"), self.private_key.clone())?; Ok(()) } fn decrypt(&mut self, cipher: &Aes256Gcm) -> Result<(), String> { - self.private_key = decrypt_bytes_integral_nonce(cipher, self.private_key.clone())?; + self.private_key = decrypt_bytes_integral_nonce(cipher, self.domain("private_key"), self.private_key.clone())?; Ok(()) } } diff --git a/base_layer/wallet/src/output_manager_service/storage/sqlite_db/new_output_sql.rs b/base_layer/wallet/src/output_manager_service/storage/sqlite_db/new_output_sql.rs index a48219f5af..8c4e97a5d4 100644 --- a/base_layer/wallet/src/output_manager_service/storage/sqlite_db/new_output_sql.rs +++ b/base_layer/wallet/src/output_manager_service/storage/sqlite_db/new_output_sql.rs @@ -121,15 +121,36 @@ impl NewOutputSql { } impl Encryptable for NewOutputSql { + fn domain(&self, field_name: &'static str) -> Vec { + // WARNING: using `OUTPUT` for both NewOutputSql and OutputSql due to later transition without re-encryption + [Self::OUTPUT, self.script.as_slice(), field_name.as_bytes()] + .concat() + .to_vec() + } + fn encrypt(&mut self, cipher: &Aes256Gcm) -> Result<(), String> { - self.spending_key = encrypt_bytes_integral_nonce(cipher, self.spending_key.clone())?; - self.script_private_key = encrypt_bytes_integral_nonce(cipher, self.script_private_key.clone())?; + self.spending_key = + encrypt_bytes_integral_nonce(cipher, self.domain("spending_key"), self.spending_key.clone())?; + + self.script_private_key = encrypt_bytes_integral_nonce( + cipher, + self.domain("script_private_key"), + self.script_private_key.clone(), + )?; + Ok(()) } fn decrypt(&mut self, cipher: &Aes256Gcm) -> Result<(), String> { - self.spending_key = decrypt_bytes_integral_nonce(cipher, self.spending_key.clone())?; - self.script_private_key = decrypt_bytes_integral_nonce(cipher, self.script_private_key.clone())?; + self.spending_key = + decrypt_bytes_integral_nonce(cipher, self.domain("spending_key"), self.spending_key.clone())?; + + self.script_private_key = decrypt_bytes_integral_nonce( + cipher, + self.domain("script_private_key"), + self.script_private_key.clone(), + )?; + Ok(()) } } diff --git a/base_layer/wallet/src/output_manager_service/storage/sqlite_db/output_sql.rs b/base_layer/wallet/src/output_manager_service/storage/sqlite_db/output_sql.rs index 63480ee86c..8e6cbdd476 100644 --- a/base_layer/wallet/src/output_manager_service/storage/sqlite_db/output_sql.rs +++ b/base_layer/wallet/src/output_manager_service/storage/sqlite_db/output_sql.rs @@ -746,15 +746,36 @@ impl TryFrom for DbUnblindedOutput { } impl Encryptable for OutputSql { + fn domain(&self, field_name: &'static str) -> Vec { + // WARNING: using `OUTPUT` for both NewOutputSql and OutputSql due to later transition without re-encryption + [Self::OUTPUT, self.script.as_slice(), field_name.as_bytes()] + .concat() + .to_vec() + } + fn encrypt(&mut self, cipher: &Aes256Gcm) -> Result<(), String> { - self.spending_key = encrypt_bytes_integral_nonce(cipher, self.spending_key.clone())?; - self.script_private_key = encrypt_bytes_integral_nonce(cipher, self.script_private_key.clone())?; + self.spending_key = + encrypt_bytes_integral_nonce(cipher, self.domain("spending_key"), self.spending_key.clone())?; + + self.script_private_key = encrypt_bytes_integral_nonce( + cipher, + self.domain("script_private_key"), + self.script_private_key.clone(), + )?; + Ok(()) } fn decrypt(&mut self, cipher: &Aes256Gcm) -> Result<(), String> { - self.spending_key = decrypt_bytes_integral_nonce(cipher, self.spending_key.clone())?; - self.script_private_key = decrypt_bytes_integral_nonce(cipher, self.script_private_key.clone())?; + self.spending_key = + decrypt_bytes_integral_nonce(cipher, self.domain("spending_key"), self.spending_key.clone())?; + + self.script_private_key = decrypt_bytes_integral_nonce( + cipher, + self.domain("script_private_key"), + self.script_private_key.clone(), + )?; + Ok(()) } } diff --git a/base_layer/wallet/src/storage/database.rs b/base_layer/wallet/src/storage/database.rs index 9cb0acba60..9050ef9a34 100644 --- a/base_layer/wallet/src/storage/database.rs +++ b/base_layer/wallet/src/storage/database.rs @@ -34,6 +34,7 @@ use tari_comms::{ tor::TorIdentity, }; use tari_key_manager::cipher_seed::CipherSeed; +use tari_utilities::SafePassword; use crate::{error::WalletStorageError, utxo_scanner_service::service::ScannedBlock}; @@ -46,7 +47,7 @@ pub trait WalletBackend: Send + Sync + Clone { /// Modify the state the of the backend with a write operation fn write(&self, op: WriteOperation) -> Result, WalletStorageError>; /// Apply encryption to the backend. - fn apply_encryption(&self, passphrase: String) -> Result; + fn apply_encryption(&self, passphrase: SafePassword) -> Result; /// Remove encryption from the backend. fn remove_encryption(&self) -> Result<(), WalletStorageError>; @@ -276,7 +277,7 @@ where T: WalletBackend + 'static Ok(()) } - pub async fn apply_encryption(&self, passphrase: String) -> Result { + pub async fn apply_encryption(&self, passphrase: SafePassword) -> Result { let db_clone = self.db.clone(); tokio::task::spawn_blocking(move || db_clone.apply_encryption(passphrase)) .await diff --git a/base_layer/wallet/src/storage/sqlite_db/wallet.rs b/base_layer/wallet/src/storage/sqlite_db/wallet.rs index f97972c438..7f207e86a5 100644 --- a/base_layer/wallet/src/storage/sqlite_db/wallet.rs +++ b/base_layer/wallet/src/storage/sqlite_db/wallet.rs @@ -25,11 +25,7 @@ use std::{ sync::{Arc, RwLock}, }; -use aes_gcm::{ - aead::{generic_array::GenericArray, Aead}, - Aes256Gcm, - NewAead, -}; +use aes_gcm::{aead::generic_array::GenericArray, Aes256Gcm, NewAead}; use argon2::{ password_hash::{rand_core::OsRng, PasswordHash, PasswordHasher, PasswordVerifier, SaltString}, Argon2, @@ -46,6 +42,7 @@ use tari_key_manager::cipher_seed::CipherSeed; use tari_utilities::{ hex::{from_hex, Hex}, message_format::MessageFormat, + SafePassword, }; use tokio::time::Instant; @@ -57,7 +54,13 @@ use crate::{ sqlite_db::scanned_blocks::ScannedBlockSql, sqlite_utilities::wallet_db_connection::WalletDbConnection, }, - util::encryption::{decrypt_bytes_integral_nonce, encrypt_bytes_integral_nonce, Encryptable, AES_NONCE_BYTES}, + util::encryption::{ + decrypt_bytes_integral_nonce, + encrypt_bytes_integral_nonce, + Encryptable, + AES_MAC_BYTES, + AES_NONCE_BYTES, + }, utxo_scanner_service::service::ScannedBlock, }; @@ -72,7 +75,7 @@ pub struct WalletSqliteDatabase { impl WalletSqliteDatabase { pub fn new( database_connection: WalletDbConnection, - passphrase: Option, + passphrase: Option, ) -> Result { let cipher = check_db_encryption_status(&database_connection, passphrase)?; @@ -94,8 +97,9 @@ impl WalletSqliteDatabase { }, Some(cipher) => { let seed_bytes = seed.encipher(None)?; - let ciphertext_integral_nonce = encrypt_bytes_integral_nonce(cipher, seed_bytes) - .map_err(|e| WalletStorageError::AeadError(format!("Encryption Error:{}", e)))?; + let ciphertext_integral_nonce = + encrypt_bytes_integral_nonce(cipher, b"wallet_setting_master_seed".to_vec(), seed_bytes) + .map_err(|e| WalletStorageError::AeadError(format!("Encryption Error:{}", e)))?; WalletSettingSql::new(DbKey::MasterSeed.to_string(), ciphertext_integral_nonce.to_hex()).set(conn)?; }, } @@ -109,8 +113,12 @@ impl WalletSqliteDatabase { let seed = match cipher.as_ref() { None => CipherSeed::from_enciphered_bytes(&from_hex(seed_str.as_str())?, None)?, Some(cipher) => { - let decrypted_key_bytes = decrypt_bytes_integral_nonce(cipher, from_hex(seed_str.as_str())?) - .map_err(|e| WalletStorageError::AeadError(format!("Decryption Error:{}", e)))?; + let decrypted_key_bytes = decrypt_bytes_integral_nonce( + cipher, + b"wallet_setting_master_seed".to_vec(), + from_hex(seed_str.as_str())?, + ) + .map_err(|e| WalletStorageError::AeadError(format!("Decryption Error:{}", e)))?; CipherSeed::from_enciphered_bytes(&decrypted_key_bytes, None)? }, }; @@ -171,8 +179,10 @@ impl WalletSqliteDatabase { }, Some(cipher) => { let bytes = bincode::serialize(&tor).map_err(|e| WalletStorageError::ConversionError(e.to_string()))?; - let ciphertext_integral_nonce = encrypt_bytes_integral_nonce(cipher, bytes) - .map_err(|e| WalletStorageError::AeadError(format!("Encryption Error:{}", e)))?; + let ciphertext_integral_nonce = + encrypt_bytes_integral_nonce(cipher, b"wallet_setting_tor_id".to_vec(), bytes) + .map_err(|e| WalletStorageError::AeadError(format!("Encryption Error:{}", e)))?; + WalletSettingSql::new(DbKey::TorId.to_string(), ciphertext_integral_nonce.to_hex()).set(conn)?; }, } @@ -188,8 +198,10 @@ impl WalletSqliteDatabase { TorIdentity::from_json(&key_str).map_err(|e| WalletStorageError::ConversionError(e.to_string()))? }, Some(cipher) => { - let decrypted_key_bytes = decrypt_bytes_integral_nonce(cipher, from_hex(&key_str)?) - .map_err(|e| WalletStorageError::AeadError(format!("Decryption Error:{}", e)))?; + let decrypted_key_bytes = + decrypt_bytes_integral_nonce(cipher, b"wallet_setting_tor_id".to_vec(), from_hex(&key_str)?) + .map_err(|e| WalletStorageError::AeadError(format!("Decryption Error:{}", e)))?; + bincode::deserialize(&decrypted_key_bytes) .map_err(|e| WalletStorageError::ConversionError(e.to_string()))? }, @@ -383,7 +395,7 @@ impl WalletBackend for WalletSqliteDatabase { } } - fn apply_encryption(&self, passphrase: String) -> Result { + fn apply_encryption(&self, passphrase: SafePassword) -> Result { let mut current_cipher = acquire_write_lock!(self.cipher); if current_cipher.is_some() { return Err(WalletStorageError::AlreadyEncrypted); @@ -404,16 +416,17 @@ impl WalletBackend for WalletSqliteDatabase { let passphrase_salt = SaltString::generate(&mut OsRng); let passphrase_hash = argon2 - .hash_password_simple(passphrase.as_bytes(), &passphrase_salt) + .hash_password_simple(passphrase.reveal(), &passphrase_salt) .map_err(|e| WalletStorageError::AeadError(e.to_string()))? .to_string(); let encryption_salt = SaltString::generate(&mut OsRng); let derived_encryption_key = argon2 - .hash_password_simple(passphrase.as_bytes(), encryption_salt.as_str()) + .hash_password_simple(passphrase.reveal(), encryption_salt.as_str()) .map_err(|e| WalletStorageError::AeadError(e.to_string()))? .hash .ok_or_else(|| WalletStorageError::AeadError("Problem generating encryption key hash".to_string()))?; + let key = GenericArray::from_slice(derived_encryption_key.as_bytes()); let cipher = Aes256Gcm::new(key); @@ -424,11 +437,14 @@ impl WalletBackend for WalletSqliteDatabase { None => return Err(WalletStorageError::ValueNotFound(DbKey::MasterSeed)), Some(sk) => sk, }; + let master_seed_bytes = from_hex(master_seed_str.as_str())?; + // Sanity check that the decrypted bytes are a valid CipherSeed let _master_seed = CipherSeed::from_enciphered_bytes(&master_seed_bytes, None)?; - let ciphertext_integral_nonce = encrypt_bytes_integral_nonce(&cipher, master_seed_bytes) - .map_err(|e| WalletStorageError::AeadError(format!("Encryption Error:{}", e)))?; + let ciphertext_integral_nonce = + encrypt_bytes_integral_nonce(&cipher, b"wallet_setting_master_seed".to_vec(), master_seed_bytes) + .map_err(|e| WalletStorageError::AeadError(format!("Encryption Error:{}", e)))?; WalletSettingSql::new(DbKey::MasterSeed.to_string(), ciphertext_integral_nonce.to_hex()).set(&conn)?; // Encrypt all the client values @@ -444,8 +460,9 @@ impl WalletBackend for WalletSqliteDatabase { if let Some(v) = tor_id { let tor = TorIdentity::from_json(&v).map_err(|e| WalletStorageError::ConversionError(e.to_string()))?; let bytes = bincode::serialize(&tor).map_err(|e| WalletStorageError::ConversionError(e.to_string()))?; - let ciphertext_integral_nonce = encrypt_bytes_integral_nonce(&cipher, bytes) - .map_err(|e| WalletStorageError::AeadError(format!("Encryption Error:{}", e)))?; + let ciphertext_integral_nonce = + encrypt_bytes_integral_nonce(&cipher, b"wallet_setting_tor_id".to_vec(), bytes) + .map_err(|e| WalletStorageError::AeadError(format!("Encryption Error:{}", e)))?; WalletSettingSql::new(DbKey::TorId.to_string(), ciphertext_integral_nonce.to_hex()).set(&conn)?; } @@ -478,8 +495,13 @@ impl WalletBackend for WalletSqliteDatabase { Some(sk) => sk, }; - let master_seed_bytes = decrypt_bytes_integral_nonce(&cipher, from_hex(master_seed_str.as_str())?) - .map_err(|e| WalletStorageError::AeadError(format!("Decryption Error:{}", e)))?; + let master_seed_bytes = decrypt_bytes_integral_nonce( + &cipher, + b"wallet_setting_master_seed".to_vec(), + from_hex(master_seed_str.as_str())?, + ) + .map_err(|e| WalletStorageError::AeadError(format!("Decryption Error:{}", e)))?; + // Sanity check that the decrypted bytes are a valid CipherSeed let _master_seed = CipherSeed::from_enciphered_bytes(&master_seed_bytes, None)?; WalletSettingSql::new(DbKey::MasterSeed.to_string(), master_seed_bytes.to_hex()).set(&conn)?; @@ -498,10 +520,13 @@ impl WalletBackend for WalletSqliteDatabase { // remove tor id encryption if present let key_str = WalletSettingSql::get(DbKey::TorId.to_string(), &conn)?; if let Some(v) = key_str { - let decrypted_key_bytes = decrypt_bytes_integral_nonce(&cipher, from_hex(v.as_str())?) - .map_err(|e| WalletStorageError::AeadError(format!("Decryption Error:{}", e)))?; + let decrypted_key_bytes = + decrypt_bytes_integral_nonce(&cipher, b"wallet_setting_tor_id".to_vec(), from_hex(v.as_str())?) + .map_err(|e| WalletStorageError::AeadError(format!("Decryption Error:{}", e)))?; + let tor_id: TorIdentity = bincode::deserialize(&decrypted_key_bytes) .map_err(|e| WalletStorageError::ConversionError(e.to_string()))?; + let tor_string = tor_id .to_json() .map_err(|e| WalletStorageError::ConversionError(e.to_string()))?; @@ -560,7 +585,7 @@ impl WalletBackend for WalletSqliteDatabase { /// Master Public Key that is stored in the db fn check_db_encryption_status( database_connection: &WalletDbConnection, - passphrase: Option, + passphrase: Option, ) -> Result, WalletStorageError> { let start = Instant::now(); let conn = database_connection.get_pooled_connection()?; @@ -581,13 +606,14 @@ fn check_db_encryption_status( let argon2 = Argon2::default(); let stored_hash = PasswordHash::new(&db_passphrase_hash).map_err(|e| WalletStorageError::AeadError(e.to_string()))?; - if let Err(e) = argon2.verify_password(passphrase.as_bytes(), &stored_hash) { + + if let Err(e) = argon2.verify_password(passphrase.reveal(), &stored_hash) { error!(target: LOG_TARGET, "Incorrect passphrase ({})", e); return Err(WalletStorageError::InvalidPassphrase); } let derived_encryption_key = argon2 - .hash_password_simple(passphrase.as_bytes(), encryption_salt.as_str()) + .hash_password_simple(passphrase.reveal(), encryption_salt.as_str()) .map_err(|e| WalletStorageError::AeadError(e.to_string()))? .hash .ok_or_else(|| WalletStorageError::AeadError("Problem generating encryption key hash".to_string()))?; @@ -622,18 +648,19 @@ fn check_db_encryption_status( Err(_) => { // This means the secret key was encrypted. Try decrypt if let Some(cipher_inner) = cipher.clone() { - let mut sk_bytes: Vec = from_hex(sk.as_str())?; - if sk_bytes.len() < AES_NONCE_BYTES { + let sk_bytes: Vec = from_hex(sk.as_str())?; + + if sk_bytes.len() < AES_NONCE_BYTES + AES_MAC_BYTES { return Err(WalletStorageError::MissingNonce); } - // This leaves the nonce in sk_bytes - let data = sk_bytes.split_off(AES_NONCE_BYTES); - let nonce = GenericArray::from_slice(sk_bytes.as_slice()); - let decrypted_key = cipher_inner.decrypt(nonce, data.as_ref()).map_err(|e| { - error!(target: LOG_TARGET, "Incorrect passphrase ({})", e); - WalletStorageError::InvalidPassphrase - })?; + let decrypted_key = + decrypt_bytes_integral_nonce(&cipher_inner, b"wallet_setting_master_seed".to_vec(), sk_bytes) + .map_err(|e| { + error!(target: LOG_TARGET, "Incorrect passphrase ({})", e); + WalletStorageError::InvalidPassphrase + })?; + let _cipher_seed = CipherSeed::from_enciphered_bytes(&decrypted_key, None).map_err(|_| { error!( target: LOG_TARGET, @@ -748,20 +775,32 @@ impl ClientKeyValueSql { } impl Encryptable for ClientKeyValueSql { + fn domain(&self, field_name: &'static str) -> Vec { + [Self::CLIENT_KEY_VALUE, self.key.as_bytes(), field_name.as_bytes()] + .concat() + .to_vec() + } + #[allow(unused_assignments)] fn encrypt(&mut self, cipher: &Aes256Gcm) -> Result<(), String> { - let encrypted_value = encrypt_bytes_integral_nonce(cipher, self.value.as_bytes().to_vec())?; - self.value = encrypted_value.to_hex(); + self.value = + encrypt_bytes_integral_nonce(cipher, self.domain("value"), self.value.as_bytes().to_vec())?.to_hex(); + Ok(()) } #[allow(unused_assignments)] fn decrypt(&mut self, cipher: &Aes256Gcm) -> Result<(), String> { - let decrypted_value = - decrypt_bytes_integral_nonce(cipher, from_hex(self.value.as_str()).map_err(|e| e.to_string())?)?; + let decrypted_value = decrypt_bytes_integral_nonce( + cipher, + self.domain("value"), + from_hex(self.value.as_str()).map_err(|e| e.to_string())?, + )?; + self.value = from_utf8(decrypted_value.as_slice()) .map_err(|e| e.to_string())? .to_string(); + Ok(()) } } @@ -770,7 +809,7 @@ impl Encryptable for ClientKeyValueSql { mod test { use tari_key_manager::cipher_seed::CipherSeed; use tari_test_utils::random::string; - use tari_utilities::hex::Hex; + use tari_utilities::{hex::Hex, SafePassword}; use tempfile::tempdir; use crate::storage::{ @@ -826,7 +865,7 @@ mod test { let db_folder = db_tempdir.path().to_str().unwrap().to_string(); let connection = run_migration_and_create_sqlite_connection(&format!("{}{}", db_folder, db_name), 16).unwrap(); - let passphrase = "an example very very secret key.".to_string(); + let passphrase = SafePassword::from("an example very very secret key.".to_string()); assert!(WalletSqliteDatabase::new(connection.clone(), Some(passphrase.clone())).is_err()); @@ -879,7 +918,7 @@ mod test { }; assert_eq!(seed, read_seed1); - let passphrase = "an example very very secret key.".to_string(); + let passphrase = "an example very very secret key.".to_string().into(); db.apply_encryption(passphrase).unwrap(); let read_seed2 = match db.fetch(&DbKey::MasterSeed).unwrap().unwrap() { DbValue::MasterSeed(sk) => sk, diff --git a/base_layer/wallet/src/storage/sqlite_utilities/mod.rs b/base_layer/wallet/src/storage/sqlite_utilities/mod.rs index c1aae1ce18..9802aa13c7 100644 --- a/base_layer/wallet/src/storage/sqlite_utilities/mod.rs +++ b/base_layer/wallet/src/storage/sqlite_utilities/mod.rs @@ -25,6 +25,7 @@ use std::{fs::File, path::Path, time::Duration}; use fs2::FileExt; use log::*; use tari_common_sqlite::sqlite_connection_pool::SqliteConnectionPool; +use tari_utilities::SafePassword; pub use wallet_db_connection::WalletDbConnection; use crate::{ @@ -125,7 +126,7 @@ pub fn acquire_exclusive_file_lock(db_path: &Path) -> Result>( db_path: P, - passphrase: Option, + passphrase: Option, sqlite_pool_size: usize, ) -> Result< ( diff --git a/base_layer/wallet/src/transaction_service/storage/sqlite_db.rs b/base_layer/wallet/src/transaction_service/storage/sqlite_db.rs index 17972e34a1..2e6921cc50 100644 --- a/base_layer/wallet/src/transaction_service/storage/sqlite_db.rs +++ b/base_layer/wallet/src/transaction_service/storage/sqlite_db.rs @@ -1443,20 +1443,38 @@ impl InboundTransactionSql { } impl Encryptable for InboundTransactionSql { + fn domain(&self, field_name: &'static str) -> Vec { + [ + Self::INBOUND_TRANSACTION, + self.tx_id.to_le_bytes().as_slice(), + field_name.as_bytes(), + ] + .concat() + .to_vec() + } + fn encrypt(&mut self, cipher: &Aes256Gcm) -> Result<(), String> { - let encrypted_protocol = encrypt_bytes_integral_nonce(cipher, self.receiver_protocol.as_bytes().to_vec())?; - self.receiver_protocol = encrypted_protocol.to_hex(); + self.receiver_protocol = encrypt_bytes_integral_nonce( + cipher, + self.domain("receiver_protocol"), + self.receiver_protocol.as_bytes().to_vec(), + )? + .to_hex(); + Ok(()) } fn decrypt(&mut self, cipher: &Aes256Gcm) -> Result<(), String> { let decrypted_protocol = decrypt_bytes_integral_nonce( cipher, + self.domain("receiver_protocol"), from_hex(self.receiver_protocol.as_str()).map_err(|e| e.to_string())?, )?; + self.receiver_protocol = from_utf8(decrypted_protocol.as_slice()) .map_err(|e| e.to_string())? .to_string(); + Ok(()) } } @@ -1613,20 +1631,38 @@ impl OutboundTransactionSql { } impl Encryptable for OutboundTransactionSql { + fn domain(&self, field_name: &'static str) -> Vec { + [ + Self::OUTBOUND_TRANSACTION, + self.tx_id.to_le_bytes().as_slice(), + field_name.as_bytes(), + ] + .concat() + .to_vec() + } + fn encrypt(&mut self, cipher: &Aes256Gcm) -> Result<(), String> { - let encrypted_protocol = encrypt_bytes_integral_nonce(cipher, self.sender_protocol.as_bytes().to_vec())?; - self.sender_protocol = encrypted_protocol.to_hex(); + self.sender_protocol = encrypt_bytes_integral_nonce( + cipher, + self.domain("sender_protocol"), + self.sender_protocol.as_bytes().to_vec(), + )? + .to_hex(); + Ok(()) } fn decrypt(&mut self, cipher: &Aes256Gcm) -> Result<(), String> { let decrypted_protocol = decrypt_bytes_integral_nonce( cipher, + self.domain("sender_protocol"), from_hex(self.sender_protocol.as_str()).map_err(|e| e.to_string())?, )?; + self.sender_protocol = from_utf8(decrypted_protocol.as_slice()) .map_err(|e| e.to_string())? .to_string(); + Ok(()) } } @@ -1941,20 +1977,38 @@ impl CompletedTransactionSql { } impl Encryptable for CompletedTransactionSql { + fn domain(&self, field_name: &'static str) -> Vec { + [ + Self::COMPLETED_TRANSACTION, + self.tx_id.to_le_bytes().as_slice(), + field_name.as_bytes(), + ] + .concat() + .to_vec() + } + fn encrypt(&mut self, cipher: &Aes256Gcm) -> Result<(), String> { - let encrypted_protocol = encrypt_bytes_integral_nonce(cipher, self.transaction_protocol.as_bytes().to_vec())?; - self.transaction_protocol = encrypted_protocol.to_hex(); + self.transaction_protocol = encrypt_bytes_integral_nonce( + cipher, + self.domain("transaction_protocol"), + self.transaction_protocol.as_bytes().to_vec(), + )? + .to_hex(); + Ok(()) } fn decrypt(&mut self, cipher: &Aes256Gcm) -> Result<(), String> { let decrypted_protocol = decrypt_bytes_integral_nonce( cipher, + self.domain("transaction_protocol"), from_hex(self.transaction_protocol.as_str()).map_err(|e| e.to_string())?, )?; + self.transaction_protocol = from_utf8(decrypted_protocol.as_slice()) .map_err(|e| e.to_string())? .to_string(); + Ok(()) } } diff --git a/base_layer/wallet/src/types.rs b/base_layer/wallet/src/types.rs index b25778807d..03c1ef6894 100644 --- a/base_layer/wallet/src/types.rs +++ b/base_layer/wallet/src/types.rs @@ -35,4 +35,12 @@ pub(crate) trait PersistentKeyManager { fn create_and_store_new(&mut self) -> Result; } -hasher!(Blake256, WalletHasher, "com.tari.base_layer.wallet", 1); +hasher!( + Blake256, + WalletEncryptionHasher, + "com.tari.base_layer.wallet.encryption", + 1, + wallet_encryption_hasher +); + +hasher!(Blake256, WalletHasher, "com.tari.base_layer.wallet", 1, wallet_hasher); diff --git a/base_layer/wallet/src/util/encryption.rs b/base_layer/wallet/src/util/encryption.rs index f7e4364db6..91a5a16044 100644 --- a/base_layer/wallet/src/util/encryption.rs +++ b/base_layer/wallet/src/util/encryption.rs @@ -25,33 +25,83 @@ use aes_gcm::{ Aes256Gcm, }; use rand::{rngs::OsRng, RngCore}; +use tari_utilities::ByteArray; + +use crate::types::WalletEncryptionHasher; pub const AES_NONCE_BYTES: usize = 12; pub const AES_KEY_BYTES: usize = 32; +pub const AES_MAC_BYTES: usize = 32; pub trait Encryptable { + const KEY_MANAGER: &'static [u8] = b"KEY_MANAGER"; + const OUTPUT: &'static [u8] = b"OUTPUT"; + const WALLET_SETTING_MASTER_SEED: &'static [u8] = b"MASTER_SEED"; + const WALLET_SETTING_TOR_ID: &'static [u8] = b"TOR_ID"; + const INBOUND_TRANSACTION: &'static [u8] = b"INBOUND_TRANSACTION"; + const OUTBOUND_TRANSACTION: &'static [u8] = b"OUTBOUND_TRANSACTION"; + const COMPLETED_TRANSACTION: &'static [u8] = b"COMPLETED_TRANSACTION"; + const KNOWN_ONESIDED_PAYMENT_SCRIPT: &'static [u8] = b"KNOWN_ONESIDED_PAYMENT_SCRIPT"; + const CLIENT_KEY_VALUE: &'static [u8] = b"CLIENT_KEY_VALUE"; + + fn domain(&self, field_name: &'static str) -> Vec; fn encrypt(&mut self, cipher: &C) -> Result<(), String>; fn decrypt(&mut self, cipher: &C) -> Result<(), String>; } -pub fn decrypt_bytes_integral_nonce(cipher: &Aes256Gcm, ciphertext: Vec) -> Result, String> { - if ciphertext.len() < AES_NONCE_BYTES { +pub fn decrypt_bytes_integral_nonce( + cipher: &Aes256Gcm, + domain: Vec, + ciphertext: Vec, +) -> Result, String> { + if ciphertext.len() < AES_NONCE_BYTES + AES_MAC_BYTES { return Err(AeadError.to_string()); } - let (nonce, cipher_text) = ciphertext.split_at(AES_NONCE_BYTES); + + let (nonce, ciphertext) = ciphertext.split_at(AES_NONCE_BYTES); + let (ciphertext, appended_mac) = ciphertext.split_at(ciphertext.len().saturating_sub(AES_MAC_BYTES)); let nonce = GenericArray::from_slice(nonce); - cipher.decrypt(nonce, cipher_text.as_ref()).map_err(|e| e.to_string()) + + let expected_mac = WalletEncryptionHasher::new_with_label("storage_encryption_mac") + .chain(nonce.as_slice()) + .chain(ciphertext) + .chain(domain) + .finalize(); + + if appended_mac != expected_mac.as_ref() { + return Err(AeadError.to_string()); + } + + let plaintext = cipher.decrypt(nonce, ciphertext.as_ref()).map_err(|e| e.to_string())?; + + Ok(plaintext) } -pub fn encrypt_bytes_integral_nonce(cipher: &Aes256Gcm, plaintext: Vec) -> Result, String> { +pub fn encrypt_bytes_integral_nonce( + cipher: &Aes256Gcm, + domain: Vec, + plaintext: Vec, +) -> Result, String> { let mut nonce = [0u8; AES_NONCE_BYTES]; OsRng.fill_bytes(&mut nonce); let nonce_ga = GenericArray::from_slice(&nonce); + let mut ciphertext = cipher - .encrypt(nonce_ga, plaintext.as_ref()) + .encrypt(nonce_ga, plaintext.as_bytes()) .map_err(|e| e.to_string())?; + + let mut mac = WalletEncryptionHasher::new_with_label("storage_encryption_mac") + .chain(nonce.as_slice()) + .chain(ciphertext.clone()) + .chain(domain.as_slice()) + .finalize() + .as_ref() + .to_vec(); + let mut ciphertext_integral_nonce = nonce.to_vec(); ciphertext_integral_nonce.append(&mut ciphertext); + ciphertext_integral_nonce.append(&mut mac); + Ok(ciphertext_integral_nonce) } @@ -70,8 +120,25 @@ mod test { let key = GenericArray::from_slice(b"an example very very secret key."); let cipher = Aes256Gcm::new(key); - let cipher_text = encrypt_bytes_integral_nonce(&cipher, plaintext.clone()).unwrap(); - let decrypted_text = decrypt_bytes_integral_nonce(&cipher, cipher_text).unwrap(); + let ciphertext = encrypt_bytes_integral_nonce(&cipher, b"correct_domain".to_vec(), plaintext.clone()).unwrap(); + let decrypted_text = + decrypt_bytes_integral_nonce(&cipher, b"correct_domain".to_vec(), ciphertext.clone()).unwrap(); + + // decrypted text must be equal to the original plaintext assert_eq!(decrypted_text, plaintext); + + // must fail with a wrong domain + assert!(decrypt_bytes_integral_nonce(&cipher, b"wrong_domain".to_vec(), ciphertext.clone()).is_err()); + + // must fail without nonce + assert!(decrypt_bytes_integral_nonce(&cipher, b"correct_domain".to_vec(), ciphertext[0..12].to_vec()).is_err()); + + // must fail without mac + assert!(decrypt_bytes_integral_nonce( + &cipher, + b"correct_domain".to_vec(), + ciphertext[0..ciphertext.len().saturating_sub(32)].to_vec() + ) + .is_err()); } } diff --git a/base_layer/wallet/src/wallet.rs b/base_layer/wallet/src/wallet.rs index 34834d4aae..8521034b2a 100644 --- a/base_layer/wallet/src/wallet.rs +++ b/base_layer/wallet/src/wallet.rs @@ -69,6 +69,7 @@ use tari_p2p::{ use tari_script::{script, ExecutionStack, TariScript}; use tari_service_framework::StackBuilder; use tari_shutdown::ShutdownSignal; +use tari_utilities::SafePassword; use crate::{ assets::{infrastructure::initializer::AssetManagerServiceInitializer, AssetManagerHandle}, @@ -685,7 +686,7 @@ where /// Apply encryption to all the Wallet db backends. The Wallet backend will test if the db's are already encrypted /// in which case this will fail. - pub async fn apply_encryption(&mut self, passphrase: String) -> Result<(), WalletError> { + pub async fn apply_encryption(&mut self, passphrase: SafePassword) -> Result<(), WalletError> { debug!(target: LOG_TARGET, "Applying wallet encryption."); let cipher = self.db.apply_encryption(passphrase).await?; self.output_manager_service.apply_encryption(cipher.clone()).await?; diff --git a/base_layer/wallet/tests/output_manager_service_tests/service.rs b/base_layer/wallet/tests/output_manager_service_tests/service.rs index 0f8b980230..963b5a12c6 100644 --- a/base_layer/wallet/tests/output_manager_service_tests/service.rs +++ b/base_layer/wallet/tests/output_manager_service_tests/service.rs @@ -185,35 +185,17 @@ async fn setup_output_manager_service>(), None, ) .unwrap(); diff --git a/base_layer/wallet/tests/wallet.rs b/base_layer/wallet/tests/wallet.rs index cd49ddf3fe..a669bd8ede 100644 --- a/base_layer/wallet/tests/wallet.rs +++ b/base_layer/wallet/tests/wallet.rs @@ -61,7 +61,7 @@ use tari_p2p::{ use tari_script::{inputs, script}; use tari_shutdown::{Shutdown, ShutdownSignal}; use tari_test_utils::{collect_recv, random}; -use tari_utilities::Hashable; +use tari_utilities::{Hashable, SafePassword}; use tari_wallet::{ contacts_service::{ handle::ContactsLivenessEvent, @@ -114,7 +114,7 @@ async fn create_wallet( database_name: &str, factories: CryptoFactories, shutdown_signal: ShutdownSignal, - passphrase: Option, + passphrase: Option, recovery_seed: Option, ) -> Result { const NETWORK: Network = Network::LocalNet; @@ -316,14 +316,14 @@ async fn test_wallet() { let current_wallet_path = alice_db_tempdir.path().join("alice_db").with_extension("sqlite3"); alice_wallet - .apply_encryption("It's turtles all the way down".to_string()) + .apply_encryption("It's turtles all the way down".to_string().into()) .await .unwrap(); // Second encryption should fail #[allow(clippy::match_wild_err_arm)] match alice_wallet - .apply_encryption("It's turtles all the way down".to_string()) + .apply_encryption("It's turtles all the way down".to_string().into()) .await { Ok(_) => panic!("Should not be able to encrypt twice"), @@ -342,7 +342,7 @@ async fn test_wallet() { panic!("Should not be able to instantiate encrypted wallet without cipher"); } - let result = WalletSqliteDatabase::new(connection.clone(), Some("wrong passphrase".to_string())); + let result = WalletSqliteDatabase::new(connection.clone(), Some("wrong passphrase".to_string().into())); if let Err(err) = result { assert!(matches!(err, WalletStorageError::InvalidPassphrase)); @@ -350,7 +350,7 @@ async fn test_wallet() { panic!("Should not be able to instantiate encrypted wallet without cipher"); } - let db = WalletSqliteDatabase::new(connection, Some("It's turtles all the way down".to_string())) + let db = WalletSqliteDatabase::new(connection, Some("It's turtles all the way down".to_string().into())) .expect("Should be able to instantiate db with cipher"); drop(db); @@ -360,7 +360,7 @@ async fn test_wallet() { "alice_db", factories.clone(), shutdown_a.to_signal(), - Some("It's turtles all the way down".to_string()), + Some("It's turtles all the way down".to_string().into()), None, ) .await @@ -781,14 +781,14 @@ async fn test_recovery_birthday() { // let seed = CipherSeed::new(); // use tari_key_manager::mnemonic::MnemonicLanguage; // let mnemonic_seq = seed - // .to_mnemonic(MnemonicLanguage::English, None) + // .to_mnemonic(MnemonicLanguage::Spanish, None) // .expect("Couldn't convert CipherSeed to Mnemonic"); // println!("{:?}", mnemonic_seq); let seed_words: Vec = [ - "parade", "allow", "earth", "sibling", "jealous", "tower", "pet", "project", "pole", "dizzy", "tower", "genre", - "marine", "immense", "region", "diagram", "dress", "symptom", "dutch", "require", "virus", "angry", "cotton", - "nominee", + "octavo", "joroba", "aplicar", "lamina", "semilla", "tiempo", "codigo", "contar", "maniqui", "guiso", + "imponer", "barba", "torpedo", "mejilla", "fijo", "grave", "caer", "libertad", "sol", "sordo", "alacran", + "bucle", "diente", "vereda", ] .iter() .map(|w| w.to_string()) diff --git a/base_layer/wallet_ffi/src/lib.rs b/base_layer/wallet_ffi/src/lib.rs index 1a645e5f19..6182babeea 100644 --- a/base_layer/wallet_ffi/src/lib.rs +++ b/base_layer/wallet_ffi/src/lib.rs @@ -131,7 +131,7 @@ use tari_p2p::{ }; use tari_script::{inputs, script}; use tari_shutdown::Shutdown; -use tari_utilities::{hex, hex::Hex}; +use tari_utilities::{hex, hex::Hex, SafePassword}; use tari_wallet::{ connectivity_service::WalletConnectivityInterface, contacts_service::storage::database::Contact, @@ -4256,7 +4256,7 @@ pub unsafe extern "C" fn wallet_create( .to_str() .expect("A non-null passphrase should be able to be converted to string") .to_owned(); - Some(pf) + Some(SafePassword::from(pf)) }; let network = if network_str.is_null() { @@ -6792,8 +6792,8 @@ pub unsafe extern "C" fn wallet_apply_encryption( let pf = CStr::from_ptr(passphrase) .to_str() - .expect("A non-null passphrase should be able to be converted to string") - .to_owned(); + .map(|s| SafePassword::from(s.to_owned())) + .expect("A non-null passphrase should be able to be converted to string"); if let Err(e) = (*wallet).runtime.block_on((*wallet).wallet.apply_encryption(pf)) { error = LibWalletError::from(e).code; @@ -9276,17 +9276,16 @@ mod test { // To create a new seed word sequence, uncomment below // let seed = CipherSeed::new(); - // use tari_key_manager::mnemonic::MnemonicLanguage; - // use tari_key_manager::mnemonic::Mnemonic; + // use tari_key_manager::mnemonic::{Mnemonic, MnemonicLanguage}; // let mnemonic_seq = seed // .to_mnemonic(MnemonicLanguage::English, None) // .expect("Couldn't convert CipherSeed to Mnemonic"); // println!("{:?}", mnemonic_seq); let mnemonic = vec![ - "theme", "stove", "win", "endorse", "ostrich", "voyage", "frequent", "battle", "crime", "volcano", - "dune", "also", "lunar", "banner", "clay", "that", "urge", "spin", "uncover", "extra", "village", - "mask", "trumpet", "bag", + "scale", "poem", "sorry", "language", "gorilla", "despair", "alarm", "jungle", "invite", "orient", + "blast", "try", "jump", "escape", "estate", "reward", "race", "taxi", "pitch", "soccer", "matter", + "team", "parrot", "enter", ]; let seed_words = seed_words_create(); diff --git a/comms/dht/src/crypt.rs b/comms/dht/src/crypt.rs index aa6273eaa5..a2c6c31214 100644 --- a/comms/dht/src/crypt.rs +++ b/comms/dht/src/crypt.rs @@ -55,6 +55,9 @@ use crate::{ pub struct CipherKey(chacha20::Key); pub struct AuthenticatedCipherKey(chacha20poly1305::Key); +const LITTLE_ENDIAN_U32_SIZE_REPRESENTATION: usize = 4; +const MESSAGE_BASE_LENGTH: usize = 6000; + /// Generates a Diffie-Hellman secret `kx.G` as a `chacha20::Key` given secret scalar `k` and public key `P = x.G`. pub fn generate_ecdh_secret(secret_key: &CommsSecretKey, public_key: &CommsPublicKey) -> [u8; 32] { // TODO: PK will still leave the secret in released memory. Implementing Zerioze on RistrettoPublicKey is not @@ -66,6 +69,47 @@ pub fn generate_ecdh_secret(secret_key: &CommsSecretKey, public_key: &CommsPubli output } +fn pad_message_to_base_length_multiple(message: &[u8]) -> Vec { + let n = message.len(); + // little endian representation of message length, to be appended to padded message, + // assuming our code runs on 64-bits system + let prepend_to_message = (n as u32).to_le_bytes(); + + let k = prepend_to_message.len(); + + let div_n_base_len = (n + k) / MESSAGE_BASE_LENGTH; + let output_size = (div_n_base_len + 1) * MESSAGE_BASE_LENGTH; + + // join prepend_message_len | message | zero_padding + let mut output = Vec::with_capacity(output_size); + output.extend_from_slice(&prepend_to_message); + output.extend_from_slice(message); + output.extend(std::iter::repeat(0u8).take(output_size - n - k)); + + output +} + +fn get_original_message_from_padded_text(message: &[u8]) -> Result, DhtOutboundError> { + let mut le_bytes = [0u8; 4]; + le_bytes.copy_from_slice(&message[..LITTLE_ENDIAN_U32_SIZE_REPRESENTATION]); + + // obtain length of original message, assuming our code runs on 64-bits system + let original_message_len = u32::from_le_bytes(le_bytes) as usize; + + if original_message_len > message.len() { + return Err(DhtOutboundError::CipherError( + "Original length message is invalid".to_string(), + )); + } + + // obtain original message + let start = LITTLE_ENDIAN_U32_SIZE_REPRESENTATION; + let end = LITTLE_ENDIAN_U32_SIZE_REPRESENTATION + original_message_len; + let original_message = &message[start..end]; + + Ok(original_message.to_vec()) +} + pub fn generate_key_message(data: &[u8]) -> CipherKey { // domain separated hash of data (e.g. ecdh shared secret) using hashing API let domain_separated_hash = comms_dht_hash_domain_key_message().chain(data).finalize(); @@ -96,6 +140,9 @@ pub fn decrypt(cipher_key: &CipherKey, cipher_text: &[u8]) -> Result, Dh let mut cipher = ChaCha20::new(&cipher_key.0, nonce); cipher.apply_keystream(cipher_text.as_mut_slice()); + + // get original message, from decrypted padded cipher text + let cipher_text = get_original_message_from_padded_text(cipher_text.as_slice())?; Ok(cipher_text) } @@ -117,6 +164,9 @@ pub fn decrypt_with_chacha20_poly1305( /// Encrypt the plain text using the ChaCha20 stream cipher pub fn encrypt(cipher_key: &CipherKey, plain_text: &[u8]) -> Vec { + // pad plain_text to avoid message length leaks + let plain_text = pad_message_to_base_length_multiple(plain_text); + let mut nonce = [0u8; size_of::()]; OsRng.fill_bytes(&mut nonce); @@ -125,7 +175,8 @@ pub fn encrypt(cipher_key: &CipherKey, plain_text: &[u8]) -> Vec { let mut buf = vec![0u8; plain_text.len() + nonce.len()]; buf[..nonce.len()].copy_from_slice(&nonce[..]); - buf[nonce.len()..].copy_from_slice(plain_text); + + buf[nonce.len()..].copy_from_slice(plain_text.as_slice()); cipher.apply_keystream(&mut buf[nonce.len()..]); buf } @@ -226,9 +277,10 @@ mod test { fn decrypt_fn() { let pk = CommsPublicKey::default(); let key = CipherKey(*chacha20::Key::from_slice(pk.as_bytes())); - let cipher_text = - from_hex("24bf9e698e14938e93c09e432274af7c143f8fb831f344f244ef02ca78a07ddc28b46fec536a0ca5c04737a604") - .unwrap(); + let cipher_text = from_hex( + "6063cd49c7b871c0fc9785e9b959fda553fadbb10bcaaced0958e88eb6858e05fe310b4401a78d03b52a81be49db2bffcce13765e1a64460063d33289b1a3527af3df8e292c79abca71aa9a87baa1a0a6c23532a3297dda9e0c22d4b60606db1ed02a75e7a7d21fafe1214cbf8a3a66ec319a6aafeeb0e7b06375370c52b2abe63170ce50552a697f1ff87dc03ae1df574ed8e7abf915aec6959808ec526d6da78f08f2bed24268028baeba3ebd52d0fde34b145267ced68a08c4d480c213d0ab8b9c55a1630e956ed9531d6db600537f3997d612bec5905bd3ce72f5eace475e9793e6e316349f9cbd49022e401870af357605a1c7d279d5f414a5cae13e378711f345eabf46eb7fabc9465376f027a8d2d69448243cf2d70223c2430f6ce1ff55b8c2f54e27ccf77040f70c9eb84c9da9f8176a867ebcf8cbb9dfbb9256d688a76ec02af3afe3fa8221a4876462a754bc65a15c584b8c132a48dd955821961e47d5fdce62668b8f11a42b9127d5d98414bbd026eed0b5511628b96eb8435880f38d28bad6573f90611397b7f0df46d371e29a1de3591a3aa221623a540693941d6fc22d4fa57bc612282868f024613c2f69224141ad623eb49cb2b151ff79b36ec30f9a2842c696fd94f3092989ba3f8f5136850ed1f970a559e3fe19c24b84f650da4d46a5abc3514a242e3088cae15b6012e2a3cd878a7b988cc6183e81a27fc263d3be1ef07403262b22972b4f78a38d930f3e4ec1f5c9f1a1639cd138f24c665fc39b808dfbed81f8888e8e23aac1cdd980dcd0f7d95dc71fb02abcf532be17208b8d86a5fd501c18f2b18234956617797ea78523907c46a7b558cde76d1d034d6ab36ca33c4c17ea617929e3e40af011417f9b3983912fd0ed60685261532bcca5f3b3863d92cf2e4eb33a97613d9ca55112d1bd63df9d591042f972e2da3bde7b5a572f9f4cd2633330fdb3250430f27d3a40a30a996d5a41d61dafc3fe96a1fb63e2cab5c3ec0f1084e778f303da498fc970fe3117b8513166a6c8798e00a82a6c96a61b419e12717fac57dd1c989d3a10b3ab798ee0c15a5cd04dbe83667523ad4b7587ff331513b63f15c72f1844d67c7830f723fa754969d3f4254895e71d7087617ded071a797ee5791b7a95abcb360ef504bdaa85191b09b2345c0d0096fafa85b10d1675f0c4a1fe45231fd88c715c32c38d9697cebfb712e1ce57645c8faa08b4983ab8f537b017c9898b7907c06b25c01ea0f9d736d573ec7e3b14efa84d84258131ddc11a9696e12234ab65fb4e653d8dcdf4c6ce51e0104aedf2b8089593c4b5d665ed885c2d798843cf80041657ba9d800bbb9eee7076c212cfdc56df1d63e3eb4de5c6f132376c39b0272fc35e4729988eeaf9f7748142b68b10a72a948f6db24baebe3323c11002edfbe878972f02920a9ca536ff917c37a20eb0d984ee230885950fb271e56a8640d5fff75cef501cc82a46244eca5bfed4ab180b4c4d61abb209f38dabddcb4a82581bf4990631e11cc670507f8ce88c10a7fcc06ed824d7a1aab0b4dc4a2b984f45c7447b077262dc6b90f04516a5598193e95a40e982f092a2b61fcf5203379e6b042715792915adfd0edc66a69bb7fbacf64b9c86f3b15da66f8d8eecfc373204ad10f2c666181f0facd972be208c620bb0539e396c6254f2efcc2934ae6265fdc0b8172577435758e7da46df6f175f2867e6bd734fe70416b70b643b9e4b0e0e844ccfa019c699f85313f94cc6a1bb371bb14fd65a1599355ec8cdd71a555d98a900b8cbd2b1de3378f42a137eb95b473d62be3559ed687583d6963b6857d3be5f7acc12f5f5b04b26b5c582ce6dd8d1bee6322b02c2c1dc29fcba20899d529df4fd6c1edbfd1081d6cf10b20b9451ad935da2c4cef66c160550b180ba1b668029ed15448cd288427aca7f6e6505fdfc69b8111a2a071601d78484e857705a4bc7f9423800ded4eba46e0f22ee85c48fc9e8a191355edc0868df350d627a7f1120d79ba4aa1dde1ec97f8daeb0a07be914d5f6d2e74270666d03e4ca92845957b85982761dc1ee6f7603e31681dd323a045c0ac3b06b515d7bd485bfe7f6abe31e35aac7d8536b3f9c572121fcdd44c505ccfffe514e498732cab4e70524a5281b0942f5ae861b535764f056df6a1951b3c1c261f21b3b5f0a276ed05e32879ede428683b34ac8e7ebc88c9c767bf5e7cfb0cf444c1f9fd5be9041f69f6ae9772b0e035c6a2a7d9c1c614858a506a4a4bc00cee0577561b96e98c973edfa43d52471db9c716699e52260a20150aa99f8adea872c999b66fb4395d5b8c717a2c97eb7638a1d92da2ef8b2ec80db3afa3ce83445aaccae09f38c0b84c85a8983ba4c3b9a13fed4c65fd8899333da4dbca549cd2a487eb58841881f3571dfa4821bc522b56993d657bce51dfb41caf6c2cb78e8b6beceddc44febdea144da13ae9ccd9465b3ac96b06dfe79baced35ad51763d05090dc7620c89f448134507f41828be8703fd2ab1f53370e75e55366eba1e903311313707279d5965e3343476c0a8aeef2001ad88d5e452d648dd2029a6f549809c4177d1871c88abcd1404d52ebee2dd97dc52ad1a9c018428a1a64fda6773a6ea967d4124a6cf98c7e6dc4c4d9c051a376d3e3fe2e17f6cd044dd60ee32e9d6bdbdfdcbdecc4e7306092186a7ad8ab87328f9fedb6ee8ab9417968fbaa0e582205a660fa55e1ba3c5b0c84b67017f250338125894d162c400be8d563c9f0416dc5641d31bad577543cba8c6c9a7c04064e412597d47c6272d8e087bc11397533cb1bd7feebea9feee44e1b6a5f49b937594da3b719e1982a90594277f43798a39e419c204f18a6920e5ac1a751eddeaef9392a6f84d68d73aabc6ba68750d47ad4da8bd842662226225a764661ea11ff9f13d328e0242a0b513aa5ad9fbe9d484b3d28a41890e4fc62820ef2342a90c0837b30c831eb78213e7e2cd6dfbda26a7e6103ab8b4219462ca70ca57c79638b2c49f0469ea6f68335071294257c5337ccf452ca1bfedf81610f353e7576f02a2b32aba64a4252946fda330de11990f51207817860e0d8b7c9cb58a5858155db61376a01c02aaedb7017fd3c36adf4f3c07f29f352330c6d78ab6bbb7d4aabf3725833e86523b755094273465ba57545162623036a7786f426d0a63e13bebf2205a6b488bd6da3c93469a4df4b3811e9c63d62c61e0cdd263df821adc0d1b751c1314be9fd93761b447931e425db7e09baac9083aed472de5fe6172c8e8f729ade8faa96d131e86204462e14e0411b4b7629de25a0c5dbf848c9ca8c42376f5d54bff34bf36074136bbe98228745dbc9d411d891553f0af00240e1729ce7757fba2775fa5b700e95460910008584a833fb9edc073cd4d8333643631e193040d850f87cd50d9ca2e2e5c3943787dc4a4677ac7e130c2d6739945fd3b059ebe040abb38a20d73a7669516cf8503f40642217c8580a27b127f1f33eaa7adff44c922afac813c870795563fac79d139d5b5233a26728328f88f1f9daaaea1c4e1ee64ded0b006ce46015d512e8c4a411ab788a5383563949c95846202250c5b9e0baab0bc8620327ed2aacd36e1bdc9d3a4d6b4e22627d75bd088cdd47ca204f1ce44357d1b471b37581c820f6bbdfe3da1f4f90dc353833731703b7b9bb87ff2d0cae1e2f0321994759d1a21b2075a620b58b814cd65812092891261dd7e879b65843480382f59e20d6b6c67b2fb750ff0cfce897891f976b0fae7ac31e02384b251bcbecce6ff98819cf0cd6d41fdab9ba6907742394732ed5e74bcd13aad1a188855c020f09e62540be9b2992a397b30107ad730ebe183504226b303f30032f4c0a683812d05be57961430504866bd2ee6993423ebe34ba4d2d022ce6d5b2345bbed34d6807aec473ad0701b9b8fe2db1cef57748dfcb29ddb3b253a865dd7383d04253cd70c350d02ccd2371cecfd74aae820fa91eddd89d27925c33183e03d44c7f88f8068c64d223d2d5f4ab18fae6d209e1e267395576f4f48ae056da7d6e91f94991659b4c07f44aa1c45aaf75b7274b7668753f968d5e6635f4abf238e5d44ffc38e68cad8237f7e7a25d5fc0dcd5afc2bedbac6b42e8bc8064118c9042d1159f70dfac73d65c8a9782c264445af11c878591d49d49ad46f4e6d086d55232afd234c3bceab2eef0e22e5c2875670c5125e8a172f5f2168e59fe0cb5e9e1a81bf645a2c45d115b9a3efe9fe2d1799f12b0c11f50ae5540ff4e90e6220eb62451e10ce1418929e03c751d9019d47b87847595333feb6ab4af40662d04c3ece4f93b4c2c2f2ee2078724090336f16a4f33801095036a31b557960b5d8d2552f0aadfa3dc9dcfe8f1dd6a61631b6a69ee6ce8433153f8b1ea99a9a5ac688026d6ef408f2aa958ada8baf0193b3989f359c7a913fcb9eec230568584bcda3a759c824884c9febff518c7cf312360d2c1ffd2bbdd0b2e9346cbe1bf383446bd2fec431475ec509474ef9eb06817f53d3c4ca74fba08c3b434eabf3ae9fcc2287c588fc5574bff37066705ca9a39d088cd5cbb83b385b5cf647ced0c23885295d2b24f37e4098be82edccc23e1c973b1855e2009de63408c78e570b3cec65c6d236d81adb1bc298436a1e125b99bb995a5c6df5b2a4e70b8cf1db5de38120134527ce349c32f8e35fa43837aa38cdb1d5695a34d12d27bd5ee4536d9a20e62b55e59cdc7ecca1f4398dac7a4b756d9e131a7d2c8bde32c20ef0424154c88c8276fdf3c75f08f3cd423bd648ff3520680a1f1dd956451881f6d31238c11c99a20e1d9170410c8d8eb88ce90e179fc80e23e36a28b1810383a4d0d1ef0f2db94206aa1fb25498b425e5ad1f0f0bd3eed22ca5545ef541880f37f8fea82fecd59d8c94765d3a454e81775844701412e3c01a6dcdbf277428969a7f08d67313cdd2ce3b531addee28733552ee1bf4124ad8b3e40e04b94599e04cce60f5676307b0605ad7dc73b03cc26227eab60196d37c312a01858f5ad6a901e0f1c796c52cb9690da5c712a2d36c74e65ec9a60ea41387b8a0f79697cdfd93e40ab569d6a55361be97fb7ac8d80b5a5482908d44af94df2fb09a777978f4d911008d528ff44aef960cfd25fb56e26c341850721f020f9fd112cd52fc28dd129ffc2f9a12a829dcb69b54a894d4b3d1ac3b63bc9bcd39e30a00e419c8f4d2b630c224880a7d3af9c19c8a79262818b368589e7ad03b722021306fbcbcf7bc87ad418a3eb6616e7d4ce286264554be6040e8e4cd0c5a9bbdd2367e47d1fe0a9c3eeaf2455c4f6f779bab3d5bea5284a244fc3e804fb6d0e50fec91f85b71c6ba91f43a240fa48900229e5f3038b0806f70a1cb72fdea58b664f06c04bf688183a4f22255d6976f2102aafb669ee117fa1e44ae325ad52001469fed9d26e4f8592f56e42bf5e7195f521c0beaf891e47a703075fa1948ee07add55a765346b94ae498fa96145ad8460f23248222e329398fec6ad7f323c448ce82bb706b24e07adc0681901a63d5d1c7b871a9df8009ed7bb10be4e39a987c1bf039554a016ac8693284a7248fb8a9aa440dde213c2414447727c1556d25f1fbce057652044e2350b9ef5627584d403a934dd33e8c26e20799f1dbf915705b70d66256d31ca7c407307fa18e163917635d67f742828deba4b942b5f0d916b5e737b5811d3c3b4ac386c7ebaad1a6c465ce9fb229bc6ce7ae62f8efd8632e5312db8ba213d28d19843ac7fbae105a1433921b34c216c3c2ab247080a629c7ac5507129b27ce0d38ddde06722a5a0a979894d6140c31a82bfc517adf0b57c761f75bc14d65d8701e0dea92a06584f2d877dc5fb0b32496754e6b0115e99a9623ad631ea0a76b4e7893bf0982151e1c5ed6d64a305393f6f715de333653ace204c2f03de8f36c463c937f7f23326a88337624fc606317d7c0ea2badf69e40602c2ff1e2dcc9cfca1ccef566381712af157c5e458335c8a283733a617a75fd7cef52c515c754443c8e9e1930994805e6f0b2a9a2ccbd848f6580896317dd9dbfda17d00e80d35bc58a704fbe7d6d6d45752811130f682ea9471903c9af9b5c95d074ec87b32c86dc5b29a186a60a7a03e630c7a5cd38e6ab1a2f561642d5662658fc20239233505727575e75dbbc5be630f9ebc9485f03bbf0569554a87bfb5cfd397daf5d8d92a38fc4b24b1a433edad26a12c4e29506362dec83fb9b1f31158e834cde319d40bd283f0a1f3995b3ade08bcc01c794d656583b928300a6f2be57e5bf1586a123cf28ac8e2287e0b7ab67419dc4f527f714fdee8c47088cc1857a39825524dc3f5a3777c1f906cf496dac43e3f8304ebe5d5696da5b7e5d79f176a391736ba46bc718356e1713a00ea754a52b5899ba7eb71b10bbd211cead7d1890f2d8bb981a2549e2cd53bf895f96f628c4d00061275c87f4dcbadcb5944f912d27aaa4124cadd0e2a1d82ee4d3a8b977bd1a03fc6f79caf4c306addea0bd72b754c113350655324dec3dcbb1f1de66e3e7a9f06ba0e04de0cae7af7d6e31298bf5be706038e0d8477a79ea5f8e21decdf6f5ff71090d8cb2dabb9d1a87ee526b0ec84be81ad9585b09f165cfff7a4a63e30ac7341a3a42e3e02bfc34486a2a5492a39dd31dc233c74f454584e5bd2524382a08357ba2d3ef46833551ed3fa6f672d5ba0ec75258430738c16989840b6ae6909f10340d1845bd975cf2933047a1c22c332606f76681dae8727921d4f1345f0457700b8622ea72a50c17cda201f7019d4af9dc0b67bd95317d98c2fe38ab8e12dbefe236b463014caf9c3cdd390dfcff034e90e4f51e1233bb8b341bba6f922d1b0629e261844018f39d054b26cb82592e33466354f552a14f9e6175d418cb9724fa045e723b5ab9ece5aa45800f1202b3d174fe4e129e7320a9063039f8bdea8601762ff45933503e0bd10944893e565d641a39289da67e5269dc7cfc22dc3d5f5011b66b25340cea4055bd66a7752c69e624bfc12e5cc68cd0b5cab3242860b7f40541303e228e666a6500e1e739b0b6d853b5715cf3a668facc135d133d4eefa035f36c16838f25ab4a0d2f20d18f05c0fae1a705370162cfe6f7fcf1c69654c2ce73e3820b48568c25a6d9d036c2386dee6ef14e5fc967ffccde38bf263c8c0f924bcfcdf54669dcf872724280cc4f81bb2aa993998f6312d0c6084ed823e5e6bff0ed25cc4e82b749dc11f4cf55290344d9c307d634793e81b9d3d457765dc6f81b66f1a6aaaa1079558a4892edfc342fe24856200b5dcc65c9e7809b655d3cc7bb26bcda91933f590bc61099fd0b83e04bad2174150645afd7c3ccac5417234e30da4e7574af953f8d9b7a5029417d439f1d13c4390bed2bc05d73821ff2355c33da5f95623c73abd826614572841e4777a9a0b538cef4a2c6327c75116977322a8c488f466178cdcaf3f0e10df86dbd1827ba2cc4c8fba90a1d64ad783a77704c5b1262cd11cb010f09ab04377d6e5ebed4d5dfe8eaa0cc2535a0be69bdf5e1987167b8135428ef84287aa4424c35c7a7bc94cabd553df4840121403b2ba3479e1a7f86085cce49c245af944a2ed78b77784309d05d5f5587a6e589baa6e7d279b0ab43afb5497b6d5954b3a8f66dd547d3b72565437ead511c9d5342406aace95e5cc31f2c6d618a24c219d0298f980a571ce29b999d20b3ed94a60e286ed7ddc647c439e3d421814da6d91c8b7d5b3fa70ec6a3c261ab9ee4ac779545edcc7db6df3345db26b91c5a997dac5e62b75dc05358821bab4fe65a049fe9ce8e537ae81a10dfcca0ef97c6cb95e3ff1573e5461f7b505e1678fbe97a41ab696b53f6ea09038ecda09eed34247251424b766306c0c64fd836274cf85fe0e0c19638127c2210b580c9194fe0cccac7ac80e3dde38e9cccd6a194ee923f4e73800bec0c77f60553f9c7c8413ea87d20c114d7b415fdd87fc55f273b1b3a9f9c71c4462d5b3f300daf0fc6c338278e5991e0c6de07a3c288d237df00325230be204f7b2bb7a127ac28b001e4225e910eeb9521f5af6cfeae1f18c08bf8ac9d1513c3794ba5b8ea9fb9a57825cb154fc1e9a9dddf809dd6bb11a207625b23b274344e7e0b7ca666e456735d5901f1341aca42e749183823b3debbd563aeebc68f9b15dce13d0fc1acef47d38d5967c2b6b3fe8ed69b180dbcbf17455ee6825641202ccd145c0a0a0f4091622338f48474e5838d8915f814eb87ad45e710b07f79f662c2120278ce05978d8a7aee20fc5661a08c072977ed878092e7183332b70c9c54db307c705e527f6fd2076e39c216b0490f552d52a109652958c62fc6bf7f913818dbdf5d92550779aae541d54d059d5844658422c17a24e374fa6f92e5a9fda87eee249747b9cd292043c9731d2c1d08d06eab030fb49e779cb58bf4f776d6aa0185db860007d8b2d0f7205dacb9201ac9538d2c37062f736b6b44e971e11500", + ) + .unwrap(); let plain_text = decrypt(&key, &cipher_text).unwrap(); let secret_msg = "Last enemy position 0830h AJ 9863".as_bytes().to_vec(); assert_eq!(plain_text, secret_msg); @@ -305,7 +357,7 @@ mod test { } #[test] - fn decryption_fails_if_message_sned_to_incorrect_node() { + fn decryption_fails_if_message_send_to_incorrect_node() { let (sk, pk) = CommsPublicKey::random_keypair(&mut OsRng); let (other_sk, other_pk) = CommsPublicKey::random_keypair(&mut OsRng); @@ -325,4 +377,170 @@ mod test { .to_string() .contains("Authenticated decryption failed")); } + + #[test] + fn pad_message_correctness() { + // test for small message + let message = &[0u8, 10, 22, 11, 38, 74, 59, 91, 73, 82, 75, 23, 59]; + let prepend_message = (message.len() as u32).to_le_bytes(); + let pad = std::iter::repeat(0u8) + .take(MESSAGE_BASE_LENGTH - message.len() - prepend_message.len()) + .collect::>(); + + let pad_message = pad_message_to_base_length_multiple(message); + + // padded message is of correct length + assert_eq!(pad_message.len(), MESSAGE_BASE_LENGTH); + // prepend message is well specified + assert_eq!(prepend_message, pad_message[..prepend_message.len()]); + // message body is well specified + assert_eq!( + *message, + pad_message[prepend_message.len()..prepend_message.len() + message.len()] + ); + // pad is well specified + assert_eq!(pad, pad_message[prepend_message.len() + message.len()..]); + + // test for large message + let message = &[100u8; MESSAGE_BASE_LENGTH * 8 - 100]; + let prepend_message = (message.len() as u32).to_le_bytes(); + let pad_message = pad_message_to_base_length_multiple(message); + let pad = std::iter::repeat(0u8) + .take((8 * MESSAGE_BASE_LENGTH) - message.len() - prepend_message.len()) + .collect::>(); + + // padded message is of correct length + assert_eq!(pad_message.len(), 8 * MESSAGE_BASE_LENGTH); + // prepend message is well specified + assert_eq!(prepend_message, pad_message[..prepend_message.len()]); + // message body is well specified + assert_eq!( + *message, + pad_message[prepend_message.len()..prepend_message.len() + message.len()] + ); + // pad is well specified + assert_eq!(pad, pad_message[prepend_message.len() + message.len()..]); + + // test for base message of multiple base length + let message = &[100u8; MESSAGE_BASE_LENGTH * 9 - 123]; + let prepend_message = (message.len() as u32).to_le_bytes(); + let pad = std::iter::repeat(0u8) + .take((9 * MESSAGE_BASE_LENGTH) - message.len() - prepend_message.len()) + .collect::>(); + + let pad_message = pad_message_to_base_length_multiple(message); + + // padded message is of correct length + assert_eq!(pad_message.len(), 9 * MESSAGE_BASE_LENGTH); + // prepend message is well specified + assert_eq!(prepend_message, pad_message[..prepend_message.len()]); + // message body is well specified + assert_eq!( + *message, + pad_message[prepend_message.len()..prepend_message.len() + message.len()] + ); + // pad is well specified + assert_eq!(pad, pad_message[prepend_message.len() + message.len()..]); + + // test for empty message + let message: [u8; 0] = []; + let prepend_message = (message.len() as u32).to_le_bytes(); + let pad_message = pad_message_to_base_length_multiple(&message); + let pad = [0u8; MESSAGE_BASE_LENGTH - 4]; + + // padded message is of correct length + assert_eq!(pad_message.len(), MESSAGE_BASE_LENGTH); + // prepend message is well specified + assert_eq!(prepend_message, pad_message[..prepend_message.len()]); + // message body is well specified + assert_eq!( + message, + pad_message[prepend_message.len()..prepend_message.len() + message.len()] + ); + + // pad is well specified + assert_eq!(pad, pad_message[prepend_message.len() + message.len()..]); + } + + #[test] + fn get_original_message_from_padded_text_successful() { + // test for short message + let message = vec![0u8, 10, 22, 11, 38, 74, 59, 91, 73, 82, 75, 23, 59]; + let pad_message = pad_message_to_base_length_multiple(message.as_slice()); + + let output_message = get_original_message_from_padded_text(pad_message.as_slice()).unwrap(); + assert_eq!(message, output_message); + + // test for large message + let message = vec![100u8; 1024]; + let pad_message = pad_message_to_base_length_multiple(message.as_slice()); + + let output_message = get_original_message_from_padded_text(pad_message.as_slice()).unwrap(); + assert_eq!(message, output_message); + + // test for base message of base length + let message = vec![100u8; 984]; + let pad_message = pad_message_to_base_length_multiple(message.as_slice()); + + let output_message = get_original_message_from_padded_text(pad_message.as_slice()).unwrap(); + assert_eq!(message, output_message); + + // test for empty message + let message: Vec = vec![]; + let pad_message = pad_message_to_base_length_multiple(message.as_slice()); + + let output_message = get_original_message_from_padded_text(pad_message.as_slice()).unwrap(); + assert_eq!(message, output_message); + } + + #[test] + fn padding_fails_if_pad_message_prepend_length_is_bigger_than_plaintext_length() { + let message = "This is my secret message, keep it secret !".as_bytes(); + let mut pad_message = pad_message_to_base_length_multiple(message); + + // we modify the prepend length, in order to assert that the get original message + // method will output a different length message + pad_message[0] = 1; + + let modified_message = get_original_message_from_padded_text(pad_message.as_slice()).unwrap(); + assert!(message.len() != modified_message.len()); + + // add big number from le bytes of prepend bytes + pad_message[0] = 255; + pad_message[1] = 255; + pad_message[2] = 255; + pad_message[3] = 255; + + assert!(get_original_message_from_padded_text(pad_message.as_slice()) + .unwrap_err() + .to_string() + .contains("Original length message is invalid")); + } + + #[test] + fn check_decryption_succeeds_if_pad_message_padding_is_modified() { + // this should not be problematic as any changes in the content of the encrypted padding, should not affect + // in any way the value of the decrypted content, by applying a cipher stream + let pk = CommsPublicKey::default(); + let key = CipherKey(*chacha20::Key::from_slice(pk.as_bytes())); + let message = "My secret message, keep it secret !".as_bytes().to_vec(); + let mut encrypted = encrypt(&key, &message); + + let n = encrypted.len(); + encrypted[n - 1] += 1; + + assert!(decrypt(&key, &encrypted).unwrap() == message); + } + + #[test] + fn decryption_fails_if_message_body_is_modified() { + let pk = CommsPublicKey::default(); + let key = CipherKey(*chacha20::Key::from_slice(pk.as_bytes())); + let message = "My secret message, keep it secret !".as_bytes().to_vec(); + let mut encrypted = encrypt(&key, &message); + + encrypted[size_of::() + LITTLE_ENDIAN_U32_SIZE_REPRESENTATION + 1] += 1; + + assert!(decrypt(&key, &encrypted).unwrap() != message); + } } diff --git a/dan_layer/engine/tests/hello_world/Cargo.lock b/dan_layer/engine/tests/hello_world/Cargo.lock index b09f1bec3b..7d65fd86d2 100644 --- a/dan_layer/engine/tests/hello_world/Cargo.lock +++ b/dan_layer/engine/tests/hello_world/Cargo.lock @@ -171,6 +171,7 @@ dependencies = [ "quote", "syn", "tari_template_abi", + "tari_template_lib", ] [[package]] diff --git a/dan_layer/engine/tests/state/Cargo.lock b/dan_layer/engine/tests/state/Cargo.lock index 9964c09d41..f89e31465c 100644 --- a/dan_layer/engine/tests/state/Cargo.lock +++ b/dan_layer/engine/tests/state/Cargo.lock @@ -107,9 +107,9 @@ dependencies = [ [[package]] name = "proc-macro2" -version = "1.0.40" +version = "1.0.42" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "dd96a1e8ed2596c337f8eae5f24924ec83f5ad5ab21ea8e455d3566c69fbcaf7" +checksum = "c278e965f1d8cf32d6e0e96de3d3e79712178ae67986d9cf9151f51e95aac89b" dependencies = [ "unicode-ident", ] @@ -135,6 +135,7 @@ version = "0.1.0" dependencies = [ "tari_template_abi", "tari_template_lib", + "tari_template_macros", ] [[package]] @@ -162,6 +163,17 @@ dependencies = [ "tari_template_abi", ] +[[package]] +name = "tari_template_macros" +version = "0.1.0" +dependencies = [ + "proc-macro2", + "quote", + "syn", + "tari_template_abi", + "tari_template_lib", +] + [[package]] name = "toml" version = "0.5.9" diff --git a/dan_layer/engine/tests/state/Cargo.toml b/dan_layer/engine/tests/state/Cargo.toml index 19b00846d8..9374e4d4a2 100644 --- a/dan_layer/engine/tests/state/Cargo.toml +++ b/dan_layer/engine/tests/state/Cargo.toml @@ -9,6 +9,7 @@ edition = "2021" [dependencies] tari_template_abi = { path = "../../../template_abi" } tari_template_lib = { path = "../../../template_lib" } +tari_template_macros = { path = "../../../template_macros" } [profile.release] opt-level = 's' # Optimize for size. diff --git a/dan_layer/engine/tests/state/src/lib.rs b/dan_layer/engine/tests/state/src/lib.rs index 9e516b4fea..a8ed6e3f17 100644 --- a/dan_layer/engine/tests/state/src/lib.rs +++ b/dan_layer/engine/tests/state/src/lib.rs @@ -20,23 +20,15 @@ // WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE // USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -use tari_template_abi::{decode, encode_with_len, FunctionDef, Type}; -use tari_template_lib::{call_engine, generate_abi, generate_main, TemplateImpl}; +use tari_template_macros::template; -// that's what the example should look like from the user's perspective -#[allow(dead_code)] +#[template] mod state_template { - use tari_template_abi::{borsh, Decode, Encode}; - - // #[tari::template] - #[derive(Encode, Decode)] pub struct State { - value: u32, + pub value: u32, } - // #[tari::impl] impl State { - // #[tari::constructor] pub fn new() -> Self { Self { value: 0 } } @@ -50,97 +42,3 @@ mod state_template { } } } - -// TODO: Macro generated code -#[no_mangle] -extern "C" fn State_abi() -> *mut u8 { - let template_name = "State".to_string(); - - let functions = vec![ - FunctionDef { - name: "new".to_string(), - arguments: vec![], - output: Type::U32, // the component_id - }, - FunctionDef { - name: "set".to_string(), - arguments: vec![Type::U32, Type::U32], // the component_id and the new value - output: Type::Unit, // does not return anything - }, - FunctionDef { - name: "get".to_string(), - arguments: vec![Type::U32], // the component_id - output: Type::U32, // the stored value - }, - ]; - - generate_abi(template_name, functions) -} - -#[no_mangle] -unsafe extern "C" fn State_main(call_info: *mut u8, call_info_len: usize) -> *mut u8 { - let mut template_impl = TemplateImpl::new(); - use tari_template_abi::{ops::*, CreateComponentArg, EmitLogArg, LogLevel}; - use tari_template_lib::models::ComponentId; - - tari_template_lib::call_engine::<_, ()>(OP_EMIT_LOG, &EmitLogArg { - message: "This is a log message from State_main!".to_string(), - level: LogLevel::Info, - }); - - // constructor - template_impl.add_function( - "new".to_string(), - Box::new(|_| { - let ret = state_template::State::new(); - let encoded = encode_with_len(&ret); - // Call the engine to create a new component - // TODO: proper component id - // The macro will know to generate this call because of the #[tari(constructor)] attribute - // TODO: what happens if the user wants to return multiple components/types? - let component_id = call_engine::<_, ComponentId>(OP_CREATE_COMPONENT, &CreateComponentArg { - name: "State".to_string(), - quantity: 1, - metadata: Default::default(), - state: encoded, - }); - let component_id = component_id.expect("no asset id returned"); - encode_with_len(&component_id) - }), - ); - - template_impl.add_function( - "set".to_string(), - Box::new(|args| { - // read the function paramenters - let _component_id: u32 = decode(&args[0]).unwrap(); - let _new_value: u32 = decode(&args[1]).unwrap(); - - // update the component value - // TODO: use a real op code (not "123") when they are implemented - call_engine::<_, ()>(123, &()); - - // the function does not return any value - // TODO: implement "Unit" type empty responses. Right now this fails: wrap_ptr(vec![]) - encode_with_len(&0) - }), - ); - - template_impl.add_function( - "get".to_string(), - Box::new(|args| { - // read the function paramenters - let _component_id: u32 = decode(&args[0]).unwrap(); - - // get the component state - // TODO: use a real op code (not "123") when they are implemented - let _state = call_engine::<_, ()>(123, &()); - - // return the value - let value = 1_u32; // TODO: read from the component state - encode_with_len(&value) - }), - ); - - generate_main(call_info, call_info_len, template_impl) -} diff --git a/dan_layer/engine/tests/test.rs b/dan_layer/engine/tests/test.rs index 65194806b9..df7468e79a 100644 --- a/dan_layer/engine/tests/test.rs +++ b/dan_layer/engine/tests/test.rs @@ -46,13 +46,11 @@ fn test_hello_world() { #[test] fn test_state() { + // TODO: use the Component and ComponentId types in the template let template_test = TemplateTest::new("State".to_string(), "tests/state".to_string()); // constructor let component: ComponentId = template_test.call_function("new".to_string(), vec![]); - assert_eq!(component.1, 0); - let component: ComponentId = template_test.call_function("new".to_string(), vec![]); - assert_eq!(component.1, 1); // call the "set" method to update the instance value let new_value = 20_u32; @@ -60,11 +58,13 @@ fn test_state() { encode_with_len(&component), encode_with_len(&new_value), ]); + // call the "get" method to get the current value let value: u32 = template_test.call_method("State".to_string(), "get".to_string(), vec![encode_with_len( &component, )]); - assert_eq!(value, 1); + // TODO: when state storage is implemented in the engine, assert the previous setted value (20_u32) + assert_eq!(value, 0); } struct TemplateTest { diff --git a/dan_layer/template_lib/src/lib.rs b/dan_layer/template_lib/src/lib.rs index f178a71471..c5f2689371 100644 --- a/dan_layer/template_lib/src/lib.rs +++ b/dan_layer/template_lib/src/lib.rs @@ -33,7 +33,7 @@ pub mod models; // TODO: we should only use stdlib if the template dev needs to include it e.g. use core::mem when stdlib is not // available -use std::{collections::HashMap, mem, ptr::copy, slice}; +use std::{collections::HashMap, mem, slice}; use tari_template_abi::{encode_with_len, Decode, Encode, FunctionDef, TemplateDef}; diff --git a/dan_layer/template_lib/src/models/component.rs b/dan_layer/template_lib/src/models/component.rs index 3b27286bbc..6b377bc74c 100644 --- a/dan_layer/template_lib/src/models/component.rs +++ b/dan_layer/template_lib/src/models/component.rs @@ -20,4 +20,42 @@ // WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE // USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +// TODO: use the actual component id type pub type ComponentId = ([u8; 32], u32); + +use tari_template_abi::{Decode, Encode, encode_with_len, ops::OP_CREATE_COMPONENT, CreateComponentArg}; + +use crate::call_engine; + +pub fn initialise(template_name: String, initial_state: T) -> ComponentId { + let encoded_state = encode_with_len(&initial_state); + + // Call the engine to create a new component + // TODO: proper component id + // TODO: what happens if the user wants to return multiple components/types? + let component_id = call_engine::<_, ComponentId>(OP_CREATE_COMPONENT, &CreateComponentArg { + name: template_name, + quantity: 1, + metadata: Default::default(), + state: encoded_state, + }); + component_id.expect("no asset id returned") +} + +pub fn get_state(_id: u32) -> T { + // get the component state + // TODO: use a real op code (not "123") when they are implemented + let _state = call_engine::<_, ()>(123, &()); + + // create and return a mock state because state is not implemented yet in the engine + let len = std::mem::size_of::(); + let byte_vec = vec![0_u8; len]; + let mut mock_value = byte_vec.as_slice(); + T::deserialize(&mut mock_value).unwrap() +} + +pub fn set_state(_id: u32, _state: T) { + // update the component value + // TODO: use a real op code (not "123") when they are implemented + call_engine::<_, ()>(123, &()); +} diff --git a/dan_layer/template_lib/src/models/mod.rs b/dan_layer/template_lib/src/models/mod.rs index ef04fea78d..a2237b672d 100644 --- a/dan_layer/template_lib/src/models/mod.rs +++ b/dan_layer/template_lib/src/models/mod.rs @@ -21,4 +21,4 @@ // USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. mod component; -pub use component::ComponentId; +pub use component::*; diff --git a/dan_layer/template_macros/Cargo.lock b/dan_layer/template_macros/Cargo.lock new file mode 100644 index 0000000000..746c58c20a --- /dev/null +++ b/dan_layer/template_macros/Cargo.lock @@ -0,0 +1,200 @@ +# This file is automatically @generated by Cargo. +# It is not intended for manual editing. +version = 3 + +[[package]] +name = "ahash" +version = "0.7.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fcb51a0695d8f838b1ee009b3fbf66bda078cd64590202a864a8f3e8c4315c47" +dependencies = [ + "getrandom", + "once_cell", + "version_check", +] + +[[package]] +name = "borsh" +version = "0.9.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "15bf3650200d8bffa99015595e10f1fbd17de07abbc25bb067da79e769939bfa" +dependencies = [ + "borsh-derive", + "hashbrown", +] + +[[package]] +name = "borsh-derive" +version = "0.9.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6441c552f230375d18e3cc377677914d2ca2b0d36e52129fe15450a2dce46775" +dependencies = [ + "borsh-derive-internal", + "borsh-schema-derive-internal", + "proc-macro-crate", + "proc-macro2", + "syn", +] + +[[package]] +name = "borsh-derive-internal" +version = "0.9.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5449c28a7b352f2d1e592a8a28bf139bc71afb0764a14f3c02500935d8c44065" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "borsh-schema-derive-internal" +version = "0.9.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cdbd5696d8bfa21d53d9fe39a714a18538bad11492a42d066dbbc395fb1951c0" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "cfg-if" +version = "1.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd" + +[[package]] +name = "getrandom" +version = "0.2.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4eb1a864a501629691edf6c15a593b7a51eebaa1e8468e9ddc623de7c9b58ec6" +dependencies = [ + "cfg-if", + "libc", + "wasi", +] + +[[package]] +name = "hashbrown" +version = "0.11.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ab5ef0d4909ef3724cc8cce6ccc8572c5c817592e9285f5464f8e86f8bd3726e" +dependencies = [ + "ahash", +] + +[[package]] +name = "indoc" +version = "1.0.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "05a0bd019339e5d968b37855180087b7b9d512c5046fbd244cf8c95687927d6e" + +[[package]] +name = "libc" +version = "0.2.126" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "349d5a591cd28b49e1d1037471617a32ddcda5731b99419008085f72d5a53836" + +[[package]] +name = "once_cell" +version = "1.13.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "18a6dbe30758c9f83eb00cbea4ac95966305f5a7772f3f42ebfc7fc7eddbd8e1" + +[[package]] +name = "proc-macro-crate" +version = "0.1.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1d6ea3c4595b96363c13943497db34af4460fb474a95c43f4446ad341b8c9785" +dependencies = [ + "toml", +] + +[[package]] +name = "proc-macro2" +version = "1.0.42" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c278e965f1d8cf32d6e0e96de3d3e79712178ae67986d9cf9151f51e95aac89b" +dependencies = [ + "unicode-ident", +] + +[[package]] +name = "quote" +version = "1.0.20" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3bcdf212e9776fbcb2d23ab029360416bb1706b1aea2d1a5ba002727cbcab804" +dependencies = [ + "proc-macro2", +] + +[[package]] +name = "serde" +version = "1.0.140" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fc855a42c7967b7c369eb5860f7164ef1f6f81c20c7cc1141f2a604e18723b03" + +[[package]] +name = "syn" +version = "1.0.98" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c50aef8a904de4c23c788f104b7dddc7d6f79c647c7c8ce4cc8f73eb0ca773dd" +dependencies = [ + "proc-macro2", + "quote", + "unicode-ident", +] + +[[package]] +name = "tari_template_abi" +version = "0.1.0" +dependencies = [ + "borsh", +] + +[[package]] +name = "tari_template_lib" +version = "0.1.0" +dependencies = [ + "tari_template_abi", +] + +[[package]] +name = "tari_template_macros" +version = "0.1.0" +dependencies = [ + "indoc", + "proc-macro2", + "quote", + "syn", + "tari_template_abi", + "tari_template_lib", +] + +[[package]] +name = "toml" +version = "0.5.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8d82e1a7758622a465f8cee077614c73484dac5b836c02ff6a40d5d1010324d7" +dependencies = [ + "serde", +] + +[[package]] +name = "unicode-ident" +version = "1.0.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "15c61ba63f9235225a22310255a29b806b907c9b8c964bcbd0a2c70f3f2deea7" + +[[package]] +name = "version_check" +version = "0.9.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "49874b5167b65d7193b8aba1567f5c7d93d001cafc34600cee003eda787e483f" + +[[package]] +name = "wasi" +version = "0.11.0+wasi-snapshot-preview1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9c8d87e72b64a3b4db28d11ce29237c246188f4f51057d65a7eab63b7987e423" diff --git a/dan_layer/template_macros/Cargo.toml b/dan_layer/template_macros/Cargo.toml index 8b077deb05..a3335cabe6 100644 --- a/dan_layer/template_macros/Cargo.toml +++ b/dan_layer/template_macros/Cargo.toml @@ -10,6 +10,7 @@ proc-macro = true [dependencies] tari_template_abi = { path = "../template_abi" } +tari_template_lib = { path = "../template_lib" } syn = { version = "1.0.98", features = ["full"] } proc-macro2 = "1.0.42" quote = "1.0.20" diff --git a/dan_layer/template_macros/src/ast.rs b/dan_layer/template_macros/src/ast.rs index fd6f458297..27079f882a 100644 --- a/dan_layer/template_macros/src/ast.rs +++ b/dan_layer/template_macros/src/ast.rs @@ -34,6 +34,7 @@ use syn::{ ItemStruct, Result, ReturnType, + Signature, Stmt, }; @@ -95,38 +96,44 @@ impl TemplateAst { match item { ImplItem::Method(m) => FunctionAst { name: m.sig.ident.to_string(), - input_types: Self::get_input_type_tokens(&m.sig.inputs), + input_types: Self::get_input_types(&m.sig.inputs), output_type: Self::get_output_type_token(&m.sig.output), statements: Self::get_statements(m), + is_constructor: Self::is_constructor(&m.sig), }, _ => todo!(), } } - fn get_input_type_tokens(inputs: &Punctuated) -> Vec { + fn get_input_types(inputs: &Punctuated) -> Vec { inputs .iter() .map(|arg| match arg { // TODO: handle the "self" case - syn::FnArg::Receiver(_) => todo!(), - syn::FnArg::Typed(t) => Self::get_type_token(&t.ty), + syn::FnArg::Receiver(r) => { + // TODO: validate that it's indeed a reference ("&") to self + + let mutability = r.mutability.is_some(); + TypeAst::Receiver { mutability } + }, + syn::FnArg::Typed(t) => Self::get_type_ast(&t.ty), }) .collect() } - fn get_output_type_token(ast_type: &ReturnType) -> String { + fn get_output_type_token(ast_type: &ReturnType) -> Option { match ast_type { - syn::ReturnType::Default => String::new(), // the function does not return anything - syn::ReturnType::Type(_, t) => Self::get_type_token(t), + syn::ReturnType::Default => None, // the function does not return anything + syn::ReturnType::Type(_, t) => Some(Self::get_type_ast(t)), } } - fn get_type_token(syn_type: &syn::Type) -> String { + fn get_type_ast(syn_type: &syn::Type) -> TypeAst { match syn_type { syn::Type::Path(type_path) => { // TODO: handle "Self" // TODO: detect more complex types - type_path.path.segments[0].ident.to_string() + TypeAst::Typed(type_path.path.segments[0].ident.clone()) }, _ => todo!(), } @@ -135,11 +142,27 @@ impl TemplateAst { fn get_statements(method: &ImplItemMethod) -> Vec { method.block.stmts.clone() } + + fn is_constructor(sig: &Signature) -> bool { + match &sig.output { + syn::ReturnType::Default => false, // the function does not return anything + syn::ReturnType::Type(_, t) => match t.as_ref() { + syn::Type::Path(type_path) => type_path.path.segments[0].ident == "Self", + _ => false, + }, + } + } } pub struct FunctionAst { pub name: String, - pub input_types: Vec, - pub output_type: String, + pub input_types: Vec, + pub output_type: Option, pub statements: Vec, + pub is_constructor: bool, +} + +pub enum TypeAst { + Receiver { mutability: bool }, + Typed(Ident), } diff --git a/dan_layer/template_macros/src/template/abi.rs b/dan_layer/template_macros/src/template/abi.rs index e1386b3198..a2c964f019 100644 --- a/dan_layer/template_macros/src/template/abi.rs +++ b/dan_layer/template_macros/src/template/abi.rs @@ -24,7 +24,7 @@ use proc_macro2::TokenStream; use quote::{format_ident, quote}; use syn::{parse_quote, Expr, Result}; -use crate::ast::{FunctionAst, TemplateAst}; +use crate::ast::{FunctionAst, TemplateAst, TypeAst}; pub fn generate_abi(ast: &TemplateAst) -> Result { let abi_function_name = format_ident!("{}_abi", ast.struct_section.ident); @@ -51,13 +51,13 @@ pub fn generate_abi(ast: &TemplateAst) -> Result { fn generate_function_def(f: &FunctionAst) -> Expr { let name = f.name.clone(); - let arguments: Vec = f - .input_types - .iter() - .map(String::as_str) - .map(generate_abi_type) - .collect(); - let output = generate_abi_type(&f.output_type); + + let arguments: Vec = f.input_types.iter().map(generate_abi_type).collect(); + + let output = match &f.output_type { + Some(type_ast) => generate_abi_type(type_ast), + None => parse_quote!(Type::Unit), + }; parse_quote!( FunctionDef { @@ -68,26 +68,36 @@ fn generate_function_def(f: &FunctionAst) -> Expr { ) } -fn generate_abi_type(rust_type: &str) -> Expr { - // TODO: there may be a better way of handling this +fn generate_abi_type(rust_type: &TypeAst) -> Expr { match rust_type { - "" => parse_quote!(Type::Unit), - "bool" => parse_quote!(Type::Bool), - "i8" => parse_quote!(Type::I8), - "i16" => parse_quote!(Type::I16), - "i32" => parse_quote!(Type::I32), - "i64" => parse_quote!(Type::I64), - "i128" => parse_quote!(Type::I128), - "u8" => parse_quote!(Type::U8), - "u16" => parse_quote!(Type::U16), - "u32" => parse_quote!(Type::U32), - "u64" => parse_quote!(Type::U64), - "u128" => parse_quote!(Type::U128), - "String" => parse_quote!(Type::String), - _ => todo!(), + // on "&self" we want to pass the component id + TypeAst::Receiver { .. } => get_component_id_type(), + // basic type + // TODO: there may be a better way of handling this + TypeAst::Typed(ident) => match ident.to_string().as_str() { + "" => parse_quote!(Type::Unit), + "bool" => parse_quote!(Type::Bool), + "i8" => parse_quote!(Type::I8), + "i16" => parse_quote!(Type::I16), + "i32" => parse_quote!(Type::I32), + "i64" => parse_quote!(Type::I64), + "i128" => parse_quote!(Type::I128), + "u8" => parse_quote!(Type::U8), + "u16" => parse_quote!(Type::U16), + "u32" => parse_quote!(Type::U32), + "u64" => parse_quote!(Type::U64), + "u128" => parse_quote!(Type::U128), + "String" => parse_quote!(Type::String), + "Self" => get_component_id_type(), + _ => todo!(), + }, } } +fn get_component_id_type() -> Expr { + parse_quote!(Type::U32) +} + #[cfg(test)] mod tests { use std::str::FromStr; @@ -101,7 +111,7 @@ mod tests { use crate::ast::TemplateAst; #[test] - fn test_hello_world() { + fn test_signatures() { let input = TokenStream::from_str(indoc! {" mod foo { struct Foo {} @@ -112,7 +122,9 @@ mod tests { pub fn some_args_function(a: i8, b: String) -> u32 { 1_u32 } - pub fn no_return_function() {} + pub fn no_return_function() {} + pub fn constructor() -> Self {} + pub fn method(&self){} } } "}) @@ -144,6 +156,16 @@ mod tests { name: "no_return_function".to_string(), arguments: vec![], output: Type::Unit, + }, + FunctionDef { + name: "constructor".to_string(), + arguments: vec![], + output: Type::U32, + }, + FunctionDef { + name: "method".to_string(), + arguments: vec![Type::U32], + output: Type::Unit, } ], }; diff --git a/dan_layer/template_macros/src/template/definition.rs b/dan_layer/template_macros/src/template/definition.rs index dbc330bdb1..f3c98825ed 100644 --- a/dan_layer/template_macros/src/template/definition.rs +++ b/dan_layer/template_macros/src/template/definition.rs @@ -27,15 +27,16 @@ use crate::ast::TemplateAst; pub fn generate_definition(ast: &TemplateAst) -> TokenStream { let template_name = format_ident!("{}", ast.struct_section.ident); + let template_fields = &ast.struct_section.fields; + let semi_token = &ast.struct_section.semi_token; let functions = &ast.impl_section.items; quote! { pub mod template { - use super::*; + use tari_template_abi::borsh; - pub struct #template_name { - // TODO: fill template fields - } + #[derive(tari_template_abi::borsh::BorshSerialize, tari_template_abi::borsh::BorshDeserialize)] + pub struct #template_name #template_fields #semi_token impl #template_name { #(#functions)* diff --git a/dan_layer/template_macros/src/template/dispatcher.rs b/dan_layer/template_macros/src/template/dispatcher.rs index 12ebede5f3..90339769d2 100644 --- a/dan_layer/template_macros/src/template/dispatcher.rs +++ b/dan_layer/template_macros/src/template/dispatcher.rs @@ -20,11 +20,11 @@ // WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE // USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -use proc_macro2::{Span, TokenStream}; +use proc_macro2::{Ident, Span, TokenStream}; use quote::{format_ident, quote}; -use syn::{token::Brace, Block, Expr, ExprBlock, Result}; +use syn::{parse_quote, token::Brace, Block, Expr, ExprBlock, Result}; -use crate::ast::TemplateAst; +use crate::ast::{FunctionAst, TemplateAst, TypeAst}; pub fn generate_dispatcher(ast: &TemplateAst) -> Result { let dispatcher_function_name = format_ident!("{}_main", ast.struct_section.ident); @@ -35,6 +35,7 @@ pub fn generate_dispatcher(ast: &TemplateAst) -> Result { #[no_mangle] pub extern "C" fn #dispatcher_function_name(call_info: *mut u8, call_info_len: usize) -> *mut u8 { use ::tari_template_abi::{decode, encode_with_len, CallInfo}; + use ::tari_template_lib::models::{get_state, set_state, initialise}; if call_info.is_null() { panic!("call_info is null"); @@ -43,94 +44,113 @@ pub fn generate_dispatcher(ast: &TemplateAst) -> Result { let call_data = unsafe { Vec::from_raw_parts(call_info, call_info_len, call_info_len) }; let call_info: CallInfo = decode(&call_data).unwrap(); - let result = match call_info.func_name.as_str() { - #( #function_names => #function_blocks )*, + let result; + match call_info.func_name.as_str() { + #( #function_names => #function_blocks ),*, _ => panic!("invalid function name") }; - wrap_ptr(encode_with_len(&result)) + wrap_ptr(result) } }; Ok(output) } -pub fn get_function_names(ast: &TemplateAst) -> Vec { +fn get_function_names(ast: &TemplateAst) -> Vec { ast.get_functions().iter().map(|f| f.name.clone()).collect() } -pub fn get_function_blocks(ast: &TemplateAst) -> Vec { +fn get_function_blocks(ast: &TemplateAst) -> Vec { let mut blocks = vec![]; for function in ast.get_functions() { - let statements = function.statements; - blocks.push(Expr::Block(ExprBlock { - attrs: vec![], - label: None, - block: Block { - brace_token: Brace { - span: Span::call_site(), - }, - stmts: statements, - }, - })); + let block = get_function_block(&ast.template_name, function); + blocks.push(block); } blocks } -#[cfg(test)] -mod tests { - use std::str::FromStr; - - use indoc::indoc; - use proc_macro2::TokenStream; - use quote::quote; - use syn::parse2; - - use crate::{ast::TemplateAst, template::dispatcher::generate_dispatcher}; - - #[test] - fn test_hello_world() { - let input = TokenStream::from_str(indoc! {" - mod hello_world { - struct HelloWorld {} - impl HelloWorld { - pub fn greet() -> String { - \"Hello World!\".to_string() - } - } - } - "}) - .unwrap(); - - let ast = parse2::(input).unwrap(); - - let output = generate_dispatcher(&ast).unwrap(); - - assert_code_eq(output, quote! { - #[no_mangle] - pub extern "C" fn HelloWorld_main(call_info: *mut u8, call_info_len: usize) -> *mut u8 { - use ::tari_template_abi::{decode, encode_with_len, CallInfo}; - - if call_info.is_null() { - panic!("call_info is null"); +fn get_function_block(template_ident: &Ident, ast: FunctionAst) -> Expr { + let mut args: Vec = vec![]; + let mut stmts = vec![]; + let mut should_get_state = false; + let mut should_set_state = false; + + // encode all arguments of the functions + for (i, input_type) in ast.input_types.into_iter().enumerate() { + let arg_ident = format_ident!("arg_{}", i); + let stmt = match input_type { + // "self" argument + TypeAst::Receiver { mutability } => { + should_get_state = true; + should_set_state = mutability; + args.push(parse_quote! { &mut state }); + parse_quote! { + let #arg_ident = + decode::(&call_info.args[#i]) + .unwrap(); + } + }, + // non-self argument + TypeAst::Typed(type_ident) => { + args.push(parse_quote! { #arg_ident }); + parse_quote! { + let #arg_ident = + decode::<#type_ident>(&call_info.args[#i]) + .unwrap(); } + }, + }; + stmts.push(stmt); + } - let call_data = unsafe { Vec::from_raw_parts(call_info, call_info_len, call_info_len) }; - let call_info: CallInfo = decode(&call_data).unwrap(); + // load the component state + if should_get_state { + stmts.push(parse_quote! { + let mut state: template::#template_ident = get_state(arg_0); + }); + } - let result = match call_info.func_name.as_str() { - "greet" => { "Hello World!".to_string() }, - _ => panic!("invalid function name") - }; + // call the user defined function in the template + let function_ident = Ident::new(&ast.name, Span::call_site()); + if ast.is_constructor { + stmts.push(parse_quote! { + let state = template::#template_ident::#function_ident(#(#args),*); + }); - wrap_ptr(encode_with_len(&result)) - } + let template_name_str = template_ident.to_string(); + stmts.push(parse_quote! { + let rtn = initialise(#template_name_str.to_string(), state); + }); + } else { + stmts.push(parse_quote! { + let rtn = template::#template_ident::#function_ident(#(#args),*); }); } - fn assert_code_eq(a: TokenStream, b: TokenStream) { - assert_eq!(a.to_string(), b.to_string()); + // encode the result value + stmts.push(parse_quote! { + result = encode_with_len(&rtn); + }); + + // after user function invocation, update the component state + if should_set_state { + stmts.push(parse_quote! { + set_state(arg_0, state); + }); } + + // construct the code block for the function + Expr::Block(ExprBlock { + attrs: vec![], + label: None, + block: Block { + brace_token: Brace { + span: Span::call_site(), + }, + stmts, + }, + }) } diff --git a/dan_layer/template_macros/src/template/mod.rs b/dan_layer/template_macros/src/template/mod.rs index e717fd73db..e0bd5541d9 100644 --- a/dan_layer/template_macros/src/template/mod.rs +++ b/dan_layer/template_macros/src/template/mod.rs @@ -57,3 +57,170 @@ pub fn generate_template(input: TokenStream) -> Result { Ok(output) } + +#[cfg(test)] +mod tests { + use std::str::FromStr; + + use indoc::indoc; + use proc_macro2::TokenStream; + use quote::quote; + + use super::generate_template; + + #[test] + #[allow(clippy::too_many_lines)] + fn test_state() { + let input = TokenStream::from_str(indoc! {" + mod test { + struct State { + value: u32 + } + impl State { + pub fn new() -> Self { + Self { value: 0 } + } + pub fn get(&self) -> u32 { + self.value + } + pub fn set(&mut self, value: u32) { + self.value = value; + } + } + } + "}) + .unwrap(); + + let output = generate_template(input).unwrap(); + + assert_code_eq(output, quote! { + pub mod template { + use tari_template_abi::borsh; + + #[derive(tari_template_abi::borsh::BorshSerialize, tari_template_abi::borsh::BorshDeserialize)] + pub struct State { + value: u32 + } + + impl State { + pub fn new() -> Self { + Self { value: 0 } + } + pub fn get(&self) -> u32 { + self.value + } + pub fn set(&mut self, value: u32) { + self.value = value; + } + } + } + + #[no_mangle] + pub extern "C" fn State_abi() -> *mut u8 { + use ::tari_template_abi::{encode_with_len, FunctionDef, TemplateDef, Type}; + + let template = TemplateDef { + template_name: "State".to_string(), + functions: vec![ + FunctionDef { + name: "new".to_string(), + arguments: vec![], + output: Type::U32, + }, + FunctionDef { + name: "get".to_string(), + arguments: vec![Type::U32], + output: Type::U32, + }, + FunctionDef { + name: "set".to_string(), + arguments: vec![Type::U32, Type::U32], + output: Type::Unit, + } + ], + }; + + let buf = encode_with_len(&template); + wrap_ptr(buf) + } + + #[no_mangle] + pub extern "C" fn State_main(call_info: *mut u8, call_info_len: usize) -> *mut u8 { + use ::tari_template_abi::{decode, encode_with_len, CallInfo}; + use ::tari_template_lib::models::{get_state, set_state, initialise}; + + if call_info.is_null() { + panic!("call_info is null"); + } + + let call_data = unsafe { Vec::from_raw_parts(call_info, call_info_len, call_info_len) }; + let call_info: CallInfo = decode(&call_data).unwrap(); + + let result; + match call_info.func_name.as_str() { + "new" => { + let state = template::State::new(); + let rtn = initialise("State".to_string(), state); + result = encode_with_len(&rtn); + }, + "get" => { + let arg_0 = decode::(&call_info.args[0usize]).unwrap(); + let mut state: template::State = get_state(arg_0); + let rtn = template::State::get(&mut state); + result = encode_with_len(&rtn); + }, + "set" => { + let arg_0 = decode::(&call_info.args[0usize]).unwrap(); + let arg_1 = decode::(&call_info.args[1usize]).unwrap(); + let mut state: template::State = get_state(arg_0); + let rtn = template::State::set(&mut state, arg_1); + result = encode_with_len(&rtn); + set_state(arg_0, state); + }, + _ => panic!("invalid function name") + }; + + wrap_ptr(result) + } + + extern "C" { + pub fn tari_engine(op: u32, input_ptr: *const u8, input_len: usize) -> *mut u8; + } + + pub fn wrap_ptr(mut v: Vec) -> *mut u8 { + use std::mem; + + let ptr = v.as_mut_ptr(); + mem::forget(v); + ptr + } + + #[no_mangle] + pub unsafe extern "C" fn tari_alloc(len: u32) -> *mut u8 { + use std::{mem, intrinsics::copy}; + + let cap = (len + 4) as usize; + let mut buf = Vec::::with_capacity(cap); + let ptr = buf.as_mut_ptr(); + mem::forget(buf); + copy(len.to_le_bytes().as_ptr(), ptr, 4); + ptr + } + + #[no_mangle] + pub unsafe extern "C" fn tari_free(ptr: *mut u8) { + use std::intrinsics::copy; + + let mut len = [0u8; 4]; + copy(ptr, len.as_mut_ptr(), 4); + + let cap = (u32::from_le_bytes(len) + 4) as usize; + let _ = Vec::::from_raw_parts(ptr, cap, cap); + } + }); + } + + fn assert_code_eq(a: TokenStream, b: TokenStream) { + assert_eq!(a.to_string(), b.to_string()); + } +}