# High-Level Setup & Imports

In [4]:
import torch

# General Python Packages
import os, time, numbers, math

# Torch Packages
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, models
from torch.optim import lr_scheduler, SGD
from torch.autograd import Variable
from torch import nn
from torch.nn import DataParallel
from torch.nn import Module

# General Analytics Packages
import pandas as pd
import numpy as np

# Visualization / Image Packages
import matplotlib.pyplot as plt
from PIL import Image

# Randomization Functions
from random import random as randuni

### Model Classes

In [5]:
def is_image_file(fname):
    """Checks if a file is an image.
    Args:
        fname (string): path to a file
    Returns:
        bool: True if the filename ends with a known image extension
    """
    return fname.lower().endswith('.png')

def create_label_maps(details_df):
    """ Take a descriptive dataframe and extract the unique labels and map to index values
    Args:
        details_df: Dataframe with the image details
    Returns:
        label_list: list of unique labels in the dataframe
        label_to_index: map from labels to indices
    """
    """ TODO: Research paper also excludes these labels but need to figure out how to handle
              cases that have these as positive findings (completely exclude?)
    excluded_labels = ['Edema','Hernia','Emphysema','Fibrosis','No Finding'
                      'Pleural_Thickening','Consolidation']
    """
    excluded_labels = ['No Finding']
    
    label_groups = details_df['Finding Labels'].unique()
    unique_labels = set([label for sublist in label_groups.tolist() for label in sublist.split('|')])
    
    # Drop some label that we do not want to include
    unique_labels = sorted([l for l in unique_labels if l not in excluded_labels])

    index_to_label = {idx: val for idx, val in enumerate(unique_labels)}
    label_to_index = {val: idx for idx, val in index_to_label.items()}

    label_list = list(label_to_index.keys())

    return label_list, label_to_index

def create_image_list(dir):
    """ Create a full list of images available 
    Args:
        dir (string): root directory of images with subdirectories underneath
                      that have the .png images within them
    Returns:
        image_list: list of tuples with (image_name, full_image_path)
    """
    image_list = []
    dir = os.path.expanduser(dir)
    for subfolder in sorted(os.listdir(dir)):
        d = os.path.join(dir, subfolder)
        if not os.path.isdir(d):
            continue
        for subfolder_path, _, fnames in sorted(os.walk(d)):
            for fname in sorted(fnames):
                if is_image_file(fname):
                    path = os.path.join(subfolder_path, fname)
                    image_list.append((fname, path))
    return image_list

def pil_loader(path):
    """ Opens path as file with Pillow (https://github.com/python-pillow/Pillow/issues/835)
    Args:
        path (string): File path to the image
    Returns:
        img: Image in RGB format
    """
    f = open(path, 'rb')
    return Image.open(f)
    #with open(path, 'rb') as f:
    #    return Image.open(f)
        #with Image.open(f) as img:
        #    return img.load()
        
def imshow(inp, title=None):
    """ Convert tensor array to an image (only use post-dataset transform) """
    inp = inp[0]
    inp = inp.numpy().transpose((1, 2, 0))
    mean = np.array([0.485, 0.456, 0.406])
    std = np.array([0.229, 0.224, 0.225])
    inp = std * inp + mean
    inp = np.clip(inp, 0, 1)
    plt.imshow(inp)
    if title is not None:
        plt.title(title)
    plt.pause(0.001)  # pause a bit so that plots are updated

