Skip to content

Commit

Permalink
[FLINK-22865][python] Optimize state serialize/deserialize in PyFlink
Browse files Browse the repository at this point in the history
This closes apache#16069.
  • Loading branch information
HuangXingBo authored and rkhachatryan committed Jul 20, 2021
1 parent 3460eeb commit 98c7837
Show file tree
Hide file tree
Showing 32 changed files with 359 additions and 178 deletions.
4 changes: 4 additions & 0 deletions flink-python/pyflink/common/serializer.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,10 @@ def _get_coder(self):
deserialize_func = self.deserialize

class CoderAdapter(object):
def get_impl(self):
return CoderAdapterIml()

class CoderAdapterIml(object):

def encode_nested(self, element):
bytes_io = BytesIO()
Expand Down
7 changes: 3 additions & 4 deletions flink-python/pyflink/datastream/data_stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -943,8 +943,7 @@ def __init__(self, reduce_function: ReduceFunction):

def open(self, runtime_context: RuntimeContext):
self._reduce_value_state = runtime_context.get_state(
ValueStateDescriptor("_reduce_state" + str(uuid.uuid4()),
Types.PICKLED_BYTE_ARRAY()))
ValueStateDescriptor("_reduce_state" + str(uuid.uuid4()), output_type))
self._reduce_function.open(runtime_context)
from pyflink.fn_execution.datastream.runtime_context import StreamingRuntimeContext
self._in_batch_execution_mode = \
Expand Down Expand Up @@ -1148,7 +1147,7 @@ def get_execution_environment(self):
return self._keyed_stream.get_execution_environment()

def get_input_type(self):
return self._keyed_stream.get_type()
return _from_java_type(self._keyed_stream._original_data_type_info.get_java_type_info())

def trigger(self, trigger: Trigger):
"""
Expand Down Expand Up @@ -1212,7 +1211,7 @@ def _get_result_data_stream(
self.get_execution_environment())
window_serializer = self._window_assigner.get_window_serializer()
window_state_descriptor = ListStateDescriptor(
"window-contents", Types.PICKLED_BYTE_ARRAY())
"window-contents", self.get_input_type())
window_operation_descriptor = WindowOperationDescriptor(
self._window_assigner,
self._window_trigger,
Expand Down
26 changes: 1 addition & 25 deletions flink-python/pyflink/datastream/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

from typing import TypeVar, Generic, Iterable, List, Iterator, Dict, Tuple

from pyflink.common.typeinfo import TypeInformation, Types, PickledBytesTypeInfo
from pyflink.common.typeinfo import TypeInformation, Types

__all__ = [
'ValueStateDescriptor',
Expand Down Expand Up @@ -316,10 +316,6 @@ def __init__(self, name: str, value_type_info: TypeInformation):
:param name: The name of the state.
:param value_type_info: the type information of the state.
"""
if not isinstance(value_type_info, PickledBytesTypeInfo):
raise ValueError("The type information of the value could only be PickledBytesTypeInfo "
"(created via Types.PICKLED_BYTE_ARRAY()) currently, got %s."
% type(value_type_info))
super(ValueStateDescriptor, self).__init__(name, value_type_info)


Expand All @@ -336,10 +332,6 @@ def __init__(self, name: str, elem_type_info: TypeInformation):
:param name: The name of the state.
:param elem_type_info: the type information of the state element.
"""
if not isinstance(elem_type_info, PickledBytesTypeInfo):
raise ValueError("The type information of the element could only be "
"PickledBytesTypeInfo (created via Types.PICKLED_BYTE_ARRAY()) "
"currently, got %s" % type(elem_type_info))
super(ListStateDescriptor, self).__init__(name, Types.LIST(elem_type_info))


Expand All @@ -357,14 +349,6 @@ def __init__(self, name: str, key_type_info: TypeInformation, value_type_info: T
:param key_type_info: The type information of the key.
:param value_type_info: the type information of the value.
"""
if not isinstance(key_type_info, PickledBytesTypeInfo):
raise ValueError("The type information of the key could only be PickledBytesTypeInfo "
"(created via Types.PICKLED_BYTE_ARRAY()) currently, got %s"
% type(key_type_info))
if not isinstance(value_type_info, PickledBytesTypeInfo):
raise ValueError("The type information of the value could only be PickledBytesTypeInfo "
"(created via Types.PICKLED_BYTE_ARRAY()) currently, got %s"
% type(value_type_info))
super(MapStateDescriptor, self).__init__(name, Types.MAP(key_type_info, value_type_info))


