In [1]:
%matplotlib inline

import sys
import json
from pathlib import Path

from PIL import Image
from tqdm import tqdm_notebook as tqdm
import numpy as np
import pandas as pd

from bananas.utils import images
from bananas.dataset import DataSet, DataType, Feature

# Root path of project relative to this notebook
ROOT = Path('..')

sys.path.insert(1, str(ROOT / 'scripts'))
from datamodels import *
from utils import *

In [None]:
X, y = [], []
root = ROOT / 'quickdraw'
for file in tqdm(list(root.iterdir())[1:3]):
    name = file.name.split('.')[0]
    for img in tqdm(np.load(file)[:300]):
        y.append(name)
        img = img.reshape(28, 28)
        img = images.ndarray_to_pil(img)
        img = img.resize((300, 300))
        img = images.pil_to_ndarray(img)
        X.append(img.reshape((1, *img.shape)).astype(float) / 255)

In [3]:
ds = DataSet([Feature(X)], target=Feature(y), batch_size=32)

In [4]:
from bananas.statistics.scoring import ScoringFunction
from quick_draw_learner import QDClassifier
learner = QDClassifier(random_seed=0)

### Train model using Quick Draw data

In [5]:
import pickle

model_path = ROOT / 'models' / 'quick_draw_model.pth'
if model_path.exists():
    # If the model has already been trained, load it
    with open(model_path.absolute(), 'rb') as fh:
        learner.model_ = pickle.load(fh)

else:
    # Otherwise train it now
    learner.train(ds.input_fn, progress=True, max_score=.975, max_steps=1000)
    with open(model_path.absolute(), 'wb') as fh:
        pickle.dump(learner.model_, fh)

### Save pre-trained model to be used as starting point

In [6]:
quick_draw_learner = learner
quick_draw_model = learner.model_

### Read subject data from local file

In [7]:
df = pd.read_csv(ROOT / 'datasets' / 'subject_diagnosis.csv', index_col=0)

# Convert non-primitive fields
df['processed_path'] = df['processed_path'].apply(lambda x: Path(x))
df['image_path'] = df['image_path'].apply(lambda x: Path(x))
df['template_path'] = df['template_path'].apply(lambda x: Path(x))
df['drawing_box'] = df['drawing_box'].apply(lambda x: Box.load(x))
df['template_box'] = df['template_box'].apply(lambda x: Box.load(x))

# Remove all unnecessary columns from our dataset
feat_keys = ['processed_path']
group_columns = ['diagnosis']
df = df[group_columns + feat_keys].copy()

# Normalize all feature columns
df = df.dropna()
for col in feat_keys:
    df[col] = df[col].apply(lambda x: str(ROOT / x))

df.head()

Unnamed: 0_level_0,diagnosis,processed_path
key,Unnamed: 1_level_1,Unnamed: 2_level_1
002_1,SANO,../processed/muellePsic_002Ev1.pdf_pg-16.jpg
002_1,SANO,../processed/casaPsic_002Ev1.pdf_pg-18.jpg
002_1,SANO,../processed/minimentalPsic_002Ev1.pdf_pg-3.jpg
002_1,SANO,../processed/picoPsic_002Ev1.pdf_pg-16.jpg
002_1,SANO,../processed/cruzPsic_002Ev1.pdf_pg-17.jpg


### Train transfer learning model

In [8]:
from itertools import product, combinations
    
# Define all possible hyperparameters
batch_sizes = [32, 24]
test_splits = [.2, .25]
validation_splits = [.2, .25]
skip_cats_opts = [
    'pico',
    'muelle',
    'minimental']
skip_cats_combos = sum([list(combinations(skip_cats_opts, i))
                        for i in range(len(skip_cats_opts))], [])

# Initialize random number generator without seed to randomize hyperparamters
rnd = np.random.RandomState()

# Cross product all hyperparameters
parameter_combinations = list(product(
    batch_sizes, test_splits, validation_splits, skip_cats_combos))
rnd.shuffle(parameter_combinations)

target_label = 'SANO'
target_column = 'diagnosis'

