In [1]:
import os
import csv
import json
import requests
from random import sample
from urllib.request import urlretrieve
from k12libs.utils.nb_easy import k12ai_get_top_dir, RACEURL
from k12libs.utils.nb_easy import K12AI_DATASETS_ROOT

In [2]:
R_PREFIX = 'https://raceai.s3.didiyunapi.com/data/datasets/cv/rgarbage'
L_PREFIX = '/raceai/data/datasets/rgarbage'
CSV_FILE = f'{R_PREFIX}/test.csv'

API_INFERENCE = f'{RACEURL}/raceai/framework/inference'

In [3]:
csv_file = urlretrieve(CSV_FILE, os.path.join('/tmp/', os.path.basename(CSV_FILE)))
with open(csv_file[0], 'r') as fr:
    reader = csv.DictReader(fr)
    data = [row for row in reader]
data_10  = sample(data, 10) 
data_10[0]['image_path']

'imgs/厨余垃圾/厨余垃圾_面包/汉堡_厨余垃圾/img_汉堡_66.jpeg'

In [4]:
model = 'Resnet18'
root_dir = '/raceai/data/tmp/pl_rgarbage_resnet18'
input_size = 224
mean = [
    0.6535,
    0.6132,
    0.5643
]
std = [
    0.2165,
    0.2244,
    0.2416
]
num_classes = 4
resume_weights = f"/raceai/data/ckpts/rgarbage/pl_resnet18_acc70.pth"

reqdata = '''{
    "task": "cls.inference.pl",
    "cfg": {
        "data": {
            "class_name": "raceai.data.process.PathListDataLoader",
            "params": {
                "data_source": %s,
                "dataset": {
                    "class_name": "raceai.data.PredictListImageDataset",
                     "params": {
                         "input_size": input_size,
                         "mean": mean,
                         "std": std
                     }
                 },
                "sample": {
                    "batch_size": 32,
                    "num_workers": 4,
                }
             }
        },
        "model": {
            "class_name": f"raceai.models.backbone.{model}",  
            "params": {
                "device": 'gpu',
                "num_classes": num_classes,
                "weights": False
            }
        },
        "trainer": {
            "default_root_dir": root_dir,
            "gpus": 1,
            "resume_from_checkpoint": resume_weights
        }
    }
}'''

## Single Image

In [None]:
for row in data_10:
    imgpath = os.path.join(L_PREFIX, row['image_path'])
    cfg = eval(reqdata % ("[\"" + imgpath + "\"]"))
    resdata = json.loads(requests.post(url=API_INFERENCE, json=cfg).text)
    if resdata['errno'] == 0:
        result = resdata['result'][0]
        fname = result['fname']
        probs = result['probs']
        print('{}: {} vs {}'.format(fname, row['label'], probs.index(max(probs))))

img_汉堡_66.jpeg: 1 vs 1
img_汉堡_302.jpeg: 1 vs 1
img_螺丝刀_1.jpeg: 2 vs 3
img_书_168.jpeg: 2 vs 2
img_白板笔_177.jpeg: 0 vs 2
img_汉堡_131.jpeg: 1 vs 1


## Batch Images

In [None]:
images = [os.path.join(L_PREFIX, row['image_path']) for row in data_10]
labels = [row['label'] for row in data_10]
cfg = eval(reqdata % json.dumps(images, ensure_ascii=False))
resdata = json.loads(requests.post(url=API_INFERENCE, json=cfg).text)
if resdata['errno'] == 0:
    for item, label in zip(resdata['result'], labels):
        fname = item['fname']
        probs = item['probs']
        print('{}: {} vs {}'.format(fname, label, probs.index(max(probs))))
resdata