In [1]:
import tensorflow as tf
import matplotlib.pyplot as plt

In [2]:
tf.random.set_seed(42)

In [3]:
data_path = 'data/'

# train_csv_file_path = data_path + 'train_ship_segmentations_v2.csv'
train_csv_file_path = data_path + 'train_ship_segmentations_v2_clean.csv'
# train_csv_file_path = data_path + 'demo.csv'
# train_csv_file_path = data_path + 'demo_clean.csv'
train_image_path = data_path + 'train_v2/'

In [4]:
def load_csv(path):
    return tf.data.experimental.make_csv_dataset(
        path,
        batch_size=1, # required
        column_names=['ImageId', 'EncodedPixels'],
        num_epochs=1,
        shuffle=False,
    )

train_csv = load_csv(train_csv_file_path)

# for batch in train_csv.take(1):
#     print(batch)

2024-02-09 12:31:10.873361: I metal_plugin/src/device/metal_device.cc:1154] Metal device set to: Apple M1 Pro
2024-02-09 12:31:10.873379: I metal_plugin/src/device/metal_device.cc:296] systemMemory: 32.00 GB
2024-02-09 12:31:10.873383: I metal_plugin/src/device/metal_device.cc:313] maxCacheSize: 10.67 GB
2024-02-09 12:31:10.873413: I tensorflow/core/common_runtime/pluggable_device/pluggable_device_factory.cc:306] Could not identify NUMA node of platform GPU ID 0, defaulting to 0. Your kernel may not have been built with NUMA support.
2024-02-09 12:31:10.873427: I tensorflow/core/common_runtime/pluggable_device/pluggable_device_factory.cc:272] Created TensorFlow device (/job:localhost/replica:0/task:0/device:GPU:0 with 0 MB memory) -> physical PluggableDevice (device: 0, name: METAL, pci bus id: <undefined>)


In [5]:
IMG_HEIGHT = 768
IMG_WIDTH = 768

@tf.function
def process_img(file):
    img = tf.io.read_file(train_image_path + file)
    img = tf.io.decode_jpeg(img, channels=3)
    return img

@tf.function
def decode_label_mask(encoded_pixels, image_height, image_width):
    mask = tf.zeros(image_height * image_width, dtype=tf.float32)

    # Convert string to integer tensor
    pairs = tf.strings.to_number(tf.strings.split(encoded_pixels), out_type=tf.int32)

    # Iterate over pairs and update mask
    for i in range(0, len(pairs), 2):
        start = pairs[i] - 1
        run_length = pairs[i + 1]

        indices = tf.range(start, start + run_length)
        updates = tf.ones(run_length, dtype=tf.float32)
        mask = tf.tensor_scatter_nd_update(mask, indices=tf.expand_dims(indices, axis=1), updates=updates)

    return  tf.transpose(tf.reshape(mask, (image_height, image_width)))

@tf.function
def process_label(label):
    return decode_label_mask(label, IMG_HEIGHT, IMG_WIDTH)

@tf.function
def process_batch(csv_item):
    X = process_img(csv_item['ImageId'])
    y = process_label(csv_item['EncodedPixels'])
    return X, y

def plot_ds_element(background, overlay):
    fig, ax = plt.subplots()
    plt.imshow(background)
    ax.imshow(overlay, alpha=0.3)

train_ds = train_csv.unbatch().map(process_batch)

# a = train_ds.batch(32).take(1000).reduce(0, lambda x, _: x + 1).numpy()
# print(a)

# for item, label in train_ds.skip(2).take(1):
#     plot_ds_element(item, label)


In [6]:
def dice_coefficient(y_true, y_pred, smooth=1):
    intersection = tf.reduce_sum(y_true * y_pred)
    union = tf.reduce_sum(y_true) + tf.reduce_sum(y_pred)
    dice = (2. * intersection + smooth) / (union + smooth)
    return dice

def dice_loss(y_true, y_pred):
    return 1 - dice_coefficient(y_true, y_pred)

