Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: write a wrapper for the provisioner to call gw and r-r clients #1585

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
55ddeb0
feat: resource DAL
chesedo Jan 23, 2024
daf7803
tests: basic get_project_resource DAL test
chesedo Jan 23, 2024
484b221
tests: getting only RDS resources
chesedo Jan 23, 2024
acaa7ba
refactor: set token correctly
chesedo Jan 24, 2024
a77ae91
refactor: clippy suggestions
chesedo Jan 24, 2024
41e9730
refactor: add some comments
chesedo Jan 24, 2024
5eeecc0
refactor: add rr-client to provisioner
chesedo Jan 24, 2024
d7d892a
refactor: add gateway-client to provisioner
chesedo Jan 24, 2024
944cce0
feat: have provisioner check RDS quota before provisioning
chesedo Jan 25, 2024
d3332f4
Merge remote-tracking branch 'origin/main' into feature/engn-55-write…
chesedo Jan 25, 2024
e0ea928
refactor: make test-utils available to other crates
chesedo Jan 25, 2024
6b34ee4
refactor: move `ClaimExt` to backends
chesedo Jan 25, 2024
b0cd477
refactor: update compose file
chesedo Jan 25, 2024
fc364b7
refactor: name all mock getters the same
chesedo Jan 26, 2024
b626278
refactor: correct gateway port
chesedo Jan 26, 2024
a870039
refactor: test comments
chesedo Jan 26, 2024
a0f3dd0
Merge remote-tracking branch 'origin/main' into feature/engn-55-write…
chesedo Jan 26, 2024
dd4c000
tests: ci fixes
chesedo Jan 29, 2024
8989dfa
refactor: add tracing to remote DAL calls
chesedo Jan 29, 2024
910f5f8
refactor: move rr mock out of shuttle-common-test
chesedo Jan 29, 2024
0cd0403
tests: make sure client connects to the mocked server
chesedo Jan 29, 2024
d572aa7
Update common/src/backends/client/resource_recorder.rs
chesedo Jan 29, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
6 changes: 6 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion auth/src/main.rs
@@ -1,7 +1,7 @@
use std::io;

use clap::Parser;
use shuttle_common::{backends::tracing::setup_tracing, claims::AccountTier, log::Backend};
use shuttle_common::{backends::trace::setup_tracing, claims::AccountTier, log::Backend};
use sqlx::migrate::Migrator;
use tracing::trace;

Expand Down
2 changes: 1 addition & 1 deletion builder/src/main.rs
Expand Up @@ -5,7 +5,7 @@ use shuttle_builder::{args::Args, Service};
use shuttle_common::{
backends::{
auth::{AuthPublicKey, JwtAuthenticationLayer},
tracing::{setup_tracing, ExtractPropagationLayer},
trace::{setup_tracing, ExtractPropagationLayer},
},
log::Backend,
};
Expand Down
2 changes: 1 addition & 1 deletion common-tests/src/builder.rs
Expand Up @@ -12,7 +12,7 @@ use shuttle_proto::builder::{
use tonic::transport::{Endpoint, Server};
use tower::ServiceBuilder;

pub async fn mocked_builder_client(
pub async fn get_mocked_builder_client(
builder: impl Builder,
) -> BuilderClient<
shuttle_common::claims::ClaimService<
Expand Down
14 changes: 13 additions & 1 deletion common-tests/src/lib.rs
Expand Up @@ -2,7 +2,6 @@ pub mod builder;
pub mod cargo_shuttle;
pub mod logger;
pub mod postgres;
pub mod resource_recorder;

use shuttle_common::claims::{AccountTier, Claim, Scope};

Expand Down Expand Up @@ -65,3 +64,16 @@ where
self.inner.call(req)
}
}

pub trait ClaimTestsExt {
/// Fill the token of a test key correctly
fn fill_token(self) -> Self;
}

impl ClaimTestsExt for Claim {
fn fill_token(mut self) -> Self {
self.token = Some(self.sub.clone());

self
}
}
2 changes: 1 addition & 1 deletion common-tests/src/logger.rs
Expand Up @@ -47,7 +47,7 @@ impl Logger for MockedLogger {
}
}

