<a href="https://colab.research.google.com/github/shauryasawai/Medical_Research_Lab_Task/blob/main/train_1.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import os
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, random_split
import cv2
from tqdm import tqdm
import matplotlib.pyplot as plt
from pathlib import Path

In [2]:
# ============================================================================
# 1. DATASET CLASS
# ============================================================================

class FetalLandmarkDataset(Dataset):
    """Dataset for fetal ultrasound landmark detection"""

    def __init__(self, csv_path, image_dir, img_size=256):
        """
        Args:
            csv_path: Path to CSV with columns [image_name, ofd_1_x, ofd_1_y, ..., bpd_2_y]
            image_dir: Directory containing ultrasound images
            img_size: Target image size (default: 256x256)
        """
        self.df = pd.read_csv(csv_path)
        self.image_dir = image_dir
        self.img_size = img_size

        # Store original image dimensions for denormalization
        self.original_sizes = {}

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]

        # Load image
        img_name = row['image_name']
        img_path = os.path.join(self.image_dir, img_name)
        image = cv2.imread(img_path, cv2.IMREAD_GRAYSCALE)

        if image is None:
            raise FileNotFoundError(f"Image not found: {img_path}")

        # Store original size
        orig_h, orig_w = image.shape
        self.original_sizes[img_name] = (orig_w, orig_h)

        # Resize image
        image = cv2.resize(image, (self.img_size, self.img_size))

        # Normalize image to [0, 1]
        image = image.astype(np.float32) / 255.0

        # Convert to tensor (1, H, W)
        image = torch.tensor(image, dtype=torch.float32).unsqueeze(0)

        # Extract landmarks (8 values)
        landmarks = row[1:9].values.astype(np.float32)

        # Normalize coordinates to [0, 1]
        # Assumes landmarks are given in original image coordinates
        landmarks[0::2] = landmarks[0::2] / orig_w  # x coordinates
        landmarks[1::2] = landmarks[1::2] / orig_h  # y coordinates

        landmarks = torch.tensor(landmarks, dtype=torch.float32)

        return image, landmarks, img_name

In [3]:
# ============================================================================
# 2. MODEL ARCHITECTURE
# ============================================================================

class LandmarkCNN(nn.Module):
    """CNN for landmark regression"""

    def __init__(self, num_landmarks=8):
        super(LandmarkCNN, self).__init__()

        # Feature extraction layers
        self.features = nn.Sequential(
            # Block 1
            nn.Conv2d(1, 32, kernel_size=3, padding=1),
            nn.BatchNorm2d(32),
            nn.ReLU(inplace=True),
            nn.Conv2d(32, 32, kernel_size=3, padding=1),
            nn.BatchNorm2d(32),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2, 2),  # 256 -> 128

            # Block 2
            nn.Conv2d(32, 64, kernel_size=3, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
            nn.Conv2d(64, 64, kernel_size=3, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2, 2),  # 128 -> 64

            # Block 3
            nn.Conv2d(64, 128, kernel_size=3, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(inplace=True),
            nn.Conv2d(128, 128, kernel_size=3, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2, 2),  # 64 -> 32

            # Block 4
            nn.Conv2d(128, 256, kernel_size=3, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2, 2),  # 32 -> 16
        )

        # Regression head
        self.regressor = nn.Sequential(
            nn.Flatten(),
            nn.Linear(256 * 16 * 16, 512),
            nn.ReLU(inplace=True),
            nn.Dropout(0.5),
            nn.Linear(512, 256),
            nn.ReLU(inplace=True),
            nn.Dropout(0.3),
            nn.Linear(256, num_landmarks)
        )

    def forward(self, x):
        x = self.features(x)
        x = self.regressor(x)
        return x

In [4]:
# ============================================================================
# 3. TRAINING UTILITIES
# ============================================================================

def calculate_pixel_error(pred, target, img_size=256):
    """
    Calculate mean Euclidean distance error in pixels

    Args:
        pred: Predicted landmarks (normalized, 0-1)
        target: Ground truth landmarks (normalized, 0-1)
        img_size: Image size for denormalization

    Returns:
        Mean pixel error across all landmarks
    """
    # Denormalize to pixel coordinates
    pred_pixels = pred.cpu().numpy() * img_size
    target_pixels = target.cpu().numpy() * img_size

    # Reshape to (batch, num_points, 2)
    pred_points = pred_pixels.reshape(-1, 4, 2)
    target_points = target_pixels.reshape(-1, 4, 2)

    # Calculate Euclidean distance for each point
    distances = np.sqrt(np.sum((pred_points - target_points) ** 2, axis=2))

    # Return mean distance
    return distances.mean()


def train_epoch(model, dataloader, criterion, optimizer, device):
    """Train for one epoch"""
    model.train()
    total_loss = 0
    total_error = 0

    pbar = tqdm(dataloader, desc="Training")
    for images, landmarks, _ in pbar:
        images = images.to(device)
        landmarks = landmarks.to(device)

        # Forward pass
        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, landmarks)

        # Backward pass
        loss.backward()
        optimizer.step()

        # Calculate metrics
        total_loss += loss.item()
        error = calculate_pixel_error(outputs.detach(), landmarks.detach())
        total_error += error

        pbar.set_postfix({'loss': f'{loss.item():.4f}', 'error': f'{error:.2f}px'})

    avg_loss = total_loss / len(dataloader)
    avg_error = total_error / len(dataloader)

    return avg_loss, avg_error


