Skip to content

Commit

Permalink
[COST-4589] Correctly get end of month data for Azure Metering (#4897)
Browse files Browse the repository at this point in the history
* move SQL to files per provider, case statement for Azure to handle EoM scenario
  • Loading branch information
cgoodfred authored Feb 16, 2024
1 parent 89bb16a commit a7d38ed
Show file tree
Hide file tree
Showing 10 changed files with 272 additions and 184 deletions.
169 changes: 23 additions & 146 deletions koku/subs/subs_data_extractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,92 +21,10 @@
from masu.util.aws.common import get_s3_resource
from reporting.models import SubsIDMap
from reporting.models import SubsLastProcessed
from reporting.provider.aws.models import TRINO_LINE_ITEM_TABLE as AWS_TABLE
from reporting.provider.azure.models import TRINO_LINE_ITEM_TABLE as AZURE_TABLE


LOG = logging.getLogger(__name__)

TABLE_MAP = {
Provider.PROVIDER_AWS: AWS_TABLE,
Provider.PROVIDER_AZURE: AZURE_TABLE,
}

ID_COLUMN_MAP = {
Provider.PROVIDER_AWS: "lineitem_usageaccountid",
Provider.PROVIDER_AZURE: "COALESCE(NULLIF(subscriptionid, ''), subscriptionguid)",
}

RECORD_FILTER_MAP = {
Provider.PROVIDER_AWS: (
" lineitem_productcode = 'AmazonEC2' AND lineitem_lineitemtype IN ('Usage', 'SavingsPlanCoveredUsage') "
"AND product_vcpu != '' AND strpos(lower(resourcetags), 'com_redhat_rhel') > 0"
),
Provider.PROVIDER_AZURE: (
" metercategory = 'Virtual Machines' AND chargetype = 'Usage' "
"AND json_extract_scalar(lower(additionalinfo), '$.vcpus') IS NOT NULL "
"AND json_extract_scalar(lower(tags), '$.com_redhat_rhel') IS NOT NULL"
),
}

RESOURCE_ID_FILTER_MAP = {
Provider.PROVIDER_AWS: (
" AND lineitem_productcode = 'AmazonEC2' "
"AND strpos(lower(resourcetags), 'com_redhat_rhel') > 0 AND lineitem_usageaccountid = {{usage_account}}"
),
Provider.PROVIDER_AZURE: (
" AND metercategory = 'Virtual Machines' "
"AND json_extract_scalar(lower(additionalinfo), '$.vcpus') IS NOT NULL "
"AND json_extract_scalar(lower(tags), '$.com_redhat_rhel') IS NOT NULL "
"AND (subscriptionid = {{usage_account}} or subscriptionguid = {{usage_account}}) "
),
}

RESOURCE_SELECT_MAP = {
Provider.PROVIDER_AWS: " SELECT lineitem_resourceid, max(lineitem_usagestartdate) ",
Provider.PROVIDER_AZURE: " SELECT coalesce(NULLIF(resourceid, ''), instanceid), date_add('day', -1, max(coalesce(date, usagedatetime))) ", # noqa E501
}

RESOURCE_ID_GROUP_BY_MAP = {
Provider.PROVIDER_AWS: " GROUP BY lineitem_resourceid",
Provider.PROVIDER_AZURE: " GROUP BY resourceid, instanceid",
}

RESOURCE_ID_EXCLUSION_CLAUSE_MAP = {
Provider.PROVIDER_AWS: " AND lineitem_resourceid NOT IN {{excluded_ids | inclause}} ",
Provider.PROVIDER_AZURE: " and coalesce(NULLIF(resourceid, ''), instanceid) NOT IN {{excluded_ids | inclause}} ",
}

RESOURCE_ID_SQL_CLAUSE_MAP = {
Provider.PROVIDER_AWS: (
" ( lineitem_resourceid = {{{{ rid_{0} }}}} "
" AND lineitem_usagestartdate >= {{{{ start_date_{0} }}}} "
" AND lineitem_usagestartdate <= {{{{ end_date_{0} }}}}) "
),
Provider.PROVIDER_AZURE: (
" ( coalesce(NULLIF(resourceid, ''), instanceid) = {{{{ rid_{0} }}}} "
"AND coalesce(date, usagedatetime) >= {{{{ start_date_{0} }}}} "
"AND coalesce(date, usagedatetime) <= {{{{ end_date_{0} }}}}) "
),
}

POST_OR_CLAUSE_SQL_MAP = {
Provider.PROVIDER_AWS: """
OFFSET
{{ offset }}
LIMIT
{{ limit }}
)
WHERE json_extract_scalar(tags, '$.com_redhat_rhel') IS NOT NULL
""",
Provider.PROVIDER_AZURE: """
OFFSET
{{ offset }}
LIMIT
{{ limit }}
""",
}


class SUBSDataExtractor(ReportDBAccessorBase):
def __init__(self, tracing_id, context):
Expand All @@ -125,16 +43,6 @@ def __init__(self, tracing_id, context):
settings.S3_SUBS_ACCESS_KEY, settings.S3_SUBS_SECRET, settings.S3_SUBS_REGION
)
self.context = context
# The following variables all change depending on the provider type to run the correct SQL
self.table = TABLE_MAP.get(self.provider_type)
self.id_column = ID_COLUMN_MAP.get(self.provider_type)
self.provider_where_clause = RECORD_FILTER_MAP.get(self.provider_type)
self.resource_select_sql = RESOURCE_SELECT_MAP.get(self.provider_type)
self.resource_id_where_clause = RESOURCE_ID_FILTER_MAP.get(self.provider_type)
self.resource_id_group_by = RESOURCE_ID_GROUP_BY_MAP.get(self.provider_type)
self.resource_id_sql_clause = RESOURCE_ID_SQL_CLAUSE_MAP.get(self.provider_type)
self.resource_id_exclusion_clause = RESOURCE_ID_EXCLUSION_CLAUSE_MAP.get(self.provider_type)
self.post_or_clause_sql = POST_OR_CLAUSE_SQL_MAP.get(self.provider_type)

