Skip to content
This repository has been archived by the owner on Aug 18, 2022. It is now read-only.

Commit

Permalink
Add BaseKnownVLR abc to factorize some code
Browse files Browse the repository at this point in the history
  • Loading branch information
tmontaigu committed Jul 11, 2018
1 parent cc0ed95 commit 33a0aca
Showing 1 changed file with 42 additions and 60 deletions.
102 changes: 42 additions & 60 deletions pylas/vlrs/known.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from ..extradims import get_type_for_extra_dim


class KnownVLR:
class IKnownVLR:
""" Interface that any KnownVLR must implement.
A KnownVLR is a VLR for which we know how to parse its record_data
Expand Down Expand Up @@ -66,6 +66,17 @@ def parse_record_data(self, record_data):
"""
pass


class BaseKnownVLR(BaseVLR, IKnownVLR):
""" Base Class to factorize common code between the different type of Known VLRs
"""
def __init__(self, record_id=None, description=""):
super().__init__(
self.official_user_id(),
self.official_record_ids()[0] if record_id is None else record_id,
description,
)

@classmethod
def from_raw(cls, raw):
vlr = cls()
Expand Down Expand Up @@ -97,15 +108,11 @@ def size():
return ctypes.sizeof(ClassificationLookupStruct)


class ClassificationLookupVlr(BaseVLR, KnownVLR):
class ClassificationLookupVlr(BaseKnownVLR):
_lookup_size = ClassificationLookupStruct.size()

def __init__(self):
super().__init__(
self.official_user_id(),
self.official_record_ids()[0],
description="Classification Lookup",
)
super().__init__(description="Classification Lookup")
self.lookups = []

def _is_max_num_lookups_reached(self):
Expand All @@ -127,7 +134,7 @@ def parse_record_data(self, record_data):
for i in range(len(record_data) // ctypes.sizeof(ClassificationLookupStruct)):
self.lookups.append(
ClassificationLookupStruct.from_buffer_copy(
record_data[self._lookup_size * i : self._lookup_size * (i + 1)]
record_data[self._lookup_size * i: self._lookup_size * (i + 1)]
)
)

Expand All @@ -143,13 +150,9 @@ def official_record_ids():
return (0,)


class LasZipVlr(BaseVLR, KnownVLR):
class LasZipVlr(BaseKnownVLR):
def __init__(self, data):
super().__init__(
LasZipVlr.official_user_id(),
LasZipVlr.official_record_ids()[0],
"http://laszip.org",
)
super().__init__(description="http://laszip.org")
self.record_data = data

def parse_record_data(self, record_data):
Expand Down Expand Up @@ -206,11 +209,9 @@ def __repr__(self):
)


class ExtraBytesVlr(BaseVLR, KnownVLR):
class ExtraBytesVlr(BaseKnownVLR):
def __init__(self):
super().__init__(
"LASF_Spec", self.official_record_ids()[0], "Extra Bytes Record"
)
super().__init__(description="Extra Bytes Record")
self.extra_bytes_structs = []

def parse_record_data(self, data):
Expand All @@ -224,7 +225,7 @@ def parse_record_data(self, data):
self.extra_bytes_structs = [None] * num_extra_bytes_structs
for i in range(num_extra_bytes_structs):
self.extra_bytes_structs[i] = ExtraBytesStruct.from_buffer_copy(
data[ExtraBytesStruct.size() * i : ExtraBytesStruct.size() * (i + 1)]
data[ExtraBytesStruct.size() * i: ExtraBytesStruct.size() * (i + 1)]
)

def record_data_bytes(self):
Expand Down Expand Up @@ -265,11 +266,9 @@ def size():
return ctypes.sizeof(WaveformPacketStruct)


class WaveformPacketVlr(BaseVLR, KnownVLR):
class WaveformPacketVlr(BaseKnownVLR):
def __init__(self, record_id, description=""):
super().__init__(
self.official_user_id(), record_id=record_id, description=description
)
super().__init__(record_id=record_id, description=description)
self.parsed_record = None

def parse_record_data(self, record_data):
Expand Down Expand Up @@ -342,13 +341,9 @@ def __repr__(self):
)


class GeoKeyDirectoryVlr(BaseVLR, KnownVLR):
class GeoKeyDirectoryVlr(BaseKnownVLR):
def __init__(self):
super().__init__(
self.official_user_id(),
self.official_record_ids()[0],
description="GeoTIFF GeoKeyDirectoryTag",
)
super().__init__(description="GeoTIFF GeoKeyDirectoryTag")
self.geo_keys_header = GeoKeysHeaderStructs()
self.geo_keys = [GeoKeyEntryStruct()]

