In [65]:
import os
import csv
import json
import requests
from random import sample
from urllib.request import urlretrieve
from k12libs.utils.nb_easy import k12ai_get_top_dir, RACEURL
from k12libs.utils.nb_easy import K12AI_DATASETS_ROOT

In [75]:
R_PREFIX = 'https://raceai.s3.didiyunapi.com/data/datasets/cv/rgarbage'
L_PREFIX = '/raceai/data/datasets/rgarbage'
CSV_FILE = f'{R_PREFIX}/test.csv'

API_INFERENCE = f'{RACEURL}/raceai/framework/inference'

In [76]:
csv_file = urlretrieve(CSV_FILE, os.path.join('/tmp/', os.path.basename(CSV_FILE)))
with open(csv_file[0], 'r') as fr:
    reader = csv.DictReader(fr)
    data = [row for row in reader]
data_10  = sample(data, 10) 
data_10[0]['image_path']

'imgs/可回收物/可回收物_不锈钢制品/img_不锈钢制品_188.jpeg'

In [135]:
model = 'Resnet18'
root_dir = '/raceai/data/tmp/pl_rgarbage_resnet18'
input_size = 224
mean = [
    0.6535,
    0.6132,
    0.5643
]
std = [
    0.2165,
    0.2244,
    0.2416
]
num_classes = 4
resume_weights = f"/raceai/data/ckpts/rgarbage/pl_resnet18_acc70.pth"

reqdata = '''{
    "task": "cls.inference.pl",
    "cfg": {
        "data": {
            "class_name": "raceai.data.process.PathListDataLoader",
            "params": {
                "data_source": %s,
                "dataset": {
                    "class_name": "raceai.data.PredictListImageDataset",
                     "params": {
                         "input_size": input_size,
                         "mean": mean,
                         "std": std
                     }
                 },
                "sample": {
                    "batch_size": 32,
                    "num_workers": 4,
                }
             }
        },
        "model": {
            "class_name": f"raceai.models.backbone.{model}",  
            "params": {
                "device": 'gpu',
                "num_classes": num_classes,
                "weights": False
            }
        },
        "trainer": {
            "default_root_dir": root_dir,
            "gpus": 1,
            "resume_from_checkpoint": resume_weights
        }
    }
}'''

## Single Image

In [136]:
for row in data_10:
    imgpath = os.path.join(L_PREFIX, row['image_path'])
    cfg = eval(reqdata % ("[\"" + imgpath + "\"]"))
    resdata = json.loads(requests.post(url=API_INFERENCE, json=cfg).text)
    if resdata['errno'] == 0:
        result = resdata['result'][0]
        fname = result['fname']
        probs = result['probs']
        print('{}: {} vs {}'.format(fname, row['label'], probs.index(max(probs))))

img_不锈钢制品_188.jpeg: 2 vs 2
img_毛巾_290.jpeg: 0 vs 0
img_纽扣电池_78.jpeg: 3 vs 2
img_药片_48.jpeg: 3 vs 3
img_面包_39.jpeg: 1 vs 1
img_面包_232.jpeg: 1 vs 1
img_电池_529.jpeg: 3 vs 3
img_红花油_28.jpeg: 3 vs 3
img_汉堡_109.jpeg: 1 vs 3
img_药片_73.jpeg: 3 vs 0


## Batch Images

In [131]:
images = [os.path.join(L_PREFIX, row['image_path']) for row in data_10]
labels = [row['label'] for row in data_10]
cfg = eval(reqdata % json.dumps(images, ensure_ascii=False))
resdata = json.loads(requests.post(url=API_INFERENCE, json=cfg).text)
if resdata['errno'] == 0:
    for item, label in zip(resdata['result'], labels):
        fname = item['fname']
        probs = item['probs']
        print('{}: {} vs {}'.format(fname, label, probs.index(max(probs))))
resdata

img_不锈钢制品_188.jpeg: 2 vs 2
img_毛巾_290.jpeg: 0 vs 0
img_纽扣电池_78.jpeg: 3 vs 3
img_药片_48.jpeg: 3 vs 3
img_面包_39.jpeg: 1 vs 1
img_面包_232.jpeg: 1 vs 1
img_电池_529.jpeg: 3 vs 3
img_红花油_28.jpeg: 3 vs 3
img_汉堡_109.jpeg: 1 vs 1
img_药片_73.jpeg: 3 vs 3


{'errno': 0,
 'result': [{'fname': 'img_不锈钢制品_188.jpeg',
   'probs': [9.478302672505379e-05,
    0.0001377913577016443,
    0.999762237071991,
    5.1964780141133815e-06]},
  {'fname': 'img_毛巾_290.jpeg',
   'probs': [0.7620434165000916,
    0.23186975717544556,
    0.005751080345362425,
    0.00033575136330910027]},
  {'fname': 'img_纽扣电池_78.jpeg',
   'probs': [0.006104649975895882,
    7.1398867476091255e-06,
    0.3644082546234131,
    0.6294800043106079]},
  {'fname': 'img_药片_48.jpeg',
   'probs': [0.16693489253520966,
    0.054228123277425766,
    0.05048124119639397,
    0.7283557653427124]},
  {'fname': 'img_面包_39.jpeg',
   'probs': [4.2797486798917816e-07,
    0.9985795021057129,
    2.1652081159118097e-06,
    0.0014179437421262264]},
  {'fname': 'img_面包_232.jpeg',
   'probs': [0.002373164752498269,
    0.9967215657234192,
    0.0009043923346325755,
    7.759157938380667e-07]},
  {'fname': 'img_电池_529.jpeg',
   'probs': [0.009048981592059135,
    0.00017791068239603192,
    7.48