In [11]:
import torchvision.transforms as transforms
import torch
from PIL import Image
from collections import OrderedDict
import torch.nn.functional as F
from torch import nn
import os, time
import torchvision.models as models
from models.resnetxt_wsl import resnext101_32x8d_wsl, resnext101_32x16d_wsl, resnext101_32x32d_wsl
import cv2
import matplotlib.pyplot as plt
import numpy as np

os.environ["CUDA_VISIBLE_DEVICES"] = '0,1,2,3,4,5,6,7,8,9'
args = {}
args['arch'] = 'resnext101_32x16d_wsl'
args['pretrained'] = False
args['num_classes'] = 42
args['image_size'] = 288


class classfication_service():
    def __init__(self, model_path):
        self.model = self.build_model(model_path)
        self.pre_img = self.preprocess_img()
        self.model.eval()
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        self.label_id_name_dict = \
            {
                "0": "其他垃圾/一次性快餐盒",
                "1": "其他垃圾/污损塑料",
                "2": "其他垃圾/烟蒂",
                "3": "其他垃圾/牙签",
                "4": "其他垃圾/破碎花盆及碟碗",
                "5": "其他垃圾/竹筷",
                "6": "厨余垃圾/剩饭剩菜",
                "7": "厨余垃圾/大骨头",
                "8": "厨余垃圾/水果果皮",
                "9": "厨余垃圾/水果果肉",
                "10": "厨余垃圾/茶叶渣",
                "11": "厨余垃圾/菜叶菜根",
                "12": "厨余垃圾/蛋壳",
                "13": "厨余垃圾/鱼骨",
                "14": "可回收物/充电宝",
                "15": "可回收物/包",
                "16": "可回收物/化妆品瓶",
                "17": "可回收物/塑料玩具",
                "18": "可回收物/塑料碗盆",
                "19": "可回收物/塑料衣架",
                "20": "可回收物/快递纸袋",
                "21": "可回收物/插头电线",
                "22": "可回收物/旧衣服",
                "23": "可回收物/易拉罐",
                "24": "可回收物/枕头",
                "25": "可回收物/毛绒玩具",
                "26": "可回收物/洗发水瓶",
                "27": "可回收物/玻璃杯",
                "28": "可回收物/皮鞋",
                "29": "可回收物/砧板",
                "30": "可回收物/纸板箱",
                "31": "可回收物/调料瓶",
                "32": "可回收物/酒瓶",
                "33": "可回收物/金属食品罐",
                "34": "可回收物/锅",
                "35": "可回收物/食用油桶",
                "36": "可回收物/饮料瓶",
                "37": "有害垃圾/干电池",
                "38": "有害垃圾/软膏",
                "39": "有害垃圾/过期药物",
                "40": "可回收物/纸",
                "41": "可回收物/锡箔纸",
            }

    def build_model(self, model_path):
        
        if args['arch'] == 'resnext101_32x16d_wsl':
            model = resnext101_32x16d_wsl()
        if args['arch'] == 'resnext101_32x8d':
            model = models.__dict__[args['arch']]()
        elif args['arch'] == 'efficientnet-b7':
            model = EfficientNet.from_name(args['arch'])
        

        layerName, layer = list(model.named_children())[-1]
        exec("model." + layerName + "=nn.Linear(layer.in_features," + str(args['num_classes']) + ")")
        #model = nn.DataParallel(model)
        if torch.cuda.is_available():
            modelState = torch.load(model_path)
            for key in list(modelState.keys()):
                if 'fc.1.' in key:
                    modelState[key.replace('fc.1.', 'fc.')] = modelState[key]
                    del modelState[key]
            model.load_state_dict(modelState)
            model = model.cuda()
        else:
            modelState = torch.load(model_path, map_location='cpu')
            model.load_state_dict(modelState)
        return model

    def preprocess_img(self):
        mean, std = [0.485, 0.456, 0.406], [0.229, 0.224, 0.225]
        infer_transformation = transforms.Compose([
            Resize((int(288 * (256 / 224)), int(288 * (256 / 224)))),
            transforms.CenterCrop(288),
            transforms.ToTensor(),
            transforms.Normalize(mean=mean, std=std),
        ])
        return infer_transformation

    def _preprocess(self, data):
        preprocessed_data = {}
        for k, v in data.items():
            for file_name, file_content in v.items():
                img = Image.open(file_content)
                img = self.pre_img(img)
                preprocessed_data[k] = img
        return preprocessed_data

    def _inference(self, data):
        """
        model inference function
        Here are a inference example of resnet, if you use another model, please modify this function
        """
        img = data['input_img']
        img = img.unsqueeze(0)
        img = img.to(self.device)
        with torch.no_grad():
            pred_score = self.model(img)

        if pred_score is not None:
            _, pred_label = torch.max(pred_score.data, 1)
            result = {'result': self.label_id_name_dict[str(pred_label[0].item())]}
        else:
            result = {'result': 'predict score is None'}

        return result,pred_label

    def _postprocess(self, data):
        return data
    
    def draw_CAM(self, data, save_name,file_path, visual_heatmap=True):
        # 图像加载&预处理
        img = data
        # 获取模型输出的feature/score
        img = img.unsqueeze(0)
        img = img.to(self.device)
        
        features_blobs = []
        def hook_feature(module, input, output):
            features_blobs.append(output.data.cpu().numpy())
        self.model._modules.get('layer4').register_forward_hook(hook_feature)
        features = self.model(img)
        
        params = list(self.model.parameters()) # 将参数变换为列表
        weight_fc = np.squeeze(params[-2].data.cpu().numpy()) # 提取softmax 层的参数
        
        def returnCAM(feature_conv, weight_softmax, class_idx):
            size_upsample = (288, 288)
            bz, nc, h, w = feature_conv.shape # 获取feature_conv特征的尺寸
            output_cam = []
            for idx in class_idx:
                cam = weight_softmax[idx].dot(feature_conv.reshape((nc, h*w)))
                cam = cam.reshape(h, w)
                cam = cam - np.min(cam)
                cam_img = cam / np.max(cam)
                cam_img = np.uint8(255 * cam_img)
                output_cam.append(cv2.resize(cam_img, size_upsample))
            return output_cam 
       
        h_x = F.softmax(features, dim=1).data.squeeze() # 分类分值
        # 对分类的预测类别分值排序，输出预测值和在列表中的位置
        probs, idx = h_x.sort(0, True) 
        # 转换数据类型
        probs = probs.cpu().numpy()
        idx = idx.cpu().numpy()
        print(probs,idx) 
        CAMs = returnCAM(features_blobs[0], weight_fc, [idx[0]])
        img = cv2.imread(file_path)
        height, width, _ = img.shape
        heatmap = cv2.applyColorMap(cv2.resize(CAMs[0],(width, height)), cv2.COLORMAP_JET)
        result = heatmap * 0.3 + img * 0.5
        cv2.imwrite(save_name+'cam.jpg', result)
       

