# TPU VM v3-8

In [12]:
!pip install pydicom
!pip install pillow==10.0.0
!pip install ipywidgets
!pip install torch-xla
!pip install pandas
!pip install --upgrade pip


[0m

# Restarts the Jupyter kernel

In [13]:
from IPython.display import display_html
def restartkernel() :
    display_html("<script>Jupyter.notebook.kernel.restart()</script>",raw=True)

In [14]:
restartkernel()

# Imports

In [15]:
import os
import random
import cv2
import xml.etree.ElementTree as ET
import torch
import torchvision
import pydicom
from pydicom.pixel_data_handlers.util import apply_color_lut
from pydicom.pixel_data_handlers.util import apply_modality_lut
from pydicom.pixel_data_handlers.util import apply_voi_lut
from torch.utils.data import DataLoader, Dataset
import torch.optim as optim
import numpy as np
from torch.utils.data import Dataset
from pathlib import Path
from tqdm import tqdm
from tqdm.notebook import tqdm
from matplotlib import pyplot as plt
from matplotlib.patches import Rectangle
from PIL import Image
import torchvision.transforms as torchvision_transforms
import io
import torch_xla
import torch_xla.core.xla_model as xm
import sqlite3
import pandas as pd

# Initialize Path

In [16]:
path = Path('/kaggle/input/mitosis-wsi-ccmct-training-set/')

# Initialize Database

In [17]:
database = sqlite3.connect(str(path/'MITOS_WSI_CCMCT_ODAEL_train_dcm.sqlite'))

# Define output directories

In [18]:
OUTPUT_DIR = '/kaggle/working/'  
MODEL_DIR = os.path.join(OUTPUT_DIR, 'models')
VISUALIZATION_DIR = os.path.join(OUTPUT_DIR, 'visualizations')
PREDICTION_DIR = os.path.join(OUTPUT_DIR, 'predictions')
REPORT_DIR = os.path.join(OUTPUT_DIR, 'reports')

# Define utility functions

In [19]:
def create_output_directories(parent_dir, subdirs=['train', 'val', 'test']):
    for subdir in subdirs:
        dir_path = os.path.join(parent_dir, subdir)
        if not os.path.exists(dir_path):
            os.makedirs(dir_path)
            print(f"Created directory: {dir_path}")

def visualize_predictions(image, predictions, class_names, score_threshold=0.5):
    fig, ax = plt.subplots(1)
    image = image * torch.tensor([0.229, 0.224, 0.225]).view(3, 1, 1) + torch.tensor([0.485, 0.456, 0.406]).view(3, 1, 1)
    image = image.numpy()
    image = np.transpose(image, (1, 2, 0))
    ax.imshow(image)

    boxes = predictions['boxes'].cpu().numpy()
    scores = predictions['scores'].cpu().numpy()
    labels = predictions['labels'].cpu().numpy()

    for box, score, label in zip(boxes, scores, labels):
        if score > score_threshold:
            xmin, ymin, xmax, ymax = box
            rect = patches.Rectangle((xmin, ymin), xmax - xmin, ymax - ymin,
                                     linewidth=2, edgecolor='r', facecolor='none')
            ax.add_patch(rect)
            ax.text(xmin, ymin, f'{class_names[label]}: {score:.2f}', color='r')

    plt.axis('off')
    plt.show()

def save_checkpoint(model, optimizer, epoch, loss, checkpoint_dir):
    checkpoint_path = os.path.join(checkpoint_dir, f'checkpoint_epoch_{epoch}.pth')
    xm.save({
        'epoch': epoch,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'loss': loss
    }, checkpoint_path)
    print(f"Saved checkpoint to: {checkpoint_path}")

def load_checkpoint(model, optimizer, checkpoint_dir):
    checkpoint_files = [f for f in os.listdir(checkpoint_dir) if f.startswith('checkpoint_epoch_')]

    if not checkpoint_files:
        print("No checkpoints found in the directory.")
        return 0, None 

    checkpoint_files.sort(key=lambda x: int(x.split('_')[-1].split('.')[0]), reverse=True)

    latest_checkpoint_path = os.path.join(checkpoint_dir, checkpoint_files[0])
    checkpoint = torch.load(latest_checkpoint_path)
    model.load_state_dict(checkpoint['model_state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    epoch = checkpoint['epoch']
    loss = checkpoint['loss']
    print(f"Loaded checkpoint from epoch {epoch} with loss {loss}")
    return epoch, loss

# Define preprocessing functions

In [20]:
def apply_voi_lut(image, dataset, index=0):
    if 'VOILUTSequence' in dataset:
        return apply_lut(image, dataset.VOILUTSequence[index])
    return image

def apply_modality_lut(image, dataset):
    if 'ModalityLUTSequence' in dataset:
        return apply_lut(image, dataset.ModalityLUTSequence[0])
    return image

def apply_color_lut(image, dataset):
    if 'LUTDescriptor' in dataset:
        return apply_lut(image, dataset)
    return image

def apply_lut(image, dataset,lut_descriptor=None):
    if lut_descriptor is None:
        lut_descriptor = dataset.LUTDescriptor
    nr_entries = lut_descriptor[0] or 2**16
    first_mapped = lut_descriptor[1]
    bits = lut_descriptor[2]
    output_range = 2**bits - 1
    try:
        lut_data = dataset.LUTData
    except AttributeError:
        return image
    lut = np.zeros(nr_entries, dtype=np.uint32)
    for i in range(0, len(lut_data), 2):
        lut[i // 2] = lut_data[i] + (lut_data[i + 1] << 8)

    if not np.array_equal(lut, np.arange(first_mapped, first_mapped + nr_entries, dtype=lut.dtype)):
        clipped_img = np.clip(image - first_mapped, 0, nr_entries - 1)
        return (lut[clipped_img] / output_range * 65535).astype(np.uint16)
    return image


def apply_palette_color(image, dataset):
    try:
        lut = dataset.PaletteColorLookupTableDataArray

        if lut is None or len(lut) == 0:
            raise ValueError("No Palette Color Lookup Table found in the DICOM dataset")

        lut_entries = len(lut) // 3 
        image = np.array(lut[image.flatten() * 3]).reshape(image.shape + (3,))

        if image.dtype != np.uint8:
            image = (image / image.max() * 255).astype(np.uint8)

        return image
    except Exception as e:
        print(f"Error applying palette color: {e}")
        return image 


def normalize_dicom(image, ds):

    photometric_interpretation = ds.PhotometricInterpretation
    modality = ds.Modality

    if photometric_interpretation == 'MONOCHROME1' or photometric_interpretation == 'MONOCHROME2':
        image = apply_voi_lut(image, ds, index=0)
        return apply_modality_lut(image, ds)

    elif photometric_interpretation == 'RGB':
        return apply_modality_lut(image,ds)

    elif photometric_interpretation == 'YBR_FULL_422' or photometric_interpretation == 'YBR_FULL':
        # Convert YBR to RGB
        image = cv2.cvtColor(image, cv2.COLOR_YCrCb2RGB)
        return apply_color_lut(image, ds)

    elif photometric_interpretation == 'PALETTE COLOR':
        image = apply_palette_color(image, ds)
        return apply_color_lut(image, ds)

    else:
        print(f"Warning: Unhandled Photometric Interpretation: {photometric_interpretation}")
        return image

def resize_image(image, target_size):
    return cv2.resize(image, target_size, interpolation=cv2.INTER_LINEAR)

def save_patch(patch, output_path, filename):
    if not os.path.exists(output_path):
        os.makedirs(output_path)
    full_path = os.path.join(output_path, filename)
    cv2.imwrite(full_path, patch)

## Create output directories with subdirectories

In [21]:
create_output_directories(MODEL_DIR, subdirs=['train', 'val', 'test'])
create_output_directories(VISUALIZATION_DIR, subdirs=['train', 'val', 'test'])
create_output_directories(PREDICTION_DIR, subdirs=[])
create_output_directories(REPORT_DIR, subdirs=[])

# Define model architectures

In [22]:
def create_faster_rcnn_model(num_classes):
    model = torchvision.models.detection.fasterrcnn_resnet50_fpn(pretrained=True)
    in_features = model.roi_heads.box_predictor.cls_score.in_features
    model.roi_heads.box_predictor = torchvision.models.detection.faster_rcnn.FastRCNNPredictor(in_features, num_classes)
    return model

# Dataset class

In [23]:
class MitoticDataset(Dataset):
    def __init__(self, slides, database, image_dir, transforms, patch_size=224):
        self.slides = slides
        self.database = database 
        self.image_dir = image_dir
        self.transforms = transforms
        self.patch_size = patch_size
        self.annotations = self.load_annotations()

    def load_annotations(self):
        all_annotations = []
        for slide_filename in self.slides:
            getannotations = f"""
            SELECT 
                C.coordinateX, 
                C.coordinateY,
                A.agreedClass
            FROM Annotations AS A
            JOIN Annotations_coordinates AS C ON A.uid = C.annoId
            JOIN Slides ON A.slide = Slides.uid
            WHERE Slides.filename = '{slide_filename[0]}' AND A.agreedClass = 2 
            """
            annotations = self.database.execute(getannotations).fetchall()
            for annotation in annotations:
                all_annotations.append({
                    'slide': slide_filename[0],
                    'center_x': annotation[0],
                    'center_y': annotation[1],
                    'class': annotation[2],
                })
        return all_annotations

    def __len__(self):
        return len(self.slides)

    def __getitem__(self, slide_idx):
        slide_filename = self.slides[slide_idx][0]
        slide_path = os.path.join(self.image_dir, slide_filename)

        ds = pydicom.dcmread(slide_path)

        tile_width = int(ds.Columns)
        tile_height = int(ds.Rows)
        num_frames = int(ds.NumberOfFrames)

        patches = []
        targets = []

        for annotation in self.annotations:
            if annotation['slide'] == slide_filename:
                tile_x = annotation['center_x'] // tile_width
                tile_y = annotation['center_y'] // tile_height

                annotation_x_in_tile = annotation['center_x'] % tile_width
                annotation_y_in_tile = annotation['center_y'] % tile_height

                frame_index = tile_y * (ds.TotalPixelMatrixColumns // tile_width) + tile_x
                if frame_index < num_frames:
                    image = ds.pixel_array[frame_index]
                else:
                    print(f"Error: Frame index {frame_index} out of bounds for {slide_filename}")
                    continue 

                x_start = max(0, annotation_x_in_tile - self.patch_size // 2)
                y_start = max(0, annotation_y_in_tile - self.patch_size // 2)
                x_end = min(image.shape[1], x_start + self.patch_size)
                y_end = min(image.shape[0], y_start + self.patch_size)

                patch = image[y_start:y_end, x_start:x_end]

                patch = normalize_dicom(patch, ds)

                if patch.shape[0] != self.patch_size or patch.shape[1] != self.patch_size:
                    patch = resize_image(patch, (self.patch_size, self.patch_size))

                half_patch = self.patch_size // 2
                boxes = torch.tensor([[annotation_x_in_tile - half_patch,
                                      annotation_y_in_tile - half_patch,
                                      annotation_x_in_tile + half_patch,
                                      annotation_y_in_tile + half_patch]], dtype=torch.float32)
                labels = torch.tensor([1], dtype=torch.int64)
                target = {
                    'boxes': boxes,
                    'labels': labels
                }
                
                if self.transforms:
                    patch = self.transforms(patch)
                else:
                    patch = self.transforms.ToTensor()(patch)
                
                patches.append(patch)
                targets.append(target)

        return patches, targets 


# Train the model

In [24]:
def train_model(model, dataloader, optimizer, device, epochs, batch_size, checkpoint_dir):
    model.to(device)
    model.train()

    start_epoch = 0
    start_epoch, _ = load_checkpoint(model, optimizer, checkpoint_dir)

    for epoch in range(start_epoch, epochs):
        total_loss = 0.0
        num_batches = 0

        for slide_idx in tqdm(range(len(dataloader)), unit="slide"):
            patches, targets = dataloader.dataset[slide_idx]
            for batch_start in range(0, len(patches), batch_size):
                batch_end = min(batch_start + batch_size, len(patches))
                images = torch.stack(patches[batch_start:batch_end]).to(device)
                targets_batch = [{k: v.to(device) for k, v in t.items()} for t in targets[batch_start:batch_end]]

                optimizer.zero_grad()
                loss_dict = model(images, targets_batch)
                losses = sum(loss for loss in loss_dict.values())

                xm.optimizer_step(optimizer, barrier=True)
                xm.mark_step()

                total_loss += losses.item()
                num_batches += 1

        avg_loss = total_loss / num_batches
        print(f"Epoch {epoch + 1}/{epochs}, Average Loss: {avg_loss:.4f}")
        save_checkpoint(model, optimizer, epoch + 1, avg_loss, checkpoint_dir)

# Main execution

In [None]:
if __name__ == "__main__":
    getslides = """SELECT filename FROM Slides"""
    all_slides = database.execute(getslides).fetchall()
    
    random.seed(42)
    random.shuffle(all_slides) 

    train_split = int(0.6 * len(all_slides))
    val_split = int(0.8 * len(all_slides))

    train_slides = all_slides[:train_split]
    val_slides = all_slides[train_split:val_split]
    test_slides = all_slides[val_split:]

    print(f"Number of training slides: {len(train_slides)}")
    print(f"Number of validation slides: {len(val_slides)}")
    print(f"Number of test slides: {len(test_slides)}")

    image_dir = str(path) 
    
    print(f"# --- Data Transforms ---")

    data_transforms = torchvision_transforms.Compose([
        torchvision_transforms.ToTensor(),
        torchvision_transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), 
    ])

    print(f"# --- Create Datasets ---")
    train_dataset = MitoticDataset(train_slides, database, image_dir, transforms=data_transforms)
    val_dataset = MitoticDataset(val_slides, database, image_dir, transforms=data_transforms)
    test_dataset = MitoticDataset(test_slides, database, image_dir, transforms=data_transforms)

    print(f"# --- Create DataLoaders ---")
    train_dataloader = DataLoader(train_dataset, batch_size=1, shuffle=True, pin_memory=True, num_workers=4) 
    val_dataloader = DataLoader(val_dataset, batch_size=1, shuffle=False, pin_memory=True, num_workers=4)
    test_dataloader = DataLoader(test_dataset, batch_size=1, shuffle=False, pin_memory=True, num_workers=4)

    print(f"# --- Model---")
    device = xm.xla_device() #torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    num_classes = 2 
    model = create_faster_rcnn_model(num_classes).to(device)
  
    print(f"# --- Optimizer ---")
    optimizer = optim.Adam(model.parameters(), lr=0.001)
    
    print(f"# --- Directory to store checkpoints ---")
    checkpoint_dir = os.path.join(MODEL_DIR, 'train')
    print(checkpoint_dir)

    print(f"# --- Training ---:{device}")
    batch_size = 4
    train_model(model, train_dataloader, optimizer, device, epochs=20, batch_size=batch_size, checkpoint_dir=checkpoint_dir)
    
    create_output_directories(os.path.join(OUTPUT_DIR, 'predictions'), subdirs=[])
    create_output_directories(os.path.join(OUTPUT_DIR, 'reports'), subdirs=[])
          
    print(f"# --- Visualization --- ")
    
    print(f"# Move images and targets to the device ")
    images, targets = next(iter(test_dataloader))
    images = list(image.to(device) for image in images)
    targets = [{k: v.to(device) for k, v in t.items()} for t in targets]


    
    print(f"# --- Load the latest model for Inference ---")
    model_inference = create_faster_rcnn_model(num_classes).to(device)
    _, _ = load_checkpoint(model_inference, optimizer, checkpoint_dir)  # Fixed line
    model_inference.eval()
    
    print(f"# Get model predictions ")
    all_predictions = []
    with torch.no_grad():
        for slide_idx in tqdm(range(len(test_dataloader)), unit="slide"):
            patches, _ = test_dataloader.dataset[slide_idx]

            for i in range(0, len(patches), batch_size):
                batch_end = min(i + batch_size, len(patches))
                image_batch = torch.stack(patches[i:batch_end]).to(device)
                predictions = model_inference(image_batch) 
                all_predictions.extend(predictions) 
    
        
    print(f"# --- Process and Analyze Predictions ---")   
    for slide_idx, predictions in enumerate(all_predictions):
        print(f"Predictions for Slide {slide_idx + 1}:")
        boxes = predictions['boxes'].cpu().numpy()
        scores = predictions['scores'].cpu().numpy()
        labels = predictions['labels'].cpu().numpy()

        confident_mask = scores > 0.5
        confident_boxes = boxes[confident_mask]
        confident_scores = scores[confident_mask]
        confident_labels = labels[confident_mask]

        num_mitotic_figures = np.sum(confident_labels == 1)
        print(f"  Number of mitotic figures: {num_mitotic_figures}")

        visualize_predictions(patches[slide_idx].cpu().permute(1, 2, 0),
                             predictions, 
                             class_names=['background', 'mitotic figure'])  

        prediction_df = pd.DataFrame({
            'xmin': confident_boxes[:, 0],
            'ymin': confident_boxes[:, 1],
            'xmax': confident_boxes[:, 2],
            'ymax': confident_boxes[:, 3],
            'score': confident_scores,
            'label': confident_labels
        })
        prediction_df.to_csv(os.path.join(OUTPUT_DIR, 'predictions', f'slide_{slide_idx+1}_predictions.csv'), index=False)
        
        total_mitotic_figures += num_mitotic_figures

        plt.figure()
        plt.hist(confident_scores, bins=10, range=(0.0, 1.0))
        plt.title(f'Confidence Score Histogram - Slide {slide_idx + 1}')
        plt.xlabel('Confidence Score')
        plt.ylabel('Frequency')
        plt.savefig(os.path.join(OUTPUT_DIR, 'visualizations', f'slide_{slide_idx + 1}_score_histogram.png'))
        plt.close()

        with open(os.path.join(OUTPUT_DIR, 'reports', f'slide_{slide_idx + 1}_report.txt'), 'w') as f:
            f.write(f"Slide: {slide_filename}\n")
            f.write(f"Number of Mitotic Figures: {num_mitotic_figures}\n")
            f.write(f"Precision: {precision:.4f}\n")
            f.write(f"Recall: {recall:.4f}\n")
            f.write(f"mAP: {mAP:.4f}\n")
    

Number of training slides: 12
Number of validation slides: 4
Number of test slides: 5
# --- Data Transforms ---
# --- Create Datasets ---
