In [8]:
import os
import pandas as pd
import tensorflow as tf
from tensorflow.keras import layers, models
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras.applications import ResNet50
from tensorflow.keras.optimizers import Adam
from sklearn.model_selection import train_test_split
import shutil


# Enable mixed precision
from tensorflow.keras import mixed_precision
mixed_precision.set_global_policy('mixed_float16')

In [5]:
# Paths
image_dir = '../1792x1792'  # Use the 1792x1792 images
metadata_file = 'project-root/data/BCC_labels.csv'
output_dir = 'organized_data'

# Step 1: Load and process metadata
metadata = pd.read_csv(metadata_file)
metadata['label'] = metadata['label'].map({'Clear': 0, 'Present': 1})  # Map labels to 0 and 1

# Group by StudyID to split data
grouped = metadata.groupby('StudyID #')

# Count images per slide
slide_counts = {}
for slide_id in metadata['slide_id'].unique():
    slide_folder = os.path.join(image_dir, slide_id)
    if os.path.exists(slide_folder):
        slide_counts[slide_id] = len([f for f in os.listdir(slide_folder) if f.endswith('.png')])
    else:
        slide_counts[slide_id] = 0

print(f"Tissue images per slide: {slide_counts}")
print(f"Total tissue images found: {sum(slide_counts.values())}")

# Convert the keys to a list
train_study_ids, val_study_ids = train_test_split(list(grouped.groups.keys()), test_size=0.2, random_state=42)

# Create organized directories for training and validation sets
train_dir = os.path.join(output_dir, 'train')
val_dir = os.path.join(output_dir, 'val')
os.makedirs(train_dir, exist_ok=True)
os.makedirs(val_dir, exist_ok=True)

Tissue images per slide: {'slide-2022-02-09T12-26-27-R5-S1': 3, 'slide-2022-02-09T12-28-49-R5-S2': 3, 'slide-2022-02-09T12-30-52-R5-S3': 2, 'slide-2022-02-09T12-33-12-R5-S4': 2, 'slide-2022-02-09T12-36-31-R5-S5': 1, 'slide-2022-02-09T12-38-53-R5-S6': 2, 'slide-2022-02-09T12-41-58-R5-S7': 1, 'slide-2022-02-09T12-44-19-R5-S8': 2, 'slide-2022-02-09T12-47-21-R5-S9': 4, 'slide-2022-02-09T13-08-40-R5-S17': 2, 'slide-2022-02-09T13-11-39-R5-S18': 2, 'slide-2022-02-09T13-14-10-R5-S19': 2, 'slide-2022-02-09T13-15-57-R5-S20': 3, 'slide-2022-02-09T13-18-39-R5-S21': 4, 'slide-2022-02-09T13-21-07-R5-S22': 3, 'slide-2022-02-09T13-23-51-R5-S23': 2, 'slide-2022-02-09T13-26-16-R5-S24': 2, 'slide-2022-02-09T13-29-07-R5-S25': 2, 'slide-2022-02-09T13-31-46-R6-S1': 2, 'slide-2022-02-09T13-33-41-R6-S2': 4, 'slide-2022-02-09T13-36-50-R6-S3': 2, 'slide-2022-02-09T13-39-43-R6-S4': 3, 'slide-2022-02-09T13-42-12-R6-S5': 3, 'slide-2022-02-09T13-44-36-R6-S6': 2, 'slide-2022-02-09T13-46-51-R6-S7': 5, 'slide-2022-02-

In [6]:
def organize_images(group_ids, dest_dir):
    total_copied = 0  # Track copied images
    for study_id in group_ids:
        group = grouped.get_group(study_id)
        for _, row in group.iterrows():
            slide_id = row['slide_id']
            label = row['label']
            label_dir = 'Clear' if label == 0 else 'Present'
            source_dir = os.path.join(image_dir, slide_id)
            dest_label_dir = os.path.join(dest_dir, label_dir)
            os.makedirs(dest_label_dir, exist_ok=True)
            if os.path.exists(source_dir):
                for file in os.listdir(source_dir):
                    if file.endswith('.png'):
                        # Append slide_id to filename to prevent overwriting
                        new_file_name = f"{slide_id}_{file}"
                        shutil.copy(
                            os.path.join(source_dir, file),
                            os.path.join(dest_label_dir, new_file_name)
                        )
                        total_copied += 1
    print(f"Total images copied to {dest_dir}: {total_copied}")

# Organize images
organize_images(train_study_ids, train_dir)
organize_images(val_study_ids, val_dir)

Total images copied to organized_data/train: 165
Total images copied to organized_data/val: 43


In [9]:
# Step 2: Use ImageDataGenerator for data loading
image_size = (1792, 1792)
batch_size = 4  # Adjusted batch size to fit GPU memory

# Define ImageDataGenerators with appropriate preprocessing
train_datagen = ImageDataGenerator(
    preprocessing_function=tf.keras.applications.resnet50.preprocess_input,
    rotation_range=20,
    width_shift_range=0.05,
    height_shift_range=0.05,
    shear_range=0.05,
    zoom_range=0.1,
    horizontal_flip=True,
    fill_mode='nearest'
)

val_datagen = ImageDataGenerator(
    preprocessing_function=tf.keras.applications.resnet50.preprocess_input
)

# Create Generators
train_generator = train_datagen.flow_from_directory(
    train_dir,
    target_size=image_size,
    batch_size=batch_size,
    class_mode='binary',
    shuffle=True
)

val_generator = val_datagen.flow_from_directory(
    val_dir,
    target_size=image_size,
    batch_size=batch_size,
    class_mode='binary',
    shuffle=False
)

Found 165 images belonging to 2 classes.
Found 43 images belonging to 2 classes.


In [None]:
# Step 3: Load the Pre-trained ResNet-50 model without the top layers
base_model = ResNet50(weights='imagenet', include_top=False, input_shape=(1792, 1792, 3))

# Freeze the base model layers initially
base_model.trainable = False

# Step 4: Build the model
inputs = layers.Input(shape=(1792, 1792, 3))
x = base_model(inputs, training=False)
x = layers.GlobalAveragePooling2D()(x)
x = layers.Dense(1024, activation='relu')(x)
outputs = layers.Dense(1, activation='sigmoid', dtype='float32')(x)  # Set dtype to float32 for mixed precision
model = models.Model(inputs, outputs)

# Step 5: Compile the model
model.compile(optimizer=Adam(learning_rate=1e-4),
              loss='binary_crossentropy',
              metrics=['accuracy'])

# Step 6: Set up callbacks
callbacks = [
    tf.keras.callbacks.ModelCheckpoint('best_model.h5', save_best_only=True, monitor='val_loss'),
    tf.keras.callbacks.EarlyStopping(monitor='val_loss', patience=5, restore_best_weights=True),
    tf.keras.callbacks.ReduceLROnPlateau(monitor='val_loss', factor=0.1, patience=3)
]

In [None]:
# Step 7: Train the model
history = model.fit(
    train_generator,
    epochs=10,
    validation_data=val_generator,
    callbacks=callbacks,
    verbose=1
)