In [None]:
import numpy as np
import tensorflow as tf

import scipy.io as sio
from os import path
from PIL import Image
from PIL import ImageOps
import matplotlib.pyplot as plt

In [None]:
# Each zebrafish sample is reflected to create 3 samples: the original, reflection about y-axis, 
# and reflection about x-axis. 
# Each of these 3 samples are rotated to generate the desired number of images per sample.

In [None]:
# Names of sample files: sample_names-xxxxx
sample_names = 'zebrafish'
record_dir = path.join('..', 'data', '2d_data', 'tf_records', 'zebrafish_and_beads', 'ground_truths')

# Target number of images per image, should be divisible by 3.
num_images_per = 99
obj_dims = (648, 486)

In [None]:
zebrafish_mat_file = path.join('..', 'data', '2d_data', 'real_data', 'fish_double.mat')

In [None]:
zfs = np.array(sio.loadmat(zebrafish_mat_file)['fish'])
zfs = zfs.transpose(2, 0, 1)    # Transpose to achieve n, y, x order

In [None]:
plt.imshow(zfs[0])

In [None]:
def calc_pad_widths(im, target_dims):
    """
    Calculates and returns pad width so that rotation does not get cut off. 
    Image must be numpy array.
    """
    padded_size = max(np.max(target_dims), 
                      int(np.max(np.shape(im)) * np.sqrt(2))) # either target image size, or longest diagonal
    y, x = im.shape
    y_diff, x_diff = padded_size - y, padded_size - x
    pad_widths = ((int(np.ceil(y_diff / 2)), int(np.floor(y_diff / 2))),
                    (int(np.ceil(x_diff / 2)), int(np.floor(x_diff / 2))))
    
    return pad_widths

def crop(im, target_dims):
    """
    Crops image numpy array to target dims (y, x), centering image. 
    Returns cropped image.
    """
    y1, x1 = im.shape
    y2, x2, = target_dims
    y_diff, x_diff = y1 - y2, x1 - x2
    c_y1, c_y2 = int(np.ceil(y_diff / 2)), y1 - int(np.floor(y_diff / 2))
    c_x1, c_x2 = int(np.ceil(x_diff / 2)), x1 - int(np.floor(x_diff / 2))
                    
    return im[c_y1:c_y2, c_x1:c_x2]

def augment(im, target_dims, num_degrees, transform=None):
    """
    Takes in numpy array. First, pads image with zeroes such that rotation fits in the image.
    Uses PIL to apply transform and rotation. Crops the image back to be of obj_target_dims size.
    Returns resulting numpy array.
    
    Arguments:
        - im: image to be transformed
        - num_degrees: number of degrees to rotate
        - target_dims: target image size (y, x).
        - transform: transform to be applied before rotation
            - 'flip': reflects along x-axis
            - 'mirror': reflects along y-axis
    """

    pad_widths = calc_pad_widths(im, target_dims)
    im = np.pad(im, pad_width=pad_widths) 
    im = Image.fromarray(im)
    if transform == 'flip':
        im = ImageOps.flip(im)
    elif transform == 'mirror':
        im = ImageOps.mirror(im)
    im = im.rotate(num_degrees)
    im = np.asarray(im)
    im = crop(im, target_dims)
    
    return im

def _create_example(plane):
    """
    Creates and returns tf.Example from a given numpy array.
    """
    plane_feature = tf.train.Feature(float_list=tf.train.FloatList(value=plane.ravel()))
    feature = {
        'plane': plane_feature
    }
    return tf.train.Example(features=tf.train.Features(feature=feature))

In [None]:
num_rotations = num_images_per // 3
degree_increment = 360 / num_rotations

for i in range(len(zfs)):
    zf = zfs[i]
    for j in range(num_rotations):
        num_degrees = degree_increment * j
        zf_aug = augment(zf, obj_dims, num_degrees, transform=None)
        record_file = path.join(record_dir, sample_names + '-%.5d' % (num_images_per*i + 3*j))
        with tf.io.TFRecordWriter(record_file) as writer:
            tf_example = _create_example(zf_aug)
            writer.write(tf_example.SerializeToString())
            
#         plt.imshow(zf_aug)
#         plt.title('i: {0}, num_degrees: {1}, transform: {2}'.format(i, num_degrees, 'None'))
#         plt.show()
        
        zf_aug = augment(zf, obj_dims, num_degrees, transform='flip')
        record_file = path.join(record_dir, sample_names + '-%.5d' % (num_images_per*i + 3*j + 1))
        with tf.io.TFRecordWriter(record_file) as writer:
            tf_example = _create_example(zf_aug)
            writer.write(tf_example.SerializeToString())
            
#         plt.imshow(zf_aug)
#         plt.title('i: {0}, num_degrees: {1}, transform: {2}'.format(i, num_degrees, 'flip'))
#         plt.show()
        
        zf_aug = augment(zf, obj_dims, num_degrees, transform='mirror')
        record_file = path.join(record_dir, sample_names + '-%.5d' % (num_images_per*i + 3*j + 2))
        with tf.io.TFRecordWriter(record_file) as writer:
            tf_example = _create_example(zf_aug)
            writer.write(tf_example.SerializeToString())
            
#         plt.imshow(zf_aug)
#         plt.title('i: {0}, num_degrees: {1}, transform: {2}'.format(i, num_degrees, 'mirror'))
#         plt.show()
        

In [None]:
fig = plt.figure(figsize=(30, 30))
s = fig.add_subplot(1, 3, 1)
plt.imshow(zf_3)
s = fig.add_subplot(1, 3, 2)
plt.imshow(zf_2)
s = fig.add_subplot(1, 3, 3)
plt.imshow(zf_1)
plt.show()

In [None]:
im = zfs[0]
dims = (648, 486)
pad_widths = calc_pad_widths(im, dims)
im = np.pad(im, pad_width=pad_widths) 
plt.imshow(im)
plt.show()
temp = im
print(im.shape)

im = crop(im, dims)
plt.imshow(im)
plt.show()
print(im.shape)


assert np.all(crop(temp, (512, 512)) == zfs[0])


In [None]:
fig = plt.figure(figsize = (20, 20))
for i in range(10):
    num_degrees = i * 36
    fig.add_subplot(2, 5, i+1)
    plt.imshow(augment(zfs[0], dims, num_degrees, transform='flip'))

plt.show()