Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 20 additions & 4 deletions app/src/ai/aws_credentials.rs
Original file line number Diff line number Diff line change
Expand Up @@ -131,19 +131,31 @@ fn aws_credentials_state_for_error(err: LoadAwsCredentialsError) -> AwsCredentia
/// # Arguments
/// * `profile` - AWS profile name. If empty, uses the default AWS SDK behavior
/// (checks AWS_PROFILE env var, then uses "default").
/// * `region_override` - Explicit region to use. If empty, the region is resolved
/// from the AWS SDK default provider chain (AWS_REGION env var, ~/.aws/config, etc.).
pub async fn load_aws_credentials_from_sdk(
profile: &str,
region_override: &str,
) -> Result<AwsCredentials, LoadAwsCredentialsError> {
let region_provider = aws_config::meta::region::RegionProviderChain::default_provider();
let loader =
aws_config::defaults(aws_config::BehaviorVersion::latest()).region(region_provider);
let loader = if region_override.trim().is_empty() {
let region_provider = aws_config::meta::region::RegionProviderChain::default_provider();
aws_config::defaults(aws_config::BehaviorVersion::latest()).region(region_provider)
} else {
aws_config::defaults(aws_config::BehaviorVersion::latest())
.region(aws_config::Region::new(region_override.trim().to_string()))
};
let loader = if profile.trim().is_empty() {
loader // Let AWS SDK use its default behavior
} else {
loader.profile_name(profile)
};
let config = loader.load().await;

let resolved_region = config
.region()
.map(|r| r.as_ref().to_string())
.unwrap_or_default();

let provider = config
.credentials_provider()
.ok_or(LoadAwsCredentialsError::NotConfigured)?;
Expand All @@ -164,6 +176,7 @@ pub async fn load_aws_credentials_from_sdk(
creds.secret_access_key().to_string(),
creds.session_token().map(|s| s.to_string()),
creds.expiry(),
resolved_region,
))
}

