In [1]:
import os
import sys
import gc
import pickle
import warnings
import pandas as pd
import numpy as np
import seaborn as sns
from typing import *
from tqdm.notebook import tqdm
from pathlib import Path
from matplotlib import pyplot as plt

pd.set_option('max_columns', 50)
warnings.simplefilter('ignore')

In [2]:
base_dir = Path().resolve()
sys.path.append(str(base_dir / '../'))

from utils.preprocess import *
from utils.model import *

fail to import apex_C: apex was not installed or installed without --cpp_ext.
fail to import amp_C: apex was not installed or installed without --cpp_ext.


In [3]:
def seed_everything(seed: int):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.cuda.empty_cache()

In [4]:
from dataclasses import dataclass, field, asdict
import yaml


@dataclass
class Config:
    # General
    debug: bool = False
    outdir: str = "results01"
    device: str = "cuda:1"

    # Data config
    imgdir_name: str = "../../data/ChestXRay14"
    seed: int = 111
    n_splits: int = 10
    label_smoothing: float = 1e-2
    
    # Model config
    model_name: str = "resnet18"
    model_mode: str = "normal"  # normal, cnn_fixed supported
        
    # Training config
    epoch: int = 20
    lr: float = 1e-3
    lr_decay: float = 0.9
    batchsize: int = 8
    valid_batchsize: int = 16
    patience: int = 3
    num_workers: int = 4
    snapshot_freq: int = 5
    scheduler_type: str = ""
    scheduler_kwargs: Dict[str, Any] = field(default_factory=lambda: {})
    scheduler_trigger: List[Union[int, str]] = field(default_factory=lambda: [1, "iteration"])
    aug_kwargs: Dict[str, Dict[str, Any]] = field(default_factory=lambda: {})
    mixup_prob: float = 0.

    def update(self, param_dict: Dict) -> "Config":
        # Overwrite by `param_dict`
        for key, value in param_dict.items():
            if not hasattr(self, key):
                raise ValueError(f"[ERROR] Unexpected key for flag = {key}")
            setattr(self, key, value)
        return self
    
    def to_yaml(self, filepath: str, width: int = 120):
        with open(filepath, 'w') as f:
            yaml.dump(asdict(self), f, width=width)

In [5]:
config_dict = {
    "debug": True,
    # Data Config
    "n_splits": 5,
    # Model
    "model_name": "resnet18",
    # Training
    "num_workers": 4,
    "epoch": 10,
    "batchsize": 32,
    "lr": 1e-4,
    "lr_decay": 0.9,
    "patience": 5,
    "scheduler_type": "CosineAnnealingWarmRestarts",
    "scheduler_kwargs": {"T_0": 7032, 'verbose': True},  # 15000 * 15 epoch // (batchsize=8)
    "scheduler_trigger": [1, "iteration"],
    "aug_kwargs": {
        "HorizontalFlip": {"p": 0.5},
        "ShiftScaleRotate": {"scale_limit": 0.15, "rotate_limit": 10, "p": 0.5},
        "RandomBrightnessContrast": {"p": 0.5},
        "CoarseDropout": {"max_holes": 8, "max_height": 25, "max_width": 25, "p": 0.5},
        "Blur": {"blur_limit": [3, 7], "p": 0.5},
        "Downscale": {"scale_min": 0.25, "scale_max": 0.9, "p": 0.3},
        "RandomGamma": {"gamma_limit": [80, 120], "p": 0.6},
    },
    "mixup_prob": 0
}

config = Config().update(config_dict)

In [6]:
skf = StratifiedKFoldWrapper(
    datadir=base_dir / config.imgdir_name,
    n_splits=config.n_splits,
    shuffle=True,
    seed=config.seed,
    label_smoothing=config.label_smoothing,
    mixup_prob=config.mixup_prob,
    aug_kwargs=config.aug_kwargs,
    debug=config.debug,
    oversample=False
)

In [7]:
config.to_yaml(str(base_dir / config.outdir / 'config.yaml'))
config

