<div class="alert alert-info">
    <h1 align='center'>Brain Tumor Video Classification using PyTorch + W&B Tracking ‚ú®</h1>
</div>

<p style='text-align: center'>
I'm approaching this problem as a video classification problem, using all images in the FLAIR folder. Thanks to Ayush Thakur's <a href='https://www.kaggle.com/ayuraj/train-brain-tumor-as-video-classification-w-b'>Notebook</a> doing the same but in Tensorflow.<br>
I have also used Weights and Biases tracking to keep track of the training process and the experiments I am conducting.
</p>

<div style='text-align: center'>
    <strong>You can upvote this kernel, if you found it useful!</strong>
</div>

<center><img src="https://i.imgur.com/gb6B4ig.png" width="400" alt="Weights & Biases"/></center><br>
<p style="text-align:center">WandB is a developer tool for companies turn deep learning research projects into deployed software by helping teams track their models, visualize model performance and easily automate training and improving models.
We will use their tools to log hyperparameters and output metrics from your runs, then visualize and compare results and quickly share findings with your colleagues.<br><br></p>

![img](https://i.imgur.com/BGgfZj3.png)

<div class="alert alert-success">
    <h2 align='center'>üìî Imports and Installation</h2>
</div>

In [None]:
%%sh
pip install -q timm
pip install -q einops
pip install -q rich
pip install -q wandb --upgrade;

In [None]:
import os
import sys
import re
import gc
import platform
import random

import numpy as np
import pandas as pd
from tqdm import tqdm

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
from torch.cuda.amp import autocast, GradScaler
from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts, CosineAnnealingLR, ReduceLROnPlateau

import einops

from sklearn.metrics import roc_auc_score
from sklearn.model_selection import StratifiedKFold, train_test_split

import timm
import glob
import cv2

from rich import print as _pprint
from rich.progress import track

import albumentations as A
from albumentations.pytorch import ToTensorV2

import wandb

import warnings
warnings.simplefilter('ignore')

<div class="alert alert-success">
    <h2 align='center'>‚õΩ Utility Functions </h2>
</div>

In [None]:
def get_patient_id(patient_id):
    if patient_id < 10:
        return '0000'+str(patient_id)
    elif patient_id >= 10 and patient_id < 100:
        return '000'+str(patient_id)
    elif patient_id >= 100 and patient_id < 1000:
        return '00'+str(patient_id)
    else:
        return '0'+str(patient_id)

def get_path(row):
    patient_id = get_patient_id(row.BraTS21ID)
    return f'../input/rsna-miccai-png/train/{patient_id}/FLAIR/'

def wandb_log(**kwargs):
    """
    Logs a key-value pair to W&B
    """
    for k, v in kwargs.items():
        wandb.log({k: v})
        
def cprint(string):
    """
    Utility function for beautiful colored printing.
    """
    _pprint(f"[black]{string}[/black]")

<div class="alert alert-success">
    <h2 align='center'>üöÄ Config Dictionary and W&B Integration </h2>
</div>

In [None]:
Config = dict(
    MAX_FRAMES = 12,
    EPOCHS = 5,
    LR = 2e-4,
    IMG_SIZE = (224, 224),
    FEATURE_EXTRACTOR = 'resnext50_32x4d',
    DR_RATE = 0.35,
    NUM_CLASSES = 1,
    RNN_HIDDEN_SIZE = 100,
    RNN_LAYERS = 1,
    TRAIN_BS = 4,
    VALID_BS = 4,
    NUM_WORKERS = 4,
    infra = "Kaggle",
    competition = 'rsna_miccai',
    _wandb_kernel = 'tanaym'
)

To login to W&B, you can use below snippet.

```python
from kaggle_secrets import UserSecretsClient
user_secrets = UserSecretsClient()
wb_key = user_secrets.get_secret("WANDB_API_KEY")

wandb.login(key=wb_key)
```
Make sure you have your W&B key stored as `WANDB_API_KEY` under Add-ons -> Secrets

You can view [this](https://www.kaggle.com/ayuraj/experiment-tracking-with-weights-and-biases) notebook to learn more about W&B tracking.

If you don't want to login to W&B, the kernel will still work and log everything to W&B in anonymous mode.

In [None]:
run = wandb.init(
    project='pytorch',
    config=Config,
    group='vision',
    job_type='train',
    anonymous='allow'
)

In [None]:
class Augments:
    """
    Contains Train, Validation Augments
    """
    train_augments = A.Compose([
        ToTensorV2(p=1.0),
    ],p=1.)
    
    valid_augments = A.Compose([
        ToTensorV2(p=1.0),
    ], p=1.)

<div class="alert alert-success">
    <h2 align='center'>üíª Custom Dataset Class</h2>
</div>

<div class="alert alert-block alert-info" style="font-size:14px; font-family:verdana; line-height: 1.7em;">
    üìå In this custom Dataset, I am essentially reading "MAX_FRAMES" number of images from a patient's FLAIR folder and making list of those frames and converting it to torch tensor.
</div>

In [None]:
class RSNADataset(Dataset):
    def __init__(self, df, augments=None, is_test=False):
        self.df = df
        self.augments = augments
        self.is_test = is_test
        
    def __getitem__(self, idx):
        row = self.df.loc[idx]
        paths = self.getPaths(row)
        frames = []
        for path in paths:
            img = cv2.imread(path)
            img = cv2.resize(img, Config['IMG_SIZE'])
            img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
            frames.append(img)

        frames_tr = np.stack(frames, axis=2)
        
        if self.augments:
            for frame in frames:
                frame = self.augments(image=frame)['image']
                frames_tr.append(frame)
            
        if self.is_test:
            return frames_tr
        else:
            label = torch.tensor(row['MGMT_value']).float()
            return frames_tr, label
        
    def __len__(self):
        return len(self.df)
    
    def getPaths(self, row):
        paths = glob.glob(row['path'] + '*.png')
        sortedPaths = self.sort(paths)
        maxWindowStart = len(sortedPaths) - Config['MAX_FRAMES']
        start = 0 # np.random.randint(1, maxWindowStart)
        paths = sortedPaths[start:Config['MAX_FRAMES']]
        
        return paths
        
    def sort(self, entry):
        # https://stackoverflow.com/a/2669120/7636462
        convert = lambda text: int(text) if text.isdigit() else text 
        alphanum_key = lambda key: [convert(c) for c in re.split('([0-9]+)', key)]
    
        return sorted(entry, key = alphanum_key)

<div class="alert alert-success">
    <h2 align='center'>üìà Model Class with ResNext Backbone</h2>
</div>

In [None]:
class ResNextModel(nn.Module):
    def __init__(self):
        super(ResNextModel, self).__init__()
        self.backbone = timm.create_model(Config['FEATURE_EXTRACTOR'], pretrained=True, in_chans=1)
    def forward(self, x):
        return self.backbone(x)

class Identity(nn.Module):
    def __init__(self):
        super(Identity, self).__init__()
    def forward(self, x):
        return x

class RSNAModel(nn.Module):
    def __init__(self, pretrained=True):
        super(RSNAModel, self).__init__()
        self.backbone = ResNextModel()
        num_features = self.backbone.backbone.fc.in_features
        
        self.backbone.backbone.fc = Identity()
        self.dropout= nn.Dropout(Config['DR_RATE'])
        self.rnn = nn.LSTM(num_features, Config['RNN_HIDDEN_SIZE'], Config['RNN_LAYERS'])
        self.fc1 = nn.Linear(Config['RNN_HIDDEN_SIZE'], Config['NUM_CLASSES'])
        
    def forward(self, x):
        b_z, fr, h, w = x.shape
        ii = 0
        in_pass = x[:, ii].unsqueeze_(1)
        y = self.backbone((in_pass))
        output, (hn, cn) = self.rnn(y.unsqueeze(1))
        for ii in range(1, fr):
            y = self.backbone((x[:, ii].unsqueeze_(1)))
            out, (hn, cn) = self.rnn(y.unsqueeze(1), (hn, cn))
        out = self.dropout(out[:, -1])
        out = self.fc1(out)
        return out

<div class="alert alert-success">
    <h2 align='center'>üè¥‚Äç‚ò†Ô∏è Training and Validation Functions</h2>
</div>

In [None]:
def train_one_epoch(model, train_dataloader, optimizer, loss_fn, epoch, device, log_wandb=True, verbose=False):
    """
    Trains model for one epoch
    """
    model.train()
    running_loss = 0
    prog_bar = tqdm(enumerate(train_dataloader), total=len(train_dataloader))
    for batch, (frames, targets) in prog_bar:
        optimizer.zero_grad()
        
        frames = frames.to(device, torch.float)
        targets = targets.to(device, torch.float)
        
        # Re arrange the frames in the format our model wants to recieve
        frames = einops.rearrange(frames, 'b h w f -> b f h w')

        preds = model(frames).view(-1)
        loss = loss_fn(preds, targets)
        
        loss.backward()
        optimizer.step()
        
        loss_item = loss.item()
        running_loss += loss_item
        
        prog_bar.set_description(f"loss: {loss_item:.4f}")
        
        if log_wandb == True:
            wandb_log(
                batch_train_loss=loss_item
            )
        
        if verbose == True and batch % 20 == 0:
            print(f"Batch: {batch}, Loss: {loss_item}")
    
    avg_loss = running_loss / len(train_dataloader)
    
    return avg_loss

@torch.no_grad()
def valid_one_epoch(model, valid_dataloader, loss_fn, epoch, device, log_wandb=True, verbose=False):
    """
    Validates the model for one epoch
    """
    model.eval()
    running_loss = 0
    prog_bar = tqdm(enumerate(valid_dataloader), total=len(valid_dataloader))
    for batch, (frames, targets) in prog_bar:
        frames = frames.to(device, torch.float)
        targets = targets.to(device, torch.float)
        
        # Re arrange the frames in the format our model wants to recieve
        frames = einops.rearrange(frames, 'b h w f -> b f h w')
        preds = model(frames).view(-1)
        loss = loss_fn(preds, targets)
        
        loss_item = loss.item()
        running_loss += loss_item
        
        prog_bar.set_description(f"val_loss: {loss_item:.4f}")
        
        if log_wandb == True:
            wandb_log(
                batch_val_loss=loss_item
            )
        
        if verbose == True and batch % 10 == 0:
            print(f"Batch: {batch}, Loss: {loss_item}")
    
    avg_val_loss = running_loss / len(valid_dataloader)
    
    return avg_val_loss

<div class="alert alert-success">
    <h2 align='center'>üèó Training and Validating the Model</h2>
</div>

In [None]:
if __name__ == "__main__":
    log_wandb = True
    if torch.cuda.is_available():
        print("Using GPU: {}\n".format(torch.cuda.get_device_name()))
        device = torch.device('cuda')
    else:
        print("\nGPU not found. Using CPU: {}\n".format(platform.processor()))
        device = torch.device('cpu')
    
    
    # Load training csv file
    df = pd.read_csv('../input/rsna-miccai-brain-tumor-radiogenomic-classification/train_labels.csv')
    df['path'] = df.apply(lambda row: get_path(row), axis=1)

    # Removing two patient ids from the dataframe since there are not FLAIR directories for these ids. 
    df = df.loc[df.BraTS21ID!=109]
    df = df.loc[df.BraTS21ID!=709]
    df = df.reset_index(drop=True)
    
    train_df, valid_df = train_test_split(df, test_size=0.1, stratify=df.MGMT_value.values)
    train_df = train_df.reset_index(drop=True)
    valid_df = valid_df.reset_index(drop=True)
    
    print(f'Size of Training Set: {len(train_df)}, Validation Set: {len(valid_df)}')
    
    model = RSNAModel()
    model = model.to(device)
    optimizer = torch.optim.AdamW(model.parameters(), lr=Config['LR'])

    train_loss_fn = nn.BCEWithLogitsLoss()
    valid_loss_fn = nn.BCEWithLogitsLoss()
    
    print(f"\nUsing Backbone: {Config['FEATURE_EXTRACTOR']}")
    
    train_data = RSNADataset(train_df)
    valid_data = RSNADataset(valid_df)
    
    train_loader = DataLoader(
        train_data,
        batch_size=Config['TRAIN_BS'], 
        shuffle=True,
        num_workers=Config['NUM_WORKERS']
    )
    
    valid_loader = DataLoader(
        valid_data, 
        batch_size=Config['VALID_BS'], 
        shuffle=False,
        num_workers=Config['NUM_WORKERS']
    )
    
    current_loss = 1000
    for epoch in range(Config['EPOCHS']):
        print(f"\n{'--'*8} EPOCH: {epoch+1} {'--'*8}\n")
        
        train_loss = train_one_epoch(model, train_loader, optimizer, train_loss_fn, epoch=epoch, device=device, log_wandb=log_wandb)
        
        valid_loss = valid_one_epoch(model, valid_loader, valid_loss_fn, epoch=epoch, device=device, log_wandb=log_wandb)
        
        print(f"val_loss: {valid_loss:.4f}")
        
        if log_wandb == True:
            wandb_log(
                train_loss=train_loss,
                valid_loss=valid_loss
            )
        
        if valid_loss < current_loss:
            current_loss = valid_loss
            torch.save(model.state_dict(), f"model_{Config['FEATURE_EXTRACTOR']}.pt")

<h3><a href="https://wandb.ai/anony-mouse-125639/pytorch/runs/23gv9jk5?apiKey=6b04b2e314f0ee65d4e8bdc3aa267c124d7624a9">View the complete dashboard here ‚ú®</a></h3>

![Results](https://i.imgur.com/gcn59xp.gif)