Skip to content

Commit 9553b72

Browse files
authored
[aws][feat]: Update collect method implementation (#2051)
1 parent 96e5a22 commit 9553b72

30 files changed

+375
-139
lines changed

plugins/aws/fix_plugin_aws/resource/apigateway.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -513,14 +513,16 @@ def called_mutator_apis(cls) -> List[AwsApiSpec]:
513513
]
514514

515515
@classmethod
516-
def collect(cls: Type[AwsResource], json: List[Json], builder: GraphBuilder) -> None:
516+
def collect(cls: Type[AwsResource], json: List[Json], builder: GraphBuilder) -> List[AwsResource]:
517+
instances: List[AwsResource] = []
517518
for js in json:
518519
if api_instance := cls.from_api(js, builder):
519520
api_instance.set_arn(
520521
builder=builder,
521522
account="",
522523
resource=f"/restapis/{api_instance.id}",
523524
)
525+
instances.append(api_instance)
524526
builder.add_node(api_instance, js)
525527
for deployment in builder.client.list(
526528
service_name, "get-deployments", "items", restApiId=api_instance.id
@@ -532,6 +534,7 @@ def collect(cls: Type[AwsResource], json: List[Json], builder: GraphBuilder) ->
532534
resource=f"/restapis/{api_instance.id}/deployments/{deploy_instance.id}",
533535
)
534536
deploy_instance.api_link = api_instance.id
537+
instances.append(deploy_instance)
535538
builder.add_node(deploy_instance, deployment)
536539
builder.add_edge(api_instance, EdgeType.default, node=deploy_instance)
537540
for stage in builder.client.list(
@@ -544,6 +547,7 @@ def collect(cls: Type[AwsResource], json: List[Json], builder: GraphBuilder) ->
544547
stage["syntheticId"] = f'{api_instance.id}_{stage["stageName"]}' # create unique id
545548
if stage_instance := AwsApiGatewayStage.from_api(stage, builder):
546549
stage_instance.api_link = api_instance.id
550+
instances.append(stage_instance)
547551
builder.add_node(stage_instance, stage)
548552
# reference kinds for this edge are maintained in AwsApiGatewayDeployment.reference_kinds # noqa: E501
549553
builder.add_edge(deploy_instance, EdgeType.default, node=stage_instance)
@@ -552,18 +556,21 @@ def collect(cls: Type[AwsResource], json: List[Json], builder: GraphBuilder) ->
552556
):
553557
if auth_instance := AwsApiGatewayAuthorizer.from_api(authorizer, builder):
554558
auth_instance.api_link = api_instance.id
559+
instances.append(auth_instance)
555560
builder.add_node(auth_instance, authorizer)
556561
builder.add_edge(api_instance, EdgeType.default, node=auth_instance)
557562
for resource in builder.client.list(service_name, "get-resources", "items", restApiId=api_instance.id):
558563
if resource_instance := AwsApiGatewayResource.from_api(resource, builder):
559564
resource_instance.api_link = api_instance.id
565+
instances.append(resource_instance)
560566
if resource_instance.resource_methods:
561567
for method in resource_instance.resource_methods:
562568
mapped = bend(AwsApiGatewayMethod.mapping, resource["resourceMethods"][method])
563569
if gm := parse_json(mapped, AwsApiGatewayMethod, builder):
564570
resource_instance.resource_methods[method] = gm
565571
builder.add_node(resource_instance, resource)
566572
builder.add_edge(api_instance, EdgeType.default, node=resource_instance)
573+
return instances
567574

568575
def connect_in_graph(self, builder: GraphBuilder, source: Json) -> None:
569576
if self.api_endpoint_configuration:

plugins/aws/fix_plugin_aws/resource/athena.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -133,7 +133,9 @@ def called_mutator_apis(cls) -> List[AwsApiSpec]:
133133
]
134134

