In [None]:
#!pip install nibabel monai torch pydicom torchvision transformers tensorflow
#!pip install transformers[torch]

In [23]:
import os
import cv2
import pydicom
import numpy as np
from sklearn.model_selection import train_test_split
import matplotlib.pyplot as plt
import tensorflow as tf
from tensorflow.keras.layers import Input, Conv2D, MaxPooling2D, UpSampling2D, Concatenate, Dropout
from tensorflow.keras.models import Model
from tensorflow.keras.callbacks import EarlyStopping, ModelCheckpoint
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras.utils import to_categorical
from tensorflow.keras.callbacks import EarlyStopping, ModelCheckpoint
from sklearn.metrics import jaccard_score


In [None]:
# Function to load and resize DICOM images
def load_and_resize_dicom_images(folder_path, target_size=(256, 256)):
    dicom_files = [os.path.join(folder_path, f) for f in os.listdir(folder_path) if f.endswith('.dcm')]
    images = [cv2.resize(pydicom.dcmread(f).pixel_array, target_size) for f in dicom_files]
    return np.array(images)

# Function to load and resize JPG images
def load_and_resize_jpg_images(folder_path, target_size=(256, 256)):
    jpg_files = [os.path.join(folder_path, f) for f in os.listdir(folder_path) if f.endswith('.jpg') or f.endswith('.jpeg')]
    images = [cv2.resize(cv2.imread(f, cv2.IMREAD_GRAYSCALE), target_size) for f in jpg_files]
    return np.array(images)

# Assuming the structure as seen in the image
base_path = '/content/drive/MyDrive/Brain-MRI-Images-HF/ST000001'
subfolders = [os.path.join(base_path, d) for d in os.listdir(base_path) if os.path.isdir(os.path.join(base_path, d))]

all_images = []
for subfolder in subfolders:
    dicom_images = load_and_resize_dicom_images(subfolder)
    jpg_images = load_and_resize_jpg_images(subfolder)
    all_images.append(dicom_images)
    all_images.append(jpg_images)

# Flatten the list of images
all_images = np.concatenate(all_images, axis=0)

# Normalize the images
all_images = all_images.astype(np.float32) / np.max(all_images)

# Display a sample image to verify loading
plt.imshow(all_images[0], cmap='gray')
plt.title("Sample Image")
plt.show()


In [20]:
# Shape of all images
print(all_images.shape)

(763, 256, 256)


### Creating Binary Masks (if necessary)


In [26]:
def create_binary_masks(images, threshold=0.5):
    masks = (images > threshold * np.max(images)).astype(np.float32)
    return masks

# Create masks using the simple threshold
all_masks = create_binary_masks(all_images)

# Split images and masks into training and validation sets
X_train, X_val, y_train, y_val = train_test_split(all_images, all_masks, test_size=0.2, random_state=42)

# Adjust the shape of the data to include the channel dimension
X_train = np.expand_dims(X_train, axis=-1)
X_val = np.expand_dims(X_val, axis=-1)
y_train = np.expand_dims(y_train, axis=-1)
y_val = np.expand_dims(y_val, axis=-1)

print("X_train shape:", X_train.shape)
print("X_val shape:", X_val.shape)
print("y_train shape:", y_train.shape)
print("y_val shape:", y_val.shape)


X_train shape: (610, 256, 256, 1)
X_val shape: (153, 256, 256, 1)
y_train shape: (610, 256, 256, 1)
y_val shape: (153, 256, 256, 1)


### Build U-Net Model

