In [None]:
import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt
from skimage import io
from tensorflow.keras.callbacks import ModelCheckpoint, LambdaCallback
from tensorflow.keras.optimizers import Adam
from tensorflow.keras import layers, models
from tensorflow.image import ssim

# Load the clean brain image
image_path = r"C:\Users\priya\Documents\DL project\61_processed_image.png"  # Path to the clean image
brain_image = io.imread(image_path, as_gray=True)
brain_image = np.expand_dims(brain_image, axis=-1)  # Add channel dimension
#brain_image = brain_image / 255.0  # Normalize to [0,1]

# Function to apply mask to the image, hiding part of it
def mask_image(image, mask_center_x, mask_center_y, mask_size=100):
    masked_image = np.copy(image)
    masked_image[mask_center_x - mask_size:mask_center_x + mask_size,
                 mask_center_y - mask_size:mask_center_y + mask_size] = 0
    return masked_image

# Visualize masked area
def visualize_mask(image, mask_center_x, mask_center_y, mask_size=100):
    masked_image = mask_image(image, mask_center_x, mask_center_y, mask_size)
    fig, ax = plt.subplots(1, 2, figsize=(10, 5))
    ax[0].imshow(image[:, :, 0], cmap='gray')
    ax[0].set_title("Original Image")
    ax[1].imshow(masked_image[:, :, 0], cmap='gray')
    ax[1].set_title(f"Masked Area at ({mask_center_x},{mask_center_y})")
    plt.show()


def attention_block(input_tensor, filters):
    attention = layers.Conv2D(filters, (1, 1), activation='sigmoid')(input_tensor)
    return layers.multiply([input_tensor, attention])

# Encoder Block
def encoder_block(input_tensor, filters, use_attention=False):
    x = layers.Conv2D(filters, (3, 3), activation='relu', padding='same')(input_tensor)
    x = layers.BatchNormalization()(x)
    if use_attention:
        x = attention_block(x, filters)
    x = layers.Conv2D(filters, (3, 3), activation='relu', padding='same')(x)
    x = layers.BatchNormalization()(x)
    skip = x
    x = layers.MaxPooling2D((2, 2))(x)
    return x, skip

# Decoder Block
def decoder_block(input_tensor, skip_tensor, filters, use_attention=False):
    x = layers.Conv2DTranspose(filters, (2, 2), strides=(2, 2), padding='same')(input_tensor)
    x = layers.concatenate([x, skip_tensor])
    x = layers.Conv2D(filters, (3, 3), activation='relu', padding='same')(x)
    x = layers.BatchNormalization()(x)
    if use_attention:
        x = attention_block(x, filters)
    x = layers.Conv2D(filters, (3, 3), activation='relu', padding='same')(x)
    x = layers.BatchNormalization()(x)
    return x

# UNet Model with Attention
def build_unet(input_shape=(512, 512, 1)):
    inputs = layers.Input(shape=input_shape)
    
    # Encoder
    x1, skip1 = encoder_block(inputs, 64, use_attention=True)
    x2, skip2 = encoder_block(x1, 128, use_attention=True)
    x3, skip3 = encoder_block(x2, 256, use_attention=False)
    
    # Bottleneck
    bottleneck = layers.Conv2D(512, (3, 3), activation='relu', padding='same')(x3)
    bottleneck = layers.BatchNormalization()(bottleneck)
    bottleneck = layers.Dropout(0.3)(bottleneck)

    # Decoder
    x = decoder_block(bottleneck, skip3, 256, use_attention=False)
    x = decoder_block(x, skip2, 128, use_attention=True)
    x = decoder_block(x, skip1, 64, use_attention=True)

    # Output Layer (Reconstructed Image)
    outputs = layers.Conv2D(1, (1, 1), activation='sigmoid')(x)
    
    return models.Model(inputs, outputs)

