Skip to content

Commit

Permalink
mysql-srv: add support to COM_CHANGE_USER
Browse files Browse the repository at this point in the history
Add support for COM_CHANGE_USER command. This command is used to
change the user of the current connection.
As part of this work, we now store the client capabilities and the
initial scramble auth_plugin_data we sent to this client. This info
is used to re-authenticate the client when the COM_CHANGE_USER
command is received.

Refs: REA-4212

Release-Note-Core: Add support for MySQL COM_CHANGE_USER
  command.

Change-Id: I2790da03cefc04e21b487979738554aa3e1a2227
Reviewed-on: https://gerrit.readyset.name/c/readyset/+/7102
Tested-by: Buildkite CI
Reviewed-by: Luke Osborne <luke@readyset.io>
  • Loading branch information
altmannmarcelo committed Mar 20, 2024
1 parent 3f68ffd commit 7c8a1b7
Show file tree
Hide file tree
Showing 9 changed files with 380 additions and 7 deletions.
107 changes: 106 additions & 1 deletion mysql-srv/src/commands.rs
Expand Up @@ -6,7 +6,7 @@ use nom::number::complete::{le_i16, le_i24, le_i64, le_u16, le_u32, le_u8};
use nom::sequence::preceded;
use nom::IResult;

use crate::myc::constants::{CapabilityFlags, Command as CommandByte};
use crate::myc::constants::{CapabilityFlags, Command as CommandByte, UTF8MB4_GENERAL_CI};

#[derive(Debug)]
pub struct ClientHandshake<'a> {
Expand All @@ -19,6 +19,15 @@ pub struct ClientHandshake<'a> {
pub auth_plugin_name: Option<&'a str>,
}

#[derive(Debug)]
pub struct ClientChangeUser<'a> {
pub username: &'a str,
pub password: &'a [u8],
pub database: Option<&'a str>,
pub charset: u16,
pub auth_plugin_name: &'a str,
}

/// Parse a "length-encoded integer" as specified by the [mysql binary protocol documentation][docs]
///
/// [docs]: https://dev.mysql.com/doc/internals/en/integer.html#length-encoded-integer
Expand Down Expand Up @@ -52,6 +61,53 @@ fn null_terminated_string(i: &[u8]) -> IResult<&[u8], &str> {
Ok((i, res))
}

/// Parse a COM_CHANGE_USER packet as specified by the [mysql binary protocol documentation][docs]
/// [docs]: https://dev.mysql.com/doc/dev/mysql-server/latest/page_protocol_com_change_user.html
pub fn change_user(
i: &[u8],
client_capability_flags: CapabilityFlags,
) -> IResult<&[u8], ClientChangeUser<'_>> {
let (i, username) = null_terminated_string(i)?;
let (i, password) =
if client_capability_flags.contains(CapabilityFlags::CLIENT_SECURE_CONNECTION) {
let (q, auth_token_length) = le_u8(i)?;
take(auth_token_length)(q)?
} else {
map(null_terminated_string, |s| s.as_bytes())(i)?
};
let (i, database) = map(null_terminated_string, Some)(i)?;
let (i, charset, auth_plugin_name) = if !i.is_empty() {
let (i, charset) = if client_capability_flags.contains(CapabilityFlags::CLIENT_PROTOCOL_41)
{
let (i, bytes) = take(2usize)(i)?;
let charset = u16::from_le_bytes(bytes.try_into().unwrap());
(i, charset)
} else {
(i, UTF8MB4_GENERAL_CI)
};
let (i, auth_plugin_name) =
if client_capability_flags.contains(CapabilityFlags::CLIENT_PLUGIN_AUTH) {
null_terminated_string(i)?
} else {
(i, "")
};
(i, charset, auth_plugin_name)
} else {
(i, UTF8MB4_GENERAL_CI, "")
};

Ok((
i,
ClientChangeUser {
username,
password,
database,
charset,
auth_plugin_name,
},
))
}

/// <https://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::HandshakeResponse41>
pub fn client_handshake(i: &[u8]) -> IResult<&[u8], ClientHandshake<'_>> {
let (i, capabilities) = map(le_u32, CapabilityFlags::from_bits_truncate)(i)?;
Expand Down Expand Up @@ -116,6 +172,7 @@ pub enum Command<'a> {
},
Ping,
Quit,
ChangeUser(&'a [u8]),
}

pub fn execute(i: &[u8]) -> IResult<&[u8], Command<'_>> {
Expand Down Expand Up @@ -175,6 +232,10 @@ pub fn parse(i: &[u8]) -> IResult<&[u8], Command<'_>> {
),
map(tag(&[CommandByte::COM_QUIT as u8]), |_| Command::Quit),
map(tag(&[CommandByte::COM_PING as u8]), |_| Command::Ping),
map(
preceded(tag(&[CommandByte::COM_CHANGE_USER as u8]), rest),
Command::ChangeUser,
),
))(i)
}

