Skip to content

Commit

Permalink
Replace file objects used as function arguments
Browse files Browse the repository at this point in the history
By replacing file objects passed as function arguments with
the read file content, we simplify temporary file objects
life cycle management. Temporary files are handled in a single
function.  This is done for metadata files, which
are fully read into memory right after download, anyway.

Same is not true for target files which preferably should be
treated in chunks so targets download and verification still
deal with file objects (not in a stream-like manner, though).

Signed-off-by: Teodora Sechkova <tsechkova@vmware.com>
  • Loading branch information
sechkova committed Apr 20, 2021
1 parent ad760fd commit 6561f7a
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 41 deletions.
3 changes: 1 addition & 2 deletions tuf/client_rework/metadata_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,8 @@ def __init__(self, meta):
self._meta = meta

@classmethod
def from_json_object(cls, tmp_file):
def from_json_object(cls, raw_data):
"""Loads JSON-formatted TUF metadata from a file object."""
raw_data = tmp_file.read()
# Use local scope import to avoid circular import errors
# pylint: disable=import-outside-toplevel
from tuf.api.serialization.json import JSONDeserializer
Expand Down
67 changes: 28 additions & 39 deletions tuf/client_rework/updater_rework.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import fnmatch
import logging
import os
from typing import BinaryIO, Dict, Optional, TextIO
from typing import Dict, Optional

from securesystemslib import exceptions as sslib_exceptions
from securesystemslib import hash as sslib_hash
Expand Down Expand Up @@ -158,9 +158,8 @@ def download_target(self, target: Dict, destination_directory: str):
temp_obj = download.download_file(
file_mirror, target["fileinfo"]["length"], self._fetcher
)

temp_obj.seek(0)
self._verify_target_file(temp_obj, target)
_check_file_length(temp_obj, target["fileinfo"]["length"])
_check_hashes_obj(temp_obj, target["fileinfo"]["hashes"])
break

except Exception as exception:
Expand Down Expand Up @@ -297,7 +296,7 @@ def _root_mirrors_download(self, root_mirrors: Dict) -> "RootWrapper":
)

temp_obj.seek(0)
intermediate_root = self._verify_root(temp_obj)
intermediate_root = self._verify_root(temp_obj.read())
# When we reach this point, a root file has been successfully
# downloaded and verified so we can exit the loop.
break
Expand Down Expand Up @@ -344,7 +343,7 @@ def _load_timestamp(self) -> None:
)

temp_obj.seek(0)
verified_timestamp = self._verify_timestamp(temp_obj)
verified_timestamp = self._verify_timestamp(temp_obj.read())
break

except Exception as exception: # pylint: disable=broad-except
Expand Down Expand Up @@ -397,7 +396,7 @@ def _load_snapshot(self) -> None:
)

temp_obj.seek(0)
verified_snapshot = self._verify_snapshot(temp_obj)
verified_snapshot = self._verify_snapshot(temp_obj.read())
break

except Exception as exception: # pylint: disable=broad-except
Expand Down Expand Up @@ -451,7 +450,7 @@ def _load_targets(self, targets_role: str, parent_role: str) -> None:

temp_obj.seek(0)
verified_targets = self._verify_targets(
temp_obj, targets_role, parent_role
temp_obj.read(), targets_role, parent_role
)
break

Expand All @@ -472,12 +471,12 @@ def _load_targets(self, targets_role: str, parent_role: str) -> None:
self._get_full_meta_name(targets_role, extension=".json")
)

def _verify_root(self, temp_obj: TextIO) -> RootWrapper:
def _verify_root(self, file_content: bytes) -> RootWrapper:
"""
TODO
"""

intermediate_root = RootWrapper.from_json_object(temp_obj)
intermediate_root = RootWrapper.from_json_object(file_content)

# Check for an arbitrary software attack
trusted_root = self._metadata["root"]
Expand All @@ -490,7 +489,6 @@ def _verify_root(self, temp_obj: TextIO) -> RootWrapper:

# Check for a rollback attack.
if intermediate_root.version < trusted_root.version:
temp_obj.close()
raise exceptions.ReplayedMetadataError(
"root", intermediate_root.version(), trusted_root.version()
)
Expand All @@ -499,11 +497,11 @@ def _verify_root(self, temp_obj: TextIO) -> RootWrapper:

return intermediate_root

def _verify_timestamp(self, temp_obj: TextIO) -> TimestampWrapper:
def _verify_timestamp(self, file_content: bytes) -> TimestampWrapper:
"""
TODO
"""
intermediate_timestamp = TimestampWrapper.from_json_object(temp_obj)
intermediate_timestamp = TimestampWrapper.from_json_object(file_content)

