In [1]:
import os
import sys
import gc
import pickle
import warnings
import pandas as pd
import numpy as np
import seaborn as sns
from typing import *
from tqdm.notebook import tqdm
from pathlib import Path
from matplotlib import pyplot as plt

pd.set_option('max_columns', 50)
warnings.simplefilter('ignore')

In [4]:
base_dir = Path().resolve()
sys.path.append(str(base_dir / '../'))

from utils.preprocess import *
from utils.model import *
from utils.train import *

In [12]:
def seed_everything(seed: int):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.cuda.empty_cache()

In [5]:
from dataclasses import dataclass, field, asdict
import yaml


@dataclass
class Config:
    # General
    debug: bool = False
    outdir: str = "results02"
    device: str = "cuda:2"

    # Data config
    imgdir_name: str = "../../data/ChestXRay14"
    seed: int = 111
    n_splits: int = 10
    label_smoothing: float = 1e-2
    
    # Model config
    model_name: str = "resnet18"
    model_mode: str = "normal"  # normal, cnn_fixed supported
        
    # Training config
    epoch: int = 20
    lr: float = 1e-3
    lr_decay: float = 0.9
    batchsize: int = 8
    valid_batchsize: int = 16
    patience: int = 3
    num_workers: int = 4
    snapshot_freq: int = 5
    scheduler_type: str = ""
    scheduler_kwargs: Dict[str, Any] = field(default_factory=lambda: {})
    aug_kwargs: Dict[str, Dict[str, Any]] = field(default_factory=lambda: {})
    mixup_prob: float = 0.
    oversample: bool = True

    def update(self, param_dict: Dict) -> "Config":
        # Overwrite by `param_dict`
        for key, value in param_dict.items():
            if not hasattr(self, key):
                raise ValueError(f"[ERROR] Unexpected key for flag = {key}")
            setattr(self, key, value)
        return self
    
    def to_yaml(self, filepath: str, width: int = 120):
        with open(filepath, 'w') as f:
            yaml.dump(asdict(self), f, width=width)

In [6]:
config_dict = {
    "debug": True,
    # Data config
    "n_splits": 5,
    "label_smoothing": 0,
    # Model
    "model_name": "resnet18",
    # Training
    "num_workers": 4,
    "epoch": 25,
    "batchsize": 32,
    "lr": 1e-4,
    "lr_decay": 0.9,
    "patience": 5,
    "scheduler_type": "CosineAnnealingWarmRestarts",
    "scheduler_kwargs": {"T_0": 7032, 'verbose': True},  # 15000 * 15 epoch // (batchsize=8)
    "aug_kwargs": {
        "HorizontalFlip": {"p": 0.5},
        "ShiftScaleRotate": {"scale_limit": 0.15, "rotate_limit": 10, "p": 0.5},
        "RandomBrightnessContrast": {"p": 0.5},
        "CoarseDropout": {"max_holes": 8, "max_height": 25, "max_width": 25, "p": 0.5},
        "Blur": {"blur_limit": [3, 7], "p": 0.5},
        "Downscale": {"scale_min": 0.25, "scale_max": 0.9, "p": 0.3},
        "RandomGamma": {"gamma_limit": [80, 120], "p": 0.6},
        "Normalize": {},
    },
    "mixup_prob": 0,
    "oversample": False,
}

config = Config().update(config_dict)

In [7]:
skf = StratifiedKFoldWrapper(
    datadir=base_dir / config.imgdir_name,
    n_splits=config.n_splits,
    shuffle=True,
    seed=config.seed,
    label_smoothing=config.label_smoothing,
    mixup_prob=config.mixup_prob,
    aug_kwargs=config.aug_kwargs,
    debug=config.debug,
    oversample=config.oversample
)

In [8]:
config.to_yaml(str(base_dir / config.outdir / 'config.yaml'))
config

Config(debug=True, outdir='results02', device='cuda:2', imgdir_name='../../data/ChestXRay14', seed=111, n_splits=5, label_smoothing=0, model_name='resnet18', model_mode='normal', epoch=25, lr=0.0001, lr_decay=0.9, batchsize=32, valid_batchsize=16, patience=5, num_workers=4, snapshot_freq=5, scheduler_type='CosineAnnealingWarmRestarts', scheduler_kwargs={'T_0': 7032, 'verbose': True}, aug_kwargs={'HorizontalFlip': {'p': 0.5}, 'ShiftScaleRotate': {'scale_limit': 0.15, 'rotate_limit': 10, 'p': 0.5}, 'RandomBrightnessContrast': {'p': 0.5}, 'CoarseDropout': {'max_holes': 8, 'max_height': 25, 'max_width': 25, 'p': 0.5}, 'Blur': {'blur_limit': [3, 7], 'p': 0.5}, 'Downscale': {'scale_min': 0.25, 'scale_max': 0.9, 'p': 0.3}, 'RandomGamma': {'gamma_limit': [80, 120], 'p': 0.6}, 'Normalize': {}}, mixup_prob=0, oversample=False)

In [None]:
# check outdir
assert len([f for f in os.listdir(str(base_dir / config.outdir)) if 'fold' in f]) == 0

In [13]:
from torch.utils.data.dataloader import DataLoader


for fold, (train_dataset, valid_dataset) in enumerate(skf):
    seed_everything(seed=config.seed)

    train_loader = DataLoader(
        train_dataset,
        batch_size=config.batchsize,
        num_workers=config.num_workers,
        shuffle=True,
        pin_memory=True,
    )
    valid_loader = DataLoader(
        valid_dataset,
        batch_size=config.valid_batchsize,
        num_workers=config.num_workers,
        shuffle=False,
        pin_memory=True,
    )


    predictor = build_predictor(model_name=config.model_name, model_mode=config.model_mode)
    model = Classifier(predictor)

    history_df = train(
        epochs=config.epoch, 
        model=model,
        train_loader=train_loader, 
        valid_loader=valid_loader, 
        criterion=cross_entropy_with_logits,
        device=config.device,
        lr=config.lr,
        lr_decay=config.lr_decay,
        patience=config.patience,
        lr_scheduler=config.scheduler_type,
        lr_scheduler_kwargs=config.scheduler_kwargs
    )
    
    os.mkdir(str(base_dir / config.outdir / f'fold-{fold + 1}'))
    history_df.to_csv(str(base_dir / config.outdir / f'fold-{fold + 1}' / 'history.csv'))
    torch.save(model.state_dict(), str(base_dir / config.outdir / f'fold-{fold + 1}' / 'model_last.pt'))
    
    del model, predictor, history_df, train_dataset, valid_dataset, train_loader, valid_loader
    gc.collect()
    
    break

Epoch     0: adjusting learning rate of group 0 to 1.0000e-04.


[1/32]   3%|3          [00:00<?]

train BCE: 0.57 train ACC: 0.73 train AP: 0.81 valid BCE: 2.84 valid ACC: 0.30 valid AP: 0.76


[1/32]   3%|3          [00:00<?]

train BCE: 0.57 train ACC: 0.73 train AP: 0.81 valid BCE: 35.32 valid ACC: 0.26 valid AP: 0.73


[1/32]   3%|3          [00:00<?]

Engine run is terminating due to exception: 
Engine run is terminating due to exception: 


KeyboardInterrupt: 