In [27]:
def unet_model_with_dropout(input_size=(256, 256, 1)):
    inputs = Input(input_size)

    # Encoder
    conv1 = Conv2D(64, 3, activation='relu', padding='same')(inputs)
    conv1 = Dropout(0.1)(conv1)
    conv1 = Conv2D(64, 3, activation='relu', padding='same')(conv1)
    pool1 = MaxPooling2D(pool_size=(2, 2))(conv1)

    conv2 = Conv2D(128, 3, activation='relu', padding='same')(pool1)
    conv2 = Dropout(0.1)(conv2)
    conv2 = Conv2D(128, 3, activation='relu', padding='same')(conv2)
    pool2 = MaxPooling2D(pool_size=(2, 2))(conv2)

    conv3 = Conv2D(256, 3, activation='relu', padding='same')(pool2)
    conv3 = Dropout(0.2)(conv3)
    conv3 = Conv2D(256, 3, activation='relu', padding='same')(conv3)
    pool3 = MaxPooling2D(pool_size=(2, 2))(conv3)

    conv4 = Conv2D(512, 3, activation='relu', padding='same')(pool3)
    conv4 = Dropout(0.2)(conv4)
    conv4 = Conv2D(512, 3, activation='relu', padding='same')(conv4)
    pool4 = MaxPooling2D(pool_size=(2, 2))(conv4)

    conv5 = Conv2D(1024, 3, activation='relu', padding='same')(pool4)
    conv5 = Dropout(0.3)(conv5)
    conv5 = Conv2D(1024, 3, activation='relu', padding='same')(conv5)

    # Decoder
    up6 = Concatenate()([UpSampling2D(size=(2, 2))(conv5), conv4])
    conv6 = Conv2D(512, 3, activation='relu', padding='same')(up6)
    conv6 = Dropout(0.2)(conv6)
    conv6 = Conv2D(512, 3, activation='relu', padding='same')(conv6)

    up7 = Concatenate()([UpSampling2D(size=(2, 2))(conv6), conv3])
    conv7 = Conv2D(256, 3, activation='relu', padding='same')(up7)
    conv7 = Dropout(0.2)(conv7)
    conv7 = Conv2D(256, 3, activation='relu', padding='same')(conv7)

    up8 = Concatenate()([UpSampling2D(size=(2, 2))(conv7), conv2])
    conv8 = Conv2D(128, 3, activation='relu', padding='same')(up8)
    conv8 = Dropout(0.1)(conv8)
    conv8 = Conv2D(128, 3, activation='relu', padding='same')(conv8)

    up9 = Concatenate()([UpSampling2D(size=(2, 2))(conv8), conv1])
    conv9 = Conv2D(64, 3, activation='relu', padding='same')(up9)
    conv9 = Dropout(0.1)(conv9)
    conv9 = Conv2D(64, 3, activation='relu', padding='same')(conv9)

    conv10 = Conv2D(1, 1, activation='sigmoid')(conv9)

    model = Model(inputs=[inputs], outputs=[conv10])

    return model

model = unet_model_with_dropout()
model.compile(optimizer='adam', loss='binary_crossentropy', metrics=['accuracy'])

