From e7f950a23eeccf4847be7ba88aae5e970828f472 Mon Sep 17 00:00:00 2001 From: Olivier Lacroix Date: Mon, 6 May 2024 22:41:59 +1000 Subject: [PATCH] spawn --- src/cli/global/common.rs | 6 +-- src/cli/global/install.rs | 2 +- src/cli/global/upgrade.rs | 82 ++++++++++++++++++++------------------- 3 files changed, 47 insertions(+), 43 deletions(-) diff --git a/src/cli/global/common.rs b/src/cli/global/common.rs index 4834c4e9a..f4912145c 100644 --- a/src/cli/global/common.rs +++ b/src/cli/global/common.rs @@ -153,13 +153,13 @@ pub(super) fn channel_name_from_prefix( /// # Returns /// /// The package records (with dependencies records) for the given package MatchSpec -pub fn load_package_records( +pub fn load_package_records<'a>( package_matchspec: MatchSpec, - sparse_repodata: &IndexMap<(Channel, Platform), SparseRepoData>, + sparse_repodata: impl IntoIterator, ) -> miette::Result> { let package_name = package_name(&package_matchspec)?; let available_packages = - SparseRepoData::load_records_recursive(sparse_repodata.values(), vec![package_name], None) + SparseRepoData::load_records_recursive(sparse_repodata, vec![package_name], None) .into_diagnostic()?; let virtual_packages = rattler_virtual_packages::VirtualPackage::current() .into_diagnostic()? diff --git a/src/cli/global/install.rs b/src/cli/global/install.rs index 799333c08..f66af19d1 100644 --- a/src/cli/global/install.rs +++ b/src/cli/global/install.rs @@ -258,7 +258,7 @@ pub async fn execute(args: Args) -> miette::Result<()> { // Install the package(s) let mut executables = vec![]; for (package_name, package_matchspec) in args.specs()? { - let records = load_package_records(package_matchspec, &sparse_repodata)?; + let records = load_package_records(package_matchspec, sparse_repodata.values())?; let (prefix_package, scripts, _) = globally_install_package( &package_name, diff --git a/src/cli/global/upgrade.rs b/src/cli/global/upgrade.rs index bdb7af451..94c6cca7a 100644 --- a/src/cli/global/upgrade.rs +++ b/src/cli/global/upgrade.rs @@ -2,10 +2,12 @@ use std::collections::HashMap; use std::time::Duration; use clap::Parser; +use futures::stream::FuturesUnordered; +use futures::StreamExt; use indexmap::IndexMap; use indicatif::ProgressBar; use itertools::Itertools; -use miette::IntoDiagnostic; +use miette::{IntoDiagnostic, Report}; use rattler_conda_types::{Channel, MatchSpec, PackageName, Platform}; use crate::config::Config; @@ -59,41 +61,49 @@ pub(super) async fn upgrade_packages( cli_channels: &[String], platform: &Platform, ) -> miette::Result<()> { - // Get channels and versions of globally installed packages - let mut installed_versions = HashMap::with_capacity(specs.len()); - let mut channels = config.compute_channels(cli_channels).into_diagnostic()?; - - for package_name in specs.keys() { - let prefix_record = find_installed_package(package_name).await?; - let last_installed_channel = Channel::from_str( - prefix_record.repodata_record.channel.clone(), - config.channel_config(), - ) - .into_diagnostic()?; - - channels.push(last_installed_channel); - - let installed_version = prefix_record - .repodata_record - .package_record - .version - .into_version(); - installed_versions.insert(package_name.clone(), installed_version); + let channel_cli = config.compute_channels(cli_channels).into_diagnostic()?; + + // Get channels and version of globally installed packages in parallel + let mut channels = HashMap::with_capacity(specs.len()); + let mut versions = HashMap::with_capacity(specs.len()); + let mut handles = FuturesUnordered::>>::new(); + for package_name in specs.keys().cloned() { + let channel_config = config.channel_config().clone(); + let future: tokio::task::JoinHandle> = tokio::spawn(async move { + let p = find_installed_package(&package_name).await?; + let channel = + Channel::from_str(p.repodata_record.channel, &channel_config).into_diagnostic()?; + let version = p.repodata_record.package_record.version.into_version(); + Ok((package_name, channel, version)) + }); + handles.push(future); + } + while let Some(data) = handles.next().await { + let (package_name, channel, version) = data.into_diagnostic()??; + channels.insert(package_name.clone(), channel); + versions.insert(package_name, version); } - channels = channels.into_iter().unique().collect(); - // Fetch sparse repodata - let (authenticated_client, sparse_repodata) = - get_client_and_sparse_repodata(&channels, *platform, &config).await?; + // Fetch sparse repodata across all channels + let all_channels = channels.values().chain(channel_cli.iter()).unique(); + let (client, repodata) = + get_client_and_sparse_repodata(all_channels, *platform, &config).await?; // Upgrade each package when relevant let mut upgraded = false; - for (package_name, package_matchspec) in specs { - let matchspec_has_version = package_matchspec.version.is_some(); - let records = load_package_records(package_matchspec, &sparse_repodata)?; + for (package_name, package_matchspec) in specs.iter() { + // Filter repodata based on channels specific to the package (and from the CLI) + let specific_repodata = repodata.iter().filter_map(|((c, _), v)| { + if channel_cli.contains(c) || channels.get(package_name).unwrap() == c { + Some(v) + } else { + None + } + }); + let records = load_package_records(package_matchspec.clone(), specific_repodata)?; let package_record = records .iter() - .find(|r| r.package_record.name == package_name) + .find(|r| r.package_record.name == *package_name) .ok_or_else(|| { miette::miette!( "Package {} not found in the specified channels", @@ -101,14 +111,14 @@ pub(super) async fn upgrade_packages( ) })?; let toinstall_version = package_record.package_record.version.version().to_owned(); - let installed_version = installed_versions - .get(&package_name) + let installed_version = versions + .get(package_name) .expect("should have the installed version") .to_owned(); // Perform upgrade if a specific version was requested // OR if a more recent version is available - if matchspec_has_version || toinstall_version > installed_version { + if package_matchspec.version.is_some() || toinstall_version > installed_version { let message = format!( "{} v{} -> v{}", package_name.as_normalized(), @@ -124,13 +134,7 @@ pub(super) async fn upgrade_packages( console::style("Updating").green(), message )); - globally_install_package( - &package_name, - records, - authenticated_client.clone(), - platform, - ) - .await?; + globally_install_package(package_name, records, client.clone(), platform).await?; pb.finish_with_message(format!("{} {}", console::style("Updated").green(), message)); upgraded = true; }