In [27]:
import zipfile

In [30]:
with zipfile.ZipFile("train.zip", "r") as zip_ref:
    zip_ref.extractall()

## Import Libraries

In [8]:
!apt-get update && apt-get install -y python3-opencv

Ign:1 https://developer.download.nvidia.com/compute/cuda/repos/ubuntu1804/x86_64  InRelease
Hit:2 http://security.ubuntu.com/ubuntu bionic-security InRelease
Hit:3 http://archive.ubuntu.com/ubuntu bionic InRelease                  
Hit:4 http://archive.ubuntu.com/ubuntu bionic-updates InRelease          
Ign:5 https://developer.download.nvidia.com/compute/machine-learning/repos/ubuntu1804/x86_64  InRelease
Hit:6 https://developer.download.nvidia.com/compute/cuda/repos/ubuntu1804/x86_64  Release
Hit:7 https://developer.download.nvidia.com/compute/machine-learning/repos/ubuntu1804/x86_64  Release
Hit:8 http://archive.ubuntu.com/ubuntu bionic-backports InRelease
Reading package lists... Done                    
Reading package lists... Done
Building dependency tree       
Reading state information... Done
python3-opencv is already the newest version (3.2.0+dfsg-4ubuntu0.1).
0 upgraded, 0 newly installed, 0 to remove and 77 not upgraded.


In [9]:
!pip install sklearn



In [14]:
!pip3 install opencv-python

Collecting opencv-python
  Downloading opencv_python-4.5.5.62-cp36-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (60.4 MB)
[K     |████████████████████████████████| 60.4 MB 57 kB/s s eta 0:00:01
Installing collected packages: opencv-python
Successfully installed opencv-python-4.5.5.62


In [16]:
!pip3 install pandas

Collecting pandas
  Downloading pandas-1.4.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (11.7 MB)
[K     |████████████████████████████████| 11.7 MB 15.1 MB/s eta 0:00:01
Installing collected packages: pandas
Successfully installed pandas-1.4.0


In [1]:
import os, torch, copy, cv2, sys, random
# from datetime import datetime, timezone, timedelta
from PIL import Image
import numpy as np
import pandas as pd
from tqdm import tqdm
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
import torchvision.transforms as transforms

## Set Arguments & hyperparameters

In [2]:
# 시드(seed) 설정

RANDOM_SEED = 2022

torch.manual_seed(RANDOM_SEED)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
np.random.seed(RANDOM_SEED)
random.seed(RANDOM_SEED)

In [3]:
# parameters

### 데이터 디렉토리 설정 ###
DATA_DIR= 'data'
NUM_CLS = 3

EPOCHS = 50
BATCH_SIZE = 32
LEARNING_RATE = 0.0005
EARLY_STOPPING_PATIENCE = 20
INPUT_SHAPE = 128

os.environ["CUDA_VISIBLE_DEVICES"]="0"
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

## Dataloader

#### Train & Validation Set loader

In [4]:
class CustomDataset(Dataset):
    def __init__(self, data_dir, mode, input_shape):
        self.data_dir = data_dir
        self.mode = mode
        self.input_shape = input_shape
        
        # Loading dataset
        self.db = self.data_loader()
        
        # Dataset split
        if self.mode == 'train':
            self.db = self.db[:int(len(self.db) * 0.9)]
        elif self.mode == 'val':
            self.db = self.db[int(len(self.db) * 0.9):]
            self.db.reset_index(inplace=True)
        else:
            print(f'!!! Invalid split {self.mode}... !!!')
            
        # Transform function
        self.transform = transforms.Compose([transforms.Resize(self.input_shape),
                                             transforms.ToTensor(),
                                             transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])

    def data_loader(self):
        print('Loading ' + self.mode + ' dataset..')
        if not os.path.isdir(self.data_dir):
            print(f'!!! Cannot find {self.data_dir}... !!!')
            sys.exit()
        
        # (COVID : 1, No : 0)
        db = pd.read_csv(os.path.join(self.data_dir, 'train.csv'))
        
        return db

    def __len__(self):
        return len(self.db)

    def __getitem__(self, index):
        data = copy.deepcopy(self.db.loc[index])

        # Loading image
        cvimg = cv2.imread(os.path.join(self.data_dir,'train',data['file_name']), cv2.IMREAD_COLOR | cv2.IMREAD_IGNORE_ORIENTATION)
        if not isinstance(cvimg, np.ndarray):
            raise IOError("Fail to read %s" % data['file_name'])

        # Preprocessing images
        trans_image = self.transform(Image.fromarray(cvimg))

        return trans_image, data['COVID']


## Model

In [5]:
import torch.nn.functional as F

class custom_CNN(nn.Module):
    def __init__(self, num_classes):
        super(custom_CNN, self).__init__()
        self.conv1 = nn.Conv2d(in_channels=3, out_channels=8, kernel_size=5)
        self.pool = nn.MaxPool2d(kernel_size=2)
        self.conv2 = nn.Conv2d(in_channels=8, out_channels=25, kernel_size=5)
        
        self.fc1 = nn.Linear(in_features=25*29*29, out_features=512)
        self.fc2 = nn.Linear(in_features=512, out_features=128)
        self.fc3 = nn.Linear(in_features=128, out_features=num_classes)
        self.softmax = nn.Softmax(dim=1)
        
    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x))) # (32, 3, 128, 128) -> (32, 8, 62, 62)
        x = self.pool(F.relu(self.conv2(x))) # (32, 8, 62, 62) -> (32, 25, 29, 29)
        
        x = torch.flatten(x,1)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = F.relu(self.fc3(x))
        
        output = self.softmax(x)
        
        return output