Expand Down Expand Up @@ -251,4 +312,48 @@ mod tests {
Command::ListFields(&b"select @@version_comment limit 1"[..])
);
}

#[tokio::test]
async fn it_parses_change_user() {
let data = &[
0x24, 0x00, 0x00, 0x00, 0x72, 0x6f, 0x6f, 0x74, 0x00, 0x00, 0x74, 0x65, 0x73, 0x74,
0x00, 0x2d, 0x00, 0x6d, 0x79, 0x73, 0x71, 0x6c, 0x5f, 0x6e, 0x61, 0x74, 0x69, 0x76,
0x65, 0x5f, 0x70, 0x61, 0x73, 0x73, 0x77, 0x6f, 0x72, 0x64, 0x00, 0x00,
];
let r = Cursor::new(&data[..]);
let mut pr = PacketReader::new(r);
let (_, p) = pr.next().await.unwrap().unwrap();
let capability_flags = CapabilityFlags::CLIENT_PROTOCOL_41
| CapabilityFlags::CLIENT_SECURE_CONNECTION
| CapabilityFlags::CLIENT_PLUGIN_AUTH;
let (_, changeuser) = change_user(&p, capability_flags).unwrap();
assert_eq!(changeuser.username, "root");
assert_eq!(changeuser.password, b"");
assert_eq!(changeuser.database, Some("test"));
assert_eq!(changeuser.charset, UTF8MB4_GENERAL_CI);
assert_eq!(changeuser.auth_plugin_name, "mysql_native_password");

let data = &[
0x38, 0x00, 0x00, 0x00, 0x72, 0x6f, 0x6f, 0x74, 0x00, 0x14, 0x95, 0x16, 0xb1, 0x01,
0x40, 0x6d, 0x7c, 0xc3, 0x17, 0x22, 0xc5, 0x9d, 0x00, 0xf3, 0x5d, 0x37, 0xb9, 0xb5,
0x6d, 0x0f, 0x74, 0x65, 0x73, 0x74, 0x00, 0x2d, 0x00, 0x6d, 0x79, 0x73, 0x71, 0x6c,
0x5f, 0x6e, 0x61, 0x74, 0x69, 0x76, 0x65, 0x5f, 0x70, 0x61, 0x73, 0x73, 0x77, 0x6f,
0x72, 0x64, 0x00, 0x00,
];
let r = Cursor::new(&data[..]);
let mut pr = PacketReader::new(r);
let (_, p) = pr.next().await.unwrap().unwrap();
let (_, changeuser) = change_user(&p, capability_flags).unwrap();
assert_eq!(changeuser.username, "root");
assert_eq!(
changeuser.password,
&[
0x95, 0x16, 0xb1, 0x01, 0x40, 0x6d, 0x7c, 0xc3, 0x17, 0x22, 0xc5, 0x9d, 0x00, 0xf3,
0x5d, 0x37, 0xb9, 0xb5, 0x6d, 0x0f
]
);
assert_eq!(changeuser.database, Some("test"));
assert_eq!(changeuser.charset, UTF8MB4_GENERAL_CI);
assert_eq!(changeuser.auth_plugin_name, "mysql_native_password");
}
}
95 changes: 94 additions & 1 deletion mysql-srv/src/lib.rs
Expand Up @@ -53,6 +53,9 @@
//! async fn on_init(&mut self, _: &str, w: Option<InitWriter<'_, W>>) -> io::Result<()> {
//! w.unwrap().ok().await
//! }
//! async fn on_change_user(&mut self, _: &str, _: &str, _: &str) -> io::Result<()> {
//! Ok(())
//! }
//!
//! async fn on_query(
//! &mut self,
Expand Down Expand Up @@ -176,6 +179,7 @@ use tracing::{debug, info, trace};
use writers::write_err;

use crate::authentication::{generate_auth_data, hash_password, AUTH_PLUGIN_NAME};
use crate::commands::change_user;
pub use crate::myc::constants::{ColumnFlags, ColumnType, StatusFlags};
pub use crate::writers::prepare_column_definitions;

