In [4]:
from tensorflow.keras.applications import ResNet101
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Conv2D, MaxPooling2D, Flatten, Dense, Dropout, BatchNormalization, GlobalAveragePooling2D
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras.callbacks import LearningRateScheduler, EarlyStopping, TensorBoard, ModelCheckpoint, ReduceLROnPlateau
from sklearn.utils.class_weight import compute_class_weight
from tensorflow.keras.callbacks import Callback
from tensorflow.keras.regularizers import l2
from tensorflow.keras.optimizers import RMSprop, SGD
from tensorflow.keras import regularizers

In [5]:
combined_data_dir = '/Users/gabeprice/Desktop/Blackwell Research 2024/Alzheimers Research/Alzheimer_s Photos/combined'

In [6]:
# Load the pre-trained ResNet-101 model without the top classification layer
weights_path = '/Users/gabeprice/Desktop/Blackwell Research 2024/Alzheimers Research/Alzheimer_s Photos/model_attempts/res_net-101/resnet101_weights_tf_dim_ordering_tf_kernels_notop.h5'
base_model = ResNet101(weights=weights_path, include_top=False, input_shape=(208, 176, 3))

# Add custom layers on top of the base model
x = base_model.output
x = GlobalAveragePooling2D()(x)
x = Dense(128, activation='relu', kernel_regularizer=l2(0.001))(x)
x = Dropout(0.2)(x)
predictions = Dense(4, activation='softmax')(x)


In [7]:
# Define callbacks
early_stopping = EarlyStopping(monitor='val_loss', patience=15, restore_best_weights=True)
reduce_lr = ReduceLROnPlateau(monitor='val_loss', factor=0.5, patience=10, min_lr=1e-6, verbose=1)

In [8]:
# Create the final model
model = Model(inputs=base_model.input, outputs=predictions)

In [9]:
# Compile the model
optimizer = Adam(learning_rate=1e-4, clipnorm=1.0)
model.compile(optimizer=optimizer, loss='categorical_crossentropy', metrics=['accuracy'])

In [10]:
# Display the model summary
model.summary()

In [11]:
# Define the combined data generator with validation split
combined_datagen = ImageDataGenerator(
    rescale=1./255,
    validation_split=0.2  # 20% of data for validation
)

# Load the training data
combined_train_data = combined_datagen.flow_from_directory(
    combined_data_dir,
    target_size=(208, 176),
    batch_size=8,
    class_mode='categorical',
    subset='training'  # Use 80% of the data for training
)

# Load the validation data
combined_val_data = combined_datagen.flow_from_directory(
    combined_data_dir,
    target_size=(208, 176),
    batch_size=8,
    class_mode='categorical',
    subset='validation'  # Use 20% of the data for validation
)

Found 5121 images belonging to 4 classes.
Found 1279 images belonging to 4 classes.


In [12]:
# Define ModelCheckpoint callback
checkpoint_filepath = 'model_checkpoint.weights.h5'
model_checkpoint_callback = ModelCheckpoint(
    filepath=checkpoint_filepath,
    save_weights_only=True,
    monitor='val_loss',
    mode='min',
    save_best_only=True
)

In [13]:
from sklearn.utils.class_weight import compute_class_weight

class_weights = compute_class_weight(
    class_weight='balanced',
    classes=np.unique(combined_train_data.classes),
    y=combined_train_data.classes
)
class_weights = dict(enumerate(class_weights))

print("Class weights:", class_weights)


Class weights: {0: 1.7855648535564854, 1: 24.620192307692307, 2: 0.50009765625, 3: 0.7144252232142857}


In [14]:
# Load previous weights if they exist
try:
    model.load_weights(checkpoint_filepath)
    print("Loaded weights from checkpoint")
except:
    print("No checkpoint found, training from scratch")

No checkpoint found, training from scratch


In [15]:
# Train the model
hist = model.fit(
    combined_train_data,
    epochs=50,
    validation_data=combined_val_data,
    class_weight=class_weights,  # Pass the computed class weights here
    callbacks=[early_stopping, reduce_lr, model_checkpoint_callback]
)

Epoch 1/50


  self._warn_if_super_not_called()


[1m 21/641[0m [37m━━━━━━━━━━━━━━━━━━━━[0m [1m59:05[0m 6s/step - accuracy: 0.3932 - loss: 1.2983

KeyboardInterrupt: 

In [None]:
# Evaluate the model
test_loss, test_accuracy = model.evaluate(test_generator, verbose=1)
print(f'Test Loss: {test_loss}')
print(f'Test Accuracy: {test_accuracy}')
