In [1]:
import os
import h5py
import numpy as np
import tensorflow as tf
from tensorflow.keras import layers, models
import h5py
from sklearn.model_selection import train_test_split
import tensorflow as tf
from tensorflow.keras import layers, models
import matplotlib.pyplot as plt
from sklearn.metrics import precision_score, recall_score, accuracy_score

In [2]:

# --- Step 1: Load Data ---
def load_h5_data(directory, is_training=True, target_size=(256, 256), max_samples=None):
    images, masks = [], []
    sample_count = 0

    for filename in os.listdir(directory):
        if max_samples and sample_count >= max_samples:
            break
        if filename.endswith('.h5'):
            try:
                with h5py.File(os.path.join(directory, filename), 'r') as f:
                    image = f['image'][:]
                    if is_training:
                        mask = f['label'][:]

                    if image.ndim == 3:
                        for i in range(image.shape[0]):
                            img = tf.image.resize(np.expand_dims(image[i], -1), target_size).numpy()
                            images.append(img)
                            if is_training:
                                msk = tf.image.resize(np.expand_dims(mask[i], -1), target_size, method='nearest').numpy()
                                masks.append(np.squeeze(msk))
                    elif image.ndim == 2:
                        img = tf.image.resize(np.expand_dims(image, -1), target_size).numpy()
                        images.append(img)
                        if is_training:
                            msk = tf.image.resize(np.expand_dims(mask, -1), target_size, method='nearest').numpy()
                            masks.append(np.squeeze(msk))
                    sample_count += 1
            except Exception as e:
                print(f"Error loading {filename}: {e}")
                continue
    images = np.array(images, dtype=np.float32)
    masks = np.array(masks, dtype=np.uint8) if is_training else None
    return images, masks


In [3]:
# Load datasets
train_dir = 'E:/TEJA/NEW/cardiac/dataset/ACDC_preprocessed/ACDC_training_slices'
test_dir = 'E:/TEJA/NEW/cardiac/dataset/ACDC_preprocessed/ACDC_testing_volumes'

train_images, train_masks = load_h5_data(train_dir, is_training=True, target_size=(256, 256), max_samples=1500)
test_images, _ = load_h5_data(test_dir, is_training=False, target_size=(256, 256), max_samples=100)

# Normalize
train_images = train_images / np.max(train_images)
test_images = test_images / np.max(test_images)

# Train-validation split
X_train, X_val, y_train, y_val = train_test_split(train_images, train_masks, test_size=0.2, random_state=42)

print(f"Training shape: {X_train.shape}, Validation shape: {X_val.shape}, Test shape: {test_images.shape}")


Training shape: (1200, 256, 256, 1), Validation shape: (300, 256, 256, 1), Test shape: (1076, 256, 256, 1)


In [4]:
def attention_block(x, g, inter_channels):
    theta_x = layers.Conv2D(inter_channels, 1)(x)
    phi_g = layers.Conv2D(inter_channels, 1)(g)
    add = layers.Activation('relu')(layers.add([theta_x, phi_g]))
    psi = layers.Conv2D(1, 1)(add)
    psi = layers.Activation('sigmoid')(psi)
    return layers.Multiply()([x, psi])

def conv_block(x, filters):
    x = layers.Conv2D(filters, 3, padding='same')(x)
    x = layers.BatchNormalization()(x)
    x = layers.Activation('relu')(x)
    x = layers.Conv2D(filters, 3, padding='same')(x)
    x = layers.BatchNormalization()(x)
    x = layers.Activation('relu')(x)
    return x

def attention_unet(input_shape=(256, 256, 1), num_classes=4):
    inputs = layers.Input(shape=input_shape)

    # Encoder
    c1 = conv_block(inputs, 64)
    p1 = layers.MaxPooling2D((2, 2))(c1)
    c2 = conv_block(p1, 128)
    p2 = layers.MaxPooling2D((2, 2))(c2)
    c3 = conv_block(p2, 256)
    p3 = layers.MaxPooling2D((2, 2))(c3)
    c4 = conv_block(p3, 512)
    p4 = layers.MaxPooling2D((2, 2))(c4)

    # Bottleneck
    c5 = conv_block(p4, 1024)

    # Decoder + Attention
    u4 = layers.UpSampling2D((2, 2))(c5)
    att4 = attention_block(c4, u4, 512)
    u4 = layers.Concatenate()([u4, att4])
    c6 = conv_block(u4, 512)

    u3 = layers.UpSampling2D((2, 2))(c6)
    att3 = attention_block(c3, u3, 256)
    u3 = layers.Concatenate()([u3, att3])
    c7 = conv_block(u3, 256)

    u2 = layers.UpSampling2D((2, 2))(c7)
    att2 = attention_block(c2, u2, 128)
    u2 = layers.Concatenate()([u2, att2])
    c8 = conv_block(u2, 128)

    u1 = layers.UpSampling2D((2, 2))(c8)
    att1 = attention_block(c1, u1, 64)
    u1 = layers.Concatenate()([u1, att1])
    c9 = conv_block(u1, 64)

    outputs = layers.Conv2D(num_classes, (1, 1), activation='softmax')(c9)

    model = models.Model(inputs, outputs)
    return model


In [5]:

# Instantiate model
model = attention_unet(input_shape=(256, 256, 1), num_classes=4)

# Print model architecture summary
model.summary()


In [6]:

