From 4a4b4feb94fca4ddc1a8c06963653b70e91138d9 Mon Sep 17 00:00:00 2001 From: MasterPtato <23087326+MasterPtato@users.noreply.github.com> Date: Tue, 27 Aug 2024 18:57:51 +0000 Subject: [PATCH] fix(api): move cors verification to endpoint level (#1094) Fixes RVTEE-551 ## Changes --- lib/api-helper/macros/src/lib.rs | 91 ++++++++++++++++++-------------- 1 file changed, 52 insertions(+), 39 deletions(-) diff --git a/lib/api-helper/macros/src/lib.rs b/lib/api-helper/macros/src/lib.rs index c456fd794b..e57c6469f5 100644 --- a/lib/api-helper/macros/src/lib.rs +++ b/lib/api-helper/macros/src/lib.rs @@ -134,31 +134,16 @@ impl EndpointRouter { let endpoints = self .routes .into_iter() - .map(|endpoint| endpoint.render()) + .map(|endpoint| endpoint.render(self.cors_config.is_some())) .collect::>>()?; - let cors = self + let cors_config = self .cors_config .map(|cors_config| { quote! { lazy_static::lazy_static! { static ref CORS_CONFIG: api_helper::util::CorsConfig = #cors_config; } - - match api_helper::util::verify_cors(request, &*CORS_CONFIG)? { - // Set headers and immediately return empty response - api_helper::util::CorsResponse::Preflight(headers) => { - response.headers_mut().map(|h| h.extend(headers)); - - return Ok(Some(Vec::new())); - } - // Set headers, continue with request - api_helper::util::CorsResponse::Regular(headers) => { - response.headers_mut().map(|h| h.extend(headers)); - } - // No CORS - api_helper::util::CorsResponse::NoCors => {} - } } }) .unwrap_or_else(|| quote! {}); @@ -174,21 +159,25 @@ impl EndpointRouter { }; quote! { - .try_or_else(|| rivet_operation::prelude::futures_util::FutureExt::boxed(tracing::Instrument::in_current_span(async { - router_config.prefix = #mount_prefix; - - #mount_path::__inner( - shared_client.clone(), - pools.clone(), - cache.clone(), - ray_id, - request, - response, - router_config, + .try_or_else(|| { + rivet_operation::prelude::futures_util::FutureExt::boxed( + tracing::Instrument::in_current_span(async { + router_config.prefix = #mount_prefix; + + #mount_path::__inner( + shared_client.clone(), + pools.clone(), + cache.clone(), + ray_id, + request, + response, + router_config, + ) + .await + .map(std::convert::Into::into) + }) ) - .await - .map(std::convert::Into::into) - }))).await? + }).await? } }) .collect::>(); @@ -214,9 +203,8 @@ impl EndpointRouter { return Ok(None); } - // Cors is handled after path segments are created so that we are sure that we are in - // the correct mount (if nested routers) - #cors + // Define cors config + #cors_config let body = __AsyncOption::None #(#endpoints)* @@ -381,7 +369,6 @@ impl Parse for RequestPath { enum RequestPathSegment { LitStr(syn::LitStr), Type(syn::Type), - Empty, } impl Parse for RequestPathSegment { @@ -391,7 +378,10 @@ impl Parse for RequestPathSegment { if let Ok(lit) = fork.parse::() { input.advance_to(&fork); if lit.value().is_empty() { - Ok(RequestPathSegment::Empty) + return Err(syn::Error::new( + lit.span(), + format!("Empty segment not allowed"), + )); } else { Ok(RequestPathSegment::LitStr(lit)) } @@ -433,7 +423,6 @@ impl RequestPathSegment { }; } } - RequestPathSegment::Empty => quote! {}, // Handle empty path segment } } } @@ -486,7 +475,29 @@ impl Parse for Endpoint { } impl Endpoint { - fn render(self) -> syn::Result { + fn render(self, verify_cors: bool) -> syn::Result { + // Check cors at the endpoint level + let cors = if verify_cors { + quote! { + match api_helper::util::verify_cors(request, &*CORS_CONFIG)? { + // Set headers and immediately return empty response + api_helper::util::CorsResponse::Preflight(headers) => { + response.headers_mut().map(|h| h.extend(headers)); + + return Ok(__AsyncOption::Some(Vec::new())); + } + // Set headers, continue with request + api_helper::util::CorsResponse::Regular(headers) => { + response.headers_mut().map(|h| h.extend(headers)); + } + // No CORS + api_helper::util::CorsResponse::NoCors => {} + } + } + } else { + quote! {} + }; + // Generate a path to use for the metrics // // This path can't contain the actual variables from the real @@ -497,7 +508,6 @@ impl Endpoint { .map(|segment| match segment { RequestPathSegment::LitStr(lit) => lit.value(), RequestPathSegment::Type(_) => "{}".to_string(), - RequestPathSegment::Empty => "".to_string(), // Handle empty path }) .collect::>() .join("/"); @@ -540,7 +550,10 @@ impl Endpoint { .rev(); #(#segment_parsing)* + // Path matches if path_segments.next().is_none() { + #cors + match request.method() { #(#arms)* _ => {