In [None]:
import tensorflow as tf
import pandas as pd
import os
from matplotlib import pyplot as plt
from tqdm import tqdm
import numpy as np
import gc

In [None]:
# Reading the dataset

filenames_inp=pd.DataFrame({"file_names":os.listdir("../input/massachusetts-roads-dataset/road_segmentation_ideal/training/input")})
filenames_output=pd.DataFrame({"file_names":os.listdir("../input/massachusetts-roads-dataset/road_segmentation_ideal/training/output")})
filenames=filenames_inp.loc[filenames_inp.file_names.isin(filenames_output["file_names"]),:]

In [None]:
# Reading the dataset

val_filenames_inp=pd.DataFrame({"file_names":os.listdir("../input/massachusetts-roads-dataset/road_segmentation_ideal/testing/input")})
val_filenames_output=pd.DataFrame({"file_names":os.listdir("../input/massachusetts-roads-dataset/road_segmentation_ideal/testing/output")})
val_filenames=val_filenames_inp.loc[val_filenames_inp.file_names.isin(filenames_output["file_names"]),:]

In [None]:
# Creating the dataset
NUM_PARALLEL_CALLS_DS=os.cpu_count()


def get_image_data(image_name):
#     with open('../input/massachusetts-roads-dataset/road_segmentation_ideal/training/input/' + image_name, "rb") as local_file: 
#         img = local_file.read()
    img = tf.io.read_file('../input/massachusetts-roads-dataset/road_segmentation_ideal/training/input/' + image_name)
    
#     with open('../input/massachusetts-roads-dataset/road_segmentation_ideal/training/output/' + image_name, "rb") as local_file: 
#         msk = local_file.read()
    
    msk = tf.io.read_file('../input/massachusetts-roads-dataset/road_segmentation_ideal/training/output/' + image_name)
    return tf.cast(tf.image.decode_png(img),tf.int32),tf.cast(tf.image.decode_png(msk),tf.int32),image_name

with tf.device('/cpu:0'):
    train_data=tf.data.Dataset.from_tensor_slices(filenames['file_names']).map(get_image_data,num_parallel_calls=NUM_PARALLEL_CALLS_DS).batch(10)
    validation_data=tf.data.Dataset.from_tensor_slices(val_filenames['file_names']).map(get_image_data,num_parallel_calls=NUM_PARALLEL_CALLS_DS).batch(10)

In [None]:
!mkdir train
!mkdir train/images
!mkdir train/output

!mkdir test
!mkdir test/images
!mkdir test/output

In [None]:
train_csv_lst=[]

for imgs in tqdm(train_data):
    
    patches_img=tf.image.extract_patches(images=imgs[0],
                               sizes=[1, 512, 512, 1],
                               strides=[1, 493, 493, 1],
                               rates=[1, 1, 1, 1],
                               padding='VALID')
    
    patches_mask=tf.image.extract_patches(images=imgs[1],
                               sizes=[1, 512, 512, 1],
                               strides=[1, 493, 493, 1],
                               rates=[1, 1, 1, 1],
                               padding='VALID')
    
    
    def save_images_at_row(path,flnm,patch_lst,rw,channels=3):
        filename=flnm.decode('utf8').split('.')[0]        
        np.save(path + filename  +'_'+ str(rw) + '_' + str(0) + '.npy',np.reshape(patch_lst[rw,0,:].numpy(),(512,512,channels)),allow_pickle=True)        
        np.save(path + filename +'_' + str(rw) + '_' + str(1) + '.npy',np.reshape(patch_lst[rw,1,:].numpy(),(512,512,channels)),allow_pickle=True)
        np.save(path + filename  +'_' + str(rw) + '_' + str(2) + '.npy',np.reshape(patch_lst[rw,2,:].numpy(),(512,512,channels)),allow_pickle=True)
        
        
    for k,flnm in enumerate(imgs[2].numpy().tolist()):                
        for u in range(0,3):
            save_images_at_row("train/images/",flnm,patches_img[k,:,:,:],u,channels=3)       
            save_images_at_row("train/output/",flnm,patches_mask[k,:,:,:],u,channels=1)   
    gc.collect()

In [None]:
plt.imshow(np.load('./train/images/img-542_0_1.npy',allow_pickle=True))

In [None]:
plt.imshow(np.load('./train/output/img-542_0_1.npy',allow_pickle=True))

In [None]:
for imgs in tqdm(validation_data):
    
    patches_img=tf.image.extract_patches(images=imgs[0],
                               sizes=[1, 512, 512, 1],
                               strides=[1, 493, 493, 1],
                               rates=[1, 1, 1, 1],
                               padding='VALID')
    
    patches_mask=tf.image.extract_patches(images=imgs[1],
                               sizes=[1, 512, 512, 1],
                               strides=[1, 493, 493, 1],
                               rates=[1, 1, 1, 1],
                               padding='VALID')
    
    
    def save_images_at_row(path,flnm,patch_lst,rw,channels=3):
        filename=flnm.decode('utf8').split('.')[0]        
        np.save(path + filename  +'_'+ str(rw) + '_' + str(0) + '.npy',np.reshape(patch_lst[rw,0,:].numpy(),(512,512,channels)),allow_pickle=True)
        np.save(path + filename +'_' + str(rw) + '_' + str(1) + '.npy',np.reshape(patch_lst[rw,1,:].numpy(),(512,512,channels)),allow_pickle=True)
        np.save(path + filename  +'_' + str(rw) + '_' + str(2) + '.npy',np.reshape(patch_lst[rw,2,:].numpy(),(512,512,channels)),allow_pickle=True)
    
    for k,flnm in enumerate(imgs[2].numpy().tolist()):                
        for u in range(0,3):
            save_images_at_row("test/images/",flnm,patches_img[k,:,:,:],u,channels=3)       
            save_images_at_row("test/output/",flnm,patches_mask[k,:,:,:],u,channels=1)        

In [None]:
# def plt_img_at_idx(k,i,j):
#     fig, (ax1, ax2) = plt.subplots(1, 2)
#     ax1.imshow(tf.reshape(patches_img[k,i,j,:],(512,512,3)))
#     ax2.imshow(tf.reshape(patches_mask[k,i,j,:],(512,512,1)))
    
# def merge_and_plot(k):
#     fin_img=np.zeros((1536,1536,3))
#     fin_img[0:512,0:512,3]=tf.reshape(patches[k,0,0,:],(512,512,3))
#     fin_img[512:512,0:512,3]=tf.reshape(patches[k,0,1,:],(512,512,3))
#     fin_img[0:512,0:512,3]=tf.reshape(patches[k,1,0,:],(512,512,3))
#     fin_img[0:512,0:512,3]=tf.reshape(patches[k,1,1,:],(512,512,3))

In [None]:
# fig, (ax1, ax2) = plt.subplots(1, 2)
# ax1.imshow(imgs[0][0])
# ax2.imshow(imgs[1][0])


# plt_img_at_idx(0,0,0)
# plt_img_at_idx(0,0,1)
# plt_img_at_idx(0,0,2)

# plt_img_at_idx(0,1,0)
# plt_img_at_idx(0,1,1)
# plt_img_at_idx(0,1,2)

# plt_img_at_idx(0,2,0)
# plt_img_at_idx(0,2,1)
# plt_img_at_idx(0,2,2)