In [7]:
import math
import os
from itertools import islice
sep = os.sep
import numpy as np
import pydicom
from torch.utils.data.dataloader import DataLoader
from torch.utils.data.dataset import Dataset, random_split
import torchvision.transforms as tmf

In [8]:
import img_utils as iu
import nnviz as viz
from measurements import ScoreAccumulator

In [9]:
EPIDURAL = 'epidural'
INTRAPARENCHYMAL = 'intraparenchymal'
INTRAVENTRICULAR = 'intraventricular'
SUBARACHNOID = 'subarachnoid'
SUBDURAL = 'subdural'
ANY = 'any'

In [10]:
train_mapping_file = '../input/rsna-intracranial-hemorrhage-detection/stage_1_train.csv'
train_images_dir = '../input/rsna-intracranial-hemorrhage-detection/stage_1_train_images/'
test_mapping_file = '../input/rsna-intracranial-hemorrhage-detection/stage_1_sample_submission.csv'
test_images_dir = '../input/rsna-intracranial-hemorrhage-detection/stage_1_test_images/'

In [21]:
train_transforms = tmf.Compose([
    tmf.ToPILImage(),
    tmf.Resize((512, 512), interpolation=2),
    tmf.RandomHorizontalFlip(),
    tmf.RandomVerticalFlip(),
    tmf.ToTensor()
])

In [22]:
test_transforms = tmf.Compose([
    tmf.ToPILImage(),
    tmf.Resize((512, 512), interpolation=2),
    tmf.ToTensor()
])

In [24]:
class SkullDataset(Dataset):
    def __init__(self, conf=None, mode=None, transforms=None):
        self.transforms = transforms
        self.mode = mode
        self.conf = conf
        self.image_dir = None
        self.mapping_file = None
        self.expand_by = self.conf.get('Params').get('expand_patch_by')
        self.indices = []
        self.LIM = conf.get('load_lim', 10e10)

    def load_data_indices(self, validate_pth=False):
        with open(self.mapping_file) as infile:
            linecount, six_rows, _ = 1, True, next(infile)
            while six_rows:
                print('Reading Line: {}'.format(linecount), end='\r')

                six_rows = list(r.rstrip().split(',') for r in islice(infile, 6))
                image_file, cat_label = None, []
                for hname, label in six_rows:
                    (ID, file_ID, htype), label = hname.split('_'), int(label if self.mode == 'train' else 0)
                    fname_ = ID + '_' + file_ID + '.dcm'

                    if validate_pth and not os.path.exists(os.path.join(self.image_dir, fname_)):
                        break

                    if image_file and fname_ != image_file:
                        print('Mismatch Line: {}'.format(linecount), end='\r')
                        break
                    else:
                        image_file = fname_

                    cat_label.append(label)

                if image_file and len(cat_label) == 6:
                    self.indices.append([image_file, np.array(cat_label)])
                    if len(self) >= self.LIM:
                        break
                linecount += 6

    def __getitem__(self, index):
        image_file, label = self.indices[index]

        dcm = pydicom.dcmread(self.image_dir + os.sep + image_file)
        img_arr = np.array(iu.rescale2d_unsigned(dcm.pixel_array) * 255, dtype=np.uint8)
        img_arr = iu.apply_clahe(img_arr)

        if self.transforms is not None:
            img_arr = self.transforms(img_arr)
            
        return {'inputs':img_arr, 'labels':label, 'index':index}

    def __len__(self):
        return len(self.indices)

    @classmethod
    def get_test_loader(cls, conf, shuffle_indices=False):
        testset = cls(conf, 'test', test_transforms)
        testset.image_dir = conf['Dirs']['test_image_dir']
        testset.mapping_file = conf['test_mapping_file']
        testset.load_data_indices(conf['validate_img_pth'])
        return DataLoader(dataset=testset, batch_size=conf['Params']['batch_size'], shuffle=shuffle_indices,
                          num_workers=5, drop_last=False)

    @classmethod
    def get_train_val_loader(cls, conf, shuffle_indices=True, drop_last_batch=True, split_ratio=[0.8, 0.2]):
        full_dataset = cls(conf, 'train', train_transforms)
        full_dataset.image_dir = conf['Dirs']['train_image_dir']
        full_dataset.mapping_file = conf['train_mapping_file']
        full_dataset.load_data_indices(conf['validate_img_pth'])
        size_a = math.ceil(split_ratio[0] * len(full_dataset))
        size_b = math.floor(split_ratio[1] * len(full_dataset))
        dataset_a, dataset_b = random_split(full_dataset, [size_a, size_b])
        loader_a = DataLoader(dataset_a,
                              batch_size=conf['Params']['batch_size'], 
                              shuffle=shuffle_indices, num_workers=3, 
                              drop_last=drop_last_batch)
        loader_b = DataLoader(dataset_b,
                              batch_size=conf['Params']['batch_size'],
                              shuffle=shuffle_indices, num_workers=3, 
                              drop_last=drop_last_batch)
        return loader_a, loader_b