# Custom Loss Function (Weighted SSIM + MSE)
def combined_loss(alpha=0.2):
    def loss(y_true, y_pred):
        mse_loss = tf.reduce_mean(tf.square(y_true - y_pred))
        ssim_loss = 1 - tf.reduce_mean(ssim(y_true, y_pred, max_val=1.0))
        return alpha * mse_loss + (1 - alpha) * ssim_loss
    return loss

# Callback to display reconstructed image every 50 epochs
class DisplayReconstructedImage(tf.keras.callbacks.Callback):
    def __init__(self, image, output_path="reconstructed_image.png", interval=50):
        super().__init__()
        self.image = image
        self.interval = interval
        self.output_path = output_path

    def on_epoch_end(self, epoch, logs=None):
        if epoch % self.interval == 0:
            # Use self.model directly (assigned by Keras)
            reconstructed_image = self.model.predict(self.image)
            self.display_image(reconstructed_image, epoch)
    
    def display_image(self, reconstructed_image, epoch):
        plt.figure(figsize=(5, 5))
        plt.imshow(reconstructed_image[0, :, :, 0], cmap='gray')
        plt.title(f"Reconstructed Image at Epoch {epoch+1}")
        plt.axis('off')
        plt.savefig(f"reconstructed_epoch_{epoch+1}.png")  # Save image for visual inspection
        plt.show()

# Compile the model
# SSIM Metric
def ssim_metric(y_true, y_pred):
    return tf.reduce_mean(ssim(y_true, y_pred, max_val=1.0))

# Compile the model
input_shape = (512, 512, 1)
unet_model = build_unet(input_shape)
unet_model.compile(optimizer='adam', 
                   loss=combined_loss(alpha=0.8), 
                   metrics=[ssim_metric])


