In [None]:
# A simplistic experiment on cell image segmentation
# To form the image with RGB format, we simply concatenate 
# the input filter images with R-G-B sequence order and forming
# a new image. We chose one of the RGB filter color cell images 
# as the mask and segmenting the RGB color cell image. We designed 
# a simple Up and Down tensorflow layer stack for encoding and
# decoding purpose. We both show the concatednated images and 
# segmented images in this experiment.

import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)
import matplotlib.pyplot as plt
import tensorflow as tf
import os
# using python csv library for writing results
import csv
# loading data in windows environment
import pathlib
import re
print (tf.__version__)
from kaggle_datasets import KaggleDatasets
import PIL.Image
import cv2
from io import StringIO
# Hardware platform: You may want to scale your training onto multiple GPUs on one machine, 
# or multiple machines in a network (with 0 or more GPUs each), or on Cloud TPUs.
try:
    tpu = tf.distribute.cluster_resolver.TPUClusterResolver(tpu='')
    print('Running on TPU ', tpu.master())
except ValueError:
    tpu = None


if tpu:
    # Passing in the name of the CloudTPU
    tf.config.experimental_connect_to_cluster(tpu)
    # The TPU initialization code has to be at the beginning of the program
    tf.tpu.experimental.initialize_tpu_system(tpu)
    # A distribution strategy is an abstraction that can be used to drive models on CPU, GPUs or TPUs
    strategy = tf.distribute.experimental.TPUStrategy(tpu)
else:
    strategy = tf.distribute.experimental.CentralStorageStrategy()
    print ("GPU VERSION")
    
# You can also write temporary files to /kaggle/temp/, but they won't be saved outside of the current session
GCS_DS_PATH = KaggleDatasets().get_gcs_path('hpa-single-cell-image-classification')
GCS_PATH = GCS_DS_PATH + '/train_tfrecords'
#TEST_GCS_PATH = GCS_DS_PATH + '/test_tfrecords'
AUTO = tf.data.experimental.AUTOTUNE

#TRAINING_FILENAMES = tf.io.gfile.glob("../input/cassava-leaf-disease-classification/" + 'train_tfrecords/*.tfrec')
TRAINING_FILENAMES = tf.io.gfile.glob(GCS_PATH + '/*.tfrec')
#TEST_NAMES = tf.io.gfile.glob("../input/cassava-leaf-disease-classification/" + 'test_images/*.jpg')

IMG_HEIGHT=2048
IMG_WIDTH=2048

"""
Reading the tfrecord data, creating dataset from
tfrecord data.

"""
DATASET = tf.data.TFRecordDataset(TRAINING_FILENAMES)
#for dat in DATASET.take(1):
    #V = tf.train.Example()
    #V.ParseFromString(dat.numpy())
    #print (V)

def mapping_image_id(dat):
    TFREC_MAP = {"image": tf.io.FixedLenFeature([], tf.string),
                 "target": tf.io.FixedLenFeature([], tf.string),
                 "image_name": tf.io.FixedLenFeature([], tf.string),
                }
    parsed_dat = tf.io.parse_single_example(dat, TFREC_MAP)
    image = tf.io.decode_png(parsed_dat['image'], channels=3)
    # Normalizing Image
    image = tf.cast(image/255, dtype=tf.float32)
    # Parsing the target data
    target = parsed_dat['target']
    # Parsing the image name
    name = parsed_dat['image_name']
    #target = tf.cast(target, dtype=tf.int32)
    return image, target, name

# Processing Input Training Data
train_data_set = tf.data.TFRecordDataset(TRAINING_FILENAMES, num_parallel_reads=AUTO)
#train_img = train_data_set.map(mapping_image)
#train_img = train_img.batch(8000)
#train_target = train_data_set.map(mapping_id)
#train_target = train_target.batch(8000)
train_data = train_data_set.map(mapping_image_id)

