From 252abc725dead87b622a18340c56fda2621427be Mon Sep 17 00:00:00 2001 From: Anthony Lukach Date: Mon, 12 May 2025 22:19:35 -0700 Subject: [PATCH 1/4] Reorg errors --- src/apis/mod.rs | 2 +- src/apis/source.rs | 141 ++++---- src/backends/azure.rs | 213 +++++------- src/backends/common.rs | 22 +- src/backends/s3.rs | 752 +++++++++++------------------------------ src/utils/errors.rs | 124 ++++--- 6 files changed, 446 insertions(+), 808 deletions(-) diff --git a/src/apis/mod.rs b/src/apis/mod.rs index 349ca0c..cff51ed 100644 --- a/src/apis/mod.rs +++ b/src/apis/mod.rs @@ -27,5 +27,5 @@ pub trait API { &self, account_id: String, user_identity: UserIdentity, - ) -> Result; + ) -> Result; } diff --git a/src/apis/source.rs b/src/apis/source.rs index 02ea45d..217bb0a 100644 --- a/src/apis/source.rs +++ b/src/apis/source.rs @@ -4,7 +4,7 @@ use crate::backends::common::Repository; use crate::backends::s3::S3Repository; use crate::utils::api::process_json_response; use crate::utils::auth::UserIdentity; -use crate::utils::errors::{APIError, BackendError, InternalServerError}; +use crate::utils::errors::BackendError; use async_trait::async_trait; use moka::future::Cache; use rusoto_core::Region; @@ -237,7 +237,7 @@ impl API for SourceAPI { &self, account_id: String, user_identity: UserIdentity, - ) -> Result { + ) -> Result { let client = reqwest::Client::new(); // Create headers let mut headers = reqwest::header::HeaderMap::new(); @@ -252,29 +252,27 @@ impl API for SourceAPI { ); } - match client + let response = client .get(format!( "{}/api/v1/repositories/{}", self.endpoint, account_id )) .headers(headers) .send() - .await - { - Ok(response) => match response.json::().await { - Ok(repository_list) => { - let mut account = Account::default(); + .await?; - for repository in repository_list.repositories { - account.repositories.push(repository.repository_id); - } + let repository_list = process_json_response::( + response, + BackendError::RepositoryNotFound, + ) + .await?; + let mut account = Account::default(); - Ok(account) - } - Err(_) => Err(()), - }, - Err(_) => Err(()), + for repository in repository_list.repositories { + account.repositories.push(repository.repository_id); } + + Ok(account) } } @@ -323,7 +321,7 @@ impl SourceAPI { /// # Returns /// /// Returns a `Result` containing either a `SourceRepository` struct with the - /// repository information or a boxed `APIError` if the request fails. + /// repository information or a BackendError if the request fails. pub async fn get_repository_record( &self, account_id: &String, @@ -397,7 +395,7 @@ impl SourceAPI { } } - pub async fn get_api_key(&self, access_key_id: String) -> Result> { + pub async fn get_api_key(&self, access_key_id: String) -> Result { // Try to get the cached value let cache_key = format!("{}", access_key_id); @@ -406,32 +404,25 @@ impl SourceAPI { } // If not in cache, fetch it - match self.fetch_api_key(access_key_id).await { - Ok(secret) => { - // Cache the successful result - match secret { - Some(secret) => { - self.api_key_cache.insert(cache_key, secret.clone()).await; - Ok(secret) - } - None => { - let secret = APIKey { - access_key_id: "".to_string(), - secret_access_key: "".to_string(), - }; - self.api_key_cache.insert(cache_key, secret.clone()).await; - Ok(secret) - } - } + let secret = self.fetch_api_key(access_key_id).await?; + // Cache the successful result + match secret { + Some(secret) => { + self.api_key_cache.insert(cache_key, secret.clone()).await; + Ok(secret) + } + None => { + let secret = APIKey { + access_key_id: "".to_string(), + secret_access_key: "".to_string(), + }; + self.api_key_cache.insert(cache_key, secret.clone()).await; + Ok(secret) } - Err(e) => Err(e), } } - async fn fetch_api_key( - &self, - access_key_id: String, - ) -> Result, Box> { + async fn fetch_api_key(&self, access_key_id: String) -> Result, BackendError> { if access_key_id.is_empty() { return Ok(None); } @@ -445,44 +436,29 @@ impl SourceAPI { reqwest::header::AUTHORIZATION, reqwest::header::HeaderValue::from_str(&source_key).unwrap(), ); - match client + let response = client .get(format!( "{}/api/v1/api-keys/{}/auth", source_api_url, access_key_id )) .headers(headers) .send() - .await - { - Ok(response) => { - if response.status().is_success() { - match response.text().await { - Ok(text) => { - let json: Value = serde_json::from_str(&text).unwrap(); - let secret_access_key = json["secret_access_key"].as_str().unwrap(); - - return Ok(Some(APIKey { - access_key_id, - secret_access_key: secret_access_key.to_string(), - })); - } - Err(_) => Err(Box::new(InternalServerError { - message: "Internal Server Error".to_string(), - })), - } - } else { - if response.status().is_client_error() { - return Ok(None); - } - Err(Box::new(InternalServerError { - message: "Internal Server Error".to_string(), - })) - } - } - Err(_) => Err(Box::new(InternalServerError { - message: "Internal Server Error".to_string(), - })), + .await?; + let status = response.status(); + let text = response.text().await?; + if !status.is_success() { + return Err(BackendError::UnexpectedApiError(format!( + "Failed to get API key for access key id: {}, {}", + access_key_id, text + ))); } + let json: Value = serde_json::from_str(&text).unwrap(); + let secret_access_key = json["secret_access_key"].as_str().unwrap(); + + Ok(Some(APIKey { + access_key_id, + secret_access_key: secret_access_key.to_string(), + })) } async fn fetch_repository( @@ -495,8 +471,7 @@ impl SourceAPI { self.endpoint, account_id, repository_id )) .await?; - process_json_response::(response, BackendError::RepositoryNotFound) - .await + process_json_response::(response, BackendError::RepositoryNotFound).await } pub async fn is_authorized( @@ -535,10 +510,26 @@ impl SourceAPI { self.permissions_cache .insert(cache_key, permissions.clone()) .await; - + Ok(permissions.contains(&permission)) } + pub async fn assert_authorized( + &self, + user_identity: UserIdentity, + account_id: &String, + repository_id: &String, + permission: RepositoryPermission, + ) -> Result { + let authorized = self + .is_authorized(user_identity, account_id, repository_id, permission) + .await?; + if !authorized { + return Err(BackendError::UnauthorizedError); + } + Ok(authorized) + } + async fn fetch_permission( &self, user_identity: UserIdentity, @@ -569,7 +560,7 @@ impl SourceAPI { .headers(headers) .send() .await?; - + process_json_response::>( response, BackendError::RepositoryPermissionsNotFound, diff --git a/src/backends/azure.rs b/src/backends/azure.rs index ea865ea..cd0fdeb 100644 --- a/src/backends/azure.rs +++ b/src/backends/azure.rs @@ -17,7 +17,7 @@ use crate::backends::common::{ GetObjectResponse, HeadObjectResponse, ListBucketResult, Repository, }; use crate::utils::core::replace_first; -use crate::utils::errors::{APIError, InternalServerError, ObjectNotFoundError}; +use crate::utils::errors::BackendError; use super::common::{MultipartPart, UploadPartResponse}; @@ -52,7 +52,7 @@ impl Repository for AzureRepository { &self, key: String, range: Option, - ) -> Result> { + ) -> Result { let credentials = StorageCredentials::anonymous(); let client = BlobServiceClient::new(format!("{}", &self.account_name), credentials) @@ -64,96 +64,81 @@ impl Repository for AzureRepository { key )); - match blob_client.get_properties().await { - Ok(blob) => { - let content_type = blob.blob.properties.content_type.to_string(); - let etag = blob.blob.properties.etag.to_string(); - let last_modified = rfc2822_to_rfc7231( - blob.blob - .properties - .last_modified - .format(&Rfc2822) - .unwrap_or_else(|_| String::from("Invalid DateTime")) - .as_str(), - ) - .unwrap_or_else(|_| String::from("Invalid DateTime")); - - let client = reqwest::Client::new(); - - // Start building the request - let mut request = client.get(format!( - "https://{}.blob.core.windows.net/{}/{}/{}", - self.account_name, - self.container_name, - self.base_prefix.trim_end_matches('/').to_string(), - key - )); - - // If a range is provided, add it to the headers - if let Some(range_value) = range { - request = request.header(RANGE, range_value); - } + let blob = blob_client.get_properties().await?; + let content_type = blob.blob.properties.content_type.to_string(); + let etag = blob.blob.properties.etag.to_string(); + let last_modified = rfc2822_to_rfc7231( + blob.blob + .properties + .last_modified + .format(&Rfc2822) + .unwrap_or_else(|_| String::from("Invalid DateTime")) + .as_str(), + ) + .unwrap_or_else(|_| String::from("Invalid DateTime")); + + let client = reqwest::Client::new(); + + // Start building the request + let mut request = client.get(format!( + "https://{}.blob.core.windows.net/{}/{}/{}", + self.account_name, + self.container_name, + self.base_prefix.trim_end_matches('/').to_string(), + key + )); - // Send the request and await the response - match request.send().await { - Ok(response) => { - // Check if the status code is successful - if !response.status().is_success() { - return Err(Box::new(InternalServerError { - message: "Internal Server Error".to_string(), - })); - } + // If a range is provided, add it to the headers + if let Some(range_value) = range { + request = request.header(RANGE, range_value); + } - // Get the byte stream from the response - let content_length = response.content_length(); - let stream = response.bytes_stream(); - let boxed_stream: Pin< - Box> + Send>, - > = Box::pin(stream); - - Ok(GetObjectResponse { - content_length: content_length.unwrap_or(0) as u64, - content_type, - etag, - last_modified, - body: boxed_stream, - }) - } - Err(_) => Err(Box::new(InternalServerError { - message: "Internal Server Error".to_string(), - })), - } - } - Err(_) => Err(Box::new(InternalServerError { - message: "Internal Server Error".to_string(), - })), + // Send the request and await the response + let response = request.send().await?; + // Check if the status code is successful + if !response.status().is_success() { + return Err(BackendError::UnexpectedApiError(response.text().await?)); } + + // Get the byte stream from the response + let content_length = response.content_length(); + let stream = response.bytes_stream(); + let boxed_stream: Pin> + Send>> = + Box::pin(stream); + + Ok(GetObjectResponse { + content_length: content_length.unwrap_or(0) as u64, + content_type, + etag, + last_modified, + body: boxed_stream, + }) } - async fn delete_object(&self, _key: String) -> Result<(), Box> { - Err(Box::new(InternalServerError { - message: "Internal Server Error".to_string(), - })) + async fn delete_object(&self, _key: String) -> Result<(), BackendError> { + Err(BackendError::UnsupportedOperation( + "Delete object is not supported on Azure".to_string(), + )) } async fn create_multipart_upload( &self, _key: String, _content_type: Option, - ) -> Result> { - Err(Box::new(InternalServerError { - message: format!("Internal Server Error"), - })) + ) -> Result { + Err(BackendError::UnsupportedOperation( + "Create multipart upload is not supported on Azure".to_string(), + )) } async fn abort_multipart_upload( &self, _key: String, _upload_id: String, - ) -> Result<(), Box> { - Err(Box::new(InternalServerError { - message: format!("Internal Server Error"), - })) + ) -> Result<(), BackendError> { + Err(BackendError::UnsupportedOperation( + "Abort multipart upload is not supported on Azure".to_string(), + )) } async fn complete_multipart_upload( @@ -161,10 +146,10 @@ impl Repository for AzureRepository { _key: String, _upload_id: String, _parts: Vec, - ) -> Result> { - Err(Box::new(InternalServerError { - message: format!("Internal Server Error"), - })) + ) -> Result { + Err(BackendError::UnsupportedOperation( + "Complete multipart upload is not supported on Azure".to_string(), + )) } async fn upload_multipart_part( @@ -173,10 +158,10 @@ impl Repository for AzureRepository { _upload_id: String, _part_number: String, _bytes: Bytes, - ) -> Result> { - Err(Box::new(InternalServerError { - message: format!("Internal Server Error"), - })) + ) -> Result { + Err(BackendError::UnsupportedOperation( + "Upload multipart part is not supported on Azure".to_string(), + )) } async fn put_object( @@ -184,56 +169,42 @@ impl Repository for AzureRepository { _key: String, _bytes: Bytes, _content_type: Option, - ) -> Result<(), Box> { - Err(Box::new(InternalServerError { - message: "Internal Server Error".to_string(), - })) + ) -> Result<(), BackendError> { + Err(BackendError::UnsupportedOperation( + "Put object is not supported on Azure".to_string(), + )) } - async fn head_object(&self, key: String) -> Result> { + async fn head_object(&self, key: String) -> Result { let credentials = StorageCredentials::anonymous(); // Create a client for anonymous access let client = BlobServiceClient::new(format!("{}", &self.account_name), credentials) .container_client(&self.container_name); - match client + let blob = client .blob_client(format!( "{}/{}", self.base_prefix.trim_end_matches('/').to_string(), key )) .get_properties() - .await - { - Ok(blob) => Ok(HeadObjectResponse { - content_length: blob.blob.properties.content_length, - content_type: blob.blob.properties.content_type.to_string(), - etag: blob.blob.properties.etag.to_string(), - last_modified: rfc2822_to_rfc7231( - blob.blob - .properties - .last_modified - .format(&Rfc2822) - .unwrap_or_else(|_| String::from("Invalid DateTime")) - .as_str(), - ) - .unwrap_or_else(|_| String::from("Invalid DateTime")), - }), - Err(e) => { - if e.as_http_error().unwrap().status() == 404 { - return Err(Box::new(ObjectNotFoundError { - account_id: self.account_id.clone(), - repository_id: self.repository_id.clone(), - key, - })); - } else { - Err(Box::new(InternalServerError { - message: "Internal Server Error".to_string(), - })) - } - } - } + .await?; + + Ok(HeadObjectResponse { + content_length: blob.blob.properties.content_length, + content_type: blob.blob.properties.content_type.to_string(), + etag: blob.blob.properties.etag.to_string(), + last_modified: rfc2822_to_rfc7231( + blob.blob + .properties + .last_modified + .format(&Rfc2822) + .unwrap_or_else(|_| String::from("Invalid DateTime")) + .as_str(), + ) + .unwrap_or_else(|_| String::from("Invalid DateTime")), + }) } async fn list_objects_v2( @@ -242,7 +213,7 @@ impl Repository for AzureRepository { continuation_token: Option, delimiter: Option, max_keys: NonZeroU32, - ) -> Result> { + ) -> Result { let mut result = ListBucketResult { name: format!("{}", self.account_id), prefix: prefix.clone(), @@ -333,7 +304,7 @@ impl Repository for AzureRepository { _copy_identifier_path: String, _key: String, _range: Option, - ) -> Result<(), Box> { + ) -> Result<(), BackendError> { Ok(()) } } diff --git a/src/backends/common.rs b/src/backends/common.rs index 6167c7e..703db50 100644 --- a/src/backends/common.rs +++ b/src/backends/common.rs @@ -1,4 +1,3 @@ -use crate::utils::errors::APIError; use async_trait::async_trait; use bytes::Bytes; use core::num::NonZeroU32; @@ -9,6 +8,7 @@ use std::pin::Pin; use reqwest::Error as ReqwestError; type BoxedReqwestStream = Pin> + Send>>; +use crate::utils::errors::BackendError; pub struct GetObjectResponse { pub content_length: u64, @@ -39,55 +39,55 @@ pub struct CompleteMultipartUploadResponse { #[async_trait] pub trait Repository { - async fn delete_object(&self, key: String) -> Result<(), Box>; + async fn delete_object(&self, key: String) -> Result<(), BackendError>; async fn create_multipart_upload( &self, key: String, content_type: Option, - ) -> Result>; + ) -> Result; async fn abort_multipart_upload( &self, key: String, upload_id: String, - ) -> Result<(), Box>; + ) -> Result<(), BackendError>; async fn complete_multipart_upload( &self, key: String, upload_id: String, parts: Vec, - ) -> Result>; + ) -> Result; async fn upload_multipart_part( &self, key: String, upload_id: String, part_number: String, bytes: Bytes, - ) -> Result>; + ) -> Result; async fn put_object( &self, key: String, bytes: Bytes, content_type: Option, - ) -> Result<(), Box>; + ) -> Result<(), BackendError>; async fn get_object( &self, key: String, range: Option, - ) -> Result>; - async fn head_object(&self, key: String) -> Result>; + ) -> Result; + async fn head_object(&self, key: String) -> Result; async fn list_objects_v2( &self, prefix: String, continuation_token: Option, delimiter: Option, max_keys: NonZeroU32, - ) -> Result>; + ) -> Result; async fn copy_object( &self, copy_identifier_path: String, key: String, range: Option, - ) -> Result<(), Box>; + ) -> Result<(), BackendError>; } #[derive(Debug, Serialize)] diff --git a/src/backends/s3.rs b/src/backends/s3.rs index 0a43443..8bacd85 100644 --- a/src/backends/s3.rs +++ b/src/backends/s3.rs @@ -1,9 +1,10 @@ +use super::common::{MultipartPart, UploadPartResponse}; use crate::backends::common::{ CommonPrefix, CompleteMultipartUploadResponse, Content, CreateMultipartUploadResponse, GetObjectResponse, HeadObjectResponse, ListBucketResult, Repository, }; use crate::utils::core::replace_first; -use crate::utils::errors::{APIError, InternalServerError, ObjectNotFoundError}; +use crate::utils::errors::BackendError; use actix_web::http::header::RANGE; use async_trait::async_trait; use bytes::Bytes; @@ -12,7 +13,6 @@ use core::num::NonZeroU32; use futures_core::Stream; use reqwest; use rusoto_core::Region; -use rusoto_core::RusotoError; use rusoto_s3::{ AbortMultipartUploadRequest, CompleteMultipartUploadRequest, CompletedMultipartUpload, CompletedPart, CreateMultipartUploadRequest, DeleteObjectRequest, HeadObjectRequest, @@ -20,8 +20,6 @@ use rusoto_s3::{ }; use std::pin::Pin; -use super::common::{MultipartPart, UploadPartResponse}; - pub struct S3Repository { pub account_id: String, pub repository_id: String, @@ -33,121 +31,99 @@ pub struct S3Repository { pub secret_access_key: Option, } -#[async_trait] -impl Repository for S3Repository { - async fn get_object( - &self, - key: String, - range: Option, - ) -> Result> { - match self.head_object(key.clone()).await { - Ok(head_object_response) => { - let client = reqwest::Client::new(); - let url: String; - - if self.auth_method == "s3_local" { - url = format!( - "http://localhost:5050/{}/{}/{}", - self.bucket, self.base_prefix, key - ) - } else { - url = format!( - "https://s3.{}.amazonaws.com/{}/{}/{}", - self.region.name(), - self.bucket, - self.base_prefix, - key - ); - } - // Start building the request - let mut request = client.get(url); - - // If a range is provided, add it to the headers - if let Some(range_value) = range { - request = request.header(RANGE, range_value); - } - - // Send the request and await the response - match request.send().await { - Ok(response) => { - // Get the byte stream from the response - let content_length = response.content_length(); - let stream = response.bytes_stream(); - let boxed_stream: Pin< - Box> + Send>, - > = Box::pin(stream); - - Ok(GetObjectResponse { - content_length: content_length.unwrap_or(0) as u64, - content_type: head_object_response.content_type, - etag: head_object_response.etag, - last_modified: head_object_response.last_modified, - body: boxed_stream, - }) - } - Err(error) => { - if error.is_status() { - let code = error.status().unwrap().as_u16(); - if code == 404 { - return Err(Box::new(ObjectNotFoundError { - account_id: self.account_id.clone(), - repository_id: self.repository_id.clone(), - key, - })); - } - } - - return Err(Box::new(InternalServerError { - message: "Internal Server Error".to_string(), - })); - } - } - } - Err(err) => { - // Pass through the error from the head_object call - return Err(err); - } - } - } - - async fn put_object( - &self, - key: String, - bytes: Bytes, - content_type: Option, - ) -> Result<(), Box> { - let client: S3Client; - +impl S3Repository { + fn create_client(&self) -> Result { if self.auth_method == "s3_access_key" { let credentials = rusoto_credential::StaticProvider::new_minimal( self.access_key_id.clone().unwrap(), self.secret_access_key.clone().unwrap(), ); - client = S3Client::new_with( + return Ok(S3Client::new_with( rusoto_core::request::HttpClient::new().unwrap(), credentials, self.region.clone(), - ); + )); } else if self.auth_method == "s3_ecs_task_role" { let credentials = rusoto_credential::ContainerProvider::new(); - client = S3Client::new_with( + return Ok(S3Client::new_with( rusoto_core::request::HttpClient::new().unwrap(), credentials, self.region.clone(), - ); + )); } else if self.auth_method == "s3_local" { let credentials = rusoto_credential::ChainProvider::new(); - client = S3Client::new_with( + return Ok(S3Client::new_with( rusoto_core::request::HttpClient::new().unwrap(), credentials, self.region.clone(), - ); + )); + } else { + return Err(BackendError::UnsupportedAuthMethod(format!( + "Unsupported auth method: {}", + self.auth_method + ))); + } + } +} + +#[async_trait] +impl Repository for S3Repository { + async fn get_object( + &self, + key: String, + range: Option, + ) -> Result { + let head_object_response = self.head_object(key.clone()).await?; + let client = reqwest::Client::new(); + let url: String; + + if self.auth_method == "s3_local" { + url = format!( + "http://localhost:5050/{}/{}/{}", + self.bucket, self.base_prefix, key + ) } else { - return Err(Box::new(InternalServerError { - message: format!("Internal Server Error"), - })); + url = format!( + "https://s3.{}.amazonaws.com/{}/{}/{}", + self.region.name(), + self.bucket, + self.base_prefix, + key + ); + } + // Start building the request + let mut request = client.get(url); + + // If a range is provided, add it to the headers + if let Some(range_value) = range { + request = request.header(RANGE, range_value); } + // Send the request and await the response + let response = request.send().await?; + // Get the byte stream from the response + let content_length = response.content_length(); + let stream = response.bytes_stream(); + let boxed_stream: Pin> + Send>> = + Box::pin(stream); + + Ok(GetObjectResponse { + content_length: content_length.unwrap_or(0) as u64, + content_type: head_object_response.content_type, + etag: head_object_response.etag, + last_modified: head_object_response.last_modified, + body: boxed_stream, + }) + } + + async fn put_object( + &self, + key: String, + bytes: Bytes, + content_type: Option, + ) -> Result<(), BackendError> { + let client = self.create_client()?; + let request = PutObjectRequest { bucket: self.bucket.clone(), key: format!("{}/{}", self.base_prefix, key), @@ -156,50 +132,16 @@ impl Repository for S3Repository { ..Default::default() }; - match client.put_object(request).await { - Ok(_) => Ok(()), - Err(e) => Err(Box::new(InternalServerError { - message: format!("Internal Server Error"), - })), - } + client.put_object(request).await?; + Ok(()) } async fn create_multipart_upload( &self, key: String, content_type: Option, - ) -> Result> { - let client: S3Client; - - if self.auth_method == "s3_access_key" { - let credentials = rusoto_credential::StaticProvider::new_minimal( - self.access_key_id.clone().unwrap(), - self.secret_access_key.clone().unwrap(), - ); - client = S3Client::new_with( - rusoto_core::request::HttpClient::new().unwrap(), - credentials, - self.region.clone(), - ); - } else if self.auth_method == "s3_ecs_task_role" { - let credentials = rusoto_credential::ContainerProvider::new(); - client = S3Client::new_with( - rusoto_core::request::HttpClient::new().unwrap(), - credentials, - self.region.clone(), - ); - } else if self.auth_method == "s3_local" { - let credentials = rusoto_credential::ChainProvider::new(); - client = S3Client::new_with( - rusoto_core::request::HttpClient::new().unwrap(), - credentials, - self.region.clone(), - ); - } else { - return Err(Box::new(InternalServerError { - message: format!("Internal Server Error"), - })); - } + ) -> Result { + let client = self.create_client()?; let request = CreateMultipartUploadRequest { bucket: self.bucket.clone(), @@ -208,54 +150,20 @@ impl Repository for S3Repository { ..Default::default() }; - match client.create_multipart_upload(request).await { - Ok(result) => Ok(CreateMultipartUploadResponse { - bucket: self.account_id.clone(), - key: key.clone(), - upload_id: result.upload_id.unwrap(), - }), - Err(e) => Err(Box::new(InternalServerError { - message: format!("Internal Server Error"), - })), - } + let result = client.create_multipart_upload(request).await?; + Ok(CreateMultipartUploadResponse { + bucket: self.account_id.clone(), + key: key.clone(), + upload_id: result.upload_id.unwrap(), + }) } async fn abort_multipart_upload( &self, key: String, upload_id: String, - ) -> Result<(), Box> { - let client: S3Client; - - if self.auth_method == "s3_access_key" { - let credentials = rusoto_credential::StaticProvider::new_minimal( - self.access_key_id.clone().unwrap(), - self.secret_access_key.clone().unwrap(), - ); - client = S3Client::new_with( - rusoto_core::request::HttpClient::new().unwrap(), - credentials, - self.region.clone(), - ); - } else if self.auth_method == "s3_ecs_task_role" { - let credentials = rusoto_credential::ContainerProvider::new(); - client = S3Client::new_with( - rusoto_core::request::HttpClient::new().unwrap(), - credentials, - self.region.clone(), - ); - } else if self.auth_method == "s3_local" { - let credentials = rusoto_credential::ChainProvider::new(); - client = S3Client::new_with( - rusoto_core::request::HttpClient::new().unwrap(), - credentials, - self.region.clone(), - ); - } else { - return Err(Box::new(InternalServerError { - message: format!("Internal Server Error"), - })); - } + ) -> Result<(), BackendError> { + let client = self.create_client()?; let request = AbortMultipartUploadRequest { bucket: self.bucket.clone(), @@ -264,12 +172,8 @@ impl Repository for S3Repository { ..Default::default() }; - match client.abort_multipart_upload(request).await { - Ok(_) => Ok(()), - Err(_) => Err(Box::new(InternalServerError { - message: format!("Internal Server Error"), - })), - } + client.abort_multipart_upload(request).await?; + Ok(()) } async fn complete_multipart_upload( @@ -277,38 +181,8 @@ impl Repository for S3Repository { key: String, upload_id: String, parts: Vec, - ) -> Result> { - let client: S3Client; - - if self.auth_method == "s3_access_key" { - let credentials = rusoto_credential::StaticProvider::new_minimal( - self.access_key_id.clone().unwrap(), - self.secret_access_key.clone().unwrap(), - ); - client = S3Client::new_with( - rusoto_core::request::HttpClient::new().unwrap(), - credentials, - self.region.clone(), - ); - } else if self.auth_method == "s3_ecs_task_role" { - let credentials = rusoto_credential::ContainerProvider::new(); - client = S3Client::new_with( - rusoto_core::request::HttpClient::new().unwrap(), - credentials, - self.region.clone(), - ); - } else if self.auth_method == "s3_local" { - let credentials = rusoto_credential::ChainProvider::new(); - client = S3Client::new_with( - rusoto_core::request::HttpClient::new().unwrap(), - credentials, - self.region.clone(), - ); - } else { - return Err(Box::new(InternalServerError { - message: format!("Internal Server Error"), - })); - } + ) -> Result { + let client = self.create_client()?; let request = CompleteMultipartUploadRequest { bucket: self.bucket.clone(), @@ -328,17 +202,13 @@ impl Repository for S3Repository { ..Default::default() }; - match client.complete_multipart_upload(request).await { - Ok(result) => Ok(CompleteMultipartUploadResponse { - location: "".to_string(), - bucket: self.account_id.clone(), - key: key.clone(), - etag: result.e_tag.unwrap(), - }), - Err(e) => Err(Box::new(InternalServerError { - message: format!("Internal Server Error"), - })), - } + let result = client.complete_multipart_upload(request).await?; + Ok(CompleteMultipartUploadResponse { + location: "".to_string(), + bucket: self.account_id.clone(), + key: key.clone(), + etag: result.e_tag.unwrap(), + }) } async fn upload_multipart_part( @@ -347,38 +217,8 @@ impl Repository for S3Repository { upload_id: String, part_number: String, bytes: Bytes, - ) -> Result> { - let client: S3Client; - - if self.auth_method == "s3_access_key" { - let credentials = rusoto_credential::StaticProvider::new_minimal( - self.access_key_id.clone().unwrap(), - self.secret_access_key.clone().unwrap(), - ); - client = S3Client::new_with( - rusoto_core::request::HttpClient::new().unwrap(), - credentials, - self.region.clone(), - ); - } else if self.auth_method == "s3_ecs_task_role" { - let credentials = rusoto_credential::ContainerProvider::new(); - client = S3Client::new_with( - rusoto_core::request::HttpClient::new().unwrap(), - credentials, - self.region.clone(), - ); - } else if self.auth_method == "s3_local" { - let credentials = rusoto_credential::ChainProvider::new(); - client = S3Client::new_with( - rusoto_core::request::HttpClient::new().unwrap(), - credentials, - self.region.clone(), - ); - } else { - return Err(Box::new(InternalServerError { - message: format!("Internal Server Error"), - })); - } + ) -> Result { + let client = self.create_client()?; let request = UploadPartRequest { bucket: self.bucket.clone(), @@ -389,128 +229,44 @@ impl Repository for S3Repository { ..Default::default() }; - match client.upload_part(request).await { - Ok(result) => Ok(UploadPartResponse { - etag: result.e_tag.unwrap(), - }), - Err(_) => Err(Box::new(InternalServerError { - message: format!("Internal Server Error"), - })), - } + let result = client.upload_part(request).await?; + Ok(UploadPartResponse { + etag: result.e_tag.unwrap(), + }) } - async fn delete_object(&self, key: String) -> Result<(), Box> { - let client: S3Client; + async fn delete_object(&self, key: String) -> Result<(), BackendError> { + let client = self.create_client()?; - if self.auth_method == "s3_access_key" { - let credentials = rusoto_credential::StaticProvider::new_minimal( - self.access_key_id.clone().unwrap(), - self.secret_access_key.clone().unwrap(), - ); - client = S3Client::new_with( - rusoto_core::request::HttpClient::new().unwrap(), - credentials, - self.region.clone(), - ); - } else if self.auth_method == "s3_ecs_task_role" { - let credentials = rusoto_credential::ContainerProvider::new(); - client = S3Client::new_with( - rusoto_core::request::HttpClient::new().unwrap(), - credentials, - self.region.clone(), - ); - } else if self.auth_method == "s3_local" { - let credentials = rusoto_credential::ChainProvider::new(); - client = S3Client::new_with( - rusoto_core::request::HttpClient::new().unwrap(), - credentials, - self.region.clone(), - ); - } else { - return Err(Box::new(InternalServerError { - message: format!("Internal Server Error"), - })); - } let request = DeleteObjectRequest { bucket: self.bucket.clone(), key: format!("{}/{}", self.base_prefix, key), ..Default::default() }; - match client.delete_object(request).await { - Ok(_) => Ok(()), - Err(_) => Err(Box::new(InternalServerError { - message: format!("Internal Server Error"), - })), - } + client.delete_object(request).await?; + Ok(()) } - async fn head_object(&self, key: String) -> Result> { - let client: S3Client; + async fn head_object(&self, key: String) -> Result { + let client = self.create_client()?; - if self.auth_method == "s3_access_key" { - let credentials = rusoto_credential::StaticProvider::new_minimal( - self.access_key_id.clone().unwrap(), - self.secret_access_key.clone().unwrap(), - ); - client = S3Client::new_with( - rusoto_core::request::HttpClient::new().unwrap(), - credentials, - self.region.clone(), - ); - } else if self.auth_method == "s3_ecs_task_role" { - let credentials = rusoto_credential::ContainerProvider::new(); - client = S3Client::new_with( - rusoto_core::request::HttpClient::new().unwrap(), - credentials, - self.region.clone(), - ); - } else if self.auth_method == "s3_local" { - let credentials = rusoto_credential::ChainProvider::new(); - client = S3Client::new_with( - rusoto_core::request::HttpClient::new().unwrap(), - credentials, - self.region.clone(), - ); - } else { - return Err(Box::new(InternalServerError { - message: format!("Internal Server Error"), - })); - } let request = HeadObjectRequest { bucket: self.bucket.clone(), key: format!("{}/{}", self.base_prefix, key), ..Default::default() }; - match client.head_object(request).await { - Ok(result) => Ok(HeadObjectResponse { - content_length: result.content_length.unwrap_or(0) as u64, - content_type: result.content_type.unwrap_or_else(|| "".to_string()), - etag: result.e_tag.unwrap_or_else(|| "".to_string()), - last_modified: result - .last_modified - .unwrap_or_else(|| Utc::now().to_rfc2822()), - }), - Err(error) => { - match error { - RusotoError::Unknown(response) => { - if response.status.eq(&404) { - return Err(Box::new(ObjectNotFoundError { - account_id: self.account_id.clone(), - repository_id: self.repository_id.clone(), - key, - })); - } - } - _ => (), - } - - Err(Box::new(InternalServerError { - message: format!("Internal Server Error"), - })) - } - } + let result = client.head_object(request).await?; + + Ok(HeadObjectResponse { + content_length: result.content_length.unwrap_or(0) as u64, + content_type: result.content_type.unwrap_or_else(|| "".to_string()), + etag: result.e_tag.unwrap_or_else(|| "".to_string()), + last_modified: result + .last_modified + .unwrap_or_else(|| Utc::now().to_rfc2822()), + }) } async fn list_objects_v2( @@ -519,38 +275,9 @@ impl Repository for S3Repository { continuation_token: Option, delimiter: Option, max_keys: NonZeroU32, - ) -> Result> { - let client: S3Client; + ) -> Result { + let client = self.create_client()?; - if self.auth_method == "s3_access_key" { - let credentials = rusoto_credential::StaticProvider::new_minimal( - self.access_key_id.clone().unwrap(), - self.secret_access_key.clone().unwrap(), - ); - client = S3Client::new_with( - rusoto_core::request::HttpClient::new().unwrap(), - credentials, - self.region.clone(), - ); - } else if self.auth_method == "s3_ecs_task_role" { - let credentials = rusoto_credential::ContainerProvider::new(); - client = S3Client::new_with( - rusoto_core::request::HttpClient::new().unwrap(), - credentials, - self.region.clone(), - ); - } else if self.auth_method == "s3_local" { - let credentials = rusoto_credential::ChainProvider::new(); - client = S3Client::new_with( - rusoto_core::request::HttpClient::new().unwrap(), - credentials, - self.region.clone(), - ); - } else { - return Err(Box::new(InternalServerError { - message: format!("Internal Server Error"), - })); - } let mut request = ListObjectsV2Request { bucket: self.bucket.clone(), prefix: Some(format!("{}/{}", self.base_prefix, prefix)), @@ -563,180 +290,95 @@ impl Repository for S3Repository { request.continuation_token = Some(token); } - match client.list_objects_v2(request).await { - Ok(output) => { - let result = ListBucketResult { - name: format!("{}", self.account_id), - prefix: format!("{}/{}", self.repository_id, prefix), - key_count: output.key_count.unwrap_or(0), - max_keys: output.max_keys.unwrap_or(0), - is_truncated: output.is_truncated.unwrap_or(false), - next_continuation_token: output.next_continuation_token, - contents: output - .contents - .unwrap_or_default() - .iter() - .map(|item| Content { - key: replace_first( - item.key.clone().unwrap_or_else(|| "".to_string()), - self.base_prefix.clone(), - format!("{}", self.repository_id), - ), - last_modified: item - .last_modified - .clone() - .unwrap_or_else(|| Utc::now().to_rfc2822()), - etag: item.e_tag.clone().unwrap_or_else(|| "".to_string()), - size: item.size.unwrap_or(0), - storage_class: item - .storage_class - .clone() - .unwrap_or_else(|| "".to_string()), - }) - .collect(), - common_prefixes: output - .common_prefixes - .unwrap_or_default() - .iter() - .map(|item| CommonPrefix { - prefix: replace_first( - item.prefix.clone().unwrap_or_else(|| "".to_string()), - self.base_prefix.clone(), - format!("{}", self.repository_id), - ), - }) - .collect(), - }; - - return Ok(result); - } - Err(error) => { - return Err(Box::new(InternalServerError { - message: "Internal Server Error".to_string(), - })); - } - } + let output = client.list_objects_v2(request).await?; + let result = ListBucketResult { + name: format!("{}", self.account_id), + prefix: format!("{}/{}", self.repository_id, prefix), + key_count: output.key_count.unwrap_or(0), + max_keys: output.max_keys.unwrap_or(0), + is_truncated: output.is_truncated.unwrap_or(false), + next_continuation_token: output.next_continuation_token, + contents: output + .contents + .unwrap_or_default() + .iter() + .map(|item| Content { + key: replace_first( + item.key.clone().unwrap_or_else(|| "".to_string()), + self.base_prefix.clone(), + format!("{}", self.repository_id), + ), + last_modified: item + .last_modified + .clone() + .unwrap_or_else(|| Utc::now().to_rfc2822()), + etag: item.e_tag.clone().unwrap_or_else(|| "".to_string()), + size: item.size.unwrap_or(0), + storage_class: item.storage_class.clone().unwrap_or_else(|| "".to_string()), + }) + .collect(), + common_prefixes: output + .common_prefixes + .unwrap_or_default() + .iter() + .map(|item| CommonPrefix { + prefix: replace_first( + item.prefix.clone().unwrap_or_else(|| "".to_string()), + self.base_prefix.clone(), + format!("{}", self.repository_id), + ), + }) + .collect(), + }; + + Ok(result) } + async fn copy_object( &self, copy_identifier_path: String, key: String, range: Option, - ) -> Result<(), Box> { - let client: S3Client; + ) -> Result<(), BackendError> { + let client = self.create_client()?; - if self.auth_method == "s3_access_key" { - let credentials = rusoto_credential::StaticProvider::new_minimal( - self.access_key_id.clone().unwrap(), - self.secret_access_key.clone().unwrap(), - ); - client = S3Client::new_with( - rusoto_core::request::HttpClient::new().unwrap(), - credentials, - self.region.clone(), - ); - } else if self.auth_method == "s3_ecs_task_role" { - let credentials = rusoto_credential::ContainerProvider::new(); - client = S3Client::new_with( - rusoto_core::request::HttpClient::new().unwrap(), - credentials, - self.region.clone(), - ); - } else if self.auth_method == "s3_local" { - let credentials = rusoto_credential::ChainProvider::new(); - client = S3Client::new_with( - rusoto_core::request::HttpClient::new().unwrap(), - credentials, - self.region.clone(), - ); - } else { - return Err(Box::new(InternalServerError { - message: format!("Internal Server Error"), - })); - } let request = HeadObjectRequest { bucket: self.bucket.clone(), key: format!("{}", copy_identifier_path), ..Default::default() }; - match client.head_object(request).await { - Ok(result) => { - let url_client = reqwest::Client::new(); - let url: String; - - if self.auth_method == "s3_local" { - url = format!( - "http://localhost:5050/{}/{}", - self.bucket, copy_identifier_path - ) - } else { - url = format!( - "https://s3.{}.amazonaws.com/{}/{}", - self.region.name(), - self.bucket, - copy_identifier_path - ); - } - - let mut request = url_client.get(url); - - if let Some(range_value) = range { - request = request.header(RANGE, range_value); - } - - match request.send().await { - Ok(response) => { - let content_bytes = response - .bytes() - .await - .unwrap_or_else(|_| bytes::Bytes::from(vec![])); - match self - .put_object(key.clone(), content_bytes, result.content_type) - .await - { - Ok(_put_res) => Ok(()), - Err(err) => { - return Err(err); - } - } - } - Err(error) => { - if error.is_status() { - let code = error.status().unwrap().as_u16(); - if code == 404 { - return Err(Box::new(ObjectNotFoundError { - account_id: self.account_id.clone(), - repository_id: self.repository_id.clone(), - key, - })); - } - } - - return Err(Box::new(InternalServerError { - message: "Internal Server Error".to_string(), - })); - } - } - } - Err(error) => { - match error { - RusotoError::Unknown(response) => { - if response.status.eq(&404) { - return Err(Box::new(ObjectNotFoundError { - account_id: self.account_id.clone(), - repository_id: self.repository_id.clone(), - key, - })); - } - } - _ => (), - } - - Err(Box::new(InternalServerError { - message: format!("Internal Server Error"), - })) - } + let result = client.head_object(request).await?; + let url_client = reqwest::Client::new(); + let url: String; + + if self.auth_method == "s3_local" { + url = format!( + "http://localhost:5050/{}/{}", + self.bucket, copy_identifier_path + ) + } else { + url = format!( + "https://s3.{}.amazonaws.com/{}/{}", + self.region.name(), + self.bucket, + copy_identifier_path + ); + } + + let mut request = url_client.get(url); + + if let Some(range_value) = range { + request = request.header(RANGE, range_value); } + + let response = request.send().await?; + let content_bytes = response + .bytes() + .await + .unwrap_or_else(|_| bytes::Bytes::from(vec![])); + self.put_object(key.clone(), content_bytes, result.content_type) + .await?; + Ok(()) } } diff --git a/src/utils/errors.rs b/src/utils/errors.rs index 498be59..a402c84 100644 --- a/src/utils/errors.rs +++ b/src/utils/errors.rs @@ -1,45 +1,68 @@ use actix_web::error; use actix_web::http::StatusCode; use actix_web::HttpResponse; +use azure_core::error::Error as AzureError; use log::error; +use quick_xml::DeError; use reqwest::Error as ReqwestError; -use serde::Serialize; -use std::error::Error; -use std::fmt; +use rusoto_core::RusotoError; +use rusoto_s3::{ + AbortMultipartUploadError, CompleteMultipartUploadError, CreateMultipartUploadError, + DeleteObjectError, HeadObjectError, ListObjectsV2Error, PutObjectError, UploadPartError, +}; use thiserror::Error; #[derive(Error, Debug)] pub enum BackendError { #[error("repository not found")] RepositoryNotFound, + #[error("failed to fetch repository permissions")] RepositoryPermissionsNotFound, + #[error("source repository missing primary mirror")] SourceRepositoryMissingPrimaryMirror, + #[error("data connection not found")] DataConnectionNotFound, + #[error("reqwest error (url {}, message {})", .0.url().map(|u| u.to_string()).unwrap_or("unknown".to_string()), .0.to_string())] ReqwestError(#[from] ReqwestError), - #[error("Api threw a server error (url {}, status {}, message {})", .url, .status, .message)] + + #[error("api threw a server error (url {}, status {}, message {})", .url, .status, .message)] ApiServerError { url: String, status: u16, message: String, }, - #[error("Api threw a client error (url {}, status {}, message {})", .url, .status, .message)] + + #[error("api threw a client error (url {}, status {}, message {})", .url, .status, .message)] ApiClientError { url: String, status: u16, message: String, }, - #[error("Failed to parse JSON (url {})", .url)] + + #[error("failed to parse JSON (url {})", .url)] JsonParseError { url: String }, - #[error("Unexpected data connection provider (provider {})", .provider)] + + #[error("unexpected data connection provider (provider {})", .provider)] UnexpectedDataConnectionProvider { provider: String }, - #[error("Unauthorized")] - UnauthorizationError, - #[error("Unexpected API error: {0}")] // TODO: remove this + + #[error("unauthorized")] + UnauthorizedError, + + #[error("unexpected API error: {0}")] // TODO: remove this UnexpectedApiError(String), + + #[error("unsupported auth method: {0}")] + UnsupportedAuthMethod(String), + + #[error("unsupported operation: {0}")] + UnsupportedOperation(String), + + #[error("xml parse error: {0}")] + XmlParseError(String), } impl error::ResponseError for BackendError { @@ -57,8 +80,11 @@ impl error::ResponseError for BackendError { HttpResponse::InternalServerError().finish() } BackendError::RepositoryPermissionsNotFound => HttpResponse::BadGateway().finish(), - BackendError::UnauthorizationError => HttpResponse::Unauthorized().finish(), + BackendError::UnauthorizedError => HttpResponse::Unauthorized().finish(), BackendError::UnexpectedApiError(_) => HttpResponse::InternalServerError().finish(), + BackendError::UnsupportedAuthMethod(_) => HttpResponse::BadRequest().finish(), + BackendError::UnsupportedOperation(_) => HttpResponse::BadRequest().finish(), + BackendError::XmlParseError(_) => HttpResponse::InternalServerError().finish(), } } @@ -75,58 +101,66 @@ impl error::ResponseError for BackendError { StatusCode::INTERNAL_SERVER_ERROR } BackendError::RepositoryPermissionsNotFound => StatusCode::BAD_GATEWAY, - BackendError::UnauthorizationError => StatusCode::UNAUTHORIZED, + BackendError::UnauthorizedError => StatusCode::UNAUTHORIZED, BackendError::UnexpectedApiError(_) => StatusCode::INTERNAL_SERVER_ERROR, + BackendError::UnsupportedAuthMethod(_) => StatusCode::BAD_REQUEST, + BackendError::UnsupportedOperation(_) => StatusCode::BAD_REQUEST, + BackendError::XmlParseError(_) => StatusCode::INTERNAL_SERVER_ERROR, } } } -impl From> for BackendError { - fn from(error: Box) -> BackendError { +// Azure API Errors +impl From for BackendError { + fn from(error: AzureError) -> BackendError { BackendError::UnexpectedApiError(error.to_string()) } } -pub trait APIError: std::error::Error + Send + Sync { - fn to_response(&self) -> HttpResponse; +// S3 API Errors +impl From> for BackendError { + fn from(error: RusotoError) -> BackendError { + BackendError::UnexpectedApiError(error.to_string()) + } } - -#[derive(Serialize, Debug)] -pub struct ObjectNotFoundError { - pub account_id: String, - pub repository_id: String, - pub key: String, +impl From> for BackendError { + fn from(error: RusotoError) -> BackendError { + BackendError::UnexpectedApiError(error.to_string()) + } } - -impl APIError for ObjectNotFoundError { - fn to_response(&self) -> HttpResponse { - HttpResponse::NotFound().json(self) +impl From> for BackendError { + fn from(error: RusotoError) -> BackendError { + BackendError::UnexpectedApiError(error.to_string()) } } - -impl fmt::Display for ObjectNotFoundError { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - write!(f, "Object Not Found: {}", self.key) +impl From> for BackendError { + fn from(error: RusotoError) -> BackendError { + BackendError::UnexpectedApiError(error.to_string()) } } - -impl Error for ObjectNotFoundError {} - -#[derive(Serialize, Debug)] -pub struct InternalServerError { - pub message: String, +impl From> for BackendError { + fn from(error: RusotoError) -> BackendError { + BackendError::UnexpectedApiError(error.to_string()) + } } - -impl APIError for InternalServerError { - fn to_response(&self) -> HttpResponse { - HttpResponse::InternalServerError().json(self) +impl From> for BackendError { + fn from(error: RusotoError) -> BackendError { + BackendError::UnexpectedApiError(error.to_string()) } } - -impl fmt::Display for InternalServerError { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - write!(f, "Internal Server Error: {}", self.message) +impl From> for BackendError { + fn from(error: RusotoError) -> BackendError { + BackendError::UnexpectedApiError(error.to_string()) + } +} +impl From> for BackendError { + fn from(error: RusotoError) -> BackendError { + BackendError::UnexpectedApiError(error.to_string()) } } -impl Error for InternalServerError {} +impl From for BackendError { + fn from(error: DeError) -> BackendError { + BackendError::XmlParseError(format!("failed to parse xml: {}", error)) + } +} From d396b8e1891fdc53be8108999b8cc37be309fe5c Mon Sep 17 00:00:00 2001 From: Anthony Lukach Date: Mon, 12 May 2025 22:25:34 -0700 Subject: [PATCH 2/4] Finalize flattening & error handling --- src/main.rs | 497 +++++++++++++++++--------------------------- src/utils/errors.rs | 14 +- 2 files changed, 198 insertions(+), 313 deletions(-) diff --git a/src/main.rs b/src/main.rs index 339b8b9..c6abfa5 100644 --- a/src/main.rs +++ b/src/main.rs @@ -17,7 +17,6 @@ use bytes::Bytes; use core::num::NonZeroU32; use env_logger::Env; use futures_util::StreamExt; -use log::error; use quick_xml::se::to_string_with_root; use serde::Deserialize; use serde_xml_rs::from_str; @@ -49,8 +48,6 @@ impl MessageBody for FakeBody { } } -// TODO: Map the APIErrors to HTTP Responses - #[get("/{account_id}/{repository_id}/{key:.*}")] async fn get_object( api_client: web::Data, @@ -85,8 +82,8 @@ async fn get_object( .get_backend_client(&account_id, &repository_id) .await?; - let authorized = api_client - .is_authorized( + api_client + .assert_authorized( user_identity.into_inner(), &account_id, &repository_id, @@ -94,10 +91,6 @@ async fn get_object( ) .await?; - if !authorized { - return Err(BackendError::UnauthorizationError); - } - // Found the repository, now try to get the object let res = client.get_object(key.clone(), range).await?; @@ -158,52 +151,31 @@ async fn delete_object( params: web::Query, path: web::Path<(String, String, String)>, user_identity: web::ReqData, -) -> impl Responder { +) -> Result { let (account_id, repository_id, key) = path.into_inner(); - if let Ok(client) = api_client + let client = api_client .get_backend_client(&account_id, &repository_id) - .await - { - match api_client - .is_authorized( - user_identity.into_inner(), - &account_id, - &repository_id, - RepositoryPermission::Write, - ) - .await - { - Ok(authorized) => { - if !authorized { - return HttpResponse::Unauthorized().finish(); - } - } - Err(_) => return HttpResponse::InternalServerError().finish(), - } + .await?; - if params.upload_id.is_none() { - // Found the repository, now try to delete the object - match client.delete_object(key.clone()).await { - Ok(_) => { - return HttpResponse::NoContent().finish(); - } - Err(_) => HttpResponse::NotFound().finish(), - } - } else { - match client - .abort_multipart_upload(key.clone(), params.upload_id.clone().unwrap()) - .await - { - Ok(_) => { - return HttpResponse::NoContent().finish(); - } - Err(_) => HttpResponse::NotFound().finish(), - } - } + api_client + .assert_authorized( + user_identity.into_inner(), + &account_id, + &repository_id, + RepositoryPermission::Write, + ) + .await?; + + if params.upload_id.is_none() { + // Found the repository, now try to delete the object + client.delete_object(key.clone()).await?; + Ok(HttpResponse::NoContent().finish()) } else { - // Could not find the repository - return HttpResponse::NotFound().finish(); + client + .abort_multipart_upload(key.clone(), params.upload_id.clone().unwrap()) + .await?; + Ok(HttpResponse::NoContent().finish()) } } @@ -223,85 +195,61 @@ async fn put_object( params: web::Query, path: web::Path<(String, String, String)>, user_identity: web::ReqData, -) -> impl Responder { +) -> Result { let (account_id, repository_id, key) = path.into_inner(); let headers = req.headers(); - if let Ok(client) = api_client + let client = api_client .get_backend_client(&account_id, &repository_id) - .await - { - match api_client - .is_authorized( - user_identity.into_inner(), - &account_id, - &repository_id, - RepositoryPermission::Write, - ) - .await - { - Ok(authorized) => { - if !authorized { - return HttpResponse::Unauthorized().finish(); - } - } - Err(_) => return HttpResponse::InternalServerError().finish(), - } + .await?; - if params.part_number.is_none() && params.upload_id.is_none() { - // Check if this is a server-side copy operation - if let Some(header_copy_identifier) = req.headers().get("x-amz-copy-source") { - let copy_identifier_path = header_copy_identifier.to_str().unwrap_or(""); - match client - .copy_object((©_identifier_path).to_string(), key.clone(), None) - .await - { - Ok(_) => HttpResponse::NoContent().finish(), - Err(_) => { - return HttpResponse::InternalServerError() - .body("Failed to store copied object") - } - } - } else { - // Found the repository, now try to upload the object - match client - .put_object( - key.clone(), - bytes, - headers - .get(CONTENT_TYPE) - .and_then(|h| h.to_str().ok()) - .map(|s| s.to_string()), - ) - .await - { - Ok(_) => HttpResponse::NoContent().finish(), - - Err(_) => HttpResponse::NotFound().finish(), - } - } - } else if params.part_number.is_some() && params.upload_id.is_some() { - match client - .upload_multipart_part( + api_client + .assert_authorized( + user_identity.into_inner(), + &account_id, + &repository_id, + RepositoryPermission::Write, + ) + .await?; + + if params.part_number.is_none() && params.upload_id.is_none() { + // Check if this is a server-side copy operation + if let Some(header_copy_identifier) = req.headers().get("x-amz-copy-source") { + let copy_identifier_path = header_copy_identifier.to_str().unwrap_or(""); + client + .copy_object((©_identifier_path).to_string(), key.clone(), None) + .await?; + Ok(HttpResponse::NoContent().finish()) + } else { + // Found the repository, now try to upload the object + client + .put_object( key.clone(), - params.upload_id.clone().unwrap(), - params.part_number.clone().unwrap(), bytes, + headers + .get(CONTENT_TYPE) + .and_then(|h| h.to_str().ok()) + .map(|s| s.to_string()), ) - .await - { - Ok(res) => HttpResponse::Ok() - .insert_header(("ETag", res.etag)) - .finish(), - - Err(_) => HttpResponse::NotFound().finish(), - } - } else { - return HttpResponse::NotFound().finish(); + .await?; + Ok(HttpResponse::NoContent().finish()) } + } else if params.part_number.is_some() && params.upload_id.is_some() { + let res = client + .upload_multipart_part( + key.clone(), + params.upload_id.clone().unwrap(), + params.part_number.clone().unwrap(), + bytes, + ) + .await?; + Ok(HttpResponse::Ok() + .insert_header(("ETag", res.etag)) + .finish()) } else { - // Could not find the repository - return HttpResponse::NotFound().finish(); + return Err(BackendError::InvalidRequest(format!( + "Must provide both part number and upload id or neither." + ))); } } @@ -320,100 +268,63 @@ async fn post_handler( mut payload: web::Payload, path: web::Path<(String, String, String)>, user_identity: web::ReqData, -) -> impl Responder { +) -> Result { let (account_id, repository_id, key) = path.into_inner(); let headers = req.headers(); - if let Ok(client) = api_client + let client = api_client .get_backend_client(&account_id, &repository_id) - .await - { - match api_client - .is_authorized( - user_identity.into_inner(), - &account_id, - &repository_id, - RepositoryPermission::Write, - ) - .await - { - Ok(authorized) => { - if !authorized { - return HttpResponse::Unauthorized().finish(); - } - } - Err(_) => return HttpResponse::InternalServerError().finish(), - } + .await?; - if params.uploads.is_some() { - match client - .create_multipart_upload( - key, - headers - .get(CONTENT_TYPE) - .and_then(|h| h.to_str().ok()) - .map(|s| s.to_string()), - ) - .await - { - Ok(res) => match to_string_with_root("InitiateMultipartUploadResult", &res) { - Ok(serialized) => { - return HttpResponse::Ok() - .content_type("application/xml") - .body(serialized) - } - Err(_) => return HttpResponse::InternalServerError().finish(), - }, - Err(_) => { - return HttpResponse::NotFound().finish(); - } - } - } else if params.upload_id.is_some() { - let mut body = String::new(); - while let Some(chunk) = payload.next().await { - match chunk { - Ok(chunk) => match from_utf8(&chunk) { - Ok(s) => body.push_str(s), - Err(_) => return HttpResponse::BadRequest().body("Invalid UTF-8"), - }, - Err(_) => return HttpResponse::InternalServerError().finish(), - } - } + api_client + .assert_authorized( + user_identity.into_inner(), + &account_id, + &repository_id, + RepositoryPermission::Write, + ) + .await?; - match from_str::(&body) { - Ok(upload) => { - match client - .complete_multipart_upload( - key, - params.upload_id.clone().unwrap(), - upload.parts, - ) - .await - { - Ok(res) => match to_string_with_root("CompleteMultipartUploadResult", &res) - { - Ok(serialized) => { - return HttpResponse::Ok() - .content_type("application/xml") - .body(serialized) - } - Err(_) => return HttpResponse::InternalServerError().finish(), - }, - Err(_) => { - return HttpResponse::NotFound().finish(); - } + if params.uploads.is_some() { + let res = client + .create_multipart_upload( + key, + headers + .get(CONTENT_TYPE) + .and_then(|h| h.to_str().ok()) + .map(|s| s.to_string()), + ) + .await?; + let serialized = to_string_with_root("InitiateMultipartUploadResult", &res)?; + Ok(HttpResponse::Ok() + .content_type("application/xml") + .body(serialized)) + } else if params.upload_id.is_some() { + let mut body = String::new(); + while let Some(chunk) = payload.next().await { + match chunk { + Ok(chunk) => match from_utf8(&chunk) { + Ok(s) => body.push_str(s), + Err(_) => { + return Err(BackendError::InvalidRequest("Invalid UTF-8".to_string())) } - } - Err(_) => { - return HttpResponse::BadRequest().finish(); - } + }, + Err(err) => return Err(BackendError::UnexpectedApiError(err.to_string())), } - } else { - return HttpResponse::NotFound().finish(); } + + let upload = from_str::(&body)?; + let res = client + .complete_multipart_upload(key, params.upload_id.clone().unwrap(), upload.parts) + .await?; + let serialized = to_string_with_root("CompleteMultipartUploadResult", &res)?; + Ok(HttpResponse::Ok() + .content_type("application/xml") + .body(serialized)) } else { - // Could not find the repository - return HttpResponse::NotFound().finish(); + return Err(BackendError::InvalidRequest( + "Must provide either uploads or uploadId".to_string(), + )); } } @@ -422,44 +333,30 @@ async fn head_object( api_client: web::Data, path: web::Path<(String, String, String)>, user_identity: web::ReqData, -) -> impl Responder { +) -> Result { let (account_id, repository_id, key) = path.into_inner(); - match api_client + let client = api_client .get_backend_client(&account_id, &repository_id) - .await - { - Ok(client) => { - match api_client - .is_authorized( - user_identity.into_inner(), - &account_id, - &repository_id, - RepositoryPermission::Read, - ) - .await - { - Ok(authorized) => { - if !authorized { - return HttpResponse::Unauthorized().finish(); - } - } - Err(_) => return HttpResponse::InternalServerError().finish(), - } + .await?; - match client.head_object(key.clone()).await { - Ok(res) => HttpResponse::Ok() - .insert_header(("Content-Type", res.content_type)) - .insert_header(("Last-Modified", res.last_modified)) - .insert_header(("ETag", res.etag)) - .body(BoxBody::new(FakeBody { - size: res.content_length as usize, - })), - Err(error) => error.to_response(), - } - } - Err(_) => HttpResponse::NotFound().finish(), - } + api_client + .assert_authorized( + user_identity.into_inner(), + &account_id, + &repository_id, + RepositoryPermission::Read, + ) + .await?; + + let res = client.head_object(key.clone()).await?; + Ok(HttpResponse::Ok() + .insert_header(("Content-Type", res.content_type)) + .insert_header(("Last-Modified", res.last_modified)) + .insert_header(("ETag", res.etag)) + .body(BoxBody::new(FakeBody { + size: res.content_length as usize, + }))) } #[derive(Deserialize)] @@ -482,44 +379,36 @@ async fn list_objects( info: web::Query, path: web::Path, user_identity: web::ReqData, -) -> impl Responder { +) -> Result { let account_id = path.into_inner(); if info.prefix.clone().is_some_and(|s| s.is_empty()) || info.prefix.is_none() { - match api_client + let account = api_client .get_account(account_id.clone(), (*user_identity).clone()) - .await - { - Ok(account) => { - let repositories = account.repositories; - let mut common_prefixes = Vec::new(); - for repository_id in repositories.iter() { - common_prefixes.push(CommonPrefix { - prefix: format!("{}/", repository_id.clone()), - }); - } - let list_response = ListBucketResult { - name: account_id.clone(), - prefix: "/".to_string(), - key_count: 0, - max_keys: 0, - is_truncated: false, - contents: vec![], - common_prefixes, - next_continuation_token: None, - }; - - match to_string_with_root("ListBucketResult", &list_response) { - Ok(serialized) => { - return HttpResponse::Ok() - .content_type("application/xml") - .body(serialized) - } - Err(_) => return HttpResponse::InternalServerError().finish(), - } - } - Err(_) => return HttpResponse::InternalServerError().finish(), + .await?; + + let repositories = account.repositories; + let mut common_prefixes = Vec::new(); + for repository_id in repositories.iter() { + common_prefixes.push(CommonPrefix { + prefix: format!("{}/", repository_id.clone()), + }); } + let list_response = ListBucketResult { + name: account_id.clone(), + prefix: "/".to_string(), + key_count: 0, + max_keys: 0, + is_truncated: false, + contents: vec![], + common_prefixes, + next_continuation_token: None, + }; + + let serialized = to_string_with_root("ListBucketResult", &list_response)?; + return Ok(HttpResponse::Ok() + .content_type("application/xml") + .body(serialized)); } let path_prefix = info.prefix.clone().unwrap_or("".to_string()); @@ -531,50 +420,34 @@ async fn list_objects( max_keys = mk; } - if let Ok(client) = api_client + let client = api_client .get_backend_client(&account_id, &repository_id.to_string()) - .await - { - match api_client - .is_authorized( - user_identity.into_inner(), - &account_id, - &repository_id.to_string(), - RepositoryPermission::Read, - ) - .await - { - Ok(authorized) => { - if !authorized { - return HttpResponse::Unauthorized().finish(); - } - } - Err(_) => return HttpResponse::InternalServerError().finish(), - } + .await?; - // We're listing within a repository, so we need to query the object store backend - match client - .list_objects_v2( - prefix.to_string(), - info.continuation_token.clone(), - info.delimiter.clone(), - max_keys, - ) - .await - { - Ok(res) => match to_string_with_root("ListBucketResult", &res) { - Ok(serialized) => HttpResponse::Ok() - .content_type("application/xml") - .body(serialized), - Err(e) => HttpResponse::InternalServerError().finish(), - }, - Err(_) => HttpResponse::NotFound().finish(), - } - // Found the repository, now make the list objects request - } else { - // Could not find the repository - return HttpResponse::NotFound().finish(); - } + api_client + .assert_authorized( + user_identity.into_inner(), + &account_id, + &repository_id.to_string(), + RepositoryPermission::Read, + ) + .await?; + + // We're listing within a repository, so we need to query the object store backend + let res = client + .list_objects_v2( + prefix.to_string(), + info.continuation_token.clone(), + info.delimiter.clone(), + max_keys, + ) + .await?; + + let serialized = to_string_with_root("ListBucketResult", &res)?; + + Ok(HttpResponse::Ok() + .content_type("application/xml") + .body(serialized)) } #[get("/")] diff --git a/src/utils/errors.rs b/src/utils/errors.rs index a402c84..041abe0 100644 --- a/src/utils/errors.rs +++ b/src/utils/errors.rs @@ -26,6 +26,9 @@ pub enum BackendError { #[error("data connection not found")] DataConnectionNotFound, + #[error("invalid request")] + InvalidRequest(String), + #[error("reqwest error (url {}, message {})", .0.url().map(|u| u.to_string()).unwrap_or("unknown".to_string()), .0.to_string())] ReqwestError(#[from] ReqwestError), @@ -52,7 +55,7 @@ pub enum BackendError { #[error("unauthorized")] UnauthorizedError, - #[error("unexpected API error: {0}")] // TODO: remove this + #[error("unexpected API error: {0}")] UnexpectedApiError(String), #[error("unsupported auth method: {0}")] @@ -72,6 +75,9 @@ impl error::ResponseError for BackendError { BackendError::RepositoryNotFound => HttpResponse::NotFound().finish(), BackendError::SourceRepositoryMissingPrimaryMirror => HttpResponse::NotFound().finish(), BackendError::DataConnectionNotFound => HttpResponse::NotFound().finish(), + BackendError::InvalidRequest(message) => { + HttpResponse::BadRequest().body(message.clone()) + } BackendError::ReqwestError(_) => HttpResponse::BadGateway().finish(), BackendError::ApiServerError { .. } => HttpResponse::BadGateway().finish(), BackendError::ApiClientError { .. } => HttpResponse::BadGateway().finish(), @@ -93,6 +99,7 @@ impl error::ResponseError for BackendError { BackendError::RepositoryNotFound => StatusCode::NOT_FOUND, BackendError::SourceRepositoryMissingPrimaryMirror => StatusCode::NOT_FOUND, BackendError::DataConnectionNotFound => StatusCode::NOT_FOUND, + BackendError::InvalidRequest(_) => StatusCode::BAD_REQUEST, BackendError::ReqwestError(_) => StatusCode::BAD_GATEWAY, BackendError::ApiServerError { .. } => StatusCode::BAD_GATEWAY, BackendError::ApiClientError { .. } => StatusCode::BAD_GATEWAY, @@ -164,3 +171,8 @@ impl From for BackendError { BackendError::XmlParseError(format!("failed to parse xml: {}", error)) } } +impl From for BackendError { + fn from(error: serde_xml_rs::Error) -> BackendError { + BackendError::XmlParseError(format!("failed to parse xml: {}", error)) + } +} From e7026b8d08f6747ac755e11f86cd89d9d8d92af7 Mon Sep 17 00:00:00 2001 From: Anthony Lukach Date: Tue, 13 May 2025 08:09:14 -0700 Subject: [PATCH 3/4] Continue to flatten --- src/apis/source.rs | 24 +++++++++++------------- 1 file changed, 11 insertions(+), 13 deletions(-) diff --git a/src/apis/source.rs b/src/apis/source.rs index 217bb0a..6ced977 100644 --- a/src/apis/source.rs +++ b/src/apis/source.rs @@ -405,20 +405,18 @@ impl SourceAPI { // If not in cache, fetch it let secret = self.fetch_api_key(access_key_id).await?; + // Cache the successful result - match secret { - Some(secret) => { - self.api_key_cache.insert(cache_key, secret.clone()).await; - Ok(secret) - } - None => { - let secret = APIKey { - access_key_id: "".to_string(), - secret_access_key: "".to_string(), - }; - self.api_key_cache.insert(cache_key, secret.clone()).await; - Ok(secret) - } + if let Some(secret) = secret { + self.api_key_cache.insert(cache_key, secret.clone()).await; + Ok(secret) + } else { + let secret = APIKey { + access_key_id: "".to_string(), + secret_access_key: "".to_string(), + }; + self.api_key_cache.insert(cache_key, secret.clone()).await; + Ok(secret) } } From ce7a621a27a74bd8043dfb0992d443409c2607f4 Mon Sep 17 00:00:00 2001 From: Anthony Lukach Date: Tue, 13 May 2025 08:19:09 -0700 Subject: [PATCH 4/4] Make use of api handler util --- Cargo.lock | 1 - Cargo.toml | 1 - src/apis/source.rs | 14 ++------------ src/utils/errors.rs | 5 +++++ 4 files changed, 7 insertions(+), 14 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index dd1fa94..cbf0a27 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2353,7 +2353,6 @@ dependencies = [ "rusoto_s3", "serde", "serde-xml-rs", - "serde_json", "sha2 0.10.8", "thiserror 2.0.12", "time", diff --git a/Cargo.toml b/Cargo.toml index 060bfec..6206942 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -27,7 +27,6 @@ xml-rs = "0.8" serde = { version = "1.0", features = ["derive"] } serde-xml-rs = "0.6" bytes = "1.0" -serde_json = "1.0" pin-project-lite = "0.2" futures = "0.3" futures-core = "0.3" diff --git a/src/apis/source.rs b/src/apis/source.rs index 6ced977..6225e48 100644 --- a/src/apis/source.rs +++ b/src/apis/source.rs @@ -9,7 +9,6 @@ use async_trait::async_trait; use moka::future::Cache; use rusoto_core::Region; use serde::{Deserialize, Serialize}; -use serde_json::Value; use std::collections::HashMap; use std::env; use std::sync::Arc; @@ -442,20 +441,11 @@ impl SourceAPI { .headers(headers) .send() .await?; - let status = response.status(); - let text = response.text().await?; - if !status.is_success() { - return Err(BackendError::UnexpectedApiError(format!( - "Failed to get API key for access key id: {}, {}", - access_key_id, text - ))); - } - let json: Value = serde_json::from_str(&text).unwrap(); - let secret_access_key = json["secret_access_key"].as_str().unwrap(); + let key: APIKey = process_json_response::(response, BackendError::ApiKeyNotFound).await?; Ok(Some(APIKey { access_key_id, - secret_access_key: secret_access_key.to_string(), + secret_access_key: key.secret_access_key, })) } diff --git a/src/utils/errors.rs b/src/utils/errors.rs index 041abe0..59cd5bf 100644 --- a/src/utils/errors.rs +++ b/src/utils/errors.rs @@ -23,6 +23,9 @@ pub enum BackendError { #[error("source repository missing primary mirror")] SourceRepositoryMissingPrimaryMirror, + #[error("api key not found")] + ApiKeyNotFound, + #[error("data connection not found")] DataConnectionNotFound, @@ -74,6 +77,7 @@ impl error::ResponseError for BackendError { match self { BackendError::RepositoryNotFound => HttpResponse::NotFound().finish(), BackendError::SourceRepositoryMissingPrimaryMirror => HttpResponse::NotFound().finish(), + BackendError::ApiKeyNotFound => HttpResponse::NotFound().finish(), BackendError::DataConnectionNotFound => HttpResponse::NotFound().finish(), BackendError::InvalidRequest(message) => { HttpResponse::BadRequest().body(message.clone()) @@ -98,6 +102,7 @@ impl error::ResponseError for BackendError { match self { BackendError::RepositoryNotFound => StatusCode::NOT_FOUND, BackendError::SourceRepositoryMissingPrimaryMirror => StatusCode::NOT_FOUND, + BackendError::ApiKeyNotFound => StatusCode::NOT_FOUND, BackendError::DataConnectionNotFound => StatusCode::NOT_FOUND, BackendError::InvalidRequest(_) => StatusCode::BAD_REQUEST, BackendError::ReqwestError(_) => StatusCode::BAD_GATEWAY,