In [None]:
import tensorflow as tf
import pandas as pd
import cv2
from tqdm import tqdm

BASE_PATH = "../data"
ID_COL = "StudyInstanceUID"
SPLITS = 5

train = pd.read_csv(f'{BASE_PATH}/train.csv')
test = pd.read_csv(f'{BASE_PATH}/sample_submission.csv')
train["path"] = f"{BASE_PATH}/train/" + train[ID_COL].astype(str) + ".jpg"
test["path"] = f"{BASE_PATH}/test/" + test[ID_COL].astype(str) + ".jpg"
target_cols = train.drop(columns=[ID_COL] + ["PatientID", "path"]).columns

train["fold"] = pd.cut(train.index, SPLITS, labels=False)
test["fold"] = pd.cut(test.index, SPLITS, labels=False)

filename_train = "train.tfrecord"
filename_test = "test.tfrecord"


# 下記の関数を使うと値を tf.Example と互換性の有る型に変換できる

def _bytes_feature(value):
    """string / byte 型から byte_list を返す"""
    if isinstance(value, type(tf.constant(0))):
        value = value.numpy()  # BytesList won't unpack a string from an EagerTensor.
    return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))


def _float_feature(value):
    """float / double 型から float_list を返す"""
    return tf.train.Feature(float_list=tf.train.FloatList(value=[value]))


def _int64_feature(value):
    """bool / enum / int / uint 型から Int64_list を返す"""
    return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))


def _int64_list_feature(value):
    """List[bool / enum / int / uint] 型から Int64_list を返す"""
    return tf.train.Feature(int64_list=tf.train.Int64List(value=value))


def serialize_train(image, target, image_name):
    feature = {
        'image': _bytes_feature(image),
        "target": _int64_list_feature(target.tolist()),
        'image_name': _bytes_feature(image_name)
    }
    example_proto = tf.train.Example(features=tf.train.Features(feature=feature))
    return example_proto.SerializeToString()


def serialize_test(image, image_name):
    feature = {
        'image': _bytes_feature(image),
        'image_name': _bytes_feature(image_name)
    }
    example_proto = tf.train.Example(features=tf.train.Features(feature=feature))
    return example_proto.SerializeToString()


def save_tfrecord(filename, data, train=True):
    with tf.io.TFRecordWriter(filename) as writer:
        for i, row in tqdm(data.iterrows()):
            if train:
                label = row[target_cols]
            path = row["path"]

            img = cv2.imread(path)
            img = cv2.imencode('.jpg', img, (cv2.IMWRITE_JPEG_QUALITY, 100))[1].tobytes()
            if train:
                example = serialize_train(img, label, str.encode(row[ID_COL]))
            else:
                example = serialize_test(img, str.encode(row[ID_COL]))
            writer.write(example)


for i in range(SPLITS):
    filename = f"train_{i}.tfrecord"
    use_data = train[train["fold"] == i]
    save_tfrecord(filename, use_data, train=True)

for i in range(SPLITS):
    filename = f"test_{i}.tfrecord"
    use_data = test[test["fold"] == i]
    save_tfrecord(filename, use_data, train=False)