# Creating training and testing dataset.
# Getting ready for model data input.
def data_input(S):
    # Image List
    I=[]
    # Name List
    N=[]
    # Label List
    L=[]
    K=0
    # Scaning through dataset
    for dat, lab, name in train_data:
        # exit when exceeding length 
        if (K>=S):
            break
        #re_image = tf.image.central_crop(dat, 0.5)
        #re_image = tf.image.resize_with_pad(dat, IMG_HEIGHT,IMG_WIDTH, method=tf.image.ResizeMethod.BICUBIC,antialias=True)
        #re_image = tf.reshape(re_image, shape=[1,IMG_HEIGHT, IMG_WIDTH,3])
        #re_image = re_image[0].numpy()
        #IMG.append(re_image)
        #L.append(lab)
        I.append(dat)
        N.append(name.numpy().decode('utf-8'))
        L.append(lab.numpy().decode('utf-8'))
        #print (L)
        # Adding Image to the list
        # Adding Name to the list
        #W = name.numpy().decode('utf-8')
        #W = W.split('_')
        #if (W[1]=='green'):
            #print (W[0])
            #I.append(dat)
            #N.append(W)
            #LL = lab.numpy().decode('utf-8')
            #LL = LL.split('|')
            #LZ = int(LL[0])
            #L.append(LZ)
            #K=K+1
            #print (K)
        #N.append(name.numpy().decode('utf-8'))
        # Adding Label to the list
        K = K + 1
    return I, N, L

DATA_SIZE = 9

# Color green files
green_files = tf.io.gfile.glob('../input/hpa-single-cell-image-classification/train/*_green.png')
# Color yellow files
yellow_files = tf.io.gfile.glob('../input/hpa-single-cell-image-classification/train/*_yellow.png')
# Color blue files
blue_files = tf.io.gfile.glob('../input/hpa-single-cell-image-classification/train/*_blue.png')
# Color red files
red_files = tf.io.gfile.glob('../input/hpa-single-cell-image-classification/train/*_red.png')


yellow_files.sort()
green_files.sort()
blue_files.sort()
red_files.sort()

# Load Data Images
def LoadImages(S):
    Images = []
    for index in range(S):
        IMG = tf.io.read_file(filenames[index])
        # Decoding the image
        #IMG = tf.io.decode_png(IMG, channels=3)
        # Resize the image
        # IMG = tf.cast(IMG/255, dtype=tf.float32)
        #IMG = tf.image.resize(IMG, [1024,1024],method='bicubic')
        Images.append(IMG)
    return Images

#Generate color red cell images
def Gen_color_Img(S, names):
    Images = []
    for index in range(S):
        # Read the raw data
        IMG = tf.io.read_file(names[index])
        # Decoding the raw data
        IMG = tf.io.decode_png(IMG, channels=1)
        # Resize the image
        IMG = tf.image.resize(IMG, [2048,2048],method='bicubic')
        Images.append(IMG)
    return Images
# Generate color red images
RED_CELL = Gen_color_Img(DATA_SIZE, red_files)
# Generate color green images
GREEN_CELL = Gen_color_Img(DATA_SIZE, green_files)
# Generate color blue images
BLUE_CELL = Gen_color_Img(DATA_SIZE, blue_files)
# Generate color yellow images
YELLOW_CELL = Gen_color_Img(DATA_SIZE, yellow_files)

"""
Left here
"""
def input_data_processing(S):
    DATA = []
    for index in range(S):
        IMG = tf.concat([RED_CELL[index], GREEN_CELL[index], BLUE_CELL[index]], 2)
        DATA.append(IMG)
    plt.show()
    return DATA

# Displaying the cell image
def Display_M(dat, data_size):
    fig = plt.figure(figsize=(18,18))
    for Z in range(data_size):
        ax = fig.add_subplot(3,3,Z+1)
        plt.imshow(dat[Z])
        plt.title("Unsegmented {}".format(Z))
    plt.show()
    
# Processing input data
input_data = input_data_processing(DATA_SIZE)
# Display data
Display_M(input_data, DATA_SIZE)
# Convert to tensor
input_data = tf.convert_to_tensor(input_data)

# Creating mask
def LoadFilters(F,S):
    Images = []
    for index in range(S):
        IMG = tf.io.read_file(F[index])
        # Decoding the image
        IMG = tf.io.decode_png(IMG, channels=3)
        Images.append(IMG)
    return Images

mask_cell = LoadFilters(red_files,DATA_SIZE)
mask_cell = tf.convert_to_tensor(mask_cell)
#green_imgs.reverse()

# Extracting Data
#IMG, NAME, LAB = data_input(DATA_SIZE)
#print (NAME)
#LAB = np.array(LAB)
#print (LAB.shape)

#IMG = tf.convert_to_tensor(IMG)
#print (IMG.shape)
        
# Showing the cell image      

# Encoder Stack
def down_sample(b_id):
    init = tf.random_normal_initializer(0., 0.02)
    X = tf.keras.Sequential()
    X.add(tf.keras.layers.DepthwiseConv2D(kernel_size=(3,3),kernel_initializer=init, padding='same',name='Conv2D_{}'.format(b_id), strides=(2,2)))
    X.add(tf.keras.layers.BatchNormalization(axis=1, name='batch_norm_{}'.format(b_id)))
    X.add(tf.keras.layers.ReLU(max_value=6.0,name='relu_{}'.format(b_id)))
    return X

