Skip to content

Commit

Permalink
feat: migrate to struct module
Browse files Browse the repository at this point in the history
Removes the Entry class in favor of the builtin struct module. This
will simplify marshalling the disk image byte data.
  • Loading branch information
swysocki committed Jul 14, 2022
1 parent 6430b06 commit aaa777f
Show file tree
Hide file tree
Showing 9 changed files with 378 additions and 426 deletions.
30 changes: 14 additions & 16 deletions gpt_image/disk.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,11 @@ class Disk:
image_path: file image path (absolute or relative)
"""

def __init__(self, image_path: str) -> None:
def __init__(self, image_path: str, sector_size: int = 512) -> None:
"""Init Disk with a file path"""
self.image_path = pathlib.Path(image_path)
self.name = self.image_path.name
self.sector_size = 512 # 512 bytes
self.sector_size = sector_size

def open(self):
"""Read existing GPT disk Table"""
Expand All @@ -41,9 +41,9 @@ def open(self):
+ self.geometry.header_length
]
self.table.primary_header = Header(self.geometry)
self.table.primary_header.read(primary_header_b)
self.table.primary_header.read(primary_header_b, self.geometry)
self.table.secondary_header = Header(self.geometry, is_backup=True)
self.table.secondary_header.read(backup_header_b)
self.table.secondary_header.read(backup_header_b, self.geometry)
# read the partition tables
primary_part_table_b = disk_bytes[
self.geometry.primary_array_byte : self.geometry.primary_array_byte
Expand All @@ -56,16 +56,15 @@ def open(self):
if primary_part_table_b != backup_part_table_b:
raise TableReadError("primary and backup table do not match")
# unmarshal the partition bytes to objects
# loop through the entire array and unmarshall the bytes if partition
# data is found
# add the partition to the entry list if the type_guid is valid
for i in range(PartitionEntryArray.EntryCount):
offset = i * PartitionEntryArray.EntryLength
partition_bytes = primary_part_table_b[
offset : offset + PartitionEntryArray.EntryLength
]
new_part = Partition()
new_part.read(partition_bytes)
if new_part.type_guid.data != b"\x00" * 16:
new_part = Partition.read(partition_bytes, self.geometry.sector_size)
print(new_part.type_guid)
if new_part.type_guid != Partition._EMPTY_GUID:
self.table.partitions.entries.append(new_part)

def create(self, size: int):
Expand All @@ -84,7 +83,7 @@ def create(self, size: int):
# zero entire disk
f.write(b"\x00" * self.size)
f.seek(0)
f.write(self.table.protective_mbr.byte_structure)
f.write(self.table.protective_mbr.marshal())
self.write()

def write(self):
Expand All @@ -98,19 +97,19 @@ def write(self):
with open(self.image_path, "r+b") as f:
# write primary header
f.seek(self.geometry.primary_header_byte)
f.write(self.table.primary_header.byte_structure)
f.write(self.table.primary_header.marshal())

# write primary partition table
f.seek(self.geometry.primary_array_byte)
f.write(self.table.partitions.byte_structure)
f.write(self.table.partitions.marshal())

# move to secondary header location and write
f.seek(self.geometry.alternate_header_byte)
f.write(self.table.secondary_header.byte_structure)
f.write(self.table.secondary_header.marshal())

# write secondary partition table
f.seek(self.geometry.alternate_array_byte)
f.write(self.table.partitions.byte_structure)
f.write(self.table.partitions.marshal())

def write_data(self, data: bytes, partition: Partition, offset: int = 0) -> None:
# @NOTE: this isn't a GPT function. Writing data should be outside the
Expand All @@ -127,8 +126,7 @@ def write_data(self, data: bytes, partition: Partition, offset: int = 0) -> None
raise ValueError(f"data must be of type bytes. found type: {type(data)}")

with open(self.image_path, "r+b") as f:
start_lba = int.from_bytes(partition.first_lba.data_bytes, "little")
start_byte = int(start_lba * self.sector_size)
start_byte = int(partition.first_lba * self.sector_size)
with open(self.image_path, "r+b") as f:
f.seek(start_byte + offset)
f.write(data)
46 changes: 0 additions & 46 deletions gpt_image/entry.py

This file was deleted.

150 changes: 62 additions & 88 deletions gpt_image/partition.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,14 @@
import struct
import uuid
from enum import Enum
from math import ceil
from sys import byteorder
from typing import List, Union
from typing import List

from gpt_image.entry import Entry
from gpt_image.geometry import Geometry


class PartitionEntryError(Exception):
pass
"""Exception class for erros in partition entries"""


class PartitionAttribute(Enum):
Expand All @@ -36,48 +35,32 @@ class Partition:
from a table's partition list.
"""

