Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Issue 374 Add type hint #510

Open
wants to merge 4 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 10 additions & 10 deletions .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -22,15 +22,15 @@ jobs:
- os: macos-14
java-version: 8
python-version: '3.12'
- os: windows-2019
java-version: 11
python-version: '3.10'
- os: ubuntu-22.04
java-version: 17
python-version: '3.11'
- os: ubuntu-22.04
java-version: 21
python-version: '3.9'
# - os: windows-2019
# java-version: 11
# python-version: '3.10'
# - os: ubuntu-22.04
# java-version: 17
# python-version: '3.11'
# - os: ubuntu-22.04
# java-version: 21
# python-version: '3.9'
name: Python ${{ matrix.python-version }}, Java ${{ matrix.java-version }}, ${{ matrix.os }}
steps:
- uses: actions/checkout@1e204e9a9253d643386038d443f96446fa156a97 # pin@v2.3.5
Expand Down Expand Up @@ -82,7 +82,7 @@ jobs:
echo `java -version`
echo $JAVA_HOME
# Java TLS tests are disabled until they can be fixed (refs #441)
pytest -k "not java_tls_test." -vvv
pytest -k "not java_tls_test." -vvv -s

test-doc:
name: Documentation build
Expand Down
4 changes: 2 additions & 2 deletions py4j-java/src/test/java/py4j/commands/DirCommandTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -69,8 +69,8 @@ public class DirCommandTest {
{
// Defined in ExampleClass
ExampleClassMethods.addAll(Arrays.asList(new String[] { "method1", "method2", "method3", "method4", "method5",
"method6", "method7", "method8", "method9", "method10", "method11", "getList", "getField1", "setField1",
"getStringArray", "getIntArray", "callHello", "callHello2", "static_method", "getInteger",
"method6", "method7", "method8", "method9", "method10", "method11", "method12", "getList", "getField1",
"setField1", "getStringArray", "getIntArray", "callHello", "callHello2", "static_method", "getInteger",
"getBrokenStream", "getStream", "sleepFirstTimeOnly" }));
// Defined in Object
ExampleClassMethods.addAll(Arrays
Expand Down
15 changes: 15 additions & 0 deletions py4j-java/src/test/java/py4j/examples/ExampleClass.java
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
import java.nio.channels.Channels;
import java.nio.channels.ReadableByteChannel;
import java.util.ArrayList;
import java.util.HashSet;
import java.util.List;

public class ExampleClass {
Expand Down Expand Up @@ -174,6 +175,20 @@ public BigInteger method11(BigInteger bi) {
return bi.add(new BigInteger("1"));
}

public int method12(HashSet<Object> set) {
Object element = set.stream().findAny().get();
if (element instanceof Long) {
return 4;
}
if (element instanceof Integer) {
return 1;
}
if (element instanceof String) {
return 2;
}
return 3;
}

@SuppressWarnings("unused")
private int private_method() {
return 0;
Expand Down
9 changes: 9 additions & 0 deletions py4j-python/src/py4j/protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,13 @@
ITERATOR_TYPE = "g"
PYTHON_PROXY_TYPE = "f"

class TypeHint:
"""Enables users to provide a hint to the Python to Java converter specifying the accurate data type for a given value.
Essential to enforce i.e. correct number type, like Long."""
def __init__(self, value, java_type):
self.value = value
self.java_type = java_type

# Protocol
END = "e"
ERROR = "x"
Expand Down Expand Up @@ -273,6 +280,8 @@ def get_command_part(parameter, python_proxy_pool=None):

if parameter is None:
command_part = NULL_TYPE
elif isinstance(parameter, TypeHint):
command_part = parameter.java_type + smart_decode(parameter.value)
elif isinstance(parameter, bool):
command_part = BOOLEAN_TYPE + smart_decode(parameter)
elif isinstance(parameter, Decimal):
Expand Down
1 change: 1 addition & 0 deletions py4j-python/src/py4j/tests/java_dir_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
# overloaded
"method10",
"method11",
"method12",
"getList",
"getField1",
"setField1",
Expand Down
9 changes: 6 additions & 3 deletions py4j-python/src/py4j/tests/java_gateway_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,8 @@
set_default_callback_accept_timeout, GatewayConnectionGuard,
get_java_class)
from py4j.protocol import (
Py4JError, Py4JJavaError, Py4JNetworkError, decode_bytearray,
encode_bytearray, escape_new_line, unescape_new_line, smart_decode)
Py4JError, Py4JJavaError, Py4JNetworkError, TypeHint, LONG_TYPE,
decode_bytearray, encode_bytearray, escape_new_line, unescape_new_line, smart_decode)


SERVER_PORT = 25333
Expand Down Expand Up @@ -582,6 +582,7 @@ def testGCCollectNoMemoryManagement(self):
enable_memory_management=False))
gc.collect()
# Should have nothing in the finalizers
print(ThreadSafeFinalizer.finalizers)
self.assertEqual(len(ThreadSafeFinalizer.finalizers), 0)

def internal():
Expand All @@ -607,7 +608,7 @@ def internal():
class TypeConversionTest(unittest.TestCase):
def setUp(self):
self.p = start_example_app_process()
self.gateway = JavaGateway()
self.gateway = JavaGateway(auto_convert=True)

def tearDown(self):
safe_shutdown(self)
Expand All @@ -619,6 +620,8 @@ def testLongInt(self):
self.assertEqual(4, ex.method7(2147483648))
self.assertEqual(4, ex.method7(-2147483649))
self.assertEqual(4, ex.method7(long(2147483648)))
self.assertEqual(4, ex.method7(TypeHint(1, LONG_TYPE)))
self.assertEqual(4, ex.method12({TypeHint(1, LONG_TYPE)}))
self.assertEqual(long(4), ex.method8(3))
self.assertEqual(4, ex.method8(3))
self.assertEqual(long(4), ex.method8(long(3)))
Expand Down
21 changes: 21 additions & 0 deletions py4j-web/advanced_topics.rst
Original file line number Diff line number Diff line change
Expand Up @@ -726,6 +726,27 @@ Java methods slightly less efficient because in the worst case, Py4J needs to
go through all registered converters for all parameters. This is why automatic
conversion is disabled by default.

.. _explicit_conversion:

Explicit converting Python objects to Java primitives
-----------------------------------------------------

Sometimes, especially when ``auto_convert=True`` it is difficult to enforce correct type
passed from Python to Java. Then, ``TypeHint`` from ``py4j.protocol`` may be used.
``java_type`` argument of constructor should be one of Java types defined in ``py4j.protocol``.

So if you have method in Java like:

.. code-block:: java

void method(HashSet<Long> longs) {}

Then you can pass arguments with correct type to this method with ``TypeHint``

::

>>> set_with_longs = { TypeHint(1, LONG_TYPE), TypeHint(2, LONG_TYPE) }
>>> gateway.jvm.my.Class().method(set_with_longs)

.. _py4j_exceptions:

Expand Down
Loading