Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
Added support for Python 3.x -- test suite passes on Python 3.1 and
2.6
  • Loading branch information
Alex Grönholm committed Mar 8, 2011
1 parent f121918 commit 36a1456
Show file tree
Hide file tree
Showing 12 changed files with 157 additions and 114 deletions.
25 changes: 25 additions & 0 deletions py4j-python/src/py4j/compat.py
@@ -0,0 +1,25 @@
#coding: utf-8
'''
Compatibility functions for unified behavior between Python 2.x and 3.x.
:author: Alex Grönholm
'''
import sys


if sys.version_info[0] < 3:
items = lambda d: d.items()
iteritems = lambda d: d.iteritems()
next = lambda x: x.next()
range = xrange
long = long
basestring = basestring
unicode = unicode
else:
items = lambda d: list(d.items())
iteritems = lambda d: d.items()
next = next
range = range
long = int
basestring = str
unicode = str
6 changes: 4 additions & 2 deletions py4j-python/src/py4j/finalizer.py
Expand Up @@ -10,6 +10,8 @@

from threading import RLock

from py4j.compat import items


class ThreadSafeFinalizer(object):
"""A `ThreadSafeFinalizer` is a global class used to register weak reference finalizers
Expand Down Expand Up @@ -57,7 +59,7 @@ def clear_finalizers(cls, clear_all=False):
if clear_all:
cls.finalizers.clear()
else:
for id, ref in cls.finalizers.items():
for id, ref in items(cls.finalizers):
if ref() is None:
cls.finalizers.pop(id, None)

Expand Down Expand Up @@ -104,7 +106,7 @@ def clear_finalizers(cls, clear_all=False):
if clear_all:
cls.finalizers.clear()
else:
for id, ref in cls.finalizers.items():
for id, ref in items(cls.finalizers):
if ref() is None:
cls.finalizers.pop(id, None)

Expand Down
61 changes: 32 additions & 29 deletions py4j-python/src/py4j/java_collections.py
Expand Up @@ -9,9 +9,11 @@
:author: Barthelemy Dagenais
'''
from collections import MutableMapping, Sequence, MutableSequence, MutableSet, Set
import sys

from py4j.java_gateway import JavaObject, JavaMember, get_method, JavaClass
from py4j.protocol import *
from py4j.compat import iteritems, next


class JavaIterator(JavaObject):
Expand All @@ -38,6 +40,8 @@ def next(self):
return self._methods[self._next_name]()
except Py4JError:
raise StopIteration()

__next__ = next


class JavaMap(JavaObject, MutableMapping):
Expand Down Expand Up @@ -70,24 +74,31 @@ def __contains__(self, key):
def __str__(self):
return self.__repr__()

def __repr__(self):
# TODO Make it more efficient/pythonic
# TODO Debug why strings are not outputed with apostrophes.
if len(self) == 0:
return '{}'
else:
srep = '{'
for key in self:
srep += repr(key) + ': ' + repr(self[key]) + ', '
# def __repr__(self):
# # TODO Make it more efficient/pythonic
# # TODO Debug why strings are not outputed with apostrophes.
# if len(self) == 0:
# return '{}'
# else:
# srep = '{'
# for key in self:
# srep += repr(key) + ': ' + repr(self[key]) + ', '
#
# return srep[:-2] + '}'

return srep[:-2] + '}'
def __repr__(self):
items = ('{0}: {1}'.format(repr(k), repr(v)) for k, v in iteritems(self))
return '{{{0}}}'.format(', '.join(items))


class JavaSet(JavaObject, MutableSet):
"""Maps a Python Set to a Java Set.
All operations possible on a Python set are implemented."""

__EMPTY_SET = 'set([])' if sys.version_info[0] < 3 else 'set()'
__SET_TEMPLATE = 'set([{0}])' if sys.version_info[0] < 3 else '{{{0}}}'

def __init__(self, target_id, gateway_client):
JavaObject.__init__(self, target_id, gateway_client)
self._add = get_method(self, 'add')
Expand Down Expand Up @@ -122,14 +133,9 @@ def __str__(self):
return self.__repr__()

def __repr__(self):
if len(self) == 0:
return 'set([])'
else:
srep = 'set(['
for value in self:
srep += repr(value) + ', '

return srep[:-2] + '])'
if len(self):
return self.__SET_TEMPLATE.format(', '.join((repr(x) for x in self)))
return self.__EMPTY_SET


class JavaArray(JavaObject, Sequence):
Expand Down Expand Up @@ -182,7 +188,7 @@ def __getitem__(self, key):
def __repl_item_from_slice(self, range, iterable):
value_iter = iter(iterable)
for i in range:
value = value_iter.next()
value = next(value_iter)
self.__set_item(i, value)

def __set_item(self, key, value):
Expand Down Expand Up @@ -225,6 +231,9 @@ class JavaList(JavaObject, MutableSequence):
will create a copy of the list on the JVM. Slicing is thus not equivalent to subList(), because
a modification to a slice such as the addition of a new element will not affect the original
list."""

