# 常見CNN分類模型使用
使用pytorchw提供的模型進行使用
* pytorch提供的模型是使用imagenet(>100萬張圖片)訓練的
* 請使用GPU進行測試，沒有GPU的話再用CPU


下載測試用圖片&物件分類字典

In [1]:
!wget -O test.jpg https://raw.githubusercontent.com/tetenlost/demo_obj_detection/main/data/pexels-alexandru-rotariu-733416.jpg
!wget https://s3.amazonaws.com/deep-learning-models/image-models/imagenet_class_index.json

--2023-02-07 01:34:55--  https://raw.githubusercontent.com/tetenlost/demo_obj_detection/main/data/pexels-alexandru-rotariu-733416.jpg
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.108.133, 185.199.109.133, 185.199.110.133, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.108.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 3608957 (3.4M) [image/jpeg]
Saving to: ‘test.jpg’


2023-02-07 01:34:56 (48.4 MB/s) - ‘test.jpg’ saved [3608957/3608957]

--2023-02-07 01:34:56--  https://s3.amazonaws.com/deep-learning-models/image-models/imagenet_class_index.json
Resolving s3.amazonaws.com (s3.amazonaws.com)... 52.216.50.184, 52.216.32.192, 52.217.111.166, ...
Connecting to s3.amazonaws.com (s3.amazonaws.com)|52.216.50.184|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 35363 (35K) [application/octet-stream]
Saving to: ‘imagenet_class_index.json’


2023-02-07 01:34:56 (1.13 MB/s) - ‘im

## 載入相關套件

torch AI學習套件

torch.nn AI運算基礎單元

torch.optim AI訓練、優化器

torch.optim.lr_scheduler AI學習率調整

torchvision pytorch 影像辨識套件

torchvision.datasets 訓練資料及製作

torchvision.models 知名模型套件(不用重新開始建構)

torchvision.transforms 影像處理套件

json 讀取json檔案套件


In [2]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim import lr_scheduler
import torchvision
from torchvision import datasets, models, transforms
from PIL import Image
import matplotlib.pyplot as plt
import json
import numpy as np
#-----------------------
#以下僅COLAB有效
from google.colab.patches import cv2_imshow


## 宣告主要測試程式

In [8]:
def main(model,transform,idx2label): 
    print("Your model:",model.__class__.__name__)
    # 側視圖片路徑
    path = '/content/test.jpg'
    # 檢查設備有無GPU，無則使用CPU
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    # 將模型移到GPU/CPU
    model.to(device)
    # 模型設定為測試模式(非訓練模式)
    model.eval()
    # 讓pytorch不計算模型梯度(非訓練模式)
    with torch.no_grad():
        #讀取圖片
        img = Image.open(path)
        #將圖片轉換為模型支援格式(1,C,H,W)，並放入GPU/CPU中
        img = transform(img).unsqueeze(0).to(device)
        #模型辨識
        output = model(img)
        #將圖片轉回(H,W,C)格式
        test_img=np.asarray(img[0].permute(1, 2, 0).cpu())[:,:,::-1]
        #將圖片數值由(-1,1)轉回(0,255)
        test_img = (test_img+1)/2*255
        #顯示圖片
        cv2_imshow(test_img)
        print("-----detect_result-----")
        #將偵測結果信心度進行排序(高->低)
        _,ids=torch.sort(output[0].cpu(),descending=True)
        #取前五項
        top5 = ids[:5]
        #將前五項類別轉換成文字並顯示
        for top,id in enumerate(top5):
            print("top%s: %s"%(str(top+1),str(idx2label[int(id)])))
        print("-----------------------")

讀取class_table(用於模型輸出結果與類別名稱轉換)

In [4]:
with open("imagenet_class_index.json") as class_table:
    class_idx = json.load(class_table)
idx2label = [class_idx[str(k)][1] for k in range(len(class_idx))]

宣告資料預處理函式

1.Resize->縮放

2.ToTensor->格式轉換為pytorch格式

3.Normalize->標準化

In [5]:
val_transform = transforms.Compose([
                    transforms.Resize((224,224)),
                    transforms.ToTensor(),
                    transforms.Normalize([0.5,0.5,0.5],[0.5,0.5,0.5]),
                    ])

## 宣告模型

以下模型，選一個點擊執行即可

### ALEXNET(誤判率:43.45%)

In [None]:
model = models.alexnet(pretrained=True)

Downloading: "https://download.pytorch.org/models/alexnet-owt-7be5be79.pth" to /root/.cache/torch/hub/checkpoints/alexnet-owt-7be5be79.pth


  0%|          | 0.00/233M [00:00<?, ?B/s]

### VGG19(誤判率:27.62%)

In [None]:
model = models.vgg19(pretrained=True)

### VGG19-BN(誤判率:25.76%)

比較不會梯度爆炸版本

In [None]:
model = models.vgg19_bn(pretrained=True)

Downloading: "https://download.pytorch.org/models/vgg19_bn-c79401a0.pth" to /root/.cache/torch/hub/checkpoints/vgg19_bn-c79401a0.pth


  0%|          | 0.00/548M [00:00<?, ?B/s]

### GOOGLENET(誤判率:30.22%)

In [None]:
model = models.googlenet(pretrained=True)

### RESNET 18(誤判率:30.24%)

In [6]:
model = models.resnet18(pretrained=True)

Downloading: "https://download.pytorch.org/models/resnet18-f37072fd.pth" to /root/.cache/torch/hub/checkpoints/resnet18-f37072fd.pth


  0%|          | 0.00/44.7M [00:00<?, ?B/s]

### RESNET 50(誤判率:23.85%)

In [12]:
model = models.resnet50(pretrained=True)

Downloading: "https://download.pytorch.org/models/resnet50-0676ba61.pth" to /root/.cache/torch/hub/checkpoints/resnet50-0676ba61.pth


  0%|          | 0.00/97.8M [00:00<?, ?B/s]

### RESNET 152(誤判率:21.69%)

In [None]:
model = models.resnet152(pretrained=True)

## 測試

In [None]:
main(model,val_transform,idx2label)

### 下載其他影像

In [None]:
!wget -O test.jpg https://encrypted-tbn0.gstatic.com/images?q=tbn:ANd9GcRPiNOxJBhg1qT5gh7zY327f-vahP6_e2SO6A&usqp=CAU