diff --git a/src/bin/ch-remote.rs b/src/bin/ch-remote.rs index 1d4e2a4d87..888946247d 100644 --- a/src/bin/ch-remote.rs +++ b/src/bin/ch-remote.rs @@ -445,14 +445,14 @@ fn rest_api_do_command(matches: &ArgMatches, socket: &mut UnixStream) -> ApiResu .map_err(Error::HttpApiClient) } Some("restore") => { - let restore_config = restore_config( + let (restore_config, net_fds) = restore_config( matches .subcommand_matches("restore") .unwrap() .get_one::("restore_config") .unwrap(), )?; - simple_api_command(socket, "PUT", "restore", Some(&restore_config)) + simple_api_command_with_fds(socket, "PUT", "restore", Some(&restore_config), net_fds) .map_err(Error::HttpApiClient) } Some("coredump") => { @@ -661,7 +661,7 @@ fn dbus_api_do_command(matches: &ArgMatches, proxy: &DBusApi1ProxyBlocking<'_>) proxy.api_vm_snapshot(&snapshot_config) } Some("restore") => { - let restore_config = restore_config( + let (restore_config, _net_fds) = restore_config( matches .subcommand_matches("restore") .unwrap() @@ -849,11 +849,15 @@ fn snapshot_config(url: &str) -> String { serde_json::to_string(&snapshot_config).unwrap() } -fn restore_config(config: &str) -> Result { - let restore_config = vmm::config::RestoreConfig::parse(config).map_err(Error::Restore)?; +fn restore_config(config: &str) -> Result<(String, Vec), Error> { + let mut restore_config = vmm::config::RestoreConfig::parse(config).map_err(Error::Restore)?; + + // RestoreConfig is modified on purpose to take out the file descriptors. + // These fds are passed to the server side process via SCM_RIGHTS + let net_fds = restore_config.net_fds.take().unwrap_or_default(); let restore_config = serde_json::to_string(&restore_config).unwrap(); - Ok(restore_config) + Ok((restore_config, net_fds)) } fn coredump_config(destination_url: &str) -> String { diff --git a/vmm/src/api/http/http_endpoint.rs b/vmm/src/api/http/http_endpoint.rs index 81825685ec..3e0c6c3e28 100644 --- a/vmm/src/api/http/http_endpoint.rs +++ b/vmm/src/api/http/http_endpoint.rs @@ -13,7 +13,7 @@ use crate::api::{ VmReboot, VmReceiveMigration, VmRemoveDevice, VmResize, VmResizeZone, VmRestore, VmResume, VmSendMigration, VmShutdown, VmSnapshot, }; -use crate::config::NetConfig; +use crate::config::{NetConfig, RestoreConfig}; use micro_http::{Body, Method, Request, Response, StatusCode, Version}; use std::fs::File; use std::os::unix::io::IntoRawFd; @@ -184,7 +184,6 @@ vm_action_put_handler_body!(VmAddUserDevice); vm_action_put_handler_body!(VmRemoveDevice); vm_action_put_handler_body!(VmResize); vm_action_put_handler_body!(VmResizeZone); -vm_action_put_handler_body!(VmRestore); vm_action_put_handler_body!(VmSnapshot); vm_action_put_handler_body!(VmReceiveMigration); vm_action_put_handler_body!(VmSendMigration); @@ -220,6 +219,34 @@ impl PutHandler for VmAddNet { impl GetHandler for VmAddNet {} +impl PutHandler for VmRestore { + fn handle_request( + &'static self, + api_notifier: EventFd, + api_sender: Sender, + body: &Option, + mut files: Vec, + ) -> std::result::Result, HttpError> { + if let Some(body) = body { + let mut restore_cfg: RestoreConfig = serde_json::from_slice(body.raw())?; + if restore_cfg.net_fds.is_some() { + warn!("Ignoring net FDs sent via the HTTP request body"); + restore_cfg.net_fds = None; + } + if !files.is_empty() { + let fds = files.drain(..).map(|f| f.into_raw_fd()).collect(); + restore_cfg.net_fds = Some(fds); + } + self.send(api_notifier, api_sender, restore_cfg) + .map_err(HttpError::ApiError) + } else { + Err(HttpError::BadRequest) + } + } +} + +impl GetHandler for VmRestore {} + // Common handler for boot, shutdown and reboot pub struct VmActionHandler { action: &'static dyn HttpVmAction, diff --git a/vmm/src/config.rs b/vmm/src/config.rs index 4efc055b5a..6ecb14eb8b 100644 --- a/vmm/src/config.rs +++ b/vmm/src/config.rs @@ -197,6 +197,12 @@ pub enum ValidationError { InvalidIoPortHex(String), #[cfg(feature = "sev_snp")] InvalidHostData, + /// Restore expects all net ids that have fds + RestoreMissingRequiredNetId(String), + /// Restore does not expect net ids that do not have fds + RestoreNonFdNetIdNotExpected(String), + /// Number of FDs passed during Restore are incorrect to the VmConfig + RestoreNetFdCountMismatch(usize, usize), } type ValidationResult = std::result::Result; @@ -336,6 +342,18 @@ impl fmt::Display for ValidationError { InvalidHostData => { write!(f, "Invalid host data format") } + RestoreMissingRequiredNetId(s) => { + write!(f, "Net id {s} is associated with FDs and is required") + } + RestoreNonFdNetIdNotExpected(s) => { + write!(f, "Net id {s} is not associated with FDs and not expected") + } + RestoreNetFdCountMismatch(u1, u2) => { + write!( + f, + "Number of Net FDs passed during Restore: {u1}. Expected: {u2}" + ) + } } } } @@ -2064,17 +2082,24 @@ pub struct RestoreConfig { pub source_url: PathBuf, #[serde(default)] pub prefault: bool, + #[serde(default)] + pub net_ids: Option>, + #[serde(default)] + pub net_fds: Option>, } impl RestoreConfig { pub const SYNTAX: &'static str = "Restore from a VM snapshot. \ \nRestore parameters \"source_url=,prefault=on|off\" \ \n`source_url` should be a valid URL (e.g file:///foo/bar or tcp://192.168.1.10/foo) \ - \n`prefault` brings memory pages in when enabled (disabled by default)"; + \n`prefault` brings memory pages in when enabled (disabled by default) \ + \n`net_ids` is a list of NetConfig id \ + \n`net_fds` is a list of file descriptors for NetConfigs \ + \n`net_ids` and `net_fds` are optional and should be used together"; pub fn parse(restore: &str) -> Result { let mut parser = OptionParser::new(); - parser.add("source_url").add("prefault"); + parser.add("source_url").add("prefault").add("net_ids").add("net_fds"); parser.parse(restore).map_err(Error::ParseRestore)?; let source_url = parser @@ -2086,12 +2111,73 @@ impl RestoreConfig { .map_err(Error::ParseRestore)? .unwrap_or(Toggle(false)) .0; + let net_ids = parser + .convert::("net_ids") + .map_err(Error::ParseRestore)? + .map(|v| v.0); + let net_fds = parser + .convert::("net_fds") + .map_err(Error::ParseRestore)? + .map(|v| v.0.iter().map(|e| *e as i32).collect()); Ok(RestoreConfig { source_url, prefault, + net_ids, + net_fds, }) } + + pub fn validate(&self, vm_config: &VmConfig) -> ValidationResult<()> { + // Check if multiple same ids are passed + // verify if passed ids are valid + let vm_net_ids = match &vm_config.net { + Some(net_configs) => net_configs.iter().map(|v| v.id.as_ref().unwrap().clone()).collect(), + None => Vec::new(), + }; + if let Some(net_ids) = &self.net_ids { + let mut net_ids_= BTreeSet::new(); + for net_id in net_ids.iter() { + if !vm_net_ids.contains(net_id) { + return Err(ValidationError::InvalidIdentifier(net_id.clone())); + } + if !net_ids_.insert(net_id.clone()) { + return Err(ValidationError::IdentifierNotUnique(net_id.clone())); + } + } + } + // Iterate through the net_ids and check if all required ids are passed without any unexpected ids + // Also, verify that the number of fds passed is equal to the number of fds required + let restore_net_fds_count = self.net_fds.as_ref().map(|v| v.len()).unwrap_or(0); + let mut expected_fd_count = 0; + if let Some(net_configs) = &vm_config.net { + let net_ids = match &self.net_ids { + Some(ids) => ids.clone(), + None => Vec::new(), + }; + for net_config in net_configs.iter() { + if let Some(fds) = &net_config.fds { + expected_fd_count += fds.len(); + if !net_ids.contains(&net_config.id.as_ref().unwrap()) { + return Err(ValidationError::RestoreMissingRequiredNetId( + net_config.id.as_ref().unwrap().clone(), + )); + } + } else if net_ids.contains(&net_config.id.as_ref().unwrap()) { + return Err(ValidationError::RestoreNonFdNetIdNotExpected( + net_config.id.as_ref().unwrap().clone(), + )); + } + } + } + if restore_net_fds_count != expected_fd_count { + return Err(ValidationError::RestoreNetFdCountMismatch( + restore_net_fds_count, + expected_fd_count, + )); + } + Ok(()) + } } impl TpmConfig { diff --git a/vmm/src/lib.rs b/vmm/src/lib.rs index 60b4feeda9..2c74c5bb84 100644 --- a/vmm/src/lib.rs +++ b/vmm/src/lib.rs @@ -1321,6 +1321,26 @@ impl RequestHandler for Vmm { let vm_config = Arc::new(Mutex::new( recv_vm_config(source_url).map_err(VmError::Restore)?, )); + restore_cfg.validate(&vm_config.lock().unwrap().clone()).map_err(VmError::ConfigValidation)?; + + if let (Some(net_ids), Some(net_fds)) = (restore_cfg.net_ids, restore_cfg.net_fds) { + let mut idx_net_fds = 0; + for net_id in net_ids.iter() { + if let Some(net_configs) = &mut vm_config.lock().unwrap().net { + for net_config in net_configs.iter_mut() { + if net_config.id == Some(net_id.clone()) { + if let Some(fds) = net_config.fds.as_ref() { + let len = fds.len(); + let new_fds = net_fds[idx_net_fds..idx_net_fds + len].to_vec(); + idx_net_fds += len; + net_config.fds = Some(new_fds); + } + } + } + } + } + } + let snapshot = recv_vm_state(source_url).map_err(VmError::Restore)?; #[cfg(all(feature = "kvm", target_arch = "x86_64"))] let vm_snapshot = get_vm_snapshot(&snapshot).map_err(VmError::Restore)?;