## Utils
### EarlyStopper

In [6]:
class LossEarlyStopper():
    """Early stopper
    
    Attributes:
        patience (int): loss가 줄어들지 않아도 학습할 epoch 수
        patience_counter (int): loss 가 줄어들지 않을 때 마다 1씩 증가, 감소 시 0으로 리셋
        min_loss (float): 최소 loss
        stop (bool): True 일 때 학습 중단

    """

    def __init__(self, patience: int)-> None:
        self.patience = patience

        self.patience_counter = 0
        self.min_loss = np.Inf
        self.stop = False
        self.save_model = False

    def check_early_stopping(self, loss: float)-> None:
        """Early stopping 여부 판단"""  

        if self.min_loss == np.Inf:
            self.min_loss = loss
            return None

        elif loss > self.min_loss:
            self.patience_counter += 1
            msg = f"Early stopping counter {self.patience_counter}/{self.patience}"

            if self.patience_counter == self.patience:
                self.stop = True
                
        elif loss <= self.min_loss:
            self.patience_counter = 0
            self.save_model = True
            msg = f"Validation loss decreased {self.min_loss} -> {loss}"
            self.min_loss = loss
        
        print(msg)

### Trainer

In [7]:
class Trainer():
    """ epoch에 대한 학습 및 검증 절차 정의"""
    
    def __init__(self, loss_fn, model, device, metric_fn, optimizer=None, scheduler=None):
        """ 초기화
        """
        self.loss_fn = loss_fn
        self.model = model
        self.device = device
        self.optimizer = optimizer
        self.scheduler = scheduler
        self.metric_fn = metric_fn

    def train_epoch(self, dataloader, epoch_index):
        """ 한 epoch에서 수행되는 학습 절차"""
        
        self.model.train()
        train_total_loss = 0
        target_lst = []
        pred_lst = []
        prob_lst = []

        for batch_index, (img, label) in enumerate(dataloader):
            img = img.to(self.device)
            label = label.to(self.device).float()
            
            pred = self.model(img)
            
            loss = self.loss_fn(pred[:,1], label)
            self.optimizer.zero_grad()
            loss.backward()
            self.optimizer.step()
            self.scheduler.step()
            
            train_total_loss += loss.item()
            prob_lst.extend(pred[:, 1].cpu().tolist())
            target_lst.extend(label.cpu().tolist())
            pred_lst.extend(pred.argmax(dim=1).cpu().tolist())
        self.train_mean_loss = train_total_loss / batch_index
        self.train_score, f1 = self.metric_fn(y_pred=pred_lst, y_answer=target_lst)
        msg = f'Epoch {epoch_index}, Train loss: {self.train_mean_loss}, Acc: {self.train_score}, F1-Macro: {f1}'
        print(msg)

    def validate_epoch(self, dataloader, epoch_index):
        """ 한 epoch에서 수행되는 검증 절차
        """
        self.model.eval()
        val_total_loss = 0
        target_lst = []
        pred_lst = []
        prob_lst = []

        for batch_index, (img, label) in enumerate(dataloader):
            img = img.to(self.device)
            label = label.to(self.device).float()
            pred = self.model(img)
            
            loss = self.loss_fn(pred[:,1], label)
            val_total_loss += loss.item()
            prob_lst.extend(pred[:, 1].cpu().tolist())
            target_lst.extend(label.cpu().tolist())
            pred_lst.extend(pred.argmax(dim=1).cpu().tolist())
        self.val_mean_loss = val_total_loss / batch_index
        self.validation_score, f1 = self.metric_fn(y_pred=pred_lst, y_answer=target_lst)
        msg = f'Epoch {epoch_index}, Val loss: {self.val_mean_loss}, Acc: {self.validation_score}, F1-Macro: {f1}'
        print(msg)