@cached_property
def subs_s3_path(self):
Expand Down Expand Up @@ -176,20 +84,15 @@ def determine_ids_for_provider(self, year, month):
excluded_ids = list(
SubsIDMap.objects.exclude(source_uuid=self.provider_uuid).values_list("usage_id", flat=True)
)
sql = (
"SELECT DISTINCT {{id_column | sqlsafe}} FROM hive.{{schema | sqlsafe}}.{{table | sqlsafe}} WHERE"
" source={{source_uuid}} AND year={{year}} AND month={{month}}"
)
if excluded_ids:
sql += " AND {{id_column | sqlsafe}} NOT IN {{excluded_ids | inclause}}"
sql_file = f"trino_sql/{self.provider_type.lower()}/determine_ids_for_provider.sql"
sql = pkgutil.get_data("subs", sql_file)
sql = sql.decode("utf-8")
sql_params = {
"schema": self.schema,
"source_uuid": self.provider_uuid,
"year": year,
"month": month,
"excluded_ids": excluded_ids,
"id_column": self.id_column,
"table": self.table,
}
ids = self._execute_trino_raw_sql_query(
sql, sql_params=sql_params, context=self.context, log_ref="subs_determine_ids_for_provider"
Expand All @@ -202,49 +105,31 @@ def determine_ids_for_provider(self, year, month):
SubsIDMap.objects.bulk_create(bulk_maps, ignore_conflicts=True)
return id_list

def determine_line_item_count(self, where_clause, sql_params):
"""Determine the number of records in the table that have not been processed and match the criteria"""
table_count_sql = f"SELECT count(*) FROM {self.schema}.{self.table} {where_clause}"
count = self._execute_trino_raw_sql_query(
table_count_sql, sql_params=sql_params, log_ref="determine_subs_processing_count"
)
def determine_row_count(self, sql_params):
"""Determine the number of records in the table that have not been processed and match the criteria."""
sql_file = f"trino_sql/{self.provider_type.lower()}/subs_row_count.sql"
sql = pkgutil.get_data("subs", sql_file)
sql = sql.decode("utf-8")
count = self._execute_trino_raw_sql_query(sql, sql_params=sql_params, log_ref="determine_subs_row_count")
return count[0][0]

def determine_where_clause_and_params(self, year, month):
"""Determine the where clause to use when processing subs data"""
where_clause = "WHERE source={{source_uuid}} AND year={{year}} AND month={{month}} AND"
# different provider types have different required filters here
where_clause += self.provider_where_clause
sql_params = {
"source_uuid": self.provider_uuid,
"year": year,
"month": month,
}
return where_clause, sql_params

def get_resource_ids_for_usage_account(self, usage_account, year, month):
"""Determine the relevant resource ids and end time to process to for each resource id."""
with schema_context(self.schema):
# get a list of IDs to exclude from this source processing
excluded_ids = list(
SubsLastProcessed.objects.exclude(source_uuid=self.provider_uuid).values_list("resource_id", flat=True)
)
sql = self.resource_select_sql + (
" FROM hive.{{schema | sqlsafe}}.{{table | sqlsafe}} WHERE"
" source={{source_uuid}} AND year={{year}} AND month={{month}}"
)
sql += self.resource_id_where_clause
if excluded_ids:
sql += self.resource_id_exclusion_clause
sql += self.resource_id_group_by
sql_file = f"trino_sql/{self.provider_type.lower()}/determine_resource_ids_for_usage_account.sql"
sql = pkgutil.get_data("subs", sql_file)
sql = sql.decode("utf-8")
sql_params = {
"schema": self.schema,
"source_uuid": self.provider_uuid,
"year": year,
"month": month,
"excluded_ids": excluded_ids,
"usage_account": usage_account,
"table": self.table,
}
ids = self._execute_trino_raw_sql_query(
sql, sql_params=sql_params, context=self.context, log_ref="subs_determine_rids_for_provider"
Expand All @@ -253,33 +138,25 @@ def get_resource_ids_for_usage_account(self, usage_account, year, month):

def gather_and_upload_for_resource_batch(self, year, month, batch, base_filename):
"""Gather the data and upload it to S3 for a batch of resource ids"""
where_clause, sql_params = self.determine_where_clause_and_params(year, month)
sql_file = f"trino_sql/{self.provider_type.lower()}_subs_pre_or_clause.sql"
sql_params = sql_params = {
"source_uuid": self.provider_uuid,
"year": year,
"month": month,
"schema": self.schema,
"resources": batch,
}
sql_file = f"trino_sql/{self.provider_type.lower()}/subs_summary.sql"
summary_sql = pkgutil.get_data("subs", sql_file)
summary_sql = summary_sql.decode("utf-8")
rid_sql_clause = " AND ( "
for i, e in enumerate(batch):
rid, start_time, end_time = e
sql_params[f"rid_{i}"] = rid
sql_params[f"start_date_{i}"] = start_time
sql_params[f"end_date_{i}"] = end_time
rid_sql_clause += self.resource_id_sql_clause.format(i)
if i < len(batch) - 1:
rid_sql_clause += " OR "
rid_sql_clause += " )"
where_clause += rid_sql_clause
summary_sql += rid_sql_clause
summary_sql += self.post_or_clause_sql
total_count = self.determine_line_item_count(where_clause, sql_params)
total_count = self.determine_row_count(sql_params)
LOG.debug(
log_json(
self.tracing_id,
msg=f"identified {total_count} matching records for metered rhel",
context=self.context | {"resource_ids": [rid for rid, _, _ in batch]},
context=self.context | {"resource_ids": [row["rid"] for row in batch]},
)
)
upload_keys = []
sql_params["schema"] = self.schema
for i, offset in enumerate(range(0, total_count, settings.PARQUET_PROCESSING_BATCH_SIZE)):
sql_params["offset"] = offset
sql_params["limit"] = settings.PARQUET_PROCESSING_BATCH_SIZE
Expand Down Expand Up @@ -359,7 +236,7 @@ def extract_data_to_s3(self, month_start):
)
for rid, end_time in resource_ids:
start_time = max(last_processed_dict.get(rid, month_start), self.creation_processing_time)
batch.append((rid, start_time, end_time))
batch.append({"rid": rid, "start": start_time, "end": end_time})
if len(batch) >= 100:
upload_keys.extend(
self.gather_and_upload_for_resource_batch(year, month, batch, f"{base_filename}_{batch_num}")
Expand Down
56 changes: 18 additions & 38 deletions koku/subs/test/test_subs_data_extractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,29 +84,11 @@ def test_determine_latest_processed_time_for_provider_without_return_value(self)
self.assertIsNone(actual)

@patch("subs.subs_data_extractor.SUBSDataExtractor._execute_trino_raw_sql_query")
def test_determine_line_item_count(self, mock_trino):
def test_determine_row_count(self, mock_trino):
"""Test determining the line item count for the subs query calls trino"""
self.extractor.determine_line_item_count("fake where clause", {"fake": "params"})
self.extractor.determine_row_count({"fake": "params"})
mock_trino.assert_called_once()

def test_determine_where_clause_and_params(self):
"""Test resulting where clause and params matches expected values"""
year = "2023"
month = "07"
expected_sql_params = {
"source_uuid": self.aws_provider.uuid,
"year": year,
"month": month,
}
expected_clause = (
"WHERE source={{source_uuid}} AND year={{year}} AND month={{month}} AND"
" lineitem_productcode = 'AmazonEC2' AND lineitem_lineitemtype IN ('Usage', 'SavingsPlanCoveredUsage') AND"
" product_vcpu != '' AND strpos(lower(resourcetags), 'com_redhat_rhel') > 0"
)
actual_clause, actual_params = self.extractor.determine_where_clause_and_params(year, month)
self.assertEqual(expected_clause, actual_clause)
self.assertEqual(expected_sql_params, actual_params)

@patch("subs.subs_data_extractor.SUBSDataExtractor.bulk_update_latest_processed_time")
@patch("subs.subs_data_extractor.SUBSDataExtractor.gather_and_upload_for_resource_batch")
@patch("subs.subs_data_extractor.SUBSDataExtractor.get_resource_ids_for_usage_account")
Expand Down Expand Up @@ -176,9 +158,8 @@ def test_extract_data_to_s3_no_resource_ids_found(

@patch("subs.subs_data_extractor.SUBSDataExtractor.copy_data_to_subs_s3_bucket")
@patch("subs.subs_data_extractor.SUBSDataExtractor._execute_trino_raw_sql_query_with_description")
@patch("subs.subs_data_extractor.SUBSDataExtractor.determine_line_item_count")
@patch("subs.subs_data_extractor.SUBSDataExtractor.determine_where_clause_and_params")
def test_gather_and_upload_for_resource_batch(self, mock_where_clause, mock_li_count, mock_trino, mock_copy):
@patch("subs.subs_data_extractor.SUBSDataExtractor.determine_row_count")
def test_gather_and_upload_for_resource_batch(self, mock_row_count, mock_trino, mock_copy):
"""Test gathering data and uploading it to S3 calls the right functions and returns the right value."""
self.dh.month_start(self.yesterday)
rid = "12345"
Expand All @@ -187,28 +168,26 @@ def test_gather_and_upload_for_resource_batch(self, mock_where_clause, mock_li_c
rid_2 = "23456"
start_time = datetime.datetime(2023, 4, 3, tzinfo=datetime.timezone.utc)
end_time = datetime.datetime(2023, 4, 5, tzinfo=datetime.timezone.utc)
batch = [(rid, start_time, end_time), (rid_2, start_time, end_time)]
mock_li_count.return_value = 10
batch = [
{"rid": rid, "start": start_time, "end": end_time},
{"rid": rid_2, "start": start_time, "end": end_time},
]
mock_row_count.return_value = 10
expected_key = "fake_key"
base_filename = "fake_filename"
mock_copy.return_value = expected_key
mock_trino.return_value = (MagicMock(), MagicMock())
mock_where_clause.return_value = (MagicMock(), MagicMock())
upload_keys = self.extractor.gather_and_upload_for_resource_batch(year, month, batch, base_filename)
mock_where_clause.assert_called_once()
mock_li_count.assert_called_once()
mock_row_count.assert_called_once()
mock_trino.assert_called_once()
mock_copy.assert_called_once()
expected_result = [expected_key]
self.assertEqual(expected_result, upload_keys)

@patch("subs.subs_data_extractor.SUBSDataExtractor.copy_data_to_subs_s3_bucket")
@patch("subs.subs_data_extractor.SUBSDataExtractor._execute_trino_raw_sql_query_with_description")
@patch("subs.subs_data_extractor.SUBSDataExtractor.determine_line_item_count")
@patch("subs.subs_data_extractor.SUBSDataExtractor.determine_where_clause_and_params")
def test_gather_and_upload_for_resource_batch_no_result(
self, mock_where_clause, mock_li_count, mock_trino, mock_copy
):
@patch("subs.subs_data_extractor.SUBSDataExtractor.determine_row_count")
def test_gather_and_upload_for_resource_batch_no_result(self, mock_row_count, mock_trino, mock_copy):
"""Test uploading does not attempt with empty values from trino query."""
self.dh.month_start(self.yesterday)
rid = "12345"
Expand All @@ -217,16 +196,17 @@ def test_gather_and_upload_for_resource_batch_no_result(
rid_2 = "23456"
start_time = datetime.datetime(2023, 4, 3, tzinfo=datetime.timezone.utc)
end_time = datetime.datetime(2023, 4, 5, tzinfo=datetime.timezone.utc)
batch = [(rid, start_time, end_time), (rid_2, start_time, end_time)]
mock_li_count.return_value = 10
batch = [
{"rid": rid, "start": start_time, "end": end_time},
{"rid": rid_2, "start": start_time, "end": end_time},
]
mock_row_count.return_value = 10
expected_key = "fake_key"
base_filename = "fake_filename"
mock_copy.return_value = expected_key
mock_trino.return_value = ([], [("fake_col1",), ("fake_col2",)])
mock_where_clause.return_value = (MagicMock(), MagicMock())
upload_keys = self.extractor.gather_and_upload_for_resource_batch(year, month, batch, base_filename)
mock_where_clause.assert_called_once()
mock_li_count.assert_called_once()
mock_row_count.assert_called_once()
mock_trino.assert_called_once()
mock_copy.assert_not_called()
self.assertEqual(upload_keys, [])
Expand Down
8 changes: 8 additions & 0 deletions koku/subs/trino_sql/aws/determine_ids_for_provider.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
SELECT DISTINCT lineitem_usageaccountid
FROM hive.{{schema | sqlsafe}}.aws_line_items
WHERE source={{source_uuid}}
AND year={{year}}
AND month={{month}}
{% if excluded_ids %}
AND lineitem_usageaccountid NOT IN {{excluded_ids | inclause}}
{% endif %}
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
SELECT lineitem_resourceid, max(lineitem_usagestartdate)
FROM hive.{{schema | sqlsafe}}.aws_line_items
WHERE source={{source_uuid}}
AND year={{year}}
AND month={{month}}
AND lineitem_productcode = 'AmazonEC2'
AND strpos(lower(resourcetags), 'com_redhat_rhel') > 0
AND lineitem_usageaccountid = {{usage_account}}
{% if excluded_ids %}
AND lineitem_usageaccountid NOT IN {{excluded_ids | inclause}}
{% endif %}
GROUP BY lineitem_resourceid
23 changes: 23 additions & 0 deletions koku/subs/trino_sql/aws/subs_row_count.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
SELECT count(*)
FROM
hive.{{schema | sqlsafe}}.aws_line_items
WHERE
source = {{ source_uuid }}
AND year = {{ year }}
AND month = {{ month }}
AND lineitem_productcode = 'AmazonEC2'
AND lineitem_lineitemtype IN ('Usage', 'SavingsPlanCoveredUsage')
AND product_vcpu != ''
AND strpos(lower(resourcetags), 'com_redhat_rhel') > 0
AND (
{% for item in resources %}
(
lineitem_resourceid = {{item.rid}} AND
lineitem_usagestartdate >= {{item.start}} AND
lineitem_usagestartdate <= {{item.end}}
)
{% if not loop.last %}
OR
{% endif %}
{% endfor %}
)
Loading

0 comments on commit a7d38ed

Please sign in to comment.