In [1]:
%reload_ext autoreload
%autoreload 2
%matplotlib inline

import numpy as np
import torch
import torchvision
import pytorch_lightning as pl

import os
import PIL.Image as Image
import base64
import io
import json
import requests
from urllib.parse import urlencode
from k12libs.utils.nb_easy import k12ai_get_top_dir, RACEURL
from pprint import pprint

np.__version__, torch.__version__, torchvision.__version__, pl.__version__

('1.19.4', '1.8.0.dev20210103+cu101', '0.9.0.dev20210103+cu101', '1.2.0')

In [2]:
model = 'Resnet18'
root_dir = '/raceai/data/tmp/pl_cleaner_robot_resnet18'
data_root = '/raceai/data/datasets/cleaner_robot/'
input_size = 224
mean = [
    0.7566,
    0.7323,
    0.7016
]
std = [
    0.1805,
    0.194,
    0.2195
]
num_classes = 3  # 0: 戒指 1: 数据线/电源线 2: 废纸/烟头
resume_weights = f"/raceai/data/ckpts/cleaner_robot/pl_resnet18_acc90.pth"

## RestAPI

post: http://116.85.5.40:9119/raceai/framework/inference

### 图片为URL方式输入(修改data_source)

**输入**:

```json
{
    "task": "cls.inference.pl", 
    "cfg": {
        "data": {
            "class_name": "raceai.data.process.PathListDataLoader", 
            "params": {
                "data_source": [
                    "https://jmy-pic.wejianzhan.com/0/pic/e621772e901d6d2f0ff072582e430173.jpg@f_webp,q_90", 
                    "http://5b0988e595225.cdn.sohucs.com/images/20190201/2976462bd67042d5b1b14c1d0f4a6717.jpeg", 
                    "/raceai/data/datasets/cleaner_robot/imgs/ring/ring_5.jpg", 
                    "/raceai/data/datasets/cleaner_robot/imgs/cable/cable_5.jpg"
                ], 
                "dataset": {
                    "class_name": "raceai.data.PredictListImageDataset", 
                    "params": {
                        "input_size": 224, 
                        "mean": [
                            0.7566, 
                            0.7323, 
                            0.7016
                        ], 
                        "std": [
                            0.1805, 
                            0.194, 
                            0.2195
                        ]
                    }
                }, 
                "sample": {
                    "batch_size": 32, 
                    "num_workers": 4
                }
            }
        }, 
        "model": {
            "class_name": "raceai.models.backbone.Resnet18", 
            "params": {
                "device": "gpu", 
                "num_classes": 3, 
                "weights": false
            }
        }, 
        "trainer": {
            "default_root_dir": "/raceai/data/tmp/pl_cleaner_robot_resnet18", 
            "gpus": 1, 
            "resume_from_checkpoint": "/raceai/data/ckpts/cleaner_robot/pl_resnet18_acc90.pth"
        }
    }
}
```

**输出**

只需要解析出`probs_sorted.indices[0]`对应的值, 通过该值来判断属于哪个分类

```json
{
    "errno": 0, 
    "result": [
        {
            "image_id": "-1", 
            "image_path": "/tmp/tmpadvoxy2h/39c1e05e0371cde9.jpg", 
            "probs": [
                0.0023420988582074642,  // 该图片是戒指的概率
                0.9976153373718262,     // 该图片是数据线/电源线的概率
                0.00004260418063495308  // 该图片是废纸/烟头的概率
            ], 
            "probs_sorted": {
                "indices": [            // 概率排序后对应的下标
                    1, 
                    0, 
                    2
                ], 
                "values": [             // 概率排序
                    0.9976153373718262, 
                    0.0023420988582074642, 
                    0.00004260418063495308
                ]
            }
        }, 
        {
            "image_id": "-1", 
            "image_path": "/tmp/tmpadvoxy2h/e621772e901d6d2f0ff072582e430173.jpg%40f_webp%2Cq_90", 
            "probs": [
                0.9999988079071045, 
                0.0000012288410289329477, 
                4.06742373115776e-8
            ], 
            "probs_sorted": {
                "indices": [
                    0, 
                    1, 
                    2
                ], 
                "values": [
                    0.9999988079071045, 
                    0.0000012288410289329477, 
                    4.06742373115776e-8
                ]
            }
        }, 
        {
            "image_id": "-1", 
            "image_path": "/tmp/tmpadvoxy2h/2976462bd67042d5b1b14c1d0f4a6717.jpeg", 
            "probs": [
                5.383871126696249e-9, 
                3.1869282679508615e-7, 
                0.9999996423721313
            ], 
            "probs_sorted": {
                "indices": [
                    2, 
                    1, 
                    0
                ], 
                "values": [
                    0.9999996423721313, 
                    3.1869282679508615e-7, 
                    5.383871126696249e-9
                ]
            }
        }, 
        {
            "image_id": "-1", 
            "image_path": "/raceai/data/datasets/cleaner_robot/imgs/ring/ring_5.jpg", 
            "probs": [
                0.9998032450675964, 
                0.00008256352884927765, 
                0.00011420287773944438
            ], 
            "probs_sorted": {
                "indices": [
                    0, 
                    2, 
                    1
                ], 
                "values": [
                    0.9998032450675964, 
                    0.00011420287773944438, 
                    0.00008256352884927765
                ]
            }
        }, 
        {
            "image_id": "-1", 
            "image_path": "/raceai/data/datasets/cleaner_robot/imgs/cable/cable_5.jpg", 
            "probs": [
                0.00011358539632055908, 
                0.9997881054878235, 
                0.00009832031355472282
            ], 
            "probs_sorted": {
                "indices": [
                    1, 
                    0, 
                    2
                ], 
                "values": [
                    0.9997881054878235, 
                    0.00011358539632055908, 
                    0.00009832031355472282
                ]
            }
        }
    ]
}
```