# Checkpoint callback to save the model every 100 epochs
checkpoint_cb = ModelCheckpoint("unet_attention_model.keras", save_freq=100 * (len(brain_image) // 1), save_best_only=True)

# Function to visualize the reconstruction and masked region
def show_reconstruction(epoch, logs, masked_brain_image):
    if epoch % 50 == 0:
        reconstructed_img = unet_model.predict(np.expand_dims(masked_brain_image, axis=0))[0, :, :, 0]
        fig, ax = plt.subplots(1, 3, figsize=(15, 5))
        ax[0].imshow(masked_brain_image[:, :, 0], cmap='gray')
        ax[0].set_title(f"Masked Input at Epoch {epoch}")
        ax[1].imshow(reconstructed_img, cmap='gray')
        ax[1].set_title(f"Reconstructed Image Epoch {epoch}")
        ax[2].imshow(brain_image[:, :, 0], cmap='gray')
        ax[2].set_title(f"Original Image")
        plt.show()

# Function to progressively mask the image and train
def progressive_training(model, brain_image, epochs_per_mask=500):
    mask_centers = [
    (100, 180), (120, 180), (140, 180), (160, 180), (200, 180), (220, 180),(240, 180), (260, 180), (332, 180),  # Top row
    (100, 256), (120, 256), (140, 256), (160, 256), (200, 256), (220, 256),(240, 256), (260, 256), (332, 256),  # Top row
    (100, 332), (120, 332), (140, 332), (160, 332), (200, 332), (220, 332),(240, 332), (260, 332), (332, 332),  # Top row
]
  # Predefined mask centers
    for center_x, center_y in mask_centers:
        masked_brain_image = mask_image(brain_image, center_x, center_y, mask_size=5)
        
        # Visualize the current masked region
        visualize_mask(brain_image, center_x, center_y, mask_size=50)
        
        print(f"Training on masked image with mask at center ({center_x}, {center_y})")
        
        # Callback for visualizing the reconstruction at the current mask
        reconstruction_cb = LambdaCallback(on_epoch_end=lambda epoch, logs: show_reconstruction(epoch, logs, masked_brain_image))
        
        # Train on the masked image
        model.fit(
            x=np.expand_dims(masked_brain_image, axis=0),
            y=np.expand_dims(brain_image, axis=0),  # The target is the original clean brain image
            epochs=epochs_per_mask,
            callbacks=[checkpoint_cb, reconstruction_cb]
        )
        # After the model has learned to reconstruct the masked region, we move on to the next mask

# Start progressive training
progressive_training(unet_model, brain_image, epochs_per_mask=1000)


In [None]:
# Save the trained model
unet_model.save('jaadu2.keras')

In [None]:
import numpy as np
import cv2
import tensorflow as tf
from sklearn.metrics import jaccard_score, f1_score
import matplotlib.pyplot as plt
from tensorflow.image import ssim

def ssim_metric(y_true, y_pred):
    return tf.reduce_mean(ssim(y_true, y_pred, max_val=1.0))
# Load the trained model
unet_model = tf.keras.models.load_model('jaadu2.keras',custom_objects={'loss': combined_loss(alpha=0.8), 'ssim_metric': ssim_metric})

# Load the anomaly brain image and the segmentation mask
anomaly_image_path = r"C:\Users\priya\Documents\DL project\test\processed_image.png"  # Replace with your actual path
segment_mask_path = r"C:\Users\priya\Documents\DL project\test\processed_mask.png"  # Replace with your actual path

# Load images
anomaly_image = cv2.imread(anomaly_image_path, cv2.IMREAD_GRAYSCALE)
segment_mask = cv2.imread(segment_mask_path, cv2.IMREAD_GRAYSCALE)

# Preprocess the anomaly image: expand dimensions to fit the model's input shape (batch size, height, width, channels)
anomaly_image_exp = np.expand_dims(anomaly_image, axis=(0, -1))  # Add batch and channel dimensions

# Ensure the segmentation mask is binary (threshold to create binary mask)
segment_mask = (segment_mask > 0).astype(np.uint8)

# Predict the clean brain image by reconstructing the anomaly image
reconstructed_image = unet_model.predict(anomaly_image_exp)
reconstructed_image = np.squeeze(reconstructed_image)  # Remove batch dimension for visualization

# Difference between the original (anomaly) and reconstructed image
predicted_diff = np.abs(anomaly_image - reconstructed_image)

# Threshold the difference to get a binary mask for predicted anomaly
threshold = 0.4 * np.max(predicted_diff)  # Adjust based on sensitivity
predicted_segment = (predicted_diff > threshold).astype(np.uint8)

# Flatten both the predicted mask and the ground truth mask for metric calculation
predicted_segment_flat = predicted_segment.flatten()
segment_mask_flat = segment_mask.flatten()

# Compute the Intersection over Union (IoU) or Jaccard Index
iou = jaccard_score(segment_mask_flat, predicted_segment_flat, average='binary')

# Compute Dice coefficient (F1 score for binary segmentation)
dice = f1_score(segment_mask_flat, predicted_segment_flat, average='binary')

# Print evaluation metrics
print(f"IoU (Jaccard Index): {iou}")
print(f"Dice Coefficient: {dice}")

# Visualize the original image, segment mask, reconstructed image, and predicted segment area
plt.figure(figsize=(10, 10))

# Display the original anomaly brain image
plt.subplot(2, 2, 1)
plt.imshow(anomaly_image, cmap='gray')
plt.title('Original Anomaly Brain Image')
plt.axis('off')

# Display the ground truth segmentation mask
plt.subplot(2, 2, 2)
plt.imshow(segment_mask, cmap='gray')
plt.title('Ground Truth Segment Mask')
plt.axis('off')

# Display the reconstructed brain image (model output)
plt.subplot(2, 2, 3)
plt.imshow(reconstructed_image, cmap='gray')
plt.title('Reconstructed Clean Brain Image')
plt.axis('off')

# Display the predicted segment mask from the difference
plt.subplot(2, 2, 4)
plt.imshow(predicted_segment, cmap='gray')
plt.title('Predicted Segment Mask')
plt.axis('off')

plt.tight_layout()
plt.show()
