<a href="https://colab.research.google.com/github/rozapkk13/unet/blob/master/trainUnet.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [8]:
import os
os.environ["SM_FRAMEWORK"] = "tf.keras"  # Force TensorFlow Keras compatibility

import tensorflow as tf
from segmentation_models import Unet  # ✅ Import pre-trained U-Net
from tensorflow.keras.callbacks import ModelCheckpoint


In [9]:
import segmentation_models
print("Segmentation Models installed successfully!")


Segmentation Models installed successfully!


## Train your Unet with membrane data
membrane data is in folder membrane/, it is a binary classification task.

The input shape of image and mask are the same :(batch_size,rows,cols,channel = 1)

### Train with data generator

In [10]:
import tensorflow as tf
import numpy as np  # ✅ Fix NameError: np is not defined

from tensorflow.keras.callbacks import ModelCheckpoint
from tensorflow.keras.mixed_precision import set_global_policy

# ✅ Enable Mixed Precision & XLA
set_global_policy('mixed_float16')
tf.config.optimizer.set_jit(True)
tf.config.optimizer.set_experimental_options({"auto_mixed_precision": True})

# ✅ Enable GPU Memory Growth
gpus = tf.config.experimental.list_physical_devices('GPU')
if gpus:
    for gpu in gpus:
        tf.config.experimental.set_memory_growth(gpu, True)

# ✅ Data Generator Optimization
AUTOTUNE = tf.data.AUTOTUNE

# ✅ Import trainGenerator if it's in another script
# from data_loader import trainGenerator
def convert_grayscale_to_rgb(generator):
    for batch in generator:
        image, mask = batch
        image = np.repeat(image, 3, axis=-1)  # ✅ Convert grayscale to RGB
        yield image, mask
# ✅ Define it here if missing
from tensorflow.keras.preprocessing.image import ImageDataGenerator

from tensorflow.keras.preprocessing.image import ImageDataGenerator

# ✅ Define Data Augmentation Arguments
data_gen_args = dict(
    rotation_range=0.2,
    width_shift_range=0.05,
    height_shift_range=0.05,
    shear_range=0.05,
    zoom_range=0.05,
    horizontal_flip=True,
    fill_mode='nearest'
)

# ✅ Fix trainGenerator indentation
def trainGenerator(batch_size, train_path, image_folder, label_folder, aug_dict, save_to_dir=None):
    image_datagen = ImageDataGenerator(**aug_dict)
    mask_datagen = ImageDataGenerator(**aug_dict)

    image_generator = image_datagen.flow_from_directory(
        train_path,
        classes=[image_folder],
        class_mode=None,
        color_mode='grayscale',
        target_size=(256, 256),
        batch_size=batch_size,
        save_to_dir=save_to_dir
    )

    mask_generator = mask_datagen.flow_from_directory(
        train_path,
        classes=[label_folder],
        class_mode=None,
        color_mode='grayscale',
        target_size=(256, 256),
        batch_size=batch_size,
        save_to_dir=save_to_dir
    )

    train_generator = zip(image_generator, mask_generator)
    return train_generator

# ✅ Now it should work without indentation errors
myGene = trainGenerator(16, 'data/membrane/train', 'image', 'label', data_gen_args, save_to_dir=None)


myGene = convert_grayscale_to_rgb(myGene)


# ✅ Initialize U-Net Model (Multi-GPU)
strategy = tf.distribute.MirroredStrategy()
with strategy.scope():
    model = Unet('resnet34', encoder_weights='imagenet', input_shape=(256, 256, 3))
    model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=1e-4),
                  loss=tf.keras.losses.BinaryCrossentropy(from_logits=True),
                  metrics=['accuracy'])

print(model.summary())

# ✅ Train the model
model_checkpoint = ModelCheckpoint('unet_membrane.keras', monitor='loss', verbose=1, save_best_only=True)
steps_per_epoch = 500
model.fit(myGene, steps_per_epoch=steps_per_epoch, epochs=5, callbacks=[model_checkpoint])


