In [None]:
%reload_ext autoreload
%autoreload 2
%matplotlib inline

In [None]:
import os, sys
from random import sample

In [None]:
# This file contains all the main external libs we'll use
from fastai.imports import *

In [None]:
from fastai.transforms import *
from fastai.conv_learner import *
from fastai.model import *
from fastai.dataset import *
from fastai.sgdr import *
from fastai.plots import *

In [None]:
from sklearn.metrics import confusion_matrix

### Data

In [None]:
PATH = 'data/cell_images/'


assert os.path.exists(PATH)
!ls -l {PATH}

In [None]:
TEST_PATH = f'{PATH}test/'
TRAIN_PATH = f'{PATH}train/'
VALID_PATH = f'{PATH}valid/'

CLASS_DIR_NAMES = ['Parasitized', 'Uninfected']

In [None]:
# save current dir
notebook_dir = os.getcwd()
notebook_dir

In [None]:
os.chdir(PATH)
!pwd

#### Create test data set if it doesn't exist

In [None]:
def is_testdir(path):
    return 'test/' in path;

In [None]:
def create_dataset(path, dataset_perc):
    if not os.path.isdir(path):
        os.mkdir(path)
        
        for class_dir in CLASS_DIR_NAMES:
            if os.path.exists(f'{class_dir}/Thumbs.db'): os.remove(f'{class_dir}/Thumbs.db')
            files = os.listdir(class_dir)
            test_sample_no = (round(len(files) * (dataset_perc/100))) 
            random_test_sample_files = sample(files, test_sample_no)

            if not is_testdir(path): os.mkdir(f'{path}{class_dir}') # create the data class dirs
            for file in random_test_sample_files:
                if is_testdir(path):
                    os.rename(f'{PATH}{class_dir}/{file}', f'{path}{file}')
                else:
                    os.rename(f'{PATH}{class_dir}/{file}', f'{path}{class_dir}/{file}')
                

In [None]:
create_dataset(TEST_PATH, 5)

In [None]:
create_dataset(VALID_PATH, 10)

#### Move the remaining to create training set

In [None]:
if not os.path.isdir(TRAIN_PATH):
    os.mkdir(TRAIN_PATH)
    os.rename(f'{PATH}{CLASS_DIR_NAMES[0]}/', f'{TRAIN_PATH}{CLASS_DIR_NAMES[0]}/')
    os.rename(f'{PATH}{CLASS_DIR_NAMES[1]}/', f'{TRAIN_PATH}{CLASS_DIR_NAMES[1]}/')

## Check images

In [None]:
!pwd

In [None]:
os.chdir(notebook_dir)

In [None]:
!pwd

In [None]:
def display_file(dir_name):
    files = !ls {PATH}/valid/{dir_name}/ | head
    img = plt.imread(f'{PATH}valid/{dir_name}/{files[0]}')
    plt.imshow(img)
    print(img.shape)

In [None]:
display_file('Parasitized')

In [None]:
display_file('Uninfected')

## Data augmentation

In [None]:
bs = 64
arch = resnet34

In [None]:
tfms = tfms_from_model(arch, bs, aug_tfms=transforms_top_down)

In [None]:
data = ImageClassifierData.from_paths(PATH, tfms=tfms, bs=bs, test_name='test')

## Train

In [None]:
learn = ConvLearner.pretrained(arch, data, precompute=False)

In [None]:
learn.lr_find()

In [None]:
learn.sched.plot_lr()

In [None]:
learn.sched.plot()

In [None]:
lr = 0.09 # pre_compute = True
lr = 0.008 # pre_compute = False
lr = 0.001 # tfms = side on / top down

In [None]:
learn.fit(lr, 3, cycle_len=1)

#### Fine tuning

In [None]:
lrs = [lr/9, lr/3, lr]

In [None]:
learn.unfreeze()
learn.bn_freeze(True)

In [None]:
learn.fit(lrs, 3, cycle_len=1, cycle_mult=2)

In [None]:
learn.sched.plot_lr()

In [None]:
learn.sched.plot_loss()

In [None]:
learn.save(f'{bs}_all')
learn.load(f'{bs}_all')

#### Test time data augmentation (TTA)

##### Validation set

In [None]:
log_preds, y = learn.TTA() 

In [None]:
probs = np.mean(np.exp(log_preds), 0)

In [None]:
accuracy_np(probs, y)

## Analyze results

### Confusion matrix

In [None]:
predications = np.argmax(probs, axis=1)

In [None]:
cm = confusion_matrix(y, predications)

In [None]:
plot_confusion_matrix(cm, data.classes)