diff --git a/src/apis/handlers.rs b/src/apis/handlers.rs index 8d968902..38959edb 100644 --- a/src/apis/handlers.rs +++ b/src/apis/handlers.rs @@ -66,8 +66,8 @@ pub async fn add_torrent_to_whitelist_handler( match InfoHash::from_str(&info_hash.0) { Err(_) => invalid_info_hash_param_response(&info_hash.0), Ok(info_hash) => match tracker.add_torrent_to_whitelist(&info_hash).await { - Ok(..) => ok_response(), - Err(..) => failed_to_whitelist_torrent_response(), + Ok(_) => ok_response(), + Err(e) => failed_to_whitelist_torrent_response(e), }, } } @@ -79,16 +79,16 @@ pub async fn remove_torrent_from_whitelist_handler( match InfoHash::from_str(&info_hash.0) { Err(_) => invalid_info_hash_param_response(&info_hash.0), Ok(info_hash) => match tracker.remove_torrent_from_whitelist(&info_hash).await { - Ok(..) => ok_response(), - Err(..) => failed_to_remove_torrent_from_whitelist_response(), + Ok(_) => ok_response(), + Err(e) => failed_to_remove_torrent_from_whitelist_response(e), }, } } pub async fn reload_whitelist_handler(State(tracker): State>) -> Response { match tracker.load_whitelist().await { - Ok(..) => ok_response(), - Err(..) => failed_to_reload_whitelist_response(), + Ok(_) => ok_response(), + Err(e) => failed_to_reload_whitelist_response(e), } } @@ -96,7 +96,7 @@ pub async fn generate_auth_key_handler(State(tracker): State>, Path let seconds_valid = seconds_valid_or_key; match tracker.generate_auth_key(Duration::from_secs(seconds_valid)).await { Ok(auth_key) => auth_key_response(&AuthKey::from(auth_key)), - Err(_) => failed_to_generate_key_response(), + Err(e) => failed_to_generate_key_response(e), } } @@ -111,15 +111,15 @@ pub async fn delete_auth_key_handler( Err(_) => invalid_auth_key_param_response(&seconds_valid_or_key.0), Ok(key_id) => match tracker.remove_auth_key(&key_id.to_string()).await { Ok(_) => ok_response(), - Err(_) => failed_to_delete_key_response(), + Err(e) => failed_to_delete_key_response(e), }, } } pub async fn reload_keys_handler(State(tracker): State>) -> Response { match tracker.load_keys().await { - Ok(..) => ok_response(), - Err(..) => failed_to_reload_keys_response(), + Ok(_) => ok_response(), + Err(e) => failed_to_reload_keys_response(e), } } diff --git a/src/apis/responses.rs b/src/apis/responses.rs index b150b4bf..3704c7a1 100644 --- a/src/apis/responses.rs +++ b/src/apis/responses.rs @@ -1,3 +1,5 @@ +use std::error::Error; + use axum::http::{header, StatusCode}; use axum::response::{IntoResponse, Json, Response}; use serde::Serialize; @@ -110,33 +112,33 @@ pub fn torrent_not_known_response() -> Response { } #[must_use] -pub fn failed_to_remove_torrent_from_whitelist_response() -> Response { - unhandled_rejection_response("failed to remove torrent from whitelist".to_string()) +pub fn failed_to_remove_torrent_from_whitelist_response(e: E) -> Response { + unhandled_rejection_response(format!("failed to remove torrent from whitelist: {e}").to_string()) } #[must_use] -pub fn failed_to_whitelist_torrent_response() -> Response { - unhandled_rejection_response("failed to whitelist torrent".to_string()) +pub fn failed_to_whitelist_torrent_response(e: E) -> Response { + unhandled_rejection_response(format!("failed to whitelist torrent: {e}").to_string()) } #[must_use] -pub fn failed_to_reload_whitelist_response() -> Response { - unhandled_rejection_response("failed to reload whitelist".to_string()) +pub fn failed_to_reload_whitelist_response(e: E) -> Response { + unhandled_rejection_response(format!("failed to reload whitelist: {e}").to_string()) } #[must_use] -pub fn failed_to_generate_key_response() -> Response { - unhandled_rejection_response("failed to generate key".to_string()) +pub fn failed_to_generate_key_response(e: E) -> Response { + unhandled_rejection_response(format!("failed to generate key: {e}").to_string()) } #[must_use] -pub fn failed_to_delete_key_response() -> Response { - unhandled_rejection_response("failed to delete key".to_string()) +pub fn failed_to_delete_key_response(e: E) -> Response { + unhandled_rejection_response(format!("failed to delete key: {e}").to_string()) } #[must_use] -pub fn failed_to_reload_keys_response() -> Response { - unhandled_rejection_response("failed to reload keys".to_string()) +pub fn failed_to_reload_keys_response(e: E) -> Response { + unhandled_rejection_response(format!("failed to reload keys: {e}").to_string()) } /// This error response is to keep backward compatibility with the old Warp API. diff --git a/src/databases/driver.rs b/src/databases/driver.rs index 7eaa9064..c601f186 100644 --- a/src/databases/driver.rs +++ b/src/databases/driver.rs @@ -1,7 +1,30 @@ use serde::{Deserialize, Serialize}; -#[derive(Serialize, Deserialize, PartialEq, Eq, Debug)] +use super::error::Error; +use super::mysql::Mysql; +use super::sqlite::Sqlite; +use super::{Builder, Database}; + +#[derive(Serialize, Deserialize, PartialEq, Eq, Debug, derive_more::Display, Clone)] pub enum Driver { Sqlite3, MySQL, } + +impl Driver { + /// . + /// + /// # Errors + /// + /// This function will return an error if unable to connect to the database. + pub fn build(&self, db_path: &str) -> Result, Error> { + let database = match self { + Driver::Sqlite3 => Builder::::build(db_path), + Driver::MySQL => Builder::::build(db_path), + }?; + + database.create_database_tables().expect("Could not create database tables."); + + Ok(database) + } +} diff --git a/src/databases/error.rs b/src/databases/error.rs index 467db407..b0081c40 100644 --- a/src/databases/error.rs +++ b/src/databases/error.rs @@ -1,21 +1,95 @@ -use derive_more::{Display, Error}; +use std::panic::Location; +use std::sync::Arc; -#[derive(Debug, Display, PartialEq, Eq, Error)] -#[allow(dead_code)] +use r2d2_mysql::mysql::UrlError; + +use super::driver::Driver; +use crate::located_error::{Located, LocatedError}; + +#[derive(thiserror::Error, Debug, Clone)] pub enum Error { - #[display(fmt = "Query returned no rows.")] - QueryReturnedNoRows, - #[display(fmt = "Invalid query.")] - InvalidQuery, - #[display(fmt = "Database error.")] - DatabaseError, + #[error("The {driver} query unexpectedly returned nothing: {source}")] + QueryReturnedNoRows { + source: LocatedError<'static, dyn std::error::Error>, + driver: Driver, + }, + + #[error("The {driver} query was malformed: {source}")] + InvalidQuery { + source: LocatedError<'static, dyn std::error::Error>, + driver: Driver, + }, + + #[error("Unable to insert record into {driver} database, {location}")] + InsertFailed { + location: &'static Location<'static>, + driver: Driver, + }, + + #[error("Failed to remove record from {driver} database, error-code: {error_code}, {location}")] + DeleteFailed { + location: &'static Location<'static>, + error_code: usize, + driver: Driver, + }, + + #[error("Failed to connect to {driver} database: {source}")] + ConnectionError { + source: LocatedError<'static, UrlError>, + driver: Driver, + }, + + #[error("Failed to create r2d2 {driver} connection pool: {source}")] + ConnectionPool { + source: LocatedError<'static, r2d2::Error>, + driver: Driver, + }, } impl From for Error { - fn from(e: r2d2_sqlite::rusqlite::Error) -> Self { - match e { - r2d2_sqlite::rusqlite::Error::QueryReturnedNoRows => Error::QueryReturnedNoRows, - _ => Error::InvalidQuery, + #[track_caller] + fn from(err: r2d2_sqlite::rusqlite::Error) -> Self { + match err { + r2d2_sqlite::rusqlite::Error::QueryReturnedNoRows => Error::QueryReturnedNoRows { + source: (Arc::new(err) as Arc).into(), + driver: Driver::Sqlite3, + }, + _ => Error::InvalidQuery { + source: (Arc::new(err) as Arc).into(), + driver: Driver::Sqlite3, + }, + } + } +} + +impl From for Error { + #[track_caller] + fn from(err: r2d2_mysql::mysql::Error) -> Self { + let e: Arc = Arc::new(err); + Error::InvalidQuery { + source: e.into(), + driver: Driver::MySQL, + } + } +} + +impl From for Error { + #[track_caller] + fn from(err: UrlError) -> Self { + Self::ConnectionError { + source: Located(err).into(), + driver: Driver::MySQL, + } + } +} + +impl From<(r2d2::Error, Driver)> for Error { + #[track_caller] + fn from(e: (r2d2::Error, Driver)) -> Self { + let (err, driver) = e; + Self::ConnectionPool { + source: Located(err).into(), + driver, } } } diff --git a/src/databases/mod.rs b/src/databases/mod.rs index 873dd70e..809decc2 100644 --- a/src/databases/mod.rs +++ b/src/databases/mod.rs @@ -3,37 +3,48 @@ pub mod error; pub mod mysql; pub mod sqlite; +use std::marker::PhantomData; + use async_trait::async_trait; -use self::driver::Driver; use self::error::Error; -use crate::databases::mysql::Mysql; -use crate::databases::sqlite::Sqlite; use crate::protocol::info_hash::InfoHash; use crate::tracker::auth; -/// # Errors -/// -/// Will return `r2d2::Error` if `db_path` is not able to create a database. -pub fn connect(db_driver: &Driver, db_path: &str) -> Result, r2d2::Error> { - let database: Box = match db_driver { - Driver::Sqlite3 => { - let db = Sqlite::new(db_path)?; - Box::new(db) - } - Driver::MySQL => { - let db = Mysql::new(db_path)?; - Box::new(db) - } - }; - - database.create_database_tables().expect("Could not create database tables."); - - Ok(database) +pub(self) struct Builder +where + T: Database, +{ + phantom: PhantomData, +} + +impl Builder +where + T: Database + 'static, +{ + /// . + /// + /// # Errors + /// + /// Will return `r2d2::Error` if `db_path` is not able to create a database. + pub(self) fn build(db_path: &str) -> Result, Error> { + Ok(Box::new(T::new(db_path)?)) + } } #[async_trait] pub trait Database: Sync + Send { + /// . + /// + /// # Errors + /// + /// Will return `r2d2::Error` if `db_path` is not able to create a database. + fn new(db_path: &str) -> Result + where + Self: std::marker::Sized; + + /// . + /// /// # Errors /// /// Will return `Error` if unable to create own tables. @@ -52,27 +63,22 @@ pub trait Database: Sync + Send { async fn save_persistent_torrent(&self, info_hash: &InfoHash, completed: u32) -> Result<(), Error>; - async fn get_info_hash_from_whitelist(&self, info_hash: &str) -> Result; + async fn get_info_hash_from_whitelist(&self, info_hash: &str) -> Result, Error>; async fn add_info_hash_to_whitelist(&self, info_hash: InfoHash) -> Result; async fn remove_info_hash_from_whitelist(&self, info_hash: InfoHash) -> Result; - async fn get_key_from_keys(&self, key: &str) -> Result; + async fn get_key_from_keys(&self, key: &str) -> Result, Error>; async fn add_key_to_keys(&self, auth_key: &auth::Key) -> Result; async fn remove_key_from_keys(&self, key: &str) -> Result; async fn is_info_hash_whitelisted(&self, info_hash: &InfoHash) -> Result { - self.get_info_hash_from_whitelist(&info_hash.clone().to_string()) - .await - .map_or_else( - |e| match e { - Error::QueryReturnedNoRows => Ok(false), - e => Err(e), - }, - |_| Ok(true), - ) + Ok(self + .get_info_hash_from_whitelist(&info_hash.clone().to_string()) + .await? + .is_some()) } } diff --git a/src/databases/mysql.rs b/src/databases/mysql.rs index 71b06378..ac54ebb8 100644 --- a/src/databases/mysql.rs +++ b/src/databases/mysql.rs @@ -8,33 +8,32 @@ use r2d2_mysql::mysql::prelude::Queryable; use r2d2_mysql::mysql::{params, Opts, OptsBuilder}; use r2d2_mysql::MysqlConnectionManager; +use super::driver::Driver; use crate::databases::{Database, Error}; use crate::protocol::common::AUTH_KEY_LENGTH; use crate::protocol::info_hash::InfoHash; use crate::tracker::auth; +const DRIVER: Driver = Driver::MySQL; + pub struct Mysql { pool: Pool, } -impl Mysql { +#[async_trait] +impl Database for Mysql { /// # Errors /// /// Will return `r2d2::Error` if `db_path` is not able to create `MySQL` database. - pub fn new(db_path: &str) -> Result { - let opts = Opts::from_url(db_path).expect("Failed to connect to MySQL database."); + fn new(db_path: &str) -> Result { + let opts = Opts::from_url(db_path)?; let builder = OptsBuilder::from_opts(opts); let manager = MysqlConnectionManager::new(builder); - let pool = r2d2::Pool::builder() - .build(manager) - .expect("Failed to create r2d2 MySQL connection pool."); + let pool = r2d2::Pool::builder().build(manager).map_err(|e| (e, DRIVER))?; Ok(Self { pool }) } -} -#[async_trait] -impl Database for Mysql { fn create_database_tables(&self) -> Result<(), Error> { let create_whitelist_table = " CREATE TABLE IF NOT EXISTS whitelist ( @@ -63,7 +62,7 @@ impl Database for Mysql { i8::try_from(AUTH_KEY_LENGTH).expect("auth::Auth Key Length Should fit within a i8!") ); - let mut conn = self.pool.get().map_err(|_| Error::DatabaseError)?; + let mut conn = self.pool.get().map_err(|e| (e, DRIVER))?; conn.query_drop(&create_torrents_table) .expect("Could not create torrents table."); @@ -87,7 +86,7 @@ impl Database for Mysql { DROP TABLE `keys`;" .to_string(); - let mut conn = self.pool.get().map_err(|_| Error::DatabaseError)?; + let mut conn = self.pool.get().map_err(|e| (e, DRIVER))?; conn.query_drop(&drop_whitelist_table) .expect("Could not drop `whitelist` table."); @@ -99,155 +98,124 @@ impl Database for Mysql { } async fn load_persistent_torrents(&self) -> Result, Error> { - let mut conn = self.pool.get().map_err(|_| Error::DatabaseError)?; - - let torrents: Vec<(InfoHash, u32)> = conn - .query_map( - "SELECT info_hash, completed FROM torrents", - |(info_hash_string, completed): (String, u32)| { - let info_hash = InfoHash::from_str(&info_hash_string).unwrap(); - (info_hash, completed) - }, - ) - .map_err(|_| Error::QueryReturnedNoRows)?; + let mut conn = self.pool.get().map_err(|e| (e, DRIVER))?; + + let torrents = conn.query_map( + "SELECT info_hash, completed FROM torrents", + |(info_hash_string, completed): (String, u32)| { + let info_hash = InfoHash::from_str(&info_hash_string).unwrap(); + (info_hash, completed) + }, + )?; Ok(torrents) } async fn load_keys(&self) -> Result, Error> { - let mut conn = self.pool.get().map_err(|_| Error::DatabaseError)?; - - let keys: Vec = conn - .query_map( - "SELECT `key`, valid_until FROM `keys`", - |(key, valid_until): (String, i64)| auth::Key { - key, - valid_until: Some(Duration::from_secs(valid_until.unsigned_abs())), - }, - ) - .map_err(|_| Error::QueryReturnedNoRows)?; + let mut conn = self.pool.get().map_err(|e| (e, DRIVER))?; + + let keys = conn.query_map( + "SELECT `key`, valid_until FROM `keys`", + |(key, valid_until): (String, i64)| auth::Key { + key, + valid_until: Some(Duration::from_secs(valid_until.unsigned_abs())), + }, + )?; Ok(keys) } async fn load_whitelist(&self) -> Result, Error> { - let mut conn = self.pool.get().map_err(|_| Error::DatabaseError)?; + let mut conn = self.pool.get().map_err(|e| (e, DRIVER))?; - let info_hashes: Vec = conn - .query_map("SELECT info_hash FROM whitelist", |info_hash: String| { - InfoHash::from_str(&info_hash).unwrap() - }) - .map_err(|_| Error::QueryReturnedNoRows)?; + let info_hashes = conn.query_map("SELECT info_hash FROM whitelist", |info_hash: String| { + InfoHash::from_str(&info_hash).unwrap() + })?; Ok(info_hashes) } async fn save_persistent_torrent(&self, info_hash: &InfoHash, completed: u32) -> Result<(), Error> { - let mut conn = self.pool.get().map_err(|_| Error::DatabaseError)?; + const COMMAND : &str = "INSERT INTO torrents (info_hash, completed) VALUES (:info_hash_str, :completed) ON DUPLICATE KEY UPDATE completed = VALUES(completed)"; + + let mut conn = self.pool.get().map_err(|e| (e, DRIVER))?; let info_hash_str = info_hash.to_string(); debug!("{}", info_hash_str); - match conn.exec_drop("INSERT INTO torrents (info_hash, completed) VALUES (:info_hash_str, :completed) ON DUPLICATE KEY UPDATE completed = VALUES(completed)", params! { info_hash_str, completed }) { - Ok(_) => { - Ok(()) - } - Err(e) => { - debug!("{:?}", e); - Err(Error::InvalidQuery) - } - } + Ok(conn.exec_drop(COMMAND, params! { info_hash_str, completed })?) } - async fn get_info_hash_from_whitelist(&self, info_hash: &str) -> Result { - let mut conn = self.pool.get().map_err(|_| Error::DatabaseError)?; - - match conn - .exec_first::( - "SELECT info_hash FROM whitelist WHERE info_hash = :info_hash", - params! { info_hash }, - ) - .map_err(|_| Error::DatabaseError)? - { - Some(info_hash) => Ok(InfoHash::from_str(&info_hash).unwrap()), - None => Err(Error::QueryReturnedNoRows), - } + async fn get_info_hash_from_whitelist(&self, info_hash: &str) -> Result, Error> { + let mut conn = self.pool.get().map_err(|e| (e, DRIVER))?; + + let select = conn.exec_first::( + "SELECT info_hash FROM whitelist WHERE info_hash = :info_hash", + params! { info_hash }, + )?; + + let info_hash = select.map(|f| InfoHash::from_str(&f).expect("Failed to decode InfoHash String from DB!")); + + Ok(info_hash) } async fn add_info_hash_to_whitelist(&self, info_hash: InfoHash) -> Result { - let mut conn = self.pool.get().map_err(|_| Error::DatabaseError)?; + let mut conn = self.pool.get().map_err(|e| (e, DRIVER))?; let info_hash_str = info_hash.to_string(); - match conn.exec_drop( + conn.exec_drop( "INSERT INTO whitelist (info_hash) VALUES (:info_hash_str)", params! { info_hash_str }, - ) { - Ok(_) => Ok(1), - Err(e) => { - debug!("{:?}", e); - Err(Error::InvalidQuery) - } - } + )?; + + Ok(1) } async fn remove_info_hash_from_whitelist(&self, info_hash: InfoHash) -> Result { - let mut conn = self.pool.get().map_err(|_| Error::DatabaseError)?; + let mut conn = self.pool.get().map_err(|e| (e, DRIVER))?; let info_hash = info_hash.to_string(); - match conn.exec_drop("DELETE FROM whitelist WHERE info_hash = :info_hash", params! { info_hash }) { - Ok(_) => Ok(1), - Err(e) => { - debug!("{:?}", e); - Err(Error::InvalidQuery) - } - } + conn.exec_drop("DELETE FROM whitelist WHERE info_hash = :info_hash", params! { info_hash })?; + + Ok(1) } - async fn get_key_from_keys(&self, key: &str) -> Result { - let mut conn = self.pool.get().map_err(|_| Error::DatabaseError)?; + async fn get_key_from_keys(&self, key: &str) -> Result, Error> { + let mut conn = self.pool.get().map_err(|e| (e, DRIVER))?; - match conn - .exec_first::<(String, i64), _, _>("SELECT `key`, valid_until FROM `keys` WHERE `key` = :key", params! { key }) - .map_err(|_| Error::QueryReturnedNoRows)? - { - Some((key, valid_until)) => Ok(auth::Key { - key, - valid_until: Some(Duration::from_secs(valid_until.unsigned_abs())), - }), - None => Err(Error::InvalidQuery), - } + let query = + conn.exec_first::<(String, i64), _, _>("SELECT `key`, valid_until FROM `keys` WHERE `key` = :key", params! { key }); + + let key = query?; + + Ok(key.map(|(key, expiry)| auth::Key { + key, + valid_until: Some(Duration::from_secs(expiry.unsigned_abs())), + })) } async fn add_key_to_keys(&self, auth_key: &auth::Key) -> Result { - let mut conn = self.pool.get().map_err(|_| Error::DatabaseError)?; + let mut conn = self.pool.get().map_err(|e| (e, DRIVER))?; let key = auth_key.key.to_string(); let valid_until = auth_key.valid_until.unwrap_or(Duration::ZERO).as_secs().to_string(); - match conn.exec_drop( + conn.exec_drop( "INSERT INTO `keys` (`key`, valid_until) VALUES (:key, :valid_until)", params! { key, valid_until }, - ) { - Ok(_) => Ok(1), - Err(e) => { - debug!("{:?}", e); - Err(Error::InvalidQuery) - } - } + )?; + + Ok(1) } async fn remove_key_from_keys(&self, key: &str) -> Result { - let mut conn = self.pool.get().map_err(|_| Error::DatabaseError)?; - - match conn.exec_drop("DELETE FROM `keys` WHERE key = :key", params! { key }) { - Ok(_) => Ok(1), - Err(e) => { - debug!("{:?}", e); - Err(Error::InvalidQuery) - } - } + let mut conn = self.pool.get().map_err(|e| (e, DRIVER))?; + + conn.exec_drop("DELETE FROM `keys` WHERE key = :key", params! { key })?; + + Ok(1) } } diff --git a/src/databases/sqlite.rs b/src/databases/sqlite.rs index 1d7caf05..3425b15c 100644 --- a/src/databases/sqlite.rs +++ b/src/databases/sqlite.rs @@ -1,32 +1,32 @@ +use std::panic::Location; use std::str::FromStr; use async_trait::async_trait; -use log::debug; use r2d2::Pool; use r2d2_sqlite::SqliteConnectionManager; +use super::driver::Driver; use crate::databases::{Database, Error}; use crate::protocol::clock::DurationSinceUnixEpoch; use crate::protocol::info_hash::InfoHash; use crate::tracker::auth; +const DRIVER: Driver = Driver::Sqlite3; + pub struct Sqlite { pool: Pool, } -impl Sqlite { +#[async_trait] +impl Database for Sqlite { /// # Errors /// /// Will return `r2d2::Error` if `db_path` is not able to create `SqLite` database. - pub fn new(db_path: &str) -> Result { + fn new(db_path: &str) -> Result { let cm = SqliteConnectionManager::file(db_path); - let pool = Pool::new(cm).expect("Failed to create r2d2 SQLite connection pool."); - Ok(Sqlite { pool }) + Pool::new(cm).map_or_else(|err| Err((err, Driver::Sqlite3).into()), |pool| Ok(Sqlite { pool })) } -} -#[async_trait] -impl Database for Sqlite { fn create_database_tables(&self) -> Result<(), Error> { let create_whitelist_table = " CREATE TABLE IF NOT EXISTS whitelist ( @@ -51,13 +51,13 @@ impl Database for Sqlite { );" .to_string(); - let conn = self.pool.get().map_err(|_| Error::DatabaseError)?; + let conn = self.pool.get().map_err(|e| (e, DRIVER))?; + + conn.execute(&create_whitelist_table, [])?; + conn.execute(&create_keys_table, [])?; + conn.execute(&create_torrents_table, [])?; - conn.execute(&create_whitelist_table, []) - .and_then(|_| conn.execute(&create_keys_table, [])) - .and_then(|_| conn.execute(&create_torrents_table, [])) - .map_err(|_| Error::InvalidQuery) - .map(|_| ()) + Ok(()) } fn drop_database_tables(&self) -> Result<(), Error> { @@ -73,17 +73,17 @@ impl Database for Sqlite { DROP TABLE keys;" .to_string(); - let conn = self.pool.get().map_err(|_| Error::DatabaseError)?; + let conn = self.pool.get().map_err(|e| (e, DRIVER))?; conn.execute(&drop_whitelist_table, []) .and_then(|_| conn.execute(&drop_torrents_table, [])) - .and_then(|_| conn.execute(&drop_keys_table, [])) - .map_err(|_| Error::InvalidQuery) - .map(|_| ()) + .and_then(|_| conn.execute(&drop_keys_table, []))?; + + Ok(()) } async fn load_persistent_torrents(&self) -> Result, Error> { - let conn = self.pool.get().map_err(|_| Error::DatabaseError)?; + let conn = self.pool.get().map_err(|e| (e, DRIVER))?; let mut stmt = conn.prepare("SELECT info_hash, completed FROM torrents")?; @@ -94,13 +94,16 @@ impl Database for Sqlite { Ok((info_hash, completed)) })?; + //torrent_iter?; + //let torrent_iter = torrent_iter.unwrap(); + let torrents: Vec<(InfoHash, u32)> = torrent_iter.filter_map(std::result::Result::ok).collect(); Ok(torrents) } async fn load_keys(&self) -> Result, Error> { - let conn = self.pool.get().map_err(|_| Error::DatabaseError)?; + let conn = self.pool.get().map_err(|e| (e, DRIVER))?; let mut stmt = conn.prepare("SELECT key, valid_until FROM keys")?; @@ -120,7 +123,7 @@ impl Database for Sqlite { } async fn load_whitelist(&self) -> Result, Error> { - let conn = self.pool.get().map_err(|_| Error::DatabaseError)?; + let conn = self.pool.get().map_err(|e| (e, DRIVER))?; let mut stmt = conn.prepare("SELECT info_hash FROM whitelist")?; @@ -136,130 +139,117 @@ impl Database for Sqlite { } async fn save_persistent_torrent(&self, info_hash: &InfoHash, completed: u32) -> Result<(), Error> { - let conn = self.pool.get().map_err(|_| Error::DatabaseError)?; + let conn = self.pool.get().map_err(|e| (e, DRIVER))?; - match conn.execute( + let insert = conn.execute( "INSERT INTO torrents (info_hash, completed) VALUES (?1, ?2) ON CONFLICT(info_hash) DO UPDATE SET completed = ?2", [info_hash.to_string(), completed.to_string()], - ) { - Ok(updated) => { - if updated > 0 { - return Ok(()); - } - Err(Error::QueryReturnedNoRows) - } - Err(e) => { - debug!("{:?}", e); - Err(Error::InvalidQuery) - } + )?; + + if insert == 0 { + Err(Error::InsertFailed { + location: Location::caller(), + driver: DRIVER, + }) + } else { + Ok(()) } } - async fn get_info_hash_from_whitelist(&self, info_hash: &str) -> Result { - let conn = self.pool.get().map_err(|_| Error::DatabaseError)?; + async fn get_info_hash_from_whitelist(&self, info_hash: &str) -> Result, Error> { + let conn = self.pool.get().map_err(|e| (e, DRIVER))?; let mut stmt = conn.prepare("SELECT info_hash FROM whitelist WHERE info_hash = ?")?; + let mut rows = stmt.query([info_hash])?; - match rows.next() { - Ok(row) => match row { - Some(row) => Ok(InfoHash::from_str(&row.get_unwrap::<_, String>(0)).unwrap()), - None => Err(Error::QueryReturnedNoRows), - }, - Err(e) => { - debug!("{:?}", e); - Err(Error::InvalidQuery) - } - } + let query = rows.next()?; + + Ok(query.map(|f| InfoHash::from_str(&f.get_unwrap::<_, String>(0)).unwrap())) } async fn add_info_hash_to_whitelist(&self, info_hash: InfoHash) -> Result { - let conn = self.pool.get().map_err(|_| Error::DatabaseError)?; - - match conn.execute("INSERT INTO whitelist (info_hash) VALUES (?)", [info_hash.to_string()]) { - Ok(updated) => { - if updated > 0 { - return Ok(updated); - } - Err(Error::QueryReturnedNoRows) - } - Err(e) => { - debug!("{:?}", e); - Err(Error::InvalidQuery) - } + let conn = self.pool.get().map_err(|e| (e, DRIVER))?; + + let insert = conn.execute("INSERT INTO whitelist (info_hash) VALUES (?)", [info_hash.to_string()])?; + + if insert == 0 { + Err(Error::InsertFailed { + location: Location::caller(), + driver: DRIVER, + }) + } else { + Ok(insert) } } async fn remove_info_hash_from_whitelist(&self, info_hash: InfoHash) -> Result { - let conn = self.pool.get().map_err(|_| Error::DatabaseError)?; - - match conn.execute("DELETE FROM whitelist WHERE info_hash = ?", [info_hash.to_string()]) { - Ok(updated) => { - if updated > 0 { - return Ok(updated); - } - Err(Error::QueryReturnedNoRows) - } - Err(e) => { - debug!("{:?}", e); - Err(Error::InvalidQuery) - } + let conn = self.pool.get().map_err(|e| (e, DRIVER))?; + + let deleted = conn.execute("DELETE FROM whitelist WHERE info_hash = ?", [info_hash.to_string()])?; + + if deleted == 1 { + // should only remove a single record. + Ok(deleted) + } else { + Err(Error::DeleteFailed { + location: Location::caller(), + error_code: deleted, + driver: DRIVER, + }) } } - async fn get_key_from_keys(&self, key: &str) -> Result { - let conn = self.pool.get().map_err(|_| Error::DatabaseError)?; + async fn get_key_from_keys(&self, key: &str) -> Result, Error> { + let conn = self.pool.get().map_err(|e| (e, DRIVER))?; let mut stmt = conn.prepare("SELECT key, valid_until FROM keys WHERE key = ?")?; + let mut rows = stmt.query([key.to_string()])?; - if let Some(row) = rows.next()? { - let key: String = row.get(0).unwrap(); - let valid_until: i64 = row.get(1).unwrap(); + let key = rows.next()?; - Ok(auth::Key { - key, - valid_until: Some(DurationSinceUnixEpoch::from_secs(valid_until.unsigned_abs())), - }) - } else { - Err(Error::QueryReturnedNoRows) - } + Ok(key.map(|f| { + let expiry: i64 = f.get(1).unwrap(); + auth::Key { + key: f.get(0).unwrap(), + valid_until: Some(DurationSinceUnixEpoch::from_secs(expiry.unsigned_abs())), + } + })) } async fn add_key_to_keys(&self, auth_key: &auth::Key) -> Result { - let conn = self.pool.get().map_err(|_| Error::DatabaseError)?; + let conn = self.pool.get().map_err(|e| (e, DRIVER))?; - match conn.execute( + let insert = conn.execute( "INSERT INTO keys (key, valid_until) VALUES (?1, ?2)", [auth_key.key.to_string(), auth_key.valid_until.unwrap().as_secs().to_string()], - ) { - Ok(updated) => { - if updated > 0 { - return Ok(updated); - } - Err(Error::QueryReturnedNoRows) - } - Err(e) => { - debug!("{:?}", e); - Err(Error::InvalidQuery) - } + )?; + + if insert == 0 { + Err(Error::InsertFailed { + location: Location::caller(), + driver: DRIVER, + }) + } else { + Ok(insert) } } async fn remove_key_from_keys(&self, key: &str) -> Result { - let conn = self.pool.get().map_err(|_| Error::DatabaseError)?; - - match conn.execute("DELETE FROM keys WHERE key = ?", [key]) { - Ok(updated) => { - if updated > 0 { - return Ok(updated); - } - Err(Error::QueryReturnedNoRows) - } - Err(e) => { - debug!("{:?}", e); - Err(Error::InvalidQuery) - } + let conn = self.pool.get().map_err(|e| (e, DRIVER))?; + + let deleted = conn.execute("DELETE FROM keys WHERE key = ?", [key])?; + + if deleted == 1 { + // should only remove a single record. + Ok(deleted) + } else { + Err(Error::DeleteFailed { + location: Location::caller(), + error_code: deleted, + driver: DRIVER, + }) } } } diff --git a/src/lib.rs b/src/lib.rs index e8cf5304..cbda2854 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -3,6 +3,7 @@ pub mod config; pub mod databases; pub mod http; pub mod jobs; +pub mod located_error; pub mod logging; pub mod protocol; pub mod setup; diff --git a/src/located_error.rs b/src/located_error.rs new file mode 100644 index 00000000..30e8cfad --- /dev/null +++ b/src/located_error.rs @@ -0,0 +1,103 @@ +// https://stackoverflow.com/questions/74336993/getting-line-numbers-with-when-using-boxdyn-stderrorerror + +use std::error::Error; +use std::panic::Location; +use std::sync::Arc; + +pub struct Located(pub E); + +#[derive(Debug)] +pub struct LocatedError<'a, E> +where + E: Error + ?Sized, +{ + source: Arc, + location: Box>, +} + +impl<'a, E> std::fmt::Display for LocatedError<'a, E> +where + E: Error + ?Sized, +{ + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "{}, {}", self.source, self.location) + } +} + +impl<'a, E> Error for LocatedError<'a, E> +where + E: Error + ?Sized + 'static, +{ + fn source(&self) -> Option<&(dyn Error + 'static)> { + Some(&self.source) + } +} + +impl<'a, E> Clone for LocatedError<'a, E> +where + E: Error + ?Sized, +{ + fn clone(&self) -> Self { + LocatedError { + source: self.source.clone(), + location: self.location.clone(), + } + } +} + +#[allow(clippy::from_over_into)] +impl<'a, E> Into> for Located +where + E: Error, + Arc: Clone, +{ + #[track_caller] + fn into(self) -> LocatedError<'a, E> { + let e = LocatedError { + source: Arc::new(self.0), + location: Box::new(*std::panic::Location::caller()), + }; + log::debug!("{e}"); + e + } +} + +#[allow(clippy::from_over_into)] +impl<'a> Into> for Arc { + #[track_caller] + fn into(self) -> LocatedError<'a, dyn std::error::Error> { + LocatedError { + source: self, + location: Box::new(*std::panic::Location::caller()), + } + } +} + +#[cfg(test)] +mod tests { + use std::panic::Location; + + use super::LocatedError; + use crate::located_error::Located; + + #[derive(thiserror::Error, Debug)] + enum TestError { + #[error("Test")] + Test, + } + + #[track_caller] + fn get_caller_location() -> Location<'static> { + *Location::caller() + } + + #[test] + fn error_should_include_location() { + let e = TestError::Test; + + let b: LocatedError = Located(e).into(); + let l = get_caller_location(); + + assert_eq!(b.location.file(), l.file()); + } +} diff --git a/src/tracker/mod.rs b/src/tracker/mod.rs index 4f1dab49..af4c7361 100644 --- a/src/tracker/mod.rs +++ b/src/tracker/mod.rs @@ -15,6 +15,7 @@ use tokio::sync::mpsc::error::SendError; use tokio::sync::{RwLock, RwLockReadGuard}; use crate::config::Configuration; +use crate::databases::driver::Driver; use crate::databases::{self, Database}; use crate::protocol::info_hash::InfoHash; @@ -40,13 +41,13 @@ pub struct TorrentsMetrics { impl Tracker { /// # Errors /// - /// Will return a `r2d2::Error` if unable to connect to database. + /// Will return a `databases::error::Error` if unable to connect to database. pub fn new( config: &Arc, stats_event_sender: Option>, stats_repository: statistics::Repo, - ) -> Result { - let database = databases::connect(&config.db_driver, &config.db_path)?; + ) -> Result { + let database = Driver::build(&config.db_driver, &config.db_path)?; Ok(Tracker { config: config.clone(), diff --git a/tests/api/asserts.rs b/tests/api/asserts.rs index 5f9d3970..6c0d3cae 100644 --- a/tests/api/asserts.rs +++ b/tests/api/asserts.rs @@ -37,9 +37,20 @@ pub async fn assert_auth_key_utf8(response: Response) -> AuthKey { // OK response pub async fn assert_ok(response: Response) { - assert_eq!(response.status(), 200); - assert_eq!(response.headers().get("content-type").unwrap(), "application/json"); - assert_eq!(response.text().await.unwrap(), "{\"status\":\"ok\"}"); + let response_status = response.status().clone(); + let response_headers = response.headers().get("content-type").cloned().unwrap(); + let response_text = response.text().await.unwrap(); + + let details = format!( + r#" + status: ´{response_status}´ + headers: ´{response_headers:?}´ + text: ´"{response_text}"´"# + ); + + assert_eq!(response_status, 200, "details:{details}."); + assert_eq!(response_headers, "application/json", "\ndetails:{details}."); + assert_eq!(response_text, "{\"status\":\"ok\"}", "\ndetails:{details}."); } // Error responses @@ -118,8 +129,11 @@ pub async fn assert_failed_to_reload_keys(response: Response) { async fn assert_unhandled_rejection(response: Response, reason: &str) { assert_eq!(response.status(), 500); assert_eq!(response.headers().get("content-type").unwrap(), "text/plain; charset=utf-8"); - assert_eq!( - response.text().await.unwrap(), - format!("Unhandled rejection: Err {{ reason: \"{reason}\" }}") + + let reason_text = format!("Unhandled rejection: Err {{ reason: \"{reason}"); + let response_text = response.text().await.unwrap(); + assert!( + response_text.contains(&reason_text), + ":\n response: `\"{response_text}\"`\n dose not contain: `\"{reason_text}\"`." ); }