Skip to content

Commit

Permalink
gh-76728: Coerce DictReader and DictWriter fieldnames argument to a l…
Browse files Browse the repository at this point in the history
…ist (GH-32225)
  • Loading branch information
dignissimus committed Aug 25, 2022
1 parent c09fa75 commit cd492d4
Show file tree
Hide file tree
Showing 4 changed files with 37 additions and 0 deletions.
4 changes: 4 additions & 0 deletions Doc/library/csv.rst
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,8 @@ The :mod:`csv` module defines the following classes:
All other optional or keyword arguments are passed to the underlying
:class:`reader` instance.

If the argument passed to *fieldnames* is an iterator, it will be coerced to a :class:`list`.

.. versionchanged:: 3.6
Returned rows are now of type :class:`OrderedDict`.

Expand Down Expand Up @@ -209,6 +211,8 @@ The :mod:`csv` module defines the following classes:
Note that unlike the :class:`DictReader` class, the *fieldnames* parameter
of the :class:`DictWriter` class is not optional.

If the argument passed to *fieldnames* is an iterator, it will be coerced to a :class:`list`.

A short usage example::

import csv
Expand Down
4 changes: 4 additions & 0 deletions Lib/csv.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,8 @@ class unix_dialect(Dialect):
class DictReader:
def __init__(self, f, fieldnames=None, restkey=None, restval=None,
dialect="excel", *args, **kwds):
if fieldnames is not None and iter(fieldnames) is fieldnames:
fieldnames = list(fieldnames)
self._fieldnames = fieldnames # list of keys for the dict
self.restkey = restkey # key to catch long rows
self.restval = restval # default value for short rows
Expand Down Expand Up @@ -133,6 +135,8 @@ def __next__(self):
class DictWriter:
def __init__(self, f, fieldnames, restval="", extrasaction="raise",
dialect="excel", *args, **kwds):
if fieldnames is not None and iter(fieldnames) is fieldnames:
fieldnames = list(fieldnames)
self.fieldnames = fieldnames # list of keys for the dict
self.restval = restval # for writing short dicts
if extrasaction.lower() not in ("raise", "ignore"):
Expand Down
28 changes: 28 additions & 0 deletions Lib/test/test_csv.py
Original file line number Diff line number Diff line change
Expand Up @@ -736,6 +736,34 @@ def test_write_field_not_in_field_names_ignore(self):
csv.DictWriter.writerow(writer, dictrow)
self.assertEqual(fileobj.getvalue(), "1,2\r\n")

def test_dict_reader_fieldnames_accepts_iter(self):
fieldnames = ["a", "b", "c"]
f = StringIO()
reader = csv.DictReader(f, iter(fieldnames))
self.assertEqual(reader.fieldnames, fieldnames)

def test_dict_reader_fieldnames_accepts_list(self):
fieldnames = ["a", "b", "c"]
f = StringIO()
reader = csv.DictReader(f, fieldnames)
self.assertEqual(reader.fieldnames, fieldnames)

def test_dict_writer_fieldnames_rejects_iter(self):
fieldnames = ["a", "b", "c"]
f = StringIO()
writer = csv.DictWriter(f, iter(fieldnames))
self.assertEqual(writer.fieldnames, fieldnames)

def test_dict_writer_fieldnames_accepts_list(self):
fieldnames = ["a", "b", "c"]
f = StringIO()
writer = csv.DictWriter(f, fieldnames)
self.assertEqual(writer.fieldnames, fieldnames)

def test_dict_reader_fieldnames_is_optional(self):
f = StringIO()
reader = csv.DictReader(f, fieldnames=None)

def test_read_dict_fields(self):
with TemporaryFile("w+", encoding="utf-8") as fileobj:
fileobj.write("1,2,abc\r\n")
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
The constructors for :class:`~csv.DictWriter` and :class:`~csv.DictReader` now coerce the ``fieldnames`` argument to a :class:`list` if it is an iterator.

0 comments on commit cd492d4

Please sign in to comment.