Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor(connector): use AwsAuthProps instead of HashMap in config #13513

Merged
merged 9 commits into from
Nov 22, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
139 changes: 0 additions & 139 deletions src/connector/src/aws_auth.rs

This file was deleted.

18 changes: 2 additions & 16 deletions src/connector/src/aws_utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,22 +22,9 @@ use risingwave_common::error::ErrorCode::InternalError;
use risingwave_common::error::{Result, RwError};
use url::Url;

use crate::aws_auth::AwsAuthProps;
use crate::common::AwsAuthProps;

pub const REGION: &str = "region";
pub const ACCESS_KEY: &str = "access_key";
pub const SECRET_ACCESS: &str = "secret_access";

pub const AWS_DEFAULT_CONFIG: [&str; 7] = [
REGION,
"arn",
"profile",
ACCESS_KEY,
SECRET_ACCESS,
"session_token",
"endpoint_url",
];
pub const AWS_CUSTOM_CONFIG_KEY: [&str; 3] = ["retry_times", "conn_timeout", "read_timeout"];
const AWS_CUSTOM_CONFIG_KEY: [&str; 3] = ["retry_times", "conn_timeout", "read_timeout"];

pub fn default_conn_config() -> HashMap<String, u64> {
let mut default_conn_config = HashMap::new();
Expand Down Expand Up @@ -118,7 +105,6 @@ pub fn s3_client(
}

// TODO(Tao): Probably we should never allow to use S3 URI.
/// properties require keys: refer to [`AWS_DEFAULT_CONFIG`]
pub async fn load_file_descriptor_from_s3(
location: &Url,
config: &AwsAuthProps,
Expand Down
108 changes: 95 additions & 13 deletions src/connector/src/common.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,6 @@ use time::OffsetDateTime;
use url::Url;
use with_options::WithOptions;

use crate::aws_auth::AwsAuthProps;
use crate::aws_utils::load_file_descriptor_from_s3;
use crate::deserialize_duration_from_string;
use crate::sink::SinkError;
Expand All @@ -51,6 +50,98 @@ pub struct AwsPrivateLinkItem {
pub port: u16,
}

use aws_config::default_provider::region::DefaultRegionChain;
use aws_config::sts::AssumeRoleProvider;
use aws_credential_types::provider::SharedCredentialsProvider;
use aws_types::region::Region;
use aws_types::SdkConfig;

/// A flatten config map for aws auth.
#[derive(Deserialize, Serialize, Debug, Clone, WithOptions)]
pub struct AwsAuthProps {
pub region: Option<String>,
#[serde(alias = "endpoint_url")]
pub endpoint: Option<String>,
pub access_key: Option<String>,
pub secret_key: Option<String>,
pub session_token: Option<String>,
pub arn: Option<String>,
/// This field was added for kinesis. Not sure if it's useful for other connectors.
/// Please ignore it in the documentation for now.
pub external_id: Option<String>,
pub profile: Option<String>,
}

impl AwsAuthProps {
async fn build_region(&self) -> anyhow::Result<Region> {
if let Some(region_name) = &self.region {
Ok(Region::new(region_name.clone()))
} else {
let mut region_chain = DefaultRegionChain::builder();
if let Some(profile_name) = &self.profile {
region_chain = region_chain.profile_name(profile_name);
}

Ok(region_chain
.build()
.region()
.await
.ok_or_else(|| anyhow::format_err!("region should be provided"))?)
}
}

fn build_credential_provider(&self) -> anyhow::Result<SharedCredentialsProvider> {
if self.access_key.is_some() && self.secret_key.is_some() {
Ok(SharedCredentialsProvider::new(
aws_credential_types::Credentials::from_keys(
self.access_key.as_ref().unwrap(),
self.secret_key.as_ref().unwrap(),
self.session_token.clone(),
),
))
} else {
Err(anyhow!(
"Both \"access_key\" and \"secret_access\" are required."
))
}
}

async fn with_role_provider(
&self,
credential: SharedCredentialsProvider,
) -> anyhow::Result<SharedCredentialsProvider> {
if let Some(role_name) = &self.arn {
let region = self.build_region().await?;
let mut role = AssumeRoleProvider::builder(role_name)
.session_name("RisingWave")
.region(region);
if let Some(id) = &self.external_id {
role = role.external_id(id);
}
let provider = role.build_from_provider(credential).await;
Ok(SharedCredentialsProvider::new(provider))
} else {
Ok(credential)
}
}

pub async fn build_config(&self) -> anyhow::Result<SdkConfig> {
let region = self.build_region().await?;
let credentials_provider = self
.with_role_provider(self.build_credential_provider()?)
.await?;
let mut config_loader = aws_config::from_env()
.region(region)
.credentials_provider(credentials_provider);

if let Some(endpoint) = self.endpoint.as_ref() {
config_loader = config_loader.endpoint_url(endpoint);
}

Ok(config_loader.load().await)
}
}

#[serde_as]
#[derive(Debug, Clone, Serialize, Deserialize, WithOptions)]
pub struct KafkaCommon {
Expand Down Expand Up @@ -282,8 +373,7 @@ pub struct PulsarOauthCommon {
pub scope: Option<String>,

#[serde(flatten)]
/// required keys refer to [`crate::aws_utils::AWS_DEFAULT_CONFIG`]
pub s3_credentials: HashMap<String, String>,
pub aws_auth_props: AwsAuthProps,
}

impl PulsarCommon {
Expand All @@ -294,16 +384,8 @@ impl PulsarCommon {
let url = Url::parse(&oauth.credentials_url)?;
match url.scheme() {
"s3" => {
let credentials = load_file_descriptor_from_s3(
&url,
&AwsAuthProps::from_pairs(
oauth
.s3_credentials
.iter()
.map(|(k, v)| (k.as_str(), v.as_str())),
),
)
.await?;
let credentials =
load_file_descriptor_from_s3(&url, &oauth.aws_auth_props).await?;
let mut f = NamedTempFile::new()?;
f.write_all(&credentials)?;
f.as_file().sync_all()?;
Expand Down
1 change: 0 additions & 1 deletion src/connector/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,6 @@ use risingwave_pb::connector_service::SinkPayloadFormat;
use risingwave_rpc_client::ConnectorClient;
use serde::de;

pub mod aws_auth;
pub mod aws_utils;
pub mod error;
mod macros;
Expand Down
11 changes: 3 additions & 8 deletions src/connector/src/parser/avro/parser.rs
Original file line number Diff line number Diff line change
Expand Up @@ -219,7 +219,7 @@ mod test {
read_schema_from_http, read_schema_from_local, read_schema_from_s3, AvroAccessBuilder,
AvroParserConfig,
};
use crate::aws_auth::AwsAuthProps;
use crate::common::AwsAuthProps;
use crate::parser::bytes_parser::BytesAccessBuilder;
use crate::parser::plain_parser::PlainParser;
use crate::parser::unified::avro::unix_epoch_days;
Expand Down Expand Up @@ -256,14 +256,9 @@ mod test {
#[ignore]
async fn test_load_schema_from_s3() {
let schema_location = "s3://mingchao-schemas/complex-schema.avsc".to_string();
let mut s3_config_props = HashMap::new();
s3_config_props.insert("region".to_string(), "ap-southeast-1".to_string());
let url = Url::parse(&schema_location).unwrap();
let aws_auth_config = AwsAuthProps::from_pairs(
s3_config_props
.iter()
.map(|(k, v)| (k.as_str(), v.as_str())),
);
let aws_auth_config: AwsAuthProps =
serde_json::from_str(r#"region":"ap-southeast-1"#).unwrap();
let schema_content = read_schema_from_s3(&url, &aws_auth_config).await;
assert!(schema_content.is_ok());
let schema = Schema::parse_str(&schema_content.unwrap());
Expand Down
20 changes: 13 additions & 7 deletions src/connector/src/parser/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ use self::simd_json_parser::DebeziumJsonAccessBuilder;
use self::unified::{AccessImpl, AccessResult};
use self::upsert_parser::UpsertParser;
use self::util::get_kafka_topic;
use crate::aws_auth::AwsAuthProps;
use crate::common::AwsAuthProps;
use crate::parser::maxwell::MaxwellParser;
use crate::schema::schema_registry::SchemaRegistryAuth;
use crate::source::{
Expand Down Expand Up @@ -912,9 +912,12 @@ impl SpecificParserConfig {
config.topic = get_kafka_topic(props)?.clone();
config.client_config = SchemaRegistryAuth::from(props);
} else {
config.aws_auth_props = Some(AwsAuthProps::from_pairs(
props.iter().map(|(k, v)| (k.as_str(), v.as_str())),
));
config.aws_auth_props = Some(
serde_json::from_value::<AwsAuthProps>(
serde_json::to_value(props).unwrap(),
)
.map_err(|e| anyhow::anyhow!(e))?,
);
}
EncodingProperties::Avro(config)
}
Expand All @@ -941,9 +944,12 @@ impl SpecificParserConfig {
config.topic = get_kafka_topic(props)?.clone();
config.client_config = SchemaRegistryAuth::from(props);
} else {
config.aws_auth_props = Some(AwsAuthProps::from_pairs(
props.iter().map(|(k, v)| (k.as_str(), v.as_str())),
));
config.aws_auth_props = Some(
serde_json::from_value::<AwsAuthProps>(
serde_json::to_value(props).unwrap(),
)
.map_err(|e| anyhow::anyhow!(e))?,
);
}
EncodingProperties::Protobuf(config)
}
Expand Down
2 changes: 1 addition & 1 deletion src/connector/src/parser/util.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,8 @@ use risingwave_common::error::ErrorCode::{
};
use risingwave_common::error::{Result, RwError};

use crate::aws_auth::AwsAuthProps;
use crate::aws_utils::{default_conn_config, s3_client};
use crate::common::AwsAuthProps;

const AVRO_SCHEMA_LOCATION_S3_REGION: &str = "region";

Expand Down