Skip to content

Commit

Permalink
spawn
Browse files Browse the repository at this point in the history
  • Loading branch information
olivier-lacroix committed May 19, 2024
1 parent b77aa15 commit e7f950a
Show file tree
Hide file tree
Showing 3 changed files with 47 additions and 43 deletions.
6 changes: 3 additions & 3 deletions src/cli/global/common.rs
Original file line number Diff line number Diff line change
Expand Up @@ -153,13 +153,13 @@ pub(super) fn channel_name_from_prefix(
/// # Returns
///
/// The package records (with dependencies records) for the given package MatchSpec
pub fn load_package_records(
pub fn load_package_records<'a>(
package_matchspec: MatchSpec,
sparse_repodata: &IndexMap<(Channel, Platform), SparseRepoData>,
sparse_repodata: impl IntoIterator<Item = &'a SparseRepoData>,
) -> miette::Result<Vec<RepoDataRecord>> {
let package_name = package_name(&package_matchspec)?;
let available_packages =
SparseRepoData::load_records_recursive(sparse_repodata.values(), vec![package_name], None)
SparseRepoData::load_records_recursive(sparse_repodata, vec![package_name], None)
.into_diagnostic()?;
let virtual_packages = rattler_virtual_packages::VirtualPackage::current()
.into_diagnostic()?
Expand Down
2 changes: 1 addition & 1 deletion src/cli/global/install.rs
Original file line number Diff line number Diff line change
Expand Up @@ -258,7 +258,7 @@ pub async fn execute(args: Args) -> miette::Result<()> {
// Install the package(s)
let mut executables = vec![];
for (package_name, package_matchspec) in args.specs()? {
let records = load_package_records(package_matchspec, &sparse_repodata)?;
let records = load_package_records(package_matchspec, sparse_repodata.values())?;

let (prefix_package, scripts, _) = globally_install_package(
&package_name,
Expand Down
82 changes: 43 additions & 39 deletions src/cli/global/upgrade.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,12 @@ use std::collections::HashMap;
use std::time::Duration;

use clap::Parser;
use futures::stream::FuturesUnordered;
use futures::StreamExt;
use indexmap::IndexMap;
use indicatif::ProgressBar;
use itertools::Itertools;
use miette::IntoDiagnostic;
use miette::{IntoDiagnostic, Report};
use rattler_conda_types::{Channel, MatchSpec, PackageName, Platform};

use crate::config::Config;
Expand Down Expand Up @@ -59,56 +61,64 @@ pub(super) async fn upgrade_packages(
cli_channels: &[String],
platform: &Platform,
) -> miette::Result<()> {
// Get channels and versions of globally installed packages
let mut installed_versions = HashMap::with_capacity(specs.len());
let mut channels = config.compute_channels(cli_channels).into_diagnostic()?;

for package_name in specs.keys() {
let prefix_record = find_installed_package(package_name).await?;
let last_installed_channel = Channel::from_str(
prefix_record.repodata_record.channel.clone(),
config.channel_config(),
)
.into_diagnostic()?;

channels.push(last_installed_channel);

let installed_version = prefix_record
.repodata_record
.package_record
.version
.into_version();
installed_versions.insert(package_name.clone(), installed_version);
let channel_cli = config.compute_channels(cli_channels).into_diagnostic()?;

// Get channels and version of globally installed packages in parallel
let mut channels = HashMap::with_capacity(specs.len());
let mut versions = HashMap::with_capacity(specs.len());
let mut handles = FuturesUnordered::<tokio::task::JoinHandle<Result<_, Report>>>::new();
for package_name in specs.keys().cloned() {
let channel_config = config.channel_config().clone();
let future: tokio::task::JoinHandle<Result<_, Report>> = tokio::spawn(async move {
let p = find_installed_package(&package_name).await?;
let channel =
Channel::from_str(p.repodata_record.channel, &channel_config).into_diagnostic()?;
let version = p.repodata_record.package_record.version.into_version();
Ok((package_name, channel, version))
});
handles.push(future);
}
while let Some(data) = handles.next().await {
let (package_name, channel, version) = data.into_diagnostic()??;
channels.insert(package_name.clone(), channel);
versions.insert(package_name, version);
}
channels = channels.into_iter().unique().collect();

// Fetch sparse repodata
let (authenticated_client, sparse_repodata) =
get_client_and_sparse_repodata(&channels, *platform, &config).await?;
// Fetch sparse repodata across all channels
let all_channels = channels.values().chain(channel_cli.iter()).unique();
let (client, repodata) =
get_client_and_sparse_repodata(all_channels, *platform, &config).await?;

// Upgrade each package when relevant
let mut upgraded = false;
for (package_name, package_matchspec) in specs {
let matchspec_has_version = package_matchspec.version.is_some();
let records = load_package_records(package_matchspec, &sparse_repodata)?;
for (package_name, package_matchspec) in specs.iter() {
// Filter repodata based on channels specific to the package (and from the CLI)
let specific_repodata = repodata.iter().filter_map(|((c, _), v)| {
if channel_cli.contains(c) || channels.get(package_name).unwrap() == c {
Some(v)
} else {
None
}
});
let records = load_package_records(package_matchspec.clone(), specific_repodata)?;
let package_record = records
.iter()
.find(|r| r.package_record.name == package_name)
.find(|r| r.package_record.name == *package_name)
.ok_or_else(|| {
miette::miette!(
"Package {} not found in the specified channels",
package_name.as_normalized()
)
})?;
let toinstall_version = package_record.package_record.version.version().to_owned();
let installed_version = installed_versions
.get(&package_name)
let installed_version = versions
.get(package_name)
.expect("should have the installed version")
.to_owned();

// Perform upgrade if a specific version was requested
// OR if a more recent version is available
if matchspec_has_version || toinstall_version > installed_version {
if package_matchspec.version.is_some() || toinstall_version > installed_version {
let message = format!(
"{} v{} -> v{}",
package_name.as_normalized(),
Expand All @@ -124,13 +134,7 @@ pub(super) async fn upgrade_packages(
console::style("Updating").green(),
message
));
globally_install_package(
&package_name,
records,
authenticated_client.clone(),
platform,
)
.await?;
globally_install_package(package_name, records, client.clone(), platform).await?;
pb.finish_with_message(format!("{} {}", console::style("Updated").green(), message));
upgraded = true;
}
Expand Down

0 comments on commit e7f950a

Please sign in to comment.