In [None]:

# Install necessary libraries
!pip install monai pylibjpeg pylibjpeg-libjpeg pylibjpeg-openjpeg

# Import libraries
import os
import numpy as np
import pandas as pd
from PIL import Image
from tqdm import tqdm
import tensorflow as tf
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras.utils import Sequence
from sklearn.model_selection import train_test_split
import matplotlib.pyplot as plt
from monai.transforms import LoadImage  # Monai for DICOM loading
from tensorflow.keras.callbacks import ModelCheckpoint

# TPU Setup
resolver = tf.distribute.cluster_resolver.TPUClusterResolver(tpu='')
tf.config.experimental_connect_to_cluster(resolver)
tf.tpu.experimental.initialize_tpu_system(resolver)
strategy = tf.distribute.TPUStrategy(resolver)


# Step 1: Function to load DICOM images using monai
def load_dicom_images(image_path, img_size):
    load_image = LoadImage(image_only=True)
    img = load_image(image_path)
    
    # Check if the image has 2 dimensions (grayscale) and expand it to 3 dimensions
    if len(img.shape) == 2:
        img = np.expand_dims(img, axis=-1)  # Add a channel dimension (grayscale to (H, W, 1))
    
    # Resize the image to the target size
    img_resized = tf.image.resize(img, img_size)
    
    # If the image is grayscale, convert it to 3 channels (RGB-like)
    if img_resized.shape[-1] == 1:  # Grayscale
        img_resized = tf.image.grayscale_to_rgb(img_resized)
    
    img_resized = img_resized / 255.0  # Normalize pixel values
    return img_resized.numpy()

# Step 2: Data Generator for DICOM images
class DICOMDataGenerator(Sequence):
    def __init__(self, df, batch_size, img_size, augment=False):
        self.df = df.reset_index(drop=True)
        self.batch_size = batch_size
        self.img_size = img_size
        self.augment = augment
        self.indices = np.arange(len(self.df))

    def __len__(self):
        return int(np.ceil(len(self.df) / self.batch_size))

    def __getitem__(self, index):
        batch_indices = self.indices[index * self.batch_size: (index + 1) * self.batch_size]
        batch_df = self.df.loc[batch_indices]

        images = []
        labels = []

        for i, row in batch_df.iterrows():
            dicom_path = row['image_path']
            try:
                img = load_dicom_images(dicom_path, self.img_size)
                images.append(img)
                labels.append(row['cancer'])
            except Exception as e:
                print(f"Error loading image {dicom_path}: {e}")
                continue

        images = np.array(images)
        labels = np.array(labels)

        if self.augment and len(images) > 0:
            aug_gen = ImageDataGenerator(rotation_range=20,
                                         width_shift_range=0.1,
                                         height_shift_range=0.1,
                                         shear_range=0.2,
                                         zoom_range=0.2,
                                         horizontal_flip=True)
            images = aug_gen.flow(images, shuffle=False, batch_size=len(images)).__getitem__(0)

        return images, labels

    def on_epoch_end(self):
        np.random.shuffle(self.indices)

# Step 3: Load CSV and update file paths
df = pd.read_csv('/kaggle/input/rsna-breast-cancer-detection/train.csv')
df['image_path'] = df.apply(lambda row: f"/kaggle/input/rsna-breast-cancer-detection/train_images/{row['patient_id']}/{row['image_id']}.dcm", axis=1)

# Step 4: Split data into training and validation sets
train_df, valid_df = train_test_split(df, test_size=0.2, stratify=df['cancer'], random_state=42)

# Image size and batch size
img_size = (224, 224)
batch_size = 8  # Reduced batch size to manage memory usage

# Step 5: Create data generators for training and validation
train_gen = DICOMDataGenerator(train_df, batch_size=batch_size, img_size=img_size, augment=True)
valid_gen = DICOMDataGenerator(valid_df, batch_size=batch_size, img_size=img_size, augment=False)

# Use TPU strategy for model training
with strategy.scope():
    # Step 6: Build the model (Transfer learning with ResNet50)
    base_model = tf.keras.applications.ResNet50(weights='imagenet', include_top=False, input_shape=(*img_size, 3))
    base_model.trainable = False  # Freeze base model layers

    model = tf.keras.Sequential([
        base_model,
        tf.keras.layers.GlobalAveragePooling2D(),
        tf.keras.layers.Dense(256, activation='relu'),
        tf.keras.layers.Dense(1, activation='sigmoid')  # Binary classification
    ])

    model.compile(optimizer='adam', loss='binary_crossentropy', metrics=['accuracy'])
    model.summary()

    # Step 7: Save the best model during training
    checkpoint = ModelCheckpoint('best_model.h5', monitor='val_loss', save_best_only=True)

    # Step 8: Train the model with early stopping and checkpointing
    early_stopping = tf.keras.callbacks.EarlyStopping(monitor='val_loss', patience=3, restore_best_weights=True)

    history = model.fit(train_gen, 
                        validation_data=valid_gen, 
                        epochs=10,  # Increased epochs
                        callbacks=[early_stopping, checkpoint])

# Step 9: Plot training and validation accuracy and loss
plt.plot(history.history['accuracy'], label='Train Accuracy')
plt.plot(history.history['val_accuracy'], label='Validation Accuracy')
plt.legend()
plt.title('Model Accuracy')
plt.show()

plt.plot(history.history['loss'], label='Train Loss')
plt.plot(history.history['val_loss'], label='Validation Loss')
plt.legend()
plt.title('Model Loss')
plt.show()

# Step 10: Make predictions for submission
test_dir = '/kaggle/input/rsna-breast-cancer-detection/test_images'

# Load test data from CSV and create file paths for test images
test_df = pd.read_csv('/kaggle/input/rsna-breast-cancer-detection/test.csv')
test_df['image_path'] = test_df.apply(lambda row: f"{test_dir}/{row['patient_id']}/{row['image_id']}.dcm", axis=1)

# Create a test data generator
test_gen = DICOMDataGenerator(test_df, batch_size=1, img_size=img_size, augment=False)

# Make predictions
predictions = model.predict(test_gen)

# Create submission file
submission = pd.DataFrame({
    'prediction_id': test_df['prediction_id'],
    'cancer': predictions.flatten()
})

# Clip predictions between 0 and 1
submission['cancer'] = submission['cancer'].apply(lambda x: 1 if x > 0.5 else 0)

# Save submission
submission.to_csv('submission.csv', index=False)
