<a href="https://colab.research.google.com/github/tanmayg/LS/blob/master/FB_Covid.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [2]:
!pip install pytorch_lightning

Collecting pytorch_lightning
[?25l  Downloading https://files.pythonhosted.org/packages/38/38/f010c6de967dd9e3c765a252d0551aff7194bab90b681407c5d702ca22df/pytorch_lightning-1.2.0-py3-none-any.whl (813kB)
[K     |████████████████████████████████| 819kB 995kB/s 
[?25hCollecting fsspec[http]>=0.8.1
[?25l  Downloading https://files.pythonhosted.org/packages/91/0d/a6bfee0ddf47b254286b9bd574e6f50978c69897647ae15b14230711806e/fsspec-0.8.7-py3-none-any.whl (103kB)
[K     |████████████████████████████████| 112kB 3.2MB/s 
[?25hCollecting PyYAML!=5.4.*,>=5.1
[?25l  Downloading https://files.pythonhosted.org/packages/64/c2/b80047c7ac2478f9501676c988a5411ed5572f35d1beff9cae07d321512c/PyYAML-5.3.1.tar.gz (269kB)
[K     |████████████████████████████████| 276kB 2.2MB/s 
Collecting future>=0.17.1
[?25l  Downloading https://files.pythonhosted.org/packages/45/0b/38b06fd9b92dc2b68d58b75f900e97884c45bedd2ff83203d933cf5851c9/future-0.18.2.tar.gz (829kB)
[K     |████████████████████████████████| 82

In [9]:
import logging
import os
from argparse import ArgumentParser
from pathlib import Path
from warnings import warn

import numpy as np
import pytorch_lightning as pl
import torch
import yaml
import argparse

argparser = argparse.ArgumentParser()

In [1]:
!pwd

/content


In [4]:
from transforms import (
    Compose,
    HistogramNormalize,
    NanToInt,
    RemapLabel,
    TensorToRGB,
)

In [5]:
from xray_datamodule import XrayDataModule

In [6]:
from torchvision import transforms
from sip_finetune import SipModule

In [None]:
#!python train_sip.py --pretrained_file mimic-chexpert_lr_0.01_bs_128_fd_128_qs_65536.pt

In [12]:
argparser.add_argument("--pretrained_file", help="Pretrained File", default="mimic-chexpert_lr_0.01_bs_128_fd_128_qs_65536.pt")
args = argparser.parse_args(["--pretrained_file", "mimic-chexpert_lr_0.01_bs_128_fd_128_qs_65536.pt"])

In [14]:
args

Namespace(pretrained_file='mimic-chexpert_lr_0.01_bs_128_fd_128_qs_65536.pt')

In [15]:
def build_args(arg_defaults=None):
    pl.seed_everything(1234)
    data_config = Path.cwd() / "data.yaml"
    tmp = arg_defaults
    arg_defaults = {
        "accelerator": "ddp",
        "batch_size": 32,
        "max_epochs": 5,
        "gpus": 1,
        "num_workers": 10,
        "callbacks": [],
    }
    if tmp is not None:
        arg_defaults.update(tmp)

    # ------------
    # args
    # ------------
    parser = ArgumentParser()
    parser.add_argument("--im_size", default=224, type=int)
    parser.add_argument("--uncertain_label", default=np.nan, type=float)
    parser.add_argument("--nan_label", default=np.nan, type=float)
    parser = pl.Trainer.add_argparse_args(parser)
    parser = XrayDataModule.add_model_specific_args(parser)
    parser = SipModule.add_model_specific_args(parser)
    parser.set_defaults(**arg_defaults)
    args = parser.parse_args()

    if args.default_root_dir is None:
        args.default_root_dir = Path.cwd()

    if args.pretrained_file is None:
        warn("Pretrained file not specified, training from scratch.")
    else:
        logging.info(f"Loading pretrained file from {args.pretrained_file}")

    if args.dataset_dir is None:
        with open(data_config, "r") as f:
            paths = yaml.load(f, Loader=yaml.SafeLoader)["paths"]

        if args.dataset_name == "nih":
            args.dataset_dir = paths["nih"]
        if args.dataset_name == "mimic":
            args.dataset_dir = paths["mimic"]
        elif args.dataset_name == "chexpert":
            args.dataset_dir = paths["chexpert"]
        elif args.dataset_name == "mimic-chexpert":
            args.dataset_dir = [paths["chexpert"], paths["mimic"]]
        else:
            raise ValueError("Unrecognized path config.")

    if args.dataset_name in ("chexpert", "mimic", "mimic-chexpert"):
        args.val_pathology_list = [
            "Atelectasis",
            "Cardiomegaly",
            "Consolidation",
            "Edema",
            "Pleural Effusion",
        ]
    elif args.dataset_name == "nih":
        args.val_pathology_list = [
            "Atelectasis",
            "Cardiomegaly",
            "Consolidation",
            "Edema",
            "Effusion",
        ]
    else:
        raise ValueError("Unrecognized dataset.")

    # ------------
    # checkpoints
    # ------------
    checkpoint_dir = Path(args.default_root_dir) / "checkpoints"
    if not checkpoint_dir.exists():
        checkpoint_dir.mkdir(parents=True)
    elif args.resume_from_checkpoint is None:
        ckpt_list = sorted(checkpoint_dir.glob("*.ckpt"), key=os.path.getmtime)
        if ckpt_list:
            args.resume_from_checkpoint = str(ckpt_list[-1])

    args.callbacks.append(
        pl.callbacks.ModelCheckpoint(dirpath=checkpoint_dir, verbose=True)
    )

    return args

def fetch_pos_weights(dataset_name, csv, label_list, uncertain_label, nan_label):
    if dataset_name == "nih":
        pos = [(csv["Finding Labels"].str.contains(lab)).sum() for lab in label_list]
        neg = [(~csv["Finding Labels"].str.contains(lab)).sum() for lab in label_list]
        pos_weights = torch.tensor((neg / np.maximum(pos, 1)).astype(np.float))
    else:
        pos = (csv[label_list] == 1).sum()
        neg = (csv[label_list] == 0).sum()

        if uncertain_label == 1:
            pos = pos + (csv[label_list] == -1).sum()
        elif uncertain_label == -1:
            neg = neg + (csv[label_list] == -1).sum()

        if nan_label == 1:
            pos = pos + (csv[label_list].isna()).sum()
        elif nan_label == -1:
            neg = neg + (csv[label_list].isna()).sum()

        pos_weights = torch.tensor((neg / np.maximum(pos, 1)).values.astype(np.float))

    return pos_weights