<a href="https://colab.research.google.com/github/spate472/RecreatingRetinaUNet/blob/main/Attempt2_ToyDataPyHealth_Gemini.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import os
import cv2
import numpy as np
import random
import pandas as pd
from tqdm.notebook import tqdm # Use notebook version for Colab
import math

# --- Configuration ---
IMAGE_SIZE = 320
SPLITS = {
    "train": 1000,
    "val": 500,
    "test": 1000,
}
# Total 2500 images per task

# Object properties
OBJECT_INTENSITY = 0.2
NOISE_AMPLITUDE = 0.1 # Uniform noise range [-0.1, 0.1]

# Task 1 & 2 specific parameters
CIRCLE_RADIUS_T12 = 15
DONUT_OUTER_RADIUS_T12 = 15
DONUT_INNER_RADIUS_T12 = 7

# Task 3 specific parameters
CIRCLE_RADIUS_T3_CLS1 = 9 # Corresponds to diameter ~19
CIRCLE_RADIUS_T3_CLS2 = 10 # Corresponds to diameter 20

# --- Helper Functions ---

def create_base_image(size):
    """Creates a black base image (numpy array)."""
    return np.zeros((size, size), dtype=np.float32)

def add_noise(image_array):
    """Adds uniform noise and clips values to [0, 1]."""
    noise = np.random.uniform(-NOISE_AMPLITUDE, NOISE_AMPLITUDE, image_array.shape)
    noisy_image = image_array + noise
    return np.clip(noisy_image, 0.0, 1.0)

def draw_circle(image, center, radius, color, thickness=-1):
    """Draws a filled circle using cv2."""
    cv2.circle(image, center, radius, color, thickness)

def draw_donut(image, center, outer_radius, inner_radius, color):
    """Draws a donut (filled circle with a smaller black circle)."""
    # Draw the outer circle
    cv2.circle(image, center, outer_radius, color, -1)
    # Draw the inner circle (hole) with background color (0 for image, 0 for mask)
    cv2.circle(image, center, inner_radius, 0, -1)

def get_bounding_box(mask, class_id):
    """Calculates the bounding box for a given class_id in a mask."""
    y_indices, x_indices = np.where(mask == class_id)
    if len(y_indices) == 0:
        # Should not happen if draw was successful, but handle anyway
        return 0, 0, 0, 0
    xmin = int(np.min(x_indices))
    xmax = int(np.max(x_indices))
    ymin = int(np.min(y_indices))
    ymax = int(np.max(y_indices))
    # Add 1 to max coords because bbox max is exclusive in some contexts
    # but often inclusive for definition. Let's use inclusive pixel coords.
    return xmin, ymin, xmax, ymax

def save_image(image_array, path):
    """Saves a float32 [0,1] image as uint8 [0,255] PNG."""
    image_uint8 = (image_array * 255).astype(np.uint8)
    cv2.imwrite(path, image_uint8)

def save_mask(mask_array, path):
    """Saves a uint8 mask directly."""
    cv2.imwrite(path, mask_array.astype(np.uint8))

# --- Data Generation Function ---

def generate_data_for_task(task_id, base_dir):
    """Generates the dataset for a specific task."""
    print(f"--- Generating Task {task_id} ---")
    task_dir = os.path.join(base_dir, f"task{task_id}")
    os.makedirs(task_dir, exist_ok=True)

    for split, num_images in SPLITS.items():
        print(f"Generating {split} split ({num_images} images)...")
        split_dir = os.path.join(task_dir, split)
        img_dir = os.path.join(split_dir, "images")
        mask_dir = os.path.join(split_dir, "masks")
        os.makedirs(img_dir, exist_ok=True)
        os.makedirs(mask_dir, exist_ok=True)

        annotations = []

        # Use tqdm for progress bar
        for i in tqdm(range(num_images), desc=f"Task {task_id} - {split}"):
            img_filename = f"img_{i:04d}.png"
            mask_filename = f"mask_{i:04d}.png"
            img_path = os.path.join(img_dir, img_filename)
            mask_path = os.path.join(mask_dir, mask_filename)

            # Create base images
            image = create_base_image(IMAGE_SIZE)
            mask = create_base_image(IMAGE_SIZE).astype(np.uint8) # Mask uses int class IDs

            # Determine object class and parameters based on task
            if task_id == 1 or task_id == 2:
                class_id = random.choice([1, 2]) # 1: Circle, 2: Donut
                radius = CIRCLE_RADIUS_T12
                outer_radius = DONUT_OUTER_RADIUS_T12
                inner_radius = DONUT_INNER_RADIUS_T12
                # Ensure object fits within image bounds
                margin = max(radius, outer_radius) + 1
                center_x = random.randint(margin, IMAGE_SIZE - margin -1)
                center_y = random.randint(margin, IMAGE_SIZE - margin -1)
                center = (center_x, center_y)

                if class_id == 1: # Circle
                    draw_circle(image, center, radius, OBJECT_INTENSITY)
                    # Task 1: Mask is a circle
                    # Task 2: Mask is a circle
                    draw_circle(mask, center, radius, class_id) # Mask uses class_id as color
                else: # Donut (Class 2)
                    # Draw visual donut on image
                    draw_donut(image, center, outer_radius, inner_radius, OBJECT_INTENSITY)
                    if task_id == 1:
                        # Task 1: Mask is donut shape
                        draw_donut(mask, center, outer_radius, inner_radius, class_id)
                    else: # Task 2
                        # Task 2: Mask is a filled circle, even though image is donut
                        draw_circle(mask, center, outer_radius, class_id)

            elif task_id == 3:
                class_id = random.choice([1, 2]) # 1: Small Circle, 2: Large Circle
                radius = CIRCLE_RADIUS_T3_CLS1 if class_id == 1 else CIRCLE_RADIUS_T3_CLS2
                # Ensure object fits within image bounds
                margin = radius + 1
                center_x = random.randint(margin, IMAGE_SIZE - margin -1)
                center_y = random.randint(margin, IMAGE_SIZE - margin -1)
                center = (center_x, center_y)

                # Draw visual circle on image
                draw_circle(image, center, radius, OBJECT_INTENSITY)
                # Draw corresponding circle mask
                draw_circle(mask, center, radius, class_id)

            # Add noise to the image
            noisy_image = add_noise(image)

            # Get bounding box from the mask
            xmin, ymin, xmax, ymax = get_bounding_box(mask, class_id)

            # Save image and mask
            save_image(noisy_image, img_path)
            save_mask(mask, mask_path)

            # Append annotation (use relative paths for portability)
            relative_img_path = os.path.join("images", img_filename)
            relative_mask_path = os.path.join("masks", mask_filename)
            annotations.append([relative_img_path, relative_mask_path, xmin, ymin, xmax, ymax, class_id])

        # Save annotations to CSV
        annot_df = pd.DataFrame(annotations, columns=["image_path", "mask_path", "xmin", "ymin", "xmax", "ymax", "class_id"])
        annot_path = os.path.join(split_dir, "annotations.csv")
        annot_df.to_csv(annot_path, index=False)
        print(f"Saved annotations to {annot_path}")

