Skip to content

Commit

Permalink
annotate pdfminer.jbig2
Browse files Browse the repository at this point in the history
  • Loading branch information
0xabu committed Sep 5, 2021
1 parent 0d40b7c commit 4dbcf87
Show file tree
Hide file tree
Showing 3 changed files with 79 additions and 62 deletions.
3 changes: 0 additions & 3 deletions mypy.ini
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,6 @@ disallow_untyped_defs = True
disallow_untyped_calls = False
disallow_untyped_defs = False

[mypy-pdfminer.jbig2]
disallow_untyped_defs = False

[mypy-cryptography.hazmat.*]
ignore_missing_imports = True

Expand Down
9 changes: 4 additions & 5 deletions pdfminer/image.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,12 +97,11 @@ def export_image(self, image: LTImage) -> str:
input_stream = BytesIO()
input_stream.write(image.stream.get_data())
input_stream.seek(0)
reader = \
JBIG2StreamReader(input_stream) # type:ignore[no-untyped-call]
segments = reader.get_segments() # type: ignore[no-untyped-call]
reader = JBIG2StreamReader(input_stream)
segments = reader.get_segments()

writer = JBIG2StreamWriter(fp) # type: ignore[no-untyped-call]
writer.write_file(segments) # type: ignore[no-untyped-call]
writer = JBIG2StreamWriter(fp)
writer.write_file(segments)
elif image.bits == 1:
bmp = BMPWriter(fp, 1, width, height)
data = image.stream.get_data()
Expand Down
129 changes: 75 additions & 54 deletions pdfminer/jbig2.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import math
import os
from struct import pack, unpack, calcsize
from typing import BinaryIO, Dict, Iterable, List, Optional, Tuple, Union, cast

