In [1]:
import tensorflow as tf
import matplotlib.pyplot as plt

In [2]:
data_path = 'data/'
data_preprocessed_dir_path = 'data_p/'
train_image_path = data_path + 'train_v2/'

In [3]:
def load_csv(path):
    return tf.data.experimental.make_csv_dataset(
        path,
        batch_size=1, # required
        column_names=['ImageId', 'EncodedPixels'],
        num_epochs=1,
        shuffle=False,
    )

In [4]:
IMG_HEIGHT = 768
IMG_WIDTH = 768

def decode_label_mask(encoded_pixels, image_height, image_width):
    mask = tf.zeros(image_height * image_width, dtype=tf.float32)

    # Convert string to integer tensor
    pairs = tf.strings.to_number(tf.strings.split(encoded_pixels), out_type=tf.int32)

    # Iterate over pairs and update mask
    for i in range(0, len(pairs), 2):
        start = pairs[i] - 1
        run_length = pairs[i + 1]

        indices = tf.range(start, start + run_length)
        updates = tf.ones(run_length, dtype=tf.float32)
        mask = tf.tensor_scatter_nd_update(mask, indices=tf.expand_dims(indices, axis=1), updates=updates)

    return  tf.transpose(tf.reshape(mask, (image_height, image_width)))

In [5]:
def make_dir(csv_path):
    filename = csv_path.split('/')[-1].split('.')[0]
    # FIXME: use os.path.join instead of string concatenation
    image = data_preprocessed_dir_path + filename + '/' + 'image/'
    label = data_preprocessed_dir_path + filename + '/' + 'label/'

    import os
    if not os.path.exists(image):
        os.makedirs(image)
    if not os.path.exists(label):
        os.makedirs(label)
    
    return image, label

In [22]:
csv_file = 'm.csv'
image_dir, labels_dir = make_dir(csv_file)

csv = load_csv(data_preprocessed_dir_path + csv_file)

def process_img(file):
    # FIXME: use some variation of cp instead of read/write
    img = tf.io.read_file(train_image_path + file)
    tf.io.write_file(image_dir + file, img) # FIXME: drop .jpg extension
    return img

def process_label(label, file):
    mask = decode_label_mask(label, IMG_HEIGHT, IMG_WIDTH)
    encoded_mask = tf.io.serialize_tensor(mask)
    tf.io.write_file(labels_dir + file, encoded_mask)
    return mask

def process_batch(csv_item):
    X = process_img(csv_item['ImageId'])
    y = process_label(csv_item['EncodedPixels'], csv_item['ImageId'])
    return X, y

# use map in hacky way to process each batch in parallel
r = csv.unbatch().map(process_batch)
# apply the function to each batch
for batch in r.batch(32):
    pass