In [1]:
import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
from pathlib import Path
from cv2 import cv2
from tqdm.auto import tqdm

import torch
from torch.utils.data import Dataset, DataLoader, random_split
import torch.nn as nn
from torchvision import models
import torch.optim as optim
import torch.nn.functional as F

import torchmetrics

import pytorch_lightning as pl
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping

import albumentations as A
from albumentations.pytorch.transforms import ToTensorV2

from pprint import pprint

from sklearn.model_selection import train_test_split
from sklearn.preprocessing import MultiLabelBinarizer

from shutil import rmtree

In [2]:
# seeding everything
pl.seed_everything(0, workers=True)

0

In [3]:
class Config:
    BATCH_SIZE = 64
    EPOCHS = 10
    IMG_HEIGHT = 224
    IMG_WIDTH = 224
    label_file_path = "../input/data/Data_Entry_2017.csv"
    image_paths = Path("../input/data/").glob("*/*/*")

In [4]:
df = pd.read_csv(Config.label_file_path)
df.head()

Unnamed: 0,Image Index,Finding Labels,Follow-up #,Patient ID,Patient Age,Patient Gender,View Position,OriginalImage[Width,Height],OriginalImagePixelSpacing[x,y],Unnamed: 11
0,00000001_000.png,Cardiomegaly,0,1,58,M,PA,2682,2749,0.143,0.143,
1,00000001_001.png,Cardiomegaly|Emphysema,1,1,58,M,PA,2894,2729,0.143,0.143,
2,00000001_002.png,Cardiomegaly|Effusion,2,1,58,M,PA,2500,2048,0.168,0.168,
3,00000002_000.png,No Finding,0,2,81,M,PA,2500,2048,0.171,0.171,
4,00000003_000.png,Hernia,0,3,81,F,PA,2582,2991,0.143,0.143,


In [5]:
label_set = set()
for label in tqdm(df['Finding Labels'].values):
    splits = label.split("|")
    for lab in splits:
        label_set.add(lab)
label_list = list(label_set)

  0%|          | 0/112120 [00:00<?, ?it/s]

In [6]:
label_count = {k: 0 for k in label_list}
for label in tqdm(df['Finding Labels'].values):
    splits = label.split("|")
    for lab in splits:
        label_count[lab] += 1
pprint(label_count)

  0%|          | 0/112120 [00:00<?, ?it/s]

{'Atelectasis': 11559,
 'Cardiomegaly': 2776,
 'Consolidation': 4667,
 'Edema': 2303,
 'Effusion': 13317,
 'Emphysema': 2516,
 'Fibrosis': 1686,
 'Hernia': 227,
 'Infiltration': 19894,
 'Mass': 5782,
 'No Finding': 60361,
 'Nodule': 6331,
 'Pleural_Thickening': 3385,
 'Pneumonia': 1431,
 'Pneumothorax': 5302}


In [7]:
df['Image Index'] = sorted(list(Config.image_paths))

In [8]:
label_mlb = []
for label in tqdm(df['Finding Labels'].values):
    splits = label.split("|")
    # Remove "No Finidng" from the label list (keep only the 14 pathologies)
    if "No Finding" in splits: 
        splits.remove("No Finding")
    label_mlb.append(set(splits))

# binarize the labels
mlb = MultiLabelBinarizer()
label_array = mlb.fit_transform(label_mlb)

pprint(mlb.classes_)
pprint(label_array)

  0%|          | 0/112120 [00:00<?, ?it/s]

array(['Atelectasis', 'Cardiomegaly', 'Consolidation', 'Edema',
       'Effusion', 'Emphysema', 'Fibrosis', 'Hernia', 'Infiltration',
       'Mass', 'Nodule', 'Pleural_Thickening', 'Pneumonia',
       'Pneumothorax'], dtype=object)
array([[0, 1, 0, ..., 0, 0, 0],
       [0, 1, 0, ..., 0, 0, 0],
       [0, 1, 0, ..., 0, 0, 0],
       ...,
       [0, 0, 0, ..., 0, 0, 0],
       [0, 0, 0, ..., 0, 0, 0],
       [0, 0, 0, ..., 0, 0, 0]])


