### SE-ResNeXt50_32x4d評価
#### 前提条件
* 学習時に保存したDAYと同じpathを指定

In [1]:
import os
import glob
import random
from PIL import Image

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
from torchvision import transforms as T
from torch.utils.data import DataLoader
import albumentations as A
import pretrainedmodels

In [2]:
device = torch.device("cuda", 0) if torch.cuda.is_available() else "cpu"
count = torch.cuda.device_count()
def fix_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
SEED = 42
fix_seed(SEED)

In [5]:
def make_datapath_list(root_path):
    path_list = []
    target_path = os.path.join("data" + root_path + f'/*.jpeg')
    for path in glob.glob(target_path):
        path_list.append(path)
    path_list.sort()
    path_list = [os.path.basename(p) for p in path_list]
    return path_list

In [6]:
BASE1 = "progress"
DAY = "/test"#学習時と同じ名前を指定
MODEL_PATH = "{}/model".format(BASE1 + DAY)
PIC_PATH = "{}/pic".format(BASE1 + DAY)
os.makedirs(MODEL_PATH, exist_ok=True)
os.makedirs(PIC_PATH, exist_ok=True)

In [7]:
test = make_datapath_list('/crip_test')
test_data = [Image.open('data/crip_test/' + i) for i in test]

size=(224,224)
test_resize= np.array([np.array(img.resize(size, Image.LANCZOS), dtype="uint8") for img in test_data])

In [8]:
class MyDataset3(torch.utils.data.Dataset):
    def __init__(self, file_list, label_list = None, P = None, A = None):
        self.P = P
        self.A = A
        self.data = file_list
        self.label = label_list
    def __len__(self):
        return len(self.data)
    def __getitem__(self, idx):
        data = self.data[idx]
        if self.label is not None:
            label = self.label[idx]
        else:
            label = 0
        data = Image.fromarray(data)
        if self.P:
            data = self.P(data)
        if self.A:
            data = np.asarray(data).transpose(1,2,0)
            image = self.A(image = data)
            data = image['image']
        data = data.transpose(2,1,0)
        return data.astype('f'), label

In [9]:
test_tf = T.Compose([
    T.ToTensor(),
    T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
test_A = A.Compose([
])
test_set = MyDataset3(test_resize, P = test_tf,  A = test_A)
test_loader = DataLoader(test_set, batch_size = 8, shuffle = False)

In [10]:
model_name = 'se_resnext50_32x4d'
resnet_model = pretrainedmodels.__dict__[model_name](num_classes=1000, pretrained='imagenet')

In [11]:
class ResNet(nn.Module):
    def __init__(self, pretrained_res_model, class_num):
        super(ResNet, self).__init__()
        self.eff = pretrained_res_model
        self.fc1 = nn.Linear(1000, 100)
        self.fc2 = nn.Linear(100, class_num)
        self.dropout = nn.Dropout(0.5)
        self.relu = nn.ReLU()

    def forward(self, input_ids):
        states = self.eff(input_ids)
        states = self.relu(self.fc1(states))
        states = self.dropout(states)
        states = self.fc2(states)
        return states

In [15]:
final_softmax = {}
for i in range(5):
    model = ResNet(resnet_model, 2)
    for param in model.parameters():
        param.requires_grad = False
    model_path = MODEL_PATH + f'/fold{i}_best.pth'
    model.load_state_dict(torch.load(model_path))
    model.to(device)
    m = nn.Softmax(dim=1)
    check_input = []
    check_output = []
    check_softmax = []
    model.eval()
    with torch.no_grad():
        for test_imgs,test_labels in test_loader:
            test_imgs = test_imgs.to(device)
            test_output = model(test_imgs)
            test_softmax = m(test_output.cpu().detach())
            check_input.append(test_imgs.cpu())
            check_output.append(test_output.cpu().detach())
            check_softmax.append(test_softmax)
            
    show_input = torch.cat(check_input).numpy().transpose(0,2,3,1)
    show_output = torch.cat(check_output).numpy()
    show_softmax = torch.cat(check_softmax).numpy()
    
    final_softmax[f"model{i}"] = show_softmax[:,1]

In [17]:
pd.DataFrame(final_softmax).to_csv('csv/res.csv',index=False)