In [None]:
!curl https://raw.githubusercontent.com/pytorch/xla/master/contrib/scripts/env-setup.py -o pytorch-xla-env-setup.py
!python pytorch-xla-env-setup.py --version nightly --apt-packages libomp5 libopenblas-dev

In [None]:
!pip install -q efficientnet_pytorch > /dev/null
!pip install -q albumentations > /dev/null

In [None]:
! pip install pytorch-lightning==0.9.1rc4

In [None]:
# for dirname, _, filenames in os.walk('/kaggle/input'):
#     for filename in filenames:
#         print(os.path.join(dirname, filename))

In [None]:
# Imports
import torch
import torch.nn as nn

import torchvision
import torchvision.transforms as transforms

import torch_xla
import torch_xla.core.xla_model as xm
import torch_xla.distributed.parallel_loader as pl
import torch_xla.distributed.xla_multiprocessing as xmp
import torch_xla.debug.metrics as met

import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)


import cv2
import matplotlib.pyplot as plt
from collections import defaultdict 
from efficientnet_pytorch import EfficientNet

 
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset

from torch import optim
from torchvision import datasets, transforms, models


from sklearn.model_selection import StratifiedKFold, train_test_split
from sklearn.metrics import roc_auc_score
from torchvision.transforms import ToTensor, RandomHorizontalFlip, Resize
from efficientnet_pytorch import EfficientNet
from transformers import AdamW, get_cosine_schedule_with_warmup
from albumentations import *
from albumentations.pytorch import ToTensor
from tqdm import tqdm
import json
import time

# Input data files are available in the read-only "../input/" directory
# For example, running this (by clicking run or pressing Shift+Enter) will list all files under the input directory

import os
# for dirname, _, filenames in os.walk('/kaggle/input'):
#     for filename in filenames:
#         print(os.path.join(dirname, filename))

In [None]:
saved_df = pd.read_csv('/kaggle/input/rsna-train-df-with-jpg/train_df_with_jpg_file_names.csv')
train_df = saved_df[['StudyInstanceUID', 'SeriesInstanceUID', 'SOPInstanceUID', 'pe_present_on_image', 'new_file_names']]
train_df.head(3)

In [None]:
train_df['path'] = '/kaggle/input/rsna-str-pe-detection-jpeg-256/train-jpegs/' \
                        + train_df['StudyInstanceUID'].astype(str) + '/'\
                        + train_df['SeriesInstanceUID'].astype(str) +'/'\
                        + train_df['new_file_names']

In [None]:
image = cv2.imread(train_df.loc[1]['path'])
image.shape

In [None]:
plt.imshow(image)

# RNSA Model building part

In [None]:
# import wandb
# wandb.login(key ='38308edac7fc6d1cdb1f4753fa958995f23cd110')
# causes the training to not end... :(

In [None]:
# run = wandb.init(project='SNAR bare min logger',
#                  name = '1k rows Adam',
#                  notes = 'R EB0 train_df[:600000]',
#                  config={  # and include hyperparameters and metadata
#                      "learning_rate": 1e-3, # from lr rate finder
#                      "epochs": 2
#                  })
# config = wandb.config  # We'll use this to configure our experiment

https://www.kaggle.com/c/rsna-intracranial-hemorrhage-detection/discussion/112290  

In [None]:
import torch
from torch import nn
import pytorch_lightning as pl
from torch.utils.data import DataLoader, random_split
from torch.nn import functional as F
from pytorch_lightning.metrics.functional import accuracy
from torchvision import datasets, transforms
# import os

In [None]:
train_df.shape

In [None]:
# or use it below like
# train_df_subset = train_df[:4000]

In [None]:
class SimpleDataset(Dataset):
            def __init__(self, image_ids_df, labels_df, transform=None):
                self.image_ids = image_ids_df
                self.labels = labels_df
                self.transform = transform

            def __getitem__(self, idx):
                image = cv2.imread(self.image_ids.values[idx])
                label = self.labels.values[idx]

                sample = {
                    'image': image,
                    'label': label
                }

                if self.transform:
                    sample = self.transform(**sample)

                image, label = sample['image'], sample['label']

                return image, label

            def __len__(self):
                return len(self.image_ids)

