In [None]:
from PIL import Image
import numpy as np
import pandas as pd 
import os, cv2, re, random
import numpy as np
import pandas as pd
from keras.preprocessing.image import ImageDataGenerator
from keras.preprocessing.image import img_to_array, load_img
from keras import layers, models, optimizers
from keras import backend as K
from sklearn.model_selection import train_test_split
from PIL import Image
from keras.models import *
from keras.layers import *
from keras.optimizers import Adam
from keras.callbacks import ModelCheckpoint

In [None]:
TRAIN_DIR = '/home/Addy/MURA-v1.1/training/'

ROWS = 299
COLS = 299
CHANNELS = 3

MURA_train = [TRAIN_DIR+i for i in os.listdir(TRAIN_DIR)]



def read_image(file_path):
    img = cv2.imread(file_path, cv2.IMREAD_COLOR) #cv2.IMREAD_GRAYSCALE
    return cv2.resize(img, (ROWS, COLS), interpolation=cv2.INTER_CUBIC)


def prep_data(images):
    count = len(images)
    data = np.ndarray((count, CHANNELS, ROWS, COLS), dtype=np.uint8)

    for i, image_file in enumerate(images):
        image = read_image(image_file)
        data[i] = image.T
        if i%250 == 0: print('Processed {} of {}'.format(i, count))
    
    return data

train = prep_data(MURA_train)

print("Train shape: {}".format(train.shape))

In [None]:
def atoi(text):
    return int(text) if text.isdigit() else text

def natural_keys(text):
    return [ atoi(c) for c in re.split('(\d+)', text) ]

In [None]:
def prepare_data(list_of_images):
    """
    Returns two arrays: 
        x is an array of resized images
        y is an array of labels
    """
    x = [] # images as arrays
    y = [] # labels
    
    for image in list_of_images:
        x.append(cv2.resize(cv2.imread(image), (299,299)))#, interpolation=cv2.INTER_CUBIC))
    
    for i in list_of_images:
        if 'positive' in i:
            y.append(1)
        else:
            y.append(0)
            
    return x, y

In [None]:
X, Y = prepare_data(MURA_train)
print(K.image_data_format())

In [None]:
X_train, X_val, Y_train, Y_val = train_test_split(X,Y, test_size=0.2, random_state=1)
nb_train_samples = len(X_train)
nb_validation_samples = len(X_val)
batch_size = 16

In [None]:
train_datagen = ImageDataGenerator(
    rescale=1. / 255,
    shear_range=0.2,
    zoom_range=0.2,
    horizontal_flip=True)

val_datagen = ImageDataGenerator(
    rescale=1. / 255,
    shear_range=0.2,
    zoom_range=0.2,
    horizontal_flip=True)

In [None]:
train_generator = train_datagen.flow(np.array(X), Y, batch_size=batch_size)
validation_generator = val_datagen.flow(np.array(X), Y, batch_size=batch_size)

In [None]:
from keras.applications.xception import Xception,preprocess_input,decode_predictions
base_model = Xception(include_top=False, weights='imagenet', input_tensor=None, input_shape=None, pooling=None, classes=1000)
base_model.trainable=False

In [None]:
pt_features = Input(base_model.get_output_shape_at(0)[1:], name = 'feature_input')
pt_depth = base_model.get_output_shape_at(0)[-1]

In [None]:
bn_features = BatchNormalization()(pt_features)

attn_layer = Conv2D(128, kernel_size = (1,1), padding = 'same', activation = 'elu')(bn_features)
attn_layer = Conv2D(32, kernel_size = (1,1), padding = 'same', activation = 'elu')(attn_layer)
attn_layer = Conv2D(16, kernel_size = (1,1), padding = 'same', activation = 'elu')(attn_layer)
attn_layer = AvgPool2D((2,2), strides = (1,1), padding = 'same')(attn_layer)
attn_layer = Conv2D(1, kernel_size = (1,1), padding = 'valid', activation = 'sigmoid')(attn_layer)

up_c2_w = np.ones((1, 1, 1, pt_depth))
up_c2 = Conv2D(pt_depth, kernel_size = (1,1), padding = 'same', 
               activation = 'linear', use_bias = False, weights = [up_c2_w])
up_c2.trainable = False
attn_layer = up_c2(attn_layer)

mask_features = multiply([attn_layer, bn_features])
gap_features = GlobalAveragePooling2D()(mask_features)
gap_mask = GlobalAveragePooling2D()(attn_layer)
# to account for missing values from the attention model
gap = Lambda(lambda x: x[0]/x[1], name = 'RescaleGAP')([gap_features, gap_mask])
gap_dr = Dropout(0.5)(gap)
gap_features1 = GlobalAveragePooling2D()(mask_features)
gap_mask1 = GlobalAveragePooling2D()(attn_layer)
#to account for missing values from the attention model
gap1 = Lambda(lambda x: x[0]/x[1], name = 'RescaleGAP')([gap_features1, gap_mask1])
gap_dr1 = Dropout(0.5)(gap1)
gap_features2 = GlobalAveragePooling2D()(mask_features)
gap_mask2 = GlobalAveragePooling2D()(attn_layer)
gap2 = Lambda(lambda x: x[0]/x[1], name = 'RescaleGAP')([gap_features2, gap_mask2])
gap_dr2 = Dropout(0.5)(gap2)

dr_steps = Dropout(0.5)(Dense(128, activation = 'elu')(gap_dr2))
out_layer = Dense(1, activation = 'sigmoid')(dr_steps)

attn_model = Model(inputs = [pt_features], outputs = [out_layer], name = 'attention_model')

attn_model.compile(optimizer = Adam(lr = 1e-3, beta_1=0.9, beta_2 = 0.999, epsilon = None, decay = 0.0, amsgrad = True),
              loss = 'binary_crossentropy', metrics = ['binary_accuracy'])

attn_model.summary()

In [None]:
model = Sequential(name = 'combined_model')
model.add(base_model)
model.add(attn_model)
model.compile(optimizer = Adam(lr = 1e-3, beta_1=0.9, beta_2 = 0.999, epsilon = None, decay = 0.0, amsgrad = True),
              loss = 'binary_crossentropy', metrics = ['binary_accuracy'])
model.summary()

In [None]:
checkpoint = ModelCheckpoint('Xception-Attention-best.hdf5',
monitor='val_binary_accuracy', verbose=1, save_best_only=True, save_weights_only=True, mode='max', period=1)

callbacks_list = [checkpoint]


history = model.fit_generator(train_generator, steps_per_epoch=nb_train_samples // batch_size, epochs=150, 
callbacks=callbacks_list, validation_data=validation_generator, validation_steps=nb_validation_samples // batch_size)

In [None]:
model.save('Xception-Attention.h5')

In [None]:
from keras.callbacks import *
import matplotlib.pyplot as plt

plt.rcParams['figure.figsize'] = (12,6)

acc = history.history['binary_accuracy']
val_acc = history.history['val_binary_accuracy']
loss = history.history['loss']
val_loss = history.history['val_loss']
epochs = range(1, len(acc) + 1)

plt.title('Training and validation accuracy')
plt.plot(epochs, acc, 'red', label='Training acc')
plt.plot(epochs, val_acc, 'blue', label='Validation acc')
plt.legend()

plt.legend()

plt.show()