__EMPTY_SET = '[]' if sys.version_info[0] < 3 else '{}'
__REPR_TEMPLATE = 'set([%s])' if sys.version_info[0] < 3 else '{%s}'

def __init__(self, target_id, gateway_client):
JavaObject.__init__(self, target_id, gateway_client)
Expand Down Expand Up @@ -263,7 +272,7 @@ def __set_item_from_slice(self, indices, iterable):
# First replace and delete if from_slice > to_slice
for i in range(*indices):
try:
value = value_iter.next()
value = next(value_iter)
self.__set_item(i, value)
except StopIteration:
self.__del_item(i)
Expand All @@ -284,7 +293,7 @@ def __insert_item_from_slice(self, indices, iterable):
def __repl_item_from_slice(self, range, iterable):
value_iter = iter(iterable)
for i in range:
value = value_iter.next()
value = value = next(value_iter)
self.__set_item(i, value)

def __append_item_from_slice(self, range, iterable):
Expand Down Expand Up @@ -426,14 +435,8 @@ def __str__(self):
return self.__repr__()

def __repr__(self):
if len(self) == 0:
return '[]'
else:
srep = '['
for elem in self:
srep += repr(elem) + ', '

return srep[:-2] + ']'
items = (repr(x) for x in self)
return '[{0}]'.format(', '.join(items))


class SetConverter(object):
Expand Down
5 changes: 3 additions & 2 deletions py4j-python/src/py4j/java_gateway.py
Expand Up @@ -20,6 +20,7 @@

from py4j.finalizer import ThreadSafeFinalizer
from py4j.protocol import *
from py4j.compat import range


class NullHandler(logging.Handler):
Expand Down Expand Up @@ -258,7 +259,7 @@ def start(self):
"""Starts the connection by connecting to the `address` and the `port`"""
self.socket.connect((self.address, self.port))
self.is_connected = True
self.stream = self.socket.makefile('r', 0)
self.stream = self.socket.makefile('rb', 0)

def close(self, throw_exception=False):
"""Closes the connection by closing the socket."""
Expand Down Expand Up @@ -703,7 +704,7 @@ def run(self):

while not self.is_shutdown:
socket, _ = self.server_socket.accept()
input = socket.makefile('r', 0)
input = socket.makefile('rb', 0)
connection = CallbackConnection(self.pool, input, socket, self.gateway_client)
with self.lock:
if not self.is_shutdown:
Expand Down
3 changes: 3 additions & 0 deletions py4j-python/src/py4j/protocol.py
Expand Up @@ -16,6 +16,9 @@
:author: Barthelemy Dagenais
'''
from py4j.compat import long, basestring


ESCAPE_CHAR = "\\"

# Entry point
Expand Down
5 changes: 3 additions & 2 deletions py4j-python/src/py4j/tests/java_array_test.py
Expand Up @@ -3,6 +3,7 @@
@author: Barthelemy Dagenais
'''
from __future__ import unicode_literals
from multiprocessing.process import Process
import subprocess
import time
Expand Down Expand Up @@ -46,12 +47,12 @@ def testArray(self):
self.assertEqual(3, len(array1))
self.assertEqual(4, len(array2))

self.assertEqual(u'333', array1[2])
self.assertEqual('333', array1[2])
self.assertEqual(5, array2[1])

array1[2] = 'aaa'
array2[1] = 6
self.assertEqual(u'aaa', array1[2])
self.assertEqual('aaa', array1[2])
self.assertEqual(6, array2[1])

new_array = array2[1:3]
Expand Down
3 changes: 2 additions & 1 deletion py4j-python/src/py4j/tests/java_callback_test.py
Expand Up @@ -11,6 +11,7 @@

from py4j.java_gateway import JavaGateway, PythonProxyPool
from py4j.tests.java_gateway_test import PY4J_JAVA_PATH
from py4j.compat import range


def start_example_server():
Expand Down Expand Up @@ -83,7 +84,7 @@ class TestPool(unittest.TestCase):

def testPool(self):
pool = PythonProxyPool()
runners = [Runner(xrange(0, 10000), pool) for _ in xrange(0, 3)]
runners = [Runner(range(0, 10000), pool) for _ in range(0, 3)]
for runner in runners:
runner.start()

