# Experiment 2: Cross Validation

In [1]:
import os
import sys
import pickle
import glob
import time
from tqdm import tqdm
from collections import Counter

# scikit-learn
from sklearn.metrics import f1_score, accuracy_score
from sklearn.model_selection import StratifiedKFold

# Data preprocessing
import cv2
import numpy as np
import pandas as pd

# data visualization
import matplotlib.pyplot as plt
import seaborn as sns
from PIL import Image
%matplotlib inline

# pytorch
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, random_split
import torchvision
from torchvision import datasets, transforms
torch.manual_seed(0)
print(f'PyTorch version: {torch.__version__}')

# device setting
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
print(f'This notebook use {device}')

# ignore warnings
import warnings
warnings.filterwarnings('ignore')

PyTorch version: 1.7.1
This notebook use cuda:0


In [2]:
# 파일 경로 사용자 정의
class path:
    data = '/opt/ml/input/original_data'
    train = f'{data}/train'
    train_img = f'{train}/images'
    train_df = f'{train}/train.csv'
    test = f'{data}/eval'
    test_img = f'{test}/images'
    test_df = f'{test}/info.csv'

In [3]:
BATCH_SIZE = 16
NUM_WORKERS = 2
LEARNING_RATE = 1e-4
EPOCHS = 3

In [4]:
df = pd.read_csv(f'{path.train}/train_modified.csv')
pd.DataFrame(df['target'].value_counts())

Unnamed: 0,target
4,4085
3,3660
0,2745
1,2050
16,817
10,817
15,732
9,732
12,549
6,549


## 1. Dataset Definition

In [6]:
class MaskDataset(Dataset):
    def __init__(self, df, transform=None):
        self.df = df
        self.transform = transform
        
    def set_transform(self, transform):
        self.transform = transform
        
    def __getitem__(self, idx):
        data = self.df.iloc[idx]
        target = data.target
        image = Image.open(data.path)
        
        if self.transform:
            image = self.transform(image)
            
        return image, target
    
    def __len__(self):
        return len(self.df)

In [7]:
class AddGaussianNoise(object):
    def __init__(self, mean=0., std=1.):
        self.std = std
        self.mean = mean
        
    def __call__(self, tensor):
        return tensor + torch.randn(tensor.size()) * self.std + self.mean
    
    def __repr__(self):
        return self.__class__.__name__ + '(mean={0}, std={1})'.format(self.mean, self.std)

In [9]:
train_transforms = transforms.Compose([
    transforms.CenterCrop(384),
    transforms.Resize(224),
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.ColorJitter(brightness=0.5, saturation=0.5, hue=0.5),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.548, 0.504, 0.479], std=[0.237, 0.247, 0.246]),
    AddGaussianNoise(0., 1.),
])

In [10]:
valid_transforms = transforms.Compose([
    transforms.CenterCrop(384),
    transforms.Resize(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.548, 0.504, 0.479], std=[0.237, 0.247, 0.246]),
])

## 2. Modeling

In [12]:
model = torchvision.models.resnet18(pretrained=False)
n_features = model.fc.in_features
model.fc = nn.Linear(n_features, 18)
model = model.cuda()

In [13]:
optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)
criterion = nn.CrossEntropyLoss().to(device)

## 3. Training

In [26]:
def test_eval(model, valid_dataset):
    model.eval()
    with torch.no_grad():
        y_true, y_pred = [], []
        for image, label in tqdm(valid_dataset):
            X = image.float().to(device)
            y = label.item()
            _, pred = torch.max(model(X), 1)
            pred = pred.item()
            y_true.append(y)
            y_pred.append(pred)
        y_true, y_pred = np.array(y_true), np.array(y_pred)
        f1 = f1_score(y_true, y_pred, average='macro')
        accuracy = accuracy_score(y_true, y_pred)
    model.train()
    return f1, accuracy

In [27]:
def train_model(train, test, model, criterion, optimizer, print_every=1):
    print(f"============ Training Starts! ============")
    best_accuracy = 0
    for epoch in range(EPOCHS):
        loss_sum = 0
        for images, label in tqdm(train):
            X = images.float().to(device)
            y = label.to(device)
            
            y_pred = model(X)
            loss = criterion(y_pred, y)
            
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
            loss_sum += loss
            
        if ((epoch % print_every) == 0) or (epoch == (EPOCHS - 1)):
            loss_avg = loss_sum / len(train)
            f1, accuracy = test_eval(model, test)
            print(f">> epoch:[{epoch + 1}/{EPOCHS}] cost: {loss_avg:5.3f} test_accuracy: {accuracy:5.3f} test_f1_score: {f1:5.3f}")
            
    print(f"============ Training Done! ============")

