diff --git a/src/modules/aws.rs b/src/modules/aws.rs index fa83c6111653a..e4fa11a18635c 100644 --- a/src/modules/aws.rs +++ b/src/modules/aws.rs @@ -63,7 +63,7 @@ fn get_creds<'a>(context: &Context, config: &'a OnceCell>) -> Option // Get the section for a given profile name in the config file. fn get_profile_config<'a>( config: &'a Ini, - profile: &Option, + profile: Option<&Profile>, ) -> Option<&'a ini::Properties> { match profile { Some(profile) => config.section(Some(format!("profile {}", profile))), @@ -88,7 +88,7 @@ fn get_aws_region_from_config( aws_config: &AwsConfigFile, ) -> Option { let config = get_config(context, aws_config)?; - let section = get_profile_config(config, aws_profile)?; + let section = get_profile_config(config, aws_profile.as_ref())?; section.get("region").map(std::borrow::ToOwned::to_owned) } @@ -148,13 +148,38 @@ fn alias_name(name: Option, aliases: &HashMap) -> Option, + aws_profile: Option<&Profile>, aws_config: &AwsConfigFile, + aws_creds: &AwsCredsFile, ) -> Option { let config = get_config(context, aws_config)?; + let credentials = get_creds(context, aws_creds); + + let empty_section = ini::Properties::new(); + // We use the aws_profile here because `get_profile_config()` treats None + // as "special" and falls back to the "[default]"; otherwise this tries + // to look up "[profile default]" which doesn't exist + let config_section = get_profile_config(config, aws_profile).or(Some(&empty_section))?; + + let default_profile = "default".to_string(); + let credential_profile = aws_profile.unwrap_or(&default_profile); + + // credentials are optional; if we can't find the credential file, + // we work with just the config and treat the credential section as + // empty + let credential_section = match credentials { + Some(credentials) => credentials + .section(Some(credential_profile)) + .or(Some(&empty_section))?, + None => &empty_section, + }; - let section = get_profile_config(config, aws_profile)?; - Some(section.contains_key("credential_process") || section.contains_key("sso_start_url")) + Some( + config_section.contains_key("credential_process") + || config_section.contains_key("sso_start_url") + || credential_section.contains_key("credential_process") + || credential_section.contains_key("sso_start_url"), + ) } fn has_defined_credentials( @@ -195,7 +220,8 @@ pub fn module<'a>(context: &'a Context) -> Option> { // only display if credential_process is defined or has valid credentials if !config.force_display - && !has_credential_process_or_sso(context, &aws_profile, &aws_config).unwrap_or(false) + && !has_credential_process_or_sso(context, aws_profile.as_ref(), &aws_config, &aws_creds) + .unwrap_or(false) && !has_defined_credentials(context, &aws_profile, &aws_creds).unwrap_or(false) { return None; @@ -888,6 +914,44 @@ credential_process = /opt/bin/awscreds-retriever dir.close() } + #[test] + fn credential_process_set_in_credentials() -> io::Result<()> { + let dir = tempfile::tempdir()?; + let config_path = dir.path().join("config"); + let credential_path = dir.path().join("credentials"); + let mut file = File::create(&config_path)?; + + file.write_all( + "[default] +region = ap-northeast-2 +" + .as_bytes(), + )?; + + let mut file = File::create(&credential_path)?; + + file.write_all( + "[default] +credential_process = /opt/bin/awscreds-for-tests +" + .as_bytes(), + )?; + let actual = ModuleRenderer::new("aws") + .env("AWS_CONFIG_FILE", config_path.to_string_lossy().as_ref()) + .env( + "AWS_CREDENTIALS_FILE", + credential_path.to_string_lossy().as_ref(), + ) + .collect(); + let expected = Some(format!( + "on {}", + Color::Yellow.bold().paint("☁️ (ap-northeast-2) ") + )); + + assert_eq!(expected, actual); + dir.close() + } + #[test] fn sso_set() -> io::Result<()> { let dir = tempfile::tempdir()?;