Config(debug=True, outdir='results01', device='cuda:1', imgdir_name='../../data/ChestXRay14', seed=111, n_splits=5, label_smoothing=0.01, model_name='resnet18', model_mode='normal', epoch=10, lr=0.0001, lr_decay=0.9, batchsize=32, valid_batchsize=16, patience=5, num_workers=4, snapshot_freq=5, scheduler_type='CosineAnnealingWarmRestarts', scheduler_kwargs={'T_0': 7032, 'verbose': True}, scheduler_trigger=[1, 'iteration'], aug_kwargs={'HorizontalFlip': {'p': 0.5}, 'ShiftScaleRotate': {'scale_limit': 0.15, 'rotate_limit': 10, 'p': 0.5}, 'RandomBrightnessContrast': {'p': 0.5}, 'CoarseDropout': {'max_holes': 8, 'max_height': 25, 'max_width': 25, 'p': 0.5}, 'Blur': {'blur_limit': [3, 7], 'p': 0.5}, 'Downscale': {'scale_min': 0.25, 'scale_max': 0.9, 'p': 0.3}, 'RandomGamma': {'gamma_limit': [80, 120], 'p': 0.6}}, mixup_prob=0)

In [8]:
import dataclasses
import pytorch_pfn_extras.training.extensions as E
import torch
from torch import nn, optim
from torch.utils.data.dataloader import DataLoader
from ignite.engine import Events, Events, create_supervised_evaluator, create_supervised_trainer
from ignite.handlers import EarlyStopping
from ignite.metrics import Accuracy, Loss
from ignite.contrib.handlers import ProgressBar
from ignite.contrib.handlers.param_scheduler import LRScheduler
from ignite.contrib.metrics import AveragePrecision

In [9]:
def score_function(engine):
    """
    Due to maximizing score_function, this returns (-1) x loss
    """
    val_loss = engine.state.metrics['BCE']
    return -val_loss


def discreted_output_transform(output):
    y_pred, y = output
    y_pred = torch.argmax(y_pred, dim=-1)
    y = torch.argmax(y, dim=-1)
    return y_pred, y


def probability_output_transform(output):
    y_pred, y = output
    y_pred = torch.softmax(y_pred, dim=1)[:, 1]
    y = torch.argmax(y, dim=1)
    return y_pred, y


def train(epochs: int, model: nn.Module, train_loader: DataLoader, valid_loader: DataLoader, criterion: Callable,
          device: str, lr: float, patience: int, lr_decay: float, lr_scheduler: str, lr_scheduler_kwargs: Dict[str, Any]):
    
    model.to(torch.device(device))
    optimizer = optim.Adam([param for param in model.parameters() if param.requires_grad], lr=lr)
    
    trainer = create_supervised_trainer(
        model, 
        optimizer, 
        criterion, 
        device=device
    )
    
    scheduler = LRScheduler(getattr(optim.lr_scheduler, lr_scheduler)(optimizer, **lr_scheduler_kwargs))
    trainer.add_event_handler(Events.ITERATION_COMPLETED, scheduler)
    
    pbar = ProgressBar(False)
    pbar.attach(trainer)
    
    train_evaluator = create_supervised_evaluator(
        model,
        metrics={'ACC': Accuracy(discreted_output_transform), 'BCE': Loss(criterion), 'AP': AveragePrecision(probability_output_transform)},
        device=device
    )
    valid_evaluator = create_supervised_evaluator(
        model,
        metrics={'ACC': Accuracy(discreted_output_transform), 'BCE': Loss(criterion), 'AP': AveragePrecision(probability_output_transform)},
        device=device
    )
    
    history = {col: list() for col in ['epoch', 'elapsed time', 'iterations', 'lr', 'train BCE', 'valid BCE', 'train ACC', 'valid ACC', 'train AP', 'valid AP']}

    
    @trainer.on(Events.EPOCH_COMPLETED)
    def log_training_results(engine):
        train_evaluator.run(train_loader)
        
        history['train BCE'] += [train_evaluator.state.metrics['BCE']]
        history['train ACC'] += [train_evaluator.state.metrics['ACC']]
        history['train AP'] += [train_evaluator.state.metrics['AP']]
        
    @trainer.on(Events.EPOCH_COMPLETED)
    def log_validation_results(engine):
        valid_evaluator.run(valid_loader)
        
        history['epoch'] += [valid_evaluator.state.epoch]
        history['iterations'] += [valid_evaluator.state.epoch_length]
        history['elapsed time'] += [0 if len(history['elapsed time']) == 0 else history['elapsed time'][-1] + valid_evaluator.state.times['COMPLETED']]
        history['lr'] += [scheduler.get_param()]
        
        history['valid BCE'] += [valid_evaluator.state.metrics['BCE']]
        history['valid ACC'] += [valid_evaluator.state.metrics['ACC']]
        history['valid AP'] += [valid_evaluator.state.metrics['AP']]
        
    @trainer.on(Events.EPOCH_COMPLETED)
    def log_progress_bar(engine):
        pbar.log_message(
            f"train BCE: {history['train BCE'][-1]:.2f} " \
            + f"train ACC: {history['train ACC'][-1]:.2f} " \
            + f"train AP: {history['train AP'][-1]:.2f} " \
            + f"valid BCE: {history['valid BCE'][-1]:.2f} " \
            + f"valid ACC: {history['valid ACC'][-1]:.2f} " \
            + f"valid AP: {history['valid AP'][-1]:.2f}"
        )
                
    
    # Early stopping
    handler = EarlyStopping(patience=patience, score_function=score_function, trainer=trainer)
    valid_evaluator.add_event_handler(Events.EPOCH_COMPLETED, handler)

    trainer.run(train_loader, max_epochs=epochs)
    return pd.DataFrame(history)

