Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improved handling of SSH connection errors to remote hosts #5356

18 changes: 17 additions & 1 deletion deployability/modules/allocation/vagrant/instance.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,19 @@ def status(self) -> str:
str: The status of the instance.
"""
output = self.__run_vagrant_command('status')
return self.__parse_vagrant_status(output)
vagrant_status = self.__parse_vagrant_status(output)
if vagrant_status == None:
if VagrantUtils.remote_command(f"sudo ls {self.host_instance_dir} > /dev/null 2>&1", self.remote_host_parameters):
if VagrantUtils.remote_command(f"sudo /usr/local/bin/prlctl list -a | grep {self.identifier} > /dev/null 2>&1", self.remote_host_parameters):
logger.warning(f"The instance was found, it will be deleted. The creation of the instance must be restarted again.")
self.delete()
else:
VagrantUtils.remote_command(f"sudo rm -rf {self.host_instance_dir}", self.remote_host_parameters)
raise ValueError(f"Instance {self.identifier} is not running, remote instance dir {self.host_instance_dir} was removed.")
else:
raise ValueError(f"Instance {self.host_instance_dir} not found.")
else:
return self.__parse_vagrant_status(output)

def ssh_connection_info(self) -> ConnectionInfo:
"""
Expand Down Expand Up @@ -235,6 +247,10 @@ def __parse_vagrant_status(self, message: str) -> str:
Returns:
str: The parsed status.
"""
if message is None:
logger.error("Received None message when parsing Vagrant status")
return None

lines = message.split('\n')
for line in lines:
if 'Current machine states:' in line:
Expand Down
81 changes: 51 additions & 30 deletions deployability/modules/allocation/vagrant/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,15 @@
# Created by Wazuh, Inc. <info@wazuh.com>.
# This program is a free software; you can redistribute it and/or modify it under the terms of GPLv2

import time
import subprocess
from pathlib import Path
import logging
import random
import socket

import paramiko

from modules.allocation.generic.utils import logger


Expand All @@ -18,33 +25,40 @@ def remote_command(cls, command: str | list, remote_host_parameters: dict) -> st
Returns:
str: The output of the command.
"""
ssh_command = None
server_ip = remote_host_parameters['server_ip']
ssh_user = remote_host_parameters['ssh_user']
if remote_host_parameters.get('ssh_password'):
ssh_password = remote_host_parameters['ssh_password']
ssh_command = f"sshpass -p {ssh_password} ssh -o 'StrictHostKeyChecking no' {ssh_user}@{server_ip} {command}"
ssh = paramiko.SSHClient()
paramiko.util.get_logger("paramiko").setLevel(logging.WARNING)
ssh.set_missing_host_key_policy(paramiko.AutoAddPolicy())
ssh_parameters = {
'hostname': remote_host_parameters['server_ip'],
'port': 22,
'username': remote_host_parameters['ssh_user']
}
if remote_host_parameters.get('ssh_key'):
ssh_key = remote_host_parameters['ssh_key']
ssh_command = f"ssh -i {ssh_key} {ssh_user}@{server_ip} \"{command}\""

try:
output = subprocess.Popen(f"{ssh_command}",
shell=True,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE)
stdout_data, stderr_data = output.communicate() # Capture stdout and stderr data
stdout_text = stdout_data.decode('utf-8') if stdout_data else "" # Decode stdout bytes to string
stderr_text = stderr_data.decode('utf-8') if stderr_data else "" # Decode stderr bytes to string
ssh_parameters['key_filename'] = remote_host_parameters['ssh_key']
else:
ssh_parameters['password'] = remote_host_parameters['ssh_password']

if stderr_text:
logger.error(f"Command failed: {stderr_text}")
return None
max_retry = 3
ssh_exceptions = (subprocess.CalledProcessError, paramiko.AuthenticationException, paramiko.SSHException, socket.timeout, ConnectionResetError)
for attempt in range(max_retry):
try:
ssh.connect(**ssh_parameters)
stdin_data, stdout_data, stderr_data = ssh.exec_command(command, timeout = 300)
stdout_text = stdout_data.read().decode('utf-8')

return stdout_text
except subprocess.CalledProcessError as e:
logger.error(f"Command failed: {e.stderr.decode('utf-8')}")
return None
ssh.close()
return stdout_text
except ssh_exceptions as e:
if attempt < max_retry - 1:
logger.warning(f"SSH connection error: {str(e)}. Retrying in 30 seconds...")
time.sleep(30)
continue
else:
ssh.close()
raise ValueError(f"Remote command execution failed: {str(e)}")
except Exception as e:
ssh.close()
raise ValueError(f"An unexpected error occurred when executing the remote command: {str(e)}")

@classmethod
def remote_copy(cls, instance_dir: Path, host_instance_dir: Path, remote_host_parameters: dict) -> str:
Expand Down Expand Up @@ -100,9 +114,16 @@ def get_port(cls, remote_host_parameters: dict, arch: str = None) -> int:

raise ValueError(f"ppc64 server has no available SSH ports.")
else:
for i in range(20, 40):
port = f"432{i}"
cmd = f"sudo lsof -i:{port}"
output = cls.remote_command(cmd, remote_host_parameters)
if not output:
return port
used_ports = []
all_ports = [f"432{i}" for i in range(20, 40)]
random.shuffle(all_ports)
for port in all_ports:
if port not in used_ports:
cmd = f"sudo lsof -i:{port}"
output = cls.remote_command(cmd, remote_host_parameters)
if not output:
return port
else:
used_ports.append(port)
else:
raise ValueError(f"The server has no available ports in the range 43220 to 43240.")