Skip to content

Commit

Permalink
refactor: update queries
Browse files Browse the repository at this point in the history
  • Loading branch information
chesedo committed Jul 17, 2023
1 parent 20f03a8 commit ecfdf2f
Show file tree
Hide file tree
Showing 3 changed files with 37 additions and 16 deletions.
12 changes: 11 additions & 1 deletion gateway/src/lib.rs
Expand Up @@ -349,6 +349,7 @@ pub mod tests {
use shuttle_common::backends::auth::ConvertResponse;
use shuttle_common::claims::{Claim, Scope};
use shuttle_common::models::project;
use sqlx::sqlite::SqliteConnectOptions;
use sqlx::SqlitePool;
use tokio::sync::mpsc::channel;

Expand Down Expand Up @@ -621,7 +622,16 @@ pub mod tests {

let hyper = HyperClient::builder().build(HttpConnector::new());

let pool = SqlitePool::connect("sqlite::memory:").await.unwrap();
let pool = SqlitePool::connect_with(
SqliteConnectOptions::from_str("sqlite::memory:")
.unwrap()
// Set the ulid0 extension for generating ULID's in migrations.
// This uses the ulid0.so file in the crate root, with the
// LD_LIBRARY_PATH env set in build.rs.
.extension("ulid0"),
)
.await
.unwrap();
MIGRATIONS.run(&pool).await.unwrap();

let acme_client = AcmeClient::new();
Expand Down
6 changes: 5 additions & 1 deletion gateway/src/main.rs
Expand Up @@ -47,7 +47,11 @@ async fn main() -> io::Result<()> {
let sqlite_options = SqliteConnectOptions::from_str(db_uri)
.unwrap()
.journal_mode(SqliteJournalMode::Wal)
.synchronous(SqliteSynchronous::Normal);
.synchronous(SqliteSynchronous::Normal)
// Set the ulid0 extension for generating ULID's in migrations.
// This uses the ulid0.so file in the crate root, with the
// LD_LIBRARY_PATH env set in build.rs.
.extension("ulid0");

let db = SqlitePool::connect_with(sqlite_options).await.unwrap();
MIGRATIONS.run(&db).await.unwrap();
Expand Down
35 changes: 21 additions & 14 deletions gateway/src/service.rs
Expand Up @@ -289,7 +289,7 @@ impl GatewayService {

query
.push_bind(account_name)
.push(" ORDER BY created_at DESC NULLS LAST, project_name LIMIT ")
.push(" ORDER BY project_id DESC, project_name LIMIT ")
.push_bind(limit);

if offset > 0 {
Expand Down Expand Up @@ -398,7 +398,7 @@ impl GatewayService {
) -> Result<Project, Error> {
if let Some(row) = query(
r#"
SELECT project_name, account_name, initial_key, project_state
SELECT project_name, project_id, account_name, initial_key, project_state
FROM projects
WHERE (project_name = ?1)
AND (account_name = ?2 OR ?3)
Expand All @@ -412,14 +412,15 @@ impl GatewayService {
{
// If the project already exists and belongs to this account
let project = row.get::<SqlxJson<Project>, _>("project_state").0;
let project_id = row.get::<String, _>("project_id");
if project.is_destroyed() {
// But is in `::Destroyed` state, recreate it
let mut creating = ProjectCreating::new_with_random_initial_key(
project_name.clone(),
idle_minutes,
);
// Restore previous custom domain, if any
match self.find_custom_domain_for_project(&project_name).await {
match self.find_custom_domain_for_project(&project_id).await {
Ok(custom_domain) => {
creating = creating.with_fqdn(custom_domain.fqdn.to_string());
}
Expand Down Expand Up @@ -462,7 +463,7 @@ impl GatewayService {
ProjectCreating::new_with_random_initial_key(project_name.clone(), idle_minutes),
));

query("INSERT INTO projects (project_name, account_name, initial_key, project_state, created_at) VALUES (?1, ?2, ?3, ?4, CURRENT_TIMESTAMP)")
query("INSERT INTO projects (project_id, project_name, account_name, initial_key, project_state) VALUES (ulid(), ?1, ?2, ?3, ?4)")
.bind(&project_name)
.bind(&account_name)
.bind(project.initial_key().unwrap())
Expand All @@ -473,7 +474,7 @@ impl GatewayService {
// If the error is a broken PK constraint, this is a
// project name clash
if let Some(db_err_code) = err.as_database_error().and_then(DatabaseError::code) {
if db_err_code == "1555" { // SQLITE_CONSTRAINT_PRIMARYKEY
if db_err_code == "2067" { // SQLITE_CONSTRAINT_UNIQUE
return Error::from_kind(ErrorKind::ProjectAlreadyExists)
}
}
Expand All @@ -493,9 +494,15 @@ impl GatewayService {
certs: &str,
private_key: &str,
) -> Result<(), Error> {
query("INSERT OR REPLACE INTO custom_domains (fqdn, project_name, certificate, private_key) VALUES (?1, ?2, ?3, ?4)")
.bind(fqdn.to_string())
let project_id = query("SELECT project_id FROM projects WHERE project_name = ?1")
.bind(project_name)
.fetch_one(&self.db)
.await?
.get::<String, _>("project_id");

query("INSERT OR REPLACE INTO custom_domains (fqdn, project_id, certificate, private_key) VALUES (?1, ?2, ?3, ?4)")
.bind(fqdn.to_string())
.bind(project_id)
.bind(certs)
.bind(private_key)
.execute(&self.db)
Expand All @@ -505,7 +512,7 @@ impl GatewayService {
}

pub async fn iter_custom_domains(&self) -> Result<impl Iterator<Item = CustomDomain>, Error> {
query("SELECT fqdn, project_name, certificate, private_key FROM custom_domains")
query("SELECT fqdn, project_name, certificate, private_key FROM custom_domains AS cd JOIN projects AS p ON cd.project_id = p.project_id")
.fetch_all(&self.db)
.await
.map(|res| {
Expand All @@ -519,14 +526,14 @@ impl GatewayService {
.map_err(|_| Error::from_kind(ErrorKind::Internal))
}

pub async fn find_custom_domain_for_project(
async fn find_custom_domain_for_project(
&self,
project_name: &ProjectName,
project_id: &str,
) -> Result<CustomDomain, Error> {
let custom_domain = query(
"SELECT fqdn, project_name, certificate, private_key FROM custom_domains WHERE project_name = ?1",
"SELECT fqdn, project_name, certificate, private_key FROM custom_domains AS cd JOIN projects AS p ON cd.project_id = p.project_id WHERE p.project_id = ?1",
)
.bind(project_name.to_string())
.bind(project_id)
.fetch_optional(&self.db)
.await?
.map(|row| CustomDomain {
Expand All @@ -544,7 +551,7 @@ impl GatewayService {
fqdn: &Fqdn,
) -> Result<CustomDomain, Error> {
let custom_domain = query(
"SELECT fqdn, project_name, certificate, private_key FROM custom_domains WHERE fqdn = ?1",
"SELECT fqdn, project_name, certificate, private_key FROM custom_domains AS cd JOIN projects AS p ON cd.project_id = p.project_id WHERE fqdn = ?1",
)
.bind(fqdn.to_string())
.fetch_optional(&self.db)
Expand Down Expand Up @@ -828,7 +835,7 @@ pub mod tests {
.unwrap();
}

// We need to fetch all of them from the DB since they are ordered by created_at and project_name,
// We need to fetch all of them from the DB since they are ordered by created_at (in the id) and project_name,
// and created_at will be the same for some of them.
let all_projects = svc
.iter_user_projects_detailed(&neo, 0, u32::MAX)
Expand Down

0 comments on commit ecfdf2f

Please sign in to comment.