diff --git a/mysql-srv/src/commands.rs b/mysql-srv/src/commands.rs index 83fa208d3f..c3431fc2ab 100644 --- a/mysql-srv/src/commands.rs +++ b/mysql-srv/src/commands.rs @@ -157,6 +157,7 @@ pub enum Command<'a> { Query(&'a [u8]), ListFields(&'a [u8]), Close(u32), + Reset, ResetStmtData(u32), Prepare(&'a [u8]), Init(&'a [u8]), @@ -230,6 +231,9 @@ pub fn parse(i: &[u8]) -> IResult<&[u8], Command<'_>> { preceded(tag(&[CommandByte::COM_STMT_CLOSE as u8]), le_u32), Command::Close, ), + map(tag(&[CommandByte::COM_RESET_CONNECTION as u8]), |_| { + Command::Reset + }), map(tag(&[CommandByte::COM_QUIT as u8]), |_| Command::Quit), map(tag(&[CommandByte::COM_PING as u8]), |_| Command::Ping), map( diff --git a/mysql-srv/src/lib.rs b/mysql-srv/src/lib.rs index 232910b5c3..1efc37b89f 100644 --- a/mysql-srv/src/lib.rs +++ b/mysql-srv/src/lib.rs @@ -50,6 +50,10 @@ //! } //! async fn on_close(&mut self, _: DeallocateId) {} //! +//! async fn on_reset(&mut self) -> io::Result<()> { +//! Ok(()) +//! } +//! //! async fn on_init(&mut self, _: &str, w: Option>) -> io::Result<()> { //! w.unwrap().ok().await //! } @@ -290,6 +294,9 @@ pub trait MySqlShim { results: QueryResultWriter<'_, W>, ) -> QueryResultsResponse; + /// Called when the client issues a reset command + async fn on_reset(&mut self) -> io::Result<()>; + /// Called when client switches database. async fn on_init(&mut self, _: &str, _: Option>) -> io::Result<()>; @@ -815,6 +822,11 @@ impl + Send, R: AsyncRead + Unpin, W: AsyncWrite + Unpin + Send> writers::write_ok_packet(&mut self.writer, 0, 0, StatusFlags::empty()).await?; self.writer.flush().await?; } + Command::Reset => { + self.shim.on_reset().await?; + writers::write_ok_packet(&mut self.writer, 0, 0, StatusFlags::empty()).await?; + self.writer.flush().await?; + } Command::Quit => { break; } diff --git a/mysql-srv/tests/main.rs b/mysql-srv/tests/main.rs index 6103b1107c..a77ebb6862 100644 --- a/mysql-srv/tests/main.rs +++ b/mysql-srv/tests/main.rs @@ -95,6 +95,10 @@ where async fn on_close(&mut self, _: DeallocateId) {} + async fn on_reset(&mut self) -> io::Result<()> { + Ok(()) + } + async fn on_init(&mut self, schema: &str, writer: Option>) -> io::Result<()> { (self.on_i)(schema, writer.unwrap()).await } diff --git a/readyset-adapter/src/backend.rs b/readyset-adapter/src/backend.rs index dd30b02fca..f7d86538e3 100644 --- a/readyset-adapter/src/backend.rs +++ b/readyset-adapter/src/backend.rs @@ -849,6 +849,15 @@ where .unwrap_or_else(|| DB::DEFAULT_DB_VERSION.to_string()) } + /// Reset the current upstream connection + pub async fn reset(&mut self) -> Result<(), DB::Error> { + if let Some(upstream) = &mut self.upstream { + upstream.reset().await + } else { + Ok(()) + } + } + /// Switch the active database for this backend to the given named database. /// /// Internally, this will set the schema search path to a single-element vector with the diff --git a/readyset-adapter/src/upstream_database.rs b/readyset-adapter/src/upstream_database.rs index 49a0c0cd75..9fe6e57eaf 100644 --- a/readyset-adapter/src/upstream_database.rs +++ b/readyset-adapter/src/upstream_database.rs @@ -110,6 +110,9 @@ pub trait UpstreamDatabase: Sized + Send { database: &str, ) -> Result<(), Self::Error>; + /// Reset the connection to the upstream database + async fn reset(&mut self) -> Result<(), Self::Error>; + /// Returns a database name if it was included in the original connection string, or None if no /// database name was included in the original connection string. fn database(&self) -> Option<&str> { @@ -277,6 +280,10 @@ where .await } + async fn reset(&mut self) -> Result<(), Self::Error> { + self.upstream().await?.reset().await + } + fn database(&self) -> Option<&str> { if let Some(u) = &self.upstream { u.database() diff --git a/readyset-mysql/src/backend.rs b/readyset-mysql/src/backend.rs index d7ad382976..9660bcbc8c 100644 --- a/readyset-mysql/src/backend.rs +++ b/readyset-mysql/src/backend.rs @@ -723,6 +723,14 @@ where let _ = self.noria.remove_statement(statement_id).await; } + async fn on_reset(&mut self) -> io::Result<()> { + let _ = match self.reset().await { + Ok(()) => Ok(()), + Err(e) => Err(io::Error::new(io::ErrorKind::Other, e.to_string())), + }; + Ok(()) + } + async fn on_query( &mut self, query: &str, diff --git a/readyset-mysql/src/upstream.rs b/readyset-mysql/src/upstream.rs index 0adf840018..2331367269 100644 --- a/readyset-mysql/src/upstream.rs +++ b/readyset-mysql/src/upstream.rs @@ -241,6 +241,11 @@ impl UpstreamDatabase for MySqlUpstream { format!("{major}.{minor}.{patch}-readyset\0") } + async fn reset(&mut self) -> Result<(), Self::Error> { + self.conn.reset().await?; + Ok(()) + } + async fn is_connected(&mut self) -> Result { Ok(self.conn.ping().await.is_ok()) } diff --git a/readyset-mysql/tests/fallback.rs b/readyset-mysql/tests/fallback.rs index bddf01ce15..59ef9f7e42 100644 --- a/readyset-mysql/tests/fallback.rs +++ b/readyset-mysql/tests/fallback.rs @@ -926,6 +926,31 @@ async fn replication_failure_ignores_table() { shutdown_tx.shutdown().await; } +#[tokio::test(flavor = "multi_thread")] +#[serial] +async fn reset_user() { + let (opts, _handle, shutdown_tx) = setup().await; + let mut conn = mysql_async::Conn::new(opts).await.unwrap(); + conn.query_drop("CREATE TEMPORARY TABLE t (id INT)") + .await + .unwrap(); + conn.query_drop("INSERT INTO t (id) VALUES (1)") + .await + .unwrap(); + let row_temp_table: Vec = conn.query("SELECT COUNT(*) FROM t").await.unwrap(); + assert_eq!(row_temp_table.len(), 1); + assert_eq!(row_temp_table[0], 1); + conn.reset().await.unwrap(); + let row = conn.query_drop("SELECT COUNT(*) FROM t").await; + + assert_eq!( + row.map_err(|e| e.to_string()), + Err("Server error: `ERROR 42S02 (1146): Table 'noria.t' doesn't exist'".to_string()) + ); + + shutdown_tx.shutdown().await; +} + #[tokio::test(flavor = "multi_thread")] #[serial] #[slow] diff --git a/readyset-psql/src/upstream.rs b/readyset-psql/src/upstream.rs index c3a239a935..fdfb40d161 100644 --- a/readyset-psql/src/upstream.rs +++ b/readyset-psql/src/upstream.rs @@ -230,6 +230,11 @@ impl UpstreamDatabase for PostgreSqlUpstream { }) } + async fn reset(&mut self) -> Result<(), Self::Error> { + self.client.simple_query("DISCARD ALL").await?; + Ok(()) + } + async fn is_connected(&mut self) -> Result { Ok(!self.client.simple_query("select 1").await?.is_empty()) }