Expand Down
34 changes: 18 additions & 16 deletions py4j-python/src/py4j/tests/java_gateway_test.py
Expand Up @@ -3,6 +3,7 @@
@author: barthelemy
'''
from __future__ import unicode_literals
from multiprocessing.process import Process
from socket import AF_INET, SOCK_STREAM, socket
from threading import Thread
Expand All @@ -17,6 +18,7 @@
from py4j.protocol import *
from py4j.java_gateway import JavaGateway, JavaMember, get_field, get_method, \
GatewayClient, set_field, java_import, JavaObject
from py4j.compat import range


SERVER_PORT = 25333
Expand Down Expand Up @@ -90,7 +92,7 @@ def tearDown(self):

def testEscape(self):
self.assertEqual("Hello\t\rWorld\n\\", unescape_new_line(escape_new_line("Hello\t\rWorld\n\\")))
self.assertEqual(u"Hello\t\rWorld\n\\", unescape_new_line(escape_new_line(u"Hello\t\rWorld\n\\")))
self.assertEqual("Hello\t\rWorld\n\\", unescape_new_line(escape_new_line("Hello\t\rWorld\n\\")))

def testProtocolSend(self):
testConnection = TestConnection()
Expand Down Expand Up @@ -179,7 +181,7 @@ def testException(self):
testSocket.sendall('yo\n'.encode('utf-8'))
testSocket.sendall('yro0\n'.encode('utf-8'))
testSocket.sendall('yo\n'.encode('utf-8'))
testSocket.sendall('x\n')
testSocket.sendall(b'x\n')
testSocket.close()
time.sleep(1)

Expand Down Expand Up @@ -235,13 +237,13 @@ def testNoneArg(self):

def testUnicode(self):
sb = self.gateway.jvm.java.lang.StringBuffer()
sb.append(u'\r\n\tHello\r\n\t')
self.assertEqual(u'\r\n\tHello\r\n\t', sb.toString())
sb.append('\r\n\tHello\r\n\t')
self.assertEqual('\r\n\tHello\r\n\t', sb.toString())

def testEscape(self):
sb = self.gateway.jvm.java.lang.StringBuffer()
sb.append('\r\n\tHello\r\n\t')
self.assertEqual(u'\r\n\tHello\r\n\t', sb.toString())
self.assertEqual('\r\n\tHello\r\n\t', sb.toString())


class FieldTest(unittest.TestCase):
Expand All @@ -260,7 +262,7 @@ def testAutoField(self):
self.assertEqual(ex.field10, 10)
sb = ex.field20
sb.append('Hello')
self.assertEqual(u'Hello', sb.toString())
self.assertEqual('Hello', sb.toString())
self.assertTrue(ex.field21 == None)

def testNoField(self):
Expand All @@ -285,7 +287,7 @@ def testNoAutoField(self):
ex._auto_field = True
sb = ex.field20
sb.append('Hello')
self.assertEqual(u'Hello', sb.toString())
self.assertEqual('Hello', sb.toString())

try:
get_field(ex, 'field20')
Expand All @@ -302,7 +304,7 @@ def testSetField(self):

sb = self.gateway.jvm.java.lang.StringBuffer('Hello World!')
set_field(ex, 'field21', sb)
self.assertEquals(get_field(ex, 'field21').toString(), u'Hello World!')
self.assertEquals(get_field(ex, 'field21').toString(), 'Hello World!')

try:
set_field(ex, 'field1', 123)
Expand Down Expand Up @@ -458,22 +460,22 @@ def testConstructors(self):
sb = jvm.java.lang.StringBuffer('hello')
sb.append('hello world')
sb.append(1)
self.assertEqual(sb.toString(), u'hellohello world1')
self.assertEqual(sb.toString(), 'hellohello world1')

l1 = jvm.java.util.ArrayList()
l1.append('hello world')
l1.append(1)
self.assertEqual(2, len(l1))
self.assertEqual(u'hello world', l1[0])
l2 = [u'hello world', 1]
self.assertEqual('hello world', l1[0])
l2 = ['hello world', 1]
print(l1)
print(l2)
self.assertEqual(str(l2), str(l1))

def testStaticMethods(self):
System = self.gateway.jvm.java.lang.System
self.assertTrue(System.currentTimeMillis() > 0)
self.assertEqual(u'123', self.gateway.jvm.java.lang.String.valueOf(123))
self.assertEqual('123', self.gateway.jvm.java.lang.String.valueOf(123))

def testStaticFields(self):
Short = self.gateway.jvm.java.lang.Short
Expand All @@ -483,7 +485,7 @@ def testStaticFields(self):

def testDefaultImports(self):
self.assertTrue(self.gateway.jvm.System.currentTimeMillis() > 0)
self.assertEqual(u'123', self.gateway.jvm.String.valueOf(123))
self.assertEqual('123', self.gateway.jvm.String.valueOf(123))

def testNone(self):
ex = self.gateway.entry_point.getNewExample()
Expand Down Expand Up @@ -581,9 +583,9 @@ def testStress(self):
# runner2 = Runner(xrange(1000,1000000,10000), self.gateway)
# runner3 = Runner(xrange(1000,1000000,10000), self.gateway)
# Small stress test
runner1 = Runner(xrange(1, 10000, 1000), self.gateway)
runner2 = Runner(xrange(1000, 1000000, 100000), self.gateway)
runner3 = Runner(xrange(1000, 1000000, 100000), self.gateway)
runner1 = Runner(range(1, 10000, 1000), self.gateway)
runner2 = Runner(range(1000, 1000000, 100000), self.gateway)
runner3 = Runner(range(1000, 1000000, 100000), self.gateway)
runner1.start()
runner2.start()
runner3.start()
Expand Down

0 comments on commit 36a1456

Please sign in to comment.