In [1]:
import os
import argparse
from typing import Optional

import cv2
import numpy as np
from PIL import Image
import torch
import torch.backends.cudnn as cudnn
import torch.optim
import torch.utils.data
import torchvision
print("PyTorch version:", torch.__version__)
print("Torchvision version:", torchvision.__version__)
print("CUDA is available:", torch.cuda.is_available())

PyTorch version: 2.0.0+cu117
Torchvision version: 0.15.1+cu117
CUDA is available: True


In [2]:
from segment_anything import sam_model_registry, SamPredictor

os.environ['TORCH_HOME'] = '/sunjinsheng/home_torch'

class SAM():
    def __init__(self, model_type, sam_ckpt, device='cuda') -> None:
        # cudnn related setting
        self.sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)
        self.sam.to(device=device)

        self.predictor = SamPredictor(self.sam)
        
    def set_image(self, image):
        timer_start = torch.cuda.Event(enable_timing=True)
        timer_end = torch.cuda.Event(enable_timing=True)
        timer_start.record()
        self.predictor.set_image(image)
        timer_end.record()
        torch.cuda.synchronize()

        return timer_start.elapsed_time(timer_end)

    def get_masks(
        self, 
        point_coords: Optional[np.ndarray] = None,
        point_labels: Optional[np.ndarray] = None,
        box: Optional[np.ndarray] = None,
        mask_input: Optional[np.ndarray] = None,
        multimask_output: bool = True,
        return_logits: bool = False,
    ):
        timer_start = torch.cuda.Event(enable_timing=True)
        timer_end = torch.cuda.Event(enable_timing=True)
        timer_start.record()
        masks, scores, logits = self.predictor.predict(
            point_coords=point_coords,
            point_labels=point_labels,
            box=box,
            mask_input=mask_input,
            multimask_output=multimask_output,
            return_logits=return_logits,
        )
        timer_end.record()
        return masks, scores, logits, timer_start.elapsed_time(timer_end)

In [3]:
sam_checkpoint = "/sun/home_models/sam/sam_vit_h_4b8939.pth"
model_type = "vit_h"
device = "cuda"
sam = SAM(model_type,sam_checkpoint,device='cuda')

In [4]:
image = cv2.imread('notebooks/images/truck.jpg')
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
sam.set_image(image)

3520.907470703125

In [5]:
input_point = np.array([[500, 375]])
input_label = np.array([1])

masks, scores, logits = sam.predictor.predict(
    point_coords=input_point,
    point_labels=input_label,
    multimask_output=True,
)

In [6]:
from flask import request, Flask
import os
import cv2
import time
import base64
import numpy as np
app = Flask(__name__)	# 必须要写

TYPE_IMAGE = 1
TYPE_INTERACTION = 2

@app.route("/", methods=['POST'])	
def get_frame():	# 客户端通过端口访问的时候，会直接调用这个函数
    flag = int(request.form.get('flag'))
    # print(flag,type(flag))
#     print(request.form.get('img'))
    if TYPE_IMAGE == flag: 
        img_encode = request.form.get('img')
        img = base64.b64decode(img_encode.encode('utf-8'))
        # print(img)
        image_data = np.frombuffer(img, np.uint8)
        # print(image_data)
        image_data = cv2.imdecode(image_data, cv2.IMREAD_COLOR)
        # print(type(image_data))
        # time_start = time.time()
        # try:
        time_used = sam.set_image(image_data)	# 图像处理入口
        # except Exception as e:
        #     return str(dict(flag=0))
        # else:
        return str(dict(flag=1,time_used=time_used))
        # time_end = time.time()
        # print(time_end-time_start)
    elif TYPE_INTERACTION == flag:
        data_point_coords = request.form.get('point_coords')
        data_point_labels = request.form.get('point_labels')
        data_point_coords = base64.b64decode(data_point_coords.encode('utf-8'))
        data_point_labels = base64.b64decode(data_point_labels.encode('utf-8'))
        point_coords = np.frombuffer(data_point_coords, dtype=np.int32).reshape((-1,2))
        point_labels = np.frombuffer(data_point_labels, dtype=np.bool8).reshape([-1])
        data_multimask_output = request.form.get('multimask_output')
        multimask_output = bool(int(data_multimask_output))
        # print(point_coords.shape)
        # print(point_labels.shape)
        # data_box = request.form.get('box')
        # data_box = base64.b64decode(data_box.encode('utf-8'))
        # data_mask_input = request.form.get('mask_input')
        # data_mask_input = base64.b64decode(data_mask_input.encode('utf-8'))
        # data_return_logits = request.form.get('return_logits')
        # print(data_point_coords)
        masks, scores, logits,time_used = sam.get_masks(
            point_coords=point_coords,
            point_labels=point_labels,
            multimask_output=multimask_output,
        )
        data_masks = masks.tobytes() # np.bool8
        data_masks = base64.b64encode(data_masks).decode('utf-8')
        data_scores = scores.tobytes() # np.float32
        data_scores = base64.b64encode(data_scores).decode('utf-8')
        data_logits = logits.tobytes() # np.float32
        data_logits = base64.b64encode(data_logits).decode('utf-8')
        return str(dict(
            flag=2,
            masks = data_masks,
            scores = data_scores, 
            logits = data_logits,
            time_used = time_used,
        ))
    return '0'

In [None]:
if __name__ == "__main__":
    
    # str_port = os.getenv('PAI_CONTAINER_HOST_http_face3_PORT_LIST')
    app.run("0.0.0.0", port=11083)	# 设置IP和端口

 * Serving Flask app '__main__'
 * Debug mode: off


 * Running on all addresses (0.0.0.0)
 * Running on http://127.0.0.1:11083
 * Running on http://172.17.0.3:11083
[33mPress CTRL+C to quit[0m
172.17.0.1 - - [26/Apr/2023 18:29:35] "POST / HTTP/1.1" 200 -
  labels_torch = torch.as_tensor(point_labels, dtype=torch.int, device=self.device)
172.17.0.1 - - [26/Apr/2023 18:29:40] "POST / HTTP/1.1" 200 -
172.17.0.1 - - [26/Apr/2023 18:29:43] "POST / HTTP/1.1" 200 -
172.17.0.1 - - [26/Apr/2023 18:30:26] "POST / HTTP/1.1" 200 -
172.17.0.1 - - [26/Apr/2023 18:30:29] "POST / HTTP/1.1" 200 -
172.17.0.1 - - [26/Apr/2023 18:30:31] "POST / HTTP/1.1" 200 -
