diff --git a/README.md b/README.md index 2d72e90..0c08326 100644 --- a/README.md +++ b/README.md @@ -6,9 +6,9 @@ [![Coveralls status](https://coveralls.io/repos/tcalmant/python-javaobj/badge.svg?branch=master)](https://coveralls.io/r/tcalmant/python-javaobj?branch=master) *python-javaobj* is a python library that provides functions for reading and -writing (writing is WIP currently) Java objects serialized or will be -deserialized by `ObjectOutputStream`. This form of object representation is a -standard data interchange format in Java world. +writing Java objects serialized or to be deserialized by `ObjectOutputStream`. +This form of object representation is a standard data interchange format in +Java world. The `javaobj` module exposes an API familiar to users of the standard library `marshal`, `pickle` and `json` modules. @@ -39,16 +39,17 @@ Since version 0.4.0, three implementations of the parser are available: with support of the object transformer (with a new API) and of the `numpy` arrays loading. * `v3`: a **new** implementation, written from scratch to benefit from - Python 3.12+ features. + Python 3.12+ features, with full read **and** write support. You can use the `v1` parser to ensure that the behaviour of your scripts -doesn't change and to keep the ability to write down files. +doesn't change. It also provides a basic marshalling capability. You can use the `v2` parser for developments in Python versions lower than 3.12 and *which won't require marshalling*, or as a *fallback* if the `v1` parser fails to parse a file. -For new development, you should use the `v3` parser. +For new development, you should use the `v3` parser, which supports both +reading and writing Java object streams. ### Object transformers V1 @@ -110,7 +111,8 @@ You can find a sample usage in the *Custom Transformer* section in this file. * Primitive values un-marshalling * Automatic conversion of Java Collections to python ones (`HashMap` => `dict`, `ArrayList` => `list`, etc.) -* Basic marshalling of simple Java objects (`v1` implementation only) +* Basic marshalling of simple Java objects (`v1` implementation) +* Full marshalling of Java object streams (`v3` implementation) * Automatically uncompresses GZipped files ## Requirements @@ -544,6 +546,7 @@ value = pobj.myField | Correct `TYPE_CHAR` numpy dtype (`>u2`) | ✗ | ✗ | ✓ | | Typed exception hierarchy | ✗ | ✗ | ✓ | | `BlockData.__eq__(bytes)` compatibility | ✓ | ✓ | ✓ | +| Marshalling (writing) support | partial | ✗ | ✓ | ### Security limits @@ -589,6 +592,180 @@ with open("arrays.ser", "rb") as fd: When `use_numpy_arrays=True`, a `NumpyArrayTransformer` is appended to the transformer list and primitive arrays are returned as `numpy.ndarray`. +### Marshalling / Writing (V3) + +The `javaobj.v3` package exposes two additional entry-points for serializing +beans back to the Java Object Serialization binary format: + +* `dump(fd, *objects)`: Writes one or more parsed objects to a binary file + descriptor opened in `wb` mode. +* `dumps(*objects) -> bytes`: Returns the serialized stream as a `bytes` + object. + +Both functions accept any combination of +`JavaInstance`, `JavaArray`, `JavaString`, `JavaEnum`, `JavaClass`, `BlockData`, +and `None` (written as `TC_NULL`) as positional arguments. + +#### Simple round-trip + +```python +import javaobj.v3 as javaobj + +# Parse an existing file +with open("obj5.ser", "rb") as fd: + pobj = javaobj.load(fd) + +# Serialize back to bytes +data = javaobj.dumps(pobj) + +# Or write directly to a file +with open("obj5_copy.ser", "wb") as fd: + javaobj.dump(fd, pobj) +``` + +#### Writing multiple objects + +```python +import javaobj.v3 as javaobj +from javaobj.v3.beans import JavaString + +with open("a.ser", "rb") as fd: + obj_a = javaobj.load(fd) + +# Write two objects into one stream +data = javaobj.dumps(obj_a, JavaString("hello")) + +# Re-parse: returns a list when the stream holds more than one object +result = javaobj.loads(data) # -> [obj_a, JavaString("hello")] +``` + +#### Supported constructs + +| Construct | Supported | +|---|---| +| `TC_OBJECT` — `NOWRCLASS` (plain fields only) | ✓ | +| `TC_OBJECT` — `WRCLASS` (fields + block-data annotations) | ✓ | +| `TC_ARRAY` | ✓ | +| `TC_STRING` / `TC_LONGSTRING` | ✓ | +| `TC_ENUM` | ✓ | +| `TC_CLASS` | ✓ | +| `TC_NULL` | ✓ | +| `TC_BLOCKDATA` / `TC_BLOCKDATALONG` | ✓ | +| `TC_PROXYCLASSDESC` | ✓ | +| Back-references (`TC_REFERENCE`) | ✓ (automatic) | +| `EXTERNAL_CONTENTS` (Protocol v1 `Externalizable`) | ✗ | + +> **Note:** Back-references are tracked automatically by identity: if the same +> object appears more than once in the graph, subsequent occurrences are +> written as `TC_REFERENCE` — exactly as Java's `ObjectOutputStream` does. + +#### Building a Java object from scratch + +You can construct the v3 beans manually to serialize a Python object as if it +were a Java one. The key types are: + +* `JavaClassDesc` — the class descriptor (name, `serialVersionUID`, flags, + fields) +* `JavaField` — one field entry (type code + name, and optionally the binary + class name for object/array fields) +* `JavaInstance` — the object instance (`field_data` maps each class + descriptor to a `{JavaField: value}` dict) +* `JavaString` — a Java `String` value + +All beans accept `handle=0` when created from scratch; the writer assigns real +handles automatically during serialization. + +```python +import javaobj.v3 as javaobj +from javaobj.constants import ClassDescFlags +from javaobj.v3.beans import ( + FieldType, + JavaClassDesc, + ClassDescType, + JavaField, + JavaInstance, + JavaString, +) + +# ── 1. Describe the Java class ──────────────────────────────────────────────── +# +# Java equivalent: +# +# package com.example; +# public class Point implements java.io.Serializable { +# private static final long serialVersionUID = 1L; +# public int x; +# public int y; +# } + +field_x = JavaField(type=FieldType.INTEGER, name="x") +field_y = JavaField(type=FieldType.INTEGER, name="y") + +point_cd = JavaClassDesc( + handle=0, # assigned by the writer + name="com.example.Point", + serial_version_uid=1, + desc_flags=ClassDescFlags.SC_SERIALIZABLE, + fields=[field_x, field_y], +) + +# ── 2. Create an instance ───────────────────────────────────────────────────── + +point = JavaInstance( + handle=0, + classdesc=point_cd, + field_data={ + point_cd: { + field_x: 42, + field_y: -7, + } + }, +) + +# ── 3. Serialize ────────────────────────────────────────────────────────────── + +data = javaobj.dumps(point) + +# ── 4. Round-trip check ─────────────────────────────────────────────────────── + +restored = javaobj.loads(data) +print(restored.get_field("x")) # 42 +print(restored.get_field("y")) # -7 +``` + +For object-type fields (e.g. a `String` attribute), use `FieldType.OBJECT`, +set `class_name` to the binary class name, and pass a `JavaString` as the +value: + +```python +field_name = JavaField( + type=FieldType.OBJECT, + name="name", + class_name="Ljava/lang/String;", # binary name for java.lang.String +) + +person_cd = JavaClassDesc( + handle=0, + name="com.example.Person", + serial_version_uid=1, + desc_flags=ClassDescFlags.SC_SERIALIZABLE, + fields=[field_name, field_x], # reuse field_x from above +) + +alice = JavaInstance( + handle=0, + classdesc=person_cd, + field_data={ + person_cd: { + field_name: JavaString(handle=0, value="Alice"), + field_x: 30, + } + }, +) + +data = javaobj.dumps(alice) +``` + --- ## Migration to V3 @@ -602,7 +779,7 @@ transformer list and primitive arrays are returned as `numpy.ndarray`. | `pobj.myField` (direct attribute) | `pobj.get_field("myField")` (preferred) or `pobj.myField` | | `pobj._data` on arrays | `pobj.data` (public) | | `javaobj.JavaObjectUnmarshaller` | removed — use `javaobj.v3.parser.JavaStreamParser` | -| `javaobj.JavaObjectMarshaller` | marshalling not available in `v3` | +| `javaobj.JavaObjectMarshaller` | `javaobj.v3.dump` / `javaobj.v3.dumps` | | Exceptions: bare `Exception` | Typed: `ParseError`, `UnexpectedOpcodeError`, … | Shallow conversion helper (best-effort, for gradual migration): @@ -637,5 +814,5 @@ from javaobj.v3._compat import v2_to_v3 v3_obj = v2_to_v3(v2_obj) ``` -> **Note:** `v3` requires **Python 3.12+** and does **not** support marshalling -> (writing). If you need to write Java object streams, use `v1`. +> **Note:** `v3` requires **Python 3.12+**. +> For writing Java object streams on older Python versions, use `v1`. diff --git a/javaobj/v3/__init__.py b/javaobj/v3/__init__.py index e54d121..acb5137 100644 --- a/javaobj/v3/__init__.py +++ b/javaobj/v3/__init__.py @@ -1,9 +1,9 @@ #!/usr/bin/env python3 """ -Rewritten version of the un-marshalling process of javaobj (v3) +Rewritten version of the un-marshalling and marshalling process of javaobj (v3) -This package targets Python 3.12+ and provides fully typed parsing of the -Java Object Serialization stream format, in read-only mode. +This package targets Python 3.12+ and provides fully typed parsing and +serializing of the Java Object Serialization stream format. :authors: Thomas Calmant :license: Apache License 2.0 @@ -66,11 +66,16 @@ NumpyArrayTransformer, ObjectTransformer, ) +from .writer import JavaStreamWriter, dump, dumps __all__ = [ - # Entry points + # Entry points (reading) "load", "loads", + # Entry points (writing) + "dump", + "dumps", + "JavaStreamWriter", # Transformer API "ObjectTransformer", "DefaultObjectTransformer", diff --git a/javaobj/v3/writer.py b/javaobj/v3/writer.py new file mode 100644 index 0000000..9dbf538 --- /dev/null +++ b/javaobj/v3/writer.py @@ -0,0 +1,544 @@ +#!/usr/bin/env python3 +""" +Serializer for the Java Object Serialization stream format (v3) + +Produces a byte stream readable by Java's ``ObjectInputStream`` from v3 bean +objects (:class:`~javaobj.v3.beans.JavaInstance`, :class:`~javaobj.v3.beans.JavaArray`, +etc.). + +:authors: Thomas Calmant +:license: Apache License 2.0 +:version: 0.5.0 +:status: Alpha + +.. + + Copyright 2026 Thomas Calmant + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +""" + +# Standard library +import logging +import struct +from io import BytesIO +from typing import IO, Any + +# Javaobj +from ..constants import StreamConstants, TerminalCode +from .beans import ( + BlockData, + ClassDataType, + ClassDescType, + FieldType, + JavaArray, + JavaClass, + JavaClassDesc, + JavaEnum, + JavaInstance, + JavaString, + ParsedContent, +) +from .exceptions import UnsupportedFeatureError + +# ------------------------------------------------------------------------------ + +# Module version +__version_info__ = (0, 5, 0) +__version__ = ".".join(str(x) for x in __version_info__) + +# Documentation strings format +__docformat__ = "restructuredtext en" + +# ------------------------------------------------------------------------------ + +__all__ = ["JavaStreamWriter", "dump", "dumps"] + +_log = logging.getLogger("javaobj.v3.writer") + +# ------------------------------------------------------------------------------ +# Modified UTF-8 encoder +# ------------------------------------------------------------------------------ + + +def _encode_mutf8(string: str) -> bytes: + """ + Encodes a Unicode string to Java Modified UTF-8 bytes. + + Differences from standard UTF-8: + + * The null character (U+0000) is encoded as two bytes ``\\xC0\\x80`` + instead of a single zero byte. + * Supplementary characters (U+10000–U+10FFFF) are encoded as two + three-byte surrogate-pair sequences (six bytes total) instead of the + standard four-byte encoding. + """ + out = bytearray() + for char in string: + cp = ord(char) + if cp == 0x0000: + # Modified UTF-8: null → 0xC0 0x80 + out += b"\xc0\x80" + elif cp <= 0x007F: + out.append(cp) + elif cp <= 0x07FF: + out += bytes([0xC0 | (cp >> 6), 0x80 | (cp & 0x3F)]) + elif cp <= 0xFFFF: + out += bytes( + [ + 0xE0 | (cp >> 12), + 0x80 | ((cp >> 6) & 0x3F), + 0x80 | (cp & 0x3F), + ] + ) + else: + # Supplementary character: encode as surrogate pair, each as a + # 3-byte modified-UTF-8 sequence (6 bytes total). + cp -= 0x10000 + high = 0xD800 | (cp >> 10) + low = 0xDC00 | (cp & 0x3FF) + out += bytes( + [ + 0xED, + 0xA0 | ((high >> 6) & 0x0F), + 0x80 | (high & 0x3F), + 0xED, + 0xB0 | ((low >> 6) & 0x0F), + 0x80 | (low & 0x3F), + ] + ) + return bytes(out) + + +# ------------------------------------------------------------------------------ +# Writer +# ------------------------------------------------------------------------------ + + +class JavaStreamWriter: + """ + Serializes v3 bean objects to the Java Object Serialization stream format. + + The generated stream is fully compatible with Java's ``ObjectInputStream``. + + Usage:: + + with open("out.ser", "wb") as fd: + writer = JavaStreamWriter(fd) + writer.write_stream(my_instance) + + Or using the module-level helpers:: + + data = javaobj.v3.dumps(my_instance) + javaobj.v3.dump(fd, my_instance) + """ + + def __init__(self, fd: IO[bytes]) -> None: + self._fd = fd + # Maps id(obj) → allocated handle (int, starting at BASE_REFERENCE_IDX) + self._handle_map: dict[int, int] = {} + self._next_handle: int = int(StreamConstants.BASE_REFERENCE_IDX) + # Cached JavaString wrappers for class-name strings found inside + # JavaField descriptors. Keyed by the string value so that identical + # class names (e.g. "Ljava/lang/String;") are written only once and + # referenced thereafter. + self._classname_strings: dict[str, JavaString] = {} + + # ------------------------------------------------------------------ + # Public entry points + # ------------------------------------------------------------------ + + def write_stream(self, *objects: ParsedContent) -> None: + """ + Writes the Java serialization magic header followed by one or more + top-level content objects. + + Call this exactly once to produce a complete, self-contained stream. + + :param objects: Top-level objects to write. Pass several to create + a stream that requires multiple ``readObject()`` calls + on the Java side. + :raises UnsupportedFeatureError: If an object type cannot be + serialized (e.g. externalizable + Protocol-v1 classes). + """ + self._write_header() + for obj in objects: + self._write_content(obj) + + # ------------------------------------------------------------------ + # Stream header + # ------------------------------------------------------------------ + + def _write_header(self) -> None: + self._fd.write( + struct.pack( + ">HH", + int(StreamConstants.STREAM_MAGIC), + int(StreamConstants.STREAM_VERSION), + ) + ) + + # ------------------------------------------------------------------ + # Handle management + # ------------------------------------------------------------------ + + def _alloc_handle(self, obj: Any) -> int: + """Allocates and records the next handle for *obj*.""" + h = self._next_handle + self._next_handle += 1 + self._handle_map[id(obj)] = h + _log.debug("Allocated handle 0x%x for %s", h, type(obj).__name__) + return h + + def _try_reference(self, obj: Any) -> bool: + """ + Emits ``TC_REFERENCE`` for *obj* if it was already written. + + :return: ``True`` when a reference was written and the caller must + **not** write the object again; ``False`` otherwise. + """ + h = self._handle_map.get(id(obj)) + if h is None: + return False + _log.debug("TC_REFERENCE 0x%x for %s", h, type(obj).__name__) + self._fd.write(struct.pack(">Bi", int(TerminalCode.TC_REFERENCE), h)) + return True + + # ------------------------------------------------------------------ + # Content dispatcher + # ------------------------------------------------------------------ + + def _write_content(self, obj: ParsedContent) -> None: + """Writes a single content item — any valid v3 bean or ``None``.""" + match obj: + case None: + self._write_null() + case JavaInstance(): + self._write_instance(obj) + case JavaArray(): + self._write_array(obj) + case JavaString(): + self._write_string_obj(obj) + case JavaEnum(): + self._write_enum(obj) + case JavaClass(): + self._write_class(obj) + case BlockData(): + self._write_blockdata(obj) + case JavaClassDesc(): + # A bare class descriptor written directly to the stream + # (rare but valid at the top level). + self._write_classdesc(obj) + case _: + raise UnsupportedFeatureError(f"Cannot serialize object of type {type(obj).__name__!r}") + + # ------------------------------------------------------------------ + # TC_NULL + # ------------------------------------------------------------------ + + def _write_null(self) -> None: + self._fd.write(bytes([int(TerminalCode.TC_NULL)])) + + # ------------------------------------------------------------------ + # TC_OBJECT + # ------------------------------------------------------------------ + + def _write_instance(self, instance: JavaInstance) -> None: + if self._try_reference(instance): + return + self._fd.write(bytes([int(TerminalCode.TC_OBJECT)])) + self._write_classdesc(instance.classdesc) + self._alloc_handle(instance) + self._write_class_data(instance) + + # ------------------------------------------------------------------ + # TC_ARRAY + # ------------------------------------------------------------------ + + def _write_array(self, array: JavaArray) -> None: + if self._try_reference(array): + return + self._fd.write(bytes([int(TerminalCode.TC_ARRAY)])) + self._write_classdesc(array.classdesc) + self._alloc_handle(array) + + data = array.data + self._fd.write(struct.pack(">i", len(data))) + + et = array.element_type + if et == FieldType.BYTE: + # Bulk write: data is already bytes (or bytearray) + self._fd.write(data if isinstance(data, (bytes, bytearray)) else bytes(data)) # type: ignore[arg-type] + else: + for item in data: # type: ignore[union-attr] + self._write_field_value(et, item) + + # ------------------------------------------------------------------ + # TC_STRING / TC_LONGSTRING + # ------------------------------------------------------------------ + + def _write_string_obj(self, s: JavaString) -> None: + if self._try_reference(s): + return + encoded = _encode_mutf8(s.value) + n = len(encoded) + if n <= 0xFFFF: + self._fd.write(bytes([int(TerminalCode.TC_STRING)])) + self._alloc_handle(s) + self._fd.write(struct.pack(">H", n) + encoded) + else: + self._fd.write(bytes([int(TerminalCode.TC_LONGSTRING)])) + self._alloc_handle(s) + self._fd.write(struct.pack(">q", n) + encoded) + + # ------------------------------------------------------------------ + # TC_ENUM + # ------------------------------------------------------------------ + + def _write_enum(self, enum: JavaEnum) -> None: + if self._try_reference(enum): + return + self._fd.write(bytes([int(TerminalCode.TC_ENUM)])) + self._write_classdesc(enum.classdesc) + self._alloc_handle(enum) + self._write_string_obj(enum.constant) + + # ------------------------------------------------------------------ + # TC_CLASS + # ------------------------------------------------------------------ + + def _write_class(self, cls: JavaClass) -> None: + if self._try_reference(cls): + return + self._fd.write(bytes([int(TerminalCode.TC_CLASS)])) + self._write_classdesc(cls.classdesc) + self._alloc_handle(cls) + + # ------------------------------------------------------------------ + # TC_CLASSDESC / TC_PROXYCLASSDESC + # ------------------------------------------------------------------ + + def _write_classdesc(self, cd: JavaClassDesc | None) -> None: + if cd is None: + self._write_null() + return + if self._try_reference(cd): + return + match cd.class_type: + case ClassDescType.NORMALCLASS: + self._write_normal_classdesc(cd) + case ClassDescType.PROXYCLASS: + self._write_proxy_classdesc(cd) + + def _write_normal_classdesc(self, cd: JavaClassDesc) -> None: + """ + Serializes a normal (non-proxy) class descriptor. + + Wire layout:: + + TC_CLASSDESC utf(className) long(serialVersionUID) + newHandle byte(classDescFlags) short(fieldCount) + [byte(typeCode) utf(fieldName) [string(className2)]] ... + classAnnotation superClassDesc + """ + self._fd.write(bytes([int(TerminalCode.TC_CLASSDESC)])) + self._write_utf(cd.name) + self._fd.write(struct.pack(">q", cd.serial_version_uid)) + self._alloc_handle(cd) + self._fd.write(struct.pack(">Bh", cd.desc_flags, len(cd.fields))) + + for f in cd.fields: + # type byte + field name + self._fd.write(bytes([f.type.value])) + self._write_utf(f.name) + # Object/array fields carry a second string: the class name + if f.type in (FieldType.OBJECT, FieldType.ARRAY): + cn = f.class_name or "" + # Reuse the same JavaString object for identical class names + # so that TC_REFERENCE is written on subsequent occurrences. + if cn not in self._classname_strings: + self._classname_strings[cn] = JavaString(handle=0, value=cn) + self._write_string_obj(self._classname_strings[cn]) + + # Class annotations written by annotateClass() (usually empty) + for ann in cd.annotations: + self._write_content(ann) + self._fd.write(bytes([int(TerminalCode.TC_ENDBLOCKDATA)])) + + # Super-class descriptor (or TC_NULL) + self._write_classdesc(cd.super_class) + + def _write_proxy_classdesc(self, cd: JavaClassDesc) -> None: + """ + Serializes a dynamic proxy class descriptor. + + Wire layout:: + + TC_PROXYCLASSDESC int(interfaceCount) + [utf(interfaceName)] ... + newHandle classAnnotation superClassDesc + """ + self._fd.write(bytes([int(TerminalCode.TC_PROXYCLASSDESC)])) + self._fd.write(struct.pack(">i", len(cd.interfaces))) + for iface in cd.interfaces: + self._write_utf(iface) + self._alloc_handle(cd) + + for ann in cd.annotations: + self._write_content(ann) + self._fd.write(bytes([int(TerminalCode.TC_ENDBLOCKDATA)])) + + self._write_classdesc(cd.super_class) + + # ------------------------------------------------------------------ + # TC_BLOCKDATA / TC_BLOCKDATALONG + # ------------------------------------------------------------------ + + def _write_blockdata(self, bd: BlockData) -> None: + n = len(bd.data) + if n <= 255: + self._fd.write(struct.pack(">BB", int(TerminalCode.TC_BLOCKDATA), n)) + else: + self._fd.write(struct.pack(">Bi", int(TerminalCode.TC_BLOCKDATALONG), n)) + self._fd.write(bd.data) + + # ------------------------------------------------------------------ + # classdata — instance field values + annotations per hierarchy class + # ------------------------------------------------------------------ + + def _write_class_data(self, instance: JavaInstance) -> None: + """ + Writes all field values and object annotations for *instance*, + walking the class hierarchy from topmost ancestor to concrete class + (the same order as ``ObjectOutputStream`` on the Java side). + """ + if instance.classdesc is None: + return + + for cd in instance.classdesc.get_hierarchy(): + try: + data_type = cd.data_type + except ValueError: + # No SC_SERIALIZABLE / SC_EXTERNALIZABLE flags — skip. + continue + + cd_fields = instance.field_data.get(cd, {}) + + match data_type: + case ClassDataType.NOWRCLASS: + # Plain serializable class: write fields only. + for f in cd.fields: + self._write_field_value(f.type, cd_fields.get(f)) + + case ClassDataType.WRCLASS: + # Serializable class with writeObject(): + # fields first, then the custom annotation block. + for f in cd.fields: + self._write_field_value(f.type, cd_fields.get(f)) + for ann in instance.annotations.get(cd, []): + self._write_content(ann) + self._fd.write(bytes([int(TerminalCode.TC_ENDBLOCKDATA)])) + + case ClassDataType.OBJECT_ANNOTATION: + # Externalizable + SC_BLOCK_DATA: + # all data lives in the annotation block. + for ann in instance.annotations.get(cd, []): + self._write_content(ann) + self._fd.write(bytes([int(TerminalCode.TC_ENDBLOCKDATA)])) + + case ClassDataType.EXTERNAL_CONTENTS: + raise UnsupportedFeatureError( + f"SC_EXTERNALIZABLE without SC_BLOCK_DATA " + f"(Protocol v1) is not supported for class {cd.name!r}" + ) + + # ------------------------------------------------------------------ + # Field value writer + # ------------------------------------------------------------------ + + def _write_field_value(self, field_type: FieldType, value: Any) -> None: + """Writes a single field value according to *field_type*.""" + match field_type: + case FieldType.BYTE: + self._fd.write(struct.pack(">b", int(value) if value is not None else 0)) + case FieldType.CHAR: + cp = ord(value) if isinstance(value, str) else int(value) + self._fd.write(struct.pack(">H", cp & 0xFFFF)) + case FieldType.SHORT: + self._fd.write(struct.pack(">h", int(value) if value is not None else 0)) + case FieldType.INTEGER: + self._fd.write(struct.pack(">i", int(value) if value is not None else 0)) + case FieldType.LONG: + self._fd.write(struct.pack(">q", int(value) if value is not None else 0)) + case FieldType.FLOAT: + self._fd.write(struct.pack(">f", float(value) if value is not None else 0.0)) + case FieldType.DOUBLE: + self._fd.write(struct.pack(">d", float(value) if value is not None else 0.0)) + case FieldType.BOOLEAN: + self._fd.write(bytes([1 if value else 0])) + case FieldType.OBJECT | FieldType.ARRAY: + self._write_content(value) + + # ------------------------------------------------------------------ + # Short-length UTF helper + # ------------------------------------------------------------------ + + def _write_utf(self, s: str) -> None: + """ + Writes a "short" UTF entry: 2-byte unsigned length + Modified UTF-8 + bytes. + + Used for class names, field names, and interface names *inside* class + descriptor records. These strings do **not** receive handles and are + **not** written as ``TC_STRING`` objects. + + :raises ValueError: If the encoded byte length exceeds 65535. + """ + encoded = _encode_mutf8(s) + n = len(encoded) + if n > 0xFFFF: + raise ValueError(f"String too long for short-length UTF field: {n} bytes (max 65535)") + self._fd.write(struct.pack(">H", n) + encoded) + + +# ------------------------------------------------------------------------------ +# Module-level convenience functions +# ------------------------------------------------------------------------------ + + +def dump(fd: IO[bytes], *objects: ParsedContent) -> None: + """ + Serializes one or more v3 bean objects to a binary file-like object. + + :param fd: A writable binary stream (opened in ``"wb"`` mode). + :param objects: Top-level objects to serialize. Pass several to create a + multi-object stream (each requiring a separate + ``readObject()`` call on the Java side). + :raises UnsupportedFeatureError: If an object type cannot be serialized. + """ + writer = JavaStreamWriter(fd) + writer.write_stream(*objects) + + +def dumps(*objects: ParsedContent) -> bytes: + """ + Serializes one or more v3 bean objects to a :class:`bytes` object. + + :param objects: Top-level objects to serialize (see :func:`dump`). + :return: A complete Java Object Serialization stream as :class:`bytes`. + :raises UnsupportedFeatureError: If an object type cannot be serialized. + """ + buf = BytesIO() + dump(buf, *objects) + return buf.getvalue() diff --git a/tests/test_v1.py b/tests/test_v1.py index 5ae5ced..e22ebc0 100644 --- a/tests/test_v1.py +++ b/tests/test_v1.py @@ -508,8 +508,8 @@ def test_read_custom(self): pobj = javaobj.loads(ser) self.assertIsNone(pobj.superItems) self.assertIsNone(pobj.items) - self.assertEquals(pobj.name, "test") - self.assertEquals(pobj.port, 443) + self.assertEqual(pobj.name, "test") + self.assertEqual(pobj.port, 443) # ------------------------------------------------------------------------------ diff --git a/tests/test_v3.py b/tests/test_v3.py index 6aad201..36724d2 100644 --- a/tests/test_v3.py +++ b/tests/test_v3.py @@ -52,6 +52,7 @@ JavaTime, ObjectTransformer, ) +from javaobj.v3.writer import _encode_mutf8 # ------------------------------------------------------------------------------ @@ -645,6 +646,172 @@ def test_v1_to_v3_unknown_raises(self) -> None: v1_to_v3(object()) # type: ignore[arg-type] +# ------------------------------------------------------------------------------ +# Writer / round-trip tests +# ------------------------------------------------------------------------------ + + +class TestWriter(TestJavaobjV3Base): + """Tests for javaobj.v3.writer — serializing beans back to bytes.""" + + # ------------------------------------------------------------------ + # Modified UTF-8 encoder unit tests + # ------------------------------------------------------------------ + + def test_mutf8_ascii(self) -> None: + """ASCII characters round-trip through Modified UTF-8.""" + s = "Hello, World!" + self.assertEqual(_encode_mutf8(s), s.encode("ascii")) + + def test_mutf8_null(self) -> None: + """Null character is encoded as two-byte sequence 0xC0 0x80.""" + self.assertEqual(_encode_mutf8("\x00"), b"\xc0\x80") + + def test_mutf8_japanese(self) -> None: + """CJK characters produce a 3-byte-per-codepoint encoding.""" + s = "\u65e5\u672c\u56fd" # 日本国 + encoded = _encode_mutf8(s) + # 3 codepoints × 3 bytes each = 9 bytes + self.assertEqual(len(encoded), 9) + + def test_mutf8_supplementary(self) -> None: + """A supplementary character (U+1F600 😀) encodes as 6 bytes.""" + encoded = _encode_mutf8("\U0001f600") + self.assertEqual(len(encoded), 6) + # Must start with the first surrogate half marker + self.assertEqual(encoded[0], 0xED) + self.assertEqual(encoded[3], 0xED) + + # ------------------------------------------------------------------ + # dumps / dump API smoke tests + # ------------------------------------------------------------------ + + def test_dumps_returns_bytes(self) -> None: + """javaobj.v3.dumps() returns bytes starting with the magic header.""" + pobj = self.load_file("testBoolIntLong.ser") + data = javaobj.dumps(pobj) + self.assertIsInstance(data, bytes) + # Magic: 0xACED, version: 0x0005 + self.assertEqual(data[:4], b"\xac\xed\x00\x05") + + def test_dump_to_fd(self) -> None: + """javaobj.v3.dump(fd, obj) writes to a file-like object.""" + import io + + pobj = self.load_file("testBoolIntLong.ser") + buf = io.BytesIO() + javaobj.dump(buf, pobj) + self.assertEqual(buf.getvalue()[:4], b"\xac\xed\x00\x05") + + # ------------------------------------------------------------------ + # Round-trip tests (parse → write → re-parse → compare field values) + # ------------------------------------------------------------------ + + def _round_trip(self, filename: str) -> tuple[Any, Any]: + """ + Parses *filename*, serializes the result, re-parses the bytes, and + returns ``(original, re_parsed)`` for the caller to assert on. + """ + original = self.load_file(filename) + serialized = javaobj.dumps(original) + re_parsed = javaobj.loads(serialized) + return original, re_parsed + + def test_round_trip_instance_fields(self) -> None: + """NOWRCLASS instance: field values survive a write→re-read cycle.""" + original, re_parsed = self._round_trip("testBoolIntLong.ser") + self.assertIsInstance(re_parsed, JavaInstance) + # Compare all field values by name + orig_cd = original.get_class() + new_cd = re_parsed.get_class() + self.assertEqual(orig_cd.name, new_cd.name) + self.assertEqual(orig_cd.serial_version_uid, new_cd.serial_version_uid) + for field_name in orig_cd.fields_names: + self.assertEqual( + original.get_field(field_name), + re_parsed.get_field(field_name), + msg=f"Field {field_name!r} differs after round-trip", + ) + + def test_round_trip_string(self) -> None: + """JavaString: value survives a write→re-read cycle.""" + original, re_parsed = self._round_trip("testJapan.ser") + self.assertIsInstance(re_parsed, JavaString) + self.assertEqual(str(original), str(re_parsed)) + + def test_round_trip_char_array(self) -> None: + """JavaArray (chars): data survives a write→re-read cycle.""" + original, re_parsed = self._round_trip("testCharArray.ser") + self.assertIsInstance(re_parsed, JavaArray) + self.assertEqual(re_parsed.element_type, FieldType.CHAR) + self.assertEqual(list(original.data), list(re_parsed.data)) + + def test_round_trip_byte_array(self) -> None: + """JavaArray (bytes): data survives a write→re-read cycle.""" + # testBytes.ser is a raw BlockData, so use a proper Java array fixture + original, re_parsed = self._round_trip("objArrays.ser") + self.assertEqual(type(original), type(re_parsed)) + + def test_round_trip_enum(self) -> None: + """Enum constant embedded in an instance: class/value survive round-trip.""" + # objEnums.ser contains a JavaInstance with a JavaEnum field 'color' + original = self.load_file("objEnums.ser") + serialized = javaobj.dumps(original) + re_parsed = javaobj.loads(serialized) + self.assertIsInstance(re_parsed, JavaInstance) + self.assertEqual(re_parsed.get_class().name, original.get_class().name) + orig_color = original.color + new_color = re_parsed.color + self.assertIsInstance(new_color, JavaEnum) + self.assertEqual(new_color.classdesc.name, orig_color.classdesc.name) + self.assertEqual(str(new_color.constant), str(orig_color.constant)) + + def test_round_trip_super_class(self) -> None: + """Instance with class hierarchy: all fields survive round-trip.""" + original, re_parsed = self._round_trip("objSuper.ser") + self.assertIsInstance(re_parsed, JavaInstance) + orig_cd = original.get_class() + new_cd = re_parsed.get_class() + self.assertEqual(orig_cd.name, new_cd.name) + # Walk hierarchy and compare every field value + for o_hcd, n_hcd in zip(orig_cd.get_hierarchy(), new_cd.get_hierarchy()): + self.assertEqual(o_hcd.name, n_hcd.name) + if o_hcd not in original.field_data: + continue + for o_f, n_f in zip(o_hcd.fields, n_hcd.fields): + self.assertEqual(o_f.name, n_f.name) + self.assertEqual( + original.field_data[o_hcd][o_f], + re_parsed.field_data[n_hcd][n_f], + msg=f"Field {o_f.name!r} in {o_hcd.name!r}", + ) + + def test_round_trip_wrclass(self) -> None: + """WRCLASS (writeObject) instance: class name survives round-trip.""" + original, re_parsed = self._round_trip("test_readFields.ser") + self.assertIsInstance(re_parsed, JavaInstance) + orig_cd = original.get_class() + new_cd = re_parsed.get_class() + self.assertEqual(orig_cd.name, new_cd.name) + + def test_round_trip_class_token(self) -> None: + """TC_CLASS token: class name survives round-trip.""" + original, re_parsed = self._round_trip("testClass.ser") + self.assertIsInstance(re_parsed, JavaClass) + self.assertEqual(re_parsed.name, original.name) + + def test_multi_object_stream(self) -> None: + """Multiple objects in one stream: all survive round-trip.""" + obj_a = self.load_file("testJapan.ser") + obj_b = self.load_file("testBoolIntLong.ser") + serialized = javaobj.dumps(obj_a, obj_b) + result = javaobj.loads(serialized) + self.assertIsInstance(result, list) + self.assertEqual(len(result), 2) + self.assertIsInstance(result[0], JavaString) + self.assertIsInstance(result[1], JavaInstance) + + # ------------------------------------------------------------------------------ # GZip decompression test