def validate(model, dataloader, criterion, device):
    """Validate the model"""
    model.eval()
    total_loss = 0
    total_error = 0

    with torch.no_grad():
        for images, landmarks, _ in tqdm(dataloader, desc="Validating"):
            images = images.to(device)
            landmarks = landmarks.to(device)

            outputs = model(images)
            loss = criterion(outputs, landmarks)

            total_loss += loss.item()
            error = calculate_pixel_error(outputs, landmarks)
            total_error += error

    avg_loss = total_loss / len(dataloader)
    avg_error = total_error / len(dataloader)

    return avg_loss, avg_error

In [5]:
# ============================================================================
# 4. MEASUREMENT UTILITIES
# ============================================================================

def calculate_measurements(landmarks, img_size=256):
    """
    Calculate BPD and OFD from predicted landmarks

    Args:
        landmarks: Array of 8 values [ofd_1_x, ofd_1_y, ofd_2_x, ofd_2_y,
                                       bpd_1_x, bpd_1_y, bpd_2_x, bpd_2_y]
        img_size: Image size for denormalization

    Returns:
        (OFD, BPD) in pixels
    """
    # Denormalize
    landmarks_pixels = landmarks * img_size

    # Extract points
    ofd_1 = landmarks_pixels[0:2]
    ofd_2 = landmarks_pixels[2:4]
    bpd_1 = landmarks_pixels[4:6]
    bpd_2 = landmarks_pixels[6:8]

    # Calculate distances
    OFD = np.linalg.norm(ofd_1 - ofd_2)
    BPD = np.linalg.norm(bpd_1 - bpd_2)

    return OFD, BPD


def visualize_predictions(model, dataset, device, num_samples=4):
    """Visualize model predictions"""
    model.eval()

    fig, axes = plt.subplots(2, num_samples, figsize=(4*num_samples, 8))
    if num_samples == 1:
        axes = axes.reshape(-1, 1)

    indices = np.random.choice(len(dataset), num_samples, replace=False)

    with torch.no_grad():
        for i, idx in enumerate(indices):
            image, gt_landmarks, img_name = dataset[idx]
            image_input = image.unsqueeze(0).to(device)

            # Predict
            pred_landmarks = model(image_input).cpu().numpy()[0]
            gt_landmarks = gt_landmarks.numpy()

            # Denormalize
            pred_pixels = pred_landmarks * 256
            gt_pixels = gt_landmarks * 256

            # Calculate measurements
            pred_ofd, pred_bpd = calculate_measurements(pred_landmarks)
            gt_ofd, gt_bpd = calculate_measurements(gt_landmarks.numpy())

            # Plot
            img_display = image.squeeze().numpy()

            # Ground truth
            axes[0, i].imshow(img_display, cmap='gray')
            axes[0, i].plot([gt_pixels[0], gt_pixels[2]], [gt_pixels[1], gt_pixels[3]],
                           'g-', linewidth=2, label='OFD')
            axes[0, i].plot([gt_pixels[4], gt_pixels[6]], [gt_pixels[5], gt_pixels[7]],
                           'b-', linewidth=2, label='BPD')
            axes[0, i].scatter(gt_pixels[0::2], gt_pixels[1::2], c='red', s=50)
            axes[0, i].set_title(f'Ground Truth\nOFD: {gt_ofd:.1f}px, BPD: {gt_bpd:.1f}px')
            axes[0, i].axis('off')
            if i == 0:
                axes[0, i].legend()

            # Prediction
            axes[1, i].imshow(img_display, cmap='gray')
            axes[1, i].plot([pred_pixels[0], pred_pixels[2]], [pred_pixels[1], pred_pixels[3]],
                           'g-', linewidth=2, label='OFD')
            axes[1, i].plot([pred_pixels[4], pred_pixels[6]], [pred_pixels[5], pred_pixels[7]],
                           'b-', linewidth=2, label='BPD')
            axes[1, i].scatter(pred_pixels[0::2], pred_pixels[1::2], c='red', s=50)
            axes[1, i].set_title(f'Prediction\nOFD: {pred_ofd:.1f}px, BPD: {pred_bpd:.1f}px')
            axes[1, i].axis('off')
            if i == 0:
                axes[1, i].legend()

    plt.tight_layout()
    plt.savefig('predictions.png', dpi=150, bbox_inches='tight')
    plt.show()
    print("Visualization saved as 'predictions.png'")

