Skip to content

Commit

Permalink
[FLINK-12588][python] Add TableSchema for Python Table API.
Browse files Browse the repository at this point in the history
This closes apache#8561
  • Loading branch information
WeiZhong94 authored and sjwiesman committed Jun 26, 2019
1 parent 023b304 commit fc2ff3d
Show file tree
Hide file tree
Showing 12 changed files with 819 additions and 57 deletions.
2 changes: 2 additions & 0 deletions flink-python/pyflink/java_gateway.py
Expand Up @@ -115,6 +115,8 @@ def import_flink_view(gateway):
java_import(gateway.jvm, "org.apache.flink.table.sources.*")
java_import(gateway.jvm, "org.apache.flink.table.sinks.*")
java_import(gateway.jvm, "org.apache.flink.table.python.*")
java_import(gateway.jvm, "org.apache.flink.table.types.*")
java_import(gateway.jvm, "org.apache.flink.table.types.logical.*")
java_import(gateway.jvm, "org.apache.flink.python.bridge.*")
java_import(gateway.jvm, "org.apache.flink.api.common.typeinfo.TypeInformation")
java_import(gateway.jvm, "org.apache.flink.api.common.typeinfo.Types")
Expand Down
4 changes: 3 additions & 1 deletion flink-python/pyflink/table/__init__.py
Expand Up @@ -41,6 +41,7 @@
from pyflink.table.types import DataTypes, UserDefinedType, Row
from pyflink.table.window import Tumble, Session, Slide, Over
from pyflink.table.table_descriptor import Rowtime, Schema, OldCsv, FileSystem, Kafka, Elasticsearch
from pyflink.table.table_schema import TableSchema

__all__ = [
'TableEnvironment',
Expand All @@ -64,5 +65,6 @@
'UserDefinedType',
'Row',
'Kafka',
'Elasticsearch'
'Elasticsearch',
'TableSchema'
]
9 changes: 9 additions & 0 deletions flink-python/pyflink/table/table.py
Expand Up @@ -19,6 +19,7 @@

from py4j.java_gateway import get_method
from pyflink.java_gateway import get_gateway
from pyflink.table.table_schema import TableSchema

from pyflink.table.window import GroupWindow
from pyflink.util.utils import to_jarray
Expand Down Expand Up @@ -532,6 +533,14 @@ def insert_into(self, table_path, *table_path_continued):
j_table_path = to_jarray(gateway.jvm.String, table_path_continued)
self._j_table.insertInto(table_path, j_table_path)

def get_schema(self):
"""
Returns the :class:`TableSchema` of this table.
:return: The schema of this table.
"""
return TableSchema(j_table_schema=self._j_table.getSchema())

def print_schema(self):
"""
Prints the schema of this table to the console in a tree format.
Expand Down
26 changes: 26 additions & 0 deletions flink-python/pyflink/table/table_descriptor.py
Expand Up @@ -179,6 +179,19 @@ def __init__(self):
self._j_schema = gateway.jvm.Schema()
super(Schema, self).__init__(self._j_schema)

def schema(self, table_schema):
"""
Sets the schema with field names and the types. Required.
This method overwrites existing fields added with
:func:`~pyflink.table.table_descriptor.Schema.field`.
:param table_schema: The :class:`TableSchema` object.
:return: This schema object.
"""
self._j_schema = self._j_schema.schema(table_schema._j_table_schema)
return self

def field(self, field_name, field_type):
"""
Adds a field with the field name and the data type or type string. Required.
Expand Down Expand Up @@ -287,6 +300,19 @@ def line_delimiter(self, delimiter):
self._j_csv = self._j_csv.lineDelimiter(delimiter)
return self

def schema(self, table_schema):
"""
Sets the schema with field names and the types. Required.
This method overwrites existing fields added with
:func:`~pyflink.table.table_descriptor.OldCsv.field`.
:param table_schema: The :class:`TableSchema` object.
:return: This schema object.
"""
self._j_csv = self._j_csv.schema(table_schema._j_table_schema)
return self