# Decoder Stack
def up_sample(filters, b_id):
    X = tf.keras.Sequential()
    X.add(tf.keras.layers.Conv2DTranspose(filters, kernel_size=(3,3), padding='same', name='Transpose_Conv2D_{}'.format(b_id),strides=(2,2)))
    X.add(tf.keras.layers.BatchNormalization(axis=1,name='batch_norm_transpose_{}'.format(b_id)))
    X.add(tf.keras.layers.ReLU(max_value=6.0,name='relu_{}'.format(b_id)))
    return X

# Generate Data for model
def gen_data_patch(S, IMG_DATA):
    DATA = []
    for dat in range(S):
        #IMG_K = tf.image.resize(IMG_DATA[dat], [2048,2048],method='bicubic')
        DATA.append(IMG_DATA[dat])
    DATA = tf.convert_to_tensor(DATA)
    return DATA

# Color red data
RED_DATA = gen_data_patch(DATA_SIZE, RED_CELL)
# Color green data
GREEN_DATA = gen_data_patch(DATA_SIZE, GREEN_CELL)
# Color blue data
BLUE_DATA = gen_data_patch(DATA_SIZE, BLUE_CELL)
# Color yellow data
YELLOW_DATA = gen_data_patch(DATA_SIZE, YELLOW_CELL)

with strategy.scope():
    b_id = 0
    model_input = tf.keras.layers.Input(shape=[IMG_WIDTH,IMG_HEIGHT,3]) #2048X2048
    # Color green cell input
    model_input_green = tf.keras.layers.Input(shape=[IMG_WIDTH,IMG_HEIGHT,1]) #2048X2048
    # Color red cell input
    model_input_red = tf.keras.layers.Input(shape=[IMG_WIDTH,IMG_HEIGHT,1]) #2048X2048
    # Color yellow cell input
    model_input_mask = tf.keras.layers.Input(shape=[IMG_WIDTH,IMG_HEIGHT,3]) #2048X2048
    
    # Concatenate layer
    the_input = tf.keras.layers.concatenate([model_input, model_input_mask])
    
    # model input

    
    # Encoder Layer
    down_layers = [down_sample(b_id+1), #512X512
                   down_sample(b_id+2), #256X256
                   down_sample(b_id+3), #128X128
                   down_sample(b_id+4), #64X64
                   down_sample(b_id+5), #32X32
                   down_sample(b_id+6), #16X16
                  ]
    # Decoder Layer
    up_layers = [up_sample(512,b_id+4),#32X32
                 up_sample(256,b_id+5),#64X64
                 up_sample(128,b_id+6),#128X128
                 up_sample(64,b_id+7),#256X256
                 up_sample(32,b_id+8),#512X512
                ]
    
    # Generator building block
    #X = the_input
    X = the_input
    down_out = []
    for layers in down_layers:
        X = layers(X)
        down_out.append(X)
    Z = down_out[-1]
    # Reversing for the concatenate layer
    down_out = reversed(down_out[:-1])
    
    # Building the Generator
    for up, down in zip(up_layers, down_out):
        Z = up(Z)
        Z = tf.keras.layers.Concatenate()([Z, down])
        
    # Final output layer
    Z = tf.keras.layers.Conv2DTranspose(3, kernel_size=(3,3), padding='same', name='Transpose_Conv2D_{}'.format(b_id),strides=(2,2))(Z)
    model_opt_func = tf.keras.optimizers.SGD(lr=0.0001)
    model = tf.keras.Model(inputs=[model_input, model_input_mask], outputs=Z)
    #model.compile(optimizer=model_opt_func, loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),metrics=['accuracy'])
    #model.summary()

# Displaying the cell image
def Display_S(dat, g_id):
    fig = plt.figure(figsize=(18,18))
    for Z in range(g_id):
        ax = fig.add_subplot(3,3,Z+1)
        plt.imshow(dat[Z])
        plt.title("Segmented {}".format(Z))
    plt.show()

# Model output
gen_out = model([input_data,mask_cell],  training=False)

# Displaying the model result
Display_S(gen_out, DATA_SIZE)

# You can write up to 20GB to the current directory (/kaggle/working/) that gets preserved as output when you create a version using "Save & Run All" 
# You can also write temporary files to /kaggle/temp/, but they won't be saved outside of the current session