In [6]:
class XrayImageSet(Dataset):
    """
    Args:
        image_root (string): root directory of the images in form image/subfolder/*.png
        csv_file (string): path to the CSV data file
        transform (callable, optional): A function/transform that  takes in an PIL image
            and returns a transformed version. E.g, ``transforms.RandomCrop``
        target_transform (callable, optional): A function/transform that takes in the
            target and transforms it.
        loader (callable, optional): A function to load an image given its path.
     Attributes:
        labels (list): list of the possible label names.
        label_to_index (dict): look from label name to a label index
        imgs (list): List of (filename, image path) tuples
    """
    
    def __init__(self, image_root, csv_file, transform=None, target_transform=None, loader = pil_loader):
        """ Create an instance of the Xray Dataset """
        img_details = pd.read_csv(csv_file)
        
        labels, label_to_index = create_label_maps(img_details)
        imgs = create_image_list(image_root)

        self.imgs = imgs
        self.image_details = img_details
        self.image_root = image_root
        self.labels = labels
        self.label_to_index = label_to_index
        self.transform = transform
        self.target_transform = target_transform
        self.loader = loader
        self.max_label_index = max(label_to_index.values())

    def __getitem__(self, index):
        """ Get image,labels pair by index
        Args:
            index (int): Index
        Returns:
            tuple: (image, target) where target is class_index of the target class.
        """
        fname, path = self.imgs[index]
        target = self.get_one_hot_labels(fname)
        img = self.loader(path)
        if self.transform is not None:
            img = self.transform(img)
        if self.target_transform is not None:
            target = self.target_transform(target)

        return img, target

    def __len__(self):
        """ Calculate length of the dataset (number of images) """
        return len(self.imgs)
    
    def get_labels(self, fname):
        """ Return the label string for the file """
        return self.image_details[self.image_details['Image Index'] == fname]['Finding Labels'].values[0]
    
    def one_hot_labels(self, labels):
        """ Convert the labels string (with each label separated by |) into 1-hot encoding """
        if labels == None:
            return None
        
        split_label_indices = [self.label_to_index.get(label)
                               for label in labels.split('|')
                               if label != 'No Finding']
        
        out = [1 if idx in split_label_indices else 0 for idx in range(self.max_label_index+1)]
        # This code UNHOTs the labels:
        # out = '|'.join([index_to_label.get(idx) for idx, val in enumerate(one_hot_tuple) if val == 1])
        return out

    def get_one_hot_labels(self, fname):
        """ Get the 1-hot encoded label array for the provided file """
        labels = self.get_labels(fname)
        one_hot_labels = self.one_hot_labels(labels)
        return torch.FloatTensor(one_hot_labels)

# Setup

In [7]:
img_data = XrayImageSet(image_root = '/user/images_processed/',
                        csv_file = '/user/img_details.csv',
                        transform = None,
                        target_transform = None)

In [13]:
df = img_data.image_details['Finding Labels'].apply(img_data.one_hot_labels).apply(pd.Series)
df.columns = img_data.labels

Unnamed: 0,0,1,2,3,4,5,6,7,8,9,10,11,12,13
0,0,1,0,0,0,0,0,0,0,0,0,0,0,0
1,0,1,0,0,0,1,0,0,0,0,0,0,0,0
2,0,1,0,0,1,0,0,0,0,0,0,0,0,0
3,0,0,0,0,0,0,0,0,0,0,0,0,0,0
4,0,0,0,0,0,0,0,1,0,0,0,0,0,0
5,0,0,0,0,0,0,0,1,0,0,0,0,0,0
6,0,0,0,0,0,0,0,1,0,0,0,0,0,0
7,0,0,0,0,0,0,0,1,1,0,0,0,0,0
8,0,0,0,0,0,0,0,1,0,0,0,0,0,0
9,0,0,0,0,0,0,0,1,0,0,0,0,0,0


In [31]:
df.sum(axis=0)

Atelectasis           11535
Cardiomegaly           2772
Consolidation          4667
Edema                  2303
Effusion              13307
Emphysema              2516
Fibrosis               1686
Hernia                  227
Infiltration          19870
Mass                   5746
Nodule                 6323
Pleural_Thickening     3385
Pneumonia              1353
Pneumothorax           5298
dtype: int64

In [46]:
s = df.sum(axis=1).clip_upper(1) - 1
abs(s.sum()) # Count of "No Finding"

