In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
import os
import numpy as np
import matplotlib.pyplot as plt
from skimage import io, filters
from skimage.exposure import equalize_adapthist

In [None]:
# edit the location based on your dataset location
image_dir = os.path.join('/content/drive/MyDrive/brain_mri_dataset', 'images')
mask_dir = os.path.join('/content/drive/MyDrive/brain_mri_dataset', 'masks')

image_files = [os.path.join(image_dir, f) for f in os.listdir(image_dir) if f.endswith('.nii.gz')]
mask_files = [os.path.join(mask_dir, f) for f in os.listdir(mask_dir) if f.endswith('.nii.gz')]

images = [io.imread(f, plugin='nibabel') for f in image_files]
masks = [io.imread(f, plugin='nibabel') for f in mask_files]

In [None]:
clahe_images = []
for image in images:
    clahe_image = equalize_adapthist(image, clip_limit=0.03)
    clahe_images.append(clahe_image)

In [None]:
clahe_images = [image / image.max() for image in clahe_images]

In [None]:
from sklearn.model_selection import train_test_split

train_images, val_images, train_masks, val_masks = train_test_split(clahe_images, masks, test_size=0.2, random_state=42)

In [None]:
def attention_unet(input_shape):
    inputs = Input(input_shape)
    conv1 = Conv2D(32, (3, 3), activation='relu', padding='same')(inputs)
    conv1 = Conv2D(32, (3, 3), activation='relu', padding='same')(conv1)
    pool1 = MaxPooling2D(pool_size=(2, 2))(conv1)

    conv2 = Conv2D(64, (3, 3), activation='relu', padding='same')(pool1)
    conv2 = Conv2D(64, (3, 3), activation='relu', padding='same')(conv2)
    pool2 = MaxPooling2D(pool_size=(2, 2))(conv2)

    conv3 = Conv2D(128, (3, 3), activation='relu', padding='same')(pool2)
    conv3 = Conv2D(128, (3, 3), activation='relu', padding='same')(conv3)
    pool3 = MaxPooling2D(pool_size=(2, 2))(conv3)

    conv4 = Conv2D(256, (3, 3), activation='relu', padding='same')(pool3)
    conv4 = Conv2D(256, (3, 3), activation='relu', padding='same')(conv4)
    pool4 = MaxPooling2D(pool_size=(2, 2))(conv4)

    conv5 = Conv2D(512, (3, 3), activation='relu', padding='same')(pool4)
    conv5 = Conv2D(512, (3, 3), activation='relu', padding='same')(conv5)

    attention_block = AttentionBlock(conv5, conv4)
    up6 = concatenate([UpSampling2D(size=(2, 2))(attention_block), conv4], axis=3)
    conv6 = Conv2D(256, (3, 3), activation='relu', padding='same')(up6)
    conv6 = Conv2D(256, (3, 3), activation='relu', padding='same')(conv6)

    up7 = concatenate([UpSampling2D(size=(2, 2))(conv6), conv3], axis=3)
    conv7 = Conv2D(128, (3, 3), activation='relu', padding='same')(up7)
    conv7 = Conv2D(128, (3, 3), activation='relu', padding='same')(conv7)

    outputs = Conv2D(1, (1, 1), activation='sigmoid')(conv7)

    model = Model(inputs=[inputs], outputs=[outputs])
    return model

class AttentionBlock(tf.keras.layers.Layer):
    def __init__(self, input_shape):
        super(AttentionBlock, self).__init__()
        self.input_shape = input_shape
        self.conv1 = Conv2D(64, (1, 1), activation='relu')
        self.conv2 = Conv2D(64, (1, 1), activation='relu')
        self.conv3 = Conv2D(1, (1, 1), activation='sigmoid')

    def call(self, inputs):
        x = self.conv1(inputs)
        x = self.conv2(x)
        attention_weights = self.conv3(x)
        return attention_weights * inputs

In [None]:
nested_unet_model = nested_unet(input_shape=(256, 256, 1))
attention_unet_model = attention_unet(input_shape=(256, 256, 1))

nested_unet_model.compile(optimizer='adam', loss='binary_crossentropy', metrics=['accuracy'])
attention_unet_model.compile(optimizer='adam', loss='binary_crossentropy', metrics=['accuracy'])

In [None]:
nested_unet_history = nested_unet_model.fit(train_images, train_masks, epochs=10, batch_size=32, validation_data=(val_images, val_masks))
attention_unet_history = attention_unet_model.fit(train_images, train_masks, epochs=10, batch_size=32, validation_data=(val_images, val_masks))

In [None]:
nested_unet_loss, nested_unet_acc = nested_unet_model.evaluate(val_images, val_masks)
attention_unet_loss, attention_unet_acc = attention_unet_model.evaluate(val_images, val_masks)

print(f'Nested U-Net Loss: {nested_unet_loss:.4f}, Accuracy: {nested_unet_acc:.4f}')
print(f'Attention U-Net Loss: {attention_unet_loss:.4f}, Accuracy: {attention_unet_acc:.4f}')

In [None]:
from fastapi import FastAPI, File, UploadFile
from fastapi.responses import HTMLResponse
from fastapi.staticfiles import StaticFiles
from fastapi.templating import Jinja2Templates
from PIL import Image
import numpy as np
import tensorflow as tf

app = FastAPI()

templates = Jinja2Templates(directory="templates")

@app.post("/predict")
async def predict(file: UploadFile = File(...)):
    image = Image.open(file.file)
    image = image.resize((256, 256))
    image = np.array(image) / 255.0
    image = image.reshape((1, 256, 256, 1))

    model = tf.keras.models.load_model('best_model.h5')
    prediction = model.predict(image)

    return {"prediction": prediction}

In [None]:
import streamlit as st
import numpy as np
import matplotlib.pyplot as plt

st.title("Brain MRI Metastasis Segmentation")

uploaded_file = st.file_uploader("Upload Brain MRI Image", type=["nii.gz"])

if uploaded_file is not None:
    image = np.array(Image.open(uploaded_file))
    image = image.resize((256, 256))
    image = image / 255.0
    image = image.reshape((1, 256, 256, 1))

    model = tf.keras.models.load_model('best_model.h5')
    prediction = model.predict(image)

    st.write("Prediction:")
    plt.imshow(prediction[0, :, :, 0], cmap='gray')
    st.pyplot(plt)