Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix import of CSV files for python 3 #537

Merged
merged 2 commits into from
Aug 6, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 5 additions & 4 deletions pydal/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -860,14 +860,15 @@ def import_from_csv_file(self, ifile, id_map=None, null='<NULL>',
id_offset = {} # only used if id_map is None
map_tablenames = map_tablenames or {}
for line in ifile:
line = line.strip()
line = line.decode().strip()
if not line:
continue
elif line == 'END':
return
elif not line.startswith('TABLE ') or \
not line[6:] in self.tables:
raise SyntaxError('invalid file format')
elif not line.startswith('TABLE ') :
raise SyntaxError('Invalid file format')
elif not line[6:] in self.tables:
raise SyntaxError('Unknown table : %s' % line[6:])
else:
tablename = line[6:]
tablename = map_tablenames.get(tablename,tablename)
Expand Down
31 changes: 15 additions & 16 deletions pydal/objects.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,12 +46,12 @@
'id': '[1-9]\d*',
'decimal': '\d{1,10}\.\d{2}',
'integer': '[+-]?\d*',
'float': '[+-]?\d*(\.\d*)?',
'float': '[+-]?\d*(\.\d*)?',
'double': '[+-]?\d*(\.\d*)?',
'date': '\d{4}\-\d{2}\-\d{2}',
'time': '\d{2}\:\d{2}(\:\d{2}(\.\d*)?)?',
'datetime':'\d{4}\-\d{2}\-\d{2} \d{2}\:\d{2}(\:\d{2}(\.\d*)?)?',
}
}

class Row(BasicStorage):

Expand Down Expand Up @@ -381,9 +381,9 @@ def fields(self):

def _structure(self):
keys = ['name','type','writable','listable','searchable','regex','options',
'default','label','unique','notnull','required']
'default','label','unique','notnull','required']
def noncallable(obj): return obj if not callable(obj) else None
return [{key: noncallable(getattr(field, key)) for key in keys}
return [{key: noncallable(getattr(field, key)) for key in keys}
for field in self if field.readable and not field.type=='password']

@cachedprop
Expand Down Expand Up @@ -415,7 +415,7 @@ def _enable_record_versioning(self,
clones.append(
field.clone(unique=False, type=field.type if nfk else 'bigint')
)

d = dict(format=self._format)
if migrate:
d['migrate'] = migrate
Expand All @@ -427,7 +427,7 @@ def _enable_record_versioning(self,
d['redefine'] = redefine
archive_db.define_table(
archive_name,
Field(current_record, field_type, label=current_record_label),
Field(current_record, field_type, label=current_record_label),
*clones, **d)

self._before_update.append(
Expand All @@ -447,7 +447,7 @@ def _enable_record_versioning(self,
self._common_filter = lambda q: reduce(
AND, [query(q), newquery(q)])
else:
self._common_filter = newquery
self._common_filter = newquery

def _validate(self, **vars):
errors = Row()
Expand Down Expand Up @@ -795,7 +795,7 @@ def validate_and_insert(self, **fields):
response.id = self.insert(**new_fields)
return response

def validate_and_update(self, _key=DEFAULT, **fields):
def validate_and_update(self, _key=DEFAULT, **fields):
response, new_fields = self._validate_fields(fields, 'update')
#: select record(s) for update
if _key is DEFAULT:
Expand Down Expand Up @@ -891,9 +891,9 @@ def import_from_csv_file(self,
null = '<NULL>',
unique = 'uuid',
id_offset = None, # id_offset used only when id_map is None
transform = None,
transform = None,
validate=False,
**kwargs
**kwargs
):
"""
Import records from csv file.
Expand All @@ -912,7 +912,6 @@ def import_from_csv_file(self,
incrementing order.
Will keep the id numbers in restored table.
"""

if validate:
inserting=self.validate_and_insert
else:
Expand All @@ -925,8 +924,8 @@ def import_from_csv_file(self,
if restore:
self._db[self].truncate()

reader = csv.reader(csvfile, delimiter=delimiter,
quotechar=quotechar, quoting=quoting)
csvfile = csvfile.read().decode()
reader = csv.reader(csvfile, delimiter=delimiter,quotechar=quotechar, quoting=quoting)
colnames = None
if isinstance(id_map, dict):
if self._tablename not in id_map:
Expand Down Expand Up @@ -1287,10 +1286,10 @@ def abs(self):
return Expression(
self.db, self._dialect.aggregate, self, 'ABS', self.type)

def cast(self, cast_as, **kwargs):
def cast(self, cast_as, **kwargs):
return Expression(
self.db, self._dialect.cast, self, self._dialect.types[cast_as] % kwargs, cast_as)

def lower(self):
return Expression(
self.db, self._dialect.lower, self, None, self.type)
Expand Down Expand Up @@ -1614,7 +1613,7 @@ class Field(Expression, Serializable):
def __init__(self, fieldname, type='string', length=None, default=DEFAULT,
required=False, requires=DEFAULT, ondelete='CASCADE',
notnull=False, unique=False, uploadfield=True, widget=None,
label=None, comment=None,
label=None, comment=None,
writable=True, readable=True,
searchable=True, listable=True,
regex=None, options=None,
Expand Down