### Metrics

In [8]:
from sklearn.metrics import accuracy_score, f1_score

def get_metric_fn(y_pred, y_answer):
    """ 성능을 반환하는 함수"""
    
    assert len(y_pred) == len(y_answer), 'The size of prediction and answer are not same.'
    accuracy = accuracy_score(y_answer, y_pred)
    f1 = f1_score(y_answer, y_pred, average='macro')
    return accuracy, f1

## Train
### 학습을 위한 객체 선언

#### Load Dataset & Dataloader

In [9]:
# Load dataset & dataloader
train_dataset = CustomDataset(data_dir=DATA_DIR, mode='train', input_shape=INPUT_SHAPE)
validation_dataset = CustomDataset(data_dir=DATA_DIR, mode='val', input_shape=INPUT_SHAPE)
train_dataloader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
validation_dataloader = DataLoader(validation_dataset, batch_size=BATCH_SIZE, shuffle=True)
print('Train set samples:',len(train_dataset),  'Val set samples:', len(validation_dataset))

Loading train dataset..
Loading val dataset..
Train set samples: 581 Val set samples: 65


#### Load model and other utils

In [10]:
# Load Model
model = custom_CNN(NUM_CLS).to(DEVICE)

# # Save Initial Model
# torch.save(model.state_dict(), 'initial.pt')

# Set optimizer, scheduler, loss function, metric function
optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)
scheduler =  optim.lr_scheduler.OneCycleLR(optimizer=optimizer, pct_start=0.1, div_factor=1e5, max_lr=0.0001, epochs=EPOCHS, steps_per_epoch=len(train_dataloader))
loss_fn = nn.BCELoss()
metric_fn = get_metric_fn


# Set trainer
trainer = Trainer(loss_fn, model, DEVICE, metric_fn, optimizer, scheduler)

# Set earlystopper
early_stopper = LossEarlyStopper(patience=EARLY_STOPPING_PATIENCE)

In [11]:
model

custom_CNN(
  (conv1): Conv2d(3, 8, kernel_size=(5, 5), stride=(1, 1))
  (pool): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (conv2): Conv2d(8, 25, kernel_size=(5, 5), stride=(1, 1))
  (fc1): Linear(in_features=21025, out_features=512, bias=True)
  (fc2): Linear(in_features=512, out_features=128, bias=True)
  (fc3): Linear(in_features=128, out_features=3, bias=True)
  (softmax): Softmax(dim=1)
)

### epoch 단위 학습 진행