In [25]:
# img_plot = images_arr.copy()
# plt.tight_layout()
# fig, axes = plt.subplots(4, 3, figsize=(10, 18), gridspec_kw = {'wspace':0.01, 'hspace':0.01})
# for i in range(axes.shape[0]):
#     for j in range(axes.shape[1]):
#         axes[i, j].imshow(img_plot.pop(), 'gray')
#         axes[i, j].set_xticklabels([])
#         axes[i, j].set_yticklabels([])
# plt.show()

# Model

In [26]:
from torch import nn


class _DoubleConvolution(nn.Module):
    def __init__(self, in_channels, middle_channel, out_channels, p=0):
        super(_DoubleConvolution, self).__init__()
        layers = [
            nn.Conv2d(in_channels, middle_channel, kernel_size=3, padding=p),
            nn.BatchNorm2d(middle_channel),
            nn.ReLU(inplace=True),
            nn.Conv2d(middle_channel, out_channels, kernel_size=3, padding=p),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        ]
        self.encode = nn.Sequential(*layers)

    def forward(self, x):
        return self.encode(x)


class SkullNet(nn.Module):
    def __init__(self, num_channels, num_classes):
        super(SkullNet, self).__init__()
        self.reduce_by = 2
        self.num_classes = num_classes

        self.C1 = _DoubleConvolution(num_channels, int(64 / self.reduce_by), int(64 / self.reduce_by))
        self.C2 = _DoubleConvolution(int(64 / self.reduce_by), int(128 / self.reduce_by), int(128 / self.reduce_by))
        self.C3 = _DoubleConvolution(int(128 / self.reduce_by), int(256 / self.reduce_by), int(256 / self.reduce_by))
        self.C4 = _DoubleConvolution(int(256 / self.reduce_by), int(512 / self.reduce_by), int(256 / self.reduce_by))
        self.C5 = _DoubleConvolution(int(256 / self.reduce_by), int(128 / self.reduce_by), int(128 / self.reduce_by))
        self.C6 = _DoubleConvolution(int(128 / self.reduce_by), int(32 / self.reduce_by), 4)
        self.fc1 = nn.Linear(4 * 8 * 8, 64)
        self.fc2 = nn.Linear(64, 12)

    def forward(self, x):
        c1 = self.C1(x)
        c1_mxp = F.max_pool2d(c1, kernel_size=2, stride=2)

        c2 = self.C2(c1_mxp)
        c2_mxp = F.max_pool2d(c2, kernel_size=2, stride=2)

        c3 = self.C3(c2_mxp)
        c3_mxp = F.max_pool2d(c3, kernel_size=2, stride=2)

        c4 = self.C4(c3_mxp)
        c4_mxp = F.max_pool2d(c4, kernel_size=2, stride=2)

        c5 = self.C5(c4_mxp)
        c5_mxp = F.max_pool2d(c5, kernel_size=2, stride=2)

        c6 = self.C6(c5_mxp)

        fc1 = self.fc1(c6.view(-1, 4 * 8 * 8))
        fc2 = self.fc2(fc1)
        out = fc2.view(fc2.shape[0], 2, -1)
        return out

    @staticmethod
    def match_and_concat(bypass, upsampled, crop=True):
        if crop:
            c = (bypass.size()[2] - upsampled.size()[2]) // 2
            bypass = F.pad(bypass, (-c, -c, -c, -c))
        return torch.cat((upsampled, bypass), 1)


m = SkullNet(1, 2)
torch_total_params = sum(p.numel() for p in m.parameters() if p.requires_grad)
print('Total Params:', torch_total_params)

Total Params: 1016360


# Train Validation and Test Module

In [27]:
import sys

import torch.nn.functional as F