In [None]:
import os

import torch
from torch.nn import functional as F
from torch.utils.data import DataLoader
from torchvision.datasets import MNIST
from torchvision import transforms
import pytorch_lightning as pl

class RNSAModel(pl.LightningModule):

    def __init__(self,  train_df = train_df):
        super(RNSAModel, self).__init__()
        
        n_channels_dict = {'efficientnet-b0': 1280, 'efficientnet-b1': 1280, 'efficientnet-b2': 1408,
                           'efficientnet-b3': 1536, 'efficientnet-b4': 1792, 'efficientnet-b5': 2048,
                           'efficientnet-b6': 2304, 'efficientnet-b7': 2560}
        self.encoder='efficientnet-b0'
        self.net = EfficientNet.from_pretrained(self.encoder)
        self.df = train_df
        
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.classifier = nn.Sequential(
            nn.Flatten(),
            nn.Linear(in_features=1280, out_features=2, bias=True) # hard code it  for now
        )

    def forward(self, x):
        x = self.net.extract_features(x)
        x = self.avg_pool(x)
        #out =  torch.log_softmax(self.classifier(x))
        out = self.classifier(x)

        return out

    def training_step(self, batch, batch_nb):
        # REQUIRED
        x, y = batch
        y_hat = self(x)
        loss = F.cross_entropy(y_hat, y)
        tensorboard_logs = {'train_loss': loss}
        return {'loss': loss, 'log': tensorboard_logs}

    def validation_step(self, batch, batch_nb):
        # OPTIONAL
        x, y = batch
        y_hat = self(x)
        val_accuracy = accuracy(y_hat,y,num_classes = 2)
        
        return {'val_loss': F.cross_entropy(y_hat, y),
                'val_accuracy' : val_accuracy
               }

    def validation_epoch_end(self, outputs):
        # OPTIONAL
        avg_loss = torch.stack([x['val_loss'] for x in outputs]).mean()
        avg_accuracy = torch.stack([x['val_accuracy'] for x in outputs]).mean()
        
        tensorboard_logs = {'val_loss': avg_loss}
        
        # wandb.log({'val_loss': avg_loss})
        # wandb.log({'val_acc': avg_accuracy})
        print(f"'avg_val_loss': {avg_loss} and 'val_acc': {avg_accuracy}")
        return {'avg_val_loss': avg_loss, 'log': tensorboard_logs}

    def test_step(self, batch, batch_nb):
        # OPTIONAL
        x, y = batch
        y_hat = self(x)
        return {'test_loss': F.cross_entropy(y_hat, y)}

    def test_epoch_end(self, outputs):
        # OPTIONAL
        avg_loss = torch.stack([x['test_loss'] for x in outputs]).mean()
        logs = {'test_loss': avg_loss}
        return {'avg_test_loss': avg_loss, 'log': logs, 'progress_bar': logs}

    def configure_optimizers(self):
        # REQUIRED
        # can return multiple optimizers and learning_rate schedulers
        # (LBFGS it is automatically supported, no need for closure function)
        return torch.optim.Adam(self.parameters(), lr=0.001)
    
    def prepare_data(self):
        # self.mnist_train = MNIST(os.getcwd(), train=True, download=True, transform=transforms.ToTensor())
        # self.mnist_test = MNIST(os.getcwd(), train=False, download=True, transform=transforms.ToTensor())
        image_ids = self.df['path']
        labels = self.df['pe_present_on_image']

        X_train, X_test, y_train, y_test = train_test_split(image_ids, labels, 
                                                            test_size=0.25, 
                                                            random_state=42, 
                                                            stratify =labels)
        
        train_transform = Compose([
            Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225), always_apply=True),
            ToTensor()
            ])
        
        test_transform = Compose([
            Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225), always_apply=True),
            ToTensor()
            ])
        
        self.train_ds = SimpleDataset(X_train, y_train, transform = train_transform)
        self.val_ds = SimpleDataset(X_test, y_test, transform = test_transform)

    def train_dataloader(self):
        loader = DataLoader(self.train_ds, batch_size=16, num_workers=4)
        return loader

    def val_dataloader(self):
        loader = DataLoader(self.val_ds, batch_size=16, num_workers=4)
        return loader

    def test_dataloader(self):
        loader = DataLoader(self.test_ds, batch_size=16, num_workers=4)
        # todo processing of dicom images
        return loader