In [12]:
for epoch_index in tqdm(range(EPOCHS)):

    trainer.train_epoch(train_dataloader, epoch_index)
    trainer.validate_epoch(validation_dataloader, epoch_index)

    # early_stopping check
    early_stopper.check_early_stopping(loss=trainer.val_mean_loss)

    if early_stopper.stop:
        print('Early stopped')
        break

    if early_stopper.save_model:
        check_point = {
            'model': model.state_dict(),
            'optimizer': optimizer.state_dict(),
            'scheduler': scheduler.state_dict()
        }
        torch.save(check_point, 'best.pt')


  0% 0/50 [00:00<?, ?it/s]

Epoch 0, Train loss: 0.764556841717826, Acc: 0.5060240963855421, F1-Macro: 0.3436806678449216


  2% 1/50 [00:43<35:23, 43.34s/it]

Epoch 0, Val loss: 1.1947570145130157, Acc: 0.5076923076923077, F1-Macro: 0.336734693877551
Epoch 1, Train loss: 0.7169728941387601, Acc: 0.4664371772805508, F1-Macro: 0.31807511737089206
Epoch 1, Val loss: 1.000713735818863, Acc: 0.5076923076923077, F1-Macro: 0.336734693877551
Validation loss decreased 1.1947570145130157 -> 1.000713735818863


  4% 2/50 [01:26<34:38, 43.31s/it]

Epoch 2, Train loss: 0.6843900812996758, Acc: 0.504302925989673, F1-Macro: 0.40734444066643993
Epoch 2, Val loss: 1.0047152638435364, Acc: 0.5692307692307692, F1-Macro: 0.5190274841437632
Early stopping counter 1/20


  6% 3/50 [02:09<33:50, 43.21s/it]

Epoch 3, Train loss: 0.6564783288372887, Acc: 0.6316695352839932, F1-Macro: 0.6161014227843098
Epoch 3, Val loss: 0.7358929961919785, Acc: 0.6153846153846154, F1-Macro: 0.554672513017265
Validation loss decreased 1.000713735818863 -> 0.7358929961919785


  8% 4/50 [02:52<33:00, 43.06s/it]

Epoch 4, Train loss: 0.5926444315248065, Acc: 0.6953528399311532, F1-Macro: 0.6924960753532182
Epoch 4, Val loss: 0.6825853288173676, Acc: 0.676923076923077, F1-Macro: 0.6351242983159584
Validation loss decreased 0.7358929961919785 -> 0.6825853288173676


 10% 5/50 [03:36<32:28, 43.31s/it]

Epoch 5, Train loss: 0.5512094017532136, Acc: 0.7108433734939759, F1-Macro: 0.7062053023188615
Epoch 5, Val loss: 1.0974988639354706, Acc: 0.6153846153846154, F1-Macro: 0.5656241646618552
Early stopping counter 1/20


 12% 6/50 [04:18<31:35, 43.07s/it]

Epoch 6, Train loss: 0.48244575162728626, Acc: 0.7555938037865749, F1-Macro: 0.7553843781873385
Epoch 6, Val loss: 0.4659470170736313, Acc: 0.8769230769230769, F1-Macro: 0.8761904761904762
Validation loss decreased 0.6825853288173676 -> 0.4659470170736313


 14% 7/50 [05:01<30:53, 43.11s/it]

Epoch 7, Train loss: 0.4233052283525467, Acc: 0.810671256454389, F1-Macro: 0.8106572335987865
Epoch 7, Val loss: 0.493279829621315, Acc: 0.7846153846153846, F1-Macro: 0.7804054054054055
Early stopping counter 1/20


 16% 8/50 [05:45<30:15, 43.23s/it]

Epoch 8, Train loss: 0.3987894919183519, Acc: 0.7934595524956971, F1-Macro: 0.7934589406327764
Epoch 8, Val loss: 0.4904090166091919, Acc: 0.8307692307692308, F1-Macro: 0.8281663061764
Early stopping counter 2/20


 18% 9/50 [06:29<29:42, 43.48s/it]

