# Pneumonia Balanced Databunches

Includes support for balanced classes and and subsetted training set size.  Default is 100 entries, use scale for grow/shrink size.

Databunches:
- get_db_vb() - returns 2 classes: NORMAL vs PNEUMONIA
- get_db_vb() - returns 2 classes: viral vs bacterial
- get_db_nvb() - returns 3 classes: normal, viral, bacterial

In [1]:
%reload_ext autoreload
%autoreload 2
%matplotlib inline

In [2]:
import collections

In [3]:
from fastai.vision import *
from fastai.callbacks import SaveModelCallback
from fastai.metrics import error_rate
import os

In [4]:
import fastai_addons   #add plot2 extension -- learn.recorder.plot2()
from fastai_addons import interpretation_summary, plot_confusion_matrix, \
                          get_accuracy, analyze_confidence, accuracy_vs_threshold, \
                          show_incremental_accuracy,  analyze_low_confidence, \
                          plot_confusion_matrix_thresh, get_val_stats

In [5]:
from pneumonia_loaders import *

# Config

In [6]:
model = models.resnet18
prefix = 'other_classifier2_'
size=500
bs = 64

In [7]:
path = Path()/'data'/'chest_xray'
path.ls()

[WindowsPath('data/chest_xray/data'),
 WindowsPath('data/chest_xray/test'),
 WindowsPath('data/chest_xray/train'),
 WindowsPath('data/chest_xray/val')]

# Code

# Analysis

### Get Unfiltered Counts

In [8]:
print(characterize_labellist(path/'train'))
print(characterize_labellist(path/'val'))
print(characterize_labellist(path/'test'))
print('')
print(characterize_labellist(path/'train', filter_func=filter_files))
print(characterize_labellist(path/'val', filter_func=filter_files))
print(characterize_labellist(path/'test', filter_func=filter_files))
print('')
print(characterize_labellist(path/'train', label_func=get_labels))
print(characterize_labellist(path/'val', label_func=get_labels))
print(characterize_labellist(path/'test', label_func=get_labels))
print('')
print(characterize_labellist(path/'train', label_func=get_labels, filter_func=filter_files))
print(characterize_labellist(path/'val', label_func=get_labels, filter_func=filter_files))
print(characterize_labellist(path/'test', label_func=get_labels, filter_func=filter_files))

{'NORMAL': 1341, 'PNEUMONIA': 3875, '_total': 5216}
{'NORMAL': 8, 'PNEUMONIA': 8, '_total': 16}
{'NORMAL': 234, 'PNEUMONIA': 390, '_total': 624}

{'PNEUMONIA': 3875, '_total': 3875}
{'PNEUMONIA': 8, '_total': 8}
{'PNEUMONIA': 390, '_total': 390}

{'normal': 1341, 'bacteria': 2530, 'virus': 1345, '_total': 5216}
{'normal': 8, 'bacteria': 8, '_total': 16}
{'normal': 234, 'bacteria': 242, 'virus': 148, '_total': 624}

{'bacteria': 2530, 'virus': 1345, '_total': 3875}
{'bacteria': 8, '_total': 8}
{'bacteria': 242, 'virus': 148, '_total': 390}


### Get Sampled Counts

In [9]:
148/234

0.6324786324786325

In [10]:
sample_pneumonia = partial(sample_files, probs={'NORMAL':1.0, 'PNEUMONIA':0.346}, default_prob=0)
sample_pneumonia_v = partial(sample_files, probs={'NORMAL':1.0, 'PNEUMONIA':1.0}, default_prob=0)
sample_pneumonia_t = partial(sample_files, probs={'NORMAL':1.0, 'PNEUMONIA':0.6}, default_prob=0)

sample_vb = partial(sample_files, probs={'bacteria':0.5316, 'virus':1.0}, default_prob=1.0)
sample_vb_v = partial(sample_files, probs={'bacteria':1.0, 'virus':1.0}, default_prob=1.0)
sample_vb_t = partial(sample_files, probs={'bacteria':0.6115, 'virus':1.0}, default_prob=0.6325)

In [11]:
print(characterize_labellist(path/'train', sample_func=sample_pneumonia))
print(characterize_labellist(path/'val', sample_func=sample_pneumonia_v))
print(characterize_labellist(path/'test', sample_func=sample_pneumonia_t))
print('')
print(characterize_labellist(path/'train', filter_func=filter_files, sample_func=sample_pneumonia))
print(characterize_labellist(path/'val', filter_func=filter_files, sample_func=sample_pneumonia_v))
print(characterize_labellist(path/'test', filter_func=filter_files, sample_func=sample_pneumonia_t))
print('')
print(characterize_labellist(path/'train', label_func=get_labels, sample_func=sample_vb))
print(characterize_labellist(path/'val', label_func=get_labels, sample_func=sample_vb_v))
print(characterize_labellist(path/'test', label_func=get_labels, sample_func=sample_vb_t))
print('')
print(characterize_labellist(path/'train', label_func=get_labels, filter_func=filter_files, sample_func=sample_vb))
print(characterize_labellist(path/'val', label_func=get_labels, filter_func=filter_files, sample_func=sample_vb_v))
print(characterize_labellist(path/'test', label_func=get_labels, filter_func=filter_files, sample_func=sample_vb_t))

{'NORMAL': 1341, 'PNEUMONIA': 1353, '_total': 2694}
{'NORMAL': 8, 'PNEUMONIA': 8, '_total': 16}
{'NORMAL': 234, 'PNEUMONIA': 228, '_total': 462}

{'PNEUMONIA': 1316, '_total': 1316}
{'PNEUMONIA': 8, '_total': 8}
{'PNEUMONIA': 226, '_total': 226}