class Resize(object):
    def __init__(self, size, interpolation=Image.BILINEAR):
        self.size = size
        self.interpolation = interpolation

    def __call__(self, img):
        ratio = self.size[0] / self.size[1]
        w, h = img.size
        if w / h < ratio:
            t = int(h * ratio)
            w_padding = (t - w) // 2
            img = img.crop((-w_padding, 0, w+w_padding, h))
        else:
            t = int(w / ratio)
            h_padding = (t - h) // 2
            img = img.crop((0, -h_padding, w, h+h_padding))

        img = img.resize(self.size, self.interpolation)

        return img

    

In [12]:
model_path = 'data0/garbage/res_16_288_last1/model_30_9993_9394.pth'
save_path = 'garbage_classify/huawei-garbage/data/garbage_classify/train_data'
infer = classfication_service(model_path)
import os
path = os.getcwd()


input_dir = 'val'

for i in range(42):
    a = os.path.join(path,input_dir,str(i))
    files = os.listdir(a)
    for file_name in files:
        if not file_name[0]==".":
            file_path = os.path.join(path,input_dir,str(i),file_name)
            img = Image.open(file_path)
            
            img = infer.pre_img(img)
            
            result,pred_lb = infer._inference({'input_img': img})
            print(pred_lb)
            infer.draw_CAM(img,str(i),file_path)
            break
            
   

