In [None]:
# Import Required Libraries
import os, gc, math, cv2, pandas as pd, numpy as np, tensorflow as tf
from skimage.io import imread
from skimage.transform import resize
from sklearn.model_selection import train_test_split

# Enable garbage collection and suppress warnings for clean output
gc.enable()
import warnings
warnings.filterwarnings('ignore')

In [None]:
# Set Global Parameters & Paths
IMAGE_SIZE = (384, 384)
TRAIN_DIR = "train_v2"
CSV_PATH = "train_ship_segmentations_v2.csv"
OUTPUT_DIR = "./tfrecords"
os.makedirs(OUTPUT_DIR, exist_ok=True)
NUM_SHARDS = 30
SAMPLED_NO_SHIPS = 8000
SEED = 42

In [None]:
def rle_decode(mask_rle, shape=(768, 768)):
    """
    Decode a run-length encoded (RLE) string into a 2D binary mask.
    
    Parameters:
        mask_rle (str): The run-length encoded mask string.
        shape (tuple): The shape of the mask (height, width).
    
    Returns:
        mask (ndarray): 2D binary mask array (values 0 or 1), transposed.
    """
    s = mask_rle.split()
    starts, lengths = [np.asarray(x, dtype=int) for x in (s[::2], s[1::2])]
    starts -= 1
    ends = starts + lengths
    mask = np.zeros(shape[0]*shape[1], dtype=np.uint8)
    for lo, hi in zip(starts, ends):
        mask[lo:hi] = 1
    return mask.reshape(shape).T

In [None]:
def preprocess(image, mask):
    """
    Resize the image and mask to the target IMAGE_SIZE.
    
    Parameters:
        image (ndarray): Original image array.
        mask (ndarray): Original mask array.
    
    Returns:
        image, mask: Resized image and binary mask as uint8.
    """
    image = resize(image, IMAGE_SIZE, preserve_range=True, anti_aliasing=True).astype(np.uint8)
    mask = resize(mask, IMAGE_SIZE, preserve_range=True, anti_aliasing=False, order=0).astype(np.uint8)
    mask = (mask > 0).astype(np.uint8)
    return image, mask

def create_tf_example(image_id, df):
    """
    Create a TFRecord Example for a given image.
    
    Reads the image and corresponding RLE-encoded mask(s), decodes the masks,
    preprocesses the image and mask (resizing, type conversion), and encodes them.
    
    Parameters:
        image_id (str): Filename of the image.
        df (DataFrame): DataFrame containing RLE mask information.
    
    Returns:
        tf.train.Example containing image bytes, mask bytes, and additional metadata.
        Returns None if image file is missing or corrupted.
    """
    image_path = os.path.join(TRAIN_DIR, image_id)
    if not os.path.exists(image_path): return None

    image = imread(image_path)
    if image.ndim != 3 or image.shape[2] != 3: return None

    rles = df[df['ImageId'] == image_id]['EncodedPixels'].dropna()
    mask = np.zeros((768, 768), dtype=np.uint8)
    for rle in rles:
        mask += rle_decode(rle)
    mask = np.clip(mask, 0, 1)

    image, mask = preprocess(image, mask)

    _, img_buf = cv2.imencode(".jpg", image, [int(cv2.IMWRITE_JPEG_QUALITY), 90])
    _, mask_buf = cv2.imencode(".png", mask)

    feature = {
        'image_id': tf.train.Feature(bytes_list=tf.train.BytesList(value=[image_id.encode()])),
        'image': tf.train.Feature(bytes_list=tf.train.BytesList(value=[img_buf.tobytes()])),
        'mask': tf.train.Feature(bytes_list=tf.train.BytesList(value=[mask_buf.tobytes()])),
        'has_ship': tf.train.Feature(int64_list=tf.train.Int64List(value=[int(rles.shape[0] > 0)]))
    }

    return tf.train.Example(features=tf.train.Features(feature=feature))