Expand Down Expand Up @@ -392,10 +376,6 @@ def __init__(self,
reduce_function = ReduceFunctionWrapper(reduce_function) # type: ignore
else:
raise TypeError("The input must be a ReduceFunction or a callable function!")
if not isinstance(type_info, PickledBytesTypeInfo):
raise ValueError("The type information of the state could only be PickledBytesTypeInfo "
"(created via Types.PICKLED_BYTE_ARRAY()) currently, got %s"
% type(type_info))
self._reduce_function = reduce_function

def get_reduce_function(self):
Expand All @@ -418,10 +398,6 @@ def __init__(self,
from pyflink.datastream.functions import AggregateFunction
if not isinstance(agg_function, AggregateFunction):
raise TypeError("The input must be a pyflink.datastream.functions.AggregateFunction!")
if not isinstance(state_type_info, PickledBytesTypeInfo):
raise ValueError("The type information of the state could only be PickledBytesTypeInfo "
"(created via Types.PICKLED_BYTE_ARRAY()) currently, got %s"
% type(state_type_info))
self._agg_function = agg_function

def get_agg_function(self):
Expand Down
35 changes: 15 additions & 20 deletions flink-python/pyflink/datastream/tests/test_data_stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,9 +214,9 @@ def __init__(self):

def open(self, runtime_context: RuntimeContext):
self.pre1 = runtime_context.get_state(
ValueStateDescriptor("pre1", Types.PICKLED_BYTE_ARRAY()))
ValueStateDescriptor("pre1", Types.STRING()))
self.pre2 = runtime_context.get_state(
ValueStateDescriptor("pre2", Types.PICKLED_BYTE_ARRAY()))
ValueStateDescriptor("pre2", Types.STRING()))

def map1(self, value):
if value[0] == 'b':
Expand Down Expand Up @@ -414,7 +414,7 @@ def __init__(self):

def open(self, runtime_context: RuntimeContext):
self.state = runtime_context.get_state(
ValueStateDescriptor("test_state", Types.PICKLED_BYTE_ARRAY()))
ValueStateDescriptor("test_state", Types.INT()))

def map(self, value):
if value[0] == 'a':
Expand Down Expand Up @@ -467,7 +467,7 @@ def __init__(self):

def open(self, runtime_context: RuntimeContext):
self.state = runtime_context.get_state(
ValueStateDescriptor("test_state", Types.PICKLED_BYTE_ARRAY()))
ValueStateDescriptor("test_state", Types.INT()))

def flat_map(self, value):
state_value = self.state.value()
Expand Down Expand Up @@ -511,7 +511,7 @@ def __init__(self):

def open(self, runtime_context: RuntimeContext):
self.state = runtime_context.get_state(
ValueStateDescriptor("test_state", Types.PICKLED_BYTE_ARRAY()))
ValueStateDescriptor("test_state", Types.INT()))

def filter(self, value):
state_value = self.state.value()
Expand Down Expand Up @@ -708,15 +708,11 @@ def __init__(self):
self.map_state = None

def open(self, runtime_context: RuntimeContext):
value_state_descriptor = ValueStateDescriptor('value_state',
Types.PICKLED_BYTE_ARRAY())
value_state_descriptor = ValueStateDescriptor('value_state', Types.INT())
self.value_state = runtime_context.get_state(value_state_descriptor)
list_state_descriptor = ListStateDescriptor('list_state',
Types.PICKLED_BYTE_ARRAY())
list_state_descriptor = ListStateDescriptor('list_state', Types.INT())
self.list_state = runtime_context.get_list_state(list_state_descriptor)
map_state_descriptor = MapStateDescriptor('map_state',
Types.PICKLED_BYTE_ARRAY(),
Types.PICKLED_BYTE_ARRAY())
map_state_descriptor = MapStateDescriptor('map_state', Types.INT(), Types.STRING())
self.map_state = runtime_context.get_map_state(map_state_descriptor)

def process_element(self, value, ctx):
Expand Down Expand Up @@ -784,7 +780,7 @@ def __init__(self):
def open(self, runtime_context: RuntimeContext):
self.reducing_state = runtime_context.get_reducing_state(
ReducingStateDescriptor(
'reducing_state', lambda i, i2: i + i2, Types.PICKLED_BYTE_ARRAY()))
'reducing_state', lambda i, i2: i + i2, Types.INT()))