{'normal': 1341, 'bacteria': 1366, 'virus': 1345, '_total': 4052}
{'normal': 8, 'bacteria': 8, '_total': 16}
{'normal': 133, 'bacteria': 146, 'virus': 148, '_total': 427}

{'bacteria': 1336, 'virus': 1345, '_total': 2681}
{'bacteria': 8, '_total': 8}
{'bacteria': 141, 'virus': 148, '_total': 289}


### Reduce size

In [12]:
100/302

0.33112582781456956

In [13]:
print(characterize_labellist(path/'train', sample_func=sample_pneumonia, p_sample=0.0373))
print(characterize_labellist(path/'val', sample_func=sample_pneumonia_v))
print(characterize_labellist(path/'test', sample_func=sample_pneumonia_t, p_sample=0.216))
print('')
print(characterize_labellist(path/'train', label_func=get_labels, sample_func=sample_vb, p_sample=0.0249))
print(characterize_labellist(path/'val', label_func=get_labels, sample_func=sample_vb_v))
print(characterize_labellist(path/'test', label_func=get_labels, sample_func=sample_vb_t, p_sample=0.2273))
print('')
print(characterize_labellist(path/'train', label_func=get_labels, filter_func=filter_files, sample_func=sample_vb, p_sample=0.0376))
print(characterize_labellist(path/'val', label_func=get_labels, filter_func=filter_files, sample_func=sample_vb_v))
print(characterize_labellist(path/'test', label_func=get_labels, filter_func=filter_files, sample_func=sample_vb_t, p_sample=0.3311))

{'NORMAL': 48, 'PNEUMONIA': 56, '_total': 104}
{'NORMAL': 8, 'PNEUMONIA': 8, '_total': 16}
{'NORMAL': 48, 'PNEUMONIA': 33, '_total': 81}

{'normal': 39, 'bacteria': 26, 'virus': 30, '_total': 95}
{'normal': 8, 'bacteria': 8, '_total': 16}
{'normal': 37, 'bacteria': 37, 'virus': 33, '_total': 107}

{'virus': 54, 'bacteria': 61, '_total': 115}
{'bacteria': 8, '_total': 8}
{'bacteria': 56, 'virus': 45, '_total': 101}


### Verify Normal / Pneumonia

In [14]:
data = get_db_np(path)
data

ImageDataBunch;

Train: LabelList (108 items)
x: ImageList
Image (3, 224, 224),Image (3, 224, 224),Image (3, 224, 224),Image (3, 224, 224),Image (3, 224, 224)
y: CategoryList
NORMAL,NORMAL,NORMAL,NORMAL,NORMAL
Path: data\chest_xray\train;

Valid: LabelList (99 items)
x: ImageList
Image (3, 224, 224),Image (3, 224, 224),Image (3, 224, 224),Image (3, 224, 224),Image (3, 224, 224)
y: CategoryList
NORMAL,NORMAL,NORMAL,NORMAL,NORMAL
Path: data\chest_xray\test;

Test: None

In [15]:
print('Training set:')
show_categories(data.train_ds.y)
print('\nValidation set:')
show_categories(data.valid_ds.y)

Training set:
  NORMAL    :    51     47.2%
  PNEUMONIA :    57     52.8%
  Total     :   108

Validation set:
  NORMAL    :    52     52.5%
  PNEUMONIA :    47     47.5%
  Total     :    99


### Verify Viral / Bacterial

In [16]:
data = get_db_vb(path)
data

ImageDataBunch;

Train: LabelList (103 items)
x: ImageList
Image (3, 224, 224),Image (3, 224, 224),Image (3, 224, 224),Image (3, 224, 224),Image (3, 224, 224)
y: CategoryList
virus,virus,virus,virus,virus
Path: data\chest_xray\train;

Valid: LabelList (95 items)
x: ImageList
Image (3, 224, 224),Image (3, 224, 224),Image (3, 224, 224),Image (3, 224, 224),Image (3, 224, 224)
y: CategoryList
bacteria,bacteria,bacteria,bacteria,bacteria
Path: data\chest_xray\test;

Test: None

In [17]:
print('Training set:')
show_categories(data.train_ds.y)
print('\nValidation set:')
show_categories(data.valid_ds.y)

Training set:
  bacteria  :    45     43.7%
  virus     :    58     56.3%
  Total     :   103

Validation set:
  bacteria  :    51     53.7%
  virus     :    44     46.3%
  Total     :    95


### Verify Normal/Viral / Bacterial

In [18]:
data = get_db_nvb(path)
data

ImageDataBunch;

Train: LabelList (121 items)
x: ImageList
Image (3, 224, 224),Image (3, 224, 224),Image (3, 224, 224),Image (3, 224, 224),Image (3, 224, 224)
y: CategoryList
normal,normal,normal,normal,normal
Path: data\chest_xray\train;

Valid: LabelList (106 items)
x: ImageList
Image (3, 224, 224),Image (3, 224, 224),Image (3, 224, 224),Image (3, 224, 224),Image (3, 224, 224)
y: CategoryList
normal,normal,normal,normal,normal
Path: data\chest_xray\test;

Test: None

In [19]:
print('Training set:')
show_categories(data.train_ds.y)
print('\nValidation set:')
show_categories(data.valid_ds.y)

Training set:
  bacteria  :    39     32.2%
  normal    :    42     34.7%
  virus     :    40     33.1%
  Total     :   121

Validation set:
  bacteria  :    44     41.5%
  normal    :    27     25.5%
  virus     :    35     33.0%
  Total     :   106