In [None]:
def write_tfrecord(filename, image_ids, df):
    """
    Write a single TFRecord file with GZIP compression.
    
    Loops over a list of image_ids, creates a TFRecord Example for each, and writes them.
    
    Parameters:
        filename (str): Output filename for the TFRecord.
        image_ids (list): List of image filenames to be written.
        df (DataFrame): DataFrame containing RLE mask information.
    
    Prints the number of samples written.
    """
    options = tf.io.TFRecordOptions(compression_type='GZIP')
    count = 0
    with tf.io.TFRecordWriter(filename, options=options) as writer:
        for img_id in image_ids:
            example = create_tf_example(img_id, df)
            if example:
                writer.write(example.SerializeToString())
                count += 1
    print(f"{filename} - {count} samples")

def write_shards(prefix, image_ids, df, num_shards=NUM_SHARDS):
    """
    Split the dataset into multiple shards and write each shard as a TFRecord file.
    
    Parameters:
        prefix (str): Prefix for the output filename (e.g., 'train' or 'val').
        image_ids (list): List of image filenames to be processed.
        df (DataFrame): DataFrame with RLE data.
        num_shards (int): Number of shards/files to create.
    """
    shard_size = math.ceil(len(image_ids) / num_shards)
    for i in range(num_shards):
        shard_ids = image_ids[i * shard_size: (i + 1) * shard_size]
        fname = os.path.join(OUTPUT_DIR, f"{prefix}_sharded_{i+1:02d}-of-{num_shards}.tfrecord.gz")
        write_tfrecord(fname, shard_ids, df)

In [None]:
# Load CSV & Split Data, then Write TFRecord Shards
df = pd.read_csv(CSV_PATH)
df['has_ship'] = df['EncodedPixels'].notnull()
ship_ids = df[df['has_ship']]['ImageId'].unique()
no_ship_ids = np.setdiff1d(df['ImageId'].unique(), ship_ids)
sampled_no_ship = np.random.choice(no_ship_ids, size=SAMPLED_NO_SHIPS, replace=False)

filtered_ids = np.concatenate([ship_ids, sampled_no_ship])
train_ids, val_ids = train_test_split(filtered_ids, test_size=0.1, random_state=SEED)

print(f"Train: {len(train_ids)} images")
print(f"Val: {len(val_ids)} images")

write_shards("train", train_ids, df)
write_shards("val", val_ids, df)

Train: 45500 images
Val: 5056 images
✅ ./tfrecords/train_sharded_01-of-30.tfrecord.gz - 1517 samples
✅ ./tfrecords/train_sharded_02-of-30.tfrecord.gz - 1517 samples
✅ ./tfrecords/train_sharded_03-of-30.tfrecord.gz - 1517 samples
✅ ./tfrecords/train_sharded_04-of-30.tfrecord.gz - 1517 samples
✅ ./tfrecords/train_sharded_05-of-30.tfrecord.gz - 1517 samples
✅ ./tfrecords/train_sharded_06-of-30.tfrecord.gz - 1517 samples
✅ ./tfrecords/train_sharded_07-of-30.tfrecord.gz - 1517 samples
✅ ./tfrecords/train_sharded_08-of-30.tfrecord.gz - 1517 samples
✅ ./tfrecords/train_sharded_09-of-30.tfrecord.gz - 1517 samples
✅ ./tfrecords/train_sharded_10-of-30.tfrecord.gz - 1517 samples
✅ ./tfrecords/train_sharded_11-of-30.tfrecord.gz - 1517 samples
✅ ./tfrecords/train_sharded_12-of-30.tfrecord.gz - 1517 samples
✅ ./tfrecords/train_sharded_13-of-30.tfrecord.gz - 1517 samples
✅ ./tfrecords/train_sharded_14-of-30.tfrecord.gz - 1517 samples
✅ ./tfrecords/train_sharded_15-of-30.tfrecord.gz - 1517 samples
✅ .