### 样例

In [3]:
reqdata = {
    "task": "cls.inference.pl",
    "cfg": {
        "data": {
            "class_name": "raceai.data.process.PathListDataLoader",  
            "params": {
                "data_source": [
                    "https://img14.360buyimg.com/n0/jfs/t1/115543/8/11330/109373/5efc3fcdE2dfb9242/39c1e05e0371cde9.jpg",
                    "https://jmy-pic.wejianzhan.com/0/pic/e621772e901d6d2f0ff072582e430173.jpg@f_webp,q_90",
                    "http://5b0988e595225.cdn.sohucs.com/images/20190201/2976462bd67042d5b1b14c1d0f4a6717.jpeg",
                    "/raceai/data/datasets/cleaner_robot/imgs/ring/ring_5.jpg",
                    "/raceai/data/datasets/cleaner_robot/imgs/cable/cable_5.jpg",
                ],
                "dataset": {
                    "class_name": "raceai.data.PredictListImageDataset",
                     "params": {
                         "input_size": input_size,
                         "mean": mean,
                         "std": std
                     }
                 },
                "sample": {
                    "batch_size": 32,
                    "num_workers": 4,
                }
             }
        },
        "model": {
            "class_name": f"raceai.models.backbone.{model}",  
            "params": {
                "device": 'gpu',
                "num_classes": num_classes,
                "weights": False
            }
        },
        "trainer": {
            "default_root_dir": root_dir,
            "gpus": 1,
            "resume_from_checkpoint": resume_weights
        }
    }
}

uri = f'{RACEURL}/raceai/framework/inference'

resdata = json.loads(requests.post(url=uri, json=reqdata).text)
if 0 != resdata['errno']:
    if 'traceback' in resdata['result']:
        print(resdata['result']['traceback'])
    else:
        print(resdata)
else:
    print(resdata)

{'errno': 0, 'result': [{'image_id': '-1', 'image_path': '/tmp/tmp1e8xa9i0/39c1e05e0371cde9.jpg', 'probs': [0.0023420988582074642, 0.9976153373718262, 4.260418063495308e-05], 'probs_sorted': {'indices': [1, 0, 2], 'values': [0.9976153373718262, 0.0023420988582074642, 4.260418063495308e-05]}}, {'image_id': '-1', 'image_path': '/tmp/tmp1e8xa9i0/e621772e901d6d2f0ff072582e430173.jpg%40f_webp%2Cq_90', 'probs': [0.9999988079071045, 1.2288410289329477e-06, 4.06742373115776e-08], 'probs_sorted': {'indices': [0, 1, 2], 'values': [0.9999988079071045, 1.2288410289329477e-06, 4.06742373115776e-08]}}, {'image_id': '-1', 'image_path': '/tmp/tmp1e8xa9i0/2976462bd67042d5b1b14c1d0f4a6717.jpeg', 'probs': [5.383871126696249e-09, 3.1869282679508615e-07, 0.9999996423721313], 'probs_sorted': {'indices': [2, 1, 0], 'values': [0.9999996423721313, 3.1869282679508615e-07, 5.383871126696249e-09]}}, {'image_id': '-1', 'image_path': '/raceai/data/datasets/cleaner_robot/imgs/ring/ring_5.jpg', 'probs': [0.9998032450

### 图片为Base64方式输入

1. cfg.data.class_name: `raceai.data.process.Base64DataLoader`

2. cfg.data.params.data_source: `b64string`

### 样例

In [4]:
with open('/raceai/data/datasets/cleaner_robot/imgs/cable/cable_5.jpg', 'rb') as fr:
    b64data = base64.b64encode(fr.read()).decode()
    
reqdata = {
    "task": "cls.inference.pl",
    "cfg": {
        "data": {
            "class_name": "raceai.data.process.Base64DataLoader",  
            "params": {
                "data_source": "\"" + b64data + "\"",
                "dataset": {
                    "class_name": "raceai.data.PredictListImageDataset",
                     "params": {
                         "input_size": input_size,
                         "mean": mean,
                         "std": std
                     }
                 },
                "sample": {
                    "batch_size": 32,
                    "num_workers": 4,
                }
             }
        },
        "model": {
            "class_name": f"raceai.models.backbone.{model}",  
            "params": {
                "device": 'gpu',
                "num_classes": num_classes,
                "weights": False
            }
        },
        "trainer": {
            "default_root_dir": root_dir,
            "gpus": 1,
            "resume_from_checkpoint": resume_weights
        }
    }
}

uri = f'{RACEURL}/raceai/framework/inference'

resdata = json.loads(requests.post(url=uri, json=reqdata).text)
if 0 != resdata['errno']:
    if 'traceback' in resdata['result']:
        print(resdata['result']['traceback'])
    else:
        print(resdata)
else:
    print(resdata)

{'errno': 0, 'result': [{'image_id': '-1', 'image_path': '/tmp/b4img_1614672344.png', 'probs': [0.00011358528718119487, 0.9997881054878235, 9.832012437982485e-05], 'probs_sorted': {'indices': [1, 0, 2], 'values': [0.9997881054878235, 0.00011358528718119487, 9.832012437982485e-05]}}]}
