Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feature: api server #3

Merged
merged 6 commits into from
Feb 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .gitignore
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
*.pyc
*.pkl
.vscode/

assets

1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ The format is based on [Keep a Changelog](https://github.com/olivierlacan/keep-a

- Panorama pre-process from [HorizonNet](https://github.com/sunset1995/HorizonNet). (https://github.com/yushiang-demo/DuLa-Net/pull/1)
- Update packages and add `Dockerfile`, `requirements` to freeze versions. (https://github.com/yushiang-demo/DuLa-Net/pull/2)
- Implement api server. (https://github.com/yushiang-demo/DuLa-Net/pull/3)

### Changed

Expand Down
9 changes: 2 additions & 7 deletions Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,5 @@ COPY . /app

RUN pip install --no-cache-dir --upgrade pip

# https://pytorch.org/
RUN pip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118

#RUN pip install matplotlib scikit-image opencv-python pylsd-nova==1.2.0
RUN pip install --no-cache-dir -r requirements.txt

ENTRYPOINT ["python"]
#RUN pip install torch torchvision torchaudio matplotlib scikit-image opencv-python pylsd-nova==1.2.0 Celery redis Flask Flask-RESTx
RUN pip install --no-cache-dir -r requirements.txt
6 changes: 3 additions & 3 deletions Model/dulanet.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def forward_from_feats(self, x, feats):
return out

class DuLaNet(nn.Module):
def __init__(self, backbone):
def __init__(self, backbone, gpu=True):
super(DuLaNet, self).__init__()

self.model_equi = DulaNet_Branch(backbone)
Expand All @@ -74,10 +74,10 @@ def __init__(self, backbone):
nn.Linear(64, 1)
)

self.e2p = E2P(cf.pano_size, cf.fp_size, cf.fp_fov)
self.e2p = E2P(cf.pano_size, cf.fp_size, cf.fp_fov, gpu=gpu)

fuse_dim = [int((cf.pano_size[0]/32)*2**i) for i in range(6)]
self.e2ps_f = [E2P((n, n*2), n, cf.fp_fov) for n in fuse_dim]
self.e2ps_f = [E2P((n, n*2), n, cf.fp_fov, gpu=gpu) for n in fuse_dim]

def forward(self, pano_view):

Expand Down
Binary file modified README.md
Binary file not shown.
Empty file added api/__init__.py
Empty file.
16 changes: 16 additions & 0 deletions api/app.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
from flask import Flask, send_from_directory
from flask_restx import Api

app = Flask(__name__)
app.config['RESTX_MASK_SWAGGER'] = False
api = Api(app, version='1.0', title='DuLa-Net APIs', prefix='/api', base_url='/api')

import traceback
# Global error handler for all other exceptions
@api.errorhandler(Exception)
def handle_unexpected_error(error):
# Log the error for debugging purposes
app.logger.error('Unhandled Exception: %s', traceback.format_exc())

# Return a generic error response
return {'message': 'An unexpected error occurred'}, 500
1 change: 1 addition & 0 deletions api/constant.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
STATIC_FOLDER = 'assets/storage'
13 changes: 13 additions & 0 deletions api/models/Request.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
from flask_restx import reqparse

class Request():
def __init__(self):
self.parser = reqparse.RequestParser(bundle_errors=True)

def addFile(self, name, required):
self.parser.add_argument(
name,
required=required,
type=reqparse.FileStorage,
location='files'
)
18 changes: 18 additions & 0 deletions api/models/Task.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
from flask_restx import fields
from api.app import api

Layout = api.model('Layout',{
'data': fields.Raw,
})

Images = api.model('Images',{
'preview': fields.String,
'origin': fields.String,
'aligned': fields.String,
})

Task = api.model('Task', {
'uuid': fields.String,
'images': fields.Nested(Images),
'layout': fields.Nested(Layout),
})
7 changes: 7 additions & 0 deletions api/models/Tasks.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
from flask_restx import fields
from api.app import api
from .Task import Task

Tasks = api.model('Tasks', {
'tasks': fields.List(fields.Nested(Task))
})
3 changes: 3 additions & 0 deletions api/models/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from .Request import *
from .Task import *
from .Tasks import *
40 changes: 40 additions & 0 deletions api/resources/Task.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
import os
import uuid
from tasks import inference_from_file
from api.constant import STATIC_FOLDER

from api.app import api

from flask import request
from flask_restx import Resource
from api.models import Request, Task

input = Request()
input.addFile(name='file',required=True)


class Task(Resource):
@api.expect(input.parser)
@api.marshal_with(Task)
def post(self):
"""Create a task"""
args = input.parser.parse_args()
file = args.file
if file:

id = str(uuid.uuid4())
output = os.path.join(STATIC_FOLDER, id)
os.makedirs(output)
inference_from_file(file, output)

output = {
'uuid': id,
'images':{
'origin': f"{request.referrer}files/storage/{id}/image.jpg",
'preview': f"{request.referrer}files/storage/{id}/vis.jpg",
'aligned': f"{request.referrer}files/storage/{id}/raw.jpg",
}
}
return output, 200
else:
return {}, 400
14 changes: 14 additions & 0 deletions api/resources/Tasks.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
import os
from flask_restx import Resource

from api.constant import STATIC_FOLDER
from api.app import api
from api.models import Tasks

class Tasks(Resource):
@api.marshal_with(Tasks)
def get(self):
"""List all task_id"""
task_ids = os.listdir(STATIC_FOLDER)
return { "tasks": task_ids }, 200

2 changes: 2 additions & 0 deletions api/resources/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
from .Task import Task
from .Tasks import Tasks
32 changes: 32 additions & 0 deletions docker-compose.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
version: '3'

services:
nginx:
image: nginx:latest
ports:
- "80:80"
volumes:
- ./assets:/usr/share/nginx/html:ro
- ./nginx/nginx.conf:/etc/nginx/conf.d/default.conf:ro

redis:
image: redis:latest

server:
build:
dockerfile: Dockerfile
command: python server.py
volumes:
- ./:/app
depends_on:
- redis


worker:
build:
dockerfile: Dockerfile
command: celery -A tasks worker --loglevel=info
volumes:
- ./:/app
depends_on:
- redis
19 changes: 19 additions & 0 deletions nginx/nginx.conf
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
# nginx.conf

server {
listen 80;
server_name localhost;

location /files {
alias /usr/share/nginx/html;
index index.html;
}

location / {
proxy_pass http://server:5000;
proxy_set_header Host $host;
proxy_set_header X-Real-IP $remote_addr;
proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for;
proxy_set_header X-Forwarded-Proto $scheme;
}
}
Binary file modified requirements.txt
Binary file not shown.
15 changes: 15 additions & 0 deletions server.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
from flask import send_file
from flask_restx import Resource

from api.constant import STATIC_FOLDER
from api.app import app, api
from api.resources import Task, Tasks

admin = api.namespace('admin', description='Inspect system info.')
admin.add_resource(Tasks, '/tasks')

task = api.namespace('task', description='Run a DuLa-Net task.')
task.add_resource(Task,'/')

if __name__ == '__main__':
app.run(host='0.0.0.0', debug=True)
104 changes: 104 additions & 0 deletions tasks.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
import io
import os
import sys
import argparse

import numpy as np
from PIL import Image

import torch
from torch.autograd import Variable
from torchvision import transforms

import Layout
import Utils

import config as cf
from Model import DuLaNet, E2P

import postproc

from Preprocess.pano_lsd_align import panoEdgeDetection, rotatePanorama

import base64

gpu = False
device = torch.device('cuda' if torch.cuda.is_available() and gpu else 'cpu')

def preprocess(img, q_error=0.7, refine_iter=3):
img_ori = np.array(img.resize((1024, 512), Image.Resampling.BICUBIC))[..., :3]

# VP detection and line segment extraction
_, vp, _, _, panoEdge, _, _ = panoEdgeDetection(img_ori,
qError=q_error,
refineIter=refine_iter)
panoEdge = (panoEdge > 0)

# Align images with VP
i_img = rotatePanorama(img_ori / 255.0, vp[2::-1])

return Image.fromarray(np.uint8(i_img * 255.0))

def predict(model, img):
model.eval()

trans = transforms.Compose([
transforms.Resize((cf.pano_size)),
transforms.ToTensor()
])
color = torch.unsqueeze(trans(img), 0).to(device)

[fp, fc, h] = model(color)

e2p = E2P(cf.pano_size, cf.fp_size, cf.fp_fov, gpu=gpu)
[fc_up, fc_down] = e2p(fc)

[fp, fc_up, fc_down, h] = Utils.var2np([fp, fc_up, fc_down, h])
fp_pts, fp_pred = postproc.run(fp, fc_up, fc_down, h)

# Visualization
scene_pred = Layout.pts2scene(fp_pts, h)
edge = Layout.genLayoutEdgeMap(scene_pred, [512 , 1024, 3], dilat=2, blur=0)

img = img.resize((1024,512))
img = np.array(img, float) / 255
vis = img * 0.5 + edge * 0.5

vis = Image.fromarray(np.uint8(vis* 255))
return vis, scene_pred

from celery import Celery

BROKER_URL = 'redis://redis:6379/0'
BACKEND_URL = 'redis://redis:6379/0'
app = Celery('tasks', broker=BROKER_URL, backend=BACKEND_URL)

@app.task
def inference(image_data_base64, output, seed=224, backbone='resnet18',ckpt = './Model/ckpt/res18_realtor.pkl'):

# initialize DuLa-net
np.random.seed(seed)
torch.manual_seed(seed)

model = DuLaNet(backbone,gpu=gpu).to(device)

#model.load_state_dict(torch.load(args.ckpt))
model.load_state_dict(torch.load(ckpt, map_location=str(device)))

image_data = base64.b64decode(image_data_base64)
pil_image = Image.open(io.BytesIO(image_data))
img = preprocess(pil_image)
vis, scene_pred = predict(model, img)

pil_image.save(os.path.join(output,"raw.jpg"))
img.save(os.path.join(output,"image.jpg"))
vis.save(os.path.join(output,"vis.jpg"))
Layout.saveSceneAsJson(os.path.join(output,"layout.json"), scene_pred)
return output

def inference_from_file(file, output):
pil_image = Image.open(io.BytesIO(file.read()))
img_bytes = io.BytesIO()
pil_image.save(img_bytes, format='JPEG')
img_base64 = base64.b64encode(img_bytes.getvalue()).decode()
inference.delay(img_base64, output)