# Imports

In [None]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns

import cv2
import PIL
import glob
from tqdm.notebook import tqdm

import torch
import torch.nn as nn
import albumentations as alb
from pydicom import dcmread

# Helper

In [None]:
def plot(img):
    plt.figure(figsize=(10, 10))
    plt.title("Raw visualization")
    plt.imshow(img, 'gray')


def get_labels(df):
    for i, row in df.iterrows():
        if(row[labels[0]] == 1):
            df.loc[i, 'label']=labels[0]
        elif(row[labels[1]] == 1):
            df.loc[i, 'label']=labels[1]
        elif(row[labels[2]] == 1):
            df.loc[i, 'label']=labels[2]
        elif(row[labels[3]] == 1):
            df.loc[i, 'label']=labels[3]

In [None]:
def get_paths(df):
    list_, labels=[], []

    for i, path_ in tqdm(enumerate(df.path), total=len(df)):
        for folder in glob.glob(path_+'*'):
            list_.extend(glob.glob(folder+'/*.dcm'))
            labels.extend([df.iloc[i, 6] for _ in glob.glob(folder+'/*.dcm')])
            if(len(list_)!=len(labels)):
                print(len(list_),len(labels))
    return list_, labels


def convert_dicom(path):
    dicom = dcmread(path)
    return dicom.pixel_array

def augmentations(path):
    img = convert_dicom(path)
    
    aug = alb.Compose([
#         alb.RandomBrightnessContrast(p=1),
        alb.Emboss(p=1),
#         alb.Downscale(scale_min=0.1, scale_max=0.5,p=1),
#         alb.Equalize(p=1),
        alb.RandomGamma(p=1),
    ])
    
    return aug(image=img)['image']

# Data processing

In [None]:
study_level = pd.read_csv('../input/siim-covid19-detection/train_study_level.csv')
study_level.head()

In [None]:
study_level['id'] = study_level.apply(lambda row: row.id.split('_')[0], axis=1)
study_level['path'] = study_level.apply(lambda row: "../input/siim-covid19-detection/train/"+row.id+'/', axis=1)

study_level.head(5)

In [None]:
labels = ['Negative for Pneumonia', 'Typical Appearance', 'Indeterminate Appearance', 'Atypical Appearance']
study_level['label']=''

get_labels(study_level)

In [None]:
study_level

# Image and label segregation

In [None]:
train_images_path, train_images_labels = get_paths(study_level)
len(train_images_path), len(train_images_labels)

# Class Distribution

In [None]:
sns.countplot(study_level['Negative for Pneumonia'].value_counts())

In [None]:
sns.countplot(study_level['Typical Appearance'].value_counts())

In [None]:
sns.countplot(study_level['Indeterminate Appearance'].value_counts())

In [None]:
sns.countplot(study_level['Atypical Appearance'].value_counts())

In [None]:
sns.countplot(study_level.label)

# Image Visualization

In [None]:
dicom = dcmread(train_images_path[10])
img = dicom.pixel_array
plot(img)

print(train_images_labels[10])

# Feature Extraction

## Albumentations

In [None]:
img = augmentations(train_images_path[10])
plot(np.array(img))

## OpenCV Image Threshold

In [None]:
img = convert_dicom(train_images_path[5])
retval, threshold = cv2.threshold(img, 2500, 255, cv2.THRESH_TOZERO_INV)
plt.imshow(threshold)

In [None]:
img = convert_dicom(train_images_path[5])
retval, threshold = cv2.threshold(img, 2500, 255, cv2.THRESH_TOZERO)
plt.imshow(threshold)

In [None]:
img = convert_dicom(train_images_path[5])
retval, threshold = cv2.threshold(img, 2500, 255, cv2.THRESH_TRUNC)
plt.imshow(threshold)

In [None]:
img = convert_dicom(train_images_path[5])
retval, threshold = cv2.threshold(img, 2500, 255, cv2.THRESH_BINARY_INV)
plt.imshow(threshold)

In [None]:
img = convert_dicom(train_images_path[5])
retval, threshold = cv2.threshold(img, 2500, 255, cv2.THRESH_BINARY)
plt.imshow(threshold)

## Normal Deep 2D Convolution

In [None]:
class FeatureExtractor(nn.Module):
    def __init__(self, fil1, fil2, fil3):
        super(FeatureExtractor, self).__init__()
        self.conv1 = nn.Conv2d(1, 1, fil1)
        self.conv2 = nn.Conv2d(1, 1, fil2)
        self.conv3 = nn.Conv2d(1, 1, fil3)
        
    def forward(self, img):
        img = self.conv1(img)
        img = self.conv2(img)
        img = self.conv3(img)
        return img
    
model = FeatureExtractor(32, 32, 64)

In [None]:
img = convert_dicom(train_images_path[10])
img = torch.Tensor(img.astype('float32')).unsqueeze(0).unsqueeze(0)
plt.imshow(model(img).squeeze(0).squeeze(0).squeeze(0).detach().numpy())