-
Notifications
You must be signed in to change notification settings - Fork 44
/
csvreader.py
114 lines (96 loc) · 3.66 KB
/
csvreader.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
import itertools
from functools import partial
from pysparkling.fileio import TextFile
from pysparkling.sql.casts import get_caster
from pysparkling.sql.internal_utils.options import Options
from pysparkling.sql.internal_utils.readers.utils import resolve_partitions, \
guess_schema_from_strings
from pysparkling.sql.internals import DataFrameInternal
from pysparkling.sql.schema_utils import infer_schema_from_rdd
from pysparkling.sql.types import StructType, StringType, StructField, create_row
class CSVReader(object):
default_options = dict(
lineSep=None,
encoding="utf-8",
sep=",",
inferSchema=False,
header=False
)
def __init__(self, spark, paths, schema, options):
self.spark = spark
self.paths = paths
self.schema = schema
self.options = Options(self.default_options, options)
def read(self):
sc = self.spark._sc
paths = self.paths
partitions, partition_schema = resolve_partitions(paths)
rdd_filenames = sc.parallelize(sorted(partitions.keys()), len(partitions))
rdd = rdd_filenames.flatMap(partial(
parse_csv_file,
partitions,
partition_schema,
self.schema,
self.options
))
if self.schema is not None:
schema = self.schema
elif self.options.inferSchema:
fields = rdd.take(1)[0].__fields__
schema = guess_schema_from_strings(fields, rdd.collect(), options=self.options)
else:
schema = infer_schema_from_rdd(rdd)
schema_with_string = StructType(fields=[
StructField(field.name, StringType()) for field in schema.fields
])
if partition_schema:
partitions_fields = partition_schema.fields
full_schema = StructType(schema.fields[:-len(partitions_fields)] + partitions_fields)
else:
full_schema = schema
cast_row = get_caster(
from_type=schema_with_string, to_type=full_schema, options=self.options
)
casted_rdd = rdd.map(cast_row)
casted_rdd._name = paths
return DataFrameInternal(
sc,
casted_rdd,
schema=full_schema
)
def parse_csv_file(partitions, partition_schema, schema, options, file_name):
f_content = TextFile(file_name).load(encoding=options.encoding).read()
records = (f_content.split(options.lineSep)
if options.lineSep is not None
else f_content.splitlines())
if options.header == "true":
header = records[0].split(options.sep)
records = records[1:]
else:
header = None
null_value = ""
rows = []
for record in records:
row = csv_record_to_row(
record, options, schema, header, null_value, partition_schema, partitions[file_name]
)
row.set_input_file_name(file_name)
rows.append(row)
return rows
def csv_record_to_row(record, options, schema=None, header=None,
null_value=None, partition_schema=None, partition=None):
record_values = [val if val != null_value else None for val in record.split(options.sep)]
if schema is not None:
field_names = [f.name for f in schema.fields]
elif header is not None:
field_names = header
else:
field_names = ["_c{0}".format(i) for i, field in enumerate(record_values)]
partition_field_names = [
f.name for f in partition_schema.fields
] if partition_schema else []
row = create_row(
itertools.chain(field_names, partition_field_names),
itertools.chain(record_values, partition or [])
)
return row