In [None]:
from bananas.core.mixins import HighDimensionalMixin
from bananas.sampling.cross_validation import DataSplit
from bananas.statistics import scoring
from bulbasaur.learners.transfer_learning import TransferLearningModel

# Store results in a list to display them later
trial_results = []

for batch_size, test_split, validation_split, skip_cats in tqdm(parameter_combinations, leave=False):

    # Re-initialize seed every time
    random_seed = 0

    # Create a single feature containing all image data
    mask = df['processed_path'].astype(str).apply(
        lambda impath: any([('processed_path_%s' % cat) in impath for cat in skip_cats]))
    features = [Feature(
        ImageAugmenterLoader(df.loc[~mask, 'processed_path'].values),
        kind=DataType.HIGH_DIMENSIOAL,
        sample_size=10,
        random_seed=random_seed)]

    # Define target feature
    target_feature = Feature(
        (df[target_column] == target_label).values, random_seed=random_seed)

    while True:

        # Build dataset, making sure that we have a left-out validation subset
        dataset = DataSet(
            features,
            name=target_label,
            target=target_feature,
            random_seed=random_seed,
            batch_size=batch_size,
            test_split=test_split,
            validation_split=validation_split)

        # Compute test class balance to tell what minimum accuracy we should beat
        test_idx = dataset.sampler.subsamplers[DataSplit.VALIDATION].data
        test_classes = target_feature[test_idx]
        test_class_balance = sum(test_classes) / len(test_classes)

        # Rebuild dataset unless test class balance is within 5% of ground truth
        true_class_balance = sum(target_feature[:] / len(target_feature))
        if abs(test_class_balance - true_class_balance) < .05: break

        # Keep changing the seed to avoid getting stuck
        random_seed += 1

    # Instantiate learner using pre-trained model
    learner = TransferLearningModel(quick_draw_learner, freeze_base_model=False)

    # Train learner using train dataset
    learner.train(dataset.input_fn, progress=True, max_steps=200)

    # Test learner predictions using left-out dataset
    X, y = dataset[test_idx]
    y = learner.label_encoder_.transform(y)
    y_ = learner.predict_proba(X)
    score_auroc = scoring.score_auroc(y, y_)
    score_accuracy = scoring.score_accuracy(y, y_)
    score_precision = scoring.score_precision(y, y_)
    score_recall = scoring.score_recall(y, y_)

    # Store trial results
    naive_accuracy = max(test_class_balance, 1 - test_class_balance)
    trial_results.append({
        'Subset splits': (test_split, validation_split),
        'Skipped categories': ', '.join(skip_cats),
        'Batch size': batch_size,
        'Δ Naive Classifier': score_accuracy - naive_accuracy,
        'Accuracy': score_accuracy,
        'Precision': score_precision,
        'Recall': score_recall,
        'Area under ROC': score_auroc
    })

In [10]:
pd.DataFrame.from_records(trial_results) \
    .sort_values('Accuracy', ascending=False).head(10)

Unnamed: 0,Subset splits,Skipped categories,Batch size,Δ Naive Classifier,Accuracy,Precision,Recall,Area under ROC
34,"(0.2, 0.25)",,32,0.082546,0.597401,0.0,0.0,0.570604
47,"(0.2, 0.25)",muelle,32,0.051546,0.556701,0.0,0.0,0.500053
9,"(0.25, 0.25)","muelle, minimental",24,0.037037,0.539095,0.0,0.0,0.532575
42,"(0.2, 0.2)",minimental,24,0.0,0.525974,0.0,0.0,0.486175
2,"(0.2, 0.2)","pico, minimental",32,0.0,0.525974,0.0,0.0,0.498609
52,"(0.2, 0.2)","muelle, minimental",24,0.0,0.525974,0.0,0.0,0.512916
38,"(0.2, 0.2)","pico, minimental",24,0.0,0.525974,0.0,0.0,0.519998
8,"(0.2, 0.2)","pico, muelle",24,0.0,0.525974,0.0,0.0,0.506004
43,"(0.2, 0.2)",pico,24,0.0,0.525974,0.0,0.0,0.520252
16,"(0.2, 0.2)",muelle,24,0.0,0.525974,0.0,0.0,0.510549