def process_element(self, value, ctx):
self.reducing_state.add(value[0])
Expand Down Expand Up @@ -828,7 +824,7 @@ def __init__(self):
def open(self, runtime_context: RuntimeContext):
self.aggregating_state = runtime_context.get_aggregating_state(
AggregatingStateDescriptor(
'aggregating_state', MyAggregateFunction(), Types.PICKLED_BYTE_ARRAY()))
'aggregating_state', MyAggregateFunction(), Types.INT()))

def process_element(self, value, ctx):
self.aggregating_state.add(value[0])
Expand Down Expand Up @@ -1371,7 +1367,7 @@ def __init__(self):

def open(self, runtime_context: RuntimeContext):
self.map_state = runtime_context.get_map_state(
MapStateDescriptor("map", Types.PICKLED_BYTE_ARRAY(), Types.PICKLED_BYTE_ARRAY()))
MapStateDescriptor("map", Types.STRING(), Types.BOOLEAN()))

def flat_map1(self, value):
yield str(value[0] + 1)
Expand All @@ -1391,8 +1387,7 @@ def __init__(self):

def open(self, runtime_context: RuntimeContext):
self.timer_registered = False
self.count_state = runtime_context.get_state(ValueStateDescriptor(
"count", Types.PICKLED_BYTE_ARRAY()))
self.count_state = runtime_context.get_state(ValueStateDescriptor("count", Types.INT()))

def process_element1(self, value, ctx: 'KeyedCoProcessFunction.Context'):
if not self.timer_registered:
Expand Down Expand Up @@ -1426,7 +1421,7 @@ def __init__(self):

def open(self, runtime_context: RuntimeContext):
self.state = runtime_context.get_state(
ValueStateDescriptor("test_state", Types.PICKLED_BYTE_ARRAY()))
ValueStateDescriptor("test_state", Types.INT()))

def reduce(self, value1, value2):
state_value = self.state.value()
Expand Down Expand Up @@ -1454,7 +1449,7 @@ class SimpleCountWindowTrigger(Trigger[tuple, CountWindow]):
def __init__(self):
self._window_size = 3
self._count_state_descriptor = ReducingStateDescriptor(
"trigger_counter", lambda a, b: a + b, Types.PICKLED_BYTE_ARRAY())
"trigger_counter", lambda a, b: a + b, Types.BIG_INT())

def on_element(self,
element: tuple,
Expand Down Expand Up @@ -1494,7 +1489,7 @@ def __init__(self):
self._window_id = 0
self._window_size = 3
self._counter_state_descriptor = ReducingStateDescriptor(
"assigner_counter", lambda a, b: a + b, Types.PICKLED_BYTE_ARRAY())
"assigner_counter", lambda a, b: a + b, Types.BIG_INT())

