Skip to content
This repository was archived by the owner on Jul 7, 2023. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions tensor2tensor/insights/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -42,14 +42,14 @@ Start guide, a sample configuration would be:
"hparams": "",
"hparams_set": "transformer_base_single_gpu",
"problem": "translate_ende_wmt32k"
},
}]
}
}],
"language": [{
"code": "en",
"name": "English",
"name": "English"
},{
"code": "de",
"name": "German",
"name": "German"
}]
}
```
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -150,5 +150,6 @@ <h4>Rapid Response</h4>
on-response="handleTranslationResponse_">
</iron-ajax>
</template>
<script src="../d3/d3.js"></script>
<script src="explore-view.js"></script>
</dom-module>
22 changes: 21 additions & 1 deletion tensor2tensor/insights/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@
from flask import jsonify
from flask import request
from flask import send_from_directory
from flask.json import JSONEncoder
import numpy as np
from gunicorn.app.base import BaseApplication
from gunicorn.six import iteritems
from tensor2tensor.insights import transformer_model
Expand All @@ -36,6 +38,23 @@
"Path to static javascript and html files to serve.")


_NUMPY_INT_DTYPES = [
np.int8, np.int16, np.int32, np.int64
]
_NUMPY_FP_DTYPES = [
np.float16, np.float32, np.float64
]
class NumpySerializationFix(JSONEncoder):
"""json module cannot serialize numpy datatypes, reinterpret them first"""
def default(self, obj):
obj_type = type(obj)
if obj_type in _NUMPY_INT_DTYPES:
return int(obj)
if obj_type in _NUMPY_FP_DTYPES:
return float(obj)
return json.JSONEncoder.default(self, obj)


class DebugFrontendApplication(BaseApplication):
"""A local custom application for GUnicorns.

Expand Down Expand Up @@ -100,6 +119,7 @@ def main(_):
__name__.split(".")[0],
static_url_path="/polymer",
static_folder=FLAGS.static_path)
app.json_encoder = NumpySerializationFix

# Disable static file caching.
app.config["SEND_FILE_MAX_AGE_DEFAULT"] = 0
Expand All @@ -112,7 +132,7 @@ def language_list(): # pylint: disable=unused-variable
JSON for the languages.
"""
return jsonify({
"language": languages.values()
"language": list(languages.values())
})

@app.route("/api/list_models/")
Expand Down