class NNTrainer:

    def __init__(self, conf=None, model=None, optimizer=None):

        # Initialize parameters and directories before-hand so that we can clearly track which ones are used
        self.conf = conf
        self.log_dir = self.conf.get('Dirs').get('logs', 'net_logs')
        self.epochs = self.conf.get('Params').get('epochs', 100)
        self.log_frequency = self.conf.get('Params').get('log_frequency', 10)
        self.validation_frequency = self.conf.get('Params').get('validation_frequency', 1)
        self.mode = self.conf.get('Params').get('mode', 'test')

        # Initialize necessary logging conf
        self.checkpoint_file = os.path.join(self.log_dir, self.conf.get('checkpoint_file'))

        self.log_headers = self.get_log_headers()
        _log_key = self.conf.get('checkpoint_file').split('.')[0]
        self.test_logger = NNTrainer.get_logger(log_file=os.path.join(self.log_dir, 'submission.csv'),
                                                header=self.log_headers.get('test', ''))
        if self.mode == 'train':
            self.train_logger = NNTrainer.get_logger(log_file=os.path.join(self.log_dir, _log_key + '-TRAIN.csv'),
                                                     header=self.log_headers.get('train', ''))
            self.val_logger = NNTrainer.get_logger(log_file=os.path.join(self.log_dir, _log_key + '-VAL.csv'),
                                                   header=self.log_headers.get('validation', ''))

        self.f_dyn_weights = self.conf.get('f_dyn_weights')

        # Handle gpu/cpu
        if torch.cuda.is_available():
            self.device = torch.device("cuda" if self.conf['Params'].get('use_gpu', False) else "cpu")
        else:
            print('### GPU not found.')
            self.device = torch.device("cpu")

        # Initialization to save model
        self.model = model.to(self.device)
        self.optimizer = optimizer
        self.checkpoint = {'total_epochs:': 0, 'epochs': 0, 'state': None, 'score': 0.0, 'model': 'EMPTY'}
        self.patience = self.conf.get('Params').get('patience', 35)

    def test(self, data_loaders=None):
        print('------Running test------')
        score = ScoreAccumulator()
        self.model.eval()
        with torch.no_grad():
            for i, data in enumerate(data_loaders, 1):
                inputs, labels = data['inputs'].to(self.device).float(), data['labels'].to(self.device).long()

                if self.model.training:
                    self.optimizer.zero_grad()

                outputs = self.model(inputs)
                for ix, pred  in zip(data['index'], F.softmax(outputs, 1)[:,1,:]):
                    file = data_loaders.dataset.indices[ix][0].split('.')[0]

                    p_EPIDURAL = pred[0].item()
                    p_INTRAPARENCHYMAL = pred[1].item()
                    p_INTRAVENTRICULAR = pred[2].item()
                    p_SUBARACHNOID = pred[3].item()
                    p_SUBDURAL = pred[4].item()
                    p_ANY = pred[5].item()

                    log = file + '_' + EPIDURAL + ',' + str(p_EPIDURAL)
                    log += '\n' + file + '_' + INTRAPARENCHYMAL + ',' + str(p_INTRAPARENCHYMAL)
                    log += '\n' + file + '_' + INTRAVENTRICULAR + ',' + str(p_INTRAVENTRICULAR)
                    log += '\n' + file + '_' + SUBARACHNOID + ',' + str(p_SUBARACHNOID)
                    log += '\n' + file + '_' + SUBDURAL + ',' + str(p_SUBDURAL) 
                    log += '\n' + file + '_' + ANY + ',' + str(p_ANY)
                    print(file, end='\r')
                    NNTrainer.flush(self.test_logger, log)

        self._on_test_end(log_file=self.test_logger.name)
        if not self.test_logger and not self.test_logger.closed:
            self.test_logger.close()

    def train(self, data_loader=None, validation_loader=None, epoch_run=None):
        print('Training...')
        for epoch in range(1, self.epochs + 1):
            self.model.train()
            self._adjust_learning_rate(epoch=epoch)
            self.checkpoint['total_epochs'] = epoch

            # Run one epoch
            epoch_run(epoch=epoch, data_loader=data_loader, logger=self.train_logger)

            self._on_epoch_end(data_loader=data_loader, log_file=self.train_logger.name)

            # Validation_frequency is the number of epoch until validation
            if epoch % self.validation_frequency == 0:
                print('############# Running validation... ####################')
                self.model.eval()
                with torch.no_grad():
                    self.validation(epoch=epoch, validation_loader=validation_loader, epoch_run=epoch_run)
                self._on_validation_end(data_loader=validation_loader, log_file=self.val_logger.name)
                if self.early_stop(patience=self.patience):
                    return
                print('########################################################')

        if not self.train_logger and not self.train_logger.closed:
            self.train_logger.close()
        if not self.val_logger and not self.val_logger.closed:
            self.val_logger.close()

    def _on_epoch_end(self, **kw):
        viz.plot_column_keys(file=kw['log_file'], batches_per_epoch=kw['data_loader'].__len__(),
                              keys=['F1', 'LOSS', 'ACCURACY'])
        viz.plot_cmap(file=kw['log_file'], save=True, x='PRECISION', y='RECALL')

    def _on_validation_end(self, **kw):
        viz.plot_column_keys(file=kw['log_file'], batches_per_epoch=kw['data_loader'].__len__(),
                              keys=['F1', 'ACCURACY'])
        viz.plot_cmap(file=kw['log_file'], save=True, x='PRECISION', y='RECALL')

    def _on_test_end(self, **kw):
        viz.y_scatter(file=kw['log_file'], y='F1', label='ID', save=True, title='Test')
        viz.y_scatter(file=kw['log_file'], y='ACCURACY', label='ID', save=True, title='Test')
        viz.xy_scatter(file=kw['log_file'], save=True, x='PRECISION', y='RECALL', label='ID', title='Test')

    # Headers for log files
    def get_log_headers(self):
        return {
            'train': 'ID,EPOCH,BATCH,PRECISION,RECALL,F1,ACCURACY,LOSS',
            'validation': 'ID,PRECISION,RECALL,F1,ACCURACY',
            'test': 'ID,Label'
        }
    
    def validation(self, epoch=None, validation_loader=None, epoch_run=None):
        score_acc = ScoreAccumulator()
        epoch_run(epoch=epoch, data_loader=validation_loader, logger=self.val_logger, score_acc=score_acc)
        p, r, f1, a = score_acc.get_prfa()
        print('>>> PRF1: ', [p, r, f1, a])
        self._save_if_better(score=f1)

    def resume_from_checkpoint(self, parallel_trained=False):
        self.checkpoint = torch.load(self.checkpoint_file)
        print(self.checkpoint_file, 'Loaded...')
        try:
            if parallel_trained:
                from collections import OrderedDict
                new_state_dict = OrderedDict()
                for k, v in self.checkpoint['state'].items():
                    name = k[7:]  # remove `module.`
                    new_state_dict[name] = v
                # load params
                self.model.load_state_dict(new_state_dict)
            else:
                self.model.load_state_dict(self.checkpoint['state'])
        except Exception as e:
            print('ERROR: ' + str(e))

    def _save_if_better(self, score=None):

        if self.mode == 'test':
            return

        if score > self.checkpoint['score']:
            print('Score improved: ',
                  str(self.checkpoint['score']) + ' to ' + str(score) + ' BEST CHECKPOINT SAVED')
            self.checkpoint['state'] = self.model.state_dict()
            self.checkpoint['epochs'] = self.checkpoint['total_epochs']
            self.checkpoint['score'] = score
            self.checkpoint['model'] = str(self.model)
            torch.save(self.checkpoint, self.checkpoint_file)
        else:
            print('Score did not improve:' + str(score) + ' BEST: ' + str(self.checkpoint['score']) + ' Best EP: ' + (
                str(self.checkpoint['epochs'])))

    def early_stop(self, patience=35):
        return self.checkpoint['total_epochs'] - self.checkpoint['epochs'] >= patience * self.validation_frequency

    @staticmethod
    def get_logger(log_file=None, header=''):

        if os.path.isfile(log_file):
            print('### CRITICAL!!! ' + log_file + '" already exists.')
            ip = input('Override? [Y/N]: ')
            if ip == 'N' or ip == 'n':
                sys.exit(1)

        file = open(log_file, 'w')
        NNTrainer.flush(file, header)
        return file

    @staticmethod
    def flush(logger, msg):
        if logger is not None:
            logger.write(msg + '\n')
            logger.flush()

    def _adjust_learning_rate(self, epoch):
        if epoch % 30 == 0:
            for param_group in self.optimizer.param_groups:
                if param_group['lr'] >= 1e-5:
                    param_group['lr'] = param_group['lr'] * 0.7
        
    def epoch_ce_loss(self, **kw):
        """
        One epoch implementation of binary cross-entropy loss
        :param kw:
        :return:
        """
        running_loss = 0.0
        score_acc = ScoreAccumulator() if self.model.training else kw.get('score_acc')
        assert isinstance(score_acc, ScoreAccumulator)

        for i, data in enumerate(kw['data_loader'], 1):
            inputs, labels = data['inputs'].to(self.device).float(), data['labels'].to(self.device).long()
            
            if self.model.training:
                self.optimizer.zero_grad()

            outputs = self.model(inputs)
            _, predicted = torch.max(outputs, 1)
            loss = F.nll_loss(F.log_softmax(outputs, 1), labels, weight=torch.FloatTensor(self.f_dyn_weights(self.conf)).to(self.device))
            
            if self.model.training:
                loss.backward()
                self.optimizer.step()

            current_loss = loss.item()
            running_loss += current_loss

            if self.model.training:
                score_acc.reset()

            p, r, f1, a = score_acc.add_tensor(predicted, labels).get_prfa()

            if i % self.log_frequency == 0:
                print('Epochs[%d/%d] Batch[%d/%d] loss:%.5f pre:%.3f rec:%.3f f1:%.3f acc:%.3f' %
                      (
                          kw['epoch'], self.epochs, i, kw['data_loader'].__len__(),
                          running_loss / self.log_frequency, p, r, f1,
                          a))
                running_loss = 0.0
            self.flush(kw['logger'],
                       ','.join(str(x) for x in [0, kw['epoch'], i, p, r, f1, a, current_loss]))

