In [None]:
from matplotlib import pyplot as plt
# import glob
import numpy as np
import random
import os
import string
from PIL import Image 
import time
import tensorflow as tf
import pickle
import json
import os

from nets import vgg
from preprocessing import vgg_preprocessing

%matplotlib inline

# take an array of shape (n, height, width) or (n, height, width, channels)
# and visualize each (height, width) thing in a grid of size approx. sqrt(n) by sqrt(n)
def vis_square(data, padsize=1, padval=0):
    data -= data.min()
    data /= data.max()
    
    # force the number of filters to be square
    n = int(np.ceil(np.sqrt(data.shape[0])))
    padding = ((0, n ** 2 - data.shape[0]), (0, padsize), (0, padsize)) + ((0, 0),) * (data.ndim - 3)
    data = np.pad(data, padding, mode='constant', constant_values=(padval, padval))
    
    # tile the filters into an image
    data = data.reshape((n, n) + data.shape[1:]).transpose((0, 2, 1, 3) + tuple(range(4, data.ndim + 1)))
    data = data.reshape((n * data.shape[1], n * data.shape[3]) + data.shape[4:])
    
    plt.imshow(data)

In [None]:
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Provides data for the fish dataset.

The dataset scripts used to create the dataset can be found at:
"""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import os
import tensorflow as tf

from datasets import dataset_utils

slim = tf.contrib.slim

_FILE_PATTERN = 'fish_%s_*.tfrecord'

SPLITS_TO_SIZES = {'train': 11002, 'validation': 2750, 'test': 1000}

_NUM_CLASSES = 8

_ITEMS_TO_DESCRIPTIONS = {
    'image': 'A color image of varying size.',
    'label': 'A single integer between 0 and 7',
    'bb_label': 'A list of boundary boxes'
}


def get_split(split_name, dataset_dir, file_pattern=None, reader=None):
    """Gets a dataset tuple with instructions for reading fish.

    Args:
    split_name: A train/validation split name.
    dataset_dir: The base directory of the dataset sources.
    file_pattern: The file pattern to use when matching the dataset sources.
      It is assumed that the pattern contains a '%s' string so that the split
      name can be inserted.
    reader: The TensorFlow reader type.

    Returns:
    A `Dataset` namedtuple.

    Raises:
    ValueError: if `split_name` is not a valid train/validation split.
    """
    if split_name not in SPLITS_TO_SIZES:
        raise ValueError('split name %s was not recognized.' % split_name)

    if not file_pattern:
        file_pattern = _FILE_PATTERN
        file_pattern = os.path.join(dataset_dir, file_pattern % split_name)

    # Allowing None in the signature so that dataset_factory can use the default.
    if reader is None:
        reader = tf.TFRecordReader

    keys_to_features = {
        'image/encoded': tf.FixedLenFeature((), tf.string, default_value=''),
        'image/format': tf.FixedLenFeature((), tf.string, default_value='png'),
        'image/class/label': tf.FixedLenFeature(
            [], tf.int64, default_value=tf.zeros([], dtype=tf.int64)),
        'image/bb_label': tf.FixedLenFeature((), tf.string, default_value=b''),
    }

    items_to_handlers = {
        'image': slim.tfexample_decoder.Image(),
        'label': slim.tfexample_decoder.Tensor('image/class/label'),
        'bb_label': slim.tfexample_decoder.Tensor('image/bb_label'),
    }

    decoder = slim.tfexample_decoder.TFExampleDecoder(
        keys_to_features, items_to_handlers)

    labels_to_names = None
    if dataset_utils.has_labels(dataset_dir):
        labels_to_names = dataset_utils.read_label_file(dataset_dir)

    return slim.dataset.Dataset(
        data_sources=file_pattern,
        reader=reader,
        decoder=decoder,
        num_samples=SPLITS_TO_SIZES[split_name],
        items_to_descriptions=_ITEMS_TO_DESCRIPTIONS,
        num_classes=_NUM_CLASSES,
        labels_to_names=labels_to_names)



In [None]:

slim = tf.contrib.slim


def load_batch(dataset, batch_size=32, height=224, width=224, is_training=False):
    """Loads a single batch of data.
    
    Args:
      dataset: The dataset to load.
      batch_size: The number of images in the batch.
      height: The size of each image after preprocessing.
      width: The size of each image after preprocessing.
      is_training: Whether or not we're currently training or evaluating.
    
    Returns:
      images: A Tensor of size [batch_size, height, width, 3], image samples that have been preprocessed.
      images_raw: A Tensor of size [batch_size, height, width, 3], image samples that can be used for visualization.
      labels: A Tensor of size [batch_size], whose values range between 0 and dataset.num_classes.
    """

    data_provider = slim.dataset_data_provider.DatasetDataProvider(
    dataset, common_queue_capacity=batch_size*3,
    common_queue_min=8)
    image_raw, label = data_provider.get(['image', 'label'])
    
    # Preprocess image for usage by Inception.
    image = vgg_preprocessing.preprocess_image(image_raw, height, width, is_training=is_training)
    
    # Preprocess the image for display purposes.
    image_raw = tf.expand_dims(image_raw, 0)
    image_raw = tf.image.resize_images(image_raw, [height, width])
    image_raw = tf.squeeze(image_raw)

    # Batch it up.
    images, images_raw, labels = tf.train.batch(
          [image, image_raw, label],
          batch_size=batch_size,
          num_threads=1,
          capacity=2 * batch_size)
    
    return images, images_raw, labels




In [None]:
# Note that this may take several minutes!!
# Fine-tune pre-trained netoworks.



slim = tf.contrib.slim
image_size = vgg.vgg_16.default_image_size


def get_init_fn():
    """Returns a function run by the chief worker to warm-start the training."""
    checkpoint_exclude_scopes=["vgg_16/fc8"]
    
    exclusions = [scope.strip() for scope in checkpoint_exclude_scopes]

    variables_to_restore = []
    for var in slim.get_model_variables():
        excluded = False
        for exclusion in exclusions:
            if var.op.name.startswith(exclusion):
                excluded = True
                break
        if not excluded:
            variables_to_restore.append(var)
    
    ## vgg
    return slim.assign_from_checkpoint_fn(
        os.path.join(checkpoints_dir, 'vgg_16.ckpt'),
        variables_to_restore)


train_dir = 'tmp/fish/vgg16_finetuned/'
fish_data_dir = 'tmp/fish/dataset'
checkpoints_dir = 'tmp/fish/my_checkpoints'
batch_size = 16
with tf.Graph().as_default():
    tf.logging.set_verbosity(tf.logging.INFO)
    
    dataset = get_split('train', fish_data_dir)
    
    images, _, labels = load_batch(dataset, height=image_size, width=image_size, batch_size=batch_size)
    
    # Create the model, use the default arg scope to configure the batch norm parameters.
    with slim.arg_scope(vgg.vgg_arg_scope()):
        logits, _ = vgg.vgg_16(images, num_classes=dataset.num_classes, is_training=True)
        
    # Specify the loss function:
    with tf.device('/cpu:0'):
        one_hot_labels = slim.one_hot_encoding(labels, dataset.num_classes)
    slim.losses.softmax_cross_entropy(logits, one_hot_labels)
    total_loss = slim.losses.get_total_loss()

    # Create some summaries to visualize the training process:
    tf.summary.scalar('losses/Total_Loss', total_loss)
  
    # Specify the optimizer and create the train op:
    optimizer = tf.train.AdamOptimizer(learning_rate=1e-4)
    train_op = slim.learning.create_train_op(total_loss, optimizer)
    
    # Run the training:
    # GPU error on incepton-v3 because of motherboard: "failed to query event: CUDA_ERROR_LAUNCH_FAILED,Unexpected Event status: 1"

#     config = tf.ConfigProto(device_count = {'gpu': 0})
    
    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True

    final_loss = slim.learning.train(
        train_op,
        logdir=train_dir,
        init_fn=get_init_fn(),
        number_of_steps=5,
        session_config=config,)
        
print('Finished training. Last batch loss %f' % final_loss)

In [None]:
# convnet visualize


slim = tf.contrib.slim

# image_size = inception.inception_v3.default_image_size
image_size = 224
batch_size = 8
fish_data_dir = 'tmp/fish/dataset'
train_dir = 'tmp/fish/vgg16_finetuned/'

convnets = []

with tf.Graph().as_default():
    tf.logging.set_verbosity(tf.logging.INFO)
    
    dataset = get_split('validation', fish_data_dir)
    images, images_raw, labels = load_batch(dataset, height=image_size, width=image_size, batch_size=batch_size)
    
    # Create the model, use the default arg scope to configure the batch norm parameters.
    with slim.arg_scope(vgg.vgg_arg_scope()):
        _convnet, _end_points = vgg.vgg_16(images, num_classes=dataset.num_classes, is_training=False)
#          _, _convnet = inception.inception_v3_base(images)
            
    checkpoint_path = tf.train.latest_checkpoint(train_dir)
    init_fn = slim.assign_from_checkpoint_fn(
      checkpoint_path,
      slim.get_variables_to_restore())
    
    with tf.Session() as sess:
        with slim.queues.QueueRunners(sess):
            sess.run(tf.local_variables_initializer())
            init_fn(sess)
            np_end_points, np_images_raw = sess.run([_end_points, images_raw])
            plt.rcParams['figure.figsize'] = (30.0, 15.0)
            np_convnet = np_end_points['vgg_16/conv2/conv2_2']
            for i in range(batch_size): 
                image = np_images_raw[i]
                convnet = np_convnet[i]
                vis_square(convnet.transpose(2, 0, 1))
                
                plt.figure()
                plt.imshow(image.astype(np.uint8))
                plt.axis('off')
                plt.show()
                

In [None]:
np_convnet = np_end_points['vgg_16/fc7']
for i in range(batch_size): 
    image = np_images_raw[i]
    convnet = np_convnet[i]
    vis_square(convnet.transpose(2, 0, 1))

    plt.figure()
    plt.imshow(image.astype(np.uint8))
    plt.axis('off')
    plt.show()

In [None]:
np_end_points.keys()