Skip to content

Commit

Permalink
fix syntax and add tests
Browse files Browse the repository at this point in the history
  • Loading branch information
notoriousno committed Jan 22, 2017
1 parent 5d929d5 commit f3c3420
Show file tree
Hide file tree
Showing 4 changed files with 74 additions and 21 deletions.
6 changes: 3 additions & 3 deletions src/klein/app.py
Expand Up @@ -24,7 +24,7 @@ def iscoroutine(*args, **kwargs):

from twisted.web.iweb import IRenderable
from twisted.web.template import renderElement
from twisted.web.server import Site, Request
from twisted.web.server import Request
from twisted.internet import reactor, endpoints

try:
Expand All @@ -35,7 +35,7 @@ def ensureDeferred(*args, **kwagrs):

from zope.interface import implementer

from klein.resource import KleinResource
from klein.resource import KleinResource, KleinSite
from klein.interfaces import IKleinRequest

__all__ = ['Klein', 'run', 'route', 'resource']
Expand Down Expand Up @@ -367,7 +367,7 @@ def run(self, host=None, port=None, logFile=None,
host)

endpoint = endpoints.serverFromString(reactor, endpoint_description)
endpoint.listen(Site(self.resource()))
endpoint.listen(KleinSite(self.resource()))
reactor.run()


Expand Down
22 changes: 7 additions & 15 deletions src/klein/resource.py
Expand Up @@ -288,33 +288,25 @@ class KleinHTTPRequest(server.Request):
def getArg(self, key):
"""
Get a single arg value.
@raises KeyError: If key doesn't exist
@raises ValueError: If there is more than 1 value
@return: L{list} of L{bytes}
"""
key = ensure_utf8_bytes(key)
value = self.arg[key]
value = self.args[key]
if len(value) != 1:
raise ValueError('Too many values for: {0}'.format(key))
return value[0]

def getArgs(self, key):
"""
Get the list of values for a key.
"""
key = ensure_utf8_bytes(key)
return self.args.get(key, [])

def appendArg(self, key, value):
"""
Append a value into the list.
"""
key = ensure_utf8_bytes(key)
self.args.setdefault(key, []).append(value)
def setArg(self, key, value):
"""
Set a value for a given key. The value will always be in a list.
@return: L{list} of L{bytes}
"""
key = ensure_utf8_bytes(key)
self.args[key] = [value]
return self.args.get(key, [])



Expand Down
6 changes: 3 additions & 3 deletions src/klein/test/test_app.py
Expand Up @@ -246,12 +246,12 @@ def foo(request):


@patch('klein.app.KleinResource')
@patch('klein.app.Site')
@patch('klein.app.KleinSite')
@patch('klein.app.log')
@patch('klein.app.reactor')
def test_run(self, reactor, mock_log, mock_site, mock_kr):
"""
L{Klein.run} configures a L{KleinResource} and a L{Site}
L{Klein.run} configures a L{KleinResource} and a L{KleinSite}
listening on the specified interface and port, and logs
to stdout.
"""
Expand All @@ -270,7 +270,7 @@ def test_run(self, reactor, mock_log, mock_site, mock_kr):


@patch('klein.app.KleinResource')
@patch('klein.app.Site')
@patch('klein.app.KleinSite')
@patch('klein.app.log')
@patch('klein.app.reactor')
def test_runWithLogFile(self, reactor, mock_log, mock_site, mock_kr):
Expand Down
61 changes: 61 additions & 0 deletions src/klein/test/test_request.py
@@ -0,0 +1,61 @@
from __future__ import absolute_import

from klein import Klein
from klein.resource import KleinHTTPRequest, KleinSite
from klein.test.util import TestCase
from twisted.web.test.requesthelper import DummyChannel

class BytesUnicodeTest(TestCase):

def setUp(self):
self.encoding = 'utf-8'
app = Klein()
site = KleinSite(app.resource())
channel = site.buildProtocol('127.0.0.1')
self.request = channel.requestFactory(DummyChannel(), None)
self.request.args = {}

def test_getArg(self):
str_key = 'test'
bytes_key = str_key.encode(self.encoding)
value = b'hello world'

self.request.args[bytes_key] = [value]
self.assertEquals(self.request.getArg(str_key), value)
self.assertEquals(self.request.getArg(bytes_key), value)

def test_getArg_not_1(self):
"""
Raise exception if there are more or less values than 1
"""
str_key = 'test'
bytes_key = str_key.encode(self.encoding)
values = []

self.request.args[bytes_key] = values
self.assertRaises(ValueError, self.request.getArg, str_key)
self.assertRaises(ValueError, self.request.getArg, bytes_key)

values.extend([b'hello', b'world'])
self.assertRaises(ValueError, self.request.getArg, str_key)
self.assertRaises(ValueError, self.request.getArg, bytes_key)

def test_getArgs(self):
str_key = 'test'
bytes_key = str_key.encode(self.encoding)
values = [b'hello world', b'hey earth']
self.request.args[bytes_key] = values

self.assertEquals(len(self.request.getArgs(str_key)), len(values))
self.assertEquals(self.request.getArgs(str_key), values)
self.assertEquals(self.request.getArgs(bytes_key), values)

def test_getArgs_no_key(self):
"""
By default, an empty list is returned if a key doesn't exist
"""
str_key = 'test'
bytes_key = str_key.encode(self.encoding)

self.assertEquals(self.request.getArgs(str_key), [])
self.assertEquals(self.request.getArgs(bytes_key), [])

0 comments on commit f3c3420

Please sign in to comment.