In [None]:
# This Python 3 environment comes with many helpful analytics libraries installed
# It is defined by the kaggle/python Docker image: https://github.com/kaggle/docker-python
# For example, here's several helpful packages to load

import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)
import os
import tensorflow as tf 
import cv2
import matplotlib
import matplotlib.pyplot as plt
import time
import random
import tensorflow_addons as tfa
from sklearn.preprocessing import LabelEncoder
from tqdm import tqdm



AUTO = tf.data.AUTOTUNE

In [None]:
def rle_decode(mask_rle, shape):
    '''
    mask_rle: run-length as string formated (start length)
    shape: (height,width) of array to return 
    Returns numpy array, 1 - mask, 0 - background

    '''
    s = mask_rle.split()
    starts, lengths = [np.asarray(x, dtype=int) for x in (s[0:][::2], s[1:][::2])]
    starts -= 1
    ends = starts + lengths
    img = np.zeros(shape[0]*shape[1], dtype=np.uint8)
    for lo, hi in zip(starts, ends):
        img[lo:hi] = 1
    return img.reshape(shape)


In [None]:
def _bytes_feature(value):         # S/O la doc tensorflow et le livre pour convertyre en byte pour le TFRECORD
  """Returns a bytes_list from a string / byte."""
  if isinstance(value, type(tf.constant(0))):
    value = value.numpy() # BytesList won't unpack a string from an EagerTensor.
  return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))


def _int64_feature(value):
  """Returns an int64_list from a bool / enum / int / uint."""
  return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))

In [None]:
def build_tfrecord_2():
    start = time.time() 
    
    df = pd.read_csv('../input/sartorius-cell-instance-segmentation/train.csv')    #datafrme
    shape = (520, 704)     # all image in test data have the same shape
    UI = df["id"].unique() # Nombre de De photo / id differente        #606 
    
    df1 = df.groupby('id',as_index=False,sort=False).last()
    lbl = LabelEncoder()
    df1["cell_type"] = lbl.fit_transform(df1["cell_type"])
    option = tf.io.TFRecordOptions(compression_level=2, compression_type="ZLIB")
    
    for n in range(3):
        with tf.io.TFRecordWriter(f"Train_type_{n}.tfrec", options=option) as writer:
            for i in tqdm(range(len(UI))):   
                if df1["cell_type"].iloc[i]==n:
                    img_masks = df.loc[df['id']==UI[i], 'annotation'].to_list() # all mask in 1 list 
                    img = cv2.imread(f"../input/sartorius-cell-instance-segmentation/train/{UI[i]}.png")      # Image 
                    all_masks = np.zeros(shape, dtype=np.float32)
                    img = np.true_divide(img, 255, dtype=np.float32)
            
                    for mask in img_masks:
                        all_masks += rle_decode(mask, shape) # mask
            
                    all_masks[all_masks > 1] = 1   #pour avoir que des 1 ou 0 pour le mask 
                    #print(img.dtype , all_masks.dtype , "ULTRA IMPORTANT")                             # use to check dtype
                    #print(all_masks.dtype)
            
                    data = {'image': _bytes_feature(img.tobytes()),
                        'mask': _bytes_feature(all_masks.tobytes()),
                        'label': _int64_feature(df1["cell_type"].iloc[i]),
                        }
        
                    Data = tf.train.Example(features=tf.train.Features(feature=data))
                    Data = Data.SerializeToString()

                    writer.write(Data)
                
    
    elapsed = time.time()
    elapsed = elapsed - start
    print("Time spent: ", elapsed)

In [None]:
def decode_mask(image_data):
    image = tf.io.decode_raw(image_data['image'], tf.float32)
    image = tf.reshape(image, [520,704,3])
    mask = tf.io.decode_raw(image_data['mask'], tf.float32)
    mask = tf.reshape(mask, [520,704,1])
    return (image, mask)


def read_mask(exemple):
    image_feature_description = {
        'image': tf.io.FixedLenFeature([], tf.string),
        'label': tf.io.FixedLenFeature([], tf.int64),
        'mask': tf.io.FixedLenFeature([], tf.string)
    }
    Exemple = tf.io.parse_single_example(exemple , image_feature_description)
    return decode_mask(Exemple)

In [None]:
def load_dataset(path , Augment = False , Big = False , label=True):
    dataset = tf.data.TFRecordDataset(path,compression_type="ZLIB", num_parallel_reads=AUTO)
    if label:
        dataset = dataset.map(read_label, num_parallel_calls= AUTO)
    else:
        dataset = dataset.map(read_mask, num_parallel_calls= AUTO)
    if Big:
        dataset = dataset.repeat(6)
    if Augment :
        if label:
            dataset = dataset.map(augment3L,num_parallel_calls = AUTO) 
        else:
            dataset = dataset.map(augment3,num_parallel_calls = AUTO)
    return dataset

In [None]:
def victor(path, Augment = False ,big = False, label=True ):
    data = load_dataset(path, Augment = Augment  , Big = big, label=label)  # class 'tensorflow.python.data.ops.dataset_ops.ParallelMapDataset'   
    N = 5
    i = 0
    '''
    #pour connaitre le nombre d'element dans un dataset
    for _ in data:
        i+=1
    
    print(i)'''
    

    ds = data.take(N)
    if label:
        fig, axarr = plt.subplots(N,1, figsize=(15, 40))
        for image ,label  in ds:
            print(image.shape , type(image))
            print(label.shape)
        
            #image 
            axarr[i].imshow(image)
            axarr[i].axis('off')
            axarr[i].set_title(f'Masks {label}')
            i+=1
    else:
        fig, axarr = plt.subplots(N,2, figsize=(15, 40))
        for image ,mask  in ds:
            print(image.shape , type(image))
            print(mask.shape)
        
            reshape = tf.reshape(mask ,[520*704])
            Unique = tf.unique(reshape)
            print(mask.shape)
            print(Unique)
        
            #mask
            axarr[i, 1].imshow(mask)
            axarr[i, 1].axis('off')
            axarr[i, 1].set_title(f'Masks {i}')
    
            #image 
            axarr[i, 0].imshow(image)
            axarr[i, 0].axis('off')
            axarr[i, 0].set_title(f'Masks {i}')
            i+=1
    
    plt.tight_layout(h_pad=0.1, w_pad=0.1)
    plt.show()
    

In [None]:
if __name__ == "__main__":
    path0 = "./Train_type_0.tfrec"
    path1 = "./Train_type_1.tfrec"
    path2 = "./Train_type_2.tfrec"
    
    build_tfrecord_2()
    
    '''victor(path0,Augment = False, big = True ,label=False)
    victor(path1,Augment = False, big = True ,label=False)
    victor(path2,Augment = False, big = True ,label=False)
    '''