In [1]:
import os
import shutil
import warnings

import numpy as np
from PIL import Image
import torch
import torch.nn.functional as F
import torchvision
from torchvision import transforms as T
from torchvision.datasets import ImageFolder
from sklearn.model_selection import StratifiedShuffleSplit
from torch.utils.data import DataLoader
from pytorch_grad_cam import GradCAM, ScoreCAM, GradCAMPlusPlus, AblationCAM, XGradCAM, EigenCAM, GuidedBackpropReLUModel
from pytorch_grad_cam.utils.image import show_cam_on_image, deprocess_image
import matplotlib.pyplot as plt
import cv2
from tqdm import tqdm
from sklearn import metrics, preprocessing

from train import create_net
from utils.eval import eval_net
from utils.datasets import AortaDataset3D

warnings.filterwarnings("ignore")
np.random.seed(63910)
torch.manual_seed(53152)
torch.cuda.manual_seed_all(7987)
torch.backends.cudnn.deterministic = True

In [2]:
os.environ['CUDA_VISIBLE_DEVICES'] = '2'
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
#device = torch.device('cpu')

In [9]:
transform = T.Compose([
    T.Resize(81), # 缩放图片(Image)，保持长宽比不变，最短边为img_size像素
    T.CenterCrop(81), # 从图片中间切出img_size*img_size的图片
    T.ToTensor(), # 将图片(Image)转成Tensor，归一化至[0, 1]
])

In [10]:
val = ImageFolder('/nfs3-p1/zsxm/dataset/aorta_classify_ct_-100_500/val/', transform=transform, loader=lambda path: Image.open(path))
val_loader = DataLoader(val, batch_size=100, shuffle=True, num_workers=8, pin_memory=True, drop_last=False)

In [11]:
# net1 = create_net(device, 1, 3, 'details/checkpoints/CrossEntropy/11-06_12:19:06/Net_best.pth', entire=True)
# net2 = create_net(device, 1, 1, 'details/checkpoints/CrossEntropy/11-06_12:22:36/Net_epoch49.pth', entire=True)
net1 = create_net(device, 34, 1, 3, 'details/checkpoints/CrossEntropy/11-11_20:14:53/Net_best.pth', entire=True)
net2 = create_net(device, 34, 1, 1, 'details/checkpoints/CrossEntropy/11-11_20:14:22/Net_best.pth', entire=True)
net1.eval()
net2.eval()

[INFO]: **********************************************************************
Network: ResNet_34
	1 input channels
	3 output channels (classes)
	3D model: False

[INFO]: Model loaded from details/checkpoints/CrossEntropy/11-11_20:14:53/Net_best.pth
[INFO]: **********************************************************************
Network: ResNet_34
	1 input channels
	1 output channels (classes)
	3D model: False

[INFO]: Model loaded from details/checkpoints/CrossEntropy/11-11_20:14:22/Net_best.pth


ResNet(
  (encoder): ResNetEncoder(
    (conv1): Conv2d(1, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
    (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu): ReLU(inplace=True)
    (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
    (layer1): Sequential(
      (0): BasicBlock(
        (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU(inplace=True)
        (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
      (1): BasicBlock(
        (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, tra

In [12]:
num_val_batches = len(val_loader)  # the number of batch
n_val = 12434

true_list = []
pred_list = []

for imgs, true_categories in tqdm(val_loader, total=num_val_batches, desc='Validation round', unit='batch', leave=False):
    imgs = imgs.to(device=device, dtype=torch.float32)
    true_categories = true_categories.to(device=device, dtype=torch.long)
    true_list += true_categories.tolist()

    with torch.no_grad():
        categories_pred1 = net1(imgs)

    pred1 = torch.softmax(categories_pred1, dim=1)
    pred1 = pred1.argmax(dim=1)
    datas = imgs[pred1 > 1]
    
    with torch.no_grad():
        categories_pred2 = net2(datas)
        
    pred2 = torch.sigmoid(categories_pred2)
    pred2 = (pred2 > 0.5).long().squeeze(-1)
        
#     print(pred1)
#     print(pred2)
#     print(true_categories.shape[0])
#     print(datas.shape[0])
    
    pred1[pred1>1] = pred2+2
    #print(pred1)
    
    pred_list.extend(pred1.tolist())
    


                                                                                                           

In [13]:
print(metrics.classification_report(true_list, pred_list, digits=4))

              precision    recall  f1-score   support

           0     0.9134    0.8979    0.9056      6522
           1     0.8167    0.9041    0.8582      4370
           2     0.7243    0.5579    0.6303      1149
           3     0.5700    0.4351    0.4935       393

    accuracy                         0.8540     12434
   macro avg     0.7561    0.6987    0.7219     12434
weighted avg     0.8511    0.8540    0.8505     12434