def assign_windows(self,
element: tuple,
Expand Down
22 changes: 8 additions & 14 deletions flink-python/pyflink/datastream/window.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,22 +153,19 @@ def __init__(self):

def serialize(self, element: TimeWindow, stream: BytesIO) -> None:
if self._underlying_coder is None:
self._underlying_coder = self._get_coder()
self._underlying_coder = self._get_coder().get_impl()
bytes_data = self._underlying_coder.encode(element)
stream.write(bytes_data)

def deserialize(self, stream: BytesIO) -> TimeWindow:
if self._underlying_coder is None:
self._underlying_coder = self._get_coder()
self._underlying_coder = self._get_coder().get_impl()
bytes_data = stream.read(16)
return self._underlying_coder.decode(bytes_data)

def _get_coder(self):
try:
from pyflink.fn_execution import coder_impl_fast as coder_impl
except:
from pyflink.fn_execution import coder_impl_slow as coder_impl
return coder_impl.TimeWindowCoderImpl()
from pyflink.fn_execution import coders
return coders.TimeWindowCoder()


class CountWindowSerializer(TypeSerializer[CountWindow]):
Expand All @@ -178,22 +175,19 @@ def __init__(self):

def serialize(self, element: CountWindow, stream: BytesIO) -> None:
if self._underlying_coder is None:
self._underlying_coder = self._get_coder()
self._underlying_coder = self._get_coder().get_impl()
bytes_data = self._underlying_coder.encode(element)
stream.write(bytes_data)

def deserialize(self, stream: BytesIO) -> CountWindow:
if self._underlying_coder is None:
self._underlying_coder = self._get_coder()
self._underlying_coder = self._get_coder().get_impl()
bytes_data = stream.read(8)
return self._underlying_coder.decode(bytes_data)

def _get_coder(self):
try:
from pyflink.fn_execution import coder_impl_fast as coder_impl
except:
from pyflink.fn_execution import coder_impl_slow as coder_impl
return coder_impl.CountWindowCoderImpl()
from pyflink.fn_execution import coders
return coders.CountWindowCoder()


T = TypeVar('T')
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ from apache_beam.coders.coder_impl cimport InputStream as BInputStream
from apache_beam.coders.coder_impl cimport OutputStream as BOutputStream
from apache_beam.coders.coder_impl cimport StreamCoderImpl

from pyflink.fn_execution.beam.beam_stream cimport BeamInputStream
from pyflink.fn_execution.beam.beam_stream_fast cimport BeamInputStream
from pyflink.fn_execution.stream_fast cimport InputStream

cdef class PassThroughLengthPrefixCoderImpl(StreamCoderImpl):
Expand Down Expand Up @@ -59,9 +59,11 @@ cdef class PassThroughPrefixCoderImpl(StreamCoderImpl):
# create InputStream
data_input_stream = InputStream()
data_input_stream._input_data = <char*?>in_stream.allc
in_stream.pos = size
data_input_stream._input_pos = in_stream.pos

return self._value_coder.decode_from_stream(data_input_stream, size)
result = self._value_coder.decode_from_stream(data_input_stream, size)
in_stream.pos = data_input_stream._input_pos
return result

cdef void _write_data_output_stream(self, BOutputStream out_stream):
cdef OutputStream data_out_stream
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,8 @@

from apache_beam.coders.coder_impl import StreamCoderImpl, create_InputStream, create_OutputStream

from pyflink.fn_execution.stream_slow import OutputStream, InputStream
from pyflink.fn_execution.stream_slow import OutputStream
from pyflink.fn_execution.beam.beam_stream_slow import BeamInputStream


class PassThroughLengthPrefixCoderImpl(StreamCoderImpl):
Expand Down Expand Up @@ -50,7 +51,7 @@ def encode_to_stream(self, value, out_stream: create_OutputStream, nested):
self._data_output_stream.clear()

def decode_from_stream(self, in_stream: create_InputStream, nested):
data_input_stream = InputStream(in_stream.read_all(False))
data_input_stream = BeamInputStream(in_stream)
return self._value_coder.decode_from_stream(data_input_stream)

def __repr__(self):
Expand Down
29 changes: 1 addition & 28 deletions flink-python/pyflink/fn_execution/beam/beam_coders.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
################################################################################
import pickle
from typing import Any

from apache_beam.coders import Coder, coder_impl
from apache_beam.coders import Coder
from apache_beam.coders.coders import FastCoder, LengthPrefixCoder
from apache_beam.portability import common_urns
from apache_beam.typehints import typehints
Expand Down Expand Up @@ -93,27 +90,3 @@ def __ne__(self, other):

def __hash__(self):
return hash(self._internal_coder)


class DataViewFilterCoder(FastCoder):

def to_type_hint(self):
return Any

def __init__(self, udf_data_view_specs):
self._udf_data_view_specs = udf_data_view_specs

def filter_data_views(self, row):
i = 0
for specs in self._udf_data_view_specs:
for spec in specs:
row[i][spec.field_index] = None
i += 1
return row

def _create_impl(self):
filter_data_views = self.filter_data_views
dumps = pickle.dumps
HIGHEST_PROTOCOL = pickle.HIGHEST_PROTOCOL
return coder_impl.CallbackCoderImpl(
lambda x: dumps(filter_data_views(x), HIGHEST_PROTOCOL), pickle.loads)
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ from libc.stdint cimport *
from apache_beam.utils.windowed_value cimport WindowedValue

from pyflink.fn_execution.coder_impl_fast cimport LengthPrefixBaseCoderImpl
from pyflink.fn_execution.beam.beam_stream cimport BeamInputStream, BeamOutputStream
from pyflink.fn_execution.beam.beam_stream_fast cimport BeamInputStream, BeamOutputStream
from pyflink.fn_execution.beam.beam_coder_impl_fast cimport InputStreamWrapper, BeamCoderImpl
from pyflink.fn_execution.table.operations import BundleOperation

Expand Down
Loading

0 comments on commit 98c7837

Please sign in to comment.