tensor([0], device='cuda:0')
[9.99999642e-01 1.18940399e-07 1.14833732e-07 6.28970582e-08
 2.64253366e-08 2.20053291e-08 1.56330096e-08 1.14185434e-08
 1.09564411e-08 1.01484279e-08 7.29086835e-09 4.48723503e-09
 3.95748723e-09 3.91302235e-09 2.41420350e-09 2.05274131e-09
 1.97068228e-09 1.64648184e-09 1.56562474e-09 1.44961310e-09
 1.29134803e-09 1.28428124e-09 8.25279622e-10 7.86405163e-10
 7.40997430e-10 7.40663919e-10 5.11953979e-10 4.79704054e-10
 4.50967874e-10 4.32223229e-10 3.77830461e-10 2.93876617e-10
 2.05661627e-10 1.52745858e-10 1.43162301e-10 1.07501084e-10
 8.45190307e-11 7.45353015e-11 7.02524636e-11 5.34729899e-11
 3.43849775e-11 1.65411990e-11] [ 0  6 35 41 31 33  9  1 12  8 29 27 34 15 28 11 18 10 16 13  5 21  3  4
 19  2 40 14  7 17 30 24 39 23 20 25 38 32 36 22 37 26]
tensor([1], device='cuda:0')
[1.00000000e+00 1.19653654e-09 4.58016403e-10 1.93738928e-10
 9.27324120e-11 6.12211115e-11 5.24460544e-11 4.88892954e-11
 3.94340914e-11 3.81921959e-11 3.48040485e-11 1.3

tensor([11], device='cuda:0')
[9.99997973e-01 4.72951029e-07 2.62727781e-07 2.38841153e-07
 2.23134990e-07 1.42469034e-07 1.13288927e-07 1.07531825e-07
 7.05029422e-08 5.47542207e-08 4.94521650e-08 3.55304408e-08
 2.61897597e-08 2.21148984e-08 2.18550014e-08 1.97790637e-08
 1.96211740e-08 1.78167578e-08 1.58674212e-08 1.34536675e-08
 1.07095399e-08 1.04497859e-08 9.98692062e-09 8.73617090e-09
 8.07915956e-09 7.97940114e-09 7.45197593e-09 6.56660637e-09
 5.92986638e-09 5.40565548e-09 5.34929567e-09 4.57585259e-09
 4.49372806e-09 3.83843668e-09 3.23010019e-09 3.17940918e-09
 2.20731677e-09 1.74955173e-09 1.35682598e-09 1.03644859e-09
 8.89650131e-10 8.54990634e-10] [11 10  5 31  7  3 32  9  6 29  1 13  0 40 19 18 12 17 24 41  8 33 26 28
 25 16 22 37 34 38 30 15  2 35 23 27 36 14  4 20 21 39]
tensor([12], device='cuda:0')
[9.99999046e-01 3.33259663e-07 2.65293664e-07 2.35679110e-07
 1.17429806e-07 2.72514136e-08 1.62096789e-08 8.50975024e-09
 7.24224325e-09 5.93709304e-09 4.26529034e-09 4

tensor([22], device='cuda:0')
[9.99999881e-01 4.54749056e-08 1.14045289e-08 1.11047243e-08
 8.32692759e-09 5.10911580e-09 3.94647515e-09 3.61083297e-09
 1.71403436e-09 1.38078737e-09 1.14735343e-09 9.10658937e-10
 9.02943276e-10 8.09811607e-10 7.36335604e-10 6.25028418e-10
 4.51121085e-10 1.72812611e-10 1.56997235e-10 1.29447453e-10
 1.22156507e-10 1.16293856e-10 1.10091360e-10 1.05550617e-10
 9.01229161e-11 8.36450215e-11 4.83311377e-11 4.70383420e-11
 4.14382591e-11 4.09287049e-11 3.82016606e-11 3.69813624e-11
 3.47611211e-11 2.58190708e-11 2.56034412e-11 1.09307719e-11
 1.01497109e-11 9.68069375e-12 9.66322335e-12 3.19824701e-12
 2.16726420e-12 2.68335321e-13] [22 19 28  6 15 11  1 13  8 24  7 25 10 29 27 30 41 17 12 23  4 18 34 26
  5 20 33  0  9 14 36 38  2 32 40  3 16 39 35 21 37 31]
tensor([23], device='cuda:0')
[1.00000000e+00 9.67029568e-10 3.53301749e-10 2.18676174e-10
 1.35863307e-10 1.07071192e-10 1.02025305e-10 8.04498412e-11
 7.04777001e-11 6.99688571e-11 5.34396728e-11 4

tensor([33], device='cuda:0')
[1.0000000e+00 3.8805292e-09 3.4741077e-09 3.2382370e-09 1.9813078e-09
 1.0244030e-09 8.5569224e-10 7.1694861e-10 5.6832367e-10 2.8155156e-10
 2.3179909e-10 2.2752432e-10 1.7393829e-10 8.8969901e-11 6.8949561e-11
 5.1148637e-11 3.8543242e-11 3.1382143e-11 2.4013111e-11 1.9092519e-11
 1.6310245e-11 1.3007557e-11 1.1366444e-11 9.8274427e-12 9.7073564e-12
 9.3758378e-12 6.6905583e-12 5.7579492e-12 3.5128356e-12 2.3359459e-12
 1.4808467e-12 1.4382510e-12 1.2799488e-12 9.5785326e-13 8.3727459e-13
 8.1750801e-13 7.9118553e-13 5.3899241e-13 2.8857117e-13 1.7608068e-13
 1.4752949e-13 4.1895083e-14] [33 35 37 34 31 27 30 19  5 24  0  1 41 29 40 14  3 38 16 15  6 32 23 39
 21  2 28 26 13  9 36 20 17 25 11 12 10 18  7  4  8 22]
tensor([34], device='cuda:0')
[9.99940395e-01 4.11245201e-05 5.54295366e-06 1.79701794e-06
 1.61312232e-06 1.57889087e-06 1.50210906e-06 1.35269977e-06
 1.10245890e-06 9.78737603e-07 6.51546600e-07 3.51866674e-07
 2.66250368e-07 2.24443610e-07

In [7]:
import os
path = os.getcwd()
print(path)
input_dir = 'val/0'
a = os.path.join(path,input_dir)

files = os.listdir(a)
files = files.stratwith('.')
print(files)


/home/yizj/garbage_classify/huawei-garbage


AttributeError: 'list' object has no attribute 'stratwith'

In [6]:
import os

model_path = 'data0/garbage/res_16_288_last1/model_30_9993_9394.pth'
save_path = 'garbage_classify/huawei-garbage/data/garbage_classify/train_data'
infer = classfication_service(model_path)
path = os.getcwd()
print(path)
input_dir = 'val'


# 导入CSV安装包
import csv

# 1. 创建文件对象
f = open('result.csv','w',newline='')

# 2. 基于文件对象构建 csv写入对象
csv_writer = csv.writer(f)

# 3. 构建列表头
csv_writer.writerow(["","pred","target"])

t1 = int(time.time()*1000)
idx = 0
for i in range(42):
    a = os.path.join(path,input_dir,str(i))
    files = os.listdir(a)
    for file_name in files:
        if not file_name[0]==".":
            
            file_path = os.path.join(path,input_dir,str(i),file_name)
            img = Image.open(file_path)
            img = infer.pre_img(img)
            result,pred_lb = infer._inference({'input_img': img})
            csv_writer.writerow([str(idx),str(pred_lb.cpu().numpy()[0]),str(i)])
            idx = idx+1
        
    
# 5. 关闭文件
f.close()

/home/yizj/garbage_classify/huawei-garbage


range(0, 10)