In [1]:
import os
import torch
from torch import nn
import torch.nn.functional as F
from torchvision import transforms
import torchvision.models as models
from torch.utils.data import Dataset, DataLoader, random_split

In [2]:
import pytorch_lightning as pl
import pandas as pd
from PIL import Image
from sklearn import metrics

In [73]:
import cv2
import seaborn as sns
import numpy as np
import matplotlib.pyplot as plt

## Define Model

In [3]:
class Classifier(pl.LightningModule):
    def __init__(self):
        super(Classifier, self).__init__()
        self.model = models.resnet18(pretrained=False)
        # change 1st conv layer from 3 channel to 1 channel
        self.model.conv1 = nn.Conv2d(1, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
        #change to single output
        self.model.fc = nn.Linear(self.model.fc.in_features, 1)
        self.BCELoss = nn.BCEWithLogitsLoss()
        self.sigmoid = nn.Sigmoid()
        
    def forward(self, x):
        x = self.model(x)
        return x
    
    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.model.parameters(), lr=1e-3)
        return optimizer
    
    def BCE_loss(self, logits, labels):
        return self.BCELoss(logits, labels.float())
    
    def training_step(self, train_batch, batch_idx):
        x, y = train_batch
        logits = self.forward(x)   # we already defined forward and loss in the lightning module. We'll show the full code next
        logits = torch.flatten(logits)
        loss = self.BCE_loss(logits, y)
        return loss

    def validation_step(self, val_batch, batch_idx):
        x, y = val_batch
        logits = self.forward(x)
        logits = torch.flatten(logits)
        loss = self.BCE_loss(logits, y)
        self.log('val_loss', loss)
        
    def predict(self, x):
        x = self.model(x)
        x = self.sigmoid(x)
        return x

In [245]:
class CXRDataset(Dataset):
    def __init__(self, excel_path : str = "./", img_dir : str = "./",
                 img_brt_std : float = 0.10448302):
        """
        Args:
            excel_path (string): Path to excel file with ids and labels.
            img_dir (string): Directory with all the png images.
            img_brt_std (float): Standard deviation of brightness of training dataset.
        """
        self.dataset = self.prepare_dataset(excel_path)
        self.img_dir = img_dir
        self.img_brt_std = img_brt_std
        self.transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.5], std=[0.5])
        ])
        
    def prepare_dataset(self, excel_path):
        df = pd.read_excel(excel_path)
        df = df.sort_values(by='Abnormal', ascending=False)
        df = df.drop_duplicates(subset="image_id", keep="first")
        df = df.sample(frac=1).reset_index(drop=True)
        return df
        
    def __len__(self):
        return len(self.dataset)
    
    def __getitem__(self, idx):
        img_path = os.path.join(self.img_dir, self.dataset['image_id'].iloc[idx] + '.png')
        image = Image.open(img_path)
        image = self.transform(image)
        label = self.dataset['Abnormal'].iloc[idx]
        uuid = self.dataset['image_id'].iloc[idx]
        return image, label

In [246]:
class DataModule(pl.LightningDataModule):
    def __init__(self, img_dir: str = "./", train_file: str = "./",
                 val_file:str="./", batch_size: int = 24, num_workers: int = 0):
        """
        Args:
            img_dir (string): Directory with all the png images.
            train_file (string): path to train image excel
            val_file (string): path to validation image excel
            batch_size (int): batch size for training
        """
        super().__init__()
        self.img_dir = img_dir
        self.train_file = train_file
        self.val_file = val_file
        self.batch_size = batch_size
        self.num_workers = num_workers
    
    def prepare_data(self):
        self.train_data = CXRDataset(self.train_file, self.img_dir)
        self.val_data = CXRDataset(self.val_file, self.img_dir)
        
    def train_dataloader(self):
        return DataLoader(self.train_data, batch_size=self.batch_size,
                          num_workers=self.num_workers)

    def val_dataloader(self):
        return DataLoader(self.val_data, batch_size=self.batch_size,
                         num_workers=self.num_workers)
        

## Train Model

In [247]:
dataset = DataModule('imgs/', './train.xlsx', './test.xlsx')
dataset.prepare_data()

In [251]:
model = Classifier()

In [252]:
trainer = pl.Trainer(checkpoint_callback=False, max_epochs=8)

GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs


In [253]:
trainer.fit(model, dataset.train_dataloader(), dataset.val_dataloader())


  | Name    | Type              | Params
----------------------------------------------
0 | model   | ResNet            | 11.2 M
1 | BCELoss | BCEWithLogitsLoss | 0     
2 | sigmoid | Sigmoid           | 0     
----------------------------------------------
11.2 M    Trainable params
0         Non-trainable params
11.2 M    Total params
44.683    Total estimated model params size (MB)


Validation sanity check:   0%|          | 0/2 [00:00<?, ?it/s]

  rank_zero_warn(


                                                                      

  rank_zero_warn(


Epoch 0:  83%|████████▎ | 417/501 [25:01<05:01,  3.59s/it, loss=0.341, v_num=6]
Validating: 0it [00:00, ?it/s][A
Validating:   0%|          | 0/84 [00:00<?, ?it/s][A
Epoch 0:  84%|████████▎ | 419/501 [25:02<04:53,  3.58s/it, loss=0.341, v_num=6]
Validating:   2%|▏         | 2/84 [00:01<01:13,  1.12it/s][A
Epoch 0:  84%|████████▍ | 421/501 [25:04<04:45,  3.56s/it, loss=0.341, v_num=6]
Validating:   5%|▍         | 4/84 [00:03<01:12,  1.11it/s][A
Epoch 0:  84%|████████▍ | 423/501 [25:06<04:37,  3.55s/it, loss=0.341, v_num=6]
Validating:   7%|▋         | 6/84 [00:05<01:10,  1.11it/s][A
Epoch 0:  85%|████████▍ | 425/501 [25:07<04:29,  3.54s/it, loss=0.341, v_num=6]
Validating:  10%|▉         | 8/84 [00:07<01:08,  1.11it/s][A
Epoch 0:  85%|████████▌ | 427/501 [25:09<04:21,  3.53s/it, loss=0.341, v_num=6]
Validating:  12%|█▏        | 10/84 [00:08<01:06,  1.11it/s][A
Epoch 0:  86%|████████▌ | 429/501 [25:11<04:13,  3.51s/it, loss=0.341, v_num=6]
Validating:  14%|█▍        | 12/84 [00:10

In [255]:
trainer.save_checkpoint("models/abnormal_base_model.ckpt")
#new_model = MyModel.load_from_checkpoint(checkpoint_path="example.ckpt")

## Evaluate Model

In [28]:
def validate(model, val_loader):
    labels = []
    pred_list = []
    for x, y in val_loader:
        #x = x.to(device)
        preds = model.predict(x)
        labels += list(y.detach().cpu().numpy())
        pred_list += list(preds.detach().cpu().numpy())
    return labels, pred_list

def calc_auc(labels, pred_list):
    results = {}
    results['fpr'], results['tpr'], results['thresholds'] = metrics.roc_curve(labels, pred_list)
    results['auc'] = metrics.auc(results['fpr'], results['tpr'])
    return results

In [256]:
labels, pred_list = validate(model, dataset.val_dataloader())

In [257]:
results = calc_auc(labels, preds)

In [259]:
results['auc']

0.5060999507401561

In [31]:
results['auc']

0.9609576784830366