Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

located errors: database #124

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 10 additions & 10 deletions src/apis/handlers.rs
Original file line number Diff line number Diff line change
Expand Up @@ -66,8 +66,8 @@ pub async fn add_torrent_to_whitelist_handler(
match InfoHash::from_str(&info_hash.0) {
Err(_) => invalid_info_hash_param_response(&info_hash.0),
Ok(info_hash) => match tracker.add_torrent_to_whitelist(&info_hash).await {
Ok(..) => ok_response(),
Err(..) => failed_to_whitelist_torrent_response(),
Ok(_) => ok_response(),
Err(e) => failed_to_whitelist_torrent_response(e),
},
}
}
Expand All @@ -79,24 +79,24 @@ pub async fn remove_torrent_from_whitelist_handler(
match InfoHash::from_str(&info_hash.0) {
Err(_) => invalid_info_hash_param_response(&info_hash.0),
Ok(info_hash) => match tracker.remove_torrent_from_whitelist(&info_hash).await {
Ok(..) => ok_response(),
Err(..) => failed_to_remove_torrent_from_whitelist_response(),
Ok(_) => ok_response(),
Err(e) => failed_to_remove_torrent_from_whitelist_response(e),
},
}
}

pub async fn reload_whitelist_handler(State(tracker): State<Arc<Tracker>>) -> Response {
match tracker.load_whitelist().await {
Ok(..) => ok_response(),
Err(..) => failed_to_reload_whitelist_response(),
Ok(_) => ok_response(),
Err(e) => failed_to_reload_whitelist_response(e),
}
}

pub async fn generate_auth_key_handler(State(tracker): State<Arc<Tracker>>, Path(seconds_valid_or_key): Path<u64>) -> Response {
let seconds_valid = seconds_valid_or_key;
match tracker.generate_auth_key(Duration::from_secs(seconds_valid)).await {
Ok(auth_key) => auth_key_response(&AuthKey::from(auth_key)),
Err(_) => failed_to_generate_key_response(),
Err(e) => failed_to_generate_key_response(e),
}
}

Expand All @@ -111,15 +111,15 @@ pub async fn delete_auth_key_handler(
Err(_) => invalid_auth_key_param_response(&seconds_valid_or_key.0),
Ok(key_id) => match tracker.remove_auth_key(&key_id.to_string()).await {
Ok(_) => ok_response(),
Err(_) => failed_to_delete_key_response(),
Err(e) => failed_to_delete_key_response(e),
},
}
}

pub async fn reload_keys_handler(State(tracker): State<Arc<Tracker>>) -> Response {
match tracker.load_keys().await {
Ok(..) => ok_response(),
Err(..) => failed_to_reload_keys_response(),
Ok(_) => ok_response(),
Err(e) => failed_to_reload_keys_response(e),
}
}

Expand Down
26 changes: 14 additions & 12 deletions src/apis/responses.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
use std::error::Error;

use axum::http::{header, StatusCode};
use axum::response::{IntoResponse, Json, Response};
use serde::Serialize;
Expand Down Expand Up @@ -110,33 +112,33 @@ pub fn torrent_not_known_response() -> Response {
}

