In [1]:
import tensorflow as tf
from tensorflow.keras import layers, Model, losses
import albumentations as A
import cv2
import numpy as np
import matplotlib.pyplot as plt
import os
import time  # Add time module
from tensorflow.keras.preprocessing.image import smart_resize

# Define the weighted VAE loss function
def vae_loss(inputs, outputs, mu, log_var):
    reconstruction_loss = tf.reduce_mean(losses.binary_crossentropy(inputs, outputs))
    kl_loss = -0.5 * tf.reduce_mean(1 + log_var - tf.square(mu) - tf.exp(log_var))
    return reconstruction_loss + 0.25 * kl_loss  # Weight reconstruction more heavily

def multiscale_encoder(input_shape):
    inputs = layers.Input(shape=input_shape)

# Multi-scale convolution 
    branch_3x3 = layers.Conv2D(32, (3, 3), activation='relu', padding='same')(inputs)
    branch_5x5 = layers.Conv2D(32, (5, 5), activation='relu', padding='same')(inputs)
    branch_7x7 = layers.Conv2D(32, (7, 7), activation='relu', padding='same')(inputs)

    # Concatenate feature maps from all branches
    x = layers.Concatenate()([branch_3x3, branch_5x5, branch_7x7])

    # Downsampling and deeper feature extraction
    x = layers.MaxPooling2D((2, 2), padding='same')(x)
    x = layers.Conv2D(64, (3, 3), activation='relu', padding='same')(x)
    x = layers.MaxPooling2D((2, 2), padding='same')(x)
    x = layers.Conv2D(128, (3, 3), activation='relu', padding='same')(x)
    x = layers.MaxPooling2D((2, 2), padding='same')(x)
    x = layers.Flatten()(x)

    # Latent space
    mu = layers.Dense(128)(x)
    log_var = layers.Dense(128)(x)

    # Sampling
    def sampling(args):
        mu, log_var = args
        epsilon = tf.random.normal(shape=tf.shape(mu))
        return mu + tf.exp(0.5 * log_var) * epsilon

    z = layers.Lambda(sampling)([mu, log_var])

    return Model(inputs, [z, mu, log_var], name='multiscale_encoder')

# Decoder for the VAE
def build_decoder(latent_dim):
    decoder_input = layers.Input(shape=(latent_dim,))
    x = layers.Dense(64 * 64 * 128, activation='relu')(decoder_input)
    x = layers.Reshape((64, 64, 128))(x)
    x = layers.Conv2DTranspose(128, (3, 3), activation='relu', padding='same')(x)
    x = layers.UpSampling2D((2, 2))(x)
    x = layers.Conv2DTranspose(64, (3, 3), activation='relu', padding='same')(x)
    x = layers.UpSampling2D((2, 2))(x)
    x = layers.Conv2DTranspose(32, (3, 3), activation='relu', padding='same')(x)
    x = layers.UpSampling2D((2, 2))(x)
    outputs = layers.Conv2DTranspose(1, (3, 3), activation='sigmoid', padding='same')(x)

    return Model(decoder_input, outputs, name='decoder')

# Build the full VAE
def build_vae(input_shape):
    encoder = multiscale_encoder(input_shape)
    decoder = build_decoder(latent_dim=128)

    inputs = layers.Input(shape=input_shape)
    z, mu, log_var = encoder(inputs)
    reconstructed = decoder(z)

    vae = Model(inputs, reconstructed, name='vae')
    vae.add_loss(vae_loss(inputs, reconstructed, mu, log_var))
    vae.compile(optimizer='adam')

    return vae, encoder, decoder



import tkinter as tk
from tkinter import filedialog, Toplevel, Label
import numpy as np
import cv2
import tensorflow as tf
from tensorflow.keras import layers, Model, losses
from skimage.metrics import structural_similarity as ssim
from sklearn.metrics import precision_score, recall_score, f1_score, jaccard_score, roc_auc_score
from PIL import Image, ImageTk
from test_cropper import crop_brain_image
from test_cropper import crop_mask_image
from test_cropper import resize_image
from sklearn.metrics import confusion_matrix

from skimage.exposure import match_histograms

def normalize_contrast(original, reconstructed):
    # Match histograms without multichannel for single-channel images
    normalized = match_histograms(reconstructed, original)
    return normalized

def calculate_anomaly_mask(original, reconstructed, error_threshold=0.2):
    reconstructed_normalized = normalize_contrast(original, reconstructed)
    reconstruction_error = np.abs(original - reconstructed_normalized)
    anomaly_mask = reconstruction_error > error_threshold
    return anomaly_mask, reconstruction_error
