In [42]:
import tensorflow as tf

In [43]:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "0"

gpus = tf.config.experimental.list_physical_devices('GPU')
if gpus:
	for gpu in gpus:
		tf.config.experimental.set_memory_growth(gpu, True)

In [44]:
source_pth = r'./data/datalist/training_mr.txt'
with open(source_pth, 'r') as f:
    rows = f.readlines()

imagea_list = [row[: -1] for row in rows]

In [45]:
imagea_dataset = tf.data.TFRecordDataset(imagea_list)

In [46]:
decomp_feature = {
        # image size, dimensions of 3 consecutive slices
        'dsize_dim0': tf.io.FixedLenFeature([], tf.int64), # 256
        'dsize_dim1': tf.io.FixedLenFeature([], tf.int64), # 256
        'dsize_dim2': tf.io.FixedLenFeature([], tf.int64), # 3
        # label size, dimension of the middle slice
        'lsize_dim0': tf.io.FixedLenFeature([], tf.int64), # 256
        'lsize_dim1': tf.io.FixedLenFeature([], tf.int64), # 256
        'lsize_dim2': tf.io.FixedLenFeature([], tf.int64), # 1
        # image slices of size [256, 256, 3]
        'data_vol': tf.io.FixedLenFeature([], tf.string),
        # label slice of size [256, 256, 3]
        'label_vol': tf.io.FixedLenFeature([], tf.string)}

In [47]:
from typing import Dict, Any, Tuple

def decode(serialized_example: tf.string) -> Dict[str, tf.Tensor]:
    return tf.io.parse_single_example(serialized_example,  decomp_feature)

In [48]:
decoded_dataset = imagea_dataset.map(decode, num_parallel_calls=tf.data.experimental.AUTOTUNE)

In [49]:
def parse(data: Dict[str, Any]) -> Tuple[tf.Tensor, tf.Tensor]:
    raw_size = [256, 256, 3]
    volume_size = [256, 256, 3]
    label_size = [256, 256, 1] # the label has size [256,256,3] in the preprocessed data, but only the middle slice is used

    data_vol = tf.io.decode_raw(data['data_vol'], tf.float32)
    data_vol = tf.reshape(data_vol, raw_size)
    data_vol = tf.slice(data_vol, [0, 0, 0], volume_size)

    label_vol = tf.io.decode_raw(data['label_vol'], tf.float32)
    label_vol = tf.reshape(label_vol, raw_size)
    label_vol = tf.slice(label_vol, [0, 0, 1], label_size)

    batch_y = tf.one_hot(tf.cast(tf.squeeze(label_vol), tf.uint8), 5)

    return tf.expand_dims(data_vol[:, :, 1], axis=2), batch_y

In [50]:
parsed_dataset = decoded_dataset.map(parse, num_parallel_calls=tf.data.experimental.AUTOTUNE)

In [51]:
for data in imagea_dataset.take(1):
    print(type(data))

<class 'tensorflow.python.framework.ops.EagerTensor'>


In [53]:
iter_data = iter(parsed_dataset)
image, label = next(iter_data)
image, label = next(iter_data)
print(tf.shape(image), tf.shape(label))

tf.Tensor([256 256   1], shape=(3,), dtype=int32) tf.Tensor([256 256   5], shape=(3,), dtype=int32)


In [57]:
source_dataset = parsed_dataset.repeat().shuffle(100).batch(2)