In [1]:
%matplotlib inline

import sys
import json
from pathlib import Path

from PIL import Image
from tqdm import tqdm_notebook as tqdm
import numpy as np
import pandas as pd
from torchvision import models as torchvision_models

from bananas.training.criteria import HaltCriteria
from bananas.dataset import DataSet, DataType, Feature

# Root path of project relative to this notebook
ROOT = Path('..')

sys.path.insert(1, str(ROOT / 'scripts'))
from datamodels import *
from utils import *

### Load pre-trained models to be used as starting point

In [2]:
resnet34_model = torchvision_models.resnet34(pretrained=True)
googlenet_model = torchvision_models.googlenet(pretrained=True)

### Read subject data from local file

In [3]:
df = pd.read_csv(ROOT / 'datasets' / 'subject_diagnosis.csv', index_col=0)

# Convert non-primitive fields
df['processed_path'] = df['processed_path'].apply(lambda x: Path(x))
df['image_path'] = df['image_path'].apply(lambda x: Path(x))
df['template_path'] = df['template_path'].apply(lambda x: Path(x))
df['drawing_box'] = df['drawing_box'].apply(lambda x: Box.load(x))
df['template_box'] = df['template_box'].apply(lambda x: Box.load(x))

# Remove all unnecessary columns from our dataset
feat_keys = ['processed_path']
group_columns = ['diagnosis']
df = df[group_columns + feat_keys].copy()

# Normalize all feature columns
df = df.dropna()
for col in feat_keys:
    df[col] = df[col].apply(lambda x: str(ROOT / x))

df.head()

Unnamed: 0_level_0,diagnosis,processed_path
key,Unnamed: 1_level_1,Unnamed: 2_level_1
002_1,SANO,../processed/muellePsic_002Ev1.pdf_pg-16.jpg
002_1,SANO,../processed/casaPsic_002Ev1.pdf_pg-18.jpg
002_1,SANO,../processed/minimentalPsic_002Ev1.pdf_pg-3.jpg
002_1,SANO,../processed/picoPsic_002Ev1.pdf_pg-16.jpg
002_1,SANO,../processed/cruzPsic_002Ev1.pdf_pg-17.jpg


### Train transfer learning model

In [4]:
from itertools import product, combinations
    
# Define all possible hyperparameters
models = [resnet34_model, googlenet_model]
batch_sizes = [24, 32]
test_splits = [.2, .25]
validation_splits = [.2, .25]
skip_cats_opts = [
    'pico',
    'muelle',
    'minimental']
skip_cats_combos = sum([list(combinations(skip_cats_opts, i))
                        for i in range(len(skip_cats_opts))], [])
skip_cats_combos = [[]]

# Initialize random number generator without seed to randomize hyperparamters
rnd = np.random.RandomState()

# Cross product all hyperparameters
parameter_combinations = list(product(
    models, batch_sizes, test_splits, validation_splits, skip_cats_combos))
rnd.shuffle(parameter_combinations)

target_label = 'SANO'
target_column = 'diagnosis'

In [None]:
from bananas.core.mixins import HighDimensionalMixin
from bananas.sampling.cross_validation import DataSplit
from bananas.statistics import scoring
from bananas.statistics.scoring import ScoringFunction
from coconuts.learners.transfer_learning import TransferLearningModel, BaseNNClassifier

# Store results in a list to display them later
trial_results = []

for model, batch_size, test_split, validation_split, skip_cats in tqdm(parameter_combinations, leave=False):

    # Re-initialize seed every time
    random_seed = 0

    # Create a single feature containing all image data
    mask = df['processed_path'].astype(str).apply(
        lambda impath: any([('processed_path_%s' % cat) in impath for cat in skip_cats]))
    image_loader = ImageAugmenterLoader(
        df.loc[~mask, 'processed_path'].values,
        resize=(3, 224, 224),
        normalize=True,
        convert='RGB')
    features = [Feature(
        image_loader,
        kind=DataType.HIGH_DIMENSIOAL,
        sample_size=10,
        random_seed=random_seed)]

    # Define target feature
    target_feature = Feature(
        (df[target_column] == target_label).values, random_seed=random_seed)

    while True:

        # Build dataset, making sure that we have a left-out validation subset
        dataset = DataSet(
            features,
            name=target_label,
            target=target_feature,
            random_seed=random_seed,
            batch_size=batch_size,
            test_split=test_split,
            validation_split=validation_split)

        # Compute test class balance to tell what minimum accuracy we should beat
        test_idx = dataset.sampler.subsamplers[DataSplit.VALIDATION].data
        test_classes = target_feature[test_idx]
        test_class_balance = sum(test_classes) / len(test_classes)

        # Rebuild dataset unless test class balance is within 5% of ground truth
        true_class_balance = sum(target_feature[:] / len(target_feature))
        if abs(test_class_balance - true_class_balance) < .05: break

        # Keep changing the seed to avoid getting stuck
        random_seed += 1

    # Instantiate learner using pre-trained model
    learner = TransferLearningModel(
        model,
        freeze_base_model=True,
        scoring_function=ScoringFunction.ACCURACY) \
        .apply_mixin(BaseNNClassifier, HighDimensionalMixin)

    # Train learner using train dataset
    learner.train(dataset.input_fn, progress=True, max_steps=200)

    # Test learner predictions using left-out dataset
    # We have to do it one datapoint at a time instead of in batch to prevent overflow
    yl, ylt = [], []
    for i in tqdm(test_idx, leave=False):
        X, y = dataset[i:i+1]
        y = learner.label_encoder_.transform(y)
        y_ = learner.predict_proba(X)
        yl.append(y[0])
        ylt.append(y_[0])
    y, y_ = yl, ylt
    score_auroc = scoring.score_auroc(y, y_)
    score_accuracy = scoring.score_accuracy(y, y_)
    score_precision = scoring.score_precision(y, y_)
    score_recall = scoring.score_recall(y, y_)

    # Store trial results
    naive_accuracy = max(test_class_balance, 1 - test_class_balance)
    trial_results.append({
        'Model': model.__class__.__name__,
        'Subset splits': (test_split, validation_split),
        'Skipped categories': ', '.join(skip_cats),
        'Batch size': batch_size,
        'Δ Naive Classifier': score_accuracy - naive_accuracy,
        'Accuracy': score_accuracy,
        'Precision': score_precision,
        'Recall': score_recall,
        'Area under ROC': score_auroc,
    })

In [6]:
pd.DataFrame.from_records(trial_results) \
    .set_index('Model') \
    .groupby(level=0, group_keys=False) \
    .apply(lambda x: x.sort_values('Accuracy', ascending=False).head(10))

Unnamed: 0_level_0,Subset splits,Skipped categories,Batch size,Δ Naive Classifier,Accuracy,Precision,Recall,Area under ROC
Model,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1
GoogLeNet,"(0.2, 0.2)",,24,0.064935,0.590909,0.4,0.098765,0.579753
GoogLeNet,"(0.2, 0.2)",,32,-0.038961,0.487013,0.253521,0.222222,0.488919
ResNet,"(0.2, 0.2)",,24,0.019481,0.545455,0.607056,0.504059,0.469276
ResNet,"(0.2, 0.2)",,32,0.012987,0.538961,0.504755,0.255623,0.50833