In [28]:
def cross_validation(df, model, criterion, optimizer, k_folds=5):
    skf = StratifiedKFold(n_splits=5)
    for n_iter, (train_idx, valid_idx) in enumerate(skf.split(df, df.target), start=1):
        print(f'>> Cross Validation {n_iter} Starts!')
        train, valid = df.loc[train_idx], df.loc[valid_idx]
        train_dataset, valid_dataset = MaskDataset(train), MaskDataset(valid)
        
        # augmentation 설정
        train_dataset.set_transform(train_transforms)
        valid_dataset.set_transform(valid_transforms)
        
        # DataLoader 생성
        train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, num_workers=NUM_WORKERS, shuffle=True)
        valid_loader = DataLoader(valid_dataset, shuffle=False)
        
        train_model(train_loader, valid_loader, model, criterion, optimizer)
        print()

In [29]:
df = pd.read_csv(f'{path.train}/train_modified.csv')[['path', 'target']]
cross_validation(df, model, criterion, optimizer)

>> Cross Validation 1 Starts!


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=945.0), HTML(value='')))

Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7f03dc605430>
Traceback (most recent call last):
  File "/opt/conda/lib/python3.8/site-packages/torch/utils/data/dataloader.py", line 1203, in __del__
    self._shutdown_workers()
  File "/opt/conda/lib/python3.8/site-packages/torch/utils/data/dataloader.py", line 1177, in _shutdown_workers
    w.join(timeout=_utils.MP_STATUS_CHECK_INTERVAL)
  File "/opt/conda/lib/python3.8/multiprocessing/process.py", line 147, in join
    assert self._parent_pid == os.getpid(), 'can only join a child process'
AssertionError: can only join a child process
Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7f03dc605430>
Traceback (most recent call last):
  File "/opt/conda/lib/python3.8/site-packages/torch/utils/data/dataloader.py", line 1203, in __del__
    self._shutdown_workers()
  File "/opt/conda/lib/python3.8/site-packages/torch/utils/data/dataloader.py", line 1177, in _shutdown_workers
    w.join




HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=3780.0), HTML(value='')))


>> epoch:[1/3] cost: 0.905 test_accuracy: 0.358 test_f1_score: 0.233


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=945.0), HTML(value='')))

Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7f03dc605430>
Traceback (most recent call last):
  File "/opt/conda/lib/python3.8/site-packages/torch/utils/data/dataloader.py", line 1203, in __del__
    self._shutdown_workers()
  File "/opt/conda/lib/python3.8/site-packages/torch/utils/data/dataloader.py", line 1177, in _shutdown_workers
    w.join(timeout=_utils.MP_STATUS_CHECK_INTERVAL)
  File "/opt/conda/lib/python3.8/multiprocessing/process.py", line 147, in join
    assert self._parent_pid == os.getpid(), 'can only join a child process'
AssertionError: can only join a child process
Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7f03dc605430>
Traceback (most recent call last):
  File "/opt/conda/lib/python3.8/site-packages/torch/utils/data/dataloader.py", line 1203, in __del__
    self._shutdown_workers()
  File "/opt/conda/lib/python3.8/site-packages/torch/utils/data/dataloader.py", line 1177, in _shutdown_workers
    w.join




HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=3780.0), HTML(value='')))


>> epoch:[2/3] cost: 0.662 test_accuracy: 0.527 test_f1_score: 0.384


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=945.0), HTML(value='')))

Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7f03dc605430>
Traceback (most recent call last):
  File "/opt/conda/lib/python3.8/site-packages/torch/utils/data/dataloader.py", line 1203, in __del__
    self._shutdown_workers()
  File "/opt/conda/lib/python3.8/site-packages/torch/utils/data/dataloader.py", line 1177, in _shutdown_workers
    w.join(timeout=_utils.MP_STATUS_CHECK_INTERVAL)
  File "/opt/conda/lib/python3.8/multiprocessing/process.py", line 147, in join
    assert self._parent_pid == os.getpid(), 'can only join a child process'
AssertionError: can only join a child process
Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7f03dc605430>
Traceback (most recent call last):
  File "/opt/conda/lib/python3.8/site-packages/torch/utils/data/dataloader.py", line 1203, in __del__
    self._shutdown_workers()
  File "/opt/conda/lib/python3.8/site-packages/torch/utils/data/dataloader.py", line 1177, in _shutdown_workers
    w.join




HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=3780.0), HTML(value='')))


>> epoch:[3/3] cost: 0.540 test_accuracy: 0.579 test_f1_score: 0.412

>> Cross Validation 2 Starts!


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=945.0), HTML(value='')))

Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7f03dc605430>
Traceback (most recent call last):
      File "/opt/conda/lib/python3.8/site-packages/torch/utils/data/dataloader.py", line 1203, in __del__
self._shutdown_workers()    
  File "/opt/conda/lib/python3.8/site-packages/torch/utils/data/dataloader.py", line 1177, in _shutdown_workers
w.join(timeout=_utils.MP_STATUS_CHECK_INTERVAL)
  File "/opt/conda/lib/python3.8/multiprocessing/process.py", line 147, in join
    assert self._parent_pid == os.getpid(), 'can only join a child process'
AssertionError: can only join a child process
Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7f03dc605430>
Traceback (most recent call last):
  File "/opt/conda/lib/python3.8/site-packages/torch/utils/data/dataloader.py", line 1203, in __del__
    self._shutdown_workers()
  File "/opt/conda/lib/python3.8/site-packages/torch/utils/data/dataloader.py", line 1177, in _shutdown_workers
    w.join




HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=3780.0), HTML(value='')))


>> epoch:[1/3] cost: 0.498 test_accuracy: 0.653 test_f1_score: 0.542


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=945.0), HTML(value='')))

Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7f03dc605430>
Traceback (most recent call last):
  File "/opt/conda/lib/python3.8/site-packages/torch/utils/data/dataloader.py", line 1203, in __del__
    self._shutdown_workers()
  File "/opt/conda/lib/python3.8/site-packages/torch/utils/data/dataloader.py", line 1177, in _shutdown_workers
    w.join(timeout=_utils.MP_STATUS_CHECK_INTERVAL)
  File "/opt/conda/lib/python3.8/multiprocessing/process.py", line 147, in join
    assert self._parent_pid == os.getpid(), 'can only join a child process'
AssertionError: can only join a child process
Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7f03dc605430>
Traceback (most recent call last):
  File "/opt/conda/lib/python3.8/site-packages/torch/utils/data/dataloader.py", line 1203, in __del__
    self._shutdown_workers()
  File "/opt/conda/lib/python3.8/site-packages/torch/utils/data/dataloader.py", line 1177, in _shutdown_workers
    w.join




HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=3780.0), HTML(value='')))


>> epoch:[2/3] cost: 0.395 test_accuracy: 0.584 test_f1_score: 0.455


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=945.0), HTML(value='')))




HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=3780.0), HTML(value='')))


>> epoch:[3/3] cost: 0.329 test_accuracy: 0.696 test_f1_score: 0.568

>> Cross Validation 3 Starts!


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=945.0), HTML(value='')))




HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=3780.0), HTML(value='')))


>> epoch:[1/3] cost: 0.310 test_accuracy: 0.766 test_f1_score: 0.607


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=945.0), HTML(value='')))




HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=3780.0), HTML(value='')))


>> epoch:[2/3] cost: 0.253 test_accuracy: 0.726 test_f1_score: 0.579


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=945.0), HTML(value='')))




HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=3780.0), HTML(value='')))


>> epoch:[3/3] cost: 0.208 test_accuracy: 0.788 test_f1_score: 0.632

>> Cross Validation 4 Starts!


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=945.0), HTML(value='')))




HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=3780.0), HTML(value='')))


>> epoch:[1/3] cost: 0.235 test_accuracy: 0.871 test_f1_score: 0.791


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=945.0), HTML(value='')))




HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=3780.0), HTML(value='')))


>> epoch:[2/3] cost: 0.178 test_accuracy: 0.807 test_f1_score: 0.735


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=945.0), HTML(value='')))




HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=3780.0), HTML(value='')))


>> epoch:[3/3] cost: 0.150 test_accuracy: 0.845 test_f1_score: 0.762

>> Cross Validation 5 Starts!


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=945.0), HTML(value='')))




HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=3780.0), HTML(value='')))


>> epoch:[1/3] cost: 0.171 test_accuracy: 0.857 test_f1_score: 0.802


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=945.0), HTML(value='')))




HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=3780.0), HTML(value='')))


>> epoch:[2/3] cost: 0.136 test_accuracy: 0.889 test_f1_score: 0.833


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=945.0), HTML(value='')))




HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=3780.0), HTML(value='')))


>> epoch:[3/3] cost: 0.121 test_accuracy: 0.830 test_f1_score: 0.720

