-
Notifications
You must be signed in to change notification settings - Fork 44
/
writers.py
371 lines (304 loc) · 11 KB
/
writers.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
import collections
import csv
import json
import os
import shutil
from pysparkling import Row
from pysparkling.sql.casts import cast_to_string
from pysparkling.sql.expressions.aggregate.aggregations import Aggregation
from pysparkling.sql.expressions.mappers import StarOperator
from pysparkling.sql.functions import col
from pysparkling.sql.internal_utils.options import Options
from pysparkling.sql.internal_utils.readwrite import to_option_stored_value
from pysparkling.sql.utils import AnalysisException
from pysparkling.utils import portable_hash, get_json_encoder
class InternalWriter(object):
def __init__(self, df):
self._df = df
self._source = "parquet"
self._mode = "errorifexists"
self._options = {}
self._partitioning_col_names = None
self._num_buckets = None
self._bucket_col_names = None
self._sort_col_names = None
def option(self, k, v):
self._options[k.lower()] = to_option_stored_value(v)
return self
def mode(self, mode):
self._mode = mode
return self
def format(self, source):
self._source = source
return self
def partitionBy(self, partitioning_col_names):
self._partitioning_col_names = partitioning_col_names
return self
def bucketBy(self, num_buckets, *bucket_cols):
self._num_buckets = num_buckets
self._bucket_col_names = bucket_cols
return self
def sortBy(self, sort_cols):
self._sort_col_names = sort_cols
return self
def save(self, writer_class, path=None):
self.option("path", path)
return writer_class(
self._df,
self._mode,
self._options,
self._partitioning_col_names,
self._num_buckets,
self._bucket_col_names,
self._sort_col_names
).save()
class WriteInFolder(Aggregation):
"""
This use the computation engine of pysparkling to write the values in a folder.
It's behaviour is similar to a collect_list except that items are written
in a folder instead of being returned:
Its evaluation only return the number of items written while writing them
using the writer given in the constructor, more specifically its write method.
Pre-formatting is done as defined by writer.preformat during the merge phase.
"""
def __init__(self, writer):
super(WriteInFolder, self).__init__()
self.column = col(StarOperator())
self.writer = writer
self.ref_value = None
self.items = []
def merge(self, row, schema):
row_value = self.column.eval(row, schema)
if self.ref_value is None:
ref_value = Row(*row_value)
ref_value.__fields__ = schema.names
self.ref_value = ref_value
self.items.append(
self.writer.preformat(row_value, schema)
)
def mergeStats(self, other, schema):
self.items += other.items
if self.ref_value is None:
self.ref_value = other.ref_value
def eval(self, row, schema):
return self.writer.write(
self.items,
self.ref_value,
self.pre_evaluation_schema
)
def __str__(self):
return "write_in_folder({0})".format(self.column)
class DataWriter(object):
default_options = dict(
dateFormat="yyyy-MM-dd",
timestampFormat="yyyy-MM-dd'T'HH:mm:ss.SSSXXX",
)
def __init__(self, df, mode, options, partitioning_col_names, num_buckets,
bucket_col_names, sort_col_names):
"""
:param df: pysparkling.sql.DataFrame
:param mode: str
:param options: Dict[str, Optional[str]]
:param partitioning_col_names: Optional[List[str]]
:param num_buckets: Optional[int]
:param bucket_col_names: Optional[List[str]]
:param sort_col_names: Optional[List[str]]
"""
self.mode = mode
self.options = Options(self.default_options, options)
self.partitioning_col_names = partitioning_col_names if partitioning_col_names else []
self.num_buckets = num_buckets
self.bucket_col_names = bucket_col_names if partitioning_col_names else []
self.sort_col_names = sort_col_names if partitioning_col_names else []
if self.partitioning_col_names:
self.apply_on_aggregated_data = df.groupBy(*self.partitioning_col_names).agg
else:
self.apply_on_aggregated_data = df.select
@property
def path(self):
return self.options["path"].rstrip("/")
@property
def compression(self):
return None
@property
def encoding(self):
return None
def save(self):
output_path = self.path
mode = self.mode
if os.path.exists(output_path):
if mode == "ignore":
return
if mode in ("error", "errorifexists"):
raise AnalysisException("path {0} already exists.;".format(output_path))
if mode == "overwrite":
shutil.rmtree(output_path)
os.makedirs(output_path)
else:
os.makedirs(output_path)
self.apply_on_aggregated_data(col(WriteInFolder(writer=self))).collect()
success_path = os.path.join(output_path, "_SUCCESS")
with open(success_path, "w"):
pass
def preformat(self, row, schema):
raise NotImplementedError
def write(self, items, ref_value, schema):
"""
Write a list of rows (items) which have a given schema
Returns the number of rows written
"""
raise NotImplementedError
class CSVWriter(DataWriter):
def check_options(self):
unsupported_options = {
"compression",
"encoding",
"chartoescapequoteescaping",
"escape",
"escapequotes"
}
options_requested_but_not_supported = set(self.options) & unsupported_options
if options_requested_but_not_supported:
raise NotImplementedError(
"Pysparkling does not support yet the following options: {0}".format(
options_requested_but_not_supported
)
)
def preformat_cell(self, value, field):
if value is None:
value = self.nullValue
else:
value = cast_to_string(
value,
from_type=field.dataType,
options=self.options
)
if self.ignoreLeadingWhiteSpace:
value = value.rstrip()
if self.ignoreTrailingWhiteSpace:
value = value.lstrip()
if value == "":
return self.emptyValue
return value
def preformat(self, row, schema):
return tuple(
self.preformat_cell(value, field)
for value, field in zip(row, schema.fields)
)
@property
def sep(self):
return self.options.get("sep", ",")
@property
def quote(self):
quote = self.options.get("quote", '"')
return "\u0000" if quote == "" else quote
@property
def escape(self):
return self.options.get("escape", "\\")
@property
def header(self):
return self.options.get("header", "false") != "false"
@property
def nullValue(self):
return self.options.get("nullvalue", "")
@property
def escapeQuotes(self):
return self.options.get("escapequotes", "true") != "false"
@property
def quoteAll(self):
return self.options.get("quoteall", "false") != "false"
@property
def ignoreLeadingWhiteSpace(self):
return self.options.get("ignoreleadingwhiteSpace", 'false') != "false"
@property
def ignoreTrailingWhiteSpace(self):
return self.options.get("ignoretrailingwhiteSpace", 'false') != "false"
@property
def charToEscapeQuoteEscaping(self):
return None
@property
def emptyValue(self):
return self.options.get("emptyvalue", '""')
@property
def lineSep(self):
return self.options.get("linesep", "\n")
def write(self, items, ref_value, schema):
self.check_options()
output_path = self.path
if not items:
return 0
partition_parts = [
"{0}={1}".format(col_name, ref_value[col_name])
for col_name in self.partitioning_col_names
]
file_path = "/".join(
[output_path]
+ partition_parts
+ ["part-00000-{0}.csv".format(portable_hash(ref_value))]
)
# pylint: disable=W0511
# todo: Add support of:
# - all files systems (not only local)
# - compression
# - encoding
# - charToEscapeQuoteEscaping
# - escape
# - escapeQuotes
with open(file_path, "w") as f:
writer = csv.writer(
f,
delimiter=self.sep,
quotechar=self.quote,
quoting=csv.QUOTE_ALL if self.quoteAll else csv.QUOTE_MINIMAL,
lineterminator=self.lineSep
)
if self.header:
writer.writerow(schema.names)
writer.writerows(items)
return len(items)
class JSONWriter(DataWriter):
def __init__(self, df, mode, options, partitioning_col_names,
num_buckets, bucket_col_names, sort_col_names):
super(JSONWriter, self).__init__(df, mode, options, partitioning_col_names,
num_buckets, bucket_col_names, sort_col_names)
self.encoder = get_json_encoder(self.options)
def check_options(self):
unsupported_options = {
"compression",
"encoding",
"chartoescapequoteescaping",
"escape",
"escapequotes"
}
options_requested_but_not_supported = set(self.options) & unsupported_options
if options_requested_but_not_supported:
raise NotImplementedError(
"Pysparkling does not support yet the following options: {0}".format(
options_requested_but_not_supported
)
)
@property
def lineSep(self):
return self.options.get("linesep", "\n")
def preformat(self, row, schema):
return json.dumps(
collections.OrderedDict(zip(schema.names, row)),
cls=self.encoder,
separators=(',', ':')
) + self.lineSep
def write(self, items, ref_value, schema):
self.check_options()
output_path = self.path
if not items:
return 0
partition_parts = [
"{0}={1}".format(col_name, ref_value[col_name])
for col_name in self.partitioning_col_names
]
partition_folder = "/".join([output_path] + partition_parts)
file_path = "{0}/part-00000-{1}.json".format(partition_folder, portable_hash(ref_value))
if not os.path.exists(partition_folder):
os.makedirs(partition_folder)
with open(file_path, "a") as f:
f.writelines(items)
return len(items)