<a href="https://colab.research.google.com/github/tanmayg/LS/blob/master/FB_Covid.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install pytorch_lightning

In [43]:
import logging
import os
from argparse import ArgumentParser
from pathlib import Path
from warnings import warn

import numpy as np
import pytorch_lightning as pl
import torch
import yaml
#import argparse

#argparser = argparse.ArgumentParser()

In [44]:
!pwd

/content


In [45]:
from transforms import (
    Compose,
    HistogramNormalize,
    NanToInt,
    RemapLabel,
    TensorToRGB,
)

In [46]:
from xray_datamodule import XrayDataModule

In [47]:
from torchvision import transforms
from sip_finetune import SipModule

In [6]:
#!python train_sip.py --pretrained_file mimic-chexpert_lr_0.01_bs_128_fd_128_qs_65536.pt

In [7]:
argparser.add_argument("--pretrained_file", help="Pretrained File", default="mimic-chexpert_lr_0.01_bs_128_fd_128_qs_65536.pt")
argparser.add_argument("--im_size", default=224, type=int)
argparser.add_argument("--uncertain_label", default=np.nan, type=float)
argparser.add_argument("--nan_label", default=np.nan, type=float)
args = argparser.parse_args(["--pretrained_file", "mimic-chexpert_lr_0.01_bs_128_fd_128_qs_65536.pt"])

In [8]:
args

Namespace(im_size=224, nan_label=nan, pretrained_file='mimic-chexpert_lr_0.01_bs_128_fd_128_qs_65536.pt', uncertain_label=nan)

In [6]:
def build_args(arg_defaults=None):
    pl.seed_everything(1234)
    data_config = Path.cwd() / "data.yaml"
    tmp = arg_defaults
    arg_defaults = {
        "accelerator": "ddp",
        "batch_size": 32,
        "max_epochs": 5,
        "gpus": 1,
        "num_workers": 10,
        "callbacks": [],
    }
    if tmp is not None:
        arg_defaults.update(tmp)

    # ------------
    # args
    # ------------
    parser = ArgumentParser()
    parser.add_argument("--im_size", default=224, type=int)
    parser.add_argument("--uncertain_label", default=np.nan, type=float)
    parser.add_argument("--nan_label", default=np.nan, type=float)
    parser = pl.Trainer.add_argparse_args(parser)
    parser = XrayDataModule.add_model_specific_args(parser)
    parser = SipModule.add_model_specific_args(parser)
    parser.set_defaults(**arg_defaults)
    args = parser.parse_args()

    if args.default_root_dir is None:
        args.default_root_dir = Path.cwd()

    if args.pretrained_file is None:
        warn("Pretrained file not specified, training from scratch.")
    else:
        logging.info(f"Loading pretrained file from {args.pretrained_file}")

    if args.dataset_dir is None:
        with open(data_config, "r") as f:
            paths = yaml.load(f, Loader=yaml.SafeLoader)["paths"]

        if args.dataset_name == "nih":
            args.dataset_dir = paths["nih"]
        if args.dataset_name == "mimic":
            args.dataset_dir = paths["mimic"]
        elif args.dataset_name == "chexpert":
            args.dataset_dir = paths["chexpert"]
        elif args.dataset_name == "mimic-chexpert":
            args.dataset_dir = [paths["chexpert"], paths["mimic"]]
        else:
            raise ValueError("Unrecognized path config.")

    if args.dataset_name in ("chexpert", "mimic", "mimic-chexpert"):
        args.val_pathology_list = [
            "Atelectasis",
            "Cardiomegaly",
            "Consolidation",
            "Edema",
            "Pleural Effusion",
        ]
    elif args.dataset_name == "nih":
        args.val_pathology_list = [
            "Atelectasis",
            "Cardiomegaly",
            "Consolidation",
            "Edema",
            "Effusion",
        ]
    else:
        raise ValueError("Unrecognized dataset.")

    # ------------
    # checkpoints
    # ------------
    checkpoint_dir = Path(args.default_root_dir) / "checkpoints"
    if not checkpoint_dir.exists():
        checkpoint_dir.mkdir(parents=True)
    elif args.resume_from_checkpoint is None:
        ckpt_list = sorted(checkpoint_dir.glob("*.ckpt"), key=os.path.getmtime)
        if ckpt_list:
            args.resume_from_checkpoint = str(ckpt_list[-1])

    args.callbacks.append(
        pl.callbacks.ModelCheckpoint(dirpath=checkpoint_dir, verbose=True)
    )

    return args