Found 0 images belonging to 1 classes.
Found 0 images belonging to 1 classes.
Downloading data from https://github.com/qubvel/classification_models/releases/download/0.0.1/resnet34_imagenet_1000_no_top.h5
[1m85521592/85521592[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 0us/step


None
Epoch 1/5
[1m500/500[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 96ms/step - accuracy: 0.0000e+00 - loss: 0.0000e+00
Epoch 1: loss improved from inf to 0.00000, saving model to unet_membrane.keras
[1m500/500[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m74s[0m 107ms/step - accuracy: 0.0000e+00 - loss: 0.0000e+00
Epoch 2/5
[1m500/500[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 93ms/step - accuracy: 0.0000e+00 - loss: 0.0000e+00
Epoch 2: loss did not improve from 0.00000
[1m500/500[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m46s[0m 93ms/step - accuracy: 0.0000e+00 - loss: 0.0000e+00
Epoch 3/5
[1m500/500[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 100ms/step - accuracy: 0.0000e+00 - loss: 0.0000e+00
Epoch 3: loss did not improve from 0.00000
[1m500/500[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m50s[0m 100ms/step - accuracy: 0.0000e+00 - loss: 0.0000e+00
Epoch 4/5
[1m500/500[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 92ms/step

<keras.src.callbacks.history.History at 0x7f1cc1417990>

### Train with npy file

In [11]:
#imgs_train,imgs_mask_train = geneTrainNpy("data/membrane/train/aug/","data/membrane/train/aug/")
#model.fit(imgs_train, imgs_mask_train, batch_size=2, nb_epoch=10, verbose=1,validation_split=0.2, shuffle=True, callbacks=[model_checkpoint])

### test your model and save predicted results

In [17]:
import os
import numpy as np
import cv2
import tensorflow as tf
from skimage import io, transform
from google.colab import files
import zipfile

# ✅ Step 1: Upload and extract dataset
uploaded = files.upload()  # Upload ZIP file manually

zip_filename = list(uploaded.keys())[0]  # Get uploaded file name
extract_path = "/content/data/membrane"  # Destination folder

# Extract ZIP file
with zipfile.ZipFile(zip_filename, 'r') as zip_ref:
    zip_ref.extractall(extract_path)

test_path = os.path.join(extract_path, "test")  # Path to test images
print(f"✅ Path exists: {test_path}")
print(f"📂 Contents: {os.listdir(test_path)}")

# ✅ Step 2: Define test image generator
def testGenerator(test_path, target_size=(256, 256)):
    """
    Loads test images, ensures correct format, resizes, normalizes, and yields images for prediction.
    """
    for file_name in os.listdir(test_path):
        img_path = os.path.join(test_path, file_name)
        img = io.imread(img_path)

        # Convert grayscale images to RGB
        if len(img.shape) == 2:  # If grayscale
            img = cv2.cvtColor(img, cv2.COLOR_GRAY2RGB)

        # Resize image
        img = transform.resize(img, target_size, mode='constant', anti_aliasing=True)

        # Normalize to [0,1]
        img = img / 255.0

        # Expand dimensions to match model input shape
        img = np.expand_dims(img, axis=0)

        yield (img,)  # Return as a tuple

testGene = testGenerator(test_path)

print(os.listdir("/content/"))  # List all files in /content/
model_path = "/content/unet_membrane.keras"  # Ensure the correct path
print("✅ Model file exists:", os.path.exists(model_path))

# ✅ Load Model
model = tf.keras.models.load_model(model_path)
print("✅ Model Loaded Successfully!")

# ✅ Step 3: Run Predictions
results = model.predict(testGene, steps=30, verbose=1)

# ✅ Step 4: Save Results
def saveResult(save_path, npyfile):
    """
    Saves predicted images to the output folder.
    """
    os.makedirs(save_path, exist_ok=True)
    for i, img in enumerate(npyfile):
        img = np.squeeze(img)  # Remove extra channel dimension
        img = (img * 255).astype(np.uint8)  # Convert back to 0-255
        save_filename = os.path.join(save_path, f"{i}_predict.png")
        cv2.imwrite(save_filename, img)
        print(f"✅ Saved: {save_filename}")

# ✅ Step 5: Save predicted results
saveResult(test_path, results)

print("🎉 Prediction Complete! Check saved images in:", test_path)


Saving test.zip to test (3).zip
✅ Path exists: /content/data/membrane/test
📂 Contents: ['13_predict.png', '11_predict.png', '26_predict.png', '16.png', '17_predict.png', '4_predict.png', '24_predict.png', '13.png', '23.png', '16_predict.png', '6_predict.png', '22_predict.png', '29_predict.png', '2_predict.png', '10_predict.png', '12_predict.png', '8.png', '2.png', '1.png', '4.png', '19.png', '7_predict.png', '1_predict.png', '27_predict.png', '9_predict.png', '3.png', '18_predict.png', '18.png', '20.png', '14.png', '29.png', '9.png', '5_predict.png', '22.png', '19_predict.png', '26.png', '15_predict.png', '7.png', '21_predict.png', '6.png', '25.png', '11.png', '0_predict.png', '3_predict.png', '28.png', '8_predict.png', '5.png', '24.png', '10.png', '14_predict.png', '0.png', '12.png', '28_predict.png', '23_predict.png', '25_predict.png', '20_predict.png', '21.png', '17.png', '15.png', '27.png']
['.config', 'unet_membrane.keras', 'test.zip', 'test (1).zip', 'test (2).zip', 'test (3).zip

  img = (img * 255).astype(np.uint8)  # Convert back to 0-255
