In [4]:
import os
os.environ["SM_FRAMEWORK"] = "tf.keras"  # Force TensorFlow Keras compatibility

import tensorflow as tf
from segmentation_models import Unet  # ✅ Import pre-trained U-Net
from tensorflow.keras.callbacks import ModelCheckpoint


In [5]:
import segmentation_models
print("Segmentation Models installed successfully!")


Segmentation Models installed successfully!


## Train your Unet with membrane data
membrane data is in folder membrane/, it is a binary classification task.

The input shape of image and mask are the same :(batch_size,rows,cols,channel = 1)

### Train with data generator

In [6]:
import tensorflow as tf
from tensorflow.keras.callbacks import ModelCheckpoint
from tensorflow.keras.mixed_precision import set_global_policy

# ✅ Enable Mixed Precision & XLA
set_global_policy('mixed_float16')
tf.config.optimizer.set_jit(True)
tf.config.optimizer.set_experimental_options({"auto_mixed_precision": True})

# ✅ Enable GPU Memory Growth
gpus = tf.config.experimental.list_physical_devices('GPU')
if gpus:
    for gpu in gpus:
        tf.config.experimental.set_memory_growth(gpu, True)

# ✅ Data Generator Optimization
AUTOTUNE = tf.data.AUTOTUNE

# ✅ Import trainGenerator if it's in another script
# from data_loader import trainGenerator
def convert_grayscale_to_rgb(generator):
    for batch in generator:
        image, mask = batch
        image = np.repeat(image, 3, axis=-1)  # ✅ Convert grayscale to RGB
        yield image, mask
# ✅ Define it here if missing
from tensorflow.keras.preprocessing.image import ImageDataGenerator

from tensorflow.keras.preprocessing.image import ImageDataGenerator

# ✅ Define Data Augmentation Arguments
data_gen_args = dict(
    rotation_range=0.2,
    width_shift_range=0.05,
    height_shift_range=0.05,
    shear_range=0.05,
    zoom_range=0.05,
    horizontal_flip=True,
    fill_mode='nearest'
)

# ✅ Fix trainGenerator indentation
def trainGenerator(batch_size, train_path, image_folder, label_folder, aug_dict, save_to_dir=None):
    image_datagen = ImageDataGenerator(**aug_dict)
    mask_datagen = ImageDataGenerator(**aug_dict)

    image_generator = image_datagen.flow_from_directory(
        train_path,
        classes=[image_folder],
        class_mode=None,
        color_mode='grayscale',
        target_size=(256, 256),
        batch_size=batch_size,
        save_to_dir=save_to_dir
    )

    mask_generator = mask_datagen.flow_from_directory(
        train_path,
        classes=[label_folder],
        class_mode=None,
        color_mode='grayscale',
        target_size=(256, 256),
        batch_size=batch_size,
        save_to_dir=save_to_dir
    )

    train_generator = zip(image_generator, mask_generator)
    return train_generator

# ✅ Now it should work without indentation errors
myGene = trainGenerator(16, 'data/membrane/train', 'image', 'label', data_gen_args, save_to_dir=None)


myGene = convert_grayscale_to_rgb(myGene)


# ✅ Initialize U-Net Model (Multi-GPU)
strategy = tf.distribute.MirroredStrategy()
with strategy.scope():
    model = Unet('resnet34', encoder_weights='imagenet', input_shape=(256, 256, 3))
    model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=1e-4),
                  loss=tf.keras.losses.BinaryCrossentropy(from_logits=True),
                  metrics=['accuracy'])

print(model.summary())

# ✅ Train the model
model_checkpoint = ModelCheckpoint('unet_membrane.keras', monitor='loss', verbose=1, save_best_only=True)
steps_per_epoch = 300
model.fit(myGene, steps_per_epoch=steps_per_epoch, epochs=5, callbacks=[model_checkpoint])


NameError: name 'trainGenerator' is not defined

### Train with npy file

In [None]:
#imgs_train,imgs_mask_train = geneTrainNpy("data/membrane/train/aug/","data/membrane/train/aug/")
#model.fit(imgs_train, imgs_mask_train, batch_size=2, nb_epoch=10, verbose=1,validation_split=0.2, shuffle=True, callbacks=[model_checkpoint])

### test your model and save predicted results

In [None]:
testGene = testGenerator("data/membrane/test")
model = Unet()
model.load_weights("unet_membrane.hdf5")
results = model.predict_generator(testGene,30,verbose=1)
saveResult("data/membrane/test",results)