diff --git a/set_version.sh b/set_version.sh index a3b74d0dd..d053ae30a 100755 --- a/set_version.sh +++ b/set_version.sh @@ -1,3 +1,3 @@ #!/bin/bash -export version=0.5.1-alpha.24 +export version=0.5.2-alpha.0 echo "pub const VERSION: &str = \"${version}\";" > src/version.rs diff --git a/src/database/entities/bots.rs b/src/database/entities/bots.rs new file mode 100644 index 000000000..d99da9bf6 --- /dev/null +++ b/src/database/entities/bots.rs @@ -0,0 +1,12 @@ +use serde::{Serialize, Deserialize}; + +#[derive(Serialize, Deserialize, Debug, Clone)] +pub struct Bot { + #[serde(rename = "_id")] + pub id: String, + pub owner: String, + pub token: String, + pub public: bool, + #[serde(skip_serializing_if = "Option::is_none")] + pub interactions_url: Option, +} diff --git a/src/database/entities/mod.rs b/src/database/entities/mod.rs index 9747de04f..dd3a7d7e1 100644 --- a/src/database/entities/mod.rs +++ b/src/database/entities/mod.rs @@ -5,6 +5,7 @@ mod microservice; mod server; mod sync; mod user; +mod bots; use microservice::*; @@ -16,3 +17,4 @@ pub use message::*; pub use server::*; pub use sync::*; pub use user::*; +pub use bots::*; diff --git a/src/database/entities/server.rs b/src/database/entities/server.rs index daac8eac5..6f050ad18 100644 --- a/src/database/entities/server.rs +++ b/src/database/entities/server.rs @@ -37,13 +37,20 @@ pub type PermissionTuple = ( i32 // channel permission ); +fn if_false(t: &bool) -> bool { + *t == false +} + #[derive(Serialize, Deserialize, Debug, Clone)] pub struct Role { pub name: String, pub permissions: PermissionTuple, #[serde(skip_serializing_if = "Option::is_none")] - pub colour: Option - // Bri'ish API conventions + pub colour: Option, + #[serde(skip_serializing_if = "if_false", default)] + pub hoist: bool, + #[serde(default)] + pub rank: i64, } #[derive(Serialize, Deserialize, Debug, Clone)] diff --git a/src/database/entities/user.rs b/src/database/entities/user.rs index a66f63960..c470635c0 100644 --- a/src/database/entities/user.rs +++ b/src/database/entities/user.rs @@ -73,7 +73,12 @@ pub enum Badges { impl_op_ex_commutative!(+ |a: &i32, b: &Badges| -> i32 { *a | *b as i32 }); -// When changing this struct, update notifications/payload.rs#80 +#[derive(Serialize, Deserialize, Debug, Clone)] +pub struct BotInformation { + owner: String +} + +// When changing this struct, update notifications/payload.rs#113 #[derive(Serialize, Deserialize, Debug, Clone)] pub struct User { #[serde(rename = "_id")] @@ -91,6 +96,11 @@ pub struct User { #[serde(skip_serializing_if = "Option::is_none")] pub profile: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub flags: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub bot: Option, + // ? This should never be pushed to the collection. #[serde(skip_serializing_if = "Option::is_none")] pub relationship: Option, diff --git a/src/database/guards/user.rs b/src/database/guards/user.rs index 265b8b510..394c6b842 100644 --- a/src/database/guards/user.rs +++ b/src/database/guards/user.rs @@ -10,6 +10,61 @@ impl<'r> FromRequest<'r> for User { type Error = rauth::util::Error; async fn from_request(request: &'r Request<'_>) -> request::Outcome { + let header_bot_token = request + .headers() + .get("x-bot-token") + .next() + .map(|x| x.to_string()); + + if let Some(bot_token) = header_bot_token { + return if let Ok(result) = get_collection("bots") + .find_one( + doc! { + "token": bot_token + }, + None, + ) + .await + { + if let Some(doc) = result { + let id = doc.get_str("_id").unwrap(); + if let Ok(result) = get_collection("users") + .find_one( + doc! { + "_id": &id + }, + None, + ) + .await + { + if let Some(doc) = result { + Outcome::Success(from_document(doc).unwrap()) + } else { + Outcome::Failure((Status::Forbidden, rauth::util::Error::InvalidSession)) + } + } else { + Outcome::Failure(( + Status::InternalServerError, + rauth::util::Error::DatabaseError { + operation: "find_one", + with: "user", + }, + )) + } + } else { + Outcome::Failure((Status::Forbidden, rauth::util::Error::InvalidSession)) + } + } else { + Outcome::Failure(( + Status::InternalServerError, + rauth::util::Error::DatabaseError { + operation: "find_one", + with: "bot", + }, + )) + } + } + let session: Session = request.guard::().await.unwrap(); if let Ok(result) = get_collection("users") diff --git a/src/database/migrations/init.rs b/src/database/migrations/init.rs index be71b3994..7bb929a83 100644 --- a/src/database/migrations/init.rs +++ b/src/database/migrations/init.rs @@ -57,6 +57,10 @@ pub async fn create_database() { .await .expect("Failed to create user_settings collection."); + db.create_collection("bots", None) + .await + .expect("Failed to create bots collection."); + db.create_collection( "pubsub", CreateCollectionOptions::builder() diff --git a/src/database/migrations/scripts.rs b/src/database/migrations/scripts.rs index 342f52f73..14e30e8cc 100644 --- a/src/database/migrations/scripts.rs +++ b/src/database/migrations/scripts.rs @@ -11,7 +11,7 @@ struct MigrationInfo { revision: i32, } -pub const LATEST_REVISION: i32 = 7; +pub const LATEST_REVISION: i32 = 8; pub async fn migrate_database() { let migrations = get_collection("migrations"); @@ -203,6 +203,15 @@ pub async fn run_migrations(revision: i32) -> i32 { .expect("Failed to create message index."); } + if revision <= 7 { + info!("Running migration [revision 7 / 2021-08-11]: Add message text index."); + + get_db() + .create_collection("bots", None) + .await + .expect("Failed to create bots collection."); + } + // Reminder to update LATEST_REVISION when adding new migrations. LATEST_REVISION } diff --git a/src/notifications/events.rs b/src/notifications/events.rs index 3976b7f75..f5c9a0aca 100644 --- a/src/notifications/events.rs +++ b/src/notifications/events.rs @@ -17,10 +17,22 @@ pub enum WebSocketError { AlreadyAuthenticated, } +#[derive(Deserialize, Debug)] +pub struct BotAuth { + pub token: String +} + +#[derive(Deserialize, Debug)] +#[serde(untagged)] +pub enum AuthType { + User(Session), + Bot(BotAuth) +} + #[derive(Deserialize, Debug)] #[serde(tag = "type")] pub enum ServerboundNotification { - Authenticate(Session), + Authenticate(AuthType), BeginTyping { channel: String }, EndTyping { channel: String }, } diff --git a/src/notifications/websocket.rs b/src/notifications/websocket.rs index 1f0773970..654292398 100644 --- a/src/notifications/websocket.rs +++ b/src/notifications/websocket.rs @@ -1,4 +1,5 @@ use crate::database::*; +use crate::notifications::events::{AuthType, BotAuth}; use crate::util::variables::WS_HOST; use super::subscriptions; @@ -12,8 +13,9 @@ use futures::{pin_mut, prelude::*}; use hive_pubsub::PubSub; use log::{debug, info}; use many_to_many::ManyToMany; +use mongodb::bson::doc; use rauth::{ - auth::{Auth, Session}, + auth::{Auth}, options::Options, }; use std::collections::HashMap; @@ -66,15 +68,15 @@ async fn accept(stream: TcpStream) { } }; - let session: Arc>> = Arc::new(Mutex::new(None)); - let mutex_generator = || session.clone(); + let user_id: Arc>> = Arc::new(Mutex::new(None)); + let mutex_generator = || user_id.clone(); let fwd = rx.map(Ok).forward(write); let incoming = read.try_for_each(async move |msg| { let mutex = mutex_generator(); if let Message::Text(text) = msg { if let Ok(notification) = serde_json::from_str::(&text) { match notification { - ServerboundNotification::Authenticate(new_session) => { + ServerboundNotification::Authenticate(auth) => { { if mutex.lock().unwrap().is_some() { send(ClientboundNotification::Error( @@ -85,12 +87,34 @@ async fn accept(stream: TcpStream) { } } - if let Ok(validated_session) = - Auth::new(get_collection("accounts"), Options::new()) - .verify_session(new_session) - .await - { - let id = validated_session.user_id.clone(); + if let Some(id) = match auth { + AuthType::User(new_session) => { + if let Ok(validated_session) = + Auth::new(get_collection("accounts"), Options::new()) + .verify_session(new_session) + .await + { + Some(validated_session.user_id.clone()) + } else { + None + } + } + AuthType::Bot(BotAuth { token }) => { + if let Ok(doc) = get_collection("bots") + .find_one( + doc! { "token": token }, + None + ).await { + if let Some(doc) = doc { + Some(doc.get_str("_id").unwrap().to_string()) + } else { + None + } + } else { + None + } + } + } { if let Ok(user) = (Ref { id: id.clone() }).fetch_user().await { let was_online = is_online(&id); { @@ -110,7 +134,7 @@ async fn accept(stream: TcpStream) { } } - *mutex.lock().unwrap() = Some(validated_session); + *mutex.lock().unwrap() = Some(id.clone()); if let Err(_) = subscriptions::generate_subscriptions(&user).await { send(ClientboundNotification::Error( @@ -166,8 +190,7 @@ async fn accept(stream: TcpStream) { if mutex.lock().unwrap().is_some() { let user = { let mutex = mutex.lock().unwrap(); - let session = mutex.as_ref().unwrap(); - session.user_id.clone() + mutex.as_ref().unwrap().clone() }; ClientboundNotification::ChannelStartTyping { @@ -187,8 +210,7 @@ async fn accept(stream: TcpStream) { if mutex.lock().unwrap().is_some() { let user = { let mutex = mutex.lock().unwrap(); - let session = mutex.as_ref().unwrap(); - session.user_id.clone() + mutex.as_ref().unwrap().clone() }; ClientboundNotification::ChannelStopTyping { @@ -219,13 +241,13 @@ async fn accept(stream: TcpStream) { let mut offline = None; { - let session = session.lock().unwrap(); - if let Some(session) = session.as_ref() { + let user_id = user_id.lock().unwrap(); + if let Some(user_id) = user_id.as_ref() { let mut users = USERS.write().unwrap(); - users.remove(&session.user_id, &addr); - if users.get_left(&session.user_id).is_none() { - get_hive().drop_client(&session.user_id).unwrap(); - offline = Some(session.user_id.clone()); + users.remove(&user_id, &addr); + if users.get_left(&user_id).is_none() { + get_hive().drop_client(&user_id).unwrap(); + offline = Some(user_id.clone()); } } } diff --git a/src/routes/bots/create.rs b/src/routes/bots/create.rs new file mode 100644 index 000000000..fab17ac79 --- /dev/null +++ b/src/routes/bots/create.rs @@ -0,0 +1,89 @@ +use crate::database::*; +use crate::util::result::{Error, Result}; +use crate::util::variables::MAX_BOT_COUNT; + +use mongodb::bson::{doc, to_document}; +use regex::Regex; +use rocket::serde::json::{Json, Value}; +use serde::{Deserialize, Serialize}; +use ulid::Ulid; +use nanoid::nanoid; +use validator::Validate; + +// ! FIXME: should be global somewhere; maybe use config(?) +// ! tip: CTRL + F, RE_USERNAME +lazy_static! { + static ref RE_USERNAME: Regex = Regex::new(r"^[a-zA-Z0-9_.]+$").unwrap(); +} + +#[derive(Validate, Serialize, Deserialize)] +pub struct Data { + #[validate(length(min = 2, max = 32), regex = "RE_USERNAME")] + name: String, +} + +#[post("/create", data = "")] +pub async fn create_bot(user: User, info: Json) -> Result { + let info = info.into_inner(); + info.validate() + .map_err(|error| Error::FailedValidation { error })?; + + if get_collection("bots") + .count_documents( + doc! { + "owner": &user.id + }, + None, + ) + .await + .map_err(|_| Error::DatabaseError { + operation: "count_documents", + with: "bots", + })? as usize >= *MAX_BOT_COUNT { + return Err(Error::ReachedMaximumBots) + } + + let id = Ulid::new().to_string(); + let token = nanoid!(64); + let bot = Bot { + id: id.clone(), + owner: user.id.clone(), + token, + public: false, + interactions_url: None + }; + + if User::is_username_taken(&info.name).await? { + return Err(Error::UsernameTaken); + } + + get_collection("users") + .insert_one( + doc! { + "_id": &id, + "username": &info.name, + "bot": { + "owner": &user.id + } + }, + None, + ) + .await + .map_err(|_| Error::DatabaseError { + operation: "insert_one", + with: "user", + })?; + + get_collection("bots") + .insert_one( + to_document(&bot).map_err(|_| Error::DatabaseError { with: "bot", operation: "to_document" })?, + None, + ) + .await + .map_err(|_| Error::DatabaseError { + operation: "insert_one", + with: "user", + })?; + + Ok(json!(bot)) +} diff --git a/src/routes/bots/mod.rs b/src/routes/bots/mod.rs new file mode 100644 index 000000000..966ce8a40 --- /dev/null +++ b/src/routes/bots/mod.rs @@ -0,0 +1,9 @@ +use rocket::Route; + +mod create; + +pub fn routes() -> Vec { + routes![ + create::create_bot + ] +} diff --git a/src/routes/mod.rs b/src/routes/mod.rs index d5d960f12..1e647e31b 100644 --- a/src/routes/mod.rs +++ b/src/routes/mod.rs @@ -10,6 +10,7 @@ mod root; mod servers; mod sync; mod users; +mod bots; pub fn mount(rocket: Rocket) -> Rocket { rocket @@ -18,6 +19,7 @@ pub fn mount(rocket: Rocket) -> Rocket { .mount("/users", users::routes()) .mount("/channels", channels::routes()) .mount("/servers", servers::routes()) + .mount("/bots", bots::routes()) .mount("/invites", invites::routes()) .mount("/push", push::routes()) .mount("/sync", sync::routes()) diff --git a/src/routes/servers/roles_edit.rs b/src/routes/servers/roles_edit.rs index 81b22c8f3..1f942e174 100644 --- a/src/routes/servers/roles_edit.rs +++ b/src/routes/servers/roles_edit.rs @@ -13,6 +13,8 @@ pub struct Data { name: Option, #[validate(length(min = 1, max = 32))] colour: Option, + hoist: Option, + rank: Option, remove: Option, } @@ -22,7 +24,7 @@ pub async fn req(user: User, target: Ref, role_id: String, data: Json) -> data.validate() .map_err(|error| Error::FailedValidation { error })?; - if data.name.is_none() && data.colour.is_none() && data.remove.is_none() + if data.name.is_none() && data.colour.is_none() && data.hoist.is_none() && data.rank.is_none() && data.remove.is_none() { return Ok(()); } @@ -67,6 +69,16 @@ pub async fn req(user: User, target: Ref, role_id: String, data: Json) -> set_update.insert("colour", colour); } + if let Some(hoist) = &data.hoist { + set.insert(role_key.clone() + ".hoist", hoist); + set_update.insert("hoist", hoist); + } + + if let Some(rank) = &data.rank { + set.insert(role_key.clone() + ".rank", rank); + set_update.insert("rank", rank); + } + let mut operations = doc! {}; if set.len() > 0 { operations.insert("$set", &set); diff --git a/src/routes/users/change_username.rs b/src/routes/users/change_username.rs index 8d7ecdfbd..25683c788 100644 --- a/src/routes/users/change_username.rs +++ b/src/routes/users/change_username.rs @@ -11,6 +11,7 @@ use serde::{Deserialize, Serialize}; use validator::Validate; // ! FIXME: should be global somewhere; maybe use config(?) +// ! tip: CTRL + F, RE_USERNAME lazy_static! { static ref RE_USERNAME: Regex = Regex::new(r"^[a-zA-Z0-9_.]+$").unwrap(); } diff --git a/src/util/result.rs b/src/util/result.rs index 17d3dd4d1..ea089b120 100644 --- a/src/util/result.rs +++ b/src/util/result.rs @@ -44,6 +44,9 @@ pub enum Error { InvalidRole, Banned, + // ? Bot related errors. + ReachedMaximumBots, + // ? General errors. TooManyIds, FailedValidation { @@ -98,6 +101,8 @@ impl<'r> Responder<'r, 'static> for Error { Error::InvalidRole => Status::NotFound, Error::Banned => Status::Forbidden, + Error::ReachedMaximumBots => Status::BadRequest, + Error::FailedValidation { .. } => Status::UnprocessableEntity, Error::DatabaseError { .. } => Status::InternalServerError, Error::InternalError => Status::InternalServerError, diff --git a/src/util/variables.rs b/src/util/variables.rs index c6413d69e..8f087a249 100644 --- a/src/util/variables.rs +++ b/src/util/variables.rs @@ -62,6 +62,8 @@ lazy_static! { // Application Logic Settings pub static ref MAX_GROUP_SIZE: usize = env::var("REVOLT_MAX_GROUP_SIZE").unwrap_or_else(|_| "50".to_string()).parse().unwrap(); + pub static ref MAX_BOT_COUNT: usize = + env::var("REVOLT_MAX_BOT_COUNT").unwrap_or_else(|_| "5".to_string()).parse().unwrap(); pub static ref EARLY_ADOPTER_BADGE: i64 = env::var("REVOLT_EARLY_ADOPTER_BADGE").unwrap_or_else(|_| "0".to_string()).parse().unwrap(); } diff --git a/src/version.rs b/src/version.rs index 63868e8ff..070fe35a9 100644 --- a/src/version.rs +++ b/src/version.rs @@ -1 +1 @@ -pub const VERSION: &str = "0.5.1-alpha.24"; +pub const VERSION: &str = "0.5.2-alpha.0";