3636
3737use std:: cell:: RefCell ;
3838use std:: collections:: { HashMap , VecDeque } ;
39+ use std:: convert:: TryInto ;
3940use std:: num:: Wrapping ;
4041use std:: pin:: Pin ;
4142use std:: sync:: Arc ;
@@ -49,7 +50,7 @@ use russh_keys::encoding::Reader;
4950#[ cfg( feature = "openssl" ) ]
5051use russh_keys:: key:: SignatureHash ;
5152use russh_keys:: key:: { self , parse_public_key, PublicKey } ;
52- use tokio:: io:: { AsyncRead , AsyncWrite , AsyncWriteExt } ;
53+ use tokio:: io:: { AsyncRead , AsyncWrite , AsyncWriteExt , ReadHalf , WriteHalf } ;
5354use tokio:: net:: { TcpStream , ToSocketAddrs } ;
5455use tokio:: pin;
5556use tokio:: sync:: mpsc:: {
@@ -195,6 +196,19 @@ pub struct Prompt {
195196 pub echo : bool ,
196197}
197198
199+ #[ derive( Debug ) ]
200+ pub struct RemoteDisconnectInfo {
201+ pub reason_code : crate :: Disconnect ,
202+ pub message : String ,
203+ pub lang_tag : String ,
204+ }
205+
206+ #[ derive( Debug ) ]
207+ pub enum DisconnectReason < E : From < crate :: Error > + Send > {
208+ ReceivedDisconnect ( RemoteDisconnectInfo ) ,
209+ Error ( E ) ,
210+ }
211+
198212/// Handle to a session, used to send messages to a client outside of
199213/// the request/response cycle.
200214pub struct Handle < H : Handler > {
@@ -737,23 +751,59 @@ impl Session {
737751
738752 async fn run < H : Handler + Send , R : AsyncRead + AsyncWrite + Unpin + Send > (
739753 mut self ,
740- mut stream : SshRead < R > ,
754+ stream : SshRead < R > ,
741755 mut handler : H ,
742- mut encrypted_signal : Option < tokio:: sync:: oneshot:: Sender < ( ) > > ,
756+ encrypted_signal : Option < tokio:: sync:: oneshot:: Sender < ( ) > > ,
743757 ) -> Result < ( ) , H :: Error > {
758+ let ( stream_read, mut stream_write) = stream. split ( ) ;
759+ let result = self
760+ . run_inner (
761+ stream_read,
762+ & mut stream_write,
763+ & mut handler,
764+ encrypted_signal,
765+ )
766+ . await ;
767+ trace ! ( "disconnected" ) ;
768+ self . receiver . close ( ) ;
769+ self . inbound_channel_receiver . close ( ) ;
770+ stream_write. shutdown ( ) . await . map_err ( crate :: Error :: from) ?;
771+ match result {
772+ Ok ( v) => {
773+ handler
774+ . disconnected ( DisconnectReason :: ReceivedDisconnect ( v) )
775+ . await ?;
776+ Ok ( ( ) )
777+ }
778+ Err ( e) => {
779+ handler. disconnected ( DisconnectReason :: Error ( e) ) . await ?;
780+ //Err(e)
781+ Ok ( ( ) )
782+ }
783+ }
784+ }
785+
786+ async fn run_inner < H : Handler + Send , R : AsyncRead + AsyncWrite + Unpin + Send > (
787+ & mut self ,
788+ stream_read : SshRead < ReadHalf < R > > ,
789+ stream_write : & mut WriteHalf < R > ,
790+ handler : & mut H ,
791+ mut encrypted_signal : Option < tokio:: sync:: oneshot:: Sender < ( ) > > ,
792+ ) -> Result < RemoteDisconnectInfo , H :: Error > {
793+ let mut result: Result < RemoteDisconnectInfo , H :: Error > =
794+ Err ( crate :: Error :: Disconnect . into ( ) ) ;
744795 self . flush ( ) ?;
745796 if !self . common . write_buffer . buffer . is_empty ( ) {
746797 debug ! ( "writing {:?} bytes" , self . common. write_buffer. buffer. len( ) ) ;
747- stream
798+ stream_write
748799 . write_all ( & self . common . write_buffer . buffer )
749800 . await
750801 . map_err ( crate :: Error :: from) ?;
751- stream . flush ( ) . await . map_err ( crate :: Error :: from) ?;
802+ stream_write . flush ( ) . await . map_err ( crate :: Error :: from) ?;
752803 }
753804 self . common . write_buffer . buffer . clear ( ) ;
754805 let mut decomp = CryptoVec :: new ( ) ;
755806
756- let ( stream_read, mut stream_write) = stream. split ( ) ;
757807 let buffer = SSHBuffer :: new ( ) ;
758808
759809 // Allow handing out references to the cipher
@@ -805,10 +855,10 @@ impl Session {
805855 if !buf. is_empty( ) {
806856 #[ allow( clippy:: indexing_slicing) ] // length checked
807857 if buf[ 0 ] == crate :: msg:: DISCONNECT {
808- break ;
858+ result = self . process_disconnect ( buf ) ;
809859 } else {
810860 self . common. received_data = true ;
811- reply( & mut self , & mut handler, & mut encrypted_signal, & mut buffer. seqn, buf) . await ?;
861+ reply( self , handler, & mut encrypted_signal, & mut buffer. seqn, buf) . await ?;
812862 }
813863 }
814864
@@ -906,12 +956,30 @@ impl Session {
906956 }
907957 }
908958 }
909- debug ! ( "disconnected" ) ;
910- self . receiver . close ( ) ;
911- self . inbound_channel_receiver . close ( ) ;
912- stream_write. shutdown ( ) . await . map_err ( crate :: Error :: from) ?;
913959
914- Ok ( ( ) )
960+ result
961+ }
962+
963+ fn process_disconnect < E : From < crate :: Error > + Send > (
964+ & mut self ,
965+ buf : & [ u8 ] ,
966+ ) -> Result < RemoteDisconnectInfo , E > {
967+ self . common . disconnected = true ;
968+ let mut reader = buf. reader ( 1 ) ;
969+
970+ let reason_code = reader. read_u32 ( ) . map_err ( crate :: Error :: from) ?. try_into ( ) ?;
971+ let message = std:: str:: from_utf8 ( reader. read_string ( ) . map_err ( crate :: Error :: from) ?)
972+ . map_err ( crate :: Error :: from) ?
973+ . to_owned ( ) ;
974+ let lang_tag = std:: str:: from_utf8 ( reader. read_string ( ) . map_err ( crate :: Error :: from) ?)
975+ . map_err ( crate :: Error :: from) ?
976+ . to_owned ( ) ;
977+
978+ Ok ( RemoteDisconnectInfo {
979+ reason_code,
980+ message,
981+ lang_tag,
982+ } )
915983 }
916984
917985 fn handle_msg ( & mut self , msg : Msg ) -> Result < ( ) , crate :: Error > {
@@ -1360,7 +1428,7 @@ impl Default for Config {
13601428
13611429#[ async_trait]
13621430pub trait Handler : Sized + Send {
1363- type Error : From < crate :: Error > + Send ;
1431+ type Error : From < crate :: Error > + Send + core :: fmt :: Debug ;
13641432
13651433 /// Called when the server sends us an authentication banner. This
13661434 /// is usually meant to be shown to the user, see
@@ -1620,4 +1688,19 @@ pub trait Handler: Sized + Send {
16201688 debug ! ( "openssh_ext_hostkeys_announced: {:?}" , keys) ;
16211689 Ok ( ( ) )
16221690 }
1691+
1692+ /// Called when the server sent a disconnect message
1693+ ///
1694+ /// If reason is an Error, this function should re-return the error so the join can also evaluate it
1695+ #[ allow( unused_variables) ]
1696+ async fn disconnected (
1697+ & mut self ,
1698+ reason : DisconnectReason < Self :: Error > ,
1699+ ) -> Result < ( ) , Self :: Error > {
1700+ debug ! ( "disconnected: {:?}" , reason) ;
1701+ match reason {
1702+ DisconnectReason :: ReceivedDisconnect ( _) => Ok ( ( ) ) ,
1703+ DisconnectReason :: Error ( e) => Err ( e) ,
1704+ }
1705+ }
16231706}
0 commit comments