Skip to content

Commit

Permalink
fix(aws): enable when using .aws/credentials
Browse files Browse the repository at this point in the history
  • Loading branch information
offbyone committed Nov 20, 2022
1 parent 627a058 commit 48aac57
Showing 1 changed file with 68 additions and 12 deletions.
80 changes: 68 additions & 12 deletions src/modules/aws.rs
Expand Up @@ -63,7 +63,7 @@ fn get_creds<'a>(context: &Context, config: &'a OnceCell<Option<Ini>>) -> Option
// Get the section for a given profile name in the config file.
fn get_profile_config<'a>(
config: &'a Ini,
profile: &Option<Profile>,
profile: Option<&Profile>,
) -> Option<&'a ini::Properties> {
match profile {
Some(profile) => config.section(Some(format!("profile {}", profile))),
Expand All @@ -74,11 +74,11 @@ fn get_profile_config<'a>(
// Get the section for a given profile name in the credentials file.
fn get_profile_creds<'a>(
config: &'a Ini,
profile: &Option<Profile>,
profile: Option<&Profile>,
) -> Option<&'a ini::Properties> {
match profile {
None => config.section(Some("default")),
_ => config.section(profile.as_ref()),
_ => config.section(profile),
}
}

Expand All @@ -88,7 +88,7 @@ fn get_aws_region_from_config(
aws_config: &AwsConfigFile,
) -> Option<Region> {
let config = get_config(context, aws_config)?;
let section = get_profile_config(config, aws_profile)?;
let section = get_profile_config(config, aws_profile.as_ref())?;

section.get("region").map(std::borrow::ToOwned::to_owned)
}
Expand Down Expand Up @@ -118,7 +118,7 @@ fn get_aws_profile_and_region(

fn get_credentials_duration(
context: &Context,
aws_profile: &Option<String>,
aws_profile: Option<&Profile>,
aws_creds: &AwsCredsFile,
) -> Option<i64> {
let expiration_env_vars = ["AWS_SESSION_EXPIRATION", "AWSUME_EXPIRATION"];
Expand Down Expand Up @@ -148,18 +148,35 @@ fn alias_name(name: Option<String>, aliases: &HashMap<String, &str>) -> Option<S

fn has_credential_process_or_sso(
context: &Context,
aws_profile: &Option<Profile>,
aws_profile: Option<&Profile>,
aws_config: &AwsConfigFile,
aws_creds: &AwsCredsFile,
) -> Option<bool> {
let config = get_config(context, aws_config)?;
let credentials = get_creds(context, aws_creds);

let empty_section = ini::Properties::new();
// We use the aws_profile here because `get_profile_config()` treats None
// as "special" and falls back to the "[default]"; otherwise this tries
// to look up "[profile default]" which doesn't exist
let config_section = get_profile_config(config, aws_profile).or(Some(&empty_section))?;

let credential_section = match credentials {
Some(credentials) => get_profile_creds(credentials, aws_profile),
None => None,
};

let section = get_profile_config(config, aws_profile)?;
Some(section.contains_key("credential_process") || section.contains_key("sso_start_url"))
Some(
config_section.contains_key("credential_process")
|| config_section.contains_key("sso_start_url")
|| credential_section?.contains_key("credential_process")
|| credential_section?.contains_key("sso_start_url"),
)
}

fn has_defined_credentials(
context: &Context,
aws_profile: &Option<Profile>,
aws_profile: Option<&Profile>,
aws_creds: &AwsCredsFile,
) -> Option<bool> {
let valid_env_vars = [
Expand Down Expand Up @@ -195,14 +212,15 @@ pub fn module<'a>(context: &'a Context) -> Option<Module<'a>> {

// only display if credential_process is defined or has valid credentials
if !config.force_display
&& !has_credential_process_or_sso(context, &aws_profile, &aws_config).unwrap_or(false)
&& !has_defined_credentials(context, &aws_profile, &aws_creds).unwrap_or(false)
&& !has_credential_process_or_sso(context, aws_profile.as_ref(), &aws_config, &aws_creds)
.unwrap_or(false)
&& !has_defined_credentials(context, aws_profile.as_ref(), &aws_creds).unwrap_or(false)
{
return None;
}

let duration = {
get_credentials_duration(context, &aws_profile, &aws_creds).map(|duration| {
get_credentials_duration(context, aws_profile.as_ref(), &aws_creds).map(|duration| {
if duration > 0 {
render_time((duration * 1000) as u128, false)
} else {
Expand Down Expand Up @@ -888,6 +906,44 @@ credential_process = /opt/bin/awscreds-retriever
dir.close()
}

#[test]
fn credential_process_set_in_credentials() -> io::Result<()> {
let dir = tempfile::tempdir()?;
let config_path = dir.path().join("config");
let credential_path = dir.path().join("credentials");
let mut file = File::create(&config_path)?;

file.write_all(
"[default]
region = ap-northeast-2
"
.as_bytes(),
)?;

let mut file = File::create(&credential_path)?;

file.write_all(
"[default]
credential_process = /opt/bin/awscreds-for-tests
"
.as_bytes(),
)?;
let actual = ModuleRenderer::new("aws")
.env("AWS_CONFIG_FILE", config_path.to_string_lossy().as_ref())
.env(
"AWS_CREDENTIALS_FILE",
credential_path.to_string_lossy().as_ref(),
)
.collect();
let expected = Some(format!(
"on {}",
Color::Yellow.bold().paint("鈽侊笍 (ap-northeast-2) ")
));

assert_eq!(expected, actual);
dir.close()
}

#[test]
fn sso_set() -> io::Result<()> {
let dir = tempfile::tempdir()?;
Expand Down

0 comments on commit 48aac57

Please sign in to comment.