Epoch 9, Train loss: 0.3555283380879296, Acc: 0.8433734939759037, F1-Macro: 0.5626827409408993
Epoch 9, Val loss: 0.7752856016159058, Acc: 0.8307692307692308, F1-Macro: 0.8293148722845547
Early stopping counter 3/20


 20% 10/50 [07:12<28:58, 43.46s/it]

Epoch 10, Train loss: 0.3194955007897483, Acc: 0.8795180722891566, F1-Macro: 0.8792573152194566
Epoch 10, Val loss: 0.5446661487221718, Acc: 0.8, F1-Macro: 0.7929429061504534
Early stopping counter 4/20


 22% 11/50 [07:56<28:17, 43.53s/it]

Epoch 11, Train loss: 0.26288362261321807, Acc: 0.8984509466437177, F1-Macro: 0.5993691727281198
Epoch 11, Val loss: 0.46485792845487595, Acc: 0.8307692307692308, F1-Macro: 0.8266666666666667
Validation loss decreased 0.4659470170736313 -> 0.46485792845487595


 24% 12/50 [08:39<27:30, 43.43s/it]

Epoch 12, Train loss: 0.2575937095615599, Acc: 0.8898450946643718, F1-Macro: 0.889750705661899
Epoch 12, Val loss: 0.440654993057251, Acc: 0.7384615384615385, F1-Macro: 0.7384615384615384
Validation loss decreased 0.46485792845487595 -> 0.440654993057251


 26% 13/50 [09:23<26:48, 43.48s/it]

Epoch 13, Train loss: 0.21862823474738333, Acc: 0.927710843373494, F1-Macro: 0.9274164762992576
Epoch 13, Val loss: 0.570954717695713, Acc: 0.7692307692307693, F1-Macro: 0.7636363636363637
Early stopping counter 1/20


 28% 14/50 [10:06<26:06, 43.52s/it]

Epoch 14, Train loss: 0.1700389716360304, Acc: 0.9483648881239243, F1-Macro: 0.9481546259280411
Epoch 14, Val loss: 0.4497650238336064, Acc: 0.7692307692307693, F1-Macro: 0.7683535281539557
Early stopping counter 2/20


 30% 15/50 [10:49<25:17, 43.36s/it]

Epoch 15, Train loss: 0.14232690880695978, Acc: 0.9483648881239243, F1-Macro: 0.9482174688057041
Epoch 15, Val loss: 1.2209692597389221, Acc: 0.8, F1-Macro: 0.7998104714522627
Early stopping counter 3/20


 32% 16/50 [11:33<24:38, 43.49s/it]

Epoch 16, Train loss: 0.11204545427527693, Acc: 0.9776247848537005, F1-Macro: 0.9775385484662393
Epoch 16, Val loss: 0.4790608361363411, Acc: 0.7692307692307693, F1-Macro: 0.7690120824449184
Early stopping counter 4/20


 34% 17/50 [12:17<23:57, 43.57s/it]

Epoch 17, Train loss: 0.09748435620632437, Acc: 0.9724612736660929, F1-Macro: 0.97232344885078
Epoch 17, Val loss: 0.47392392964684404, Acc: 0.8, F1-Macro: 0.799239724400095
Early stopping counter 5/20


 36% 18/50 [13:01<23:16, 43.65s/it]

Epoch 18, Train loss: 0.07856279756459925, Acc: 0.9896729776247849, F1-Macro: 0.9896212933190425
Epoch 18, Val loss: 0.5260882608708926, Acc: 0.7692307692307693, F1-Macro: 0.7656813266041816
Early stopping counter 6/20


 38% 19/50 [13:45<22:37, 43.79s/it]

Epoch 19, Train loss: 0.0715867065721088, Acc: 0.9845094664371773, F1-Macro: 0.9844356934287015
Epoch 19, Val loss: 0.48499242169054924, Acc: 0.7692307692307693, F1-Macro: 0.7690120824449183
Early stopping counter 7/20


 40% 20/50 [14:28<21:49, 43.65s/it]