# segment structure base
SEG_STRUCT = [
Expand Down Expand Up @@ -34,49 +35,58 @@
FILE_HEAD_FLAG_PAGES_UNKNOWN = 0b00000010


def bit_set(bit_pos, value):
def bit_set(bit_pos: int, value: int) -> bool:
return bool((value >> bit_pos) & 1)


def check_flag(flag, value):
def check_flag(flag: int, value: int) -> bool:
return bool(flag & value)


def masked_value(mask, value):
def masked_value(mask: int, value: int) -> int:
for bit_pos in range(0, 31):
if bit_set(bit_pos, mask):
return (value & mask) >> bit_pos

raise Exception("Invalid mask or value")


def mask_value(mask, value):
def mask_value(mask: int, value: int) -> int:
for bit_pos in range(0, 31):
if bit_set(bit_pos, mask):
return (value & (mask >> bit_pos)) << bit_pos

raise Exception("Invalid mask or value")


def unpack_int(format: str, buffer: bytes) -> int:
assert format in {">B", ">I", ">L"}
[result] = cast(Tuple[int], unpack(format, buffer))
return result


JBIG2SegmentFlags = Dict[str, Union[int, bool]]
JBIG2RetentionFlags = Dict[str, Union[int, List[int], List[bool]]]
JBIG2Segment = Dict[str, Union[bool, int, bytes, JBIG2SegmentFlags,
JBIG2RetentionFlags]]


class JBIG2StreamReader:
"""Read segments from a JBIG2 byte stream"""

def __init__(self, stream):
def __init__(self, stream: BinaryIO) -> None:
self.stream = stream

def get_segments(self):
segments = []
def get_segments(self) -> List[JBIG2Segment]:
segments: List[JBIG2Segment] = []
while not self.is_eof():
segment = {}
segment: JBIG2Segment = {}
for field_format, name in SEG_STRUCT:
field_len = calcsize(field_format)
field = self.stream.read(field_len)
if len(field) < field_len:
segment["_error"] = True
break
value = unpack(field_format, field)
if len(value) == 1:
[value] = value
value = unpack_int(field_format, field)
parser = getattr(self, "parse_%s" % name, None)
if callable(parser):
value = parser(segment, value, field)
Expand All @@ -86,21 +96,23 @@ def get_segments(self):
segments.append(segment)
return segments

def is_eof(self):
def is_eof(self) -> bool:
if self.stream.read(1) == b'':
return True
else:
self.stream.seek(-1, os.SEEK_CUR)
return False

def parse_flags(self, segment, flags, field):
def parse_flags(self, segment: JBIG2Segment, flags: int, field: bytes
) -> JBIG2SegmentFlags:
return {
"deferred": check_flag(HEADER_FLAG_DEFERRED, flags),
"page_assoc_long": check_flag(HEADER_FLAG_PAGE_ASSOC_LONG, flags),
"type": masked_value(SEG_TYPE_MASK, flags)
}

def parse_retention_flags(self, segment, flags, field):
def parse_retention_flags(self, segment: JBIG2Segment, flags: int,
field: bytes) -> JBIG2RetentionFlags:
ref_count = masked_value(REF_COUNT_SHORT_MASK, flags)
retain_segments = []
ref_segments = []
Expand All @@ -110,15 +122,16 @@ def parse_retention_flags(self, segment, flags, field):
retain_segments.append(bit_set(bit_pos, flags))
else:
field += self.stream.read(3)
[ref_count] = unpack(">L", field)
ref_count = unpack_int(">L", field)
ref_count = masked_value(REF_COUNT_LONG_MASK, ref_count)
ret_bytes_count = int(math.ceil((ref_count + 1) / 8))
for ret_byte_index in range(ret_bytes_count):
[ret_byte] = unpack(">B", self.stream.read(1))
ret_byte = unpack_int(">B", self.stream.read(1))
for bit_pos in range(7):
retain_segments.append(bit_set(bit_pos, ret_byte))

seg_num = segment["number"]
assert isinstance(seg_num, int)
if seg_num <= 256:
ref_format = ">B"
elif seg_num <= 65536:
Expand All @@ -129,8 +142,8 @@ def parse_retention_flags(self, segment, flags, field):
ref_size = calcsize(ref_format)

for ref_index in range(ref_count):
ref = self.stream.read(ref_size)
[ref] = unpack(ref_format, ref)
ref_data = self.stream.read(ref_size)
ref = unpack_int(ref_format, ref_data)
ref_segments.append(ref)

return {
Expand All @@ -139,15 +152,18 @@ def parse_retention_flags(self, segment, flags, field):
"ref_segments": ref_segments,
}

def parse_page_assoc(self, segment, page, field):
if segment["flags"]["page_assoc_long"]:
def parse_page_assoc(self, segment: JBIG2Segment, page: int, field: bytes
) -> int:
if cast(JBIG2SegmentFlags, segment["flags"])["page_assoc_long"]:
field += self.stream.read(3)
[page] = unpack(">L", field)
page = unpack_int(">L", field)
return page

def parse_data_length(self, segment, length, field):
def parse_data_length(self, segment: JBIG2Segment, length: int,
field: bytes) -> int:
if length:
if (segment["flags"]["type"] == SEG_TYPE_IMMEDIATE_GEN_REGION) \
if (cast(JBIG2SegmentFlags, segment["flags"])["type"] ==
SEG_TYPE_IMMEDIATE_GEN_REGION) \
and (length == DATA_LEN_UNKNOWN):

raise NotImplementedError(
Expand All @@ -163,25 +179,33 @@ def parse_data_length(self, segment, length, field):
class JBIG2StreamWriter:
"""Write JBIG2 segments to a file in JBIG2 format"""

def __init__(self, stream):
EMPTY_RETENTION_FLAGS: JBIG2RetentionFlags = {
'ref_count': 0,
'ref_segments': cast(List[int], []),
'retain_segments': cast(List[bool], [])
}

def __init__(self, stream: BinaryIO) -> None:
self.stream = stream

def write_segments(self, segments, fix_last_page=True):
def write_segments(self, segments: Iterable[JBIG2Segment],
fix_last_page: bool = True) -> int:
data_len = 0
current_page = None
seg_num = None
current_page: Optional[int] = None
seg_num: Optional[int] = None

for segment in segments:
data = self.encode_segment(segment)
self.stream.write(data)
data_len += len(data)

seg_num = segment["number"]
seg_num = cast(Optional[int], segment["number"])

if fix_last_page:
seg_page = segment.get("page_assoc")
seg_page = cast(int, segment.get("page_assoc"))

if segment["flags"]["type"] == SEG_TYPE_END_OF_PAGE:
if cast(JBIG2SegmentFlags, segment["flags"])["type"] == \
SEG_TYPE_END_OF_PAGE:
current_page = None
elif seg_page:
current_page = seg_page
Expand All @@ -194,7 +218,8 @@ def write_segments(self, segments, fix_last_page=True):

return data_len

def write_file(self, segments, fix_last_page=True):
def write_file(self, segments: Iterable[JBIG2Segment],
fix_last_page: bool = True) -> int:
header = FILE_HEADER_ID
header_flags = FILE_HEAD_FLAG_SEQUENTIAL | FILE_HEAD_FLAG_PAGES_UNKNOWN
header += pack(">B", header_flags)
Expand All @@ -205,7 +230,7 @@ def write_file(self, segments, fix_last_page=True):

seg_num = 0
for segment in segments:
seg_num = segment["number"]
seg_num = cast(int, segment["number"])

eof_segment = self.get_eof_segment(seg_num + 1)
data = self.encode_segment(eof_segment)
Expand All @@ -215,7 +240,7 @@ def write_file(self, segments, fix_last_page=True):

return data_len

def encode_segment(self, segment):
def encode_segment(self, segment: JBIG2Segment) -> bytes:
data = b''
for field_format, name in SEG_STRUCT:
value = segment.get(name)
Expand All @@ -227,7 +252,8 @@ def encode_segment(self, segment):
data += field
return data

def encode_flags(self, value, segment):
def encode_flags(self, value: JBIG2SegmentFlags, segment: JBIG2Segment
) -> bytes:
flags = 0
if value.get("deferred"):
flags |= HEADER_FLAG_DEFERRED
Expand All @@ -237,17 +263,19 @@ def encode_flags(self, value, segment):
if value["page_assoc_long"] else flags
else:
flags |= HEADER_FLAG_PAGE_ASSOC_LONG \
if segment.get("page", 0) > 255 else flags
if cast(int, segment.get("page", 0)) > 255 else flags

flags |= mask_value(SEG_TYPE_MASK, value["type"])

return pack(">B", flags)

def encode_retention_flags(self, value, segment):
def encode_retention_flags(self, value: JBIG2RetentionFlags,
segment: JBIG2Segment) -> bytes:
flags = []
flags_format = ">B"
ref_count = value["ref_count"]
retain_segments = value.get("retain_segments", [])
assert isinstance(ref_count, int)
retain_segments = cast(List[bool], value.get("retain_segments", []))

if ref_count <= 4:
flags_byte = mask_value(REF_COUNT_SHORT_MASK, ref_count)
Expand All @@ -271,9 +299,9 @@ def encode_retention_flags(self, value, segment):

flags.append(ret_byte)

ref_segments = value.get("ref_segments", [])
ref_segments = cast(List[int], value.get("ref_segments", []))

seg_num = segment["number"]
seg_num = cast(int, segment["number"])
if seg_num <= 256:
ref_format = "B"
elif seg_num <= 65536:
Expand All @@ -287,35 +315,28 @@ def encode_retention_flags(self, value, segment):

return pack(flags_format, *flags)

def encode_data_length(self, value, segment):
def encode_data_length(self, value: int, segment: JBIG2Segment) -> bytes:
data = pack(">L", value)
data += segment["raw_data"]
data += cast(bytes, segment["raw_data"])
return data

def get_eop_segment(self, seg_number, page_number):
def get_eop_segment(self, seg_number: int, page_number: int
) -> JBIG2Segment:
return {
'data_length': 0,
'flags': {'deferred': False, 'type': SEG_TYPE_END_OF_PAGE},
'number': seg_number,
'page_assoc': page_number,
'raw_data': b'',
'retention_flags': {
'ref_count': 0,
'ref_segments': [],
'retain_segments': []
}
'retention_flags': JBIG2StreamWriter.EMPTY_RETENTION_FLAGS
}

def get_eof_segment(self, seg_number):
def get_eof_segment(self, seg_number: int) -> JBIG2Segment:
return {
'data_length': 0,
'flags': {'deferred': False, 'type': SEG_TYPE_END_OF_FILE},
'number': seg_number,
'page_assoc': 0,
'raw_data': b'',
'retention_flags': {
'ref_count': 0,
'ref_segments': [],
'retain_segments': []
}
'retention_flags': JBIG2StreamWriter.EMPTY_RETENTION_FLAGS
}

0 comments on commit 4dbcf87

Please sign in to comment.