In [None]:
# ============================================================================
# 5. MAIN TRAINING SCRIPT
# ============================================================================
def main():
    # Configuration
    CSV_PATH = 'landmarks.csv'
    IMAGE_DIR = 'images'
    BATCH_SIZE = 16
    NUM_EPOCHS = 50
    LEARNING_RATE = 1e-3
    IMG_SIZE = 256
    TRAIN_SPLIT = 0.8

    # Device
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"Using device: {device}")

    # Load dataset
    print("Loading dataset...")
    full_dataset = FetalLandmarkDataset(CSV_PATH, IMAGE_DIR, img_size=IMG_SIZE)

    # Split dataset
    train_size = int(TRAIN_SPLIT * len(full_dataset))
    val_size = len(full_dataset) - train_size
    train_dataset, val_dataset = random_split(
        full_dataset, [train_size, val_size],
        generator=torch.Generator().manual_seed(42)
    )

    print(f"Train samples: {train_size}, Validation samples: {val_size}")

    # Create dataloaders
    train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE,
                             shuffle=True, num_workers=4)
    val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE,
                           shuffle=False, num_workers=4)

    # Create model
    model = LandmarkCNN(num_landmarks=8).to(device)
    print(f"Model parameters: {sum(p.numel() for p in model.parameters()):,}")

    # Loss and optimizer
    criterion = nn.MSELoss()
    optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(
        optimizer, mode='min', factor=0.5, patience=5, verbose=True
    )

    # Training loop
    best_val_error = float('inf')
    history = {'train_loss': [], 'train_error': [], 'val_loss': [], 'val_error': []}

    print("\nStarting training...")
    for epoch in range(NUM_EPOCHS):
        print(f"\nEpoch {epoch+1}/{NUM_EPOCHS}")

        # Train
        train_loss, train_error = train_epoch(model, train_loader, criterion,
                                              optimizer, device)

        # Validate
        val_loss, val_error = validate(model, val_loader, criterion, device)

        # Update scheduler
        scheduler.step(val_loss)

        # Save history
        history['train_loss'].append(train_loss)
        history['train_error'].append(train_error)
        history['val_loss'].append(val_loss)
        history['val_error'].append(val_error)

        print(f"Train Loss: {train_loss:.4f}, Train Error: {train_error:.2f}px")
        print(f"Val Loss: {val_loss:.4f}, Val Error: {val_error:.2f}px")

        # Save best model
        if val_error < best_val_error:
            best_val_error = val_error
            torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'val_error': val_error,
            }, 'best_model.pth')
            print(f"âœ“ Saved best model (error: {val_error:.2f}px)")

    print(f"\nTraining complete! Best validation error: {best_val_error:.2f}px")

    # Plot training curves
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 4))

    ax1.plot(history['train_loss'], label='Train')
    ax1.plot(history['val_loss'], label='Validation')
    ax1.set_xlabel('Epoch')
    ax1.set_ylabel('Loss')
    ax1.set_title('Training and Validation Loss')
    ax1.legend()
    ax1.grid(True)

    ax2.plot(history['train_error'], label='Train')
    ax2.plot(history['val_error'], label='Validation')
    ax2.set_xlabel('Epoch')
    ax2.set_ylabel('Error (pixels)')
    ax2.set_title('Training and Validation Error')
    ax2.legend()
    ax2.grid(True)

    plt.tight_layout()
    plt.savefig('training_curves.png', dpi=150, bbox_inches='tight')
    plt.show()
    print("Training curves saved as 'training_curves.png'")

    # Visualize predictions
    print("\nGenerating predictions visualization...")
    visualize_predictions(model, full_dataset, device, num_samples=4)


if __name__ == '__main__':
    main()