Epoch 20, Train loss: 0.04930533722249998, Acc: 0.9982788296041308, F1-Macro: 0.9982714352442751
Epoch 20, Val loss: 0.5193452819439699, Acc: 0.7692307692307693, F1-Macro: 0.7690120824449183
Early stopping counter 8/20


 42% 21/50 [15:12<21:02, 43.53s/it]

Epoch 21, Train loss: 0.04547811651395427, Acc: 1.0, F1-Macro: 1.0
Epoch 21, Val loss: 0.5671897018328309, Acc: 0.7692307692307693, F1-Macro: 0.7692307692307692
Early stopping counter 9/20


 44% 22/50 [15:55<20:17, 43.50s/it]

Epoch 22, Train loss: 0.0387476057642036, Acc: 1.0, F1-Macro: 1.0
Epoch 22, Val loss: 0.5514565294142812, Acc: 0.7846153846153846, F1-Macro: 0.7845643939393939
Early stopping counter 10/20


 46% 23/50 [16:38<19:34, 43.50s/it]

Epoch 23, Train loss: 0.03563912378417121, Acc: 0.9982788296041308, F1-Macro: 0.998270632603189
Epoch 23, Val loss: 0.7840234413743019, Acc: 0.7692307692307693, F1-Macro: 0.7656813266041816
Early stopping counter 11/20


 48% 24/50 [17:22<18:54, 43.64s/it]

Epoch 24, Train loss: 0.027618575427267287, Acc: 1.0, F1-Macro: 1.0
Epoch 24, Val loss: 0.5746201686561108, Acc: 0.7538461538461538, F1-Macro: 0.7537878787878789
Early stopping counter 12/20


 50% 25/50 [18:07<18:14, 43.78s/it]

Epoch 25, Train loss: 0.01975189955232458, Acc: 1.0, F1-Macro: 1.0
Epoch 25, Val loss: 0.5765942633279337, Acc: 0.7692307692307693, F1-Macro: 0.7692307692307692
Early stopping counter 13/20


 52% 26/50 [18:50<17:29, 43.75s/it]

Epoch 26, Train loss: 0.024368970706644986, Acc: 1.0, F1-Macro: 1.0
Epoch 26, Val loss: 0.9213598072528839, Acc: 0.7692307692307693, F1-Macro: 0.7690120824449184
Early stopping counter 14/20


 54% 27/50 [19:34<16:46, 43.78s/it]

Epoch 27, Train loss: 0.021975598304480728, Acc: 0.9982788296041308, F1-Macro: 0.9982714352442751
Epoch 27, Val loss: 1.920305758714676, Acc: 0.7692307692307693, F1-Macro: 0.7690120824449184
Early stopping counter 15/20


 56% 28/50 [20:18<16:07, 43.97s/it]

Epoch 28, Train loss: 0.016027137223217223, Acc: 1.0, F1-Macro: 1.0
Epoch 28, Val loss: 0.7723749876022339, Acc: 0.8, F1-Macro: 0.799239724400095
Early stopping counter 16/20


 58% 29/50 [21:03<15:26, 44.12s/it]

Epoch 29, Train loss: 0.015790124889463186, Acc: 1.0, F1-Macro: 1.0
Epoch 29, Val loss: 0.6304820190816827, Acc: 0.8, F1-Macro: 0.799239724400095
Early stopping counter 17/20


 60% 30/50 [21:47<14:43, 44.15s/it]

Epoch 30, Train loss: 0.012013095198199153, Acc: 1.0, F1-Macro: 1.0
Epoch 30, Val loss: 0.6167648331029341, Acc: 0.7692307692307693, F1-Macro: 0.7690120824449184
Early stopping counter 18/20


 62% 31/50 [22:31<13:54, 43.93s/it]

