Skip to content

Commit

Permalink
Apply minor changes due to test requirements
Browse files Browse the repository at this point in the history
  • Loading branch information
fdalmaup committed May 14, 2024
1 parent 3a1f81e commit 9615ed8
Show file tree
Hide file tree
Showing 4 changed files with 35 additions and 6 deletions.
2 changes: 1 addition & 1 deletion src/wazuh_testing/modules/aws/data_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -1159,7 +1159,7 @@ def get_data_sample(self):
class UmbrellaDataGenerator(DataGenerator):
def __init__(self, date: datetime, region: str, **kwargs) -> None:
super().__init__(date, region)
self.base_path = join(kwargs['prefix'], 'dnslogs')
self.base_path = join(kwargs['prefix'])
self.base_file_name = ''

def get_filename(self) -> str:
Expand Down
1 change: 1 addition & 0 deletions src/wazuh_testing/modules/aws/patterns.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
MARKER = ".*DEBUG: \+\+\+ Marker: "
AWS_MODULE_STARTED = ".*DEBUG: Launching S3 Command: .*"
AWS_MODULE_STARTED_PARAMETRIZED = ".*DEBUG: Launching S3 Command: "
REMOVE_S3_FILE = ".*DEBUG: \+\+\+ Remove file from S3 Bucket:"

# Logs
NEW_LOG_FOUND = ".*Found new log: .*"
Expand Down
35 changes: 31 additions & 4 deletions src/wazuh_testing/modules/aws/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,27 @@ def get_last_file_key(bucket_type: str, bucket_name: str, execution_datetime: da
last_key = ''
return last_key


def file_exists(filename, bucket_name, client):
"""Check if a file exists in a bucket.
Args:
filename (str): Full filename to check.
bucket_name (str): Bucket that contains the file.
client (boto3.resources.base.ServiceResource): S3 client to access the bucket.
Returns:
bool: True if exists else False.
"""
exists = True
try:
client.Object(bucket_name, filename).load()
except ClientError as error:
if error.response['Error']['Code'] == '404':
exists = False

return exists


"""VPC related utils"""

def create_vpc(vpc_name: str, client) -> str:
Expand Down Expand Up @@ -220,14 +241,16 @@ def create_vpc(vpc_name: str, client) -> str:
logger.error(f"Found a problem creating a VPC: {error}.")


def delete_vpc(vpc_id: str, client) -> None:
"""Delete a VPC.
def delete_vpc(vpc_id: str, flow_log_id: str, client) -> None:
"""Delete a VPC and its inner flow logs.
Args:
vpc_id (str): Id of the VPC to delete.
flow_log_id (str): Id of the Flow Log to delete.
client (Service client instance): EC2 client to delete the VPC and its inner resources.
"""
try:
client.delete_flow_logs(FlowLogIds=[flow_log_id])
client.delete_vpc(VpcId=vpc_id)
except ClientError as error:
raise error
Expand Down Expand Up @@ -622,7 +645,7 @@ def _default_callback(line: str):


def analyze_command_output(
command_output, callback=_default_callback, expected_results=1, error_message=''
command_output, callback=_default_callback, expected_results=1, error_message='', match_exact_number=True
):
"""Analyze a given command output searching for a pattern.
Expand All @@ -631,6 +654,7 @@ def analyze_command_output(
callback (Callable): A callback to process each line. Defaults to _default_callback.
expected_results (int): Number of expected results. Defaults to 1.
error_message (str): Message to show with the exception. Defaults to ''.
match_exact_number (bool): Determine if expected_results should exactly match the number of results found.
Raises:
OutputAnalysisError: When the expected results are not correct.
Expand All @@ -647,7 +671,10 @@ def analyze_command_output(

results_len = len(results)

if results_len != expected_results:
if not match_exact_number and results_len:
return

if results_len != expected_results and match_exact_number:
if error_message:
logger.error(error_message)
logger.error(RESULTS_FOUND, results_len)
Expand Down
3 changes: 2 additions & 1 deletion src/wazuh_testing/utils/db_queries/aws_db.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,8 @@ def get_s3_db_row(table_name) -> S3CloudTrailRow:
"""
row_type = _get_s3_row_type(table_name)
query = SELECT_QUERY_TEMPLATE.format(table_name=table_name)
row = get_sqlite_fetch_one_query_result(S3_CLOUDTRAIL_DB_PATH, query)[0]
row = get_sqlite_fetch_one_query_result(S3_CLOUDTRAIL_DB_PATH, query)
print(f"ROW {row} type {type(row)}")
return row_type(*row)


Expand Down

0 comments on commit 9615ed8

Please sign in to comment.