In [9]:
label_df = pd.DataFrame(label_array, columns=mlb.classes_)
label_df.head(10)

Unnamed: 0,Atelectasis,Cardiomegaly,Consolidation,Edema,Effusion,Emphysema,Fibrosis,Hernia,Infiltration,Mass,Nodule,Pleural_Thickening,Pneumonia,Pneumothorax
0,0,1,0,0,0,0,0,0,0,0,0,0,0,0
1,0,1,0,0,0,1,0,0,0,0,0,0,0,0
2,0,1,0,0,1,0,0,0,0,0,0,0,0,0
3,0,0,0,0,0,0,0,0,0,0,0,0,0,0
4,0,0,0,0,0,0,0,1,0,0,0,0,0,0
5,0,0,0,0,0,0,0,1,0,0,0,0,0,0
6,0,0,0,0,0,0,0,1,0,0,0,0,0,0
7,0,0,0,0,0,0,0,1,1,0,0,0,0,0
8,0,0,0,0,0,0,0,1,0,0,0,0,0,0
9,0,0,0,0,0,0,0,1,0,0,0,0,0,0


In [10]:
assert len(df) == len(label_df)

In [11]:
label_df.insert(loc=0, column="Image_paths", value=df['Image Index'].values)

In [12]:
label_df.head(10)

Unnamed: 0,Image_paths,Atelectasis,Cardiomegaly,Consolidation,Edema,Effusion,Emphysema,Fibrosis,Hernia,Infiltration,Mass,Nodule,Pleural_Thickening,Pneumonia,Pneumothorax
0,../input/data/images_001/images/00000001_000.png,0,1,0,0,0,0,0,0,0,0,0,0,0,0
1,../input/data/images_001/images/00000001_001.png,0,1,0,0,0,1,0,0,0,0,0,0,0,0
2,../input/data/images_001/images/00000001_002.png,0,1,0,0,1,0,0,0,0,0,0,0,0,0
3,../input/data/images_001/images/00000002_000.png,0,0,0,0,0,0,0,0,0,0,0,0,0,0
4,../input/data/images_001/images/00000003_000.png,0,0,0,0,0,0,0,1,0,0,0,0,0,0
5,../input/data/images_001/images/00000003_001.png,0,0,0,0,0,0,0,1,0,0,0,0,0,0
6,../input/data/images_001/images/00000003_002.png,0,0,0,0,0,0,0,1,0,0,0,0,0,0
7,../input/data/images_001/images/00000003_003.png,0,0,0,0,0,0,0,1,1,0,0,0,0,0
8,../input/data/images_001/images/00000003_004.png,0,0,0,0,0,0,0,1,0,0,0,0,0,0
9,../input/data/images_001/images/00000003_005.png,0,0,0,0,0,0,0,1,0,0,0,0,0,0


In [13]:
class CheXNet(nn.Module):
    def __init__(self, n_classes):
        super(CheXNet, self).__init__()
        
        self.densenet121 = models.densenet121(pretrained=True)
        n_features = self.densenet121.classifier.in_features
        self.densenet121.classifier = nn.Sequential(
            nn.Linear(n_features, n_classes),
            nn.Sigmoid()
        )
        
    def forward(self, x):
        x = self.densenet121(x)
        return x

In [14]:
train_file = open("../input/data/train_val_list.txt") 
test_file = open("../input/data/test_list.txt")

train_images = train_file.read().splitlines()
test_images = test_file.read().splitlines()

In [15]:
# get the train and test index to create train and test dataframes
train_idxs = []
test_idxs = []
loader = enumerate(label_df['Image_paths'].values)
for idx, path in tqdm(loader, total=len(label_df)):
    if path.name in train_images:
        train_idxs.append(idx)
    if path.name in test_images:
        test_idxs.append(idx)            

  0%|          | 0/112120 [00:00<?, ?it/s]

