I normally tend to use Pytorch as my framework but I wanted to try Pytorch Lightning after I heard about it. 
Pytorch Lightning helps with structured Pytorch Code and scalability on Multi GPU environment.

Pytorch lighntning is designed to help you easily follow a pytorch based training loop and ease modifications that you may want. Want to use a new scheduler ? Then simply modify the configure_optimizer method ! The beauty of it is that it automates all the boring stuff that clogs a pure pytorch code. All these loops, .zero_grad(), .eval(), torch.save etc. are gone and handled by the framework. You just have to focus on the ML part of it. The best things for researchers is that it comes with automated logs through tensorboard to compare your many experiments and easy switches between GPU, DataParallel, TPU mixed precision etc.

Torch Metrics along with lightning has made metric handling way easier, with a simple intialisation of metric and variable reference makes the task of updating way easier. <BR>
The concept of Log helps to keep track of metrics easily as a progressbar or as a Logger like Tensorboard.

The whole kernel was made possible using the following references:
* Lightning Docs: [Lighting Docs](https://pytorch-lightning.readthedocs.io/en/latest/)
* Lightning based Kernel on Kaggle: [Arnaud's Kernel](https://www.kaggle.com/arroqc/siim-isic-pytorch-lightning-starter-seresnext50/notebook)
* Preprocessing of DataFrame: [Tanay's Kernel](https://www.kaggle.com/heyytanay/siim-pytorch-classification-only-training-effnets)
* Dicom to image conversion : [Xhlulu's Kernel](https://www.kaggle.com/xhlulu/siim-covid-19-convert-to-jpg-256px)
* Logic behind the metric from: [This Thread](https://www.kaggle.com/c/siim-covid19-detection/discussion/241300)


# Installing Dependencies and Libraries

In [None]:
!cp /kaggle/input/gdcm-conda-install/gdcm.tar .
!tar -xvzf gdcm.tar
!conda install --offline ./gdcm/gdcm-2.8.9-py37h71b2a6d_0.tar.bz2
!rm -rf ./gdcm.tar

In [None]:
import numpy as np 
import pandas as pd 
import os
import cv2
from tqdm import tqdm 
import glob
import sys

import torch 
import torch.nn as nn 
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import albumentations as A
from albumentations.pytorch import ToTensorV2

import pytorch_lightning as pl
import torchmetrics
from pytorch_lightning import loggers as pl_loggers


sys.path.append("../input/timm-pytorch-image-models/pytorch-image-models-master")
import timm

sys.path.append("../input/iterative-stratification/iterative-stratification-master")
from iterstrat.ml_stratifiers import MultilabelStratifiedKFold

import pydicom
from pydicom.pixel_data_handlers.util import apply_voi_lut
from sklearn.metrics import roc_auc_score

# Configurations 

In [None]:
TRAIN_BATCH_SIZE = 4
VAL_BATCH_SIZE = 8
IMG_SIZE = 512
LABELS = ['Negative for Pneumonia','Typical Appearance', 'Indeterminate Appearance', 'Atypical Appearance']
NUM_FOLDS = 6
MODEL_NAME = 'efficientnet_b7'
FOLD = 3
EPOCHS = 3
STUDY_LEN = 0

This function preprocesses the csv file with the labels provided 
* The function takes train as an argument. When true indicates the data to be processed is for train else to be processed for test
* When handling test data the sample_predictions dataframe has a mix of classification and detection records. Therefore the records are segregated using a map function

The training part of the function takes reference from: [Tanay's Kernel](https://www.kaggle.com/heyytanay/siim-pytorch-classification-only-training-effnets)

In [None]:
def preprocess_df(train = True):
    if train:
        df_image = pd.read_csv("../input/siim-covid19-detection/train_study_level.csv")
        df_det = pd.read_csv("../input/siim-covid19-detection/train_image_level.csv")
        df_image['StudyInstanceUID'] = df_image['id'].apply(lambda x : x[:-6])
        df = df_det.merge(df_image, on='StudyInstanceUID')
        path = []
        TRAIN_DIR = "../input/siim-covid19-detection/train/"
        for instance_id in tqdm(df['StudyInstanceUID']):
            path.append(glob.glob(os.path.join(TRAIN_DIR, instance_id +"/*/*"))[0])
        df['path'] = path
        df = df.drop(['id_x', 'id_y'], axis=1)
        return df
    
#     else:
#         df= pd.read_csv("../input/siim-covid19-detection/sample_submission.csv")
#         study_indices = df['id'].apply(lambda x : x[-6:])
#         STUDY_LEN =0
#         for i in range(len(study_indices)):
#             if study_indices[i] == '_image':
#                 STUDY_LEN = i
#                 break
            
#         df['StudyInstanceUID'] = df['id'].apply(lambda x : x[:-6])
#         df = df.iloc[:STUDY_LEN,:]
#         path = []
#         TEST_DIR = "../input/siim-covid19-detection/test"
#         for instance_id in tqdm(df['StudyInstanceUID']):
#             path.append(glob.glob(os.path.join(TEST_DIR, instance_id +"/*/*"))[0])
#         df['path'] = path

#         return df,STUDY_LEN
    

In [None]:
def dicom2array(path, voi_lut=True, fix_monochrome=True):
    dicom = pydicom.read_file(path)
    # VOI LUT (if available by DICOM device) is used to
    # transform raw DICOM data to "human-friendly" view
    if voi_lut:
        data = apply_voi_lut(dicom.pixel_array, dicom)
    else:
        data = dicom.pixel_array
    # depending on this value, X-ray may look inverted - fix that:
    if fix_monochrome and dicom.PhotometricInterpretation == "MONOCHROME1":
        data = np.amax(data) - data
    data = data - np.min(data)
    data = data / np.max(data)
    data = (data * 255).astype(np.uint8)
    return data

In [None]:
#Stratified KFold
def stratifiedKFold(df,num_folds,random_state):
    y = df [['Negative for Pneumonia','Typical Appearance', 
            'Indeterminate Appearance', 'Atypical Appearance']]
    df['fold'] = 0
    #split data
    mskf = MultilabelStratifiedKFold(n_splits=num_folds, shuffle= True, random_state=random_state)
    for i, (_, test_index) in enumerate(mskf.split(df, y)):
        df.iloc[test_index, -1] = i
    return df

In [None]:
#Torch dataset, not trying any augmentations as of now
class SIIM_COVID(Dataset):
    def __init__(self,df,
                 train = True,
                 transforms=None,
                 IMG_SIZE = 256
                ):
        self.imageList = df['path'].values
        self.transform = None
        if transforms is None:
            self.transform = A.Compose([
                                            A.Resize(IMG_SIZE,IMG_SIZE),
                                            A.Normalize(
                                                mean=[0.485, 0.456, 0.406],
                                                std=[0.229, 0.224, 0.225],
                                            ),
                                            ToTensorV2()
                                        ])
        else:
            self.transform = transforms
        self.train  = train    
        if self.train == True:
            self.labels = df[LABELS].values
    
    def __len__(self):
        return len(self.imageList)
    
    def __getitem__(self,idx):
        file_path = self.imageList[idx]
        img = dicom2array(file_path)
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        image = self.transform(image=img)
        if self.train == True:
            image = image['image']
            label = self.labels[idx]
            return image, label
        else:
            return image

In [None]:
class EffNet(nn.Module):#Efficientnet defining it seperately, helps in inference where the weights can be loaded seperately instead as a lightning module class 
    def __init__(self,model_name,
                 num_classes,
                 pretrained = True
                ):
        super().__init__()
        self.model = timm.create_model(model_name,pretrained = pretrained )
        n_features = self.model.classifier.in_features
        self.model.global_pool = nn.Identity()
        self.model.classifier = nn.Identity()
        self.pooling = nn.AdaptiveAvgPool2d(1)
        self.fc = nn.Linear(n_features,num_classes)
        self.soft = nn.Softmax(dim=1)
        
    def forward(self,x): 
        bs = x.size(0) # bs -> batch size
        features = self.model(x)
        pooled_features = self.pooling(features).view(bs,-1)
        output = self.fc(pooled_features) 
        output = self.soft(output)# Keeping a softmax layer to keep the probabilities between 0 and 1 to avoid problems with torch metrics

        return output 

In [None]:
class LitSIIM(pl.LightningModule):
    def __init__(self,model,
                 df,fold_no,
                 train_transforms = None,
                 valid_transforms = None
                ):
        super(LitSIIM,self).__init__()
        self.model = model
        
        self.train_transforms = train_transforms 
        self.valid_transforms = valid_transforms

        self.train_dataset = SIIM_COVID(df[df['fold'] != fold_no] , train = True,
                                        transforms = None, IMG_SIZE = IMG_SIZE)
        self.valid_dataset = SIIM_COVID(df[df['fold'] == fold_no] , train = True,# To pass both image and targets 
                                        transforms = None, IMG_SIZE = IMG_SIZE)
                
        self.train_loss = nn.BCEWithLogitsLoss()
        self.valid_loss = nn.BCEWithLogitsLoss()
        
        self.train_avg_prec = torchmetrics.AveragePrecision(num_classes = 4,pos_label = 1)
        self.valid_avg_prec = torchmetrics.AveragePrecision(num_classes = 4, pos_label =1)
        
        self.learning_rate = 0.001
        
    def forward(self,batch):
        return self.model(batch)
    
    def configure_optimizers(self):
        return torch.optim.Adam(self.model.parameters(), lr=1e-4, weight_decay=3e-6)
    
    def prepare_data(self):
        pass
    
    
    def train_dataloader(self):
        train_loader = DataLoader(self.train_dataset,
                                 batch_size =  TRAIN_BATCH_SIZE,
                                 shuffle = False,
                                 sampler = None, 
                                 num_workers = os.cpu_count()
                                 )

        return train_loader
    
    def val_dataloader(self):
        val_loader = DataLoader(self.valid_dataset,
                               batch_size = VAL_BATCH_SIZE,
                               shuffle = False,
                               num_workers = os.cpu_count()
                               )
        return val_loader
    
    def training_step(self,batch,batch_idx):
        image,labels = batch 
        logits = self(image)
        loss = self.train_loss(logits,labels.type_as(logits))
        train_avg_prec_batch = (2/3) * self.train_avg_prec(logits,labels)
        self.log("train_loss_batch",loss,prog_bar = True)
        self.log("train_avg_prec_batch",train_avg_prec_batch,prog_bar = True)
        return {
            'loss':loss,
            'y_pred':logits,
            'y_true':labels
        }
    
    def training_epoch_end(self,outputs):
        avg_prec = (2/3) * self.train_avg_prec.compute()
        print(len(outputs))
        print(f"Train Loss Epoch_{self.current_epoch + 1}: {outputs[len(outputs)-1]['loss']}")
        print(f"Train Average Precision Epoch {self.current_epoch +1}:{avg_prec}")
        self.log("train_avg_prec_epoch",avg_prec,prog_bar = True)
    
    def validation_step(self,batch,batch_idx):
        image,labels = batch 
        logits = self(image)
        loss = self.valid_loss(logits,labels.type_as(logits))
        valid_avg_prec_batch = self.valid_avg_prec(logits,labels)
        self.log("val_loss",loss,prog_bar = True)
        self.log("val_avg_prec_batch",(2/3) * valid_avg_prec_batch,prog_bar = True) 
        return {
            'loss'   : loss,
            'y_pred' : logits,
            'target' : labels
        }
        
    def validation_epoch_end(self,outputs):
        avg_prec = (2/3)*self.valid_avg_prec.compute()
        print(f"Train Loss Epoch_{self.current_epoch + 1}: {outputs[len(outputs)-1]['loss']}")        
        print(f"Valid Average Precision Epoch_{self.current_epoch+1} End:{avg_prec}")
        self.log(f'val_avg_prec_epoch',avg_prec,prog_bar = True)
#         PATH = f"model_state_dict/B5_E{self.current_epoch+1}.pth"
#         torch.save(self.model.state_dict(), PATH)
        return {'val_loss': outputs[0]['loss'], 'val_avg_prec_epoch': avg_prec}
    
    def test_step(self,batch,batch_idx):
        logits = self(batch['image'])
        return logits

    def test_epoch_end(self, outputs):
        probs = torch.cat(outputs,dim = 0)
        probs = probs.detach().cpu().numpy()
        self.test_predicts = probs  
        return {'dummy_item': 0}

In [None]:
from pytorch_lightning.callbacks.progress import ProgressBar

class LitProgressBar(ProgressBar):
    def init_train_tqdm(self):
        bar = super().init_train_tqdm()
        bar.leave = True
        return bar
        
    def init_validation_tqdm(self):
        bar = super().init_validation_tqdm()
        bar.set_description('Valid')
        return bar

        
    def training_epoch_end(self, outputs):
        self.trainer.progress_bar_callback.main_progress_bar.write(
            f"Epoch {self.trainer.current_epoch + 1} training loss={self.trainer.progress_bar_dict['train_loss']}" +
            f"ROC AUC={self.trainer.progress_bar_dict['train_avg_prec_epoch']}"
        )

    def validation_epoch_end(self, outputs):
        loss = torch.stack(outputs).mean()
        self.trainer.progress_bar_callback.main_progress_bar.write(
            f"Epoch {self.trainer.current_epoch + 1} validation loss={self.trainer.progress_bar_dict['val_loss']}" +
            f"Valid ROC AUC= {self.trainer.progress_bar_dict['val_avg_prec_epoch']}"
        )



In [None]:
! mkdir output
! mkdir model_state_dict # if the model weights are to be stored in model.load_state_dict()  fashion

In [None]:
if __name__ == '__main__':
    df = preprocess_df(train = True)
    df = stratifiedKFold(df = df,num_folds = NUM_FOLDS,random_state = 42)
    
    model = EffNet(
                    model_name = MODEL_NAME,
                    num_classes= len(LABELS),
                    pretrained = True
                )
    lit = LitSIIM(
            model = model,
            df = df,fold_no = FOLD,
            train_transforms = None,
            valid_transforms = None
    )
    checkpoint_callback = pl.callbacks.ModelCheckpoint(
                                        monitor = 'val_avg_prec_epoch',
                                        dirpath = "output/",
                                        save_top_k = 2, # Change this to the desired number of saves you need. Here the number of epochs is 3 so kept it as 2.
                                        mode = 'max', # Save models based on the best Validation Average Precision
                                        filename =  '{epoch}_effnetb5_{val_avg_prec_epoch:.3f}_{val_loss:.3f}',
    )
    
    trainer = pl.Trainer(
                         max_epochs=EPOCHS, 
                         gradient_clip_val=1,
                         callbacks=[checkpoint_callback,LitProgressBar()],
                         gpus=1,
                         progress_bar_refresh_rate=0
                     )
    trainer.fit(lit)