def fetch_pos_weights(dataset_name, csv, label_list, uncertain_label, nan_label):
    if dataset_name == "nih":
        pos = [(csv["Finding Labels"].str.contains(lab)).sum() for lab in label_list]
        neg = [(~csv["Finding Labels"].str.contains(lab)).sum() for lab in label_list]
        pos_weights = torch.tensor((neg / np.maximum(pos, 1)).astype(np.float))
    else:
        pos = (csv[label_list] == 1).sum()
        neg = (csv[label_list] == 0).sum()

        if uncertain_label == 1:
            pos = pos + (csv[label_list] == -1).sum()
        elif uncertain_label == -1:
            neg = neg + (csv[label_list] == -1).sum()

        if nan_label == 1:
            pos = pos + (csv[label_list].isna()).sum()
        elif nan_label == -1:
            neg = neg + (csv[label_list].isna()).sum()

        pos_weights = torch.tensor((neg / np.maximum(pos, 1)).values.astype(np.float))

    return pos_weights

In [17]:
#args = build_args()
im_size = 224
dataset_name = "mimic"
uncertain_label = np.nan
nan_label = np.nan
pretrained_file = "mimic-chexpert_lr_0.01_bs_128_fd_128_qs_65536.pt"
batch_size = 64
num_workers = 4

In [9]:
train_transform_list = [
        transforms.Resize(im_size),
        transforms.CenterCrop(im_size),
        transforms.RandomHorizontalFlip(),
        transforms.RandomVerticalFlip(),
        transforms.ToTensor(),
        HistogramNormalize(),
        TensorToRGB(),
        RemapLabel(-1, uncertain_label),
        NanToInt(nan_label),
    ]

In [10]:
val_transform_list = [
        transforms.Resize(im_size),
        transforms.CenterCrop(im_size),
        transforms.ToTensor(),
        HistogramNormalize(),
        TensorToRGB(),
        RemapLabel(-1, uncertain_label),
    ]

In [11]:
with open("data.yaml", "r") as f:
  paths = yaml.load(f, Loader=yaml.SafeLoader)["paths"]

In [12]:
if dataset_name == "nih":
  dataset_dir = paths["nih"]
if dataset_name == "mimic":
  dataset_dir = paths["mimic"]
elif dataset_name == "chexpert":
  dataset_dir = paths["chexpert"]
elif dataset_name == "mimic-chexpert":
  dataset_dir = [paths["chexpert"], paths["mimic"]]
else:
  raise ValueError("Unrecognized path config.")

In [15]:
if dataset_name in ("chexpert", "mimic", "mimic-chexpert"):
  val_pathology_list = [
            "Atelectasis",
            "Cardiomegaly",
            "Consolidation",
            "Edema",
            "Pleural Effusion",
        ]
elif dataset_name == "nih":
       val_pathology_list = [
            "Atelectasis",
            "Cardiomegaly",
            "Consolidation",
            "Edema",
            "Effusion",
        ]
else:
    raise ValueError("Unrecognized dataset.")

In [16]:
dataset_dir, val_pathology_list

('mimic_data',
 ['Atelectasis', 'Cardiomegaly', 'Consolidation', 'Edema', 'Pleural Effusion'])

In [63]:
"""
Copyright (c) Facebook, Inc. and its affiliates.

This source code is licensed under the MIT license found in the
LICENSE file in the root directory of this source tree.
"""
import os
from argparse import ArgumentParser
from typing import Callable, List, Optional, Union

from base_dataset import BaseDataset
import numpy as np
import pytorch_lightning as pl
import torch
from mimic_cxr import MimicCxrJpgDataset