_PARTITION_FORMAT = struct.Struct("<16s16sQQQ72s")
_EMPTY_GUID = "00000000-0000-0000-0000-000000000000"
# https://en.wikipedia.org/wiki/GUID_Partition_Table#Partition_entries
LINUX_FILE_SYSTEM = "0FC63DAF-8483-4772-8E79-3D69D8477DE4"
EFI_SYSTEM_PARTITION = "C12A7328-F81F-11D2-BA4B-00A0C93EC93B"

def __init__(
self,
name: str = "",
size: int = 0,
partition_guid: Union[None, uuid.UUID] = None,
attribute_flag: PartitionAttribute = PartitionAttribute.NONE,
name: str,
size: int,
type_guid: str,
partition_guid: str = "",
alignment: int = 8,
partition_attributes: int = 0,
):
"""Initialize Partition Object
All parameters have a default value to allow Partition() to create
an empty partition object. If "name" is set, we assume this is not
an empty object and set the other values.
Empty partition objects are used as placeholders in the partition
entry array.
Attributes:
size: partition size in Bytes
"""
# create an empty partition object
self.type_guid = Entry(0, 16, 0)
self.partition_guid = Entry(16, 16, 0)
self.first_lba = Entry(32, 8, 0)
self.last_lba = Entry(40, 8, 0)
self.partition_name = Entry(56, 72, 0)
self._attribute_flags = Entry(48, 8, attribute_flag.value)
# if name is set, this isn't an empty partition. Set relevant fields
# @TODO: don't base an empty partition off of the name attribute
if name:
self.type_guid.data = uuid.UUID(Partition.LINUX_FILE_SYSTEM).bytes_le
if not partition_guid:
self.partition_guid.data = uuid.uuid4().bytes_le
else:
self.partition_guid.data = partition_guid.bytes_le
# the partition name is stored as utf_16_le
self.partition_name.data = bytes(name, encoding="utf_16_le")

"""Initialize Partition Object"""
self.type_guid = type_guid
self.partition_name = name
self._attribute_flags = 0
self.partition_guid = partition_guid
# if the partition GUID is empty, generate one
if not partition_guid:
self.partition_guid = str(uuid.uuid4())
self.first_lba = 0
self.last_lba = 0
self._attribute_flags = partition_attributes
self.alignment = alignment
self.size = size

