Skip to content

Commit

Permalink
refactor targets (#12)
Browse files Browse the repository at this point in the history
set filename in targets from Content-Disposition header
  • Loading branch information
kolomenkin authored and siddhantgoel committed May 21, 2018
1 parent 704e666 commit a9a4b9e
Show file tree
Hide file tree
Showing 3 changed files with 127 additions and 5 deletions.
4 changes: 4 additions & 0 deletions streaming_form_data/_parser.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,9 @@ class Part:

self._reading = False

def set_multipart_filename(self, value):
self.target.multipart_filename = value

def start(self):
self._reading = True
self.target.start()
Expand Down Expand Up @@ -229,6 +232,7 @@ cdef class _Parser:
name = params.get('name')
if name:
part = self._part_for(name) or self.default_part
part.set_multipart_filename(params.get('filename'))
part.start()

self.set_active_part(part)
Expand Down
29 changes: 25 additions & 4 deletions streaming_form_data/targets.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,22 @@ class BaseTarget:
data_received.
"""

def __init__(self):
self.multipart_filename = None

# 'multipart_filename ' is filled before start() call.
# It contains optional 'filename' value from 'Content-Disposition' header
# Default value is None in case 'filename' is not present.
#
# NOTE! You should be very careful with this value
# because it comes from the user.
# You should never use it without filtering
# to construct filename on disk.
#
# Example library for filtering user strings
# for use in URLs, filenames:
# https://github.com/un33k/python-slugify

def start(self):
pass

Expand All @@ -18,12 +34,16 @@ def finish(self):


class NullTarget(BaseTarget):
def __init__(self):
super().__init__()

def data_received(self, chunk):
pass


class ValueTarget(BaseTarget):
def __init__(self):
super().__init__()
self._values = []

def data_received(self, chunk):
Expand All @@ -35,25 +55,26 @@ def value(self):


class FileTarget(BaseTarget):
def __init__(self, filename):
def __init__(self, filename, allow_overwrite=True):
super().__init__()
self.filename = filename

self._openmode = 'wb' if allow_overwrite else 'xb'
self._fd = None

def start(self):
self._fd = open(self.filename, 'wb')
self._fd = open(self.filename, self._openmode)

def data_received(self, chunk):
self._fd.write(chunk)
self._fd.flush()

def finish(self):
self._fd.flush()
self._fd.close()


class SHA256Target(BaseTarget):
def __init__(self):
super().__init__()
self._hash = hashlib.sha256()

def data_received(self, chunk):
Expand Down
99 changes: 98 additions & 1 deletion tests/test_targets.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,39 +2,136 @@
import tempfile
from unittest import TestCase

from streaming_form_data.targets import ValueTarget, FileTarget
from streaming_form_data.targets \
import BaseTarget, NullTarget, ValueTarget, FileTarget


class NullTargetTestCase(TestCase):
def test_basic(self):
target = NullTarget()

target.multipart_filename = 'file001.txt'

target.start()
self.assertEqual(target.multipart_filename, 'file001.txt')

target.data_received(b'hello')

target.finish()

self.assertEqual(target.multipart_filename, 'file001.txt')

def test_not_sent(self):
target = NullTarget()
self.assertTrue(target.multipart_filename is None)


class ValueTargetTestCase(TestCase):
def test_basic(self):
target = ValueTarget()
self.assertEqual(target.value, b'')

target.multipart_filename = None

target.start()
self.assertTrue(target.multipart_filename is None)
self.assertEqual(target.value, b'')

target.data_received(b'hello')
target.data_received(b' ')
target.data_received(b'world')

target.finish()

self.assertTrue(target.multipart_filename is None)
self.assertEqual(target.value, b'hello world')

def test_not_sent(self):
target = ValueTarget()
self.assertEqual(target.value, b'')
self.assertTrue(target.multipart_filename is None)


class FileTargetTestCase(TestCase):
def test_basic(self):
filename = os.path.join(tempfile.gettempdir(), 'file.txt')

target = FileTarget(filename)

target.multipart_filename = 'file001.txt'

target.start()
self.assertEqual(target.filename, filename)
self.assertEqual(target.multipart_filename, 'file001.txt')
self.assertTrue(os.path.exists(filename))

target.data_received(b'hello')
target.data_received(b' ')
target.data_received(b'world')

target.finish()

self.assertTrue(os.path.exists(filename))

self.assertEqual(target.filename, filename)
self.assertEqual(target.multipart_filename, 'file001.txt')

with open(filename, 'rb') as file_:
self.assertEqual(file_.read(), b'hello world')

def test_not_sent(self):
filename = os.path.join(tempfile.gettempdir(), 'file_not_sent.txt')

target = FileTarget(filename)

self.assertFalse(os.path.exists(filename))

self.assertEqual(target.filename, filename)
self.assertTrue(target.multipart_filename is None)


class CustomTarget(BaseTarget):
def __init__(self):
super().__init__()
self._values = []

def start(self):
self._values.append(b'[start]')

def data_received(self, chunk):
self._values.append(chunk)

def finish(self):
self._values.append(b'[finish]')

@property
def value(self):
return b' '.join(self._values)


class CustomTargetTestCase(TestCase):
def test_basic(self):
target = CustomTarget()
self.assertEqual(target.value, b'')

target.multipart_filename = 'file.txt'

target.start()
self.assertEqual(target.multipart_filename, 'file.txt')
self.assertEqual(target.value, b'[start]')

target.data_received(b'chunk1')
target.data_received(b'chunk2')
self.assertEqual(target.value, b'[start] chunk1 chunk2')
target.data_received(b'chunk3')

target.finish()

self.assertEqual(target.multipart_filename, 'file.txt')
self.assertEqual(target.value,
b'[start] chunk1 chunk2 chunk3 [finish]')

def test_not_sent(self):
target = CustomTarget()
self.assertEqual(target.value, b'')
self.assertTrue(target.multipart_filename is None)

0 comments on commit a9a4b9e

Please sign in to comment.