# --- Main Execution ---
if __name__ == "__main__":
    # Define the main directory to store the datasets
    # In Colab, '/content/' is a good place for temporary storage
    datasets_root_dir = "/content/toy_datasets"
    os.makedirs(datasets_root_dir, exist_ok=True)

    # Generate data for all three tasks
    generate_data_for_task(1, datasets_root_dir) # Task 1: Distinguishing Shapes
    generate_data_for_task(2, datasets_root_dir) # Task 2: Learning Patterns
    generate_data_for_task(3, datasets_root_dir) # Task 3: Distinguishing Scales

    print("\n--- Dataset Generation Complete ---")
    print(f"Datasets saved in: {datasets_root_dir}")
    print("Each task folder contains train/val/test splits.")
    print("Each split folder contains 'images', 'masks', and 'annotations.csv'.")

    # Example: Display structure of Task 1's train split
    print("\nExample structure for Task 1 train split:")
    !ls -lh {datasets_root_dir}/task1/train
    print("\nFirst 5 rows of Task 1 train annotations:")
    !head -n 6 {datasets_root_dir}/task1/train/annotations.csv

    # Example: Display one generated image and mask from Task 1 train
    import matplotlib.pyplot as plt
    print("\nExample Image and Mask (Task 1, Train, Image 0):")
    example_img_path = os.path.join(datasets_root_dir, "task1/train/images/img_0000.png")
    example_mask_path = os.path.join(datasets_root_dir, "task1/train/masks/mask_0000.png")

    if os.path.exists(example_img_path) and os.path.exists(example_mask_path):
        img_display = cv2.imread(example_img_path, cv2.IMREAD_GRAYSCALE)
        mask_display = cv2.imread(example_mask_path, cv2.IMREAD_GRAYSCALE)

        fig, axes = plt.subplots(1, 2, figsize=(8, 4))
        axes[0].imshow(img_display, cmap='gray')
        axes[0].set_title("Example Image")
        axes[0].axis('off')
        axes[1].imshow(mask_display, cmap='gray', vmin=0, vmax=2) # Max value is 2 for class IDs
        axes[1].set_title("Example Mask")
        axes[1].axis('off')
        plt.tight_layout()
        plt.show()
    else:
        print("Could not find example image/mask to display.")

In [None]:
!pip install torch torchvision pandas opencv-python-headless tqdm matplotlib

In [None]:
import os
import cv2
import numpy as np
import pandas as pd
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision.transforms.functional import to_tensor # Use functional transforms
import matplotlib.pyplot as plt
import matplotlib.patches as patches

# --- Custom PyTorch Dataset ---

class ToyObjectDetectionDataset(Dataset):
    """
    Custom PyTorch Dataset for the generated toy object detection data.
    Loads images, masks, and annotations.
    Outputs image tensor and a target dictionary compatible with TorchVision detection models.
    """
    def __init__(self, split_dir, transforms=None):
        """
        Args:
            split_dir (str): Path to the specific split directory
                             (e.g., '/content/toy_datasets/task1/train').
            transforms (callable, optional): Optional transform to be applied
                                             on a sample.
        """
        self.split_dir = split_dir
        self.img_dir = os.path.join(split_dir, "images")
        self.mask_dir = os.path.join(split_dir, "masks")
        self.transforms = transforms

        # Load annotations
        annot_path = os.path.join(split_dir, "annotations.csv")
        try:
            self.annotations = pd.read_csv(annot_path)
        except FileNotFoundError:
            print(f"Error: annotations.csv not found in {split_dir}")
            self.annotations = pd.DataFrame() # Empty dataframe

    def __len__(self):
        """Returns the total number of samples in the dataset."""
        return len(self.annotations)

    def __getitem__(self, idx):
        """
        Fetches the sample (image, target) at the given index.

        Returns:
            tuple: (image, target) where image is the image Tensor
                   and target is a dictionary containing object detection information.
        """
        if torch.is_tensor(idx):
            idx = idx.tolist()

        # --- Load Data ---
        annot_row = self.annotations.iloc[idx]
        img_name = annot_row['image_path'] # Already relative like 'images/img_0000.png'
        mask_name = annot_row['mask_path'] # Already relative like 'masks/mask_0000.png'

        # Construct full paths
        img_path = os.path.join(self.split_dir, img_name)
        mask_path = os.path.join(self.split_dir, mask_name)

        # Load image (OpenCV loads as BGR, convert to RGB)
        # Load as float32 initially for potential transforms, normalize later
        image = cv2.imread(img_path)
        if image is None:
             raise FileNotFoundError(f"Could not load image: {img_path}")
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) # Now HWC, RGB uint8

        # Load mask (Grayscale)
        mask = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE) # HWC, Grayscale uint8
        if mask is None:
             raise FileNotFoundError(f"Could not load mask: {mask_path}")

        # --- Extract Annotation Info ---
        # Bounding box [xmin, ymin, xmax, ymax]
        box = [annot_row['xmin'], annot_row['ymin'], annot_row['xmax'], annot_row['ymax']]
        boxes = torch.as_tensor([box], dtype=torch.float32) # Shape: [1, 4]

        # Label (Class ID) - ensure it's int64 for PyTorch criteria
        # IMPORTANT: Often, class 0 is background. Our labels are 1 and 2.
        # Ensure your model's final layer accounts for this or remap if needed.
        label = annot_row['class_id']
        labels = torch.as_tensor([label], dtype=torch.int64) # Shape: [1]

        # Image ID
        image_id = torch.tensor([idx])

        # Area of bounding box
        area = (boxes[:, 2] - boxes[:, 0]) * (boxes[:, 3] - boxes[:, 1])

        # Suppose all instances are not crowd
        iscrowd = torch.zeros((1,), dtype=torch.int64) # Shape: [1]

        # Segmentation Mask
        # Extract the specific object's mask (pixels where mask == class_id)
        # Convert to uint8 tensor, shape [1, H, W]
        obj_mask = (mask == label).astype(np.uint8)
        obj_masks = torch.as_tensor(obj_mask, dtype=torch.uint8).unsqueeze(0) # Add instance dim

        # --- Create Target Dictionary ---
        target = {}
        target["boxes"] = boxes
        target["labels"] = labels
        target["masks"] = obj_masks
        target["image_id"] = image_id
        target["area"] = area
        target["iscrowd"] = iscrowd

        # --- Apply Transforms (if any) ---
        # Convert image numpy HWC uint8 -> torch C H W float32 [0,1]
        # Note: More complex transforms might need adjustment for target['boxes']/'masks'
        image_tensor = to_tensor(image) # Converts to CxHxW and scales to [0.0, 1.0]

        if self.transforms:
             # Basic example: Applying transforms only to image tensor
             # More advanced transforms might need to modify the target dict too
             # Libraries like Albumentations handle this better for detection/segmentation
             image_tensor = self.transforms(image_tensor)
             # !! If transforms include resizing, bounding boxes and masks in target MUST be updated !!
             # This basic example does not handle target transformation.


        return image_tensor, target