Expand All @@ -358,17 +353,17 @@ def parse_record_data(self, record_data):
self.geo_keys_header = GeoKeysHeaderStructs.from_buffer(header_data)
self.geo_keys = []
keys_data = record_data[GeoKeysHeaderStructs.size():]
num_keys = len(
record_data[GeoKeysHeaderStructs.size() :]
) // GeoKeyEntryStruct.size()
num_keys = (
len(record_data[GeoKeysHeaderStructs.size():]) // GeoKeyEntryStruct.size()
)
if num_keys != self.geo_keys_header.number_of_keys:
# print("Mismatch num keys")
self.geo_keys_header.number_of_keys = num_keys

for i in range(self.geo_keys_header.number_of_keys):
data = keys_data[
(i * GeoKeyEntryStruct.size()) : (i + 1) * GeoKeyEntryStruct.size()
]
(i * GeoKeyEntryStruct.size()): (i + 1) * GeoKeyEntryStruct.size()
]
self.geo_keys.append(GeoKeyEntryStruct.from_buffer(data))

def record_data_bytes(self):
Expand All @@ -385,13 +380,9 @@ def official_record_ids():
return (34735,)


class GeoDoubleParamsVlr(BaseVLR, KnownVLR):
class GeoDoubleParamsVlr(BaseKnownVLR):
def __init__(self):
super().__init__(
self.official_user_id(),
self.official_record_ids()[0],
description="GeoTIFF GeoDoubleParamsTag",
)
super().__init__(description="GeoTIFF GeoDoubleParamsTag")
self.doubles = []

def parse_record_data(self, record_data):
Expand All @@ -405,7 +396,7 @@ def parse_record_data(self, record_data):
record_data = bytearray(record_data)
num_doubles = len(record_data) // sizeof_double
for i in range(num_doubles):
b = record_data[i * sizeof_double : (i + 1) * sizeof_double]
b = record_data[i * sizeof_double: (i + 1) * sizeof_double]
self.doubles.append(ctypes.c_double.from_buffer(b))

def record_data_bytes(self):
Expand All @@ -420,13 +411,9 @@ def official_record_ids():
return (34736,)


class GeoAsciiParamsVlr(BaseVLR, KnownVLR):
class GeoAsciiParamsVlr(BaseKnownVLR):
def __init__(self):
super().__init__(
self.official_user_id(),
self.official_record_ids()[0],
description="GeoTIFF GeoAsciiParamsTag",
)
super().__init__(description="GeoTIFF GeoAsciiParamsTag")
self.strings = []

def parse_record_data(self, record_data):
Expand All @@ -444,7 +431,7 @@ def official_record_ids():
return (34737,)


class WktMathTransformVlr(BaseVLR, KnownVLR):
class WktMathTransformVlr(BaseKnownVLR):
"""
From the Spec:
Note that the math transform WKT record is added for completeness, and a coordinate system WKT
Expand All @@ -453,9 +440,7 @@ class WktMathTransformVlr(BaseVLR, KnownVLR):
"""

def __init__(self):
super().__init__(
self.official_user_id(), self.official_record_ids()[0], description=""
)
super().__init__(description="")
self.string = ""

def _encode_string(self):
Expand All @@ -476,17 +461,13 @@ def official_record_ids():
return (2112,)


class WktCoordinateSystemVlr(BaseVLR, KnownVLR):
class WktCoordinateSystemVlr(BaseKnownVLR):
""" Replaces Coordinates Reference System for new las files (point fmt >= 5)
"LAS is not using the “ESRI WKT”
"""

def __init__(self):
super().__init__(
self.official_user_id(),
self.official_record_ids()[0],
description="OGC Transformation Record",
)
super().__init__(description="OGC Transformation Record")
self.string = ""

def _encode_string(self):
Expand All @@ -509,10 +490,11 @@ def official_record_ids():

def vlr_factory(raw_vlr):
user_id = raw_vlr.header.user_id.rstrip(NULL_BYTE).decode()
for known_vlr in KnownVLR.__subclasses__():
known_vlrs = BaseKnownVLR.__subclasses__()
for known_vlr in known_vlrs:
if (
known_vlr.official_user_id() == user_id
and raw_vlr.header.record_id in known_vlr.official_record_ids()
known_vlr.official_user_id() == user_id
and raw_vlr.header.record_id in known_vlr.official_record_ids()
):
return known_vlr.from_raw(raw_vlr)
else:
Expand Down

0 comments on commit 33a0aca

Please sign in to comment.