In [7]:
def unet(input_size=(IMG_HEIGHT, IMG_WIDTH, 3)):
    inputs = tf.keras.Input(shape=input_size)

    norm_inputs = tf.keras.layers.experimental.preprocessing.Rescaling(1./255)(inputs) # mo to preprocessing

    # Encoder
    conv1 = tf.keras.layers.Conv2D(64, 3, activation='relu', padding='same')(norm_inputs)
    conv1 = tf.keras.layers.Conv2D(64, 3, activation='relu', padding='same')(conv1)
    pool1 = tf.keras.layers.MaxPooling2D(pool_size=(2, 2))(conv1)

    conv2 = tf.keras.layers.Conv2D(128, 3, activation='relu', padding='same')(pool1)
    conv2 = tf.keras.layers.Conv2D(128, 3, activation='relu', padding='same')(conv2)
    pool2 = tf.keras.layers.MaxPooling2D(pool_size=(2, 2))(conv2)

    conv3 = tf.keras.layers.Conv2D(256, 3, activation='relu', padding='same')(pool2)
    conv3 = tf.keras.layers.Conv2D(256, 3, activation='relu', padding='same')(conv3)
    pool3 = tf.keras.layers.MaxPooling2D(pool_size=(2, 2))(conv3)

    # Bottom
    conv4 = tf.keras.layers.Conv2D(512, 3, activation='relu', padding='same')(pool3)
    conv4 = tf.keras.layers.Conv2D(512, 3, activation='relu', padding='same')(conv4)

    # Decoder
    up5 = tf.keras.layers.Conv2DTranspose(256, (2, 2), strides=(2, 2), padding='same')(conv4)
    up5 = tf.keras.layers.concatenate([up5, conv3], axis=-1)
    conv5 = tf.keras.layers.Conv2D(256, 3, activation='relu', padding='same')(up5)
    conv5 = tf.keras.layers.Conv2D(256, 3, activation='relu', padding='same')(conv5)

    up6 = tf.keras.layers.Conv2DTranspose(128, (2, 2), strides=(2, 2), padding='same')(conv5)
    up6 = tf.keras.layers.concatenate([up6, conv2], axis=-1)
    conv6 = tf.keras.layers.Conv2D(128, 3, activation='relu', padding='same')(up6)
    conv6 = tf.keras.layers.Conv2D(128, 3, activation='relu', padding='same')(conv6)

    up7 = tf.keras.layers.Conv2DTranspose(64, (2, 2), strides=(2, 2), padding='same')(conv6)
    up7 = tf.keras.layers.concatenate([up7, conv1], axis=-1)
    conv7 = tf.keras.layers.Conv2D(64, 3, activation='relu', padding='same')(up7)
    conv7 = tf.keras.layers.Conv2D(64, 3, activation='relu', padding='same')(conv7)

    outputs = tf.keras.layers.Conv2D(1, (1, 1), activation='sigmoid')(conv7)

    model = tf.keras.models.Model(inputs=[inputs], outputs=[outputs])

    return model

# Instantiate the model
model = unet()

# Compile the model
model.compile(optimizer='adam', loss=dice_loss, metrics=[dice_coefficient])


In [8]:
train_ds = train_ds.take(10).batch(4).prefetch(buffer_size=tf.data.AUTOTUNE)

model.fit(train_ds, epochs=1000)

Epoch 1/1000


2024-02-09 12:31:17.234479: I tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.cc:117] Plugin optimizer for device_type GPU is enabled.


Epoch 2/1000


2024-02-09 12:31:24.513200: I tensorflow/core/framework/local_rendezvous.cc:421] Local rendezvous recv item cancelled. Key hash: 11074633200316695256
2024-02-09 12:31:24.513235: I tensorflow/core/framework/local_rendezvous.cc:421] Local rendezvous recv item cancelled. Key hash: 12400806594822815843


Epoch 3/1000
Epoch 4/1000
Epoch 5/1000
Epoch 6/1000
Epoch 7/1000
Epoch 8/1000
Epoch 9/1000
Epoch 10/1000
Epoch 11/1000
Epoch 12/1000
Epoch 13/1000
Epoch 14/1000
Epoch 15/1000
Epoch 16/1000
Epoch 17/1000
Epoch 18/1000
Epoch 19/1000
Epoch 20/1000
Epoch 21/1000
Epoch 22/1000
Epoch 23/1000
Epoch 24/1000
Epoch 25/1000
Epoch 26/1000
Epoch 27/1000
Epoch 28/1000
Epoch 29/1000
Epoch 30/1000
Epoch 31/1000
Epoch 32/1000
Epoch 33/1000
Epoch 34/1000
Epoch 35/1000
Epoch 36/1000
Epoch 37/1000
Epoch 38/1000
Epoch 39/1000
Epoch 40/1000
Epoch 41/1000
Epoch 42/1000
Epoch 43/1000
Epoch 44/1000
Epoch 45/1000
Epoch 46/1000
Epoch 47/1000
Epoch 48/1000
Epoch 49/1000
Epoch 50/1000
Epoch 51/1000
Epoch 52/1000
Epoch 53/1000
Epoch 54/1000
Epoch 55/1000
Epoch 56/1000
Epoch 57/1000
Epoch 58/1000
Epoch 59/1000
Epoch 60/1000
Epoch 61/1000
Epoch 62/1000
Epoch 63/1000
Epoch 64/1000
Epoch 65/1000
Epoch 66/1000
Epoch 67/1000
Epoch 68/1000
Epoch 69/1000
Epoch 70/1000
Epoch 71/1000
Epoch 72/1000
Epoch 73/1000
Epoch 74/1000

KeyboardInterrupt: 