Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

gh-76728: Coerce DictReader and DictWriter fieldnames argument to a list #32225

Merged
merged 15 commits into from
Aug 25, 2022
4 changes: 4 additions & 0 deletions Doc/library/csv.rst
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,8 @@ The :mod:`csv` module defines the following classes:
All other optional or keyword arguments are passed to the underlying
:class:`reader` instance.

If the argument passed to *fieldnames* is an iterator, it will be coerced to a :class:`list`.

.. versionchanged:: 3.6
Returned rows are now of type :class:`OrderedDict`.

Expand Down Expand Up @@ -209,6 +211,8 @@ The :mod:`csv` module defines the following classes:
Note that unlike the :class:`DictReader` class, the *fieldnames* parameter
of the :class:`DictWriter` class is not optional.

If the argument passed to *fieldnames* is an iterator, it will be coerced to a :class:`list`.

A short usage example::

import csv
Expand Down
4 changes: 4 additions & 0 deletions Lib/csv.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,8 @@ class unix_dialect(Dialect):
class DictReader:
def __init__(self, f, fieldnames=None, restkey=None, restval=None,
dialect="excel", *args, **kwds):
if fieldnames is not None and iter(fieldnames) is fieldnames:
fieldnames = list(fieldnames)
self._fieldnames = fieldnames # list of keys for the dict
self.restkey = restkey # key to catch long rows
self.restval = restval # default value for short rows
Expand Down Expand Up @@ -130,6 +132,8 @@ def __next__(self):
class DictWriter:
def __init__(self, f, fieldnames, restval="", extrasaction="raise",
dialect="excel", *args, **kwds):
if fieldnames is not None and iter(fieldnames) is fieldnames:
fieldnames = list(fieldnames)
self.fieldnames = fieldnames # list of keys for the dict
self.restval = restval # for writing short dicts
if extrasaction.lower() not in ("raise", "ignore"):
Expand Down
28 changes: 28 additions & 0 deletions Lib/test/test_csv.py
Original file line number Diff line number Diff line change
Expand Up @@ -736,6 +736,34 @@ def test_write_field_not_in_field_names_ignore(self):
csv.DictWriter.writerow(writer, dictrow)
self.assertEqual(fileobj.getvalue(), "1,2\r\n")

def test_dict_reader_fieldnames_accepts_iter(self):
fieldnames = ["a", "b", "c"]
f = StringIO()
reader = csv.DictReader(f, iter(fieldnames))
self.assertEqual(reader.fieldnames, fieldnames)

def test_dict_reader_fieldnames_accepts_list(self):
fieldnames = ["a", "b", "c"]
f = StringIO()
reader = csv.DictReader(f, fieldnames)
self.assertEqual(reader.fieldnames, fieldnames)

def test_dict_writer_fieldnames_rejects_iter(self):
fieldnames = ["a", "b", "c"]
f = StringIO()
writer = csv.DictWriter(f, iter(fieldnames))
self.assertEqual(writer.fieldnames, fieldnames)

def test_dict_writer_fieldnames_accepts_list(self):
fieldnames = ["a", "b", "c"]
f = StringIO()
writer = csv.DictWriter(f, fieldnames)
self.assertEqual(writer.fieldnames, fieldnames)

def test_dict_reader_fieldnames_is_optional(self):
f = StringIO()
reader = csv.DictReader(f, fieldnames=None)

def test_read_dict_fields(self):
with TemporaryFile("w+", encoding="utf-8") as fileobj:
fileobj.write("1,2,abc\r\n")
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
The constructors for :class:`~csv.DictWriter` and :class:`~csv.DictReader` now coerce the ``fieldnames`` argument to a :class:`list` if it is an iterator.