In [10]:
for fold, (train_dataset, valid_dataset) in enumerate(skf):
    seed_everything(seed=config.seed)

    train_loader = DataLoader(
        train_dataset,
        batch_size=config.batchsize,
        num_workers=config.num_workers,
        shuffle=True,
        pin_memory=True,
    )
    valid_loader = DataLoader(
        valid_dataset,
        batch_size=config.valid_batchsize,
        num_workers=config.num_workers,
        shuffle=False,
        pin_memory=True,
    )


    predictor = build_predictor(model_name=config.model_name, model_mode=config.model_mode)
    model = Classifier(predictor)

    history_df = train(
        epochs=config.epoch, 
        model=model,
        train_loader=train_loader, 
        valid_loader=valid_loader, 
        criterion=cross_entropy_with_logits,
        device=config.device,
        lr=config.lr,
        lr_decay=config.lr_decay,
        patience=config.patience,
        lr_scheduler=config.scheduler_type,
        lr_scheduler_kwargs=config.scheduler_kwargs
    )
    
    os.mkdir(str(base_dir / config.outdir / f'fold-{fold + 1}'))
    history_df.to_csv(str(base_dir / config.outdir / f'fold-{fold + 1}' / 'history.csv'))
    torch.save(model.state_dict(), str(base_dir / config.outdir / f'fold-{fold + 1}' / 'model_last.pt'))
    
    del model, predictor, history_df, train_dataset, valid_dataset, train_loader, valid_loader
    gc.collect()
    
    break

Epoch     0: adjusting learning rate of group 0 to 1.0000e-04.


[1/32]   3%|3          [00:00<?]

train BCE: 0.61 train ACC: 0.71 train AP: 0.78 valid BCE: 0.65 valid ACC: 0.71 valid AP: 0.77


[1/32]   3%|3          [00:00<?]

train BCE: 0.60 train ACC: 0.72 train AP: 0.83 valid BCE: 0.64 valid ACC: 0.72 valid AP: 0.79


[1/32]   3%|3          [00:00<?]

train BCE: 0.53 train ACC: 0.76 train AP: 0.86 valid BCE: 0.58 valid ACC: 0.72 valid AP: 0.79


[1/32]   3%|3          [00:00<?]

train BCE: 0.52 train ACC: 0.76 train AP: 0.86 valid BCE: 0.58 valid ACC: 0.72 valid AP: 0.80


[1/32]   3%|3          [00:00<?]

train BCE: 0.50 train ACC: 0.77 train AP: 0.89 valid BCE: 0.58 valid ACC: 0.71 valid AP: 0.81


[1/32]   3%|3          [00:00<?]

train BCE: 0.50 train ACC: 0.78 train AP: 0.90 valid BCE: 0.61 valid ACC: 0.72 valid AP: 0.83


[1/32]   3%|3          [00:00<?]

train BCE: 0.47 train ACC: 0.77 train AP: 0.92 valid BCE: 0.66 valid ACC: 0.72 valid AP: 0.82


[1/32]   3%|3          [00:00<?]

train BCE: 0.48 train ACC: 0.80 train AP: 0.91 valid BCE: 0.68 valid ACC: 0.72 valid AP: 0.81


[1/32]   3%|3          [00:00<?]

train BCE: 0.51 train ACC: 0.76 train AP: 0.92 valid BCE: 0.75 valid ACC: 0.72 valid AP: 0.80


[1/32]   3%|3          [00:00<?]

2021-03-03 11:50:31,945 ignite.handlers.early_stopping.EarlyStopping INFO: EarlyStopping: Stop training


train BCE: 0.50 train ACC: 0.76 train AP: 0.94 valid BCE: 0.84 valid ACC: 0.54 valid AP: 0.79