#[must_use]
pub fn failed_to_remove_torrent_from_whitelist_response() -> Response {
unhandled_rejection_response("failed to remove torrent from whitelist".to_string())
pub fn failed_to_remove_torrent_from_whitelist_response<E: Error>(e: E) -> Response {
unhandled_rejection_response(format!("failed to remove torrent from whitelist: {e}").to_string())
}

#[must_use]
pub fn failed_to_whitelist_torrent_response() -> Response {
unhandled_rejection_response("failed to whitelist torrent".to_string())
pub fn failed_to_whitelist_torrent_response<E: Error>(e: E) -> Response {
unhandled_rejection_response(format!("failed to whitelist torrent: {e}").to_string())
}

#[must_use]
pub fn failed_to_reload_whitelist_response() -> Response {
unhandled_rejection_response("failed to reload whitelist".to_string())
pub fn failed_to_reload_whitelist_response<E: Error>(e: E) -> Response {
unhandled_rejection_response(format!("failed to reload whitelist: {e}").to_string())
}

#[must_use]
pub fn failed_to_generate_key_response() -> Response {
unhandled_rejection_response("failed to generate key".to_string())
pub fn failed_to_generate_key_response<E: Error>(e: E) -> Response {
unhandled_rejection_response(format!("failed to generate key: {e}").to_string())
}

#[must_use]
pub fn failed_to_delete_key_response() -> Response {
unhandled_rejection_response("failed to delete key".to_string())
pub fn failed_to_delete_key_response<E: Error>(e: E) -> Response {
unhandled_rejection_response(format!("failed to delete key: {e}").to_string())
}

#[must_use]
pub fn failed_to_reload_keys_response() -> Response {
unhandled_rejection_response("failed to reload keys".to_string())
pub fn failed_to_reload_keys_response<E: Error>(e: E) -> Response {
unhandled_rejection_response(format!("failed to reload keys: {e}").to_string())
}

/// This error response is to keep backward compatibility with the old Warp API.
Expand Down
25 changes: 24 additions & 1 deletion src/databases/driver.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,30 @@
use serde::{Deserialize, Serialize};

#[derive(Serialize, Deserialize, PartialEq, Eq, Debug)]
use super::error::Error;
use super::mysql::Mysql;
use super::sqlite::Sqlite;
use super::{Builder, Database};

#[derive(Serialize, Deserialize, PartialEq, Eq, Debug, derive_more::Display, Clone)]
pub enum Driver {
Sqlite3,
MySQL,
}

impl Driver {
/// .
///
/// # Errors
///
/// This function will return an error if unable to connect to the database.
pub fn build(&self, db_path: &str) -> Result<Box<dyn Database>, Error> {
let database = match self {
Driver::Sqlite3 => Builder::<Sqlite>::build(db_path),
Driver::MySQL => Builder::<Mysql>::build(db_path),
}?;

database.create_database_tables().expect("Could not create database tables.");

Ok(database)
}
}
100 changes: 87 additions & 13 deletions src/databases/error.rs
Original file line number Diff line number Diff line change
@@ -1,21 +1,95 @@
use derive_more::{Display, Error};
use std::panic::Location;
use std::sync::Arc;

#[derive(Debug, Display, PartialEq, Eq, Error)]
#[allow(dead_code)]
use r2d2_mysql::mysql::UrlError;

use super::driver::Driver;
use crate::located_error::{Located, LocatedError};

#[derive(thiserror::Error, Debug, Clone)]
pub enum Error {
#[display(fmt = "Query returned no rows.")]
QueryReturnedNoRows,
#[display(fmt = "Invalid query.")]
InvalidQuery,
#[display(fmt = "Database error.")]
DatabaseError,
#[error("The {driver} query unexpectedly returned nothing: {source}")]
QueryReturnedNoRows {
source: LocatedError<'static, dyn std::error::Error>,
driver: Driver,
},

#[error("The {driver} query was malformed: {source}")]
InvalidQuery {
source: LocatedError<'static, dyn std::error::Error>,
driver: Driver,
},

#[error("Unable to insert record into {driver} database, {location}")]
InsertFailed {
location: &'static Location<'static>,
driver: Driver,
},

#[error("Failed to remove record from {driver} database, error-code: {error_code}, {location}")]
DeleteFailed {
location: &'static Location<'static>,
error_code: usize,
driver: Driver,
},

#[error("Failed to connect to {driver} database: {source}")]
ConnectionError {
source: LocatedError<'static, UrlError>,
driver: Driver,
},

#[error("Failed to create r2d2 {driver} connection pool: {source}")]
ConnectionPool {
source: LocatedError<'static, r2d2::Error>,
driver: Driver,
},
}

impl From<r2d2_sqlite::rusqlite::Error> for Error {
fn from(e: r2d2_sqlite::rusqlite::Error) -> Self {
match e {
r2d2_sqlite::rusqlite::Error::QueryReturnedNoRows => Error::QueryReturnedNoRows,
_ => Error::InvalidQuery,
#[track_caller]
fn from(err: r2d2_sqlite::rusqlite::Error) -> Self {
match err {
r2d2_sqlite::rusqlite::Error::QueryReturnedNoRows => Error::QueryReturnedNoRows {
source: (Arc::new(err) as Arc<dyn std::error::Error>).into(),
driver: Driver::Sqlite3,
},
_ => Error::InvalidQuery {
source: (Arc::new(err) as Arc<dyn std::error::Error>).into(),
driver: Driver::Sqlite3,
},
}
}
}

impl From<r2d2_mysql::mysql::Error> for Error {
#[track_caller]
fn from(err: r2d2_mysql::mysql::Error) -> Self {
let e: Arc<dyn std::error::Error> = Arc::new(err);
Error::InvalidQuery {
source: e.into(),
driver: Driver::MySQL,
}
}
}

impl From<UrlError> for Error {
#[track_caller]
fn from(err: UrlError) -> Self {
Self::ConnectionError {
source: Located(err).into(),
driver: Driver::MySQL,
}
}
}

impl From<(r2d2::Error, Driver)> for Error {
#[track_caller]
fn from(e: (r2d2::Error, Driver)) -> Self {
let (err, driver) = e;
Self::ConnectionPool {
source: Located(err).into(),
driver,
}
}
}
70 changes: 38 additions & 32 deletions src/databases/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,37 +3,48 @@ pub mod error;
pub mod mysql;
pub mod sqlite;

use std::marker::PhantomData;

use async_trait::async_trait;

use self::driver::Driver;
use self::error::Error;
use crate::databases::mysql::Mysql;
use crate::databases::sqlite::Sqlite;
use crate::protocol::info_hash::InfoHash;
use crate::tracker::auth;

/// # Errors
///
/// Will return `r2d2::Error` if `db_path` is not able to create a database.
pub fn connect(db_driver: &Driver, db_path: &str) -> Result<Box<dyn Database>, r2d2::Error> {
let database: Box<dyn Database> = match db_driver {
Driver::Sqlite3 => {
let db = Sqlite::new(db_path)?;
Box::new(db)
}
Driver::MySQL => {
let db = Mysql::new(db_path)?;
Box::new(db)
}
};

database.create_database_tables().expect("Could not create database tables.");

Ok(database)
pub(self) struct Builder<T>
where
T: Database,
{
phantom: PhantomData<T>,
}

impl<T> Builder<T>
where
T: Database + 'static,
{
/// .
///
/// # Errors
///
/// Will return `r2d2::Error` if `db_path` is not able to create a database.
pub(self) fn build(db_path: &str) -> Result<Box<dyn Database>, Error> {
Ok(Box::new(T::new(db_path)?))
}
}

#[async_trait]
pub trait Database: Sync + Send {
/// .
///
/// # Errors
///
/// Will return `r2d2::Error` if `db_path` is not able to create a database.
fn new(db_path: &str) -> Result<Self, Error>
where
Self: std::marker::Sized;

/// .
///
/// # Errors
///
/// Will return `Error` if unable to create own tables.
Expand All @@ -52,27 +63,22 @@ pub trait Database: Sync + Send {

async fn save_persistent_torrent(&self, info_hash: &InfoHash, completed: u32) -> Result<(), Error>;

async fn get_info_hash_from_whitelist(&self, info_hash: &str) -> Result<InfoHash, Error>;
async fn get_info_hash_from_whitelist(&self, info_hash: &str) -> Result<Option<InfoHash>, Error>;

async fn add_info_hash_to_whitelist(&self, info_hash: InfoHash) -> Result<usize, Error>;

async fn remove_info_hash_from_whitelist(&self, info_hash: InfoHash) -> Result<usize, Error>;

async fn get_key_from_keys(&self, key: &str) -> Result<auth::Key, Error>;
async fn get_key_from_keys(&self, key: &str) -> Result<Option<auth::Key>, Error>;

async fn add_key_to_keys(&self, auth_key: &auth::Key) -> Result<usize, Error>;

async fn remove_key_from_keys(&self, key: &str) -> Result<usize, Error>;

async fn is_info_hash_whitelisted(&self, info_hash: &InfoHash) -> Result<bool, Error> {
self.get_info_hash_from_whitelist(&info_hash.clone().to_string())
.await
.map_or_else(
|e| match e {
Error::QueryReturnedNoRows => Ok(false),
e => Err(e),
},
|_| Ok(true),
)
Ok(self
.get_info_hash_from_whitelist(&info_hash.clone().to_string())
.await?
.is_some())
}
}
Loading