In [1]:
import tensorflow as tf
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras.callbacks import EarlyStopping
from sklearn.model_selection import train_test_split
import os
import pandas as pd
from keras_unet.models import custom_unet
from keras.optimizers import Adam
from keras.metrics import MeanIoU

# Define image dimensions and batch size
img_height = 224
img_width = 224
batch_size = 32

# Define data generators
train_datagen = ImageDataGenerator(
    rescale=1./255,
    rotation_range=20,
    width_shift_range=0.2,
    height_shift_range=0.2,
    shear_range=0.2,
    zoom_range=0.2,
    horizontal_flip=True,
    fill_mode='nearest'
)

test_datagen = ImageDataGenerator(rescale=1./255)

# Define the root directory containing subfolders for each class
root_dir = 'Osteosarcoma-UT'

# Get the list of class names
classes = os.listdir(root_dir)

# Split the data into training, validation, and test sets
train_data = []
val_data = []
test_data = []

for cls in classes:
    cls_dir = os.path.join(root_dir, cls)
    images = [os.path.join(cls_dir, img) for img in os.listdir(cls_dir)]
    train_images, temp_images = train_test_split(images, test_size=0.2, random_state=42)
    val_images, test_images = train_test_split(temp_images, test_size=0.5, random_state=42)
    
    train_data.extend([(img, cls) for img in train_images])
    val_data.extend([(img, cls) for img in val_images])
    test_data.extend([(img, cls) for img in test_images])

# Create DataFrames for train, validation, and test sets
train_df = pd.DataFrame(train_data, columns=['file_path', 'label'])
val_df = pd.DataFrame(val_data, columns=['file_path', 'label'])
test_df = pd.DataFrame(test_data, columns=['file_path', 'label'])

# Define generators for train, validation, and test sets
train_generator = train_datagen.flow_from_dataframe(
    dataframe=train_df,
    x_col='file_path',
    y_col='label',
    target_size=(img_height, img_width),
    batch_size=batch_size,
    class_mode='input',  # For segmentation, use 'input' mode
    shuffle=True
)

validation_generator = test_datagen.flow_from_dataframe(
    dataframe=val_df,
    x_col='file_path',
    y_col='label',
    target_size=(img_height, img_width),
    batch_size=batch_size,
    class_mode='input',  # For segmentation, use 'input' mode
    shuffle=False
)

test_generator = test_datagen.flow_from_dataframe(
    dataframe=test_df,
    x_col='file_path',
    y_col='label',
    target_size=(img_height, img_width),
    batch_size=batch_size,
    class_mode='input',  # For segmentation, use 'input' mode
    shuffle=False
)

# Define the input shape
input_shape = (img_height, img_width, 3)

# Define the number of classes
num_classes = len(classes)

# Load the pre-trained U-Net model
model = custom_unet(
    input_shape,
    filters=64,
    num_classes=num_classes,
    output_activation='softmax'  # Use softmax activation for multi-class segmentation
)

# Compile the model
model.compile(
    optimizer=Adam(), 
    loss='categorical_crossentropy',  # Use categorical cross-entropy for multi-class segmentation
    metrics=['accuracy', MeanIoU(num_classes=num_classes)]
)

# Print model summary
model.summary()

# Define early stopping callback
early_stopping = EarlyStopping(monitor='val_loss', patience=5, restore_best_weights=True)

# Train the model
history = model.fit(
    train_generator,
    steps_per_epoch=len(train_generator),
    epochs=100,
    validation_data=validation_generator,
    validation_steps=len(validation_generator),
    callbacks=[early_stopping]
)


-----------------------------------------
keras-unet init: TF version is >= 2.0.0 - using `tf.keras` instead of `Keras`
-----------------------------------------
Found 914 validated image filenames.
Found 114 validated image filenames.
Found 116 validated image filenames.


2024-02-21 15:57:26.799477: I metal_plugin/src/device/metal_device.cc:1154] Metal device set to: Apple M1
2024-02-21 15:57:26.799501: I metal_plugin/src/device/metal_device.cc:296] systemMemory: 8.00 GB
2024-02-21 15:57:26.799507: I metal_plugin/src/device/metal_device.cc:313] maxCacheSize: 2.67 GB
2024-02-21 15:57:26.799787: I tensorflow/core/common_runtime/pluggable_device/pluggable_device_factory.cc:306] Could not identify NUMA node of platform GPU ID 0, defaulting to 0. Your kernel may not have been built with NUMA support.
2024-02-21 15:57:26.800337: I tensorflow/core/common_runtime/pluggable_device/pluggable_device_factory.cc:272] Created TensorFlow device (/job:localhost/replica:0/task:0/device:GPU:0 with 0 MB memory) -> physical PluggableDevice (device: 0, name: METAL, pci bus id: <undefined>)


Model: "model"
__________________________________________________________________________________________________
 Layer (type)                Output Shape                 Param #   Connected to                  
 input_1 (InputLayer)        [(None, 224, 224, 3)]        0         []                            
                                                                                                  
 conv2d (Conv2D)             (None, 224, 224, 64)         1728      ['input_1[0][0]']             
                                                                                                  
 batch_normalization (Batch  (None, 224, 224, 64)         256       ['conv2d[0][0]']              
 Normalization)                                                                                   
                                                                                                  
 spatial_dropout2d (Spatial  (None, 224, 224, 64)         0         ['batch_normalization[0][0

2024-02-21 15:57:30.607121: I tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.cc:117] Plugin optimizer for device_type GPU is enabled.
2024-02-21 15:57:30.858303: E tensorflow/core/grappler/optimizers/meta_optimizer.cc:961] model_pruner failed: INVALID_ARGUMENT: Graph does not contain terminal node Adam/AssignAddVariableOp.