# Check for an arbitrary software attack
trusted_root = self._metadata["root"]
Expand All @@ -517,7 +515,6 @@ def _verify_timestamp(self, temp_obj: TextIO) -> TimestampWrapper:
intermediate_timestamp.signed.version
<= self._metadata["timestamp"].version
):
temp_obj.close()
raise exceptions.ReplayedMetadataError(
"root",
intermediate_timestamp.version(),
Expand All @@ -529,7 +526,6 @@ def _verify_timestamp(self, temp_obj: TextIO) -> TimestampWrapper:
intermediate_timestamp.snapshot.version
<= self._metadata["timestamp"].snapshot["version"]
):
temp_obj.close()
raise exceptions.ReplayedMetadataError(
"root",
intermediate_timestamp.snapshot.version(),
Expand All @@ -540,24 +536,23 @@ def _verify_timestamp(self, temp_obj: TextIO) -> TimestampWrapper:

return intermediate_timestamp

def _verify_snapshot(self, temp_obj: TextIO) -> SnapshotWrapper:
def _verify_snapshot(self, file_content: bytes) -> SnapshotWrapper:
"""
TODO
"""

# Check against timestamp metadata
if self._metadata["timestamp"].snapshot.get("hash"):
_check_hashes(
temp_obj, self._metadata["timestamp"].snapshot.get("hash")
file_content, self._metadata["timestamp"].snapshot.get("hash")
)

intermediate_snapshot = SnapshotWrapper.from_json_object(temp_obj)
intermediate_snapshot = SnapshotWrapper.from_json_object(file_content)

if (
intermediate_snapshot.version
!= self._metadata["timestamp"].snapshot["version"]
):
temp_obj.close()
raise exceptions.BadVersionNumberError

# Check for an arbitrary software attack
Expand All @@ -573,15 +568,14 @@ def _verify_snapshot(self, temp_obj: TextIO) -> SnapshotWrapper:
target_role["version"]
!= self._metadata["snapshot"].meta[target_role]["version"]
):
temp_obj.close()
raise exceptions.BadVersionNumberError

intermediate_snapshot.expires()

return intermediate_snapshot

def _verify_targets(
self, temp_obj: TextIO, filename: str, parent_role: str
self, file_content: bytes, filename: str, parent_role: str
) -> TargetsWrapper:
"""
TODO
Expand All @@ -590,15 +584,14 @@ def _verify_targets(
# Check against timestamp metadata
if self._metadata["snapshot"].role(filename).get("hash"):
_check_hashes(
temp_obj, self._metadata["snapshot"].targets.get("hash")
file_content, self._metadata["snapshot"].targets.get("hash")
)

intermediate_targets = TargetsWrapper.from_json_object(temp_obj)
intermediate_targets = TargetsWrapper.from_json_object(file_content)
if (
intermediate_targets.version
!= self._metadata["snapshot"].role(filename)["version"]
):
temp_obj.close()
raise exceptions.BadVersionNumberError

# Check for an arbitrary software attack
Expand All @@ -612,15 +605,6 @@ def _verify_targets(

return intermediate_targets

@staticmethod
def _verify_target_file(temp_obj: BinaryIO, targetinfo: Dict) -> None:
"""
TODO
"""

_check_file_length(temp_obj, targetinfo["fileinfo"]["length"])
_check_hashes(temp_obj, targetinfo["fileinfo"]["hashes"])

def _preorder_depth_first_walk(self, target_filepath) -> Dict:
"""
TODO
Expand Down Expand Up @@ -849,19 +833,24 @@ def _check_file_length(file_object, trusted_file_length):
)


def _check_hashes(file_object, trusted_hashes):
def _check_hashes_obj(file_object, trusted_hashes):
"""
TODO
"""
file_object.seek(0)
return _check_hashes(file_object.read(), trusted_hashes)


def _check_hashes(file_content, trusted_hashes):
"""
TODO
"""
# Verify each trusted hash of 'trusted_hashes'. If all are valid, simply
# return.
for algorithm, trusted_hash in trusted_hashes.items():
digest_object = sslib_hash.digest(algorithm)
# Ensure we read from the beginning of the file object
# TODO: should we store file position (before the loop) and reset
# after we seek about?
file_object.seek(0)
digest_object.update(file_object.read())

digest_object.update(file_content)
computed_hash = digest_object.hexdigest()

# Raise an exception if any of the hashes are incorrect.
Expand Down

0 comments on commit 6561f7a

Please sign in to comment.