Expand All @@ -94,52 +77,44 @@ def attribute_flags(self, flag: PartitionAttribute):
"""
if flag.value == 0:
self._attribute_flags.data = 0
self._attribute_flags = 0
else:
attr_int = int.from_bytes(
self._attribute_flags.data_bytes, byteorder="little"
)
self._attribute_flags = Entry(48, 8, attr_int | (1 << flag.value))

@property
def byte_structure(self) -> bytes:
part_fields = [
self.type_guid,
self.partition_guid,
# bit indices are zero-based so we subtract 1 from our flag
self._attribute_flags = self._attribute_flags | (1 << flag.value)

def marshal(self) -> bytes:
"""Marshal to byte structure"""
partition_bytes = self._PARTITION_FORMAT.pack(
uuid.UUID(self.type_guid).bytes_le,
uuid.UUID(self.partition_guid).bytes_le,
self.first_lba,
self.last_lba,
self.attribute_flags,
self.partition_name,
]
byte_list = [x.data_bytes for x in part_fields]
return b"".join(byte_list)

def read(self, partition_bytes: bytes):
"""Unmarshal bytes to Partition Object"""
self.type_guid.data = partition_bytes[
self.type_guid.offset : self.type_guid.offset + self.type_guid.length
]
self.partition_guid.data = partition_bytes[
self.partition_guid.offset : self.partition_guid.offset
+ self.partition_guid.length
]
self.first_lba.data = partition_bytes[
self.first_lba.offset : self.first_lba.offset + self.first_lba.length
]
self.last_lba.data = partition_bytes[
self.last_lba.offset : self.last_lba.offset + self.last_lba.length
]
self.attribute_flags.data = partition_bytes[
self.attribute_flags.offset : self.attribute_flags.offset
+ self.attribute_flags.length
]
part_name_b = partition_bytes[
self.partition_name.offset : self.partition_name.offset
+ self.partition_name.length
]
# partition name is stored as UTF-16-LE padded to 72 bytes
self.partition_name.data = part_name_b.decode(encoding="utf_16_le").rstrip(
"\x00"
bytes(self.partition_name, encoding="utf_16_le"),
)
return partition_bytes

@staticmethod
def read(partition_bytes: bytes, sector_size: int) -> "Partition":
"""Create a Partition object from existing bytes"""
if len(partition_bytes) != PartitionEntryArray.EntryLength:
raise ValueError(f"Invalid Partition Entry length: {len(partition_bytes)}")
(
type_guid,
partition_guid,
first_lba,
last_lba,
attribute_flags,
partition_name,
) = Partition._PARTITION_FORMAT.unpack(partition_bytes)
size = (last_lba - first_lba + 1) * sector_size

return Partition(
partition_name.decode("utf_16_le").rstrip("\x00"),
size,
str(uuid.UUID(bytes_le=type_guid)),
str(uuid.UUID(bytes_le=partition_guid)),
partition_attributes=attribute_flags,
)


Expand All @@ -159,8 +134,8 @@ def add(self, partition: Partition) -> None:
Appends the Partition to the next available entry. Calculates the
LBA's
"""
partition.first_lba.data = self._get_first_lba(partition)
partition.last_lba.data = self._get_last_lba(partition)
partition.first_lba = self._get_first_lba(partition)
partition.last_lba = self._get_last_lba(partition)
self.entries.append(partition)

def _get_first_lba(self, partition: Partition) -> int:
Expand All @@ -180,7 +155,7 @@ def next_lba(end_lba: int, alignment: int):

largest_lba = 0
for part in self.entries:
lba = part.last_lba.data
lba = part.last_lba
if int(lba) > int(largest_lba):
largest_lba = lba
last_lba = 33 if largest_lba == 0 else largest_lba
Expand All @@ -197,13 +172,12 @@ def _get_last_lba(self, partition: Partition) -> int:

# round the LBA up to ensure our LBA will hold the partition
lba = int(ceil(partition.size / self._geometry.sector_size))
f_lba = int(partition.first_lba.data)
f_lba = int(partition.first_lba)
return (f_lba + lba) - 1

@property
def byte_structure(self) -> bytes:
"""Convert the Partition Array to its byte structure"""
parts = [x.byte_structure for x in self.entries]
def marshal(self) -> bytes:
"""Convert the Partition Entry Array to its byte structure"""
parts = [x.marshal() for x in self.entries]
part_bytes = b"".join(parts)
# pad the rest with zeros
padded = part_bytes + b"\x00" * (
Expand Down

0 comments on commit aaa777f

Please sign in to comment.