def field(self, field_name, field_type):
"""
Adds a format field with the field name and the data type or type string. Required.
Expand Down
161 changes: 161 additions & 0 deletions flink-python/pyflink/table/table_schema.py
@@ -0,0 +1,161 @@
################################################################################
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
################################################################################
import sys

from pyflink.java_gateway import get_gateway
from pyflink.table.types import _to_java_type, _from_java_type
from pyflink.util.utils import to_jarray

if sys.version >= '3':
unicode = str

__all__ = ['TableSchema']


class TableSchema(object):
"""
A table schema that represents a table's structure with field names and data types.
"""

def __init__(self, field_names=None, data_types=None, j_table_schema=None):
if j_table_schema is None:
gateway = get_gateway()
j_field_names = to_jarray(gateway.jvm.String, field_names)
j_data_types = to_jarray(gateway.jvm.TypeInformation,
[_to_java_type(item) for item in data_types])
self._j_table_schema = gateway.jvm.TableSchema(j_field_names, j_data_types)
else:
self._j_table_schema = j_table_schema

def copy(self):
"""
Returns a deep copy of the table schema.
:return: A deep copy of the table schema.
"""
return TableSchema(j_table_schema=self._j_table_schema.copy())

def get_field_data_types(self):
"""
Returns all field data types as a list.
:return: A list of all field data types.
"""
return [_from_java_type(item) for item in self._j_table_schema.getFieldDataTypes()]

def get_field_data_type(self, field):
"""
Returns the specified data type for the given field index or field name.
:param field: The index of the field or the name of the field.
:return: The data type of the specified field.
"""
if not isinstance(field, (int, str, unicode)):
raise TypeError("Expected field index or field name, got %s" % type(field))
optional_result = self._j_table_schema.getFieldDataType(field)
if optional_result.isPresent():
return _from_java_type(optional_result.get())
else:
return None

def get_field_count(self):
"""
Returns the number of fields.
:return: The number of fields.
"""
return self._j_table_schema.getFieldCount()

def get_field_names(self):
"""
Returns all field names as a list.
:return: The list of all field names.
"""
return list(self._j_table_schema.getFieldNames())

def get_field_name(self, field_index):
"""
Returns the specified name for the given field index.
:param field_index: The index of the field.
:return: The field name.
"""
optional_result = self._j_table_schema.getFieldName(field_index)
if optional_result.isPresent():
return optional_result.get()
else:
return None

def to_row_data_type(self):
"""
Converts a table schema into a (nested) data type describing a
:func:`pyflink.table.types.DataTypes.ROW`.
:return: The row data type.
"""
return _from_java_type(self._j_table_schema.toRowDataType())

def __repr__(self):
return self._j_table_schema.toString()

def __eq__(self, other):
return isinstance(other, self.__class__) and self._j_table_schema == other._j_table_schema

def __hash__(self):
return self._j_table_schema.hashCode()

def __ne__(self, other):
return not self.__eq__(other)

@classmethod
def builder(cls):
return TableSchema.Builder()

class Builder(object):
"""
Builder for creating a :class:`TableSchema`.
"""

def __init__(self):
self._field_names = []
self._field_data_types = []

def field(self, name, data_type):
"""
Add a field with name and data type.
The call order of this method determines the order of fields in the schema.
:param name: The field name.
:param data_type: The field data type.
:return: This object.
"""
assert name is not None
assert data_type is not None
self._field_names.append(name)
self._field_data_types.append(data_type)
return self

def build(self):
"""
Returns a :class:`TableSchema` instance.
:return: The :class:`TableSchema` instance.
"""
return TableSchema(self._field_names, self._field_data_types)
12 changes: 7 additions & 5 deletions flink-python/pyflink/table/tests/test_calc.py
Expand Up @@ -103,17 +103,18 @@ def test_from_element(self):
a.fromstring('ABCD')
t = t_env.from_elements(
[(1, 1.0, "hi", "hello", datetime.date(1970, 1, 2), datetime.time(1, 0, 0),
datetime.datetime(1970, 1, 2, 0, 0), array.array("d", [1]), ["abc"],
[datetime.date(1970, 1, 2)], Decimal(1), Row("a", "b")(1, 2.0),
datetime.datetime(1970, 1, 2, 0, 0), [1.0, None], array.array("d", [1.0, 2.0]),
["abc"], [datetime.date(1970, 1, 2)], Decimal(1), Row("a", "b")(1, 2.0),
{"key": 1.0}, a, ExamplePoint(1.0, 2.0),
PythonOnlyPoint(3.0, 4.0))])
field_names = ["a", "b", "c", "d", "e", "f", "g", "h",
"i", "j", "k", "l", "m", "n", "o", "p"]
"i", "j", "k", "l", "m", "n", "o", "p", "q"]
field_types = [DataTypes.BIGINT(), DataTypes.DOUBLE(), DataTypes.STRING(),
DataTypes.STRING(), DataTypes.DATE(),
DataTypes.TIME(),
DataTypes.TIMESTAMP(),
DataTypes.ARRAY(DataTypes.DOUBLE()),
DataTypes.ARRAY(DataTypes.DOUBLE(False)),
DataTypes.ARRAY(DataTypes.STRING()),
DataTypes.ARRAY(DataTypes.DATE()),
DataTypes.DECIMAL(),
Expand All @@ -130,8 +131,9 @@ def test_from_element(self):
t_env.execute()
actual = source_sink_utils.results()

expected = ['1,1.0,hi,hello,1970-01-02,01:00:00,1970-01-02 00:00:00.0,[1.0],[abc],'
'[1970-01-02],1,1,2.0,{key=1.0},[65, 66, 67, 68],[1.0, 2.0],[3.0, 4.0]']
expected = ['1,1.0,hi,hello,1970-01-02,01:00:00,1970-01-02 00:00:00.0,[1.0, null],'
'[1.0, 2.0],[abc],[1970-01-02],1,1,2.0,{key=1.0},[65, 66, 67, 68],[1.0, 2.0],'
'[3.0, 4.0]']
self.assert_equals(actual, expected)


Expand Down
30 changes: 30 additions & 0 deletions flink-python/pyflink/table/tests/test_descriptor.py
Expand Up @@ -19,6 +19,7 @@

from pyflink.table.table_descriptor import (FileSystem, OldCsv, Rowtime, Schema, Kafka,
Elasticsearch)
from pyflink.table.table_schema import TableSchema
from pyflink.table.table_sink import CsvTableSink
from pyflink.table.types import DataTypes
from pyflink.testing.test_case_utils import (PyFlinkTestCase, PyFlinkStreamTableTestCase,
Expand Down Expand Up @@ -506,6 +507,22 @@ def test_field(self):
'format.property-version': '1'}
assert properties == expected

def test_schema(self):
csv = OldCsv()
schema = TableSchema(["a", "b"], [DataTypes.INT(), DataTypes.STRING()])

csv = csv.schema(schema)

properties = csv.to_properties()
expected = {'format.fields.0.name': 'a',
'format.fields.0.type': 'INT',
'format.fields.1.name': 'b',
'format.fields.1.type': 'VARCHAR',
'format.type': 'csv',
'format.property-version': '1'}

assert properties == expected


class RowTimeDescriptorTests(PyFlinkTestCase):

Expand Down Expand Up @@ -738,6 +755,19 @@ def test_rowtime(self):
'schema.3.type': 'VARCHAR'}
assert properties == expected

def test_schema(self):
schema = Schema()
table_schema = TableSchema(["a", "b"], [DataTypes.INT(), DataTypes.STRING()])

schema = schema.schema(table_schema)

properties = schema.to_properties()
expected = {'schema.0.name': 'a',
'schema.0.type': 'INT',
'schema.1.name': 'b',
'schema.1.type': 'VARCHAR'}
assert properties == expected


class AbstractTableDescriptorTests(object):

Expand Down
Expand Up @@ -15,7 +15,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
################################################################################

from pyflink.table.table_schema import TableSchema
from pyflink.table.types import DataTypes
from pyflink.testing import source_sink_utils
from pyflink.testing.test_case_utils import PyFlinkStreamTableTestCase
Expand All @@ -36,6 +36,21 @@ def test_print_schema(self):
result = t.group_by("c").select("a.sum, c as b")
result.print_schema()

def test_get_schema(self):
t_env = self.t_env
t = t_env.from_elements([(1, 'Hi', 'Hello'), (2, 'Hello', 'Hello'), (2, 'Hello', 'Hello')],
['a', 'b', 'c'])
field_names = ["a", "b"]
field_types = [DataTypes.BIGINT(), DataTypes.STRING()]
t_env.register_table_sink(
"Results",
field_names, field_types, source_sink_utils.TestRetractSink())

result = t.group_by("c").select("a.sum as a, c as b")
schema = result.get_schema()

assert schema == TableSchema(["a", "b"], [DataTypes.BIGINT(), DataTypes.STRING()])


if __name__ == '__main__':
import unittest
Expand Down

0 comments on commit fc2ff3d

Please sign in to comment.