# Main Tkinter GUI
class BrainImageProcessorApp:
    def __init__(self, root):
        self.root = root
        self.root.title("Brain Image Processor")
        
        # Set the window to fullscreen
        #self.root.attributes('-fullscreen', True)

        self.load_img_button = tk.Button(root, text="Load Image", command=self.load_image)
        self.load_img_button.pack()
        
        self.load_mask_button = tk.Button(root, text="Load Segmentation Mask", command=self.load_mask, state=tk.DISABLED)
        self.load_mask_button.pack()

        self.process_button = tk.Button(root, text="Process Image", command=self.process_image, state=tk.DISABLED)
        self.process_button.pack()
        
        self.vae, _, _ = build_vae(input_shape=(512,512,1))
        self.vae.load_weights("vae_model_epoch_2000.h5")
        
        self.threshold = 0.3
        self.mask = None
        self.crop_dims = None
        
        # Toggle fullscreen button
        self.toggle_fullscreen_button = tk.Button(root, text="Toggle Fullscreen", command=self.toggle_fullscreen)
        self.toggle_fullscreen_button.pack()
        
    def load_image(self):
        self.filepath = filedialog.askopenfilename()
        if self.filepath:
            self.process_button.config(state=tk.NORMAL)
            self.load_mask_button.config(state=tk.NORMAL)
            img = cv2.imread(self.filepath)
            self.original_image, self.crop_dims = crop_brain_image(img)
            resized_image = resize_image(self.original_image)
            self.display_image(resized_image, "Original Image")

    def load_mask(self):
        mask_path = filedialog.askopenfilename()
        if mask_path:
            self.mask = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE)
            cropped_mask = crop_mask_image(self.mask, self.crop_dims)
            resized_mask = resize_image(cropped_mask)
            self.display_image(resized_mask, "Original Segmentation Mask")

    def display_image(self, image, title):
        new_window = Toplevel(self.root)  # Create a new Toplevel window
        new_window.title(title)
        
        # Ensure the image is in grayscale
        if len(image.shape) == 3 and image.shape[2] == 3:  # Check if image is RGB
            image = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)  # Convert to grayscale
    
        # Convert to PIL Image, then to ImageTk format
        image = Image.fromarray(image)
        image = ImageTk.PhotoImage(image)
        
        label = Label(new_window, image=image)
        label.image = image  # Keep a reference to the image to prevent garbage collection
        label.pack()


    # Update the `process_image` method to access `self.original_image`
    from sklearn.metrics import confusion_matrix

    def process_image(self):
        if self.original_image is not None:
            # Resize and preprocess the input image
            image = resize_image(self.original_image, size=(512, 512))
            if image.ndim == 3 and image.shape[2] == 3:
                image = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
            image = np.expand_dims(image, axis=-1)  # Add channel dimension
            image = image / 255.0  # Normalize
            image = np.expand_dims(image, axis=0)
    
            # Save initial weights of the VAE
            initial_weights = self.vae.get_weights()
    
            # Reset weights to the initial state before fine-tuning
            self.vae.set_weights(initial_weights)
    
            # Fine-tune VAE on the single image
            for epoch in range(250):  # Fine-tune for 250 epochs
                self.vae.train_on_batch(image, image)
                print(f"Adaptation Epoch {epoch + 1}/250 completed.")
    
            # Predict reconstructed image
            reconstructed = self.vae.predict(image)
    
            # Normalize contrast and calculate reconstruction errors
            reconstructed_normalized = normalize_contrast(image, reconstructed)
            anomaly_mask, reconstruction_error = calculate_anomaly_mask(
                np.squeeze(image), np.squeeze(reconstructed_normalized), error_threshold=0.1
            )
    
            # Calculate metrics
            mean_error = reconstruction_error.mean()
            predicted_label = int(mean_error > self.threshold)
    
            # Display reconstructed image and anomaly mask
            self.display_image(np.squeeze(reconstructed_normalized) * 255, "Reconstructed Image")
            self.display_image(anomaly_mask.astype(np.uint8) * 255, "Anomaly Mask")
    
            if self.mask is not None:
                # Calculate additional metrics if the segmentation mask is provided
                mask_flat = self.mask.flatten() > 0
                pred_flat = anomaly_mask.flatten()
    
                precision = precision_score(mask_flat, pred_flat)
                recall = recall_score(mask_flat, pred_flat)
                f1 = f1_score(mask_flat, pred_flat)
                jaccard = jaccard_score(mask_flat, pred_flat)
                roc_auc = roc_auc_score(mask_flat, reconstruction_error.flatten())
    
                result_text = (
                    f"Mean Reconstruction Error: {mean_error:.4f}\n"
                    f"Precision: {precision:.4f}\n"
                    f"Recall: {recall:.4f}\n"
                    f"F1-Score: {f1:.4f}\n"
                    f"Jaccard Score: {jaccard:.4f}\n"
                    f"ROC-AUC: {roc_auc:.4f}\n"
                    f"Predicted Label: {'Anomalous' if predicted_label else 'Normal'}"
                )
                self.show_metrics(result_text)

    def show_metrics(self, text):
        metrics_window = Toplevel(self.root)
        metrics_window.title("Metrics")
        metrics_label = Label(metrics_window, text=text, font=("Arial", 12), justify=tk.LEFT)
        metrics_label.pack()

    def toggle_fullscreen(self):
        current_state = self.root.attributes('-fullscreen')
        self.root.attributes('-fullscreen', not current_state)

# Run the application
root = tk.Tk()
app = BrainImageProcessorApp(root)
root.mainloop()


  check_for_updates()


Adaptation Epoch 1/250 completed.
Adaptation Epoch 2/250 completed.
Adaptation Epoch 3/250 completed.
Adaptation Epoch 4/250 completed.
Adaptation Epoch 5/250 completed.
Adaptation Epoch 6/250 completed.
Adaptation Epoch 7/250 completed.
Adaptation Epoch 8/250 completed.
Adaptation Epoch 9/250 completed.
Adaptation Epoch 10/250 completed.
Adaptation Epoch 11/250 completed.
Adaptation Epoch 12/250 completed.
Adaptation Epoch 13/250 completed.
Adaptation Epoch 14/250 completed.
Adaptation Epoch 15/250 completed.
Adaptation Epoch 16/250 completed.
Adaptation Epoch 17/250 completed.
Adaptation Epoch 18/250 completed.
Adaptation Epoch 19/250 completed.
Adaptation Epoch 20/250 completed.
Adaptation Epoch 21/250 completed.
Adaptation Epoch 22/250 completed.
Adaptation Epoch 23/250 completed.
Adaptation Epoch 24/250 completed.
Adaptation Epoch 25/250 completed.
Adaptation Epoch 26/250 completed.
Adaptation Epoch 27/250 completed.
Adaptation Epoch 28/250 completed.
Adaptation Epoch 29/250 compl