135135
@classmethod
136-
def collect(cls: Type[AwsResource], json: List[Json], builder: GraphBuilder) -> None:
136+
def collect(cls: Type[AwsResource], json: List[Json], builder: GraphBuilder) -> List[AwsResource]:
137+
workgroups: List[AwsResource] = []
138+
137139
def fetch_workgroup(name: str) -> Optional[AwsAthenaWorkGroup]:
138140
result = builder.client.get(
139141
aws_service=service_name, action="get-work-group", result_name="WorkGroup", WorkGroup=name
@@ -146,6 +148,7 @@ def fetch_workgroup(name: str) -> Optional[AwsAthenaWorkGroup]:
146148
builder=builder,
147149
resource=f"workgroup/{workgroup.name}",
148150
)
151+
workgroups.append(workgroup)
149152
builder.add_node(workgroup, result)
150153
builder.submit_work(service_name, add_tags, workgroup)
151154
return workgroup
@@ -165,6 +168,7 @@ def add_tags(data_catalog: AwsAthenaWorkGroup) -> None:
165168
for js in json:
166169
if (name := js.get("Name")) is not None and isinstance(name, str):
167170
fetch_workgroup(name)
171+
return workgroups
168172

169173
# noinspection PyUnboundLocalVariable
170174
def connect_in_graph(self, builder: GraphBuilder, source: Json) -> None:
@@ -248,7 +252,9 @@ def called_mutator_apis(cls) -> List[AwsApiSpec]:
248252
]
249253

