In [2]:
import os

from datasets import flowers
from nets import resnet_v1
from nets import resnet_utils
from preprocessing import vgg_preprocessing

from keras.applications.vgg16 import (
    VGG16, preprocess_input, decode_predictions)
from keras.preprocessing import image
from tensorflow.contrib import slim

In [None]:
class CustomDataGen():

    def __init__(self, dim_x, dim_y, dim_z, num_class, batch_size):
        self.batch_size = batch_size
        self.dim_x = dim_x
        self.dim_y = dim_y
        self.dim_z = dim_z
        self.num_class = num_class
        # self.augmentation = image.ImageDataGenerator(
        #     rotation_range=20,
        #     shear_range=0.5
        # )

    def randomize_ind(self,data):
        indexes = np.arange(len(data))
        np.random.shuffle(indexes)
        return indexes

    def get_data(self,list):

        X = np.empty((self.batch_size, self.dim_x, self.dim_y, self.dim_z))
        y = np.empty((self.batch_size,self.num_class))

        for id, data in enumerate(list):
            im_path = data.split(' ')[0]
            label = int(data.split(' ')[1])
            img = image.load_img(im_path, target_size=(self.dim_x, self.dim_y))
            x = image.img_to_array(img)
            x = np.expand_dims(x, axis=0)
            x = preprocess_input(x)[0]
            X[id,:,:,:] = x

            y_ = utils.to_categorical(label, self.num_class)
            y[id,...] = y_

        return X, y

    def generate_batch(self, data):

        while 1:
            indexes = self.randomize_ind(data)

            num_batch = int(len(indexes)/self.batch_size)
            for batch_id in range(num_batch):
                temp_list = [data[k] for k in indexes[batch_id*self.batch_size:(batch_id+1)*self.batch_size]]

                X,y = self.get_data(temp_list)
                # return self.augmentation.flow(X,y,self.batch_size)
                yield X,y


In [None]:
def oversample_by_cond(images, label):
   # Oversampling factors per class
    OVERSAMPLE_FACTOR = [1, 1, 4]

   # Set up the predicates
    pred0 = tf.reshape(tf.equal(label, tf.convert_to_tensor([0])), [])
    pred1 = tf.reshape(tf.equal(label, tf.convert_to_tensor([1])), [])
    pred2 = tf.reshape(tf.equal(label, tf.convert_to_tensor([2])), [])

   # Callables functions
    def f0(): return tf.concat([images]*OVERSAMPLE_FACTOR[0], 0), tf.concat([label]*OVERSAMPLE_FACTOR[0], 0)
    def f1(): return tf.concat([images]*OVERSAMPLE_FACTOR[1], 0), tf.concat([label]*OVERSAMPLE_FACTOR[1], 0)
    def f2(): return tf.concat([images]*OVERSAMPLE_FACTOR[2], 0), tf.concat([label]*OVERSAMPLE_FACTOR[2], 0)

   # Exclusive conditionals (one for each class)
    [images, label] = tf.cond(pred0, f0, lambda: [images,label])
    [images, label] = tf.cond(pred1, f1, lambda: [images,label])
    [images, label] = tf.cond(pred2, f2, lambda: [images,label])

    return [images, label]

images = tf.expand_dims(image_decoded, 0)                                                                                                                                                                                                                               

if train:
    # Oversample the train set in order to balance the classes
    [images, labels] = oversample_by_cond(images, label_index)

    # Distort all the concatenated version of the training image
    thread_id = itertools.cycle(range(num_preprocess_threads))
    images = tf.map_fn(lambda img: vgg_preprocessing(img, bbox, train,
        next(thread_id), summariesFlag=False), images)
    images_and_labels = [images, labels]


In [None]:

image_size = resnet_v1.resnet_v1_101.default_image_size


def get_init_fn():
    """Returns a function run by the chief worker to warm-start the training."""
    checkpoint_exclude_scopes=["resnet_v1/logits"]
    
    exclusions = [scope.strip() for scope in checkpoint_exclude_scopes]

    variables_to_restore = []
    for var in slim.get_model_variables():
        excluded = False
        for exclusion in exclusions:
            if var.op.name.startswith(exclusion):
                excluded = True
                break
        if not excluded:
            variables_to_restore.append(var)

    return slim.assign_from_checkpoint_fn(
      os.path.join(checkpoints_dir, 'resnet_v1.ckpt'),
      variables_to_restore)


train_dir = '/tmp/inception_finetuned/'

with tf.Graph().as_default():
    tf.logging.set_verbosity(tf.logging.INFO)
    
    dataset = flowers.get_split('train', flowers_data_dir)
    images, _, labels = load_batch(dataset, height=image_size, width=image_size)
    
    # Create the model, use the default arg scope to configure the batch norm parameters.
    with slim.arg_scope(inception.inception_v1_arg_scope()):
        logits, _ = resnet_v1.resnet_v1_101(images, num_classes=dataset.num_classes, is_training=True)
        
    # Specify the loss function:
    one_hot_labels = slim.one_hot_encoding(labels, dataset.num_classes)
    slim.losses.softmax_cross_entropy(logits, one_hot_labels)
    total_loss = slim.losses.get_total_loss()

    # Create some summaries to visualize the training process:
    tf.summary.scalar('losses/Total Loss', total_loss)
  
    # Specify the optimizer and create the train op:
    optimizer = tf.train.AdamOptimizer(learning_rate=0.01)
    train_op = slim.learning.create_train_op(total_loss, optimizer)
    
    # Run the training:
    final_loss = slim.learning.train(
        train_op,
        logdir=train_dir,
        init_fn=get_init_fn(),
        number_of_steps=2)
        
  
print('Finished training. Last batch loss %f' % final_loss)