In [16]:
# Create train and test dataframes
train_df = label_df.iloc[train_idxs, :].reset_index(drop=True)
test_df = label_df.iloc[test_idxs, :].reset_index(drop=True)

In [17]:
train_df.head()

Unnamed: 0,Image_paths,Atelectasis,Cardiomegaly,Consolidation,Edema,Effusion,Emphysema,Fibrosis,Hernia,Infiltration,Mass,Nodule,Pleural_Thickening,Pneumonia,Pneumothorax
0,../input/data/images_001/images/00000001_000.png,0,1,0,0,0,0,0,0,0,0,0,0,0,0
1,../input/data/images_001/images/00000001_001.png,0,1,0,0,0,1,0,0,0,0,0,0,0,0
2,../input/data/images_001/images/00000001_002.png,0,1,0,0,1,0,0,0,0,0,0,0,0,0
3,../input/data/images_001/images/00000002_000.png,0,0,0,0,0,0,0,0,0,0,0,0,0,0
4,../input/data/images_001/images/00000004_000.png,0,0,0,0,0,0,0,0,0,1,1,0,0,0


In [18]:
test_df.head()

Unnamed: 0,Image_paths,Atelectasis,Cardiomegaly,Consolidation,Edema,Effusion,Emphysema,Fibrosis,Hernia,Infiltration,Mass,Nodule,Pleural_Thickening,Pneumonia,Pneumothorax
0,../input/data/images_001/images/00000003_000.png,0,0,0,0,0,0,0,1,0,0,0,0,0,0
1,../input/data/images_001/images/00000003_001.png,0,0,0,0,0,0,0,1,0,0,0,0,0,0
2,../input/data/images_001/images/00000003_002.png,0,0,0,0,0,0,0,1,0,0,0,0,0,0
3,../input/data/images_001/images/00000003_003.png,0,0,0,0,0,0,0,1,1,0,0,0,0,0
4,../input/data/images_001/images/00000003_004.png,0,0,0,0,0,0,0,1,0,0,0,0,0,0


In [19]:
class ChestXRayDataset(Dataset):
    def __init__(self, X,y,label_list, transforms=None):
        self.labels = y
        self.image_paths = X
        self.transforms = transforms
        self.default_transform = ToTensorV2()
        self.label_list = label_list
        
    def __getitem__(self, idx):
        img = cv2.imread(self.image_paths[idx].as_posix(), cv2.IMREAD_UNCHANGED)
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        lbl = self.labels[idx]
        
        if self.transforms is not None:
            transformed = self.transforms(image=img)
            img = transformed['image']
        
        # convert image and label to torch tensor
        tensor_transform = self.default_transform(image=img)
        img = tensor_transform['image']
        lbl = torch.tensor(lbl, dtype=torch.float32)
        
        return {
            "image": img,
            "label": lbl
        }
    
    def __len__(self):
        return len(self.image_paths)
    
    def label_interp(self, idx):
        label_as_binary = self.labels[idx]
        idxs = np.where(label_as_binary ==  1)
        return self.label_list[idxs]

