In [None]:
import pandas as pd
import os
import shutil
import tensorflow as tf
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras.applications import MobileNetV2
from tensorflow.keras.models import Model
from tensorflow.keras.layers import GlobalAveragePooling2D, Dense, Dropout

In [None]:
SOURCE_IMAGE_DIR = '../data/full_dataset/train_images/'
CSV_PATH = '../data/full_dataset/train.csv'
SUBSET_DIR = '../data/subset/'
MODELS_DIR = '../models/'
RESULTS_DIR = '../results/'

In [None]:
# Parameters
SAMPLES_PER_CLASS = 250
IMG_SIZE = 224
BATCH_SIZE = 32
EPOCHS = 15

# Create necessary directories if they don't exist
os.makedirs(SUBSET_DIR, exist_ok=True)
os.makedirs(MODELS_DIR, exist_ok=True)
os.makedirs(RESULTS_DIR, exist_ok=True)

In [None]:
print("--- Step 3: Creating Balanced Data Subset ---")
df = pd.read_csv(CSV_PATH)
balanced_df = df.groupby('diagnosis').apply(
    lambda x: x.sample(SAMPLES_PER_CLASS, random_state=42) if len(x) >= SAMPLES_PER_CLASS else x
).reset_index(drop=True)

print("Copying files for the subset...")
for index, row in balanced_df.iterrows():
    image_filename = f"{row['id_code']}.png"
    source_path = os.path.join(SOURCE_IMAGE_DIR, image_filename)
    destination_path = os.path.join(SUBSET_DIR, image_filename)
    if not os.path.exists(destination_path): # Avoid re-copying if script is run again
        shutil.copyfile(source_path, destination_path)
print(f"Subset created with {len(os.listdir(SUBSET_DIR))} images.")


In [None]:
print("\n--- Step 4: Preparing DataFrame for Keras Generators ---")
balanced_df['id_code'] = balanced_df['id_code'].astype(str) + '.png'
balanced_df['diagnosis'] = balanced_df['diagnosis'].astype(str)
train_df, val_df = train_test_split(
    balanced_df, test_size=0.2, random_state=42, stratify=balanced_df['diagnosis']
)
print(f"Training set size: {len(train_df)}, Validation set size: {len(val_df)}")


In [None]:
print("\n--- Step 5: Setting up Data Generators ---")
train_datagen = ImageDataGenerator(
    rescale=1./255., rotation_range=20, width_shift_range=0.1,
    height_shift_range=0.1, shear_range=0.1, zoom_range=0.1,
    horizontal_flip=True, fill_mode='nearest'
)
val_datagen = ImageDataGenerator(rescale=1./255.)

train_generator = train_datagen.flow_from_dataframe(
    dataframe=train_df, directory=SUBSET_DIR, x_col='id_code', y_col='diagnosis',
    target_size=(IMG_SIZE, IMG_SIZE), batch_size=BATCH_SIZE, class_mode='categorical'
)
validation_generator = val_datagen.flow_from_dataframe(
    dataframe=val_df, directory=SUBSET_DIR, x_col='id_code', y_col='diagnosis',
    target_size=(IMG_SIZE, IMG_SIZE), batch_size=BATCH_SIZE, class_mode='categorical'
)

In [None]:
# BUILD THE BASELINE MODEL (MobileNetV2)
print("\n--- Step 6: Building the Baseline Model ---")
base_model = MobileNetV2(input_shape=(IMG_SIZE, IMG_SIZE, 3), include_top=False, weights='imagenet')
base_model.trainable = False
x = base_model.output
x = GlobalAveragePooling2D()(x)
x = Dense(128, activation='relu')(x)
x = Dropout(0.5)(x)
predictions = Dense(5, activation='softmax')(x)
model = Model(inputs=base_model.input, outputs=predictions)
model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy'])
print("Model built successfully.")

In [None]:
print("\n--- Step 7: Starting Model Training ---")
history = model.fit(
    train_generator,
    epochs=EPOCHS,
    validation_data=validation_generator,
    steps_per_epoch=len(train_df) // BATCH_SIZE,
    validation_steps=len(val_df) // BATCH_SIZE
)
print("Training complete.")

In [None]:
print("\n--- Step 8: Saving Model and Plotting Results ---")
model.save(os.path.join(MODELS_DIR, 'baseline_model.h5'))
print(f"Model saved to {os.path.join(MODELS_DIR, 'baseline_model.h5')}")

acc = history.history['accuracy']
val_acc = history.history['val_accuracy']
loss = history.history['loss']
val_loss = history.history['val_loss']
epochs_range = range(EPOCHS)

plt.figure(figsize=(12, 5))
plt.subplot(1, 2, 1)
plt.plot(epochs_range, acc, label='Training Accuracy')
plt.plot(epochs_range, val_acc, label='Validation Accuracy')
plt.legend(loc='lower right')
plt.title('Training and Validation Accuracy')

plt.subplot(1, 2, 2)
plt.plot(epochs_range, loss, label='Training Loss')
plt.plot(epochs_range, val_loss, label='Validation Loss')
plt.legend(loc='upper right')
plt.title('Training and Validation Loss')

plt.savefig(os.path.join(RESULTS_DIR, 'baseline_learning_curves.png'))
print(f"Learning curves plot saved to {os.path.join(RESULTS_DIR, 'baseline_learning_curves.png')}")
plt.show()