250254
@classmethod
251-
def collect(cls: Type[AwsResource], json: List[Json], builder: GraphBuilder) -> None:
255+
def collect(cls: Type[AwsResource], json: List[Json], builder: GraphBuilder) -> List[AwsResource]:
256+
catalogs: List[AwsResource] = []
257+
252258
def fetch_data_catalog(data_catalog_name: str) -> Optional[AwsAthenaDataCatalog]:
253259
result = builder.client.get(
254260
aws_service=service_name,
@@ -260,6 +266,7 @@ def fetch_data_catalog(data_catalog_name: str) -> Optional[AwsAthenaDataCatalog]
260266
return None
261267
if catalog := AwsAthenaDataCatalog.from_api(result, builder):
262268
catalog.set_arn(builder=builder, resource=f"datacatalog/{catalog.name}")
269+
catalogs.append(catalog)
263270
builder.add_node(catalog, result)
264271
builder.submit_work(service_name, add_tags, catalog)
265272
return catalog
@@ -279,6 +286,7 @@ def add_tags(data_catalog: AwsAthenaDataCatalog) -> None:
279286
# we filter out the default data catalog as it is not possible to do much with it
280287
if (name := js.get("CatalogName")) is not None and isinstance(name, str) and name != "AwsDataCatalog":
281288
fetch_data_catalog(name)
289+
return catalogs
282290

283291
def update_resource_tag(self, client: AwsClient, key: str, value: str) -> bool:
284292
client.call(

plugins/aws/fix_plugin_aws/resource/base.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -202,10 +202,10 @@ def collect_resources(cls: Type[AwsResource], builder: GraphBuilder) -> None:
202202
expected_errors=spec.expected_errors,
203203
**kwargs,
204204
)
205-
cls.collect(items, builder)
205+
collected = cls.collect(items, builder)
206206
if builder.config.collect_usage_metrics:
207207
try:
208-
cls.collect_usage_metrics(builder)
208+
cls.collect_usage_metrics(builder, collected)
209209
except Exception as e:
210210
log.warning(
211211
f"Failed to collect usage metrics for {cls.__name__} in region {builder.region.id}: {e}"
@@ -220,20 +220,26 @@ def collect_resources(cls: Type[AwsResource], builder: GraphBuilder) -> None:
220220
raise
221221

222222
@classmethod
223-
def collect(cls: Type[AwsResource], json: List[Json], builder: GraphBuilder) -> None:
223+
def collect(cls: Type[AwsResource], json: List[Json], builder: GraphBuilder) -> List[AwsResource]:
224224
# Default behavior: iterate over json snippets and for each:
225225
# - bend the json
226226
# - transform the result into a resource
227227
# - add the resource to the graph
228+
# - return a list of resources
228229
# In case additional work needs to be done, override this method.
230+
instances = []
229231
for js in json:
230232
if instance := cls.from_api(js, builder):
231233
# post process
232234
instance.post_process(builder, js)
233235
builder.add_node(instance, js)
236+
instances.append(instance)
237+
return instances
234238

235239
@classmethod
236-
def collect_usage_metrics(cls: Type[AwsResource], builder: GraphBuilder) -> None:
240+
def collect_usage_metrics(
241+
cls: Type[AwsResource], builder: GraphBuilder, collected_resources: List[AwsResource]
242+
) -> None:
237243
# Default behavior: do nothing
238244
pass
239245

plugins/aws/fix_plugin_aws/resource/cloudformation.py

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import time
22
from datetime import datetime
33
from typing import Any, ClassVar, Dict, Literal, Optional, List, Type, cast
4+
from concurrent.futures import Future, wait as futures_wait
45

56
from attrs import define, field
67

@@ -282,10 +283,14 @@ class AwsCloudFormationStackSet(AwsResource):
282283
stack_set_parameters: Optional[Dict[str, Any]] = None
283284

284285
@classmethod
285-
def collect(cls, json: List[Json], builder: GraphBuilder) -> None:
286-
def stack_set_instances(ss: AwsCloudFormationStackSet) -> None:
286+
def collect(cls, json: List[Json], builder: GraphBuilder) -> List[AwsResource]:
287+
stack_items: List[AwsResource] = []
288+
289+
def stack_set_instances(ss: AwsCloudFormationStackSet) -> List[AwsCloudFormationStackInstanceSummary]:
290+
instances = []
287291
for sij in builder.client.list(service_name, "list-stack-instances", "Summaries", StackSetName=ss.name):
288292
if sii := AwsCloudFormationStackInstanceSummary.from_api(sij, builder):
293+
instances.append(sii)
289294
builder.add_node(sii, sij)
290295
builder.add_edge(ss, node=sii)
291296
builder.graph.add_deferred_edge(
@@ -294,11 +299,18 @@ def stack_set_instances(ss: AwsCloudFormationStackSet) -> None:
294299
f'is(aws_cloudformation_stack) and reported.id="{sii.stack_instance_stack_id}"'
295300
),
296301
)
302+
return instances
297303

304+
futures: List[Future[List[AwsCloudFormationStackInstanceSummary]]] = []
298305
for js in json:
299306
if stack_set := cls.from_api(js, builder):
307+
stack_items.append(stack_set)
300308
builder.add_node(stack_set, js)
301-
builder.submit_work(service_name, stack_set_instances, stack_set)
309+
future = builder.submit_work(service_name, stack_set_instances, stack_set)
310+
futures.append(future)
311+
futures_wait(futures)
312+
stack_instances: List[AwsResource] = [result for future in futures for result in future.result()]
313+
return stack_items + stack_instances
302314

303315
def _modify_tag(self, client: AwsClient, key: str, value: Optional[str], mode: Literal["update", "delete"]) -> bool:
304316
tags = dict(self.tags)

plugins/aws/fix_plugin_aws/resource/cloudfront.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,9 @@
2323

2424
class CloudFrontResource:
2525
@classmethod
26-
def collect(cls: Type[AwsResource], json: List[Json], builder: GraphBuilder) -> None: # type: ignore
26+
def collect(cls: Type[AwsResource], json: List[Json], builder: GraphBuilder) -> List[AwsResource]: # type: ignore
27+
instances = []
28+
2729
def add_tags(res: AwsResource) -> None:
2830
tags = builder.client.get(
2931
service_name, "list-tags-for-resource", "Tags", Resource=res.arn, expected_errors=["InvalidArgument"]
@@ -33,9 +35,11 @@ def add_tags(res: AwsResource) -> None:
3335

3436
for js in json:
3537
if instance := cls.from_api(js, builder):
38+
instances.append(instance)
3639
if instance.arn:
3740
builder.submit_work(service_name, add_tags, instance)
3841
builder.add_node(instance, js)
42+
return instances
3943

4044
@staticmethod
4145
def delete_cloudfront_resource(client: AwsClient, resource: str, rid: str) -> bool:
@@ -642,11 +646,6 @@ def fetch_distribution(did: str) -> None:
642646
aws_service=service_name, action="list-distributions", result_name="DistributionList.Items"
643647
):
644648
builder.submit_work(service_name, fetch_distribution, item["Id"])
645-
if builder.config.collect_usage_metrics:
646-
try:
647-
cls.collect_usage_metrics(builder)
648-
except Exception as e:
649-
log.warning(f"Failed to collect usage metrics for {cls.__name__}: {e}")
650649
except Boto3Error as e:
651650
msg = f"Error while collecting {cls.__name__} in region {builder.region.name}: {e}"
652651
builder.core_feedback.error(msg, log)

plugins/aws/fix_plugin_aws/resource/cloudtrail.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from datetime import datetime
22
from typing import ClassVar, Dict, Optional, Type, List, Any
3+
from concurrent.futures import wait as futures_wait
34

45
from attr import define, field as attrs_field, field
56

@@ -195,8 +196,8 @@ def called_collect_apis(cls) -> List[AwsApiSpec]:
195196
]
196197

197198
@classmethod
198-
def collect(cls: Type[AwsResource], json: List[Json], builder: GraphBuilder) -> None:
199-
def collect_trail(trail_arn: str) -> None:
199+
def collect(cls: Type[AwsResource], json: List[Json], builder: GraphBuilder) -> List[AwsResource]:
200+
def collect_trail(trail_arn: str) -> Optional[AwsCloudTrail]:
200201
if trail_raw := builder.client.get(service_name, "get-trail", "Trail", Name=trail_arn):
201202
if instance := AwsCloudTrail.from_api(trail_raw, builder):
202203
builder.add_node(instance, js)
@@ -206,6 +207,8 @@ def collect_trail(trail_arn: str) -> None:
206207
collect_event_selectors(instance)
207208
if instance.trail_has_insight_selectors:
208209
collect_insight_selectors(instance)
210+
return instance
211+
return None
209212

210213
def collect_event_selectors(trail: AwsCloudTrail) -> None:
211214
if esj := builder.client.get(service_name, "get-event-selectors", TrailName=trail.arn):
@@ -241,17 +244,22 @@ def collect_tags(trail: AwsCloudTrail) -> None:
241244
):
242245
trail.tags = bend(S("TagsList", default=[]) >> ToDict(), tr)
243246

247+
futures = []
244248
for js in json:
245249
arn = js["TrailARN"]
246250
# list trails will return multi account trails in all regions
247251
if js["HomeRegion"] == builder.region.name and builder.account.id in arn:
248252
# only collect trails in the current account and current region
249-
builder.submit_work(service_name, collect_trail, arn)
253+
future = builder.submit_work(service_name, collect_trail, arn)
254+
futures.append(future)
250255
else:
251256
# add a deferred edge to the trails in another account or region
252257
builder.add_deferred_edge(
253258
builder.region, EdgeType.default, f'is(aws_cloud_trail) and reported.arn=="{arn}"'
254259
)
260+
futures_wait(futures) # only continue, when all task definitions are collected
261+
instances: List[AwsResource] = [result for future in futures if (result := future.result())]
262+
return instances
255263

256264
def connect_in_graph(self, builder: GraphBuilder, source: Json) -> None:
257265
if s3 := self.trail_s3_bucket_name:

plugins/aws/fix_plugin_aws/resource/cloudwatch.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -255,16 +255,19 @@ class AwsCloudwatchAlarm(CloudwatchTaggable, AwsResource):
255255
cloudwatch_threshold_metric_id: Optional[str] = field(default=None)
256256

257257
@classmethod
258-
def collect(cls: Type[AwsResource], json: List[Json], builder: GraphBuilder) -> None:
258+
def collect(cls: Type[AwsResource], json: List[Json], builder: GraphBuilder) -> List[AwsResource]:
259259
def add_tags(alarm: AwsCloudwatchAlarm) -> None:
260260
tags = builder.client.list(service_name, "list-tags-for-resource", "Tags", ResourceARN=alarm.arn)
261261
if tags:
262262
alarm.tags = bend(ToDict(), tags)
263263

264+
instances = []
264265
for js in json:
265266
if instance := cls.from_api(js, builder):
267+
instances.append(instance)
266268
builder.add_node(instance, js)
267269
builder.submit_work(service_name, add_tags, instance)
270+
return instances
268271

269272
def connect_in_graph(self, builder: GraphBuilder, source: Json) -> None:
270273
super().connect_in_graph(builder, source)

plugins/aws/fix_plugin_aws/resource/cognito.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -242,7 +242,9 @@ def called_mutator_apis(cls) -> List[AwsApiSpec]:
242242
]
243243

244244
@classmethod
245-
def collect(cls: Type[AwsResource], json: List[Json], builder: GraphBuilder) -> None:
245+
def collect(cls: Type[AwsResource], json: List[Json], builder: GraphBuilder) -> List[AwsResource]:
246+
instances: List[AwsResource] = []
247+
246248
def add_tags(pool: AwsCognitoUserPool) -> None:
247249
tags = builder.client.get(service_name, "list-tags-for-resource", "Tags", ResourceArn=pool.arn)
248250
if tags:
@@ -251,18 +253,22 @@ def add_tags(pool: AwsCognitoUserPool) -> None:
251253
for pool in json:
252254
if pool_instance := cls.from_api(pool, builder):
253255
pool_instance.set_arn(builder=builder, resource=f"userpool/{pool_instance.id}")
256+
instances.append(pool_instance)
254257
builder.add_node(pool_instance, pool)
255258
builder.submit_work(service_name, add_tags, pool_instance)
256259
for user in builder.client.list(service_name, "list-users", "Users", UserPoolId=pool_instance.id):
257260
if user_instance := AwsCognitoUser.from_api(user, builder):
258261
user_instance.pool_name = pool_instance.name
259262
user_instance._pool_id = pool_instance.id
263+
instances.append(user_instance)
260264
builder.add_node(user_instance, user)
261265
builder.add_edge(from_node=pool_instance, edge_type=EdgeType.default, node=user_instance)
262266
for group in builder.client.list(service_name, "list-groups", "Groups", UserPoolId=pool_instance.id):
263267
if group_instance := AwsCognitoGroup.from_api(group, builder):
268+
instances.append(group_instance)
264269
builder.add_node(group_instance, group)
265270
builder.add_edge(from_node=pool_instance, edge_type=EdgeType.default, node=group_instance)
271+
return instances
266272

267273
def connect_in_graph(self, builder: GraphBuilder, source: Json) -> None:
268274
if self.lambda_config:

plugins/aws/fix_plugin_aws/resource/config.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -84,9 +84,11 @@ class AwsConfigRecorder(AwsResource):
8484
recorder_status: Optional[AwsConfigRecorderStatus] = field(default=None, metadata=dict(ignore_history=True))
8585

8686
@classmethod
87-
def collect(cls: Type[AwsResource], json: List[Json], builder: GraphBuilder) -> None:
87+
def collect(cls: Type[AwsResource], json: List[Json], builder: GraphBuilder) -> List[AwsResource]:
8888
# get all statuses
8989
statuses: Dict[str, AwsConfigRecorderStatus] = {}
90+
91+
instances: List[AwsResource] = []
9092
for r in builder.client.list(
9193
service_name, "describe-configuration-recorder-status", "ConfigurationRecordersStatus"
9294
):
@@ -95,10 +97,12 @@ def collect(cls: Type[AwsResource], json: List[Json], builder: GraphBuilder) ->
9597

9698
for js in json:
9799
if instance := AwsConfigRecorder.from_api(js, builder):
100+
instances.append(instance)
98101
if status := statuses.get(instance.id):
99102
instance.recorder_status = status
100103
instance.mtime = status.last_status_change_time
101104
builder.add_node(instance, js)
105+
return instances
102106

103107
def delete_resource(self, client: AwsClient, graph: Graph) -> bool:
104108
client.call(service_name, "delete-configuration-recorder", self.name)

0 commit comments

Comments
 (0)