Epoch 31, Train loss: 0.01111464907363471, Acc: 1.0, F1-Macro: 1.0
Epoch 31, Val loss: 1.4175937473773956, Acc: 0.7846153846153846, F1-Macro: 0.7841555977229602
Early stopping counter 19/20


 64% 32/50 [23:14<13:10, 43.92s/it]

Epoch 32, Train loss: 0.010770689126931958, Acc: 1.0, F1-Macro: 1.0


 64% 32/50 [23:57<13:28, 44.93s/it]

Epoch 32, Val loss: 0.645515990909189, Acc: 0.7846153846153846, F1-Macro: 0.7841555977229602
Early stopping counter 20/20
Early stopped





## Inference
### 모델 로드

In [13]:
TRAINED_MODEL_PATH = 'best.pt'

### Load dataset

In [14]:
class TestDataset(Dataset):
    def __init__(self, data_dir, input_shape):
        self.data_dir = data_dir
        self.input_shape = input_shape
        
        # Loading dataset
        self.db = self.data_loader()
        
        # Transform function
        self.transform = transforms.Compose([transforms.Resize(self.input_shape),
                                             transforms.ToTensor(),
                                             transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])

    def data_loader(self):
        print('Loading test dataset..')
        if not os.path.isdir(self.data_dir):
            print(f'!!! Cannot find {self.data_dir}... !!!')
            sys.exit()
        
        db = pd.read_csv(os.path.join(self.data_dir, 'sample_submission.csv'))
        return db
    
    def __len__(self):
        return len(self.db)
    
    def __getitem__(self, index):
        data = copy.deepcopy(self.db.loc[index])
        
        # Loading image
        cvimg = cv2.imread(os.path.join(self.data_dir,'test',data['file_name']), cv2.IMREAD_COLOR | cv2.IMREAD_IGNORE_ORIENTATION)
        if not isinstance(cvimg, np.ndarray):
            raise IOError("Fail to read %s" % data['file_name'])

        # Preprocessing images
        trans_image = self.transform(Image.fromarray(cvimg))

        return trans_image, data['file_name']

In [15]:
# Load dataset & dataloader
test_dataset = TestDataset(data_dir=DATA_DIR, input_shape=INPUT_SHAPE)
test_dataloader = DataLoader(dataset=test_dataset, batch_size=BATCH_SIZE, shuffle=False)

Loading test dataset..


### 추론 진행

In [16]:
model.load_state_dict(torch.load(TRAINED_MODEL_PATH)['model'])

# Prediction
file_lst = []
pred_lst = []
prob_lst = []
model.eval()
with torch.no_grad():
    for batch_index, (img, file_num) in tqdm(enumerate(test_dataloader)):
        img = img.to(DEVICE)
        pred = model(img)
        print(pred)
        file_lst.extend(list(file_num))
        pred_lst.extend(pred.argmax(dim=1).tolist())
        prob_lst.extend(pred[:, 1].tolist())

1it [00:01,  1.98s/it]