Expand Down Expand Up @@ -289,6 +293,9 @@ pub trait MySqlShim<W: AsyncWrite + Unpin + Send> {
/// Called when client switches database.
async fn on_init(&mut self, _: &str, _: Option<InitWriter<'_, W>>) -> io::Result<()>;

/// Called when client switches user.
async fn on_change_user(&mut self, _: &str, _: &str, _: &str) -> io::Result<()>;

/// Retrieve the password for the user with the given username, if any.
///
/// If the user doesn't exist, return [`None`].
Expand Down Expand Up @@ -320,6 +327,10 @@ pub struct MySqlIntermediary<B, R: AsyncRead + Unpin, W: AsyncWrite + Unpin> {
schema_cache: HashMap<u32, CachedSchema>,
/// Whether to log statements received from a client
enable_statement_logging: bool,
/// The capabilities of the client
client_capabilities: CapabilityFlags,
/// Auth data sent to client
auth_data: [u8; 20],
}

impl<B: MySqlShim<net::tcp::OwnedWriteHalf> + Send>
Expand Down Expand Up @@ -391,6 +402,8 @@ impl<B: MySqlShim<W> + Send, R: AsyncRead + Unpin, W: AsyncWrite + Unpin + Send>
writer: w,
schema_cache: HashMap::new(),
enable_statement_logging,
client_capabilities: CapabilityFlags::empty(),
auth_data: [0; 20],
};
if let (true, database) = mi.init().await? {
if let Some(database) = database {
Expand All @@ -414,7 +427,7 @@ impl<B: MySqlShim<W> + Send, R: AsyncRead + Unpin, W: AsyncWrite + Unpin + Send>
async fn init(&mut self) -> Result<(bool, Option<String>), io::Error> {
let auth_data =
generate_auth_data().map_err(|_| other_error(OtherErrorKind::AuthDataErr))?;

self.auth_data = auth_data;
let mut init_packet = Vec::with_capacity(
1 + 16 + 4 + 8 + 1 + 2 + 1 + 2 + 2 + 1 + 6 + 4 + 12 + 1 + AUTH_PLUGIN_NAME.len() + 1,
);
Expand Down Expand Up @@ -468,6 +481,7 @@ impl<B: MySqlShim<W> + Send, R: AsyncRead + Unpin, W: AsyncWrite + Unpin + Send>

self.writer.set_seq(seq + 1);

self.client_capabilities = handshake.capabilities;
let username = handshake.username.to_owned();
let password = handshake.password.to_vec();
let database = handshake.database.map(String::from);
Expand Down Expand Up @@ -580,6 +594,85 @@ impl<B: MySqlShim<W> + Send, R: AsyncRead + Unpin, W: AsyncWrite + Unpin + Send>
info!(target: "client_statement", "{:?}", cmd);
}
match cmd {
Command::ChangeUser(q) => {
let change_user = change_user(q, self.client_capabilities)
.map_err(|e| {
other_error(OtherErrorKind::GenericErr {
error: format!("{:?}", e),
})
})?
.1;
let username = change_user.username.to_owned();
let authpassword = change_user.password.to_vec();

if change_user.auth_plugin_name != AUTH_PLUGIN_NAME {
// This should never happen, as we already accepted a connection using
// AUTH_PLUGIN_NAME
writers::write_err(
ErrorKind::ER_ACCESS_DENIED_ERROR,
format!(
"Access denied for user {}. Incorrect auth plugin {}",
username, change_user.auth_plugin_name
)
.as_bytes(),
&mut self.writer,
)
.await?;
self.writer.flush().await?;
continue;
}
let plain_password = self.shim.password_for_username(&username);
let auth_success = !self.shim.require_authentication()
|| plain_password.as_ref().map_or(false, |password| {
let expected = hash_password(password, &self.auth_data);
let actual = authpassword.as_slice();
trace!(?expected, ?actual);
expected == actual
});

if auth_success {
debug!("Successfully authenticated client");
match self
.shim
.on_change_user(
&username,
&plain_password
.as_ref()
.map(|p| String::from_utf8_lossy(p))
.unwrap_or_default(),
change_user.database.unwrap_or_default(),
)
.await
{
Ok(()) => {
writers::write_ok_packet(
&mut self.writer,
0,
0,
StatusFlags::empty(),
)
.await?;
}
Err(_) => {
writers::write_err(
ErrorKind::ER_ACCESS_DENIED_ERROR,
format!("Access denied for user {}", username).as_bytes(),
&mut self.writer,
)
.await?;
}
}
} else {
debug!("Received incorrect password");
writers::write_err(
ErrorKind::ER_ACCESS_DENIED_ERROR,
format!("Access denied for user {}", username).as_bytes(),
&mut self.writer,
)
.await?;
}
self.writer.flush().await?;
}
Command::Query(q) => {
let w = QueryResultWriter::new(&mut self.writer, false);
let res = self
Expand Down

0 comments on commit 7c8a1b7

Please sign in to comment.