# EarlyStopping: stop if no improvement for 10 epochs
early_stop = tf.keras.callbacks.EarlyStopping(
    monitor='val_loss',
    patience=10,
    restore_best_weights=True
)

# Learning Rate Decay: Exponential decay every 10 epochs
def lr_schedule(epoch, lr):
    if epoch > 0 and epoch % 10 == 0:
        return lr * 0.8  # reduce LR by 20%
    return lr

lr_decay = tf.keras.callbacks.LearningRateScheduler(lr_schedule)

# (Optional) Save best model
checkpoint = tf.keras.callbacks.ModelCheckpoint(
    "best_attention_unet.h5",
    monitor='val_loss',
    save_best_only=True
)

# Compile the model
model = attention_unet(input_shape=(256, 256, 1), num_classes=4)
model.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['accuracy'])

# Train the model with callbacks
history = model.fit(
    X_train, y_train,
    validation_data=(X_val, y_val),
    epochs=15,
    batch_size=8,
    callbacks=[early_stop, lr_decay, checkpoint],
    verbose=1
)


Epoch 1/15
[1m150/150[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 12s/step - accuracy: 0.8621 - loss: 0.5114   



[1m150/150[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1983s[0m 13s/step - accuracy: 0.8627 - loss: 0.5097 - val_accuracy: 0.2564 - val_loss: 1.9286 - learning_rate: 0.0010
Epoch 2/15
[1m150/150[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 13s/step - accuracy: 0.9774 - loss: 0.0804   



[1m150/150[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2137s[0m 14s/step - accuracy: 0.9774 - loss: 0.0803 - val_accuracy: 0.9640 - val_loss: 0.1755 - learning_rate: 0.0010
Epoch 3/15
[1m150/150[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 14s/step - accuracy: 0.9880 - loss: 0.0396   



[1m150/150[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2245s[0m 15s/step - accuracy: 0.9880 - loss: 0.0396 - val_accuracy: 0.9723 - val_loss: 0.1328 - learning_rate: 0.0010
Epoch 4/15
[1m150/150[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 14s/step - accuracy: 0.9913 - loss: 0.0266   



[1m150/150[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2329s[0m 16s/step - accuracy: 0.9913 - loss: 0.0266 - val_accuracy: 0.9832 - val_loss: 0.0475 - learning_rate: 0.0010
Epoch 5/15
[1m150/150[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 14s/step - accuracy: 0.9921 - loss: 0.0231   



[1m150/150[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2238s[0m 15s/step - accuracy: 0.9921 - loss: 0.0231 - val_accuracy: 0.9900 - val_loss: 0.0333 - learning_rate: 0.0010
Epoch 6/15
[1m150/150[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2118s[0m 14s/step - accuracy: 0.9936 - loss: 0.0179 - val_accuracy: 0.9810 - val_loss: 0.0645 - learning_rate: 0.0010
Epoch 7/15
[1m150/150[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 13s/step - accuracy: 0.9937 - loss: 0.0170   



[1m150/150[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2099s[0m 14s/step - accuracy: 0.9938 - loss: 0.0170 - val_accuracy: 0.9924 - val_loss: 0.0226 - learning_rate: 0.0010
Epoch 8/15
[1m150/150[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 13s/step - accuracy: 0.9933 - loss: 0.0182   



[1m150/150[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2100s[0m 14s/step - accuracy: 0.9933 - loss: 0.0182 - val_accuracy: 0.9935 - val_loss: 0.0178 - learning_rate: 0.0010
Epoch 9/15
[1m150/150[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 13s/step - accuracy: 0.9941 - loss: 0.0155   



[1m150/150[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2097s[0m 14s/step - accuracy: 0.9941 - loss: 0.0155 - val_accuracy: 0.9944 - val_loss: 0.0150 - learning_rate: 0.0010
Epoch 10/15
[1m150/150[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2136s[0m 14s/step - accuracy: 0.9944 - loss: 0.0148 - val_accuracy: 0.9936 - val_loss: 0.0176 - learning_rate: 0.0010
Epoch 11/15
[1m150/150[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2096s[0m 14s/step - accuracy: 0.9954 - loss: 0.0118 - val_accuracy: 0.9944 - val_loss: 0.0158 - learning_rate: 8.0000e-04
Epoch 12/15
[1m150/150[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2106s[0m 14s/step - accuracy: 0.9961 - loss: 0.0099 - val_accuracy: 0.9951 - val_loss: 0.0150 - learning_rate: 8.0000e-04
Epoch 13/15
[1m150/150[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 13s/step - accuracy: 0.9960 - loss: 0.0101   



[1m150/150[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2132s[0m 14s/step - accuracy: 0.9960 - loss: 0.0101 - val_accuracy: 0.9953 - val_loss: 0.0127 - learning_rate: 8.0000e-04
Epoch 14/15
[1m150/150[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 13s/step - accuracy: 0.9964 - loss: 0.0090   



[1m150/150[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2095s[0m 14s/step - accuracy: 0.9964 - loss: 0.0090 - val_accuracy: 0.9953 - val_loss: 0.0127 - learning_rate: 8.0000e-04
Epoch 15/15
[1m150/150[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2097s[0m 14s/step - accuracy: 0.9962 - loss: 0.0094 - val_accuracy: 0.9894 - val_loss: 0.0430 - learning_rate: 8.0000e-04


In [7]:
# Save model after training (if not already saved via ModelCheckpoint)
model.save("final_attention_unet_model.h5")