tensor([[1.3801e-01, 8.5759e-01, 4.3982e-03],
        [9.9975e-01, 2.2514e-04, 2.2397e-05],
        [9.9552e-01, 3.8589e-03, 6.2211e-04],
        [2.4024e-03, 9.9732e-01, 2.7396e-04],
        [9.9582e-01, 3.8751e-03, 3.0537e-04],
        [3.4187e-03, 9.9551e-01, 1.0666e-03],
        [9.0431e-01, 9.2363e-02, 3.3232e-03],
        [3.9819e-03, 9.9514e-01, 8.7793e-04],
        [9.9998e-01, 1.3759e-05, 5.6169e-06],
        [4.7707e-07, 1.0000e+00, 4.7707e-07],
        [7.9055e-01, 1.9400e-01, 1.5453e-02],
        [2.3019e-01, 7.6789e-01, 1.9185e-03],
        [7.4556e-01, 2.4664e-01, 7.7981e-03],
        [9.9996e-01, 2.9626e-05, 9.3462e-06],
        [9.7935e-01, 2.0368e-02, 2.8672e-04],
        [9.9328e-01, 4.3821e-03, 2.3335e-03],
        [8.7028e-02, 9.1082e-01, 2.1515e-03],
        [9.9993e-01, 5.4847e-05, 1.5152e-05],
        [9.4341e-01, 5.6202e-02, 3.8879e-04],
        [9.9657e-01, 2.1727e-03, 1.2566e-03],
        [9.9985e-01, 1.3203e-04, 1.7241e-05],
        [9.9486e-01, 1.9866e-03, 3

2it [00:04,  2.02s/it]

tensor([[9.1338e-01, 8.6042e-02, 5.7739e-04],
        [4.8740e-01, 5.1118e-01, 1.4212e-03],
        [9.9926e-01, 5.4831e-04, 1.9108e-04],
        [6.2832e-01, 3.6371e-01, 7.9713e-03],
        [9.9986e-01, 1.3426e-04, 1.5110e-06],
        [2.2354e-04, 9.9964e-01, 1.3458e-04],
        [8.1557e-01, 1.8050e-01, 3.9229e-03],
        [3.4005e-02, 9.6463e-01, 1.3607e-03],
        [5.1831e-04, 9.9928e-01, 1.9746e-04],
        [1.1261e-01, 8.8409e-01, 3.3070e-03],
        [9.8835e-01, 8.8927e-03, 2.7617e-03],
        [8.8881e-02, 9.1027e-01, 8.5066e-04],
        [4.0766e-01, 5.9106e-01, 1.2794e-03],
        [4.7891e-04, 9.9929e-01, 2.3139e-04],
        [2.6033e-02, 9.7274e-01, 1.2247e-03],
        [1.0000e+00, 1.2000e-06, 1.2000e-06],
        [9.9030e-01, 9.6025e-03, 9.9422e-05],
        [9.7843e-01, 2.0672e-02, 8.9695e-04],
        [9.5035e-02, 9.0239e-01, 2.5742e-03],
        [1.4614e-04, 9.9981e-01, 3.8972e-05],
        [7.3700e-01, 2.5326e-01, 9.7383e-03],
        [9.0691e-01, 8.8441e-02, 4

3it [00:06,  2.07s/it]

tensor([[5.0425e-04, 9.9941e-01, 8.7187e-05],
        [1.5135e-05, 9.9998e-01, 4.7989e-06],
        [2.7440e-01, 7.2495e-01, 6.4874e-04],
        [9.1990e-01, 7.9553e-02, 5.4568e-04],
        [7.2264e-04, 9.9901e-01, 2.7032e-04],
        [9.8680e-01, 1.2784e-02, 4.1208e-04],
        [9.8019e-01, 1.5271e-02, 4.5382e-03],
        [9.9800e-01, 1.7619e-03, 2.4023e-04],
        [9.9995e-01, 1.0613e-05, 3.7890e-05],
        [9.9975e-01, 2.4134e-04, 1.0664e-05],
        [5.6133e-06, 9.9999e-01, 3.7368e-06],
        [9.9999e-01, 7.6593e-06, 3.4212e-06],
        [9.9981e-01, 7.5773e-05, 1.1499e-04],
        [9.5777e-01, 4.2225e-02, 7.5419e-06],
        [2.4648e-02, 9.7514e-01, 2.1342e-04],
        [2.0938e-01, 7.8865e-01, 1.9693e-03],
        [5.2303e-04, 9.9946e-01, 1.3764e-05],
        [9.9244e-01, 6.5077e-03, 1.0513e-03],
        [9.8780e-01, 1.1437e-02, 7.5943e-04],
        [2.1036e-02, 9.7822e-01, 7.4201e-04],
        [1.0619e-01, 8.9042e-01, 3.3894e-03],
        [8.3669e-01, 1.6222e-01, 1

4it [00:06,  1.62s/it]


### 결과 저장

In [17]:
df = pd.DataFrame({'file_name':file_lst, 'COVID':pred_lst})
# df.sort_values(by=['file_name'], inplace=True)
df.to_csv('prediction.csv', index=False)