In [None]:
import os
import nibabel as nib
import numpy as np
import matplotlib.pyplot as plt
from skimage.transform import resize
import matplotlib.image as mpimg

# Define dataset directory
dataset_dir = r"C:\Users\priya\Documents\DL project\BraTS2021_00061"

# Define paths for all image types
t1_path = os.path.join(dataset_dir, 'BraTS2021_00061_t1.nii.gz')
t1ce_path = os.path.join(dataset_dir, 'BraTS2021_00061_t1ce.nii.gz')
flair_path = os.path.join(dataset_dir, 'BraTS2021_00061_flair.nii.gz')
t2_path = os.path.join(dataset_dir, 'BraTS2021_00061_t2.nii.gz')

# Load all image types
t1_img = nib.load(t1_path).get_fdata()
t1ce_img = nib.load(t1ce_path).get_fdata()
flair_img = nib.load(flair_path).get_fdata()
t2_img = nib.load(t2_path).get_fdata()

# Check shape (they should be the same for all modalities)
image_shape = t1_img.shape

# Set parameters for slice selection
start_slice = 15  # Skip the first 15 slices
end_slice = image_shape[2] - 15  # Skip the last 15 slices

# Target size for resizing
target_size = (512, 512)

# Directory to save processed images
output_dir = 'processed_images'
os.makedirs(output_dir, exist_ok=True)

# Initialize slice counter
slice_counter = 1

# Function to normalize and resize image slices
def process_and_save_slice(image_slice, slice_counter):
    # Normalize the slice to [0, 1] range
    image_slice = (image_slice - np.min(image_slice)) / (np.max(image_slice) - np.min(image_slice))
    
    # Resize the image to the target size (512x512)
    resized_image = resize(image_slice, target_size, mode='reflect', anti_aliasing=True)
    
    # Save the image as a .png file
    output_path = os.path.join(output_dir, f'{slice_counter}.png')
    plt.imsave(output_path, resized_image, cmap='gray')
    
    return slice_counter + 1

# Process and save slices from all modalities
for slice_idx in range(start_slice, end_slice):
    # Get corresponding slices from each modality
    t1_slice = t1_img[:, :, slice_idx]
    t1ce_slice = t1ce_img[:, :, slice_idx]
    flair_slice = flair_img[:, :, slice_idx]
    t2_slice = t2_img[:, :, slice_idx]
    
    # Process and save each slice
    slice_counter = process_and_save_slice(t1_slice, slice_counter)
    slice_counter = process_and_save_slice(t1ce_slice, slice_counter)
    slice_counter = process_and_save_slice(flair_slice, slice_counter)
    slice_counter = process_and_save_slice(t2_slice, slice_counter)

print(f"Processed images saved in '{output_dir}'")


In [None]:
import os
import nibabel as nib
import numpy as np
import matplotlib.pyplot as plt
from skimage.transform import resize
import matplotlib.image as mpimg

# Define dataset directory
dataset_dir = r"C:\Users\priya\Documents\DL project\BraTS2021_00061"

# Define paths for all image types
t1_path = os.path.join(dataset_dir, 'BraTS2021_00061_t1.nii.gz')
#t1ce_path = os.path.join(dataset_dir, 'BraTS2021_00061_t1ce.nii.gz')
#flair_path = os.path.join(dataset_dir, 'BraTS2021_00061_flair.nii.gz')
#t2_path = os.path.join(dataset_dir, 'BraTS2021_00061_t2.nii.gz')

# Load all image types
t1_img = nib.load(t1_path).get_fdata()
#t1ce_img = nib.load(t1ce_path).get_fdata()
#flair_img = nib.load(flair_path).get_fdata()
#t2_img = nib.load(t2_path).get_fdata()

# Check shape (they should be the same for all modalities)
image_shape = t1_img.shape

# Set parameters for slice selection
start_slice = 15  # Skip the first 15 slices
end_slice = image_shape[2] - 15  # Skip the last 15 slices

# Target size for resizing
target_size = (512, 512)

# Directory to save processed images
output_dir = 'processed_images_t1'
os.makedirs(output_dir, exist_ok=True)

# Initialize slice counter
slice_counter = 1

