<h2 align=center style="color:red; border:1px dotted red">Chest X-ray - Faster RCNN</h2>

<pre>

                 .=.
         .---._.-.=.-._.---.
        / ':-(_.-: :-._)-:` \
       / /' (__.-: :-.__) `\ \
      / /  (___.-` '-.___)  \ \
     / /   (___.-'^`-.___)   \ \
    / /    (___.-'=`-.___)    \ \
   / /     (____.'=`.____)     \ \
  / /       (___.'=`.___)       \ \
 (_.;       `---'.=.`---'       ;._)
</pre>

### By Alin Cijov

In [None]:
import numpy as np
import pandas as pd
import pydicom
import time

import torch
import torchvision
import torchvision.transforms as T
from collections import defaultdict, deque
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
from torchvision.models.detection import FasterRCNN

import cv2
import os

from PIL import Image
from skimage import exposure
from pydicom.pixel_data_handlers.util import apply_voi_lut
import matplotlib.pyplot as plt

import warnings
warnings.filterwarnings("ignore")

In [None]:
path = '../input/vinbigdata-chest-xray-abnormalities-detection/'
df = pd.read_csv(path + 'train.csv')
df.head()

In [None]:
# replace NaN by 0
df = df.fillna(0)

<h2 align=center style="color:red; border:1px dotted red">Dataset</h2>

In [None]:
class XrayDataset(object):
    def __init__(self, df, path, transforms=None):
        # select only those classes that have boxes
        self.df = df[df['class_name'] != 'No finding']
        self.transforms = transforms
        self.categories = df['class_name'].unique()
        self.idx_to_categories = {k:v for k,v in enumerate(self.categories)}
        self.categories_to_idx = {v:k for k,v in enumerate(self.categories)}
        self.images_paths = path + "train/" + df['image_id'] + ".dicom"
        self.boxes = self.df[['x_min','y_min','x_max','y_max']]
        
    def __len__(self):
        return len(self.images_paths)
    
    
    def get_image(self, dicom):
        intercept = dicom.RescaleIntercept if "RescaleIntercept" in dicom else 0.0
        slope = dicom.RescaleSlope if "RescaleSlope" in dicom else 1.0
        image = apply_voi_lut(dicom.pixel_array, dicom)
        if dicom.PhotometricInterpretation == "MONOCHROME1":
            image = np.amax(image) - image
        
        if slope != 1:
            image = slope * image.astype(np.float64)
            image = image.astype(np.int16)
            
        image = np.stack([image, image, image])
        image = image - np.min(image)

        image = image / image.max()
        image = exposure.equalize_hist(image)
        image = image.astype('float32')

        image = image.transpose(1,2,0)
        
        return image
    
    def __getitem__(self, idx):
        dicom = pydicom.read_file(self.images_paths.iloc[idx])
        img = self.get_image(dicom)
        
        boxes = np.expand_dims(self.boxes.iloc[idx].values, axis=0)
        area = (boxes[:, 3] - boxes[:, 1]) * (boxes[:, 2] - boxes[:, 0])
        
        target = {}
        target['boxes'] = torch.tensor(boxes)
        target['labels'] = torch.tensor([self.categories_to_idx[self.df.iloc[idx].class_name]])
        target['area'] = torch.as_tensor(area, dtype=torch.float32)
        target['iscrowd'] = torch.zeros((self.df.iloc[idx].shape[0],), dtype=torch.int64)
        
        if self.transforms is not None:
            img = self.transforms(img)
        
        return img, target

In [None]:
def get_transform(train, dim_size):
    transforms = []
    transforms.append(T.ToTensor())
    if train:
        transforms.append(T.Resize(dim_size))
        transforms.append(T.RandomHorizontalFlip(0.5))
    return T.Compose(transforms)

In [None]:
dim_size = (256, 256)
xrayds = XrayDataset(df, path, get_transform(train=True, dim_size=dim_size))

img, target = xrayds[100]
plt.imshow(img.permute(1, 2, 0))

In [None]:
num_classes = len(xrayds.categories)

dataset = XrayDataset(df, path, get_transform(train=True, dim_size=dim_size))
dataset_test = XrayDataset(df, path, get_transform(train=True, dim_size=dim_size))

# use only 100 dicom files for demo
nr_dicom = 100
indices = torch.randperm(nr_dicom).tolist()
dataset = torch.utils.data.Subset(dataset, indices)


def collate_fn(batch):
    return list(zip(*batch))

# define training and validation data loaders
data_loader = torch.utils.data.DataLoader(
    dataset, batch_size=2, shuffle=True, num_workers=4,
    collate_fn=collate_fn)

<h2 align=center style="color:red; border:1px dotted red">Training</h2>

In [None]:
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

model = torchvision.models.detection.fasterrcnn_resnet50_fpn(pretrained=True)

# get number of input features for the classifier
in_features = model.roi_heads.box_predictor.cls_score.in_features

# replace the pre-trained head with a new one
model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes)

model = model.to(device)

In [None]:
params = [p for p in model.parameters() if p.requires_grad]

optimizer = torch.optim.SGD(params, lr=0.0001,
                            momentum=0.9, weight_decay=0.0001)

lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer,
                                               step_size=3,
                                               gamma=0.1)

In [None]:
loss_epoch = []
for epoch in range(5):
    
    loss_iteration = []
    for i, (images, targets) in enumerate(data_loader):
        
        images = list(image.to(device).type(torch.cuda.FloatTensor) for image in images)
        targets = [{k: v.to(device) for k, v in t.items()} for t in targets]

        loss_dict = model(images, targets)

        losses = sum(loss for loss in loss_dict.values())
        loss_value = losses.item()
        loss_iteration.append(loss_value)

        optimizer.zero_grad()
        losses.backward()
        optimizer.step()

        if i % 10 == 0:
            print("Epoch:{:4d}, Iteration:{:4d}, Loss:{:4.4f}"
                  .format(epoch, i, loss_iteration[-1]))
            
    loss_epoch.append(np.array(loss_iteration).mean())
    
    if lr_scheduler is not None:
        lr_scheduler.step()

<h2 align=center style="color:red; border:1px dotted red">Analyze</h2>

In [None]:
plt.figure(figsize=(14,7))
plt.xlabel('Epoch', fontsize=15)
plt.ylabel('Loss', fontsize=15)
plt.title("Mean Loss per Epoch", fontsize=15)
plt.plot(loss_epoch)