Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 43 additions & 1 deletion backend/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,10 @@
from translate.graph import Graph

app = Flask(__name__)
app.config['UPLOAD_EXTENSIONS'] = ['.h5']
app.config['UPLOAD_PATH'] = 'uploads'
ok_status = 200
error_status = 400
json_type = {'ContentType': 'application/json'}
text_type = {'ContentType': 'text/plain'}

Expand Down Expand Up @@ -60,6 +63,20 @@ def replace_references(net):
outp.append(i)
layer.output = outp

def check_uploads_path_exists(identifier):
if not os.path.exists(app.config['UPLOAD_PATH']):
os.mkdir(app.config['UPLOAD_PATH'])
if not os.path.exists(os.path.join(app.config['UPLOAD_PATH'], identifier)):
os.mkdir(os.path.join(app.config['UPLOAD_PATH'], identifier))
os.mkdir(os.path.join(app.config['UPLOAD_PATH'], identifier, 'visualizations'))
copyfile(os.path.join('default', 'layer_types_current.json'),
os.path.join(app.config['UPLOAD_PATH'], identifier, 'layer_types_current.json'))
copyfile(os.path.join('default', 'preferences.json'),
os.path.join(app.config['UPLOAD_PATH'], identifier, 'preferences.json'))
copyfile(os.path.join('default', 'groups.json'),
os.path.join(app.config['UPLOAD_PATH'], identifier, 'groups.json'))
copyfile(os.path.join('default', 'legend_preferences.json'),
os.path.join(app.config['UPLOAD_PATH'], identifier, 'legend_preferences.json'))

def check_exists(identifier):
"""Check if the desired model already exists.
Expand Down Expand Up @@ -149,7 +166,12 @@ def get_network(identifier):
object -- a http response containing the network as json
"""
check_exists(identifier)
graph = translate_keras(os.path.join('models', identifier,
check_uploads_path_exists(identifier)
if 'model.h5' in ls(os.path.join(app.config['UPLOAD_PATH'], identifier)):
graph = translate_keras(os.path.join(app.config['UPLOAD_PATH'], identifier,
'model.h5'))
else:
graph = translate_keras(os.path.join('models', identifier,
'model_current.py'))
if isinstance(graph, Graph):
net = {'layers': make_jsonifyable(graph)}
Expand Down Expand Up @@ -192,6 +214,26 @@ def update_code(identifier):
file.write(content.decode("utf-8"))
return content, ok_status, text_type

@app.route('/api/upload_model/<identifier>', methods=['POST'])
def upload_model(identifier):
"""Update the Code.

Arguments:
identifier {String} -- the identifier for the requested network

Returns:
object -- a http response signaling the change worked
"""
check_uploads_path_exists(identifier)
uploaded_file = request.files['model']
filename = uploaded_file.filename
file_ext = os.path.splitext(filename)[1]
if file_ext not in app.config['UPLOAD_EXTENSIONS']:
return "", error_status, text_type
file_path = os.path.join(app.config['UPLOAD_PATH'], identifier, 'model.h5')
uploaded_file.save(file_path)
return "", ok_status, text_type


@app.route('/api/get_layer_types/<identifier>')
def get_layer_types(identifier):
Expand Down
33 changes: 26 additions & 7 deletions backend/translate/translate_keras.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
from translate.graph import Graph
import translate.layer as layer

keras_ext = '.h5'


def translate_keras(filename):
"""Translate a keras model defined in a file into the neural network graph.
Expand All @@ -19,13 +21,16 @@ def translate_keras(filename):
epicbox.Profile('python', 'tf_plus_keras:latest')])
general_reader = open('translate/keras_loader.txt', 'rb')
general_code = general_reader.read()
with open(filename, 'rb') as myfile:
keras_code = myfile.read()
try:
return graph_from_external_file(keras_code, general_code)
except Exception as err:
return {'error_class': '', 'line_number': 1,
'detail': str(err)}
if keras_ext in filename:
return graph_from_model_file(filename)
else:
with open(filename, 'rb') as myfile:
keras_code = myfile.read()
try:
return graph_from_external_file(keras_code, general_code)
except Exception as err:
return {'error_class': '', 'line_number': 1,
'detail': str(err)}


def graph_from_external_file(keras_code, general_code):
Expand Down Expand Up @@ -61,6 +66,20 @@ def graph_from_external_file(keras_code, general_code):
graph.resolve_input_names()
return graph

def graph_from_model_file(keras_model_file):
model_keras = keras.models.load_model(keras_model_file)
model_json = model_keras.to_json()
layers_extracted = model_json['config']['layers']
graph = Graph()
previous_node = ''
for index, json_layer in enumerate(layers_extracted):
if len(layers_extracted) > len(model_keras.layers):
index = index - 1
if index >= 0:
previous_node = add_layer_type(json_layer, model_keras.layers[index], graph,
previous_node)
graph.resolve_input_names()
return graph

def add_layer_type(layer_json, model_layer, graph, previous_node):
"""Add a Layer. Layers are identified by name and equipped using the spec.
Expand Down