# --- Collate Function ---
def collate_fn(batch):
    """
    Custom collate function for object detection.
    Since targets can be dictionaries with tensors of varying sizes,
    we don't stack them but keep them as a list. Images are stacked.

    Args:
        batch: A list of tuples, where each tuple is (image_tensor, target_dict).

    Returns:
        tuple: (images, targets) where images is a stacked tensor of images
               and targets is a list of target dictionaries.
    """
    images = [item[0] for item in batch]
    targets = [item[1] for item in batch]
    images = torch.stack(images, dim=0) # Stack images along a new batch dimension
    return images, targets


# --- Example Usage ---

if __name__ == "__main__":
    # Assume datasets are generated in /content/toy_datasets
    datasets_root_dir = "/content/toy_datasets"
    task_id = 1
    split = 'train' # or 'val' or 'test'
    target_split_dir = os.path.join(datasets_root_dir, f"task{task_id}", split)

    # Check if the directory exists
    if not os.path.exists(target_split_dir):
        print(f"Error: Dataset directory not found: {target_split_dir}")
        print("Please run the dataset generation script first.")
    else:
        print(f"Loading dataset from: {target_split_dir}")

        # 1. Create the Dataset instance
        # We are not applying any further transforms here for simplicity
        toy_dataset = ToyObjectDetectionDataset(split_dir=target_split_dir, transforms=None)

        # Check dataset length
        print(f"Dataset size: {len(toy_dataset)} samples")

        # 2. Create the DataLoader instance
        batch_size = 4
        # Use num_workers > 0 for parallel data loading in real training
        # Use pin_memory=True if using GPU for faster host-to-device transfers
        data_loader = DataLoader(
            toy_dataset,
            batch_size=batch_size,
            shuffle=True, # Shuffle for training
            num_workers=0, # Set > 0 for parallel loading, but can cause issues in Colab sometimes
            collate_fn=collate_fn,
            pin_memory=False
        )

        # 3. Iterate through a batch
        print(f"\nFetching one batch of size {batch_size}...")
        try:
            images, targets = next(iter(data_loader))

            # Print shapes and types
            print("Batch loaded successfully!")
            print(f"Images batch shape: {images.shape}") # Should be [batch_size, C, H, W]
            print(f"Images batch dtype: {images.dtype}")
            print(f"Number of targets in batch: {len(targets)}") # Should be batch_size

            # Inspect the first target dictionary in the batch
            print("\nInspecting target for the first sample in the batch:")
            first_target = targets[0]
            for key, value in first_target.items():
                 if isinstance(value, torch.Tensor):
                     print(f"  Target['{key}']: shape={value.shape}, dtype={value.dtype}")
                 else:
                     print(f"  Target['{key}']: {value}")

            # 4. Optional: Visualize the first image and its bounding box/mask
            print("\nVisualizing first image in batch...")
            img_to_show = images[0].permute(1, 2, 0).numpy() # Convert CHW -> HWC for plt
            mask_to_show = targets[0]['masks'][0].numpy() # Get the first mask [H, W]
            box_to_show = targets[0]['boxes'][0].numpy() # Get the first box [xmin, ymin, xmax, ymax]
            label_to_show = targets[0]['labels'][0].item()

            fig, axes = plt.subplots(1, 2, figsize=(10, 5))

            # Image with BBox
            axes[0].imshow(img_to_show)
            rect = patches.Rectangle((box_to_show[0], box_to_show[1]), # bottom-left corner
                                     box_to_show[2] - box_to_show[0], # width
                                     box_to_show[3] - box_to_show[1], # height
                                     linewidth=2, edgecolor='r', facecolor='none')
            axes[0].add_patch(rect)
            axes[0].set_title(f"Image with BBox (Class {label_to_show})")
            axes[0].axis('off')

            # Mask Overlay
            axes[1].imshow(img_to_show)
            axes[1].imshow(mask_to_show, cmap='jet', alpha=0.5) # Overlay mask
            axes[1].set_title("Image with Mask Overlay")
            axes[1].axis('off')

            plt.tight_layout()
            plt.show()

        except StopIteration:
            print("DataLoader is empty (perhaps dataset is empty or batch size > dataset size).")
        except Exception as e:
             print(f"An error occurred while loading batch: {e}")
             import traceback
             traceback.print_exc()

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.models as models
from torchvision.models.detection.anchor_utils import AnchorGenerator
from torchvision.models.detection.retinanet import RetinaNetHead
from torchvision.ops import FeaturePyramidNetwork, sigmoid_focal_loss
from torchvision.models._utils import IntermediateLayerGetter # To get intermediate ResNet layers
from torchvision.models.detection.image_list import ImageList