60412

### Distribution of Disease Counts

In [51]:
s2 = df.sum(axis=1)
s2.value_counts()

0    60412
1    30973
2    14292
3     4829
4     1233
5      298
6       64
7       16
9        2
8        1
dtype: int64

### Correlation Matrix

In [21]:
df.corr()

Unnamed: 0,Atelectasis,Cardiomegaly,Consolidation,Edema,Effusion,Emphysema,Fibrosis,Hernia,Infiltration,Mass,Nodule,Pleural_Thickening,Pneumonia,Pneumothorax
Atelectasis,1.0,0.015846,0.109043,-0.003298,0.172467,0.032538,0.011227,0.010872,0.093393,0.018087,-0.008338,0.025178,0.02791,0.0314
Cardiomegaly,0.015846,1.0,0.015419,0.028371,0.129824,-0.00706,0.004397,0.001773,0.0138,-0.011217,-0.012033,0.009168,0.001341,-0.022464
Consolidation,0.109043,0.015419,1.0,0.020822,0.10122,-0.000521,0.003236,-0.005413,0.045948,0.073477,0.031903,0.028733,0.02359,0.000309
Edema,-0.003298,0.028371,0.020822,1.0,0.061957,-0.009204,-0.013243,-0.002326,0.094005,0.002844,0.000306,-0.002032,0.174042,-0.022471
Effusion,0.172467,0.129824,0.10122,0.061957,1.0,0.011244,-0.002743,-0.003645,0.117842,0.070291,0.018954,0.07192,0.023342,0.047596
Emphysema,0.032538,-0.00706,-0.000521,-0.009204,0.011244,1.0,-0.000908,-0.001466,0.000175,0.022683,-0.007019,0.026408,-0.005163,0.177981
Fibrosis,0.011227,0.004397,0.003236,-0.013243,-0.002743,-0.000908,1.0,0.007478,0.008868,0.009504,0.02253,0.053578,-0.006273,0.000115
Hernia,0.010872,0.001773,-0.005413,-0.002326,-0.003645,-0.001466,0.007478,1.0,-0.003756,0.012028,-0.00241,0.00133,-0.001344,-0.001614
Infiltration,0.093393,0.0138,0.045948,0.094005,0.117842,0.000175,0.008868,-0.003756,1.0,0.014055,0.042873,0.020353,0.070854,0.00045
Mass,0.018087,-0.011217,0.073477,0.002844,0.070291,0.022683,0.009504,0.012028,0.014055,1.0,0.099937,0.064893,-0.002719,0.029068


### Corrurence Counts by Disease

In [22]:
coocc = df.T.dot(df)

In [23]:
coocc

Unnamed: 0,Atelectasis,Cardiomegaly,Consolidation,Edema,Effusion,Emphysema,Fibrosis,Hernia,Infiltration,Mass,Nodule,Pleural_Thickening,Pneumonia,Pneumothorax
Atelectasis,11535,369,1222,221,3269,423,220,40,3259,727,585,495,243,772
Cardiomegaly,369,2772,169,127,1060,44,51,7,583,99,108,111,36,48
Consolidation,1222,169,4667,162,1287,103,79,4,1220,602,428,251,114,222
Edema,221,127,162,2303,592,30,9,3,979,128,131,64,330,33
Effusion,3269,1060,1287,592,13307,359,188,21,3990,1244,909,848,253,995
Emphysema,423,44,103,30,359,2516,36,4,447,212,115,151,21,746
Fibrosis,220,51,79,9,188,36,1686,8,345,115,166,176,11,80
Hernia,40,7,4,3,21,4,8,227,33,25,10,8,2,9
Infiltration,3259,583,1220,979,3990,447,345,33,19870,1151,1544,749,571,943
Mass,727,99,602,128,1244,212,115,25,1151,5746,894,448,62,424


In [None]:
#