# Training setup

In [28]:
"""
### author: Aashis Khanal
### sraashis@gmail.com
### date: 9/10/2018
"""

import os
import traceback

import torch
import torch.optim as optim


def run(runs):
    for R in runs:
        for k, folder in R['Dirs'].items():
            os.makedirs(folder, exist_ok=True)
        R['acc'] = ScoreAccumulator()
        R['checkpoint_file'] = R['train_mapping_file'].split(os.sep)[1] + '.tar'
        model = SkullNet(R['Params']['num_channels'], R['Params']['num_classes'])
        optimizer = optim.Adam(model.parameters(), lr=R['Params']['learning_rate'])
        if R['Params']['distribute']:
            model = torch.nn.DataParallel(model)
            model.float()
            optimizer = optim.Adam(model.module.parameters(), lr=R['Params']['learning_rate'])

        try:
            trainer = NNTrainer(model=model, conf=R, optimizer=optimizer)
            if R.get('Params').get('mode') == 'train':
                train_loader, val_loader = SkullDataset.get_train_val_loader(R)
                print('### Train Val Batch size:', len(train_loader.dataset), len(val_loader.dataset))
                trainer.train(data_loader=train_loader, validation_loader=val_loader,
                              epoch_run=trainer.epoch_ce_loss)

            test_loader = SkullDataset.get_test_loader(conf=R)
            trainer.resume_from_checkpoint(parallel_trained=R.get('Params').get('parallel_trained'))
            
            trainer.test(test_loader)
        except Exception as e:
            traceback.print_exc()

    print(R['acc'].get_prfa())
    f = open(R['Dirs']['logs'] + os.sep + 'score.txt', "w")
    f.write(', '.join(str(s) for s in R['acc'].get_prfa()))
    f.close()


In [29]:
Params = {
    'num_channels': 1,
    'num_classes': 2,
    'batch_size': 16,
    'epochs': 1,
    'learning_rate': 0.001,
    'use_gpu': True,
    'distribute': True,
    'shuffle': True,
    'log_frequency': 10,
    'validation_frequency': 1,
    'mode': 'train',
    'parallel_trained': False,
}
SKDB = {
    'Params': Params,
    'Dirs': {
        'train_image_dir': train_images_dir,
        'test_image_dir': test_images_dir,
        'logs': 'logs'},
    'train_mapping_file': train_mapping_file,
    'test_mapping_file': test_mapping_file,
    'f_dyn_weights': lambda x: np.random.choice(np.arange(1, 101, 1), 2),
    'validate_img_pth': False,
    'load_lim': 1000
}

In [30]:
run([SKDB])

### CRITICAL!!! logs/submission.csv" already exists.
Override? [Y/N]: 
### CRITICAL!!! logs/input-TRAIN.csv" already exists.
Override? [Y/N]: 
### CRITICAL!!! logs/input-VAL.csv" already exists.
Override? [Y/N]: 
### Train Val Batch size: 800 200
Training...
Epochs[1/1] Batch[10/50] loss:0.50157 pre:0.001 rec:0.001 f1:0.001 acc:0.927
Epochs[1/1] Batch[20/50] loss:0.18213 pre:0.001 rec:0.001 f1:0.001 acc:0.802
Epochs[1/1] Batch[30/50] loss:0.37367 pre:0.001 rec:0.001 f1:0.001 acc:0.938


KeyboardInterrupt: 