-
Notifications
You must be signed in to change notification settings - Fork 510
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add example for inference server #994
Changes from 8 commits
da3dcb5
6abbf14
d5d927d
641e8fa
5a5e532
4e83f6d
effd079
38631c7
5e68724
4ff088f
b94c05b
fed2b8d
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,18 @@ | ||
import os | ||
import sys | ||
|
||
INFERENCE_RESULT_MARKER = "INFERENCE RESULT:" | ||
|
||
|
||
def run_inference(image_path): | ||
# Perform some computation on the image located at image_path | ||
|
||
# Instead of returning the result, | ||
# print it to stdout so that the server can retrieve the result from the logs | ||
print( | ||
f"{INFERENCE_RESULT_MARKER}Ran inference on the image at '{image_path}' with size {os.path.getsize(image_path)}B." | ||
) | ||
|
||
|
||
if __name__ == "__main__": | ||
run_inference(sys.argv[1]) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
flask |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,99 @@ | ||
"""Flask inference server. | ||
|
||
Implements a Flask server that handles inference requests for some input via a HTTP handler. | ||
To run the server, run the following command from the root directory: | ||
`FLASK_APP=examples/inference_server/server.py flask run` | ||
""" | ||
|
||
import os | ||
import random | ||
import re | ||
import string | ||
import subprocess | ||
|
||
from flask import Flask, request, abort | ||
from werkzeug.utils import secure_filename | ||
iojw marked this conversation as resolved.
Show resolved
Hide resolved
|
||
import sky | ||
iojw marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
from inference import INFERENCE_RESULT_MARKER | ||
iojw marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
LOCAL_UPLOAD_FOLDER = '/Users/isaac/Dropbox/Berkeley/Sky/sky/examples/inference_server/uploads' | ||
REMOTE_UPLOAD_FOLDER = '/remote/path/to/folder' | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can we make this example easier to run? It would be nice if users can run this example without any modification to the code. |
||
ALLOWED_EXTENSIONS = {'png', 'jpg', 'jpeg'} | ||
|
||
app = Flask(__name__) | ||
app.config['UPLOAD_FOLDER'] = LOCAL_UPLOAD_FOLDER | ||
|
||
|
||
def run_output(cmd: str, **kwargs) -> str: | ||
completed_process = subprocess.run(cmd, | ||
stdout=subprocess.PIPE, | ||
shell=True, | ||
check=True, | ||
**kwargs) | ||
return completed_process.stdout.decode("ascii").strip() | ||
|
||
|
||
def allowed_file(filename): | ||
return '.' in filename and \ | ||
filename.rsplit('.', 1)[1].lower() in ALLOWED_EXTENSIONS | ||
iojw marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
|
||
@app.route("/", methods=["GET", "POST"]) | ||
def run_inference(): | ||
if request.method == 'POST': | ||
image = request.files['file'] | ||
if not image or not allowed_file(image.filename): | ||
abort(400, "Invalid image upload") | ||
|
||
filename = secure_filename(image.filename) | ||
local_image_path = os.path.join(app.config['UPLOAD_FOLDER'], filename) | ||
remote_image_path = os.path.join(REMOTE_UPLOAD_FOLDER, filename) | ||
|
||
image.save(local_image_path) | ||
|
||
with sky.Dag() as dag: | ||
workdir = os.path.dirname(os.path.abspath(__file__)) | ||
task_name = "inference_task" | ||
iojw marked this conversation as resolved.
Show resolved
Hide resolved
|
||
setup = 'pip3 install --upgrade pip' | ||
run_fn = f"python inference.py {remote_image_path}" | ||
|
||
task = sky.Task(name=task_name, | ||
setup=setup, | ||
workdir=workdir, | ||
run=run_fn) | ||
|
||
resources = sky.Resources(cloud=sky.Azure()) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is there any reason for using Azure here? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I only had access to Azure at the time, so I used it for easier testing locally. I recently gained access to AWS as well - would you suggest to leave it as the default in this case? |
||
task.set_resources(resources) | ||
task.set_file_mounts({ | ||
# Copy model weights to the cluster | ||
# Instead of local path, can also specify a cloud object store URI | ||
'/remote/path/to/model-weights': 'local/path/to/model-weights', | ||
# Copy image to the cluster | ||
remote_image_path: local_image_path, | ||
Comment on lines
+70
to
+74
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Nice point! Our storage APIs are indeed useful in sending model weights and inputs. |
||
}) | ||
|
||
cluster_name = f"inference-cluster-{''.join(random.choice(string.ascii_lowercase) for _ in range(10))}" | ||
sky.launch( | ||
dag, | ||
cluster_name=cluster_name, | ||
detach_run=True, | ||
) | ||
|
||
cmd_output = run_output(f"sky logs {cluster_name}") | ||
inference_result = re.findall(f'{INFERENCE_RESULT_MARKER}((?:[^\n])+)', | ||
cmd_output) | ||
|
||
# Down the cluster in the background | ||
subprocess.Popen(f"sky down -y {cluster_name}", shell=True) | ||
|
||
return {"result": inference_result} | ||
elif request.method == 'GET': | ||
return ''' | ||
<title>Upload Image</title> | ||
<h1>Upload Image</h1> | ||
<form method=post enctype=multipart/form-data> | ||
<input type=file name=file> | ||
<input type=submit value=Upload> | ||
</form> | ||
''' |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we have a more realistic example in addition to (or instead of) this one? As pointed out in
server.py
, we need to send model weights and inputs to the cluster and get the prediction outputs from the cluster. Why don't we show such a complete example here?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That makes sense! I went with a simpler example because I don't have much experience with ML and inference - what do you think would be a good library and model to implement here?