# FasterRCNN

In [11]:
# imports
import torch
from torch import nn
from torchvision.models import resnet18
from torch.utils.data import Dataset
from torchvision import datasets
from torch.utils.data import DataLoader
import torch.nn.functional as F
import platform
from torch.autograd import Variable
import time
import torchvision
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
from torchvision.models.detection import  FasterRCNN
from torchvision.models.detection.rpn import AnchorGenerator
import glob
import os
from sklearn import metrics
from sklearn.model_selection import StratifiedKFold

# vis
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from tqdm import tqdm
import sklearn.metrics
from math import ceil
import cv2
import pydicom
from pydicom.pixel_data_handlers.util import apply_voi_lut
import PIL

In [7]:
# import dataset
train = pd.read_csv('../input/train_exploded_filled.csv')

In [12]:
class Config:
    train_pcent = 0.8
    TRAIN_BS = 4
    VALID_BS = 4
    NB_EPOCHS = 3
    model_name = 'FastRCNN'
    reshape_size = (400, 400)
    num_classes = 4
    seed = 2021

# Splits

In [16]:
# split into folds
df_folds = train.copy()
skf = StratifiedKFold(n_splits=5, shuffle=True, random_state=Config.seed)
for n, (train_index, val_index) in enumerate(skf.split(X=df_folds.index, y=df_folds.integer_label)):
    df_folds.loc[df_folds.iloc[val_index].index, 'fold'] = int(n)
df_folds['fold'] = df_folds['fold'].astype(int)
print(df_folds.groupby(['fold', df_folds.integer_label]).size())

fold  integer_label
0     0                 334
      1                1160
      2                 286
      3                 121
1     0                 334
      1                1160
      2                 286
      3                 121
2     0                 334
      1                1160
      2                 286
      3                 121
3     0                 334
      1                1160
      2                 286
      3                 121
4     0                 333
      1                1159
      2                 286
      3                 122
dtype: int64


# Dataset & Dataloader

In [None]:
class SIIM(Dataset):
    def __init__(self, df, is_train=True, augments=None, 
                 reshape_size=Config.reshape_size):
        super().__init__()
        # random sample data
        self.df = df.sample(frac=1).reset_index(drop=True)
        # training or validation
        self.is_train = is_train
        # augmentations
        self.augments = augments
        self.reshape_size = reshape_size
    
    def __len__(self) -> int:
        return(len(self.df.shape[0]))
    
    @staticmethod
    def dicom2array(path: str, voi_lut=True, fix_monochrome=True):
        dicom = pydicom.read_file(path)
        # VOI LUT (if available by DICOM device) is used to
        # transform raw DICOM data to "human-friendly" view
        if voi_lut:
            data = apply_voi_lut(dicom.pixel_array, dicom)
        else:
            data = dicom.pixel_array
        # depending on this value, X-ray may look inverted - fix that:
        if fix_monochrome and dicom.PhotometricInterpretation == "MONOCHROME1":
            data = np.amax(data) - data
        data = data - np.min(data)
        data = data / np.max(data)
        data = (data * 255).astype(np.uint8)
        return data
    
    def load_bbox(idx: int):
        return(data)
    
    def __getitem__(self, idx: int):
        # retrieve idx data
        image_id = self.df['StudyInstanceUID'].values[idx]
        image_path = self.df['path'].values[idx]
        # get image
        image = self.dicom2array(image_path)
        # Augments
        if self.augments:
            image = self.augments(image=image)
        else:
            image = np.transpose(image, (2, 0, 1)).astype(np.float32)
            image = torch.tensor(image)  
        # if train
        if self.is_train:
            label = self.df[self.df['StudyInstanceUID'] == image_id].values.tolist()[0][4:-2]
            return image, torch.tensor(label)

        return image