class TwoImageDataset(torch.utils.data.Dataset):
    """
    Wrapper for returning two augmentations of the same image.

    Args:
        dataset: Pre-initialized data set to return multiple samples from.
    """

    def __init__(self, dataset: BaseDataset):
        assert isinstance(dataset, BaseDataset)
        self.dataset = dataset

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        # randomness handled via the transform objects
        # this requires the transforms to sample randomness from the process
        # generator
        item0 = self.dataset[idx]
        item1 = self.dataset[idx]

        sample = {
            "image0": item0["image"],
            "image1": item1["image"],
            "label": item0["labels"],
        }

        return sample


def fetch_dataset(
    dataset_name: str,
    dataset_dir: Union[List[Union[str, os.PathLike]], Union[str, os.PathLike]],
    split: str,
    transform: Optional[Callable],
    two_image: bool = False,
    label_list="all",
):
    """Dataset fetcher for config handling."""

    assert split in ("train", "val", "test")
    dataset: Union[BaseDataset, TwoImageDataset]

    # determine the dataset
    if dataset_name == "nih":
        assert not isinstance(dataset_dir, list)
        dataset = NIHChestDataset(
            directory=dataset_dir,
            split=split,
            transform=transform,
            label_list=label_list,
            resplit=True,
        )
    if dataset_name == "mimic":
        assert not isinstance(dataset_dir, list)
        print("label_list from fetch_dataset: ", label_list)
        dataset = MimicCxrJpgDataset(
            directory=dataset_dir,
            split=split,
            transform=transform,
            label_list=label_list,
        )
    elif dataset_name == "chexpert":
        assert not isinstance(dataset_dir, list)
        dataset = CheXpertDataset(
            directory=dataset_dir,
            split=split,
            transform=transform,
            label_list=label_list,
        )
    elif dataset_name == "mimic-chexpert":
        assert isinstance(dataset_dir, list)
        dataset = CombinedXrayDataset(
            dataset_list=["chexpert_v1", "mimic-cxr"],
            directory_list=dataset_dir,
            transform_list=[transform, transform],
            label_list=[label_list, label_list],
            split_list=[split, split],
        )
    else:
        raise ValueError(f"dataset {dataset_name} not recognized")

    if two_image is True:
        dataset = TwoImageDataset(dataset)

    return dataset


def worker_init_fn(worker_id):
    """Handle random seeding."""
    worker_info = torch.utils.data.get_worker_info()
    seed = worker_info.seed % (2 ** 32 - 1)  # pylint: disable=no-member

    np.random.seed(seed)


class XrayDataModule(pl.LightningDataModule):
    """
    X-ray data module for training models with PyTorch Lightning.

    Args:
        dataset_name: Name of the dataset.
        dataset_dir: Location of the data.
        label_list: Labels to load for training.
        batch_size: Training batch size.
        num_workers: Number of workers for dataloaders.
        use_two_images: Whether to return two augmentations of same image from
            dataset (for MoCo pretraining).
        train_transform: Transform for training loop.
        val_transform: Transform for validation loop.
        test_transform: Transform for test loop.
    """

    def __init__(
        self,
        dataset_name: str,
        dataset_dir: Union[List[Union[str, os.PathLike]], Union[str, os.PathLike]],
        label_list: Union[str, List[str]] = "all",
        batch_size: int = 1,
        num_workers: int = 4,
        use_two_images: bool = False,
        train_transform: Optional[Callable] = None,
        val_transform: Optional[Callable] = None,
        test_transform: Optional[Callable] = None,
    ):
        super().__init__()
        print("label_list from XrayDataModule: ", label_list)
        self.dataset_name = dataset_name
        self.dataset_dir = dataset_dir
        self.batch_size = batch_size
        self.num_workers = num_workers

        self.train_dataset = fetch_dataset(
            self.dataset_name,
            self.dataset_dir,
            "train",
            train_transform,
            label_list=label_list,
            two_image=use_two_images,
        )
        self.val_dataset = fetch_dataset(
            self.dataset_name,
            self.dataset_dir,
            "val",
            val_transform,
            label_list=label_list,
            two_image=use_two_images,
        )
        self.test_dataset = fetch_dataset(
            self.dataset_name,
            self.dataset_dir,
            "test",
            test_transform,
            label_list=label_list,
            two_image=use_two_images,
        )

        if isinstance(self.train_dataset, TwoImageDataset):
            self.label_list = None
        else:
            self.label_list = self.train_dataset.label_list

    def __dataloader(self, split: str) -> torch.utils.data.DataLoader:
        assert split in ("train", "val", "test")
        shuffle = False
        if split == "train":
            dataset = self.train_dataset
            shuffle = True
        elif split == "val":
            dataset = self.val_dataset
        else:
            dataset = self.test_dataset

        loader = torch.utils.data.DataLoader(
            dataset=dataset,
            batch_size=self.batch_size,
            num_workers=self.num_workers,
            drop_last=True,
            shuffle=shuffle,
            worker_init_fn=worker_init_fn,
        )

        return loader

    def train_dataloader(self):
        return self.__dataloader(split="train")

    def val_dataloader(self):
        return self.__dataloader(split="val")

    def test_dataloader(self):
        return self.__dataloader(split="test")

    @staticmethod
    def add_model_specific_args(parent_parser):
        parser = ArgumentParser(parents=[parent_parser], add_help=False)

        parser.add_argument("--dataset_name", default="mimic", type=str)
        parser.add_argument("--dataset_dir", default=None, type=str)
        parser.add_argument("--batch_size", default=64, type=int)
        parser.add_argument("--num_workers", default=4, type=int)

        return parser

