In [1]:
import argparse
import datetime
import os
import sys

import numpy as np

from torch.utils.tensorboard import SummaryWriter

import torch
import torch.nn as nn
from torch.optim import SGD, Adam
from torch.utils.data import DataLoader

sys.path.append(os.path.dirname(os.getcwd()))

from util.util import enumerateWithEstimate
from dsets import LunaDataset
from util.logconf import logging
from model import LunaModel

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
log = logging.getLogger(__name__)
log.setLevel(logging.INFO)
log.setLevel(logging.DEBUG)

In [3]:
# Used for computeBatchLoss and logMetrics to index into metrics_t/metrics_a
METRICS_LABEL_NDX = 0
METRICS_PRED_NDX = 1
METRICS_LOSS_NDX = 2
METRICS_SIZE = 3

In [None]:
class LunaTrainingApp:
    
    def __init__(self, sys_argv=None):
        if sys_argv is None:
            sys_argv = sys.argv[1:]
            
        parser = argparse.ArgumentParser()
        parser.add_argument('--num-workers',
                            help='Number of worker processes for background data loading',
                            default=8,
                            type=int,
                           )
        parser.add_argument('--batch-size',
                            help='Batch size to use for training',
                            default=32,
                            type=int,
                           )
        parser.add_argument('--epochs',
                            help='Number of epochs to train for',
                            default=1,
                            type=int,
                           )
        parser.add_argument('--balanced',
                            help='Balance the training data to half positive, half negative.',
                            action='store_true',
                            default=False,
                           )
        parser.add_argument('--tb-prefix',
                            default='p2ch11',
                            help='Data prefix to use for Tensorboard run. Defaults to chapter.',
                           )
        parser.add_argument('comment',
                            help='Comment suffix for Tensorboard run.',
                            nargs='?',
                            default='dwlpt',
                           )
        
        self.cli_args = parser.parse_args(sys_argv)
        self.time_str = datetime.datetime.now().strftime('%Y-%m-%d_%H.%M.%S')
        
        self.trn_writer = None
        self.val_writer = None
        self.totalTrainingSamples_count = 0
        
        self.augmentation_dict = {}
        if self.cli_args.augmented or self.cli_args.augment_flip:
            self.augmentation_dict['flip'] = True
        if self.cli_args.augmented or self.cli_args.augment_offset:
            self.augmentation_dict['offset'] = 0.1  # 경험에서 나온 것이고 더 좋은 값이 존재할 수 있음
        if self.cli_args.augmented or self.cli_args.augment_scale:
            self.augmentation_dict['scale'] = 0.2  # 경험에서 나온 것이고 더 좋은 값이 존재할 수 있음
        if self.cli_args.augmented or self.cli_args.augment_rotate:
            self.augmentation_dict['rotate'] = True
        if self.cli_args.augmented or self.cli_args.augment_noise:
            self.augmentation_dict['noise'] = 25.0  # 경험에서 나온 것이고 더 좋은 값이 존재할 수 있음
        
        self.use_cuda = torch.cuda.is_available()
        self.device = torch.device("cuda" if self.use_cuda else "cpu")
        
        self.model = self.initModel()
        self.optimizer = self.initOptimzier()
        
    def initModel(self):
        model = LunaModel()
        if self.use_cuda:
            log.info('Using CUDA; {} devices.'.format(torch.cuda.device_count()))
            if torch.cuda.device_count() > 1:  # 복수 개의 GPU를 탐지
                model = nn.DataParallel(model)  # 모델을 래핑
            model = model.to(self.device)  # GPU에 모델 파라미터 전달
        return model
    
    def initOpitmizer(self):
        return SGD(self.model.parameters(), lr=0.001, momentum=0.99)
        # return Adam(self.model.parameters())
        
    def initTrainDl(self):
        train_ds = LunaDataset(  # 커스텀 데이터셋
            val_stride=10,
            isValSet_bool=False,
            ratio_int=int(self.cli_args.balanced),  # 파이썬에서 True 가 1로 변환된다.
        )
        
        batch_size = self.cli_args.batch_size
        if self.use_cuda:
            batch_size *= torch.cuda.device_count()
            
        train_dl = DataLoader(  # 바로 사용하면 되는 클래스
            train_ds,
            batch_size=batch_size,  # 알아서 배치로 나뉜다.
            num_workers=self.cli_args.num_workers,
            pin_memory=self.use_cuda,  # 고정된 메모리 영역이 GPU 쪽으로 빠르게 전송된다.
        )
        
        return train_dl
    
    def initValDl(self):
        val_ds = LunaDataset(
            val_stride=10,
            isValSet_bool=True,
        )
        
        batch_size = self.cli_args.batch_size
        if self.use_cuda:
            batch_size *= torch.cuda.device_count()
            
        val_dl = DataLoader(
            val_ds,
            batch_size=batch_size,
            num_workers=self.cli_args.num_workers,
            pin_memory=self.use_cuda,
        )
        
        return val_dl
    
    def initTensorboardWriters(self):
        if self.trn_writer is None:
            log_dir = os.path.join('runs', self.cli_args.tb_prefix, self.time_str)
            
            self.trn_writer = SummaryWriter(
                log_dir=log_dir + '-trn_cls-' + self.cli_args.comment)
            self.val_writer = SummaryWRiter(
                log_dir=log_dir + '-val_cls-' + self.cli_args.comment)
            
    def main(self):
        log.info('Starting {}, {}'.format(type(self).__name__, self.cli_args))
        
        train_dl = self.initTrainDl()
        val_dl = self.initValDl()  # 검증 데이터 로더는 훈련 데이터 로더와 매우 유사하다.
        
        for epoch_ndx in range(1, self.cli_args.epochs + 1):
            trnMetrics_t = self.doTraining(epoch_ndx, train_dl)
            self.logMetrics(epoch_ndx, 'trn', trnMetrics_t)
            
            valMetrics_t = self.doValidation(epoch_ndx, val_dl)
            self.logMetrics(epoch_ndx, 'val', valMetrics_t)
            
    def doTraining(self, epoch_ndx, train_dl):
        self.model.train()
        trnMetrics_g = torch.zeros(  # 빈 메트릭 배열을 초기화
            METRICS_SIZE,
            len(train_dl.dataset),
            device=self.device,
        )
        
        batch_iter = enumerateWithEstimate(  # 시간을 예측하며 배치 루프를 설정한다.
            train_dl,
            "E{} Training".format(epoch_ndx),
            start_ndx=train_dl.num_workers,
        )
        for batch_ndx, batch_tup in batch_iter:
            self.optimizer.zero_grad()  # 남은 가중치 텐서를 해제한다.
            
            loss_var = self.computeBatchLoss(  # 이 코드는 다음 절에서 구체적으로 살펴본다.
                batch_ndx,
                batch_tup,
                train_dl.batch_size,
                trnMetrics_g
            )
            
            # 모델 가중치를 실제로 조정하는 부분
            loss_var.backward()
            self.optimizer.step()
            
        self.totalTrainingSamples_count += len(train_dl.dataset)
        
        return trnMetrics_g.to('cpu')
    
    def doValidation(self, epoch_ndx, val_dl):
        with torch.no_grad():
            self.model.eval()  # 훈련 때 사용했던 기능은 끈다.
            valMetrics_g = torch.zeros(
                METRICS_SIZE,
                len(val_dl.dataset),
                device=self.device,
            )
            
            batch_iter = enumerateWithEstimate(
                val_dl,
                "E{} Validation".format(epoch_ndx),
                start_ndx=val_dl.num_workers,
            )
            for batch_ndx, batch_tup in batch_iter:
                self.computeBatchLoss(
                    batch_ndx, batch_tup, val_dl.batch_size, valMetrics_g
                )
                
        return valMetrics_g.to('cpu')
    
    def computeBatchLoss(self, batch_ndx, batch_tup, batch_size, metrics_g):
        input_t, label_t, _series_list, _center_list = batch_tup
        
        input_g = input_t.to(self.device, non_blocking=True)
        label_g = label_t.to(self.device, non_blocking=True)
        
        logits_g, probability_g = self.model(input_g)
        
        loss_func = nn.CrossEntropyLoss(reduction='none')  # reduction='none' 으로 샘플별 손실값을 얻는다.
        loss_g = loss_func(
            logits_g,
            label_g[:, 1],  # 원핫 인코딩 클래스의 인덱스
        )
        
        start_ndx = batch_ndx * batch_size
        end_ndx = starat_ndx + label_t.size(0)
        
        # 기울기에 의존적인 메트릭이 없으므로 detach 를 사용한다.
        metrics_g[METRICS_LABEL_NDX, start_ndx:end_ndx] = label_g[:, 1].detach()
        metrics_g[METRICS_PRED_NDX, start_ndx:end_ndx] = probability_g[:, 1].detach()
        metrics_g[METRICS_LOSS_NDX, start_ndx:end_ndx] = loss_g.detach()
        
        return loss_g.mean()  # 샘플별 손실값을 단일값으로 합친다.
    
    def logMetrics(
        self,
        epoch_ndx,
        mode_str,
        metrics_t,
        classificationThreshold=0.5,
    ):
        self.initTensorboardWriters()
        log.info("E{} {}".format(epoch_ndx, type(self).__name__,))
        
        negLabel_mask = metrics_t[METRICS_LABEL_NDX] <= classificationThreshold
        negPred_mask = metrics_t[METRICS_PRED_NDX] <= classificationThreshold
        
        posLabel_mask = ~negLabel_mask
        posPred_mask = ~negPred_mask
        
        neg_count = int(negLabel_mask.sum())  # 일반 파이썬 정수로 변환
        pos_count = int(posLabel_mask.sum())
        
        trueNeg_count = neg_correct = int((negLabel_mask & negPred_mask).sum())
        truePos_count = pos_correct = int((posLabel_mask & posPred_mask).sum())
        
        falsePos_count = neg_count - neg_correct
        flaseNeg_count = pos_count - pos_correct
        
        metrics_dict = {}
        metrics_dict['loss/all'] = metrics_t[METRICS_LOSS_NDX].mean()
        metrics_dict['loss/neg'] = metrics_t[METRICS_LOSS_NDX, negLabel_mask].mean()
        metrics_dict['loss/pos'] = metrics_t[METRICS_LOSS_NDX, posLabel_mask].mean()
        
        # np.float32 변환으로 정수 나눗셈을 피한다.
        metrics_dict['correct/all'] = (pos_correct + neg_correct) / np.float32(metrics_t.shape[1]) * 100
        metrics_dict['correct/neg'] = neg_correct / np.float32(neg_count) * 100
        metrics_dict['correct/pos'] = pos_correct / np.float32(pos_count) * 100
        
        precision = metrics_dict['pr/precision'] = truePos_count / np.float32(truePos_count + falsePos_count)
        recall = metrics_dict['pr/recall'] = truePos_count / np.float32(truePos_count + falseNeg_count)
        
        metrics_dict['pr/f1_score'] = 2 * (precision * recall) / (precision + recall)
        
        log.info(
            ("E{} {:8} {loss/all:.4f} loss, " 
             + "{correct/all:-5.1}% correct, "
             + "{pr/precision:.4f} precision, "
             + "{pr/recall:.4f} recall, "
             + "{pr/f1_score:.4f} f1 score").format(
                epoch_ndx,
                mode_str,
                **metrics_dict,
            )
        )
        
        log.info(
            ("E{} {:8} {loss/neg:.4f} loss, " 
             + "{correct/neg:-5.1f}% correct, ({neg_correct:} of {neg_count:})").format(
                epoch_ndx, 
                mode_str + '_neg', 
                neg_correct=neg_correct,
                neg_count=neg_count,
                **metrics_dict,
            )
        )
        
        log.info(
            ("E{} {:8} {loss/pos:.4f} loss, " 
             + "{correct/pos:-5.1f}% correct ({pos_correct:} of {pos_count:})").format(
                epoch_ndx,
                mode_str + '_pos',
                pos_correct=pos_correct,
                pos_count=pos_count,
                **metrics_dict,
            )
        )
        
        writer = getattr(self, mode_str + '_writer')
        
        for key, value in metrics_dict.items():
            writer.add_scalar(key, value, self.totalTrainingSamples_count)
            
        writer.add_pr_curve(
            'pr',
            metrics_t[METRICS_LABEL_NDX],
            metrics_t[METRICS_PRED_NDX],
            self.totalTrainingSamples_count,
        )
        
        bins = [x / 50.0 for x in range(51)]
        
        negHist_mask = negLabel_mask & (metrics_t[METRICS_PRED_NDX] > 0.01)
        posHist_mask = posLabel_mask & (metrics_t[METRICS_PRED_NDX] < 0.99)
        
        if negHist_mask.any():
            writer.add_histogram(
                'is_neg',
                metrics_t[METRICS_PRED_NDX, negHist_mask],
                self.totalTrainingSamples_count,
                bins=bins,
            )
            
        if posHist_mask.any():
            writer.add_histogram(
                'is_pos',
                metrics_t[MEGRICS_PRED_NDX, posHist_mask],
                self.totalTrainingSamples_count,
                bins=bins,
            )