In [11]:
pip install numpy tensorflow scikit-learn matplotlib

Note: you may need to restart the kernel to use updated packages.


In [2]:
import os
import numpy as np
import tensorflow as tf
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Conv2D, MaxPooling2D, Flatten, Dense, Dropout
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.callbacks import EarlyStopping, ModelCheckpoint
import matplotlib.pyplot as plt
import io
from PIL import Image

# Path to your datasets (adjust as needed)
data_dir = 'dataset'  # Ensure it contains subfolders for each class (e.g., 'Brain Tumor', 'Alzheimers', etc.)

# Image size and batch settings
img_height, img_width = 224, 224
batch_size = 32
epochs = 20  # Increased epochs to improve learning

# Create a function to convert image to PNG bytes and handle in-memory processing
def convert_image_to_png_in_memory(image_path):
    try:
        with Image.open(image_path) as img:
            # Resize the image
            img_resized = img.resize((img_height, img_width), Image.Resampling.LANCZOS)

            # Convert to PNG using in-memory BytesIO object
            byteImgIO = io.BytesIO()
            img_resized.save(byteImgIO, format="PNG")
            byteImgIO.seek(0)  # Go back to the start of the stream
            byteImg = byteImgIO.read()  # Get the byte data

            return byteImg
    except Exception as e:
        print(f"Error processing {image_path}: {e}")
        return None

# Create ImageDataGenerator instances for training and validation
train_datagen = ImageDataGenerator(
    rescale=1./255,
    rotation_range=20,
    zoom_range=0.2,
    width_shift_range=0.2,
    height_shift_range=0.2,
    shear_range=0.2,
    horizontal_flip=True,
    validation_split=0.2  # 20% validation split
)

train_generator = train_datagen.flow_from_directory(
    data_dir,
    target_size=(img_height, img_width),
    batch_size=batch_size,
    class_mode='categorical',
    subset='training'  # Use 'training' subset
)

validation_generator = train_datagen.flow_from_directory(
    data_dir,
    target_size=(img_height, img_width),
    batch_size=batch_size,
    class_mode='categorical',
    subset='validation'  # Use 'validation' subset
)

# Get the number of classes (diseases)
num_classes = len(train_generator.class_indices)

# Build a CNN model
model = Sequential()

# Convolutional layers
model.add(Conv2D(32, (3, 3), activation='relu', input_shape=(img_height, img_width, 3)))
model.add(MaxPooling2D(pool_size=(2, 2)))

model.add(Conv2D(64, (3, 3), activation='relu'))
model.add(MaxPooling2D(pool_size=(2, 2)))

model.add(Conv2D(128, (3, 3), activation='relu'))
model.add(MaxPooling2D(pool_size=(2, 2)))

model.add(Flatten())

# Fully connected layers
model.add(Dense(512, activation='relu'))
model.add(Dropout(0.5))
model.add(Dense(num_classes, activation='softmax'))

# Compile the model
model.compile(optimizer=Adam(learning_rate=0.001), loss='categorical_crossentropy', metrics=['accuracy'])

# Add callbacks for early stopping and model checkpointing
early_stopping = EarlyStopping(monitor='val_loss', patience=3, restore_best_weights=True)
model_checkpoint = ModelCheckpoint('brain_disease_mri_model.h5', monitor='val_loss', save_best_only=True)

# Train the model
history = model.fit(
    train_generator,
    steps_per_epoch=train_generator.samples // batch_size,
    validation_data=validation_generator,
    validation_steps=validation_generator.samples // batch_size,
    epochs=epochs,
    callbacks=[early_stopping, model_checkpoint]
)

# Save the final model
model.save('brain_disease_mri_model_final.h5')

# Plot training and validation accuracy and loss
acc = history.history['accuracy']
val_acc = history.history['val_accuracy']
loss = history.history['loss']
val_loss = history.history['val_loss']
epochs_range = range(len(acc))  # Adjust based on actual epochs run

plt.figure(figsize=(8, 8))
plt.subplot(1, 2, 1)
plt.plot(epochs_range, acc, label='Training Accuracy')
plt.plot(epochs_range, val_acc, label='Validation Accuracy')
plt.legend(loc='lower right')
plt.title('Training and Validation Accuracy')

plt.subplot(1, 2, 2)
plt.plot(epochs_range, loss, label='Training Loss')
plt.plot(epochs_range, val_loss, label='Validation Loss')
plt.legend(loc='upper right')
plt.title('Training and Validation Loss')
plt.show()


Found 64801 images belonging to 6 classes.
Found 16198 images belonging to 6 classes.
Epoch 1/20
  39/2025 [..............................] - ETA: 54:23 - loss: 0.9327 - accuracy: 0.7708

UnknownError: Graph execution error:

UnidentifiedImageError: cannot identify image file <_io.BytesIO object at 0x0000021A8EC1D2C0>
Traceback (most recent call last):

  File "c:\Users\ar646\Desktop\Projects\BDD\pluto\lib\site-packages\tensorflow\python\ops\script_ops.py", line 268, in __call__
    ret = func(*args)

  File "c:\Users\ar646\Desktop\Projects\BDD\pluto\lib\site-packages\tensorflow\python\autograph\impl\api.py", line 643, in wrapper
    return func(*args, **kwargs)

  File "c:\Users\ar646\Desktop\Projects\BDD\pluto\lib\site-packages\tensorflow\python\data\ops\from_generator_op.py", line 198, in generator_py_func
    values = next(generator_state.get_iterator(iterator_id))

  File "c:\Users\ar646\Desktop\Projects\BDD\pluto\lib\site-packages\keras\src\engine\data_adapter.py", line 917, in wrapped_generator
    for data in generator_fn():

  File "c:\Users\ar646\Desktop\Projects\BDD\pluto\lib\site-packages\keras\src\engine\data_adapter.py", line 1064, in generator_fn
    yield x[i]

  File "c:\Users\ar646\Desktop\Projects\BDD\pluto\lib\site-packages\keras\src\preprocessing\image.py", line 116, in __getitem__
    return self._get_batches_of_transformed_samples(index_array)

  File "c:\Users\ar646\Desktop\Projects\BDD\pluto\lib\site-packages\keras\src\preprocessing\image.py", line 370, in _get_batches_of_transformed_samples
    img = image_utils.load_img(

  File "c:\Users\ar646\Desktop\Projects\BDD\pluto\lib\site-packages\keras\src\utils\image_utils.py", line 423, in load_img
    img = pil_image.open(io.BytesIO(f.read()))

  File "c:\Users\ar646\Desktop\Projects\BDD\pluto\lib\site-packages\PIL\Image.py", line 3498, in open
    raise UnidentifiedImageError(msg)

PIL.UnidentifiedImageError: cannot identify image file <_io.BytesIO object at 0x0000021A8EC1D2C0>


	 [[{{node PyFunc}}]]
	 [[IteratorGetNext]] [Op:__inference_train_function_5167]