In [65]:
data_module = XrayDataModule(
        dataset_name=dataset_name,
        dataset_dir=dataset_dir,
        batch_size=batch_size,
        num_workers=num_workers,
        train_transform=Compose(train_transform_list),
        val_transform=Compose(val_transform_list),
        test_transform=Compose(val_transform_list),
    )

label_list from XrayDataModule:  all
label_list from fetch_dataset:  all
label_list from fetch_dataset:  all
label_list from fetch_dataset:  all


In [66]:
data_module.label_list

['No Finding',
 'Enlarged Cardiomediastinum',
 'Cardiomegaly',
 'Lung Opacity',
 'Lung Lesion',
 'Edema',
 'Consolidation',
 'Pneumonia',
 'Atelectasis',
 'Pneumothorax',
 'Pleural Effusion',
 'Pleural Other',
 'Fracture',
 'Support Devices']

In [33]:
assert not isinstance(dataset_dir, list)

In [37]:
from typing import Callable, List, Optional, Union

In [38]:
transform = Optional[Callable]

In [61]:
from mimic_cxr import MimicCxrJpgDataset
dataset = MimicCxrJpgDataset(
            directory=dataset_dir,
            split="train",
            transform=transform,
            label_list="all",
        )

In [62]:
dataset.label_list

['No Finding',
 'Enlarged Cardiomediastinum',
 'Cardiomegaly',
 'Lung Opacity',
 'Lung Lesion',
 'Edema',
 'Consolidation',
 'Pneumonia',
 'Atelectasis',
 'Pneumothorax',
 'Pleural Effusion',
 'Pleural Other',
 'Fracture',
 'Support Devices']

In [42]:
Union[str, List[str]]

typing.Union[str, typing.List[str]]

In [None]:
# ------------
# model
# ------------

In [67]:
fetch_pos_weights(
    dataset_name=dataset_name,
    csv=data_module.train_dataset.csv,
    label_list=data_module.label_list,
    uncertain_label=uncertain_label,
    nan_label=nan_label,
)

tensor([0.0000, 0.8726, 0.3917, 0.0671, 0.1601, 1.1363, 0.9775, 1.6701, 0.0348,
        4.0171, 0.5680, 0.0750, 0.2347, 0.0618], dtype=torch.float64)

In [68]:
pos_weights = fetch_pos_weights(
    dataset_name=dataset_name,
    csv=data_module.train_dataset.csv,
    label_list=data_module.label_list,
    uncertain_label=uncertain_label,
    nan_label=nan_label,
)

In [70]:
arch = "densenet121"
max_epochs = 5

In [None]:
 model = SipModule(
        arch=arch,
        num_classes=len(data_module.label_list),
        pretrained_file=pretrained_file,
        label_list=data_module.label_list,
        val_pathology_list=val_pathology_list,
        learning_rate=learning_rate,
        pos_weights=pos_weights,
        epochs=max_epochs,
    )