In [20]:
class ChestDataModule(pl.LightningDataModule):
    def __init__(self, train_df, test_df, label_list):
        super().__init__()
        
        # For train/val spliiting
        self.X = train_df["Image_paths"].values
        self.y = train_df.iloc[:, 1:].values
        
        # Test splitting
        self.X_test = test_df["Image_paths"].values
        self.y_test = test_df.iloc[:, 1:].values
        
        self.label_list = label_list
        
        # Train & Test transforms
        self.tfms = {
            "train": A.Compose(
                [
                    A.Resize(
                        width=Config.IMG_WIDTH, 
                        height=Config.IMG_HEIGHT
                    ),
                    A.HorizontalFlip(p=0.5),
                    A.Normalize(),
                    A.RandomBrightnessContrast(p=0.2)
                ]
            ),
            "test": A.Resize(
                width=Config.IMG_WIDTH, 
                height=Config.IMG_HEIGHT
            )
        }
        
    def setup(self, stage) -> None:
        if stage in ("fit", None):
            # Splitting train dataset into train and validation datasets
            train_ds = ChestXRayDataset(
                X=self.X, 
                y=self.y, 
                label_list=self.label_list, 
                transforms=self.tfms['train']
            )
            train_len = int(len(train_ds) * 0.8)
            val_len = int(len(train_ds) * 0.2)
            len_sum = train_len + val_len
            if len_sum != len(train_ds):
                diff = len(train_ds) - len_sum
                train_len += diff
            lengths = [train_len, val_len]
            self.train_data, self.val_data = random_split(train_ds, lengths)
            
        if stage in ("test", None):
            # Test dataset
            self.test_data = ChestXRayDataset(
                X=self.X_test, 
                y=self.y_test, 
                label_list=self.label_list, 
                transforms=self.tfms['test']
            )
            
    def get_weights(self):
        pos_counts = np.sum(self.y, axis=0)/len(self.y)
        neg_counts = 1 - pos_counts
        return torch.tensor(pos_counts), torch.tensor(neg_counts)

    def train_dataloader(self):
        return DataLoader(self.train_data, batch_size=Config.BATCH_SIZE, num_workers=2)

    def val_dataloader(self):
        return DataLoader(self.val_data, batch_size=Config.BATCH_SIZE, num_workers=2)
    
    def test_dataloader(self):
        return DataLoader(self.test_data, batch_size=Config.BATCH_SIZE, num_workers=2)

In [21]:
class LitModel(pl.LightningModule):
    def __init__(self):
        super(LitModel, self).__init__()
        self.model = CheXNet(n_classes=14)
        self.loss = nn.BCELoss()
        self.train_acc = torchmetrics.Accuracy()
        self.val_acc = torchmetrics.Accuracy()
        self.test_acc = torchmetrics.Accuracy()

    def configure_optimizers(self):
        return optim.Adam(self.model.parameters(), lr=1e-4)

    def forward(self, images):
        return self.model(images)

    def training_step(self, batch, batch_idx):
        images = batch["image"]
        labels = batch["label"]
        
        preds = self.forward(images)
        loss = self.loss(input=preds, target=labels)
        acc = self.train_acc(preds, labels.int())

        self.log("train_loss", loss)
        self.log("train_acc", acc, prog_bar=True)

        return loss

    def validation_step(self, batch, batch_idx):
        images = batch["image"]
        labels = batch["label"]
        
        preds = self.forward(images)
        loss = self.loss(input=preds, target=labels)
        acc = self.val_acc(preds, labels.int())

        self.log("val_loss", loss)
        self.log("val_acc", acc, prog_bar=True)

        return loss
    
    def test_step(self, batch, batch_idx):
        images = batch["image"]
        labels = batch["label"]
        
        preds = self.forward(images.float())
        acc = self.val_acc(preds, labels.int())
        
        self.log("test_acc", acc, prog_bar=True)
        
        return loss
        

In [22]:
model = LitModel()
dm = ChestDataModule(train_df, test_df, label_list=mlb.classes_)

# for checkpointing our model
checkpoint_callback = ModelCheckpoint(
    dirpath="../working/models", 
    monitor="val_acc", 
    mode="max", 
    verbose=True,
    save_top_k=3,
    filename='{epoch}-{val_loss:.2f}-{val_acc:.2f}'
)

early_stop_callback = EarlyStopping(
    monitor="val_acc", 
    min_delta=0.00, 
    patience=3, 
    verbose=True, 
    mode="max"
)

trainer = pl.Trainer(
    logger=True,
    max_epochs=Config.EPOCHS,
    accelerator="auto", 
    callbacks=[checkpoint_callback, early_stop_callback],
)


Downloading: "https://download.pytorch.org/models/densenet121-a639ec97.pth" to /root/.cache/torch/hub/checkpoints/densenet121-a639ec97.pth


  0%|          | 0.00/30.8M [00:00<?, ?B/s]

In [23]:
trainer.fit(model, datamodule = dm)

Sanity Checking: 0it [00:00, ?it/s]

Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]