# Function to normalize and resize image slices
def process_and_save_slice(image_slice, slice_counter):
    # Normalize the slice to [0, 1] range
    image_slice = (image_slice - np.min(image_slice)) / (np.max(image_slice) - np.min(image_slice))
    
    # Resize the image to the target size (512x512)
    resized_image = resize(image_slice, target_size, mode='reflect', anti_aliasing=True)
    
    # Save the image as a .png file
    output_path = os.path.join(output_dir, f'{slice_counter}.png')
    plt.imsave(output_path, resized_image, cmap='gray')
    
    return slice_counter + 1

# Process and save slices from all modalities
for slice_idx in range(start_slice, end_slice):
    # Get corresponding slices from each modality
    t1_slice = t1_img[:, :, slice_idx]
    #t1ce_slice = t1ce_img[:, :, slice_idx]
    #flair_slice = flair_img[:, :, slice_idx]
    #t2_slice = t2_img[:, :, slice_idx]
    
    # Process and save each slice
    slice_counter = process_and_save_slice(t1_slice, slice_counter)
    #slice_counter = process_and_save_slice(t1ce_slice, slice_counter)
    #slice_counter = process_and_save_slice(flair_slice, slice_counter)
    #slice_counter = process_and_save_slice(t2_slice, slice_counter)

print(f"Processed images saved in '{output_dir}'")


In [None]:
import tensorflow as tf
from tensorflow.keras import layers, Model, losses
import albumentations as A
import cv2
import numpy as np
import matplotlib.pyplot as plt
import os
import time  # Add time module
from tensorflow.keras.preprocessing.image import smart_resize

# Define the weighted VAE loss function
def vae_loss(inputs, outputs, mu, log_var):
    reconstruction_loss = tf.reduce_mean(losses.binary_crossentropy(inputs, outputs))
    kl_loss = -0.5 * tf.reduce_mean(1 + log_var - tf.square(mu) - tf.exp(log_var))
    return reconstruction_loss + 0.25 * kl_loss  # Weight reconstruction more heavily

# Build the VAE model
def build_vae(input_shape):
    # Encoder
    inputs = layers.Input(shape=input_shape)
    x = layers.Conv2D(32, (3, 3), activation='relu', padding='same')(inputs)
    x = layers.MaxPooling2D((2, 2), padding='same')(x)
    x = layers.Conv2D(64, (3, 3), activation='relu', padding='same')(x)
    x = layers.MaxPooling2D((2, 2), padding='same')(x)
    x = layers.Conv2D(128, (3, 3), activation='relu', padding='same')(x)
    x = layers.MaxPooling2D((2, 2), padding='same')(x)
    x = layers.Flatten()(x)
    
    mu = layers.Dense(128)(x)
    log_var = layers.Dense(128)(x)

    # Sampling function
    def sampling(args):
        mu, log_var = args
        batch = tf.shape(mu)[0]
        dim = tf.shape(mu)[1]
        epsilon = tf.random.normal(shape=(batch, dim))
        return mu + tf.exp(0.5 * log_var) * epsilon

    z = layers.Lambda(sampling)([mu, log_var])

    # Decoder
    decoder_input = layers.Input(shape=(128,))
    x = layers.Dense(64 * 64 * 128, activation='relu')(decoder_input)
    x = layers.Reshape((64, 64, 128))(x)
    x = layers.Conv2DTranspose(128, (3, 3), activation='relu', padding='same')(x)
    x = layers.UpSampling2D((2, 2))(x)
    x = layers.Conv2DTranspose(64, (3, 3), activation='relu', padding='same')(x)
    x = layers.UpSampling2D((2, 2))(x)
    x = layers.Conv2DTranspose(32, (3, 3), activation='relu', padding='same')(x)
    x = layers.UpSampling2D((2, 2))(x)
    outputs = layers.Conv2DTranspose(1, (3, 3), activation='sigmoid', padding='same')(x)

    encoder = Model(inputs, [z, mu, log_var], name='encoder')
    decoder = Model(decoder_input, outputs, name='decoder')

    vae_outputs = decoder(encoder(inputs)[0])
    vae = Model(inputs, vae_outputs, name='vae')

    # Add loss function
    vae.add_loss(vae_loss(inputs, vae_outputs, encoder(inputs)[1], encoder(inputs)[2]))
    vae.compile(optimizer='adam')

    return vae, encoder, decoder

