In [None]:
import matplotlib.pyplot as plt 
from pydicom import dcmread
import os 
import cv2
import torch.nn as nn
import pandas as pd 
import numpy as np
import torch 
from typing import List
from matplotlib.patches import Rectangle 

BASE_FOLDER = "/kaggle/input/vinbigdata-chest-xray-abnormalities-detection/train"

In [None]:
df = pd.read_csv("/kaggle/input/vinbigdata-chest-xray-abnormalities-detection/train.csv")

In [None]:
def get_sample(df, index):
    sample = df.iloc[index]
    image_id = sample['image_id']
    class_name = sample['class_name']
    path = os.path.join(BASE_FOLDER, image_id+".dicom")
    
    raw_img = dcmread(path)
    img = raw_img.pixel_array
    mask = np.zeros(img.shape)
    
    if class_name != 'No finding':
        x_min = sample['x_min'].astype(int)
        y_min = sample['y_min'].astype(int)
        height = sample['y_max'].astype(int) - y_min
        width = sample['x_max'].astype(int) - x_min
        
        mask[y_min:y_min+width, x_min:x_min+height] = 1
        
    return img, mask, class_name
        
def plot_image(df, index):
    
    sample = df.iloc[index]
    img, mask, class_name = get_sample(df, index)
    
    fig, ax = plt.subplots(1,2, dpi=100)
    ax[0].imshow(img, cmap=plt.cm.gray);
    ax[1].imshow(mask, cmap=plt.cm.gray);

    if class_name != 'No finding':
        x_min = sample['x_min'].astype(int)
        y_min = sample['y_min'].astype(int)
        height = sample['y_max'].astype(int) - y_min
        width = sample['x_max'].astype(int) - x_min
        
        rect = Rectangle(
            (x_min, y_min), 
            height, 
            width, 
            edgecolor='red', 
            facecolor='none'
        )
    
    ax[0].add_patch(rect)
    plt.tight_layout();
    plt.show();

In [None]:
plot_image(df, 3);

In [None]:
def double_conv(in_c, out_c):
    return nn.Sequential(
        nn.Conv2d(in_c, out_c, kernel_size=3),
        nn.ReLU(inplace=True),
        nn.Conv2d(out_c, out_c, kernel_size=3),
        nn.ReLU(inplace=True),
    )

def crop_and_cat(source, target):
    source_size_x = target.shape[2]
    target_size_x = target.shape[2]
    
    source_size_y = target.shape[3]
    target_size_y = target.shape[3]
    
    delta_x = (source_size_x-target_size_x)//2
    delta_y = (source_size_y-target_size_y)//2
    
    cropped = source[:, :, delta_x: source_size_x-delta_x, delta_y:source_size_y-delta_y]
    clipped = torch.cat([cropped, target], 1)
    return clipped
        
class UNET(nn.Module):
    def __init__(self):
        super(UNET, self).__init__()
        
        self.max_pool = nn.MaxPool2d(kernel_size=2, stride=2)
        
        self.down_conv1 = double_conv(1, 64)
        self.down_conv2 = double_conv(64, 128)
        self.down_conv3 = double_conv(128, 256)
        self.down_conv4 = double_conv(256, 512)
        self.down_conv5 = double_conv(512, 1024)
        
        self.up_trans1 = nn.ConvTranspose2d(1024, 512, kernel_size=2, stride=2)
        self.up_conv1 = double_conv(1024, 512)

        self.up_trans2 = nn.ConvTranspose2d(512, 256, kernel_size=2, stride=2)
        self.up_conv2 = double_conv(512, 256)
        
        self.up_trans3 = nn.ConvTranspose2d(256, 128, kernel_size=2, stride=2)
        self.up_conv3 = double_conv(256, 128)

        self.up_trans4 = nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2)
        self.up_conv4 = double_conv(128, 64)

        self.out = nn.Conv2d(64, 2, kernel_size=1)
        
        
    def forward(self, image):
        
        # encoder
        
        x1 = self.down_conv1(image)
        x2 = self.max_pool(x1)
        x3 = self.down_conv2(x2)
        x4 = self.max_pool(x3)
        x5 = self.down_conv3(x4)
        x6 = self.max_pool(x5)
        x7 = self.down_conv4(x6)
        x8 = self.max_pool(x7)
        x9 = self.down_conv5(x8)

        # decoder 
        x = self.up_trans1(x9)
        x = crop_and_cat(x7 , x)
        x = self.up_conv1(x)

        x = self.up_trans2(x)
        x = crop_and_cat(x5 , x)
        x = self.up_conv2(x)        

        x = self.up_trans3(x)
        x = crop_and_cat(x3 , x)
        x = self.up_conv3(x)

        x = self.up_trans4(x)
        x = crop_and_cat(x1 , x)
        x = self.up_conv4(x)

        out = self.out(x)
        print(out.shape)


In [None]:
device = torch.device('cuda')

unet = UNET()
unet.to(device)

img, _, _ = get_sample(df, 1)
x = img.reshape(1,1,img.shape[0], img.shape[1]).astype(np.float32)
x = torch.tensor(x)
x = x.to(device)

unet(x)

In [None]:
x.shape