Expand Down Expand Up @@ -225,6 +238,7 @@ impl AwsCredentialRefresher for ApiKeyManager {
if matches!(
event,
AISettingsChangedEvent::AwsBedrockProfile { .. }
| AISettingsChangedEvent::AwsBedrockRegion { .. }
| AISettingsChangedEvent::AwsBedrockAuthRefreshCommand { .. }
| AISettingsChangedEvent::AwsBedrockCredentialsEnabled { .. }
) {
Expand Down Expand Up @@ -267,13 +281,14 @@ fn refresh_aws_credentials_local_chain(
}

let profile = (*AISettings::as_ref(ctx).aws_bedrock_profile).clone();
let region = (*AISettings::as_ref(ctx).aws_bedrock_region).clone();

manager.set_aws_credentials_state(AwsCredentialsState::Refreshing, ctx);

let (tx, rx) = channel();
// credential fetch from aws cli's disk cache
let _ = ctx.spawn(
async move { load_aws_credentials_from_sdk(&profile).await },
async move { load_aws_credentials_from_sdk(&profile, &region).await },
move |manager, result, ctx| {
let (new_state, tx_result) = match result {
Ok(credentials) => (
Expand Down Expand Up @@ -375,6 +390,7 @@ fn refresh_aws_credentials_oidc(
credentials.secret_access_key().to_string(),
Some(credentials.session_token().to_string()),
SystemTime::try_from(*credentials.expiration()).ok(),
region.clone(),
))
},
move |manager, result, ctx| {
Expand Down
12 changes: 12 additions & 0 deletions app/src/settings/ai.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1059,6 +1059,18 @@ define_settings_group!(AISettings, settings: [
toml_path: "cloud_platform.third_party_api_keys.aws_bedrock_profile",
description: "The AWS profile name to use for Bedrock credentials.",
}
// AWS region to use for Bedrock API requests (e.g. us-east-1, us-west-2).
// If empty, the region is resolved from the AWS SDK default provider chain
// (AWS_REGION env var, ~/.aws/config, etc.).
aws_bedrock_region: AwsBedrockRegion {
type: String,
default: String::new(),
supported_platforms: SupportedPlatforms::DESKTOP,
sync_to_cloud: SyncToCloud::Globally(RespectUserSyncSetting::Yes),
private: false,
toml_path: "cloud_platform.third_party_api_keys.aws_bedrock_region",
description: "The AWS region for Bedrock requests (e.g. us-east-1). If empty, resolved from AWS config.",
}
// Whether the AWS Bedrock login banner has been permanently dismissed.
//
// Not a user-visible setting - we model it as a setting so we can track state.
Expand Down
59 changes: 58 additions & 1 deletion app/src/settings_view/ai_page.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7874,6 +7874,7 @@ impl SettingsWidget for ApiKeysWidget {
struct AwsBedrockWidget {
aws_auth_refresh_command_editor: ViewHandle<EditorView>,
aws_auth_refresh_profile_editor: ViewHandle<EditorView>,
aws_region_editor: ViewHandle<EditorView>,
credentials_enabled_toggle: SwitchStateHandle,
auto_login_toggle: SwitchStateHandle,
refresh_credentials_button: ViewHandle<ActionButton>,
Expand All @@ -7886,6 +7887,7 @@ impl AwsBedrockWidget {

let aws_auth_refresh_command = ai_settings.aws_bedrock_auth_refresh_command.value().clone();
let aws_auth_refresh_profile = ai_settings.aws_bedrock_profile.value().clone();
let aws_region = ai_settings.aws_bedrock_region.value().clone();
let is_usage_enabled = is_any_ai_enabled
&& UserWorkspaces::as_ref(ctx).is_aws_bedrock_credentials_enabled(ctx);

Expand Down Expand Up @@ -7983,6 +7985,41 @@ impl AwsBedrockWidget {
}
});

let aws_region_editor = ctx.add_typed_action_view(move |ctx| {
let appearance = Appearance::as_ref(ctx);
let options = SingleLineEditorOptions {
is_password: false,
text: TextOptions {
font_size_override: Some(appearance.ui_font_size()),
font_family_override: Some(appearance.monospace_font_family()),
text_colors_override: Some(TextColors {
default_color: appearance.theme().active_ui_text_color(),
disabled_color: appearance.theme().disabled_ui_text_color(),
hint_color: appearance.theme().disabled_ui_text_color(),
}),
..Default::default()
},
..Default::default()
};
let mut editor = EditorView::single_line(options, ctx);
editor.set_placeholder_text("us-east-1", ctx);
editor.set_buffer_text(&aws_region, ctx);
editor
});
AISettingsPageView::update_editor_interaction_state(
aws_region_editor.clone(),
is_usage_enabled,
ctx,
);
ctx.subscribe_to_view(&aws_region_editor, |_, editor, event, ctx| {
if matches!(event, EditorEvent::Blurred | EditorEvent::Enter) {
let value = editor.as_ref(ctx).buffer_text(ctx);
AISettings::handle(ctx).update(ctx, |settings, ctx| {
let _ = settings.aws_bedrock_region.set_value(value, ctx);
});
}
});

let refresh_credentials_button = ctx.add_typed_action_view(|_| {
ActionButton::new("Refresh", SecondaryTheme)
.with_icon(Icon::RefreshCw04)
Expand All @@ -7998,6 +8035,7 @@ impl AwsBedrockWidget {
// Keep enablement in sync with the Global AI toggle.
let aws_auth_refresh_command_editor_clone = aws_auth_refresh_command_editor.clone();
let aws_auth_refresh_profile_editor_clone = aws_auth_refresh_profile_editor.clone();
let aws_region_editor_clone = aws_region_editor.clone();
let refresh_credentials_button_clone = refresh_credentials_button.clone();
ctx.subscribe_to_model(&AISettings::handle(ctx), move |_, _, event, ctx| {
if matches!(
Expand All @@ -8019,6 +8057,11 @@ impl AwsBedrockWidget {
is_usage_enabled,
ctx,
);
AISettingsPageView::update_editor_interaction_state(
aws_region_editor_clone.clone(),
is_usage_enabled,
ctx,
);
refresh_credentials_button_clone.update(ctx, |button, ctx| {
button.set_disabled(!is_usage_enabled, ctx);
});
Expand All @@ -8029,6 +8072,7 @@ impl AwsBedrockWidget {

let aws_auth_refresh_command_editor_clone = aws_auth_refresh_command_editor.clone();
let aws_auth_refresh_profile_editor_clone = aws_auth_refresh_profile_editor.clone();
let aws_region_editor_clone = aws_region_editor.clone();
let refresh_credentials_button_clone = refresh_credentials_button.clone();
ctx.subscribe_to_model(
&UserWorkspaces::handle(ctx),
Expand All @@ -8050,6 +8094,11 @@ impl AwsBedrockWidget {
is_usage_enabled,
ctx,
);
AISettingsPageView::update_editor_interaction_state(
aws_region_editor_clone.clone(),
is_usage_enabled,
ctx,
);
refresh_credentials_button_clone.update(ctx, |button, ctx| {
button.set_disabled(!is_usage_enabled, ctx);
});
Expand All @@ -8062,6 +8111,7 @@ impl AwsBedrockWidget {
Self {
aws_auth_refresh_command_editor,
aws_auth_refresh_profile_editor,
aws_region_editor,
credentials_enabled_toggle: SwitchStateHandle::default(),
auto_login_toggle: SwitchStateHandle::default(),
refresh_credentials_button,
Expand Down Expand Up @@ -8240,6 +8290,13 @@ impl AwsBedrockWidget {
is_usage_enabled,
app,
));
column.add_child(render_input(
appearance,
"AWS Region",
self.aws_region_editor.clone(),
is_usage_enabled,
app,
));

let auto_login_enabled = *AISettings::as_ref(app).aws_bedrock_auto_login.value();

Expand Down Expand Up @@ -8272,7 +8329,7 @@ impl SettingsWidget for AwsBedrockWidget {
type View = AISettingsPageView;

fn search_terms(&self) -> &str {
"aws bedrock amazon credentials login profile"
"aws bedrock amazon credentials login profile region"
}

fn should_render(&self, app: &AppContext) -> bool {
Expand Down
38 changes: 38 additions & 0 deletions crates/ai/src/api_keys_tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -334,3 +334,41 @@ fn api_keys_for_request_none_for_custom_endpoints_only() {
});
assert!(mgr.api_keys_for_request(true, false).is_none());
}

#[test]
fn api_keys_for_request_includes_aws_credentials_with_region() {
let mut mgr = make_manager(ApiKeys::default());
mgr.aws_credentials_state = AwsCredentialsState::Loaded {
credentials: AwsCredentials::new(
"AKID".to_string(),
"secret".to_string(),
Some("token".to_string()),
None,
"us-west-2".to_string(),
),
loaded_at: std::time::SystemTime::now(),
};
let result = mgr.api_keys_for_request(false, true).unwrap();
let aws = result.aws_credentials.unwrap();
assert_eq!(aws.access_key, "AKID");
assert_eq!(aws.secret_key, "secret");
assert_eq!(aws.session_token, "token");
assert_eq!(aws.region, "us-west-2");
}

#[test]
fn api_keys_for_request_aws_credentials_none_when_disabled() {
let mut mgr = make_manager(ApiKeys::default());
mgr.aws_credentials_state = AwsCredentialsState::Loaded {
credentials: AwsCredentials::new(
"AKID".to_string(),
"secret".to_string(),
None,
None,
"us-east-1".to_string(),
),
loaded_at: std::time::SystemTime::now(),
};
// include_aws_bedrock_credentials = false and no OIDC strategy
assert!(mgr.api_keys_for_request(false, false).is_none());
}
40 changes: 30 additions & 10 deletions crates/ai/src/aws_credentials.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ pub struct AwsCredentials {
secret_key: String,
session_token: Option<String>,
expires_at: Option<SystemTime>,
region: String,
}

impl AwsCredentials {
Expand All @@ -20,18 +21,24 @@ impl AwsCredentials {
secret_key: String,
session_token: Option<String>,
expires_at: Option<SystemTime>,
region: String,
) -> Self {
Self {
access_key,
secret_key,
session_token,
expires_at,
region,
}
}

pub fn expires_at(&self) -> Option<SystemTime> {
self.expires_at
}

pub fn region(&self) -> &str {
&self.region
}
}

#[derive(Debug, Clone, PartialEq, Eq)]
Expand All @@ -54,7 +61,7 @@ impl From<AwsCredentials> for api::request::settings::api_keys::AwsCredentials {
access_key: creds.access_key,
secret_key: creds.secret_key,
session_token: creds.session_token.unwrap_or_default(),
region: String::new(),
region: creds.region,
}
}
}
Expand All @@ -68,6 +75,10 @@ fn format_status_timestamp(time: SystemTime) -> String {
}
}

#[cfg(test)]
#[path = "aws_credentials_tests.rs"]
mod tests;

impl AwsCredentialsState {
pub fn user_facing_components(&self) -> (String, String, Icon) {
match self {
Expand All @@ -91,18 +102,27 @@ impl AwsCredentialsState {
Self::Loaded {
credentials,
loaded_at,
} => (
"Credentials loaded".to_string(),
match credentials.expires_at() {
} => {
let region_suffix = if credentials.region.is_empty() {
String::new()
} else {
format!(" · Region: {}", credentials.region)
};
let detail = match credentials.expires_at() {
Some(expires_at) => format!(
"Loaded at {}, expires {}",
"Loaded at {}, expires {}{}",
format_status_timestamp(*loaded_at),
format_status_timestamp(expires_at)
format_status_timestamp(expires_at),
region_suffix
),
None => format!("Loaded at {}", format_status_timestamp(*loaded_at)),
},
Icon::CheckCircleBroken,
),
None => format!(
"Loaded at {}{}",
format_status_timestamp(*loaded_at),
region_suffix
),
};
("Credentials loaded".to_string(), detail, Icon::CheckCircleBroken)
}
Self::Failed { message } => (
"Unable to load credentials".to_string(),
message.clone(),
Expand Down
Loading