# --- Helper Modules ---

class ConvBlock(nn.Module):
    """Basic Convolutional Block: Conv -> BN -> ReLU"""
    def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding=1):
        super().__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size,
                              stride=stride, padding=padding, bias=False)
        self.bn = nn.BatchNorm2d(out_channels)
        self.relu = nn.ReLU(inplace=True)

    def forward(self, x):
        return self.relu(self.bn(self.conv(x)))

class UNetDecoderBlock(nn.Module):
    """Upsamples, concatenates with skip connection, and applies ConvBlocks"""
    def __init__(self, in_channels_up, in_channels_skip, out_channels):
        super().__init__()
        # Upsampling uses bilinear interpolation followed by a 1x1 conv
        # to potentially adjust channels before concatenation
        self.upsample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
        # Adjust channels after upsampling to match the target out_channels / 2
        # assuming we'll concatenate with skip connection of similar channel size
        # Let's refine this: conv after upsampling to match skip connection expected size
        # Or, more commonly, conv *after* concatenation.
        # Let's upsample, then apply convs after concat.
        # Total channels after concat = in_channels_up + in_channels_skip

        self.conv1 = ConvBlock(in_channels_up + in_channels_skip, out_channels)
        self.conv2 = ConvBlock(out_channels, out_channels)

    def forward(self, x_up, x_skip):
        x_up = self.upsample(x_up)
        # Pad x_up if spatial dimensions don't match x_skip after upsampling
        # This can happen due to pooling/padding in the encoder
        diffY = x_skip.size()[2] - x_up.size()[2]
        diffX = x_skip.size()[3] - x_up.size()[3]
        x_up = F.pad(x_up, [diffX // 2, diffX - diffX // 2,
                        diffY // 2, diffY - diffY // 2])

        x = torch.cat([x_up, x_skip], dim=1)
        x = self.conv1(x)
        x = self.conv2(x)
        return x

class DiceLoss(nn.Module):
    """Soft Dice Loss"""
    def __init__(self, smooth=1.0):
        super().__init__()
        self.smooth = smooth

    def forward(self, logits, targets):
        # Apply softmax/sigmoid to logits to get probabilities
        # Assuming multi-class segmentation, use Softmax
        probs = F.softmax(logits, dim=1)

        # Flatten label and prediction tensors
        # Select the probability of the target class for each pixel
        # Assuming targets are LongTensor with class indices [B, H, W]
        # Probs are FloatTensor [B, C, H, W]
        targets_one_hot = F.one_hot(targets, num_classes=probs.shape[1]).permute(0, 3, 1, 2).float()
        # targets_one_hot shape: [B, C, H, W]

        probs_flat = probs.contiguous().view(probs.shape[0], probs.shape[1], -1)
        targets_flat = targets_one_hot.contiguous().view(targets_one_hot.shape[0], targets_one_hot.shape[1], -1)

        intersection = torch.sum(probs_flat * targets_flat, dim=2)
        union = torch.sum(probs_flat, dim=2) + torch.sum(targets_flat, dim=2)

        dice = (2. * intersection + self.smooth) / (union + self.smooth)

        # Average dice across classes (ignoring background maybe?) and batch
        # Let's average over all classes including background for now
        dice_loss = 1.0 - dice.mean()
        return dice_loss


# --- Retina U-Net Model ---

class RetinaUNet(nn.Module):
    def __init__(self, num_classes_det, num_classes_seg,
                 fpn_out_channels=256, backbone_name='resnet50', pretrained_backbone=True,
                 # RetinaNet specific params (can tune these)
                 anchor_sizes=((16,), (32,), (64,), (128,), (256,)), # Adjusted for P2-P5
                 aspect_ratios=((0.5, 1.0, 2.0),) * 5,
                 # Loss weights
                 seg_loss_weight=1.0,
                 # U-Net decoder channels
                 unet_decoder_channels=(256, 128, 64), # Channels for P1, P0 blocks' outputs
                 # ResNet layer names for skips (adjust if backbone changes)
                 return_layers_map = {'layer1': 'C2', 'layer2': 'C3', 'layer3': 'C4', 'layer4': 'C5'},
                 # Early ResNet layer for P0 skip connection
                 skip_layer_p0 = 'relu' # Or 'conv1'
                 ):
        super().__init__()

        self.num_classes_det = num_classes_det # Including background for detector
        self.num_classes_seg = num_classes_seg # Including background for segmentor
        self.seg_loss_weight = seg_loss_weight

        self.return_layers_map = return_layers_map


        # --- Backbone (ResNet + Intermediate Layer Getter) ---
        backbone = getattr(models, backbone_name)(pretrained=pretrained_backbone)

        # We need C1 (for P1), C2, C3, C4, C5 features from the backbone
        # And potentially an earlier layer for the P0 skip connection
        self.body = IntermediateLayerGetter(backbone, return_layers=return_layers_map)

        # Determine backbone output channels (standard for ResNet50)
        # This might need adjustment for other backbones
        # C1: 64 (after conv1/relu), C2: 256, C3: 512, C4: 1024, C5: 2048
        # Let's get C1 from the initial conv layers if needed for P1 skip
        self.early_features = nn.Sequential(
            backbone.conv1,
            backbone.bn1,
            backbone.relu,
            backbone.maxpool,
        ) # Output C1 has 64 channels after relu or maxpool

        backbone_out_channels = { # Channels *before* FPN
            'C2': 256, 'C3': 512, 'C4': 1024, 'C5': 2048
        }
        c1_channels = 64 # After backbone.relu
        early_skip_channels = 64 if skip_layer_p0 == 'relu' else 3 # If using input image

        # --- FPN (P2-P5) ---
        # FPN takes C2, C3, C4, C5 and outputs P2, P3, P4, P5
        # Input channels must match backbone_out_channels for the specified layers
        # FPN needs input layers names corresponding to the keys in backbone_out_channels
        self.fpn = FeaturePyramidNetwork(
            in_channels_list=[backbone_out_channels[k] for k in return_layers_map.values()],
            out_channels=fpn_out_channels,
            # We need P2 output as well, FPN usually starts from P3 based on C3
            # Adjust FPN or use custom implementation if P2 is needed from C2 directly
            # Torchvision FPN calculates Pk from Ck, so it should output P2,P3,P4,P5
            # It also adds P6 by default from C5, we might not need it
            extra_blocks=None # Disable P6 generation if not needed
        )

        # --- U-Net Decoder Extension (P2 -> P1 -> P0) ---
        # P1 Block: Upsamples P2, concatenates with C1, outputs features for P1
        self.decoder_p1 = UNetDecoderBlock(fpn_out_channels, c1_channels, unet_decoder_channels[0]) # P2->P1
        # P0 Block: Upsamples P1, concatenates with early features, outputs features for P0
        self.decoder_p0 = UNetDecoderBlock(unet_decoder_channels[0], early_skip_channels, unet_decoder_channels[1]) # P1->P0

        # --- Segmentation Head ---
        self.seg_head = nn.Conv2d(unet_decoder_channels[1], num_classes_seg, kernel_size=1)

        # --- Detection Head ---
        # Anchor Generator for P2, P3, P4, P5
        # Need to know the strides of these features relative to input
        # ResNet strides: C1=4, C2=4, C3=8, C4=16, C5=32
        # FPN Strides: P2=4, P3=8, P4=16, P5=32
        # Anchor sizes adjusted from paper (divided by 4 approx)
        # Example: Original {32^2, 64^2, 128^2, 256^2, 512^2} -> {8^2, 16^2, 32^2, 64^2, 128^2}
        # Paper uses {4^2, 8^2, 16^2, 32^2} for P2-P5 -> sizes (16, 32, 64, 128)
        # Let's adjust default anchor sizes:
        adjusted_anchor_sizes = ((16,), (32,), (64,), (128,)) # For P2, P3, P4, P5 respectively
        aspect_ratios_per_level = (aspect_ratios[0],) * len(adjusted_anchor_sizes)

        self.anchor_generator = AnchorGenerator(sizes=adjusted_anchor_sizes, aspect_ratios=aspect_ratios_per_level)

        # RetinaNet Head (applies to P2, P3, P4, P5)
        # Takes fpn_out_channels as input
        self.det_head = RetinaNetHead(
            in_channels=fpn_out_channels,
            num_anchors=self.anchor_generator.num_anchors_per_location()[0],
            num_classes=num_classes_det
        )

        # --- Loss Functions ---
        self.seg_criterion_ce = nn.CrossEntropyLoss()
        self.seg_criterion_dice = DiceLoss()
        # Detection losses (Focal, SmoothL1) are handled inside RetinaNet logic usually
        # We might need to implement the target assignment and loss calculation if not inheriting
        # Let's prepare for external loss calculation for now.
        # For Focal Loss, using torchvision.ops.sigmoid_focal_loss is efficient.
        # For SmoothL1, use torch.nn.SmoothL1Loss or F.smooth_l1_loss.

        # Need to handle target assignment for detection (match anchors to GT boxes)
        # Torchvision's RetinaNet has self.compute_loss and self.proposal_matcher
        # Replicating this is non-trivial. Let's assume for now we get assigned targets externally
        # OR: We will need to integrate torchvision's loss computation logic.


    def forward(self, images, targets=None):
        """
        Args:
            images (Tensor): Input images, shape [B, C, H, W]
            targets (List[Dict[str, Tensor]], optional): Ground truth for training.
                Each dict contains:
                    'boxes' (Tensor[N, 4]): [xmin, ymin, xmax, ymax]
                    'labels' (Tensor[N]): Integer labels (0 is often background)
                    'masks' (Tensor[N, H, W]): Segmentation masks (uint8)
                Required during training.

        Returns:
            During training (targets is not None):
                Dict[str, Tensor]: Dictionary of losses.
            During inference (targets is None):
                List[Dict[str, Tensor]]: Detections per image.
                Tensor: Segmentation map predictions [B, C_seg, H, W] (logits or probs).
        """
        if isinstance(images, list):
            original_image_sizes = [img.shape[-2:] for img in images]
            images = torch.stack(images, 0) # Stack if input is a list
        else: # Assuming input is already a tensor
            original_image_sizes = [images.shape[-2:]] * images.shape[0]

        # --- Backbone ---
        # Get C1 features separately (used for P1 skip)
        # Note: Depending on exactly where skip is needed, might adjust 'early_features'
        # This gets features AFTER the first maxpool (stride 4)
        c1_features = self.early_features(images) # Output e.g. 64 channels, stride 4

        # Get C2, C3, C4, C5 features by running backbone on original images
        # self.body uses IntermediateLayerGetter and expects original input
        backbone_features = self.body(images)
        # backbone_features is a dict {'C2': ..., 'C3': ..., 'C4': ..., 'C5': ...}

        # --- FPN (P2-P5) ---
        # Ensure the keys match what FPN expects
        fpn_input_features = {k: backbone_features[k] for k in self.return_layers_map.values()}
        fpn_features = self.fpn(fpn_input_features)
        # fpn_features is an OrderedDict {'0': P2, '1': P3, '2': P4, '3': P5} (standard FPN keys)
        p2 = fpn_features['C2'] # P2 corresponds to input C2
        p3 = fpn_features['C3'] # P3 corresponds to input C3
        p4 = fpn_features['C4'] # P4 corresponds to input C4
        p5 = fpn_features['C5'] # P5 corresponds to input C5

        fpn_outputs_for_detection = [p2, p3, p4, p5]

        # --- U-Net Decoder (P2 -> P1 -> P0) ---
        p1 = self.decoder_p1(p2, c1_features)
        # For P0 skip connection, choose the appropriate early feature map
        # Here using c1_features (output of early_features), adjust if needed
        # e.g., if you need skip before maxpool, redefine early_features/c1_features
        p0 = self.decoder_p0(p1, c1_features)

        # --- Segmentation Head ---
        seg_logits = self.seg_head(p0)
        seg_logits = F.interpolate(seg_logits, size=original_image_sizes[0], mode='bilinear', align_corners=False)

        # --- Detection Head ---
        det_cls_logits, det_bbox_reg = self.det_head(fpn_outputs_for_detection)

        # --- Create ImageList Object ---
        # Convert list of [H, W] tensors/lists to list of tuples (H, W)
        original_image_sizes_tuples = [tuple(s) for s in original_image_sizes]
        image_list_obj = ImageList(images, original_image_sizes_tuples)


        # --- Generate Anchors ---
        # Pass the ImageList object to the anchor generator
        anchors = self.anchor_generator(image_list_obj, fpn_outputs_for_detection)

        # --- Output ---
        losses = {}
        detections = None

        if self.training:
            if targets is None:
                raise ValueError("Targets must be provided during training.")

            # !!! ---> Placeholder for actual detection loss calculation <--- !!!
            # Requires implementing or integrating anchor-target matching and loss computation
            # Example: Use torchvision's built-in logic if possible, or implement manually
            # gt_classes_target, gt_regression_target = self.assign_targets(anchors, targets) # Assign targets function needed
            # loss_det_cls = self.compute_cls_loss(det_cls_logits, gt_classes_target) # e.g., Focal Loss
            # loss_det_reg = self.compute_reg_loss(det_bbox_reg, gt_regression_target, gt_classes_target) # e.g., Smooth L1

            # print("Warning: Detection loss calculation is currently a placeholder.")
            loss_det_cls = torch.tensor(0.0, device=images.device, requires_grad=True) # Dummy, ensure requires_grad for backprop test
            loss_det_reg = torch.tensor(0.0, device=images.device, requires_grad=True) # Dummy
            # !!! ---------------------------------------------------------- !!!

            # --- Calculate Segmentation Loss ---
            gt_seg_masks = []
            for i, t in enumerate(targets):
                h, w = original_image_sizes[i]
                combined_mask = torch.zeros((h, w), dtype=torch.long, device=images.device)
                if 'masks' in t and 'labels' in t and t['masks'].numel() > 0: # Check if masks/labels exist and are not empty
                     instance_labels = t['labels']
                     # Resize masks ONCE before the loop if possible
                     instance_masks_resized = F.interpolate(t['masks'].unsqueeze(1).float(), size=(h,w)).squeeze(1).byte()
                     for mask_idx in range(instance_masks_resized.shape[0]):
                         label = instance_labels[mask_idx]
                         mask = instance_masks_resized[mask_idx]
                         combined_mask[mask > 0] = label
                gt_seg_masks.append(combined_mask)
            gt_seg_masks = torch.stack(gt_seg_masks, dim=0)

            loss_seg_ce = self.seg_criterion_ce(seg_logits, gt_seg_masks)
            loss_seg_dice = self.seg_criterion_dice(seg_logits, gt_seg_masks)
            loss_segmentation = loss_seg_ce + loss_seg_dice

            # --- Combine Losses ---
            losses = {
                "loss_detector_cls": loss_det_cls,
                "loss_detector_reg": loss_det_reg,
                "loss_segmentation": loss_segmentation * self.seg_loss_weight,
            }
            # Ensure total loss requires grad if components do
            losses["total_loss"] = loss_det_cls + loss_det_reg + losses["loss_segmentation"]

            return losses

        else: # Inference mode
            # !!! ---> Placeholder for actual detection post-processing <--- !!!
            # Requires applying sigmoid, decoding boxes, NMS etc.
            # print("Warning: Detection post-processing is currently a placeholder.")
            # Simulate output format
            detections = [{"boxes": torch.empty((0, 4)), "scores": torch.empty((0,)), "labels": torch.empty((0,), dtype=torch.long)} for _ in range(images.shape[0])]
            # !!! ---------------------------------------------------------- !!!

            seg_predictions = F.softmax(seg_logits, dim=1)
            return detections, seg_predictions


# --- Example Instantiation ---
if __name__ == '__main__':
    # Example parameters (adjust based on your dataset)
    NUM_DET_CLASSES = 3 # E.g., Background, Circle, Donut (or Benign, Malignant)
    NUM_SEG_CLASSES = 3 # E.g., Background, Class1, Class2 (must match detector potentially)

    model = RetinaUNet(
        num_classes_det=NUM_DET_CLASSES,
        num_classes_seg=NUM_SEG_CLASSES,
        backbone_name='resnet50',
        pretrained_backbone=True,
        seg_loss_weight=1.0
    )

    model.eval() # Set to evaluation mode for inference example

    # Create a dummy input batch
    dummy_images = torch.randn(2, 3, 320, 320) # Batch size 2, 3 channels, 320x320

    # --- Inference Example ---
    with torch.no_grad():
        detections, seg_predictions = model(dummy_images)

    print("--- Inference Output ---")
    print(f"Number of images with detections: {len(detections)}")
    if detections:
         print("Detections for first image (placeholders):", detections[0])
    print(f"Segmentation predictions shape: {seg_predictions.shape}") # [B, C_seg, H, W]


    # --- Training Example (Conceptual - Needs Targets & Full Loss Impl.) ---
    model.train()
    # Create dummy targets (replace with actual data loading)
    dummy_targets = [
        {
            "boxes": torch.tensor([[50, 50, 100, 100], [150, 150, 180, 180]], dtype=torch.float32),
            "labels": torch.tensor([1, 2], dtype=torch.int64), # Class 1, Class 2
            "masks": torch.randint(0, 2, (2, 320, 320), dtype=torch.uint8) # Dummy instance masks
        },
        {
            "boxes": torch.tensor([[70, 80, 120, 150]], dtype=torch.float32),
            "labels": torch.tensor([1], dtype=torch.int64), # Class 1
            "masks": torch.randint(0, 2, (1, 320, 320), dtype=torch.uint8)
        }
    ]

    # Note: This will likely fail or produce dummy losses until detection loss is implemented
    try:
        losses = model(dummy_images, dummy_targets)
        print("\n--- Training Output (Losses - Detection Loss Placeholder) ---")
        for k, v in losses.items():
            print(f"{k}: {v.item()}")
    except Exception as e:
        print(f"\nError during training forward pass (expected if detection loss not fully implemented): {e}")

In [None]:
import torch
import torch.optim as optim
from torch.utils.data import DataLoader
import torch.cuda.amp as amp # For mixed precision
import os
import time
import copy
from tqdm.notebook import tqdm # Use notebook version for Colab
import matplotlib.pyplot as plt

# --- Assume these are defined in other files/cells ---
# from dataset import ToyObjectDetectionDataset, collate_fn # Or paste the code here
# from model import RetinaUNet # Or paste the code here
# --- Paste the Dataset and Model code definitions here if not importing ---
# <<< PASTE ToyObjectDetectionDataset and collate_fn HERE >>>
# <<< PASTE RetinaUNet, UNetDecoderBlock, ConvBlock, DiceLoss HERE >>>
# --- End of pasted code ---


# --- Configuration ---
DATASET_ROOT = "/content/toy_datasets"
TASK_ID = 1 # Choose task 1, 2, or 3
MODEL_SAVE_DIR = "/content/retina_unet_checkpoints"
MODEL_SAVE_NAME = f"retina_unet_task{TASK_ID}_best.pth"

# Training Hyperparameters (from paper where specified)
LEARNING_RATE = 1e-4
BATCH_SIZE = 20 # As per paper for 2D
NUM_EPOCHS = 50 # Adjust as needed
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# For detection classes, add 1 for the background class
NUM_DET_CLASSES_MODEL = 3 # BG, Class1, Class2 (for toy dataset)
# For segmentation classes, assuming same mapping
NUM_SEG_CLASSES_MODEL = 3 # BG, Class1, Class2

# Mixed Precision Training (recommended for larger models/GPUs)
USE_AMP = torch.cuda.is_available()

# --- Ensure save directory exists ---
os.makedirs(MODEL_SAVE_DIR, exist_ok=True)

# --- Training Function ---

def train_model(model, dataloaders, optimizer, num_epochs, model_save_path):
    """
    Main training loop for Retina U-Net.

    Args:
        model (nn.Module): The RetinaUNet model instance.
        dataloaders (dict): Dictionary containing 'train' and 'val' DataLoaders.
        optimizer (torch.optim.Optimizer): The optimizer instance.
        num_epochs (int): Number of epochs to train for.
        model_save_path (str): Path to save the best model checkpoint.
    """
    start_time = time.time()
    best_model_wts = copy.deepcopy(model.state_dict())
    # Using validation loss as proxy for best model since mAP requires implemented post-processing
    best_val_loss = float('inf')

    # AMP Scaler
    scaler = amp.GradScaler(enabled=USE_AMP)

    history = {'train_loss': [], 'val_loss': [],
               'train_loss_seg': [], 'val_loss_seg': [],
               'train_loss_det_cls': [], 'val_loss_det_cls': [],
               'train_loss_det_reg': [], 'val_loss_det_reg': []}

    print(f"Starting training on {DEVICE}")

    for epoch in range(num_epochs):
        print(f'\nEpoch {epoch+1}/{num_epochs}')
        print('-' * 10)

        # Each epoch has a training and validation phase
        for phase in ['train', 'val']:
            if phase == 'train':
                model.train()  # Set model to training mode
            else:
                model.eval()   # Set model to evaluate mode

            running_loss = 0.0
            running_loss_seg = 0.0
            running_loss_det_cls = 0.0
            running_loss_det_reg = 0.0

            # Iterate over data.
            progress_bar = tqdm(dataloaders[phase], desc=f"{phase.capitalize()} Epoch {epoch+1}")
            for images, targets in progress_bar:
                images = images.to(DEVICE)
                # Move targets to device (list of dicts with tensors)
                targets = [{k: v.to(DEVICE) for k, v in t.items()} for t in targets]

                # Zero the parameter gradients
                optimizer.zero_grad()

                # Forward pass
                # Track history only in train phase
                with torch.set_grad_enabled(phase == 'train'):
                    # AMP context manager
                    with amp.autocast(enabled=USE_AMP):
                        # Model returns dict of losses in train mode
                        # In eval mode, it returns detections, seg_preds
                        # BUT: We need losses for validation monitoring too.
                        # Let's assume the model can optionally return losses in eval mode
                        # or we re-run in train mode without gradients for val loss.
                        # Easiest: Modify model forward to always compute losses if targets given.
                        # Let's assume current model calculates losses if targets are present.
                        loss_dict_or_outputs = model(images, targets)

                        if isinstance(loss_dict_or_outputs, dict): # Training or Eval with targets
                           loss_dict = loss_dict_or_outputs
                           # Check if all expected losses are present
                           if not all(k in loss_dict for k in ["total_loss", "loss_segmentation", "loss_detector_cls", "loss_detector_reg"]):
                               print(f"Warning: Loss dict missing keys: {loss_dict.keys()}")
                               # Handle missing keys gracefully (e.g., assign 0)
                               total_loss = loss_dict.get("total_loss", torch.tensor(0.0, device=DEVICE))
                               loss_seg = loss_dict.get("loss_segmentation", torch.tensor(0.0, device=DEVICE))
                               loss_det_cls = loss_dict.get("loss_detector_cls", torch.tensor(0.0, device=DEVICE))
                               loss_det_reg = loss_dict.get("loss_detector_reg", torch.tensor(0.0, device=DEVICE))
                           else:
                               total_loss = loss_dict["total_loss"]
                               loss_seg = loss_dict["loss_segmentation"]
                               loss_det_cls = loss_dict["loss_detector_cls"]
                               loss_det_reg = loss_dict["loss_detector_reg"]
                        else:
                            # Handle case where model returns predictions in eval mode
                            # Cannot calculate loss directly here without targets/loss logic
                            print("Warning: Model in eval mode did not return loss dict.")
                            total_loss = torch.tensor(0.0, device=DEVICE) # Assign dummy loss
                            loss_seg, loss_det_cls, loss_det_reg = total_loss, total_loss, total_loss


                    # Backward + optimize only if in training phase
                    if phase == 'train':
                        # Check if total_loss is valid before backward pass
                        if torch.isnan(total_loss) or torch.isinf(total_loss):
                           print(f"Warning: Invalid loss detected: {total_loss.item()}. Skipping batch.")
                           continue # Skip this batch

                        scaler.scale(total_loss).backward()
                        # Gradient Clipping (optional, but can help stability)
                        # torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
                        scaler.step(optimizer)
                        scaler.update()

                # Statistics
                running_loss += total_loss.item() * images.size(0)
                # Check if loss components are tensors before calling .item()
                running_loss_seg += loss_seg.item() * images.size(0) if torch.is_tensor(loss_seg) else loss_seg * images.size(0)
                running_loss_det_cls += loss_det_cls.item() * images.size(0) if torch.is_tensor(loss_det_cls) else loss_det_cls * images.size(0)
                running_loss_det_reg += loss_det_reg.item() * images.size(0) if torch.is_tensor(loss_det_reg) else loss_det_reg * images.size(0)

                # Update progress bar
                progress_bar.set_postfix(loss=total_loss.item())


            epoch_loss = running_loss / len(dataloaders[phase].dataset)
            epoch_loss_seg = running_loss_seg / len(dataloaders[phase].dataset)
            epoch_loss_det_cls = running_loss_det_cls / len(dataloaders[phase].dataset)
            epoch_loss_det_reg = running_loss_det_reg / len(dataloaders[phase].dataset)

            print(f'{phase.capitalize()} Total Loss: {epoch_loss:.4f}')
            print(f'  Seg Loss: {epoch_loss_seg:.4f} | Det Cls Loss: {epoch_loss_det_cls:.4f} | Det Reg Loss: {epoch_loss_det_reg:.4f}')

            history[f'{phase}_loss'].append(epoch_loss)
            history[f'{phase}_loss_seg'].append(epoch_loss_seg)
            history[f'{phase}_loss_det_cls'].append(epoch_loss_det_cls)
            history[f'{phase}_loss_det_reg'].append(epoch_loss_det_reg)


            # Save the best model based on validation loss
            if phase == 'val' and epoch_loss < best_val_loss:
                print(f"Validation loss improved ({best_val_loss:.4f} --> {epoch_loss:.4f}). Saving model...")
                best_val_loss = epoch_loss
                best_model_wts = copy.deepcopy(model.state_dict())
                torch.save(model.state_dict(), model_save_path)


    time_elapsed = time.time() - start_time
    print(f'\nTraining complete in {time_elapsed // 60:.0f}m {time_elapsed % 60:.0f}s')
    print(f'Best Validation Loss: {best_val_loss:4f}')

    # Load best model weights
    model.load_state_dict(best_model_wts)
    return model, history

# --- Main Execution ---

if __name__ == "__main__":
    # 1. Prepare Datasets and DataLoaders
    print("Loading datasets...")
    try:
        train_split_dir = os.path.join(DATASET_ROOT, f"task{TASK_ID}", "train")
        val_split_dir = os.path.join(DATASET_ROOT, f"task{TASK_ID}", "val")

        # Add data augmentation here if needed (e.g., using torchvision.transforms or albumentations)
        # train_transforms = ...
        # val_transforms = ...

        train_dataset = ToyObjectDetectionDataset(split_dir=train_split_dir, transforms=None)
        val_dataset = ToyObjectDetectionDataset(split_dir=val_split_dir, transforms=None)

        dataloaders = {
            'train': DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True,
                              num_workers=2, collate_fn=collate_fn, pin_memory=True),
            'val': DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False,
                             num_workers=2, collate_fn=collate_fn, pin_memory=True)
        }
        print(f"Train dataset size: {len(train_dataset)}")
        print(f"Validation dataset size: {len(val_dataset)}")
    except Exception as e:
        print(f"Error loading datasets: {e}")
        print("Please ensure the dataset generation script ran successfully and paths are correct.")
        exit()

    # 2. Initialize Model
    print("\nInitializing model...")
    # Note: Pass num_classes including background for detector head
    model = RetinaUNet(
        num_classes_det=NUM_DET_CLASSES_MODEL,
        num_classes_seg=NUM_SEG_CLASSES_MODEL,
        # Add other RetinaUNet parameters if needed (backbone, etc.)
    )
    model = model.to(DEVICE)

    # 3. Initialize Optimizer (Adam as per paper)
    # Filter parameters that require gradients (useful if freezing backbone layers)
    params_to_update = [p for p in model.parameters() if p.requires_grad]
    optimizer = optim.Adam(params_to_update, lr=LEARNING_RATE)

    # 4. Run Training
    print("\nStarting training loop...")
    print("--- WARNING: Detection loss/post-processing in model are placeholders! ---")
    print("--- Training/Validation relies on these being correctly implemented. ---")
    print("--- Validation tracks LOSS, not mAP, due to placeholders. ---")

    model_best, history = train_model(
        model,
        dataloaders,
        optimizer,
        num_epochs=NUM_EPOCHS,
        model_save_path=os.path.join(MODEL_SAVE_DIR, MODEL_SAVE_NAME)
    )

    # 5. Plot Training History (Losses)
    print("\nPlotting training history...")
    plt.figure(figsize=(12, 5))
    plt.subplot(1, 2, 1)
    plt.plot(history['train_loss'], label='Train Total Loss')
    plt.plot(history['val_loss'], label='Val Total Loss')
    plt.title('Total Loss vs. Epochs')
    plt.xlabel('Epochs')
    plt.ylabel('Loss')
    plt.legend()
    plt.grid(True)

    plt.subplot(1, 2, 2)
    plt.plot(history['train_loss_seg'], label='Train Seg Loss')
    plt.plot(history['val_loss_seg'], label='Val Seg Loss')
    plt.plot(history['train_loss_det_cls'], label='Train Det Cls Loss (PH)') # PH=Placeholder
    plt.plot(history['val_loss_det_cls'], label='Val Det Cls Loss (PH)')
    plt.plot(history['train_loss_det_reg'], label='Train Det Reg Loss (PH)')
    plt.plot(history['val_loss_det_reg'], label='Val Det Reg Loss (PH)')
    plt.title('Component Losses vs. Epochs')
    plt.xlabel('Epochs')
    plt.ylabel('Loss')
    plt.legend()
    plt.grid(True)

    plt.tight_layout()
    plt.show()

    print("Training finished.")