# Callbacks for early stopping and model checkpointing
callbacks = [
    EarlyStopping(patience=10, monitor='val_loss', restore_best_weights=True),
    ModelCheckpoint('best_model.h5', save_best_only=True, monitor='val_loss')
]

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class ConvBlock(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(ConvBlock, self).__init__()
        self.conv1 = nn.Conv3d(in_channels, out_channels, kernel_size=3, padding=1)
        self.conv2 = nn.Conv3d(out_channels, out_channels, kernel_size=3, padding=1)
        self.batchnorm = nn.BatchNorm3d(out_channels)
        self.relu = nn.ReLU(inplace=True)
        
    def forward(self, x):
        x = self.conv1(x)
        x = self.batchnorm(x)
        x = self.relu(x)
        x = self.conv2(x)
        x = self.batchnorm(x)
        x = self.relu(x)
        return x

class nnUNet(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(nnUNet, self).__init__()
        self.conv1 = ConvBlock(in_channels, 32)
        self.pool1 = nn.MaxPool3d(kernel_size=2, stride=2)
        self.conv2 = ConvBlock(32, 64)
        self.pool2 = nn.MaxPool3d(kernel_size=2, stride=2)
        self.conv3 = ConvBlock(64, 128)
        self.pool3 = nn.MaxPool3d(kernel_size=2, stride=2)
        self.conv4 = ConvBlock(128, 256)
        self.pool4 = nn.MaxPool3d(kernel_size=2, stride=2)
        self.conv5 = ConvBlock(256, 512)
        self.upconv4 = nn.ConvTranspose3d(512, 256, kernel_size=2, stride=2)
        self.upconv3 = nn.ConvTranspose3d(256, 128, kernel_size=2, stride=2)
        self.upconv2 = nn.ConvTranspose3d(128, 64, kernel_size=2, stride=2)
        self.upconv1 = nn.ConvTranspose3d(64, 32, kernel_size=2, stride=2)
        self.conv6 = nn.Conv3d(32, out_channels, kernel_size=1)
        
    def forward(self, x):
        conv1 = self.conv1(x)
        pool1 = self.pool1(conv1)
        conv2 = self.conv2(pool1)
        pool2 = self.pool2(conv2)
        conv3 = self.conv3(pool2)
        pool3 = self.pool3(conv3)
        conv4 = self.conv4(pool3)
        pool4 = self.pool4(conv4)
        conv5 = self.conv5(pool4)
        upconv4 = self.upconv4(conv5)
        concat4 = torch.cat((upconv4, conv4), dim=1)
        upconv3 = self.upconv3(concat4)
        concat3 = torch.cat((upconv3, conv3), dim=1)
        upconv2 = self.upconv2(concat3)
        concat2 = torch.cat((upconv2, conv2), dim=1)
        upconv1 = self.upconv1(concat2)
        concat1 = torch.cat((upconv1, conv1), dim=1)
        output = self.conv6(concat1)
        return output
    
model = nnUNet(1, 1)
model = model.cuda()
criterion = nn.BCEWithLogitsLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

# Convert the numpy arrays to PyTorch tensors
X_train_torch = torch.tensor(X_train).permute(0, 3, 1, 2).float().cuda()
y_train_torch = torch.tensor(y_train).permute(0, 3, 1, 2).float().cuda()
X_val_torch = torch.tensor(X_val).permute(0, 3, 1, 2).float().cuda()
y_val_torch = torch.tensor(y_val).permute(0, 3, 1, 2).float().cuda()

# Training loop
epochs = 50
batch_size = 4
for epoch in range(epochs):
    model.train()
    for i in range(0, len(X_train_torch), batch_size):
        optimizer.zero_grad()
        output = model(X_train_torch[i:i+batch_size])
        loss = criterion(output, y_train_torch[i:i+batch_size])
        loss.backward()
        optimizer.step()
    print(f'Epoch {epoch+1}/{epochs}, Loss: {loss.item()}')

# Evaluation
model.eval()
with torch.no_grad():
    y_pred = model(X_val_torch)
    loss = criterion(y_pred, y_val_torch)
    print(f'Validation Loss: {loss.item()}')

# Convert the PyTorch tensors to numpy arrays
y_pred = y_pred.cpu().numpy()
y_val_torch = y_val_torch.cpu().numpy()

# Calculate the Jaccard score
y_pred = (y_pred > 0.5).astype(np.float32)
jaccard = jaccard_score(y_val_torch.flatten(), y_pred.flatten())
print(f'Jaccard Score: {jaccard}')

# Display a sample image, ground truth and prediction
sample_idx = 0
plt.figure(figsize=(12, 6))
plt.subplot(1, 3, 1)
plt.imshow(X_val[sample_idx, :, :, 0], cmap='gray')
plt.title("Image")
plt.subplot(1, 3, 2)
plt.imshow(y_val[sample_idx, :, :, 0], cmap='gray')
plt.title("Ground Truth")
plt.subplot(1, 3, 3)
plt.imshow(y_pred[sample_idx, 0, :, :], cmap='gray')
plt.title("Prediction")
plt.show()

# Save the model
torch.save(model.state_dict(), 'best_model.pth')

# Load the model
model = nnUNet(1, 1)
model.load_state_dict(torch.load('best_model.pth'))
model = model.cuda()


### Train and Augment the Data

In [None]:
from tensorflow.keras.preprocessing.image import ImageDataGenerator

# Data augmentation
datagen = ImageDataGenerator(
    rotation_range=10,
    width_shift_range=0.1,
    height_shift_range=0.1,
    shear_range=0.1,
    zoom_range=0.1,
    horizontal_flip=True,
    fill_mode='nearest'
)

# Apply data augmentation to the training data
train_gen = datagen.flow(X_train, y_train, batch_size=32)

# Use the augmented data in model training
history = model.fit(
    train_gen,
    steps_per_epoch=len(X_train) // 32,
    epochs=50,
    validation_data=(X_val, y_val),
    callbacks=callbacks
)

In [None]:
# Plotting the training history
def plot_history(history):
    plt.figure(figsize=(12, 4))

    plt.subplot(1, 2, 1)
    plt.plot(history.history['accuracy'], label='train_accuracy')
    plt.plot(history.history['val_accuracy'], label='val_accuracy')
    plt.title('Model Accuracy')
    plt.xlabel('Epochs')
    plt.ylabel('Accuracy')
    plt.legend()

    plt.subplot(1, 2, 2)
    plt.plot(history.history['loss'], label='train_loss')
    plt.plot(history.history['val_loss'], label='val_loss')
    plt.title('Model Loss')
    plt.xlabel('Epochs')
    plt.ylabel('Loss')
    plt.legend()

    plt.show()

plot_history(history)


### Visualize Model

In [None]:
import matplotlib.pyplot as plt

# Function to visualize predictions
def visualize_prediction(model, image, true_mask=None):
    pred_mask = model.predict(np.expand_dims(image, axis=0))[0, :, :, 0]  # Predict the mask and remove batch dimension

    plt.figure(figsize=(15, 5))

    plt.subplot(1, 3, 1)
    plt.title('Input Image')
    plt.imshow(image, cmap='gray')

    if true_mask is not None:
        plt.subplot(1, 3, 2)
        plt.title('True Mask')
        plt.imshow(true_mask, cmap='gray')

    plt.subplot(1, 3, 3)
    plt.title('Predicted Mask')
    plt.imshow(pred_mask, cmap='gray')

    plt.show()

# Visualize some predictions
for i in range(5):
    visualize_prediction(model, X_val[i], y_val[i])


### Model Evaluation

In [None]:
# Function to compute IoU
def compute_iou(y_true, y_pred):
    y_true = y_true.flatten()
    y_pred = (y_pred.flatten() > 0.5).astype(np.int32)
    return jaccard_score(y_true, y_pred)

# Evaluate model on validation set
val_predictions = model.predict(X_val)
iou_scores = [compute_iou(y_val[i], val_predictions[i]) for i in range(len(y_val))]

print(f'Mean IoU on validation set: {np.mean(iou_scores)}')

# Visualize some predictions
for i in range(5):
    visualize_prediction(model, X_val[i], y_val[i])
