In [7]:
import os
import warnings

from tqdm import tqdm
from PIL import Image
import numpy as np
import torch
from torch.utils.data import DataLoader
from torchvision.datasets import ImageFolder
import torchvision.transforms as T

from train import create_net

warnings.filterwarnings("ignore")
np.random.seed(63910)
torch.manual_seed(53152)
torch.cuda.manual_seed_all(7987)
torch.backends.cudnn.deterministic = True

os.environ['CUDA_VISIBLE_DEVICES'] = '2'
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

In [2]:
n_channels, n_classes, batch_size = 1, 4, 128

transform = T.Compose([
    T.Resize(81), # 缩放图片(Image)，保持长宽比不变，最短边为img_size像素
    T.CenterCrop(81), # 从图片中间切出img_size*img_size的图片
    T.ToTensor(), # 将图片(Image)转成Tensor，归一化至[0, 1]
])

train = ImageFolder('/nfs3-p1/zsxm/dataset/aorta_classify_ct_-100_500/train/', transform=transform, loader=lambda path: Image.open(path))
val = ImageFolder('/nfs3-p1/zsxm/dataset/aorta_classify_ct_-100_500/val/', transform=transform, loader=lambda path: Image.open(path))
train_loader = DataLoader(train, batch_size=batch_size, shuffle=False, num_workers=8, pin_memory=True, drop_last=False)
val_loader = DataLoader(val, batch_size=batch_size, shuffle=False, num_workers=8, pin_memory=True, drop_last=False)

In [3]:
net = create_net(device, 'resnet', 34, n_channels, n_classes, '/nfs3-p2/zsxm/temp_path/single81.pth', entire=True)
net.eval()
print('')

[INFO]: **********************************************************************
Network: ResNet_34
	1 input channels
	4 output channels (classes)
	3D model: False

[INFO]: Model loaded from /nfs3-p2/zsxm/temp_path/single81.pth



In [None]:
idx = 0
score_list = [[]] * n_classes
for imgs, true_categories in tqdm(train_loader, total=len(train_loader), desc='Train Dataset', unit='batch', leave=False):
    imgs.to(device=device, dtype=torch.float32)
    true_categories.to(device=device, dtype=torch.long)
    
    categories_pred = net(imgs)
    labels_pred = categories_pred.argmax(dim=1)
    for i in range(len(true_categories)):
        if labels_pred[i] == true_categories[i]:
            print(true_categories[i].item())
            score_list[true_categories[i].item()].append((categories_pred[i, labels_pred[i]], idx))
        idx += 1

In [12]:
for i in range(n_classes):
    print(len(score_list[i]))

53083
53083
53083
53083


In [13]:
score_list[0] == score_list[1]

True

In [14]:
len(score_list)

4