Permalink
Browse files

BUG: io/wavfile: safer writing directly from buffer

Fixes gh-2928
  • Loading branch information...
1 parent 9c5bab4 commit f4d8447a3f8f66acb8f3c428e8cc72dde5494f4c @pv pv committed with rgommers Sep 25, 2013
Showing with 23 additions and 9 deletions.
  1. +12 −5 scipy/io/tests/test_wavfile.py
  2. +11 −4 scipy/io/wavfile.py
@@ -84,21 +84,28 @@ def _check_roundtrip(realfile, rate, dtype, channels):
def test_write_roundtrip():
for realfile in (False, True):
- for signed in ('i', 'u', 'f'):
+ for dtypechar in ('i', 'u', 'f', 'g', 'q'):
for size in (1, 2, 4, 8):
- if size == 1 and signed == 'i':
+ if size == 1 and dtypechar == 'i':
# signed 8-bit integer PCM is not allowed
continue
- if size > 1 and signed == 'u':
+ if size > 1 and dtypechar == 'u':
# unsigned > 8-bit integer PCM is not allowed
continue
- if (size == 1 or size == 2) and signed == 'f':
+ if (size == 1 or size == 2) and dtypechar == 'f':
# 8- or 16-bit float PCM is not expected
continue
+ if dtypechar in 'gq':
+ # no size allowed for these types
+ if size == 1:
+ size = ''
+ else:
+ continue
+
for endianness in ('>', '<'):
if size == 1 and endianness == '<':
continue
for rate in (8000, 32000):
for channels in (1, 2, 5):
- dt = np.dtype('%s%s%d' % (endianness, signed, size))
+ dt = np.dtype('%s%s%s' % (endianness, dtypechar, size))
yield _check_roundtrip, realfile, rate, dt, channels
View
@@ -238,10 +238,8 @@ def write(filename, rate, data):
fid.write(struct.pack('<i', data.nbytes))
if data.dtype.byteorder == '>' or (data.dtype.byteorder == '=' and sys.byteorder == 'big'):
data = data.byteswap()
- if sys.version_info[0] >= 3:
- fid.write(data.ravel().data)
- else:
- fid.write(data.tostring())
+ _array_tofile(fid, data)
+
# Determine file size and place it in correct
# position at start of the file.
size = fid.tell()
@@ -253,3 +251,12 @@ def write(filename, rate, data):
fid.close()
else:
fid.seek(0)
+
+
+if sys.version_info[0] >= 3:
+ def _array_tofile(fid, data):
+ # ravel gives a c-contiguous buffer
+ fid.write(data.ravel().view('b').data)
+else:
+ def _array_tofile(fid, data):
+ fid.write(data.tostring())

0 comments on commit f4d8447

Please sign in to comment.