# Function to load preprocessed PNG slices from your dataset directory
def load_slices_from_png(folder_path):
    all_slices = []
    for filename in sorted(os.listdir(folder_path)):
        if filename.endswith('.png'):
            img = cv2.imread(os.path.join(folder_path, filename), cv2.IMREAD_GRAYSCALE)
            img = img / 255.0  # Normalize to [0,1]
            img = np.expand_dims(img, axis=-1)  # Add channel dimension (512, 512, 1)
            all_slices.append(img)
    return np.array(all_slices)

# Function to inject anomaly into an image
def inject_anomaly(image, window_size=200):
    h, w = image.shape[:2]
    # Select a random window area from the image
    x, y = np.random.randint(0, h - window_size), np.random.randint(0, w - window_size)
    anomaly_window = image[x:x+window_size, y:y+window_size]

    # Apply contrast change and elastic deformation
    anomaly_window = A.RandomBrightnessContrast(p=1.0, brightness_limit=(-.4,0.4), contrast_limit=(-0.6,0.6))(image=anomaly_window)["image"]
    anomaly_window = A.ElasticTransform(p=1.0, alpha=50, sigma=50)(image=anomaly_window)["image"]

    # Create a random mask (you can modify this to get different shapes)
    mask = np.zeros_like(anomaly_window, dtype=np.uint8)
    num_shapes = np.random.randint(1, 4)  # Number of shapes to create a random mask

    for _ in range(num_shapes):
        shape_type = np.random.choice(['ellipse', 'polygon'])
        if shape_type == 'ellipse':
            center = (np.random.randint(0, window_size), np.random.randint(0, window_size))
            axes = (np.random.randint(10, window_size // 2), np.random.randint(10, window_size // 2))
            angle = np.random.randint(0, 180)
            cv2.ellipse(mask, center, axes, angle, 0, 360, (255, 255, 255), -1)
        elif shape_type == 'polygon':
            num_points = np.random.randint(3, 7)
            points = np.array([[
                (np.random.randint(0, window_size), np.random.randint(0, window_size))
                for _ in range(num_points)
            ]], dtype=np.int32)
            cv2.fillPoly(mask, points, (255, 255, 255))

    # Apply mask to the anomaly window
    masked_anomaly = cv2.bitwise_and(anomaly_window, anomaly_window, mask=mask)

    # Place the modified, random-shaped window back into the original image
    anomaly_image = image.copy()
    new_x, new_y = np.random.randint(0, h - window_size), np.random.randint(0, w - window_size)

    # Insert the masked anomaly at the new location
    window_region = anomaly_image[new_x:new_x+window_size, new_y:new_y+window_size]
    np.copyto(window_region, masked_anomaly, where=mask.astype(bool))

    return anomaly_image, masked_anomaly
# Albumentations augmentation
def augment_image(image):
    # Ensure that the image is a 2D or 3D array
    if image.ndim == 3 and image.shape[-1] == 1:
        image = image[..., 0]  # Remove the channel if it is a grayscale image with a single channel

    image = (image * 255).astype(np.uint8)  # Convert to uint8
    transform = A.Compose([
         A.AdvancedBlur(p=0.5),
         A.CLAHE(p=0.5),
         A.Downscale(p=0.5),
         A.Emboss(p=0.5),
         A.Equalize(p=0.5),
        # A.FancyPCA(p=0.5),
         A.GaussNoise(p=0.5),
         A.RandomBrightnessContrast(p=0.5),
        # A.CoarseDropout(p=0.5),
        # A.PixelDropout(p=0.5)
    ])
    augmented = transform(image=image)
    return augmented['image'] / 255.0

# Load dataset from PNG images
folder_path = r"cropped"  
slices = load_slices_from_png(folder_path)

# Resize to (512, 512) if needed
slices = [smart_resize(slice_img, (512, 512)) for slice_img in slices]
slices = np.array(slices)

# Instantiate the VAE
vae, encoder, decoder = build_vae(input_shape=(512, 512, 1))

# Training loop
epochs = 5000
batch_size = 64

vae.summary()

for epoch in range(epochs + 1):
    start_time = time.time()  # Start timer

    for slice_img in slices:
        original_img = slice_img.copy()  # Keep the original image
        
        # Inject anomaly every 10 epochs
        if epoch % 10 == 0:
            anomaly_img, _ = inject_anomaly(np.squeeze(slice_img))
        else:
            anomaly_img = slice_img

        anomaly_img = np.expand_dims(np.squeeze(anomaly_img), axis=-1)  # Ensure channel dimension
        augmented_img = augment_image(anomaly_img)
        augmented_img = np.expand_dims(augmented_img, axis=0)  # Add batch dimension

        original_img = np.expand_dims(original_img, axis=-1)
        original_img = np.expand_dims(original_img, axis=0)

        # Train the VAE on the augmented image and its reconstruction target
        loss = vae.train_on_batch(augmented_img, original_img)

    # Calculate epoch duration
    epoch_time = time.time() - start_time

    # Save model every 100 epochs
    if epoch % 100 == 0:
        vae.save(f'vae_model_epoch_{epoch}.h5')

    # Display images and print loss every 50 epochs
    if epoch % 10 == 0:
        print(f"Epoch {epoch}/{epochs}, Loss: {loss:.4f}, Time: {epoch_time:.2f}s")

        # Reconstruct the image
        reconstructed_img = vae.predict(augmented_img)

        # Plot the images
        fig, axes = plt.subplots(1, 4, figsize=(20, 5))
        
        # Original image
        axes[0].imshow(np.squeeze(original_img), cmap='gray')
        axes[0].set_title("Original Image")
        
        # Anomaly image
        axes[1].imshow(np.squeeze(anomaly_img), cmap='gray')
        axes[1].set_title("Anomaly Image")
        
        # Augmented image
        axes[2].imshow(np.squeeze(augmented_img), cmap='gray')
        axes[2].set_title("Augmented Image")
        
        # Reconstructed image
        axes[3].imshow(np.squeeze(reconstructed_img), cmap='gray')
        axes[3].set_title("Reconstructed Image")
        
        plt.show()

In [None]:
import tensorflow as tf
from tensorflow.keras import layers, Model, losses
import albumentations as A
import cv2
import numpy as np
import matplotlib.pyplot as plt
import os
import time  # Add time module
from tensorflow.keras.preprocessing.image import smart_resize

def dice_loss(y_true, y_pred):
    intersection = tf.reduce_sum(y_true * y_pred)
    union = tf.reduce_sum(y_true) + tf.reduce_sum(y_pred)
    return 1 - (2. * intersection + 1) / (union + 1)

# Define IoU loss
def iou_loss(y_true, y_pred):
    intersection = tf.reduce_sum(y_true * y_pred)
    total = tf.reduce_sum(y_true) + tf.reduce_sum(y_pred)
    union = total - intersection
    return 1 - (intersection + 1) / (union + 1)

from tensorflow.keras import layers, Model
import tensorflow as tf

def build_vae(input_shape):
    # Encoder
    inputs = layers.Input(shape=input_shape)
    x = layers.Conv2D(32, (3, 3), activation='relu', padding='same')(inputs)
    x = layers.MaxPooling2D((2, 2), padding='same')(x)
    x = layers.Conv2D(64, (3, 3), activation='relu', padding='same')(x)
    x = layers.MaxPooling2D((2, 2), padding='same')(x)
    x = layers.Conv2D(128, (3, 3), activation='relu', padding='same')(x)
    x = layers.MaxPooling2D((2, 2), padding='same')(x)
    x = layers.Conv2D(256, (3, 3), activation='relu', padding='same')(x)
    x = layers.MaxPooling2D((2, 2), padding='same')(x)
    
    # Bottleneck layers
    mu = layers.Conv2D(256, (3, 3), activation='relu', padding='same')(x)
    log_var = layers.Conv2D(256, (3, 3), activation='relu', padding='same')(x)

    # Sampling function
    def sampling(args):
        mu, log_var = args
        epsilon = tf.random.normal(tf.shape(mu))
        return mu + tf.exp(0.5 * log_var) * epsilon

    z = layers.Lambda(sampling)([mu, log_var])

    # Decoder starts with an input placeholder for `z`
    decoder_input = layers.Input(shape=z.shape[1:])
    x = layers.Conv2DTranspose(256, (3, 3), activation='relu', padding='same')(decoder_input)
    x = layers.UpSampling2D((2, 2))(x)
    x = layers.Conv2DTranspose(128, (3, 3), activation='relu', padding='same')(x)
    x = layers.UpSampling2D((2, 2))(x)
    x = layers.Conv2DTranspose(64, (3, 3), activation='relu', padding='same')(x)
    x = layers.UpSampling2D((2, 2))(x)
    x = layers.Conv2DTranspose(32, (3, 3), activation='relu', padding='same')(x)
    x = layers.UpSampling2D((2, 2))(x)
    outputs = layers.Conv2DTranspose(1, (3, 3), activation='sigmoid', padding='same')(x)

    # Build encoder and decoder models
    encoder = Model(inputs, [z, mu, log_var], name='encoder')
    decoder = Model(decoder_input, outputs, name='decoder')

    # VAE model with connected encoder and decoder
    vae_outputs = decoder(encoder(inputs)[0])
    vae = Model(inputs, vae_outputs, name='bigger_vae')

    # Add VAE loss and compile
    vae.add_loss(vae_loss(inputs, vae_outputs, encoder(inputs)[1], encoder(inputs)[2]))
    vae.compile(optimizer='adam')

    return vae, encoder, decoder


# Function to load preprocessed PNG slices from your dataset directory
def load_slices_from_png(folder_path):
    all_slices = []
    for filename in sorted(os.listdir(folder_path)):
        if filename.endswith('.png'):
            img = cv2.imread(os.path.join(folder_path, filename), cv2.IMREAD_GRAYSCALE)
            img = img / 255.0  # Normalize to [0,1]
            img = np.expand_dims(img, axis=-1)  # Add channel dimension (512, 512, 1)
            all_slices.append(img)
    return np.array(all_slices)

# Function to inject anomaly into an image
def inject_anomaly(image, window_size=200):
    h, w = image.shape[:2]
    # Select a random window area from the image
    x, y = np.random.randint(0, h - window_size), np.random.randint(0, w - window_size)
    anomaly_window = image[x:x+window_size, y:y+window_size]

    # Apply contrast change and elastic deformation
    anomaly_window = A.RandomBrightnessContrast(p=1.0, brightness_limit=(-.4,0.4), contrast_limit=(-0.6,0.6))(image=anomaly_window)["image"]
    anomaly_window = A.ElasticTransform(p=1.0, alpha=50, sigma=50)(image=anomaly_window)["image"]

    # Create a random mask (you can modify this to get different shapes)
    mask = np.zeros_like(anomaly_window, dtype=np.uint8)
    num_shapes = np.random.randint(1, 4)  # Number of shapes to create a random mask

    for _ in range(num_shapes):
        shape_type = np.random.choice(['ellipse', 'polygon'])
        if shape_type == 'ellipse':
            center = (np.random.randint(0, window_size), np.random.randint(0, window_size))
            axes = (np.random.randint(10, window_size // 2), np.random.randint(10, window_size // 2))
            angle = np.random.randint(0, 180)
            cv2.ellipse(mask, center, axes, angle, 0, 360, (255, 255, 255), -1)
        elif shape_type == 'polygon':
            num_points = np.random.randint(3, 7)
            points = np.array([[
                (np.random.randint(0, window_size), np.random.randint(0, window_size))
                for _ in range(num_points)
            ]], dtype=np.int32)
            cv2.fillPoly(mask, points, (255, 255, 255))

    # Apply mask to the anomaly window
    masked_anomaly = cv2.bitwise_and(anomaly_window, anomaly_window, mask=mask)

    # Place the modified, random-shaped window back into the original image
    anomaly_image = image.copy()
    new_x, new_y = np.random.randint(0, h - window_size), np.random.randint(0, w - window_size)

    # Insert the masked anomaly at the new location
    window_region = anomaly_image[new_x:new_x+window_size, new_y:new_y+window_size]
    np.copyto(window_region, masked_anomaly, where=mask.astype(bool))

    return anomaly_image, masked_anomaly
# Albumentations augmentation
def augment_image(image):
    # Ensure that the image is a 2D or 3D array
    if image.ndim == 3 and image.shape[-1] == 1:
        image = image[..., 0]  # Remove the channel if it is a grayscale image with a single channel

    image = (image * 255).astype(np.uint8)  # Convert to uint8
    transform = A.Compose([
         A.AdvancedBlur(p=0.5),
         A.CLAHE(p=0.5),
         A.Downscale(p=0.5),
         A.Emboss(p=0.5),
         A.Equalize(p=0.5),
        # A.FancyPCA(p=0.5),
         A.GaussNoise(p=0.5),
         A.RandomBrightnessContrast(p=0.5),
        # A.CoarseDropout(p=0.5),
        # A.PixelDropout(p=0.5)
    ])
    augmented = transform(image=image)
    return augmented['image'] / 255.0

# Load dataset from PNG images
folder_path = r"cropped"  
slices = load_slices_from_png(folder_path)

# Resize to (512, 512) if needed
slices = [smart_resize(slice_img, (512, 512)) for slice_img in slices]
slices = np.array(slices)

# Instantiate the VAE
vae, encoder, decoder = build_vae(input_shape=(512, 512, 1))

# Training loop
epochs = 5000
batch_size = 64

vae.summary()

for epoch in range(epochs + 1):
    start_time = time.time()  # Start timer

    for slice_img in slices:
        original_img = slice_img.copy()  # Keep the original image
        
        # Inject anomaly every 10 epochs
        if epoch % 10 == 0:
            anomaly_img, _ = inject_anomaly(np.squeeze(slice_img))
        else:
            anomaly_img = slice_img

        anomaly_img = np.expand_dims(np.squeeze(anomaly_img), axis=-1)  # Ensure channel dimension
        augmented_img = augment_image(anomaly_img)
        augmented_img = np.expand_dims(augmented_img, axis=0)  # Add batch dimension

        original_img = np.expand_dims(original_img, axis=-1)
        original_img = np.expand_dims(original_img, axis=0)

        # Train the VAE on the augmented image and its reconstruction target
        loss = vae.train_on_batch(augmented_img, original_img)

    # Calculate epoch duration
    epoch_time = time.time() - start_time

    # Save model every 100 epochs
    if epoch % 100 == 0:
        vae.save(f'vae_model_epoch_{epoch}.h5')

    # Display images and print loss every 50 epochs
    if epoch % 10 == 0:
        print(f"Epoch {epoch}/{epochs}, Loss: {loss:.4f}, Time: {epoch_time:.2f}s")

        # Reconstruct the image
        reconstructed_img = vae.predict(augmented_img)

        # Plot the images
        fig, axes = plt.subplots(1, 4, figsize=(20, 5))
        
        # Original image
        axes[0].imshow(np.squeeze(original_img), cmap='gray')
        axes[0].set_title("Original Image")
        
        # Anomaly image
        axes[1].imshow(np.squeeze(anomaly_img), cmap='gray')
        axes[1].set_title("Anomaly Image")
        
        # Augmented image
        axes[2].imshow(np.squeeze(augmented_img), cmap='gray')
        axes[2].set_title("Augmented Image")
        
        # Reconstructed image
        axes[3].imshow(np.squeeze(reconstructed_img), cmap='gray')
        axes[3].set_title("Reconstructed Image")
        
        plt.show()

In [None]:
import tensorflow as tf
from tensorflow.keras import layers, Model, losses
import albumentations as A
import cv2
import numpy as np
import matplotlib.pyplot as plt
import os
import time  # Add time module
from tensorflow.keras.preprocessing.image import smart_resize

# Define the weighted VAE loss function
def vae_loss(inputs, outputs, mu, log_var):
    reconstruction_loss = tf.reduce_mean(losses.binary_crossentropy(inputs, outputs))
    kl_loss = -0.5 * tf.reduce_mean(1 + log_var - tf.square(mu) - tf.exp(log_var))
    return reconstruction_loss + 0.75 * kl_loss  # Weight reconstruction more heavily

# Build the VAE model
def build_vae(input_shape):
    # Encoder
    inputs = layers.Input(shape=input_shape)
    x = layers.Conv2D(32, (3, 3), activation='relu', padding='same')(inputs)
    x = layers.MaxPooling2D((2, 2), padding='same')(x)
    x = layers.Conv2D(64, (3, 3), activation='relu', padding='same')(x)
    x = layers.MaxPooling2D((2, 2), padding='same')(x)
    x = layers.Conv2D(128, (3, 3), activation='relu', padding='same')(x)
    x = layers.MaxPooling2D((2, 2), padding='same')(x)
    x = layers.Flatten()(x)
    
    mu = layers.Dense(128)(x)
    log_var = layers.Dense(128)(x)

    # Sampling function
    def sampling(args):
        mu, log_var = args
        batch = tf.shape(mu)[0]
        dim = tf.shape(mu)[1]
        epsilon = tf.random.normal(shape=(batch, dim))
        return mu + tf.exp(0.5 * log_var) * epsilon

    z = layers.Lambda(sampling)([mu, log_var])

    # Decoder
    decoder_input = layers.Input(shape=(128,))
    x = layers.Dense(64 * 64 * 128, activation='relu')(decoder_input)
    x = layers.Reshape((64, 64, 128))(x)
    x = layers.Conv2DTranspose(128, (3, 3), activation='relu', padding='same')(x)
    x = layers.UpSampling2D((2, 2))(x)
    x = layers.Conv2DTranspose(64, (3, 3), activation='relu', padding='same')(x)
    x = layers.UpSampling2D((2, 2))(x)
    x = layers.Conv2DTranspose(32, (3, 3), activation='relu', padding='same')(x)
    x = layers.UpSampling2D((2, 2))(x)
    outputs = layers.Conv2DTranspose(1, (3, 3), activation='sigmoid', padding='same')(x)

    encoder = Model(inputs, [z, mu, log_var], name='encoder')
    decoder = Model(decoder_input, outputs, name='decoder')

    vae_outputs = decoder(encoder(inputs)[0])
    vae = Model(inputs, vae_outputs, name='vae')

    # Add loss function
    vae.add_loss(vae_loss(inputs, vae_outputs, encoder(inputs)[1], encoder(inputs)[2]))
    vae.compile(optimizer='adam')

    return vae, encoder, decoder

# Function to load preprocessed PNG slices from your dataset directory
def load_slices_from_png(folder_path):
    all_slices = []
    for filename in sorted(os.listdir(folder_path)):
        if filename.endswith('.png'):
            img = cv2.imread(os.path.join(folder_path, filename), cv2.IMREAD_GRAYSCALE)
            img = img / 255.0  # Normalize to [0,1]
            img = np.expand_dims(img, axis=-1)  # Add channel dimension (512, 512, 1)
            all_slices.append(img)
    return np.array(all_slices)

# Function to inject anomaly into an image
def deform_and_inject_anomaly(image, window_size=200):
    h, w = image.shape[:2]
    
    # Step 1: Apply global elastic deformation to the entire image
    deformed_image = A.ElasticTransform(p=0.7, alpha=100, sigma=100)(image=image)["image"]
    
    # Step 2: Select a random window area from the deformed image for anomaly injection
    x, y = np.random.randint(0, h - window_size), np.random.randint(0, w - window_size)
    anomaly_window = deformed_image[x:x+window_size, y:y+window_size]
    
    # Apply contrast change and elastic deformation specifically to the anomaly window
    anomaly_window = A.RandomBrightnessContrast(p=1.0, brightness_limit=(-0.4, 0.4), contrast_limit=(-0.6, 0.6))(image=anomaly_window)["image"]
    anomaly_window = A.ElasticTransform(p=0.7, alpha=50, sigma=50)(image=anomaly_window)["image"]
    
    # Create a random mask for a unique anomaly shape
    mask = np.zeros_like(anomaly_window, dtype=np.uint8)
    num_shapes = np.random.randint(1, 4)  # Number of shapes to create the mask

    for _ in range(num_shapes):
        shape_type = np.random.choice(['ellipse', 'polygon'])
        if shape_type == 'ellipse':
            center = (np.random.randint(0, window_size), np.random.randint(0, window_size))
            axes = (np.random.randint(10, window_size // 2), np.random.randint(10, window_size // 2))
            angle = np.random.randint(0, 180)
            cv2.ellipse(mask, center, axes, angle, 0, 360, (255, 255, 255), -1)
        elif shape_type == 'polygon':
            num_points = np.random.randint(3, 7)
            points = np.array([[
                (np.random.randint(0, window_size), np.random.randint(0, window_size))
                for _ in range(num_points)
            ]], dtype=np.int32)
            cv2.fillPoly(mask, points, (255, 255, 255))

    # Apply mask to the anomaly window
    masked_anomaly = cv2.bitwise_and(anomaly_window, anomaly_window, mask=mask)

    # Place the modified, random-shaped window back into the deformed image
    anomaly_image = deformed_image.copy()
    new_x, new_y = np.random.randint(0, h - window_size), np.random.randint(0, w - window_size)

    # Insert the masked anomaly at the new location
    window_region = anomaly_image[new_x:new_x+window_size, new_y:new_y+window_size]
    np.copyto(window_region, masked_anomaly, where=mask.astype(bool))

    return anomaly_image, masked_anomaly

# Albumentations augmentation
def augment_image(image):
    # Ensure that the image is a 2D or 3D array
    if image.ndim == 3 and image.shape[-1] == 1:
        image = image[..., 0]  # Remove the channel if it is a grayscale image with a single channel

    image = (image * 255).astype(np.uint8)  # Convert to uint8
    transform = A.Compose([
        #  A.AdvancedBlur(p=0.5),
        #  A.CLAHE(p=0.5),
        #  A.Downscale(p=0.5),
        #  A.Emboss(p=0.5),
        #  A.Equalize(p=0.5),
        # # A.FancyPCA(p=0.5),
        #  A.GaussNoise(p=0.5),
        #  A.RandomBrightnessContrast(p=0.5),
        # # A.CoarseDropout(p=0.5),
        # # A.PixelDropout(p=0.5)
    ])
    augmented = transform(image=image)
    return augmented['image'] / 255.0

# Load dataset from PNG images
folder_path = r"cropped"  
slices = load_slices_from_png(folder_path)

# Resize to (512, 512) if needed
slices = [smart_resize(slice_img, (512, 512)) for slice_img in slices]
slices = np.array(slices)

# Instantiate the VAE
vae, encoder, decoder = build_vae(input_shape=(512, 512, 1))

# Training loop
epochs = 5000
batch_size = 64

vae.summary()

for epoch in range(epochs + 1):
    start_time = time.time()  # Start timer

    for slice_img in slices:
        original_img = slice_img.copy()  # Keep the original image
        
        # Deform the image and inject anomaly every 10 epochs
        if epoch % 10 == 0:
            anomaly_img, _ = deform_and_inject_anomaly(np.squeeze(slice_img))
        else:
            anomaly_img = slice_img

        anomaly_img = np.expand_dims(np.squeeze(anomaly_img), axis=-1)  # Ensure channel dimension
        augmented_img = augment_image(anomaly_img)
        augmented_img = np.expand_dims(augmented_img, axis=0)  # Add batch dimension

        original_img = np.expand_dims(original_img, axis=-1)
        original_img = np.expand_dims(original_img, axis=0)

        # Train the VAE on the augmented image and its reconstruction target
        loss = vae.train_on_batch(augmented_img, original_img)

    # Calculate epoch duration
    epoch_time = time.time() - start_time

    # Save model every 100 epochs
    if epoch % 100 == 0:
        vae.save(f'vae_model_epoch_{epoch}.h5')

    # Display images and print loss every 50 epochs
    if epoch % 10 == 0:
        print(f"Epoch {epoch}/{epochs}, Loss: {loss:.4f}, Time: {epoch_time:.2f}s")

        # Reconstruct the image
        reconstructed_img = vae.predict(augmented_img)

        # Plot the images
        fig, axes = plt.subplots(1, 4, figsize=(20, 5))
        
        # Original image
        axes[0].imshow(np.squeeze(original_img), cmap='gray')
        axes[0].set_title("Original Image")
        
        # Anomaly image
        axes[1].imshow(np.squeeze(anomaly_img), cmap='gray')
        axes[1].set_title("Anomaly Image")
        
        # Augmented image
        axes[2].imshow(np.squeeze(augmented_img), cmap='gray')
        axes[2].set_title("Augmented Image")
        
        # Reconstructed image
        axes[3].imshow(np.squeeze(reconstructed_img), cmap='gray')
        axes[3].set_title("Reconstructed Image")
        
        plt.show()