Skip to content
This repository has been archived by the owner on Mar 28, 2022. It is now read-only.

Commit

Permalink
Add support for gzip compression in request and response (#90)
Browse files Browse the repository at this point in the history
* add support for gzip compression in request and response

* py27 compatibility
  • Loading branch information
agermanidis committed Dec 27, 2019
1 parent 6956838 commit 59c9442
Show file tree
Hide file tree
Showing 4 changed files with 64 additions and 1 deletion.
1 change: 1 addition & 0 deletions requirements.txt
Expand Up @@ -8,3 +8,4 @@ colorcet>=2.0.1
Flask-Sockets==0.2.1
scipy>=1.2.1
urllib3[secure]>=1.25.7
flask-compress>=1.3.1
2 changes: 2 additions & 0 deletions runway/model.py
Expand Up @@ -14,6 +14,7 @@
from flask_sockets import Sockets
from gevent.pywsgi import WSGIServer
from geventwebsocket.handler import WebSocketHandler
from flask_compress import Compress
from .exceptions import RunwayError, MissingInputError, MissingOptionError, \
InferenceError, UnknownCommandError, SetupError
from .data_types import *
Expand Down Expand Up @@ -47,6 +48,7 @@ def __init__(self):
try: self.app.config['JSON_AS_ASCII'] = False
except TypeError: pass
CORS(self.app)
Compress(self.app)
self.define_error_handlers()
self.define_routes()

Expand Down
16 changes: 15 additions & 1 deletion runway/utils.py
Expand Up @@ -12,6 +12,7 @@
import urllib3
import multiprocessing
import certifi
import json
if sys.version_info[0] < 3:
from cStringIO import StringIO as IO
from urlparse import urlparse
Expand Down Expand Up @@ -42,7 +43,12 @@ def wrapped(*args, **kwargs):
return wrapped

def get_json_or_none_if_invalid(request):
return request.get_json(force=True, silent=True)
if request.headers.get('content-encoding') == 'gzip' and request.headers.get('content-type') == 'application/json':
data = request.get_data()
decompressed = gzip_decompress(data)
return json.loads(decompressed)
else:
return request.get_json(force=True, silent=True)

def serialize_command(cmd):
ret = {}
Expand Down Expand Up @@ -121,6 +127,14 @@ def extract_tarball(path):
return extracted_dir


def gzip_compress(data):
compressed_data = IO()
g = gzip.GzipFile(fileobj=compressed_data, mode='w')
g.write(data)
g.close()
return compressed_data.getvalue()


def gzip_decompress(data):
compressed_data = IO(data)
return gzip.GzipFile(fileobj=compressed_data, mode='r').read()
Expand Down
46 changes: 46 additions & 0 deletions tests/test_model.py
Expand Up @@ -9,15 +9,21 @@
import json
import pytest
import time
import gzip
from time import sleep
from runway.model import RunwayModel
from runway.__version__ import __version__ as model_sdk_version
from runway.data_types import category, text, number, array, image, vector, file, any as any_type
from runway.exceptions import *
from runway.utils import gzip_decompress, gzip_compress
from utils import *
from deepdiff import DeepDiff
from flask import abort
from multiprocessing import Process
if sys.version_info[0] < 3:
from cStringIO import StringIO as IO
else:
from io import BytesIO as IO

from pytest_cov.embed import cleanup_on_sigterm
cleanup_on_sigterm()
Expand Down Expand Up @@ -493,6 +499,46 @@ def times_two(model, args):
assert response.is_json
assert json.loads(response.data) == { 'output': 10 }

def test_post_command_json_mime_type_with_gzip():

rw = RunwayModel()

@rw.command('times_two', inputs={ 'input': number }, outputs={ 'output': number })
def times_two(model, args):
return args['input'] * 2

rw.run(debug=True)

client = get_test_client(rw)
headers = {
'content-type': 'application/json',
'content-encoding': 'gzip'
}
response = client.post('/times_two', data=gzip_compress(json.dumps({ 'input': 5 }).encode('utf-8')), headers=headers)
assert response.is_json
assert json.loads(response.data) == { 'output': 10 }

def test_post_command_json_mime_type_with_gzip_response():

rw = RunwayModel()
rw.app.config['COMPRESS_MIN_SIZE'] = 0

@rw.command('times_two', inputs={ 'input': number }, outputs={ 'output': number })
def times_two(model, args):
return args['input'] * 2

rw.run(debug=True)

client = get_test_client(rw)
response = client.post('/times_two',
json={ 'input': 5 },
headers={
'accept-encoding': 'gzip'
}
)
assert response.is_json
assert json.loads(gzip_decompress(response.data)) == { 'output': 10 }

def test_post_command_form_encoding():

rw = RunwayModel()
Expand Down

0 comments on commit 59c9442

Please sign in to comment.