@@ -615,7 +615,7 @@ async def channel():
615
615
writer .get_extra_info = dict (peername = remote_addr , sockname = remote_addr ).get
616
616
return reader , writer
617
617
async def wait_ssh_connection (self , local_addr = None , family = 0 , tunnel = None ):
618
- if self .sshconn is not None :
618
+ if self .sshconn is not None and not self . sshconn . cancelled () :
619
619
if not self .sshconn .done ():
620
620
await self .sshconn
621
621
else :
@@ -633,18 +633,24 @@ async def wait_ssh_connection(self, local_addr=None, family=0, tunnel=None):
633
633
conn = await asyncssh .connect (host = self .host_name , port = self .port , local_addr = local_addr , family = family , x509_trusted_certs = None , known_hosts = None , username = username , password = password , client_keys = client_keys , keepalive_interval = 60 , tunnel = tunnel )
634
634
self .sshconn .set_result (conn )
635
635
async def wait_open_connection (self , host , port , local_addr , family , tunnel = None ):
636
- await self .wait_ssh_connection (local_addr , family , tunnel )
637
- conn = self .sshconn .result ()
638
- if isinstance (self .jump , ProxySSH ):
639
- reader , writer = await self .jump .wait_open_connection (host , port , None , None , conn )
640
- else :
641
- host , port = self .jump .destination (host , port )
642
- if self .jump .unix :
643
- reader , writer = await conn .open_unix_connection (self .jump .bind )
636
+ try :
637
+ await self .wait_ssh_connection (local_addr , family , tunnel )
638
+ conn = self .sshconn .result ()
639
+ if isinstance (self .jump , ProxySSH ):
640
+ reader , writer = await self .jump .wait_open_connection (host , port , None , None , conn )
644
641
else :
645
- reader , writer = await conn .open_connection (host , port )
646
- reader , writer = self .patch_stream (reader , writer , host , port )
647
- return reader , writer
642
+ host , port = self .jump .destination (host , port )
643
+ if self .jump .unix :
644
+ reader , writer = await conn .open_unix_connection (self .jump .bind )
645
+ else :
646
+ reader , writer = await conn .open_connection (host , port )
647
+ reader , writer = self .patch_stream (reader , writer , host , port )
648
+ return reader , writer
649
+ except Exception as ex :
650
+ if not self .sshconn .done ():
651
+ self .sshconn .set_exception (ex )
652
+ self .sshconn = None
653
+ raise
648
654
async def start_server (self , args , stream_handler = stream_handler , tunnel = None ):
649
655
if type (self .jump ) is ProxyDirect :
650
656
raise Exception ('ssh server mode unsupported' )
0 commit comments