Skip to content

Commit

Permalink
Issue-374 Add unit tests & docs
Browse files Browse the repository at this point in the history
  • Loading branch information
Rafal Mucha committed Mar 18, 2024
1 parent 2a32b8a commit a642fa5
Show file tree
Hide file tree
Showing 6 changed files with 52 additions and 14 deletions.
4 changes: 2 additions & 2 deletions py4j-java/src/test/java/py4j/commands/DirCommandTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -69,8 +69,8 @@ public class DirCommandTest {
{
// Defined in ExampleClass
ExampleClassMethods.addAll(Arrays.asList(new String[] { "method1", "method2", "method3", "method4", "method5",
"method6", "method7", "method8", "method9", "method10", "method11", "getList", "getField1", "setField1",
"getStringArray", "getIntArray", "callHello", "callHello2", "static_method", "getInteger",
"method6", "method7", "method8", "method9", "method10", "method11", "method12", "getList", "getField1",
"setField1", "getStringArray", "getIntArray", "callHello", "callHello2", "static_method", "getInteger",
"getBrokenStream", "getStream", "sleepFirstTimeOnly" }));
// Defined in Object
ExampleClassMethods.addAll(Arrays
Expand Down
15 changes: 15 additions & 0 deletions py4j-java/src/test/java/py4j/examples/ExampleClass.java
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
import java.nio.channels.Channels;
import java.nio.channels.ReadableByteChannel;
import java.util.ArrayList;
import java.util.HashSet;
import java.util.List;

public class ExampleClass {
Expand Down Expand Up @@ -174,6 +175,20 @@ public BigInteger method11(BigInteger bi) {
return bi.add(new BigInteger("1"));
}

public int method12(HashSet<Object> set) {
Object element = set.stream().findAny().get();
if (element instanceof Long) {
return 4;
}
if (element instanceof Integer) {
return 1;
}
if (element instanceof String) {
return 2;
}
return 3;
}

@SuppressWarnings("unused")
private int private_method() {
return 0;
Expand Down
17 changes: 8 additions & 9 deletions py4j-python/src/py4j/protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,6 @@
from base64 import standard_b64encode, standard_b64decode

from decimal import Decimal
from enum import Enum
from collections import namedtuple

from py4j.compat import (
long, basestring, unicode, bytearray2,
Expand Down Expand Up @@ -72,11 +70,12 @@
ITERATOR_TYPE = "g"
PYTHON_PROXY_TYPE = "f"

class JavaType(Enum):
PRIMITIVE_INT = INTEGER_TYPE
PRIMITIVE_LONG = LONG_TYPE

TypeInt = namedtuple('TypeInt', ['value', 'java_type'])
class TypeHint:
"""Enables users to provide a hint to the Python to Java converter specifying the accurate data type for a given value.
Essential to enforce i.e. correct number type, like Long."""
def __init__(self, value, java_type):
self.value = value
self.java_type = java_type

# Protocol
END = "e"
Expand Down Expand Up @@ -281,12 +280,12 @@ def get_command_part(parameter, python_proxy_pool=None):

if parameter is None:
command_part = NULL_TYPE
elif isinstance(parameter, TypeHint):
command_part = parameter.java_type + smart_decode(parameter.value)
elif isinstance(parameter, bool):
command_part = BOOLEAN_TYPE + smart_decode(parameter)
elif isinstance(parameter, Decimal):
command_part = DECIMAL_TYPE + smart_decode(parameter)
elif isinstance(parameter, TypeInt):
command_part = parameter.java_type.value + smart_decode(parameter.value)
elif isinstance(parameter, int) and parameter <= JAVA_MAX_INT\
and parameter >= JAVA_MIN_INT:
command_part = INTEGER_TYPE + smart_decode(parameter)
Expand Down
1 change: 1 addition & 0 deletions py4j-python/src/py4j/tests/java_dir_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
# overloaded
"method10",
"method11",
"method12",
"getList",
"getField1",
"setField1",
Expand Down
8 changes: 5 additions & 3 deletions py4j-python/src/py4j/tests/java_gateway_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,8 @@
set_default_callback_accept_timeout, GatewayConnectionGuard,
get_java_class)
from py4j.protocol import (
Py4JError, Py4JJavaError, Py4JNetworkError, decode_bytearray,
encode_bytearray, escape_new_line, unescape_new_line, smart_decode)
Py4JError, Py4JJavaError, Py4JNetworkError, TypeHint, LONG_TYPE,
decode_bytearray, encode_bytearray, escape_new_line, unescape_new_line, smart_decode)


SERVER_PORT = 25333
Expand Down Expand Up @@ -607,7 +607,7 @@ def internal():
class TypeConversionTest(unittest.TestCase):
def setUp(self):
self.p = start_example_app_process()
self.gateway = JavaGateway()
self.gateway = JavaGateway(auto_convert=True)

def tearDown(self):
safe_shutdown(self)
Expand All @@ -619,6 +619,8 @@ def testLongInt(self):
self.assertEqual(4, ex.method7(2147483648))
self.assertEqual(4, ex.method7(-2147483649))
self.assertEqual(4, ex.method7(long(2147483648)))
self.assertEqual(4, ex.method7(TypeHint(1, LONG_TYPE)))
self.assertEqual(4, ex.method12({TypeHint(1, LONG_TYPE)}))
self.assertEqual(long(4), ex.method8(3))
self.assertEqual(4, ex.method8(3))
self.assertEqual(long(4), ex.method8(long(3)))
Expand Down
21 changes: 21 additions & 0 deletions py4j-web/advanced_topics.rst
Original file line number Diff line number Diff line change
Expand Up @@ -726,6 +726,27 @@ Java methods slightly less efficient because in the worst case, Py4J needs to
go through all registered converters for all parameters. This is why automatic
conversion is disabled by default.

.. _explicit_conversion:

Explicit converting Python objects to Java primitives
-----------------------------------------------------

Sometimes, especially when ``auto_convert=True`` it is difficult to enforce correct type
passed from Python to Java. Then, ``TypeHint`` from ``py4j.protocol`` may be used.
``java_type`` argument of constructor should be one of Java types defined in ``py4j.protocol``.

So if you have method in Java like:

.. code-block:: java
void method(HashSet<Long> longs) {}
Then you can pass arguments with correct type to this method with ``TypeHint``

::

>>> set_with_longs = { TypeHint(1, LONG_TYPE), TypeHint(2, LONG_TYPE) }
>>> gateway.jvm.my.Class().method(set_with_longs)

.. _py4j_exceptions:

Expand Down

0 comments on commit a642fa5

Please sign in to comment.