In [None]:
# del model

In [None]:
train_df.shape

In [None]:
data_batch_num = 1
data_batch_size  = 300000

start = data_batch_num*data_batch_size
end = (data_batch_num+1)*data_batch_size

if end > train_df.shape[0]:
    end = train_df.shape[0]


train_df_subset = train_df[start: end]

(train_df_subset.shape,train_df.shape)

In [None]:
model = RNSAModel(train_df = train_df_subset)

In [None]:
from pytorch_lightning.callbacks import ModelCheckpoint
checkpoint_callback = ModelCheckpoint(filepath='/kaggle/working/checkpoints',
                                        save_top_k=1,
                                        verbose=True,
                                        monitor='val_loss',
                                        mode='min',
                                        prefix=''
                                     )

In [None]:
from pathlib import Path
Path('/kaggle/working/checkpoints').mkdir(parents = True, exist_ok =True)

In [None]:
ls /kaggle/input/rsna-pytorch-lightning-chkpts

In [None]:
# run batch size scaling, result overrides hparams.batch_size
# trainer = pl.Trainer(tpu_cores=8, 
#                      progress_bar_refresh_rate=20,
#                      max_epochs=2,
#                      auto_scale_batch_size='binsearch',# not available
#                      checkpoint_callback=checkpoint_callback,
#                      default_root_dir='/kaggle/working/checkpoints'
#                     )

trainer = pl.Trainer(
                     tpu_cores=8, 
#                      progress_bar_refresh_rate=20,
                     max_epochs=2,
#                      auto_scale_batch_size='binsearch',# not available
                     checkpoint_callback=checkpoint_callback,
                     resume_from_checkpoint='/kaggle/input/rsna-pytorch-lightning-chkpts/zero_300k_pass1_checkpoints-v0.ckpt',
                     default_root_dir='/kaggle/working/checkpoints'
                    )

In [None]:
%%time
# call tune to find the batch size
trainer.fit(model)

In [None]:
# MyModel = RNSAModel(train_df = train_df[:1000])

In [None]:
ls .

In [None]:
# # new_model = MyModel.load_from_checkpoint(checkpoint_path="/kaggle/working/checkpoints.ckpt")
# new_model = RNSAModel()
# new_model.load_from_checkpoint("/kaggle/working/checkpoints.ckpt")

In [None]:
# run batch size scaling, result overrides hparams.batch_size
# trainer = pl.Trainer(tpu_cores=8, 
#                      progress_bar_refresh_rate=20,
#                      max_epochs=2,
#                      auto_scale_batch_size='binsearch',# not available
#                      checkpoint_callback=checkpoint_callback,
#                      default_root_dir='/kaggle/working/checkpoints'
#                     )

In [None]:
# trainer = pl.Trainer(resume_from_checkpoint='/kaggle/working/checkpoints.ckpt')

In [None]:
# trainer.fit(model = new_model)

In [None]:
# def weights_update(model, checkpoint):
#     model_dict = model.state_dict()
#     pretrained_dict = {k: v for k, v in checkpoint['state_dict'].items() if k in model_dict}
#     model_dict.update(pretrained_dict)
#     model.load_state_dict(model_dict)
#     return model

In [None]:
# # https://github.com/PyTorchLightning/pytorch-lightning/issues/924
# model = weights_update(model=EfficientNet.from_pretrained('efficientnet-b0'),
#                        checkpoint=torch.load("/kaggle/working/checkpoints.ckpt"))