pub async fn mocked_logger_client(
pub async fn get_mocked_logger_client(
logger: impl Logger,
) -> LoggerClient<
shuttle_common::claims::ClaimService<
Expand Down
2 changes: 2 additions & 0 deletions common/Cargo.toml
Expand Up @@ -52,6 +52,7 @@ ttl_cache = { workspace = true, optional = true }
url = { workspace = true, features = ["serde"] }
uuid = { workspace = true, features = ["v4", "serde"], optional = true }
zeroize = { workspace = true }
wiremock = { workspace = true, optional = true }

[features]
backend = [
Expand Down Expand Up @@ -92,6 +93,7 @@ display = ["chrono/clock", "comfy-table", "crossterm"]
models = ["async-trait", "http", "reqwest", "service", "thiserror"]
persist = ["sqlx", "rand"]
service = ["chrono/serde", "display", "tracing", "tracing-subscriber", "uuid"]
test-utils = ["wiremock"]
tracing = ["dep:tracing"]
wasm = [
"chrono/clock",
Expand Down
19 changes: 6 additions & 13 deletions common/src/backends/auth.rs
Expand Up @@ -15,10 +15,7 @@ use tower::{Layer, Service};
use tracing::{error, trace, Span};
use tracing_opentelemetry::OpenTelemetrySpanExt;

use crate::{
claims::{Claim, Scope},
limits::ClaimExt,
};
use crate::claims::{Claim, Scope};

use super::{
cache::{CacheManagement, CacheManager},
Expand Down Expand Up @@ -430,9 +427,11 @@ where
pub trait VerifyClaim {
type Error;

/// Verify that this request has the given scopes
fn verify(&self, required_scope: Scope) -> Result<(), Self::Error>;

fn verify_rds_access(&self) -> Result<(), Self::Error>;
/// Get the claim if it exists on the request
fn get_claim(&self) -> Result<Claim, Self::Error>;
}

#[cfg(feature = "tonic")]
Expand All @@ -459,19 +458,13 @@ impl<B> VerifyClaim for tonic::Request<B> {
}
}

fn verify_rds_access(&self) -> Result<(), Self::Error> {
fn get_claim(&self) -> Result<Claim, Self::Error> {
let claim = self
.extensions()
.get::<Claim>()
.ok_or_else(|| tonic::Status::internal("could not get claim"))?;

if claim.can_provision_rds() {
Ok(())
} else {
Err(tonic::Status::permission_denied(
"don't have permission to provision rds instances",
))
}
Ok(claim.clone())
}
}

Expand Down
29 changes: 16 additions & 13 deletions common/src/backends/client/gateway.rs
@@ -1,18 +1,19 @@
use headers::Authorization;
use http::{Method, Uri};
use tracing::instrument;

use crate::models;

use super::{Error, ServicesApiClient};

/// Wrapper struct to make API calls to gateway easier
#[derive(Clone)]
pub struct GatewayClient {
pub struct Client {
public_client: ServicesApiClient,
private_client: ServicesApiClient,
}

impl GatewayClient {
impl Client {
/// Make a gateway client that is able to call the public and private APIs on gateway
pub fn new(public_uri: Uri, private_uri: Uri) -> Self {
Self {
Expand All @@ -33,7 +34,8 @@ impl GatewayClient {
}

/// Interact with all the data relating to projects
trait ProjectsDal {
#[allow(async_fn_in_trait)]
pub trait ProjectsDal {
/// Get the projects that belong to a user
async fn get_user_projects(
&self,
Expand All @@ -53,7 +55,8 @@ trait ProjectsDal {
}
}

impl ProjectsDal for GatewayClient {
impl ProjectsDal for Client {
#[instrument(skip_all)]
async fn get_user_projects(
&self,
user_token: &str,
Expand All @@ -78,24 +81,24 @@ mod tests {
use test_context::{test_context, AsyncTestContext};

use crate::models::project::{Response, State};
use crate::test_utils::mocked_gateway_server;
use crate::test_utils::get_mocked_gateway_server;

use super::{GatewayClient, ProjectsDal};
use super::{Client, ProjectsDal};

#[async_trait]
impl AsyncTestContext for GatewayClient {
impl AsyncTestContext for Client {
async fn setup() -> Self {
let server = mocked_gateway_server().await;
let server = get_mocked_gateway_server().await;

GatewayClient::new(server.uri().parse().unwrap(), server.uri().parse().unwrap())
Client::new(server.uri().parse().unwrap(), server.uri().parse().unwrap())
}

async fn teardown(mut self) {}
}

#[test_context(GatewayClient)]
#[test_context(Client)]
#[tokio::test]
async fn get_user_projects(client: &mut GatewayClient) {
async fn get_user_projects(client: &mut Client) {
let res = client.get_user_projects("user-1").await.unwrap();

assert_eq!(
Expand All @@ -117,9 +120,9 @@ mod tests {
)
}

#[test_context(GatewayClient)]
#[test_context(Client)]
#[tokio::test]
async fn get_user_project_ids(client: &mut GatewayClient) {
async fn get_user_project_ids(client: &mut Client) {
let res = client.get_user_project_ids("user-2").await.unwrap();

assert_eq!(res, vec!["id3"])
Expand Down
12 changes: 8 additions & 4 deletions common/src/backends/client/mod.rs
Expand Up @@ -8,9 +8,11 @@ use thiserror::Error;
use tracing::{trace, Span};
use tracing_opentelemetry::OpenTelemetrySpanExt;

mod gateway;
pub mod gateway;
mod resource_recorder;

pub use gateway::GatewayClient;
pub use gateway::ProjectsDal;
pub use resource_recorder::ResourceDal;

#[derive(Error, Debug)]
pub enum Error {
Expand All @@ -22,6 +24,8 @@ pub enum Error {
Http(#[from] hyper::http::Error),
#[error("Request did not return correctly. Got status code: {0}")]
RequestError(StatusCode),
#[error("GRpc request did not return correctly. Got status code: {0}")]
GrpcError(#[from] tonic::Status),
}

/// `Hyper` wrapper to make request to RESTful Shuttle services easy
Expand Down Expand Up @@ -96,14 +100,14 @@ mod tests {
use http::{Method, StatusCode};

use crate::models;
use crate::test_utils::mocked_gateway_server;
use crate::test_utils::get_mocked_gateway_server;

use super::{Error, ServicesApiClient};

// Make sure we handle any unexpected responses correctly
#[tokio::test]
async fn api_errors() {
let server = mocked_gateway_server().await;
let server = get_mocked_gateway_server().await;

let client = ServicesApiClient::new(server.uri().parse().unwrap());

Expand Down
53 changes: 53 additions & 0 deletions common/src/backends/client/resource_recorder.rs
@@ -0,0 +1,53 @@
use async_trait::async_trait;
use tracing::instrument;

use crate::{database, resource};

use super::Error;

/// DAL for access resources data of projects
#[async_trait]
pub trait ResourceDal: Send {
/// Get the resources belonging to a project
async fn get_project_resources(
&mut self,
project_id: &str,
token: &str,
) -> Result<Vec<resource::Response>, Error>;

/// Get only the RDS resources that belong to a project
async fn get_project_rds_resources(
&mut self,
project_id: &str,
token: &str,
) -> Result<Vec<resource::Response>, Error> {
let rds_resources = self
.get_project_resources(project_id, token)
.await?
.into_iter()
.filter(|r| {
matches!(
r.r#type,
resource::Type::Database(database::Type::AwsRds(_))
)
})
.collect();

Ok(rds_resources)
}
}

#[async_trait]
impl<T> ResourceDal for &mut T
where
T: ResourceDal,
{
#[instrument(skip_all, fields(shuttle.project.id = project_id))]
async fn get_project_resources(
&mut self,
project_id: &str,
token: &str,
) -> Result<Vec<resource::Response>, Error> {
(**self).get_project_resources(project_id, token).await
}
}
54 changes: 53 additions & 1 deletion common/src/backends/mod.rs
@@ -1,8 +1,60 @@
use tracing::instrument;

use crate::claims::{Claim, Scope};

use self::client::{ProjectsDal, ResourceDal};

pub mod auth;
pub mod cache;
pub mod client;
mod future;
pub mod headers;
pub mod metrics;
mod otlp_tracing_bridge;
pub mod tracing;
pub mod trace;

#[allow(async_fn_in_trait)]
pub trait ClaimExt {
/// Verify that the [Claim] has the [Scope::Admin] scope.
fn is_admin(&self) -> bool;
/// Verify that the user's current project count is lower than the account limit in [Claim::limits].
fn can_create_project(&self, current_count: u32) -> bool;
/// Verify that the user has permission to provision RDS instances.
async fn can_provision_rds<G: ProjectsDal, R: ResourceDal>(
&self,
projects_dal: &G,
resource_dal: &mut R,
) -> Result<bool, client::Error>;
}

impl ClaimExt for Claim {
fn is_admin(&self) -> bool {
self.scopes.contains(&Scope::Admin)
}

fn can_create_project(&self, current_count: u32) -> bool {
self.is_admin() || self.limits.project_limit() > current_count
}

#[instrument(skip_all)]
async fn can_provision_rds<G: ProjectsDal, R: ResourceDal>(
&self,
projects_dal: &G,
resource_dal: &mut R,
) -> Result<bool, client::Error> {
let token = self.token.as_ref().expect("token to be set");

let projects = projects_dal.get_user_project_ids(token).await?;

let mut rds_count = 0;

for project_id in projects {
rds_count += resource_dal
.get_project_rds_resources(&project_id, token)
.await?
.len();
}

Ok(self.is_admin() || self.limits.rds_quota > (rds_count as u32))
}
}
File renamed without changes.