In [1]:
import os
import gc
import pickle
import pandas as pd
import numpy as np
import seaborn as sns
from typing import *
from tqdm.notebook import tqdm
from pathlib import Path
from matplotlib import pyplot as plt

pd.set_option('max_columns', 50)

In [2]:
from dataclasses import dataclass, field, asdict
import yaml


@dataclass
class Config:
    # General
    device: str = 'cuda'
    # Data config
    imgconf_file: str = '../../data/VinBigData/train.csv'
    imgdir_name: str = "../../data/ChestXRay14"
    # Model
    model_dir: str = "../VinBigData-2Class/results"
    model_mode: str = 'normal'
    model_name: str = 'resnet18'
    
    # Evaluation
    batch_size: int = 32
    num_workers: int = 4

    def update(self, param_dict: Dict) -> "Config":
        # Overwrite by `param_dict`
        for key, value in param_dict.items():
            if not hasattr(self, key):
                raise ValueError(f"[ERROR] Unexpected key for flag = {key}")
            setattr(self, key, value)
        return self
    
    def to_yaml(self, filepath: str, width: int = 120):
        with open(filepath, 'w') as f:
            yaml.dump(asdict(self), f, width=width)

In [3]:
import torch
from torch import nn
from torch.nn import Linear
import torch.nn.functional as F
import timm


class CNNFixedPredictor(nn.Module):
    def __init__(self, cnn: nn.Module, num_classes: int = 2):
        super(CNNFixedPredictor, self).__init__()
        self.cnn = cnn
        self.lin = Linear(cnn.num_features, num_classes)
        print("cnn.num_features", cnn.num_features)

        # We do not learn CNN parameters.
        # https://pytorch.org/tutorials/beginner/finetuning_torchvision_models_tutorial.html
        for param in self.cnn.parameters():
            param.requires_grad = False

    def forward(self, x):
        feat = self.cnn(x)
        return self.lin(feat)


def build_predictor(model_name: str, model_mode: str = "normal"):
    if model_mode == "normal":
        # normal configuration. train all parameters.
        return timm.create_model(model_name, pretrained=True, num_classes=2, in_chans=3)
    elif model_mode == "cnn_fixed":
        # normal configuration. train all parameters.
        # https://rwightman.github.io/pytorch-image-models/feature_extraction/
        timm_model = timm.create_model(model_name, pretrained=True, num_classes=0, in_chans=3)
        return CNNFixedPredictor(timm_model, num_classes=2)
    else:
        raise ValueError(f"[ERROR] Unexpected value model_mode={model_mode}")

        
def accuracy(y: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
    """Computes multi-class classification accuracy"""
    assert y.shape[:-1] == t.shape, f"y {y.shape}, t {t.shape} is inconsistent."
    pred_label = torch.max(y.detach(), dim=-1)[1]
    count = t.nelement()
    correct = (pred_label == t).sum().float()
    acc = correct / count
    return acc


def accuracy_with_logits(y: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
    """Computes multi-class classification accuracy"""
    assert y.shape == t.shape
    gt_label = torch.max(t.detach(), dim=-1)[1]
    return accuracy(y, gt_label)


def cross_entropy_with_logits(input, target, dim=-1):
    loss = torch.sum(- target * F.log_softmax(input, dim), dim)
    return loss.mean()


import torch.nn.functional as F
from torch import nn
import pytorch_pfn_extras as ppe


class Classifier(nn.Module):
    """two class classfication"""

    def __init__(self, predictor, lossfun=cross_entropy_with_logits):
        super().__init__()
        self.predictor = predictor
        self.lossfun = lossfun
        self.prefix = ""

    def forward(self, image, targets):
        outputs = self.predictor(image)
        loss = self.lossfun(outputs, targets)
        metrics = {
            f"{self.prefix}loss": loss.item(),
            f"{self.prefix}acc": accuracy_with_logits(outputs, targets).item()
        }
        ppe.reporting.report(metrics, self)
        return loss, metrics

    def predict(self, data_loader):
        pred = self.predict_proba(data_loader)
        label = torch.argmax(pred, dim=1)
        return label

    def predict_proba(self, data_loader):
        device: torch.device = next(self.parameters()).device
        y_list = []
        self.eval()
        with torch.no_grad():
            for batch in data_loader:
                if isinstance(batch, (tuple, list)):
                    # Assumes first argument is "image"
                    batch = batch[0].to(device)
                else:
                    batch = batch.to(device)
                y = self.predictor(batch)
                y = torch.softmax(y, dim=-1)
                y_list.append(y)
        pred = torch.cat(y_list)
        return pred

fail to import apex_C: apex was not installed or installed without --cpp_ext.
fail to import amp_C: apex was not installed or installed without --cpp_ext.


In [4]:
import albumentations as A


class Transform:
    def __init__(self, aug_kwargs: Dict):
        self.transform = A.Compose([getattr(A, name)(**kwargs) for name, kwargs in aug_kwargs.items()])

    def __call__(self, image):
        image = self.transform(image=image)["image"]
        return image

In [5]:
class ChestRayDataset(torch.utils.data.Dataset):
    def __init__(self, filepaths: List[str], labels: List[int], label_smoothing: float = 0., mixup_prob: float = 0., image_transform: Transform = None):
        self.filepaths = filepaths
        self.labels = labels
        self.label_smoothing = label_smoothing
        self.mixup_prob = mixup_prob
        self.image_transform = image_transform
        
    def __len__(self) -> int:
        return len(self.filepaths)
    
    def __getitem__(self, idx: int):
        img, label = self.__get_single_example(idx=idx)
        img, label = self.__mixup(img=img, label=label)
        
        label_logit = torch.tensor([1 - label, label], dtype=torch.float32)
        
        return img, label_logit
    
    def __get_single_example(self, idx: int) -> Tuple[torch.Tensor, float]:
        img = cv2.imread(self.filepaths[idx])
        if self.image_transform:
            img = self.image_transform(img)
        img = torch.tensor(np.transpose(img, (2, 0, 1)).astype(np.float32))
        
        label = self.labels[idx]
        if label == 0:
            return img, float(label) + self.label_smoothing
        else:
            return img, float(label) - self.label_smoothing
    
    def __mixup(self, img: torch.Tensor, label: float):
        if np.random.uniform() < self.mixup_prob:
            pair_idx = np.random.randint(0, len(filepaths))
            prob = np.random.uniform()
            pair_img, pair_label = self.__get_single_example(idx=pair_idx)
            img = img * p + pair_img * (1 - p)
            label = label * p + pair_label * (1 - p)
        return img, label

In [6]:
from sklearn.model_selection import StratifiedKFold
import cv2
from torch.utils.data.dataset import Dataset


class StratifiedKFoldWrapper:
    def __init__(self, datadir: Path, n_splits: int, shuffle: bool, seed: int, 
                 label_smoothing: float = 0., mixup_prob: float = 0., aug_kwargs: Dict = {}):
        self.datadir = datadir
        fl = self.__load_filelist()
        fp = self.__load_filepath()
        
        df = pd.merge(fl, fp, how='left', on='Image Index')[['Image Index', 'Normal', 'Unique Image Index', 'File Path']]  
        df = df.groupby('Unique Image Index').first().reset_index()
        self.datalist = df
        
        skf = StratifiedKFold(n_splits=n_splits, shuffle=shuffle, random_state=seed)
        self.split_idxs = list(skf.split(self.datalist['File Path'].values, self.datalist['Normal'].values))
        self.__i = -1
        
        self.label_smoothing = label_smoothing
        self.mixup_prob = mixup_prob
        self.image_transform = Transform(aug_kwargs=aug_kwargs)
        
    def __load_filelist(self):
        df = pd.read_csv(self.datadir / 'Data_Entry_2017.csv')
        df['Unique Image Index'] = df['Image Index'].map(lambda x: x.split('_')[0])
        df['Normal'] = (df['Finding Labels'] == 'No Finding').astype('int8')
        return df
        
    def __load_filepath(self):
        records = {'Image Index': [], 'File Path': []}
        
        for dn in [tmp for tmp in sorted(os.listdir(self.datadir)) if 'image' in tmp]:
            for fn in sorted(os.listdir(self.datadir / dn / 'images')):
                records['Image Index'] += [fn]
                records['File Path'] += [str(self.datadir / dn / 'images' / fn)]
                
        return pd.DataFrame(records)
    
    def __iter__(self):
        self.__i = -1
        return self

    def __next__(self):
        self.__i += 1
        if self.__i < 0 or len(self.split_idxs) <= self.__i:
            raise StopIteration()
        return self.generate_datasets()
    
    def __len__(self) -> int:
        return len(self.split_idxs)
    
    def __getitem__(self, idx):
        if idx < 0 or len(self.split_idxs) <= idx:
            raise ValueError()
        self.__i = idx
        return self.generate_datasets()
        
    def generate_datasets(self) -> ChestRayDataset:
        train_idx, valid_idx = self.split_idxs[self.__i]
        
        train_ds = ChestRayDataset(
            filepaths=self.datalist['File Path'].values[train_idx].tolist(),
            labels=self.datalist['Normal'].values[train_idx].tolist(),
            label_smoothing=self.label_smoothing,
            mixup_prob=self.mixup_prob,
            image_transform=self.image_transform
        )
        valid_ds = ChestRayDataset(
            filepaths=self.datalist['File Path'].values[valid_idx].tolist(),
            labels=self.datalist['Normal'].values[valid_idx].tolist(),
        )
        
        return train_ds, valid_ds

In [7]:
config_dict = {}
config = Config().update(config_dict)
base_dir = Path().resolve()

In [8]:
config

Config(device='cuda', imgconf_file='../../data/VinBigData/train.csv', imgdir_name='../../data/ChestXRay14', model_dir='../VinBigData-2Class/results', model_mode='normal', model_name='resnet18', batch_size=32, num_workers=4)

In [9]:
device = torch.device(config.device)
predictor = build_predictor(model_name=config.model_name, model_mode=config.model_mode)
predictor.load_state_dict(torch.load(base_dir / config.model_dir / 'predictor_last.pt'))
predictor.eval()

classifier = Classifier(predictor)
classifier.to(device)
print()




In [10]:
train = pd.read_csv(base_dir / config.imgconf_file)

skf = StratifiedKFoldWrapper(
    datadir=base_dir / config.imgdir_name,
    n_splits=5,
    shuffle=True,
    seed=111,
    label_smoothing=0,
    mixup_prob=0,
    aug_kwargs={}
)

In [14]:
from torch.utils.data.dataloader import DataLoader


valid_dataset = skf[0][1]  # 0th fold, validation data
valid_loader = DataLoader(
    valid_dataset,
    batch_size=config.batch_size,
    num_workers=config.num_workers,
    shuffle=False,
    pin_memory=True,
)

valid_pred = classifier.predict_proba(valid_loader).cpu().numpy()
valid_pred_df = pd.DataFrame({
    "class0": valid_pred[:, 0],
    "class1": valid_pred[:, 1],
    "truth": valid_dataset.labels,
})