diff --git a/altimeter/core/artifact_io/writer.py b/altimeter/core/artifact_io/writer.py index 8604d6c6..2292bfb4 100644 --- a/altimeter/core/artifact_io/writer.py +++ b/altimeter/core/artifact_io/writer.py @@ -5,6 +5,7 @@ import gzip import os from pathlib import Path +import tempfile from typing import Optional, Type import boto3 @@ -35,7 +36,11 @@ def write_json(self, name: str, data: BaseModel) -> str: @abc.abstractmethod def write_graph_set( - self, name: str, graph_set: ValidatedGraphSet, compression: Optional[str] = None + self, + name: str, + graph_set: ValidatedGraphSet, + compression: Optional[str] = None, + high_mem: bool = True, ) -> str: """Write a graph artifact @@ -94,7 +99,11 @@ def write_json(self, name: str, data: BaseModel) -> str: return artifact_path def write_graph_set( - self, name: str, graph_set: ValidatedGraphSet, compression: Optional[str] = None + self, + name: str, + graph_set: ValidatedGraphSet, + compression: Optional[str] = None, + high_mem: bool = True, ) -> str: """Write a graph artifact @@ -165,7 +174,11 @@ def write_json(self, name: str, data: BaseModel) -> str: return f"s3://{self.bucket}/{output_key}" def write_graph_set( - self, name: str, graph_set: ValidatedGraphSet, compression: Optional[str] = None + self, + name: str, + graph_set: ValidatedGraphSet, + compression: Optional[str] = None, + high_mem: bool = True, ) -> str: """Write a graph artifact @@ -187,19 +200,34 @@ def write_graph_set( graph = graph_set.to_rdf() with logger.bind(bucket=self.bucket, key_prefix=self.key_prefix, key=key): logger.info(event=LogEvent.WriteToS3Start) - with io.BytesIO() as rdf_bytes_buf: - if compression is None: - graph.serialize(rdf_bytes_buf, format="xml") - elif compression == GZIP: - with gzip.GzipFile(fileobj=rdf_bytes_buf, mode="wb") as gz: - graph.serialize(gz, format="xml") - else: - raise ValueError(f"Unknown compression arg {compression}") - rdf_bytes_buf.flush() - rdf_bytes_buf.seek(0) - session = boto3.Session() - s3_client = session.client("s3") - s3_client.upload_fileobj(rdf_bytes_buf, self.bucket, output_key) + if high_mem: + with io.BytesIO() as rdf_bytes_buf: + if compression is None: + graph.serialize(rdf_bytes_buf, format="xml") + elif compression == GZIP: + with gzip.GzipFile(fileobj=rdf_bytes_buf, mode="wb") as gz: + graph.serialize(gz, format="xml") + else: + raise ValueError(f"Unknown compression arg {compression}") + rdf_bytes_buf.flush() + rdf_bytes_buf.seek(0) + session = boto3.Session() + s3_client = session.client("s3") + s3_client.upload_fileobj(rdf_bytes_buf, self.bucket, output_key) + else: + with tempfile.TemporaryDirectory() as graph_dir: + graph_path = Path(graph_dir, "graph.rdf") + with graph_path.open("wb") as graph_fp: + if compression is None: + graph.serialize(graph_fp, format="xml") + elif compression == GZIP: + with gzip.GzipFile(fileobj=graph_fp, mode="wb") as gz: + graph.serialize(gz, format="xml") + else: + raise ValueError(f"Unknown compression arg {compression}") + session = boto3.Session() + s3_client = session.client("s3") + s3_client.upload_file(str(graph_path), self.bucket, output_key) s3_client.put_object_tagging( Bucket=self.bucket, Key=output_key, diff --git a/bin/sfn_compile_graphs.py b/bin/sfn_compile_graphs.py index b12018fa..5fce55e1 100755 --- a/bin/sfn_compile_graphs.py +++ b/bin/sfn_compile_graphs.py @@ -18,6 +18,7 @@ class CompileGraphsInput(BaseImmutableModel): config: AWSConfig scan_id: str account_scan_manifests: Tuple[AccountScanManifest, ...] + high_mem: bool = True class CompileGraphsOutput(BaseImmutableModel): @@ -60,21 +61,25 @@ def lambda_handler(event: Dict[str, Any], _: Any) -> Dict[str, Any]: if not graph_sets: raise Exception("BUG: No graph_sets generated.") validated_graph_set = ValidatedGraphSet.from_graph_set(GraphSet.from_graph_sets(graph_sets)) - master_artifact_path = artifact_writer.write_json(name="master", data=validated_graph_set) - start_time = validated_graph_set.start_time - end_time = validated_graph_set.end_time - scan_manifest = ScanManifest( - scanned_accounts=scanned_accounts, - master_artifact=master_artifact_path, - artifacts=artifacts, - errors=errors, - unscanned_accounts=list(unscanned_accounts), - start_time=start_time, - end_time=end_time, - ) - artifact_writer.write_json("manifest", data=scan_manifest) + if compile_graphs_input.high_mem: + master_artifact_path = artifact_writer.write_json(name="master", data=validated_graph_set) + start_time = validated_graph_set.start_time + end_time = validated_graph_set.end_time + scan_manifest = ScanManifest( + scanned_accounts=scanned_accounts, + master_artifact=master_artifact_path, + artifacts=artifacts, + errors=errors, + unscanned_accounts=list(unscanned_accounts), + start_time=start_time, + end_time=end_time, + ) + artifact_writer.write_json("manifest", data=scan_manifest) rdf_path = artifact_writer.write_graph_set( - name="master", graph_set=validated_graph_set, compression=GZIP + name="master", + graph_set=validated_graph_set, + compression=GZIP, + high_mem=compile_graphs_input.high_mem, ) return CompileGraphsOutput(