Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
91 changes: 52 additions & 39 deletions lib/api-helper/macros/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -134,31 +134,16 @@ impl EndpointRouter {
let endpoints = self
.routes
.into_iter()
.map(|endpoint| endpoint.render())
.map(|endpoint| endpoint.render(self.cors_config.is_some()))
.collect::<syn::Result<Vec<_>>>()?;

let cors = self
let cors_config = self
.cors_config
.map(|cors_config| {
quote! {
lazy_static::lazy_static! {
static ref CORS_CONFIG: api_helper::util::CorsConfig = #cors_config;
}

match api_helper::util::verify_cors(request, &*CORS_CONFIG)? {
// Set headers and immediately return empty response
api_helper::util::CorsResponse::Preflight(headers) => {
response.headers_mut().map(|h| h.extend(headers));

return Ok(Some(Vec::new()));
}
// Set headers, continue with request
api_helper::util::CorsResponse::Regular(headers) => {
response.headers_mut().map(|h| h.extend(headers));
}
// No CORS
api_helper::util::CorsResponse::NoCors => {}
}
}
})
.unwrap_or_else(|| quote! {});
Expand All @@ -174,21 +159,25 @@ impl EndpointRouter {
};

quote! {
.try_or_else(|| rivet_operation::prelude::futures_util::FutureExt::boxed(tracing::Instrument::in_current_span(async {
router_config.prefix = #mount_prefix;

#mount_path::__inner(
shared_client.clone(),
pools.clone(),
cache.clone(),
ray_id,
request,
response,
router_config,
.try_or_else(|| {
rivet_operation::prelude::futures_util::FutureExt::boxed(
tracing::Instrument::in_current_span(async {
router_config.prefix = #mount_prefix;

#mount_path::__inner(
shared_client.clone(),
pools.clone(),
cache.clone(),
ray_id,
request,
response,
router_config,
)
.await
.map(std::convert::Into::into)
})
)
.await
.map(std::convert::Into::into)
}))).await?
}).await?
}
})
.collect::<Vec<_>>();
Expand All @@ -214,9 +203,8 @@ impl EndpointRouter {
return Ok(None);
}

// Cors is handled after path segments are created so that we are sure that we are in
// the correct mount (if nested routers)
#cors
// Define cors config
#cors_config

let body = __AsyncOption::None
#(#endpoints)*
Expand Down Expand Up @@ -381,7 +369,6 @@ impl Parse for RequestPath {
enum RequestPathSegment {
LitStr(syn::LitStr),
Type(syn::Type),
Empty,
}

impl Parse for RequestPathSegment {
Expand All @@ -391,7 +378,10 @@ impl Parse for RequestPathSegment {
if let Ok(lit) = fork.parse::<syn::LitStr>() {
input.advance_to(&fork);
if lit.value().is_empty() {
Ok(RequestPathSegment::Empty)
return Err(syn::Error::new(
lit.span(),
format!("Empty segment not allowed"),
));
} else {
Ok(RequestPathSegment::LitStr(lit))
}
Expand Down Expand Up @@ -433,7 +423,6 @@ impl RequestPathSegment {
};
}
}
RequestPathSegment::Empty => quote! {}, // Handle empty path segment
}
}
}
Expand Down Expand Up @@ -486,7 +475,29 @@ impl Parse for Endpoint {
}

impl Endpoint {
fn render(self) -> syn::Result<TokenStream2> {
fn render(self, verify_cors: bool) -> syn::Result<TokenStream2> {
// Check cors at the endpoint level
let cors = if verify_cors {
quote! {
match api_helper::util::verify_cors(request, &*CORS_CONFIG)? {
// Set headers and immediately return empty response
api_helper::util::CorsResponse::Preflight(headers) => {
response.headers_mut().map(|h| h.extend(headers));

return Ok(__AsyncOption::Some(Vec::new()));
}
// Set headers, continue with request
api_helper::util::CorsResponse::Regular(headers) => {
response.headers_mut().map(|h| h.extend(headers));
}
// No CORS
api_helper::util::CorsResponse::NoCors => {}
}
}
} else {
quote! {}
};

// Generate a path to use for the metrics
//
// This path can't contain the actual variables from the real
Expand All @@ -497,7 +508,6 @@ impl Endpoint {
.map(|segment| match segment {
RequestPathSegment::LitStr(lit) => lit.value(),
RequestPathSegment::Type(_) => "{}".to_string(),
RequestPathSegment::Empty => "".to_string(), // Handle empty path
})
.collect::<Vec<String>>()
.join("/");
Expand Down Expand Up @@ -540,7 +550,10 @@ impl Endpoint {
.rev();
#(#segment_parsing)*

// Path matches
if path_segments.next().is_none() {
#cors

match request.method() {
#(#arms)*
_ => {
Expand Down