From cd492d43a2980faf0ef4a3f99c665023a506414c Mon Sep 17 00:00:00 2001 From: Sam Ezeh Date: Thu, 25 Aug 2022 11:13:24 +0100 Subject: [PATCH] gh-76728: Coerce DictReader and DictWriter fieldnames argument to a list (GH-32225) --- Doc/library/csv.rst | 4 +++ Lib/csv.py | 4 +++ Lib/test/test_csv.py | 28 +++++++++++++++++++ .../2022-04-01-09-43-54.bpo-32547.NIUiNC.rst | 1 + 4 files changed, 37 insertions(+) create mode 100644 Misc/NEWS.d/next/Library/2022-04-01-09-43-54.bpo-32547.NIUiNC.rst diff --git a/Doc/library/csv.rst b/Doc/library/csv.rst index 9dec7240d9c50f..0cab95e1500a66 100644 --- a/Doc/library/csv.rst +++ b/Doc/library/csv.rst @@ -167,6 +167,8 @@ The :mod:`csv` module defines the following classes: All other optional or keyword arguments are passed to the underlying :class:`reader` instance. + If the argument passed to *fieldnames* is an iterator, it will be coerced to a :class:`list`. + .. versionchanged:: 3.6 Returned rows are now of type :class:`OrderedDict`. @@ -209,6 +211,8 @@ The :mod:`csv` module defines the following classes: Note that unlike the :class:`DictReader` class, the *fieldnames* parameter of the :class:`DictWriter` class is not optional. + If the argument passed to *fieldnames* is an iterator, it will be coerced to a :class:`list`. + A short usage example:: import csv diff --git a/Lib/csv.py b/Lib/csv.py index bfc850ee96dab6..0de5656a4eed7b 100644 --- a/Lib/csv.py +++ b/Lib/csv.py @@ -81,6 +81,8 @@ class unix_dialect(Dialect): class DictReader: def __init__(self, f, fieldnames=None, restkey=None, restval=None, dialect="excel", *args, **kwds): + if fieldnames is not None and iter(fieldnames) is fieldnames: + fieldnames = list(fieldnames) self._fieldnames = fieldnames # list of keys for the dict self.restkey = restkey # key to catch long rows self.restval = restval # default value for short rows @@ -133,6 +135,8 @@ def __next__(self): class DictWriter: def __init__(self, f, fieldnames, restval="", extrasaction="raise", dialect="excel", *args, **kwds): + if fieldnames is not None and iter(fieldnames) is fieldnames: + fieldnames = list(fieldnames) self.fieldnames = fieldnames # list of keys for the dict self.restval = restval # for writing short dicts if extrasaction.lower() not in ("raise", "ignore"): diff --git a/Lib/test/test_csv.py b/Lib/test/test_csv.py index 95a19dd46cb4ff..51ca1f2ab29285 100644 --- a/Lib/test/test_csv.py +++ b/Lib/test/test_csv.py @@ -736,6 +736,34 @@ def test_write_field_not_in_field_names_ignore(self): csv.DictWriter.writerow(writer, dictrow) self.assertEqual(fileobj.getvalue(), "1,2\r\n") + def test_dict_reader_fieldnames_accepts_iter(self): + fieldnames = ["a", "b", "c"] + f = StringIO() + reader = csv.DictReader(f, iter(fieldnames)) + self.assertEqual(reader.fieldnames, fieldnames) + + def test_dict_reader_fieldnames_accepts_list(self): + fieldnames = ["a", "b", "c"] + f = StringIO() + reader = csv.DictReader(f, fieldnames) + self.assertEqual(reader.fieldnames, fieldnames) + + def test_dict_writer_fieldnames_rejects_iter(self): + fieldnames = ["a", "b", "c"] + f = StringIO() + writer = csv.DictWriter(f, iter(fieldnames)) + self.assertEqual(writer.fieldnames, fieldnames) + + def test_dict_writer_fieldnames_accepts_list(self): + fieldnames = ["a", "b", "c"] + f = StringIO() + writer = csv.DictWriter(f, fieldnames) + self.assertEqual(writer.fieldnames, fieldnames) + + def test_dict_reader_fieldnames_is_optional(self): + f = StringIO() + reader = csv.DictReader(f, fieldnames=None) + def test_read_dict_fields(self): with TemporaryFile("w+", encoding="utf-8") as fileobj: fileobj.write("1,2,abc\r\n") diff --git a/Misc/NEWS.d/next/Library/2022-04-01-09-43-54.bpo-32547.NIUiNC.rst b/Misc/NEWS.d/next/Library/2022-04-01-09-43-54.bpo-32547.NIUiNC.rst new file mode 100644 index 00000000000000..4599b73cc342ca --- /dev/null +++ b/Misc/NEWS.d/next/Library/2022-04-01-09-43-54.bpo-32547.NIUiNC.rst @@ -0,0 +1 @@ +The constructors for :class:`~csv.DictWriter` and :class:`~csv.DictReader` now coerce the ``fieldnames`` argument to a :class:`list` if it is an iterator.