## RSNA: STR Plumonary Embolism Detection


### What is plumonary Embolism ?

![image.png](https://a360-rtmagazine.s3.amazonaws.com/wp-content/uploads/2019/10/lung-pulmo-embolism-1500-1280x640.jpg)

Plumonary Embolism is the blood clot that is formed in the artery of the lungs.<br/>
It is a very serious disease which causes death of one-third people who go undiagnosed or undetected
with the disease. <br/>
But immediate treatment could help prevent permenant lung damage.

### What causes Plumonary Embolism ?

Plumonary Embolism is caused due to the blood clot formed deep in the vein of our body.<br/>
Main reason for the blood clot is [deep vein thrombosis](https://www.healthline.com/health/deep-venous-thrombosis)<br/> The blood clot which causes plumonary embolism are formed in legs or pelvis.

### Reason for blood clots.

**Injury Damage**: bone fracture , tissue damage etc.<br/>
**Inactivity**: sitting for long period of time or lyinig on bed because of illness.<br/>
**Medical Condition**: There are some medical condition which can cause blood clot to form in our body.<br/>

### symptoms of plumonary embolism.

The symptoms of plumonary embolism depends on the size of blood clot.<br/>
Main symptoms is shortness of breath which can be gradual or sudden.
other symptoms.

* anxiety
* clammy or bluish skin
* chest pain that may extend into your arm, jaw, neck, and shoulder
* fainting
* irregular heartbeat
* lightheadedness
* rapid breathing
* rapid heartbeat
* restlessness
* spitting up blood
* weak pulse

Read more about Plumonary Embolism [here](https://www.healthline.com/health/pulmonary-embolus).

## 1. What is the competition about ?üí°

Here we are provided with the CT Plumonary angiography(CTPA) which is common type of medical imaging.<br/>
Data Contains hundereds of such images and we have to use our model to predict wether the image has PE or Not.

**Let me know if any information or code is incorrect i will correct it and<br/> 
  if you find this notebook usefull please UPVOTE üòÄ**


## 2.Metrics: Weighted Log-loss üìè

The metrics used by the competion is weighte log-loss which is weighted over some labels.

We have to predict total 10 labels 9 for exam/study level and 1 for image level.

so submission file should have number of rows equal to <br/>
(number of images) + (number of exam/study label * number of exam/study)

labels for exam/study level.

* Label: Weight
* Negative for PE : 0.0736196319
* Indeterminate 	:0.09202453988
* Chronic 	:0.1042944785
* Acute & Chronic 	:0.1042944785
* Central PE 	:0.1877300613
* Left PE 	:0.06257668712
* Right PE 	:0.06257668712
* RV/LV Ratio >= 1 :0.2346625767
* RV/LV Ratio < 1 :0.0782208589

### 2.1 Exam/Study- level weighted log-loss

yij be the label for the exam i and label j. yij = 1 if present else 0 and pij<br/>
is the predicted probability. weight of the label j be wj. so the weighted log-loss is given by.

exam_log_loss = -wj * [yij * log(pij) + (1 - yij)* log(1-pij)]

Then mean is taken over the log_loss for all such labels j.

### 2.2 Image level weighted log-loss

yik =1 if PE is present in the image else 0. where i is exam number and k is image number.<br/>
Now the weigtage of the label PE present or not is w = 0.0736196319("Negative for PE"). <br/>
qi be the ratio of positive images to total images.

image_log_loss = -w * qi [yik * log(pik) + (1-yik)*log(1-pik)]

The total loss is the average of all image and exam loss, divided by the average of all row weights.<br/>
To get the average of all rows weights, sum the weights of all images and all exam-level labels and divide by number of rows.

Now if that was not confusing enough for you there is another catch.

[evaluation](https://www.kaggle.com/c/rsna-str-pulmonary-embolism-detection/overview/evaluation) page.

### 2.3 All the labels must be logically consistent.

All the labels in the submission file must be logically consistent or your submission will be disqualified.

What is meant by logically consistent.

At the image level, any image with predicted probability > 0.5 is considered as being positive for PE will count as a positive image

At the exam level, we have

1. Negative, Indeterminate, (Positive) and it can only be one of these. If any image is predicted positive, there cannot also be a predicted probability of Negative > 0.5 nor can there be a predicted probability of Indeterminate > 0.5.

Similarly, if no image is positive (p > 0.5), then there must be one and only one negative or indeterminate with p > 0.5

1. Right, left, central -- if any image is predicted positive (p > 0.5) then at least one of these labels must be assigned p > 0.5; more than one of these labels may be assigned p > 0.5. When no images are predicted positive, then none of these labels may be assigned p > 0.5

2. RV/LV ratio. It can be only one of these and it must be present if at least one image is positive.

   * if any image on the exam is positive, one of these must have     p > 0.5 both cannot have p > 0.5

3. Acute, Chronic, Acute & Chronic -- it cannot be both chronic      & acute and chronic so

   * only one can have p > 0.5
   * it is also possible that neither has p > 0.5
   * in other words, it is inconsistent to say chronic has p >     0.5 and acute & chronic has p > 0.5.

Code for checking consitency of the submission is provided [here](https://www.kaggle.com/anthracene/host-confirmed-label-consistency-check)<br/>
which will be used in this notebook

### Importing Libraries üìò

In [None]:
import os
import sys
import glob
import tqdm
from typing import Dict
import cv2
from collections import Counter
import random

import pydicom as dicom
from joblib import Parallel, delayed


import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
import matplotlib.animation as animation
from IPython.display import HTML


plt.style.use("fivethirtyeight")

import plotly.express as px
import plotly.graph_objs as go
from plotly.offline import iplot
import plotly.figure_factory as ff

#supress warnings
import warnings
warnings.filterwarnings('ignore')

from colorama import Fore, Back, Style
y_ = Fore.YELLOW
r_ = Fore.RED
g_ = Fore.GREEN
b_ = Fore.BLUE
m_ = Fore.MAGENTA
sr_ = Style.RESET_ALL

### Getting dataüíΩ

In [None]:
folder_path = "../input/rsna-str-pulmonary-embolism-detection"
train_path = folder_path + "/train/"
test_path = folder_path + "/test/"
    
train_data = pd.read_csv(folder_path + "/train.csv")
test_data  = pd.read_csv(folder_path + "/test.csv")
sample = pd.read_csv(folder_path + "/sample_submission.csv")

In [None]:
print(f"{y_}Number of rows in train data: {r_}{train_data.shape[0]}\n{y_}Number of columns in train data: {r_}{train_data.shape[1]}")
print(f"{g_}Number of rows in test data: {r_}{test_data.shape[0]}\n{g_}Number of columns in test data: {r_}{test_data.shape[1]}")
print(f"{b_}Number of rows in submission data: {r_}{sample.shape[0]}\n{b_}Number of columns in submission data:{r_}{sample.shape[1]}")

**There are four labels in training set which are just for information and requires no predictions OA Constrast, QA Motion, True filling defect not PE and Flow artifact**.

In [None]:
columns_only_for_info = ["qa_motion","qa_contrast","flow_artifact","true_filling_defect_not_pe"]
cols_ID = ["StudyInstanceUID","SeriesInstanceUID","SOPInstanceUID"]

def highlight_cols(x):
    df = x.copy()
    df.loc[:,:] = 'background-color: lightgreen'
    df[columns_only_for_info] = 'background-color: red'
    return df 

train_data.head().style.apply(highlight_cols, axis=None)

In [None]:
test_data.head()

In [None]:
sample.head()

Location of each image is as "StudyInstanceUID"/SeriesInstanceUID/SOPInstanceUID.dcm <br/>
let us perform sanity check wether StudyInstanceUID , SeriesInstanceUID and SOP instance UID in data match with actual folders

In [None]:
def sanity_check():
    #get all StudyInstanceUID
    train_StudyInstanceUID = os.listdir(train_path)
    test_StudyInstanceUID = os.listdir(test_path)
    
    train_SeriesInstanceUID = list()
    test_SeriesInstanceUID = list()
    
    #get all SeriesInstanceUID 
    for x in train_StudyInstanceUID:
        train_SeriesInstanceUID.extend(os.listdir(train_path+x))
    
    
    for x in test_StudyInstanceUID:
        test_SeriesInstanceUID.extend(os.listdir(test_path+x))
    
    train_SOPInstanceUID = list()
    test_SOPInstanceUID = list()
    
    #get all SOPInstanceUID
    for x,y in zip(train_StudyInstanceUID,train_SeriesInstanceUID):
        train_SOPInstanceUID.extend(os.listdir(train_path+x+"/"+y))

    for x,y in zip(test_StudyInstanceUID,test_SeriesInstanceUID):
        test_SOPInstanceUID.extend(os.listdir(test_path+x+"/"+y))
        
    #removing the extention
    train_SOPInstanceUID = [x.split(".")[0] for x in train_SOPInstanceUID]
    test_SOPInstanceUID = [x.split(".")[0] for x in test_SOPInstanceUID]
    
    ## Note: there might be better way to do this
    train_UIDs = [train_StudyInstanceUID,train_SeriesInstanceUID,train_SOPInstanceUID]
    test_UIDs = [test_StudyInstanceUID,test_SeriesInstanceUID,test_SOPInstanceUID]
    cols = ["StudyInstanceUID","SeriesInstanceUID","SOPInstanceUID"]
    

    for y,x in zip(train_UIDs,cols):
        if train_data[x].nunique() == len(y):
            print(f"{g_}Number of {x} in train data and folder are same")

            if Counter(train_data[x].unique()) == Counter(y):
                print(f"{g_}{x} in train csv and folder are same")

            else:
                print(f"{r_}{x} in train csv and folder are not same")
                not_in_train_folder = y[Counter(train_data[x].unique()) != Counter(y)]
                print("Not in train folder: ",not_in_train_folder)

        else:
            print(f"{r_}Number of {x} in train data and folder are same")

    for y,x in zip(test_UIDs,cols):
        if test_data[x].nunique() == len(y):
            print(f"{g_}Number of {x} in test data and folder are same")

            if Counter(test_data[x].unique()) == Counter(y):
                print(f"{g_}{x} in test csv and folder are same")

            else:
                print(f"{r_}{x} in test csv and folder are not same")
                not_in_test_folder = y[Counter(test_data[x].unique()) != Counter(y)]
                print("Not in test folder",not_in_test_folder)
        else:
            print(f"{r_}Number of {x} in test data and folder are same")
    

In [None]:
sanity_check()

In [None]:
train_data["ImagePath"] =train_path+ train_data[cols_ID[0]]+"/"+train_data[cols_ID[1]]+"/"+train_data[cols_ID[2]]+".dcm"
test_data["ImagePath"] = test_path+ test_data[cols_ID[0]]+"/"+test_data[cols_ID[1]]+"/"+test_data[cols_ID[2]]+".dcm"

train_data["ImageDict"] =train_path+ train_data[cols_ID[0]]+"/"+train_data[cols_ID[1]]
test_data["ImageDict"] = test_path+ test_data[cols_ID[0]]+"/"+test_data[cols_ID[1]]

In [None]:
dicom.dcmread(train_data.loc[0,"ImagePath"])

### Basic Physics Behind CT scan.

* So Idea behind CT(computed tomography) scan is that when a x-ray beam is passed through a tissue
  [attenuation factor](https://radiopaedia.org/articles/attenuation-coefficient?lang=us) of tissue is calculated.

* attenuation factor  is a measure of by how much strength of the beam is reduced when it is passed through
  certain material. 
  
* As different body parts have different attenuation factor it helps to seprate those parts easily.

### Hounsfield Unit.

* In 1979 **[Godfrey Hounsfield](https://en.wikipedia.org/wiki/Godfrey_Hounsfield)** got Nobel Price for his part in making CT scanners

* Task of CT scanner is to find density of different tissues in our body now to <br>
  convert this density to gray scale image Hounsfield came up with a way of linear transforming<br/>
  the density to gray scale image.<br/>

* Hounsfield Unit is calculated by considering density of water as 0 unit and air as -1000 unit<br/>
  and all other density are transformed accordingly. Normally body parts lies in the range<br/>
  -1000(ex air in lungs) to +1000.
 
* This -1000 to 1000 range is adjueste for 256 gray values to come up with image.

* As the values spectral composition of the x-ray depands on various parameters and voltage Hounsfield unit<br/>
  makes it easier to compare CT scans from different machines.

### window and level
* Gray scale has 256 different units to represent colours but our eyes are not able to see small changes in colour<br/>
  in gray images. 

* so instead of spreading whole Hounsfield Range to a gray image we select window of Hounsfield<br/>
  and spread 256 values between these Hounsfield values 
  
* everything below this range is black and above the range is white.<br/>

* Level is center of this window. 

* We can adjust this window size according to our need to get better view of particular tissue we<br/>
  want to observer.

watch [this](https://www.youtube.com/watch?v=KZld-5W99cI&feature=youtu.be) video for more info

### voxel size

* voxel is basically 3d version of pixel as pixel is representation of image in 2d voxel is a 3d representation.
  here it means pixelspacing in x and y direction and slice thickness in z direction
  
* There are two type of voxel isotropic in which step in all sides are same and non isotropic steps are different.

* smaller voxel sizes give better image.

![image](https://res.cloudinary.com/mtree/image/upload/f_auto,q_auto,f_jpg,fl_attachment:ce531-fig07-voxel/dentalcare/%2F-%2Fmedia%2Fdentalcareus%2Fprofessional-education%2Fce-courses%2Fcourse0501-0600%2Fce531%2Fimages%2Fce531-fig07-voxel.jpg%3Fh%3D400%26la%3Den-us%26w%3D700%26v%3D1-201710231815?h=400&la=en-US&w=700)





### Metadata of Image

In [None]:
N = 10000

def get_value(data):
    if type(data) == dicom.multival.MultiValue:
        return np.int(data[0])
    else:
        return np.int(data)
        
def get_meta_features(path):
    data = dicom.dcmread(path)
   
    slicethickness = data.SliceThickness
    windowwidth = get_value(data.WindowWidth)
    rows = data.Rows
    columns = data.Columns
    windowcenter = get_value(data.WindowCenter)
    intercepts = data.RescaleIntercept
    slopes = data.RescaleSlope
    pixelspacingcolumn = data.PixelSpacing[1]
    pixelspacingrows = data.PixelSpacing[0]
    kvp = data.KVP
    tableheight = data.TableHeight
    xray = data.XRayTubeCurrent
    exposure = data.Exposure
    modality = data.Modality
    rotationdirection = data.RotationDirection
    instancenumber = data.InstanceNumber
    
    final_data = [slicethickness,windowwidth,rows,columns,windowcenter,intercepts,
                 slopes, pixelspacingcolumn,pixelspacingrows,kvp,tableheight,
                 xray,exposure,modality,rotationdirection,instancenumber]
    return final_data

meta_data = Parallel(n_jobs = -1, verbose = 1)(map(delayed(get_meta_features),train_data["ImagePath"].sample(n=N)))

In [None]:
meta_data = pd.DataFrame(meta_data,
    columns = ["SliceThickness",
                "WindowWidth",
                "Rows",
                "Columns",
                "WindowCenter",
                "Intercept",
                "Slope",
                "PixelSpacingRows",
                "PixelSpacingColumns",
                "KVP",
                "TableHeight",
                "XRay",
                "Exposure",
                "Modality",
                "RotationDirection",
                "InstanceNumber"])

meta_data["Area"] = meta_data["Rows"] * meta_data["Columns"]
meta_data["PixelArea"] = meta_data["PixelSpacingRows"] * meta_data["PixelSpacingColumns"]
meta_data.head()

## 3 EDA

### 3.1 Check Row and Columns range

In [None]:
def dist(column,color):
    sns.distplot(meta_data[column],label=column,color=color)
    plt.legend()

plt.figure(figsize=(15,7))
plt.subplot(121)
dist("Rows","blue")
plt.subplot(122)
dist("Columns","green")
plt.show()

There are 512 rows and columns in all 10000 images so it might be safe to assume that all the images follows same.

### 3.2 Distribution of Pixel Spacing Columns and Rows

In [None]:
plt.figure(figsize=(15,7))
plt.subplot(121)
dist("PixelSpacingRows","purple")
plt.subplot(122)
dist("PixelSpacingColumns","red")
plt.show()

### 3.3 Distribution of Pixel Area

In [None]:
plt.figure(dpi=100)
dist("PixelArea","yellow")

### 3.4 Distribution of Window width

In [None]:
plt.figure(dpi=100)
dist("WindowWidth","orange")

### 3.5 Distribution of KVP

In [None]:
plt.figure(dpi=100)
dist("KVP","blue")

### 3.6 Distribution of TableHeight

In [None]:
plt.figure(dpi=100)
dist("TableHeight","brown")

### 3.7 Distribution of  XRay 

In [None]:
plt.figure(dpi=100)
dist("XRay","pink")

### 3.8 countplot of modality and rotation direction


In [None]:
plt.subplot(121)
sns.countplot(meta_data['Modality'])
plt.subplot(122)
sns.countplot(meta_data["RotationDirection"])
plt.show()

It means all the Modality are CT and rotation direction are CW

### 3.9 Distribution of Image Raw and Rescaled values

In [None]:
def distribution_of_image_values(n,train=True):
    samples = train_data.sample(n=n) if train else test_data.sample(n=n)
    image_paths = samples["ImagePath"].values
    
    plt.figure(figsize=(15,7))
    
    for i,image_path in enumerate(image_paths):
        image_data = dicom.dcmread(image_path)
        try:
            image = image_data.pixel_array.flatten()
            rescaled_image = image * image_data.RescaleSlope + image_data.RescaleIntercept
        
            plt.subplot(121)
            sns.distplot(image.flatten())
            plt.title("Raw Image")
        
            plt.subplot(122)
            sns.distplot(rescaled_image.flatten())
            plt.title("Rescaled Image")
        except:
            pass
    plt.show()  

In [None]:
distribution_of_image_values(100)

In [None]:
import gc
del meta_data
gc.collect()

## 4. Visulizing Images üñºÔ∏è


#### What is DICOM image ? 

**DICOM(Digital Image and Communication in Medecine)** is a standard developed and<br/>
maintained by **National Electrical Manufacturers Association (NEMA)** for storing<br/>
and transfering the medical images like CT(computerised Tomography), <br/>
Magnetic resonanse image(MRI) and other types of medical images.

DICOM is very good protocol and intresting, to read further click [here](https://en.wikipedia.org/wiki/DICOM)


### 4.1 Single Image üñºÔ∏è

In [None]:
def show_image(train=True):
    image_path = train_data["ImagePath"].sample(n=1).values[0] if train\
                 else test_data["ImagePath"].sample(n=1).values[0]
    print(f"{y_} Image {r_}{image_path}")
    image = dicom.dcmread(image_path)
    image = image.pixel_array
    plt.figure(figsize=(7,7))
    plt.imshow(image,cmap='gray')
    plt.axis('off')
    plt.show()

In [None]:
show_image()

### 4.2 Grid of sorted images of some random patient 

In [None]:
def show_grid(cmap='gray',train=True):
    single_sample = train_data.sample(n=1) if train else test_data.sample(n=1)
    image_dict = single_sample["ImageDict"].values[0]
    
    images = [dicom.read_file(image_dict+"/"+filename) for filename in os.listdir(image_dict)]
    images.sort(key = lambda x: float(x.ImagePositionPatient[2]))
    plt.figure(figsize=(10,10))
    
    for i,image in enumerate(images[:100]):
        plt.subplot(10,10,i+1)
        plt.imshow(image.pixel_array,cmap=cmap)
        plt.axis('off')
    plt.show()

In [None]:
show_grid()

In [None]:
show_grid(cmap='jet',train=False)

In [None]:
show_grid(cmap='RdYlBu')

### 4.3 Animation

In [None]:
def show_animation(train=True):
    single_sample = train_data.sample(n=1) if train else test_data.sample(n=1)
    image_dict = single_sample["ImageDict"].values[0]

    images = [dicom.read_file(image_dict+"/"+filename) for filename in os.listdir(image_dict)]
    images.sort(key = lambda x: float(x.ImagePositionPatient[2]))
    fig = plt.figure()
    ims = list()
    for image in images:
        img = plt.imshow(image.pixel_array,cmap='gray',animated=True)
        plt.axis('off')
        ims.append([img])
    ani = animation.ArtistAnimation(fig,ims,interval=100,blit=False,repeat_delay=1000)
    return ani

ani = show_animation()    

In [None]:
HTML(ani.to_jshtml())

### 4.4 3d- Reconstruction

In [None]:
def get_x_y_cordinate(image,n1,n2):
    arr = np.argwhere((image >= n1) & (image <= n2))
    return arr
          
def reconstruct(train=True,n1=-1000,n2=2000,s=2,color='b',alpha=0.01,number=1000):
    single_sample = train_data.sample(n=1) if train else test_data.sample(n=1)
    image_dict = single_sample["ImageDict"].values[0]
    
    images = [dicom.read_file(image_dict+"/"+filename) for filename in os.listdir(image_dict)]
    images.sort(key = lambda x: float(x.ImagePositionPatient[2]))
    
    fig = plt.figure(figsize=(10,10))
    ax = fig.add_subplot(111,projection='3d')
    
    for i,image in enumerate(images[:number]):
        img = image.pixel_array
        arr = get_x_y_cordinate(img,n1,n2)
        x = arr[:,0]
        y = arr[:,1]
        z = np.full(shape=len(x),fill_value = i+images[i].SliceThickness)
        ax.scatter(x,y,z,s=s,c=color,alpha=alpha)
    
    ax.set_xlabel('X Label')
    ax.set_ylabel('Y Label')
    ax.set_zlabel('Z Label')
    plt.show()

In [None]:
reconstruct(n1=1500,n2=2000,alpha=0.1,s=1)

well I though it would work

## Pytorch Baseline Model üî•

In [None]:
import torch
import torch.nn as nn
import torch.functional as F
import torch.optim as optim
import torchvision
from torchvision import models
from torch.utils.data import Dataset,DataLoader
import cv2
import albumentations as A
from albumentations.pytorch.transforms import ToTensorV2

from sklearn.model_selection import KFold

import vtk
from vtk.util import numpy_support
from tqdm.auto import tqdm

In [None]:
SEED  = 42

def seed_everything(seed):
    random.seed(seed)
    os.environ['PYTHONHASSEED']  = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = True

seed_everything(SEED)

In [None]:
target_columns = ['pe_present_on_image', 'negative_exam_for_pe', 'rv_lv_ratio_gte_1', 
                  'rv_lv_ratio_lt_1','leftsided_pe', 'chronic_pe','rightsided_pe', 
                  'acute_and_chronic_pe', 'central_pe', 'indeterminate']

#vtk is used because dicom is giving some error

reader = vtk.vtkDICOMImageReader()
def get_img(path):
    reader.SetFileName(path)
    reader.Update()
    _extent = reader.GetDataExtent()
    ConstPixelDims = [_extent[1]-_extent[0]+1, _extent[3]-_extent[2]+1, _extent[5]-_extent[4]+1]

    ConstPixelSpacing = reader.GetPixelSpacing()
    imageData = reader.GetOutput()
    pointData = imageData.GetPointData()
    arrayData = pointData.GetArray(0)
    ArrayDicom = numpy_support.vtk_to_numpy(arrayData)
    ArrayDicom = ArrayDicom.reshape(ConstPixelDims, order='F')
    ArrayDicom = cv2.resize(ArrayDicom,(512,512))
    return ArrayDicom


def convert_to_rgb(array):
    array = array.reshape((512, 512, 1))
    return np.stack([array, array, array], axis=2).reshape((3,512, 512))

In [None]:
class RsnaDataset(Dataset):
    
    def __init__(self,df,transforms=None):
        super().__init__()
        self.image_paths = df['ImagePath'].unique()
        self.df = df
        self.transforms = transforms
    
    def __getitem__(self,index):
        
        image_path = self.image_paths[index]
        data = self.df[self.df['ImagePath']==image_path]
        labels = data[target_columns].values.reshape(-1)
        image = get_img(image_path)
        image = convert_to_rgb(image)
        
        if self.transforms:
            image = self.transforms(image=image)['image']
            
        image = torch.tensor(image,dtype=torch.float)        
        labels = torch.tensor(labels,dtype=torch.float)
        
        return image,labels
           
    def __len__(self):
        return self.image_paths.shape[0]  

In [None]:
classes = len(target_columns)
model = models.resnet18(pretrained=True)
in_features = model.fc.in_features
model.fc = nn.Linear(in_features,classes)

config={
       "learning_rate":0.001,
       "train_batch_size":32,
        "valid_batch_size":32,
       "epochs":10,
       "nfolds":3,
       "number_of_samples":7000
       }

train_data = train_data.sample(n=config["number_of_samples"]).reset_index(drop=True)

In [None]:
def run(plot_losses=True):
  
    def train_loop(train_loader,model,loss_fn,device,optimizer,lr_scheduler=None):
        model.train()
        total_loss = 0
        tqdm_loader = tqdm(train_loader)
        for i, (images, targets) in enumerate(tqdm_loader):
            images,targets = images.to(device), targets.to(device)
            
            optimizer.zero_grad()
            outputs = model(images)

            loss = loss_fn(outputs,targets)
            loss.backward()
                
            total_loss += loss.item()

            optimizer.step()
            if lr_scheduler != None:
                lr_scheduler.step()
                    
        total_loss /= len(train_loader)
        return total_loss
    
    def valid_loop(valid_loader,model,loss_fn,device):
        model.eval()
        total_loss = 0
        predictions = list()
        tqdm_loader = tqdm(valid_loader)

        for i, (images, targets) in enumerate(tqdm_loader):
            images, targets = images.to(device),targets.to(device)
            
            outputs = model(images)                 

            loss = loss_fn(outputs,targets)
            predictions.extend(outputs.detach().cpu().numpy())
            
            total_loss += loss.item()
        total_loss /= len(valid_loader)
            
        return total_loss,np.array(predictions)    
    
    kfold = KFold(n_splits=config["nfolds"])

    fold_train_losses = list()
    fold_valid_losses = list()
    
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"{device} is used")
    
    loss_fn = nn.BCEWithLogitsLoss()
    
    for k , (train_idx,valid_idx) in enumerate(kfold.split(train_data)):
        
        x_train,x_valid = train_data.loc[train_idx,:],train_data.loc[valid_idx,:]
        
        model.to(device)

        train_ds = RsnaDataset(x_train)
        train_dl = DataLoader(train_ds,
                             batch_size = config["train_batch_size"],
                             shuffle=True
                             )

        valid_ds = RsnaDataset(x_valid)
        valid_dl = DataLoader(valid_ds,
                             batch_size = config["valid_batch_size"],
                             shuffle=False
                             )
        
        optimizer = optim.Adam(model.parameters(),lr=config["learning_rate"])
        lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer,T_max= 300,eta_min= 0.000001)

        print(f"Fold {k}")
        best_loss = 999
        train_losses = list()
        valid_losses = list()
        
        for i in range(config['epochs']):
            train_loss = train_loop(train_dl,model,loss_fn,device,optimizer,lr_scheduler)
            valid_loss,predictions = valid_loop(valid_dl,model,loss_fn,device)
            
            train_losses.append(train_loss)
            valid_losses.append(valid_loss)
                          
            print(f"epoch:{i} Training | loss:{train_loss}  Validation | loss:{valid_loss}  ")
            
            if valid_loss <= best_loss:
                print(f"{g_}Validation loss Decreased from {best_loss} to {valid_loss}{sr_}")
                best_loss = valid_loss
                torch.save(model.state_dict(),f'model{k}.bin')
                
        fold_train_losses.append(train_losses)
        fold_valid_losses.append(valid_losses)
        
    if plot_losses == True:
        plt.figure(figsize=(20,14))
        for i, (t,v) in enumerate(zip(fold_train_losses,fold_valid_losses)):
            plt.subplot(2,5,i+1)
            plt.title(f"Fold {i}")
            plt.plot(t,label="train_loss")
            plt.plot(v,label="valid_loss")
            plt.legend()
        plt.show() 

In [None]:
run()