forked from jantic/DeOldify
-
Notifications
You must be signed in to change notification settings - Fork 0
/
app.py
88 lines (62 loc) · 2.35 KB
/
app.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
# import the necessary packages
import os
import sys
import requests
import ssl
from flask import Flask
from flask import request
from flask import jsonify
from flask import send_file
from uuid import uuid4
from os import path
import torch
import fastai
from fasterai.visualize import *
from pathlib import Path
import traceback
torch.backends.cudnn.benchmark=True
image_colorizer = get_image_colorizer(artistic=True)
video_colorizer = get_video_colorizer()
os.environ['CUDA_VISIBLE_DEVICES']='0'
app = Flask(__name__)
# define a predict function as an endpoint
@app.route("/process_image", methods=["POST"])
def process_image():
try:
source_url = request.json["source_url"]
render_factor = int(request.json["render_factor"])
upload_directory = 'upload'
if not os.path.exists(upload_directory):
os.mkdir(upload_directory)
random_filename = str(uuid4()) + '.png'
image_colorizer.plot_transformed_image_from_url(url=source_url, path=os.path.join(upload_directory, random_filename), figsize=(20,20),
render_factor=render_factor, display_render_factor=True, compare=False)
callback = send_file(os.path.join("result_images", random_filename), mimetype='image/jpeg')
return callback
except:
traceback.print_exc()
return {message: 'input error'}, 400
finally:
os.remove(os.path.join("result_images", random_filename))
os.remove(os.path.join("upload", random_filename))
@app.route("/process_video", methods=["POST"])
def process_video():
try:
source_url = request.json["source_url"]
render_factor = int(request.json["render_factor"])
upload_directory = 'upload'
if not os.path.exists(upload_directory):
os.mkdir(upload_directory)
random_filename = str(uuid4()) + '.mp4'
video_path = video_colorizer.colorize_from_url(source_url, random_filename, render_factor)
callback = send_file(os.path.join("video/result/", random_filename), mimetype='application/octet-stream')
return callback
except:
traceback.print_exc()
return {message: 'input error'}, 400
finally:
os.remove(os.path.join("video/result/", random_filename))
if __name__ == '__main__':
port = 5000
host = '0.0.0.0'
app.run(host=host, port=port, threaded=True)