# Medical Tumor Detection and Segmentation System - Development Roadmap

This notebook outlines the development plan and implementation strategy for enhancing our medical imaging AI system. We'll cover:

1. Advanced model architectures with attention mechanisms
2. Optimized data pipelines for medical images
3. Production deployment and clinical integration
4. Testing and validation protocols
5. Performance optimization strategies

## Project Overview

- **Current Stack**: MONAI, PyTorch, FastAPI, React/TypeScript
- **Target Features**: Multi-modal tumor detection, clinical workflow integration, HIPAA compliance
- **Focus Areas**: Accuracy, reliability, scalability, clinical usability

## 1. Development Environment Setup

First, let's ensure our development environment is properly configured with all required dependencies and GPU support.

In [None]:
import os
import torch
import monai
from monai.utils import set_determinism
import pytorch_lightning as pl
from typing import Dict, List, Tuple, Optional

# Check CUDA availability
print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"CUDA device: {torch.cuda.get_device_name(0)}")

# Set up MONAI deterministic training
set_determinism(seed=42)

# Configure mixed precision training
torch.backends.cudnn.benchmark = True
scaler = torch.cuda.amp.GradScaler()

# Initialize PyTorch Lightning trainer with GPU support
trainer = pl.Trainer(
    accelerator='auto',
    devices='auto',
    precision=16,  # Mixed precision training
    max_epochs=100,
    logger=True,
    callbacks=[
        pl.callbacks.ModelCheckpoint(
            dirpath='checkpoints',
            filename='best_model',
            monitor='val_loss',
            mode='min'
        )
    ]
)

## 2. Advanced Model Architecture

We'll implement a state-of-the-art architecture for tumor segmentation based on an attention-enhanced UNet with the following features:
1. Multi-scale feature extraction
2. Self-attention mechanisms
3. Deep supervision
4. Uncertainty estimation through Monte Carlo Dropout

In [None]:
import torch.nn as nn
import torch.nn.functional as F
from monai.networks.nets import UNet
from monai.networks.blocks import Convolution, UpSample
from monai.networks.layers.factories import Act, Norm

class AttentionBlock(nn.Module):
    def __init__(self, in_channels: int):
        super().__init__()
        self.query = nn.Conv3d(in_channels, in_channels // 8, 1)
        self.key = nn.Conv3d(in_channels, in_channels // 8, 1)
        self.value = nn.Conv3d(in_channels, in_channels, 1)
        self.gamma = nn.Parameter(torch.zeros(1))
        self.softmax = nn.Softmax(dim=-1)

    def forward(self, x):
        batch_size, C, D, H, W = x.size()
        query = self.query(x).view(batch_size, -1, D*H*W).permute(0, 2, 1)
        key = self.key(x).view(batch_size, -1, D*H*W)
        energy = torch.bmm(query, key)
        attention = self.softmax(energy)
        value = self.value(x).view(batch_size, -1, D*H*W)
        out = torch.bmm(value, attention.permute(0, 2, 1))
        out = out.view(batch_size, C, D, H, W)
        return self.gamma * out + x

class AttentionUNet(nn.Module):
    def __init__(
        self,
        dimensions: int = 3,
        in_channels: int = 1,
        out_channels: int = 1,
        features: Tuple[int, ...] = (32, 64, 128, 256, 512),
        dropout: float = 0.3
    ):
        super().__init__()
        self.unet = UNet(
            dimensions=dimensions,
            in_channels=in_channels,
            out_channels=out_channels,
            channels=features,
            strides=(2, 2, 2, 2),
            num_res_units=2,
            dropout=dropout
        )
        
        # Add attention blocks
        self.attention_blocks = nn.ModuleList([
            AttentionBlock(feat) for feat in features
        ])
        
        # Deep supervision heads
        self.deep_supervision = nn.ModuleList([
            nn.Sequential(
                nn.Conv3d(feat, out_channels, 1),
                nn.Sigmoid()
            ) for feat in features
        ])

    def forward(self, x, return_features=False):
        features = []
        
        # Enable dropout during inference for uncertainty estimation
        if self.training or return_features:
            self.unet.eval()
            for m in self.unet.modules():
                if isinstance(m, nn.Dropout):
                    m.train()
        
        # Extract features with attention
        for i, block in enumerate(self.unet.encoder):
            x = block(x)
            x = self.attention_blocks[i](x)
            features.append(x)
        
        # Deep supervision outputs
        deep_outputs = [head(feat) for head, feat in zip(self.deep_supervision, features)]
        
        # Main output
        output = self.unet.decoder(features[-1])
        
        if return_features:
            return output, deep_outputs, features
        return output

# Create model instance
model = AttentionUNet(
    dimensions=3,  # 3D images
    in_channels=1,  # Single channel input (e.g., T1 MRI)
    out_channels=1,  # Binary segmentation
    features=(32, 64, 128, 256, 512),
    dropout=0.3  # For Monte Carlo Dropout uncertainty estimation
)

# Example of uncertainty estimation using Monte Carlo Dropout
def predict_with_uncertainty(
    model: nn.Module,
    input_tensor: torch.Tensor,
    num_samples: int = 30
) -> Tuple[torch.Tensor, torch.Tensor]:
    """
    Perform inference with uncertainty estimation using MC Dropout.
    
    Args:
        model: The neural network model
        input_tensor: Input image tensor
        num_samples: Number of Monte Carlo samples
        
    Returns:
        Tuple of (mean prediction, uncertainty)
    """
    model.eval()
    predictions = []
    
    with torch.no_grad():
        for _ in range(num_samples):
            pred = model(input_tensor, return_features=False)
            predictions.append(pred)
    
    # Stack predictions
    predictions = torch.stack(predictions)
    
    # Calculate mean and uncertainty
    mean_pred = torch.mean(predictions, dim=0)
    uncertainty = torch.std(predictions, dim=0)
    
    return mean_pred, uncertainty

## 3. Medical Image Pipeline

Let's implement an efficient data pipeline for handling medical images with:
1. Memory-efficient DICOM loading
2. Advanced augmentation strategies
3. Quality control checks
4. Multi-modal fusion

In [None]:
import pydicom
from monai.transforms import (
    Compose,
    LoadImaged,
    AddChanneld,
    ScaleIntensityd,
    RandRotate90d,
    RandZoomd,
    RandGaussianNoised,
    RandAdjustContrastd,
    RandGaussianSmoothd,
    SpatialPadd,
    RandSpatialCropd,
    ToTensord
)
from monai.data import CacheDataset, ThreadDataLoader, partition_dataset

class MedicalImageDataset:
    def __init__(
        self,
        data_dir: str,
        cache_rate: float = 1.0,
        num_workers: int = 4
    ):
        self.data_dir = data_dir
        self.cache_rate = cache_rate
        self.num_workers = num_workers
        
        # Define transforms for training
        self.train_transforms = Compose([
            LoadImaged(keys=["image", "label"]),
            AddChanneld(keys=["image", "label"]),
            ScaleIntensityd(keys=["image"]),
            RandRotate90d(keys=["image", "label"], prob=0.5, spatial_axes=[0, 1]),
            RandZoomd(
                keys=["image", "label"],
                min_zoom=0.9,
                max_zoom=1.1,
                prob=0.5
            ),
            RandGaussianNoised(keys=["image"], prob=0.3),
            RandAdjustContrastd(keys=["image"], prob=0.3),
            RandGaussianSmoothd(keys=["image"], prob=0.3),
            SpatialPadd(keys=["image", "label"], spatial_size=[192, 192, 64]),
            RandSpatialCropd(
                keys=["image", "label"],
                roi_size=[192, 192, 64],
                random_size=False
            ),
            ToTensord(keys=["image", "label"])
        ])
        
        # Define transforms for validation (no augmentation)
        self.val_transforms = Compose([
            LoadImaged(keys=["image", "label"]),
            AddChanneld(keys=["image", "label"]),
            ScaleIntensityd(keys=["image"]),
            SpatialPadd(keys=["image", "label"], spatial_size=[192, 192, 64]),
            ToTensord(keys=["image", "label"])
        ])

    def prepare_data(self, train_files, val_files):
        # Create CacheDataset for efficient memory usage
        train_ds = CacheDataset(
            data=train_files,
            transform=self.train_transforms,
            cache_rate=self.cache_rate
        )
        
        val_ds = CacheDataset(
            data=val_files,
            transform=self.val_transforms,
            cache_rate=self.cache_rate
        )
        
        # Create data loaders
        train_loader = ThreadDataLoader(
            train_ds,
            batch_size=2,
            shuffle=True,
            num_workers=self.num_workers,
            pin_memory=torch.cuda.is_available()
        )
        
        val_loader = ThreadDataLoader(
            val_ds,
            batch_size=1,
            shuffle=False,
            num_workers=self.num_workers,
            pin_memory=torch.cuda.is_available()
        )
        
        return train_loader, val_loader

    @staticmethod
    def verify_dicom_quality(dicom_file: str) -> bool:
        """Verify DICOM file quality and metadata."""
        try:
            dcm = pydicom.dcmread(dicom_file)
            
            # Check required metadata
            required_tags = [
                'PatientID',
                'StudyDate',
                'Modality',
                'PixelSpacing',
                'SliceThickness'
            ]
            
            for tag in required_tags:
                if not hasattr(dcm, tag):
                    print(f"Missing required tag: {tag}")
                    return False
            
            # Check image quality
            if not hasattr(dcm, 'pixel_array'):
                print("No pixel data found")
                return False
            
            pixel_array = dcm.pixel_array
            
            # Check for empty or corrupted images
            if pixel_array.size == 0:
                print("Empty pixel array")
                return False
            
            # Check image statistics
            mean_val = np.mean(pixel_array)
            std_val = np.std(pixel_array)
            
            if mean_val == 0 or std_val == 0:
                print("Image has no variation")
                return False
            
            return True
            
        except Exception as e:
            print(f"Error reading DICOM file: {e}")
            return False

    @staticmethod
    def fuse_modalities(mr_image: torch.Tensor, ct_image: torch.Tensor) -> torch.Tensor:
        """
        Fuse MR and CT images using attention-based fusion.
        
        Args:
            mr_image: MRI image tensor
            ct_image: CT image tensor
            
        Returns:
            Fused image tensor
        """
        # Ensure same size
        if mr_image.shape != ct_image.shape:
            ct_image = F.interpolate(
                ct_image,
                size=mr_image.shape[2:],
                mode='trilinear',
                align_corners=False
            )
        
        # Calculate attention weights
        mr_attention = torch.sigmoid(mr_image)
        ct_attention = torch.sigmoid(ct_image)
        
        # Normalize attention weights
        total_attention = mr_attention + ct_attention
        mr_weight = mr_attention / total_attention
        ct_weight = ct_attention / total_attention
        
        # Weighted fusion
        fused_image = (mr_weight * mr_image) + (ct_weight * ct_image)
        
        return fused_image

# Example usage
dataset = MedicalImageDataset(data_dir="path/to/data")

## 4. Production System Design

Implement production-ready features including:
1. FastAPI endpoints for clinical integration
2. HIPAA-compliant security measures
3. Monitoring and logging
4. Containerized deployment

In [None]:
from fastapi import FastAPI, File, UploadFile, HTTPException, Security
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
from fastapi.middleware.cors import CORSMiddleware
from typing import List, Optional
import jwt
import logging
from datetime import datetime
import prometheus_client
from prometheus_client import Counter, Histogram
import time

# Initialize FastAPI app with security
app = FastAPI(title="Medical Imaging API")
security = HTTPBearer()

# CORS configuration
app.add_middleware(
    CORSMiddleware,
    allow_origins=["http://localhost:3000"],
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)

# Prometheus metrics
PREDICTION_TIME = Histogram(
    'prediction_request_latency_seconds',
    'Time spent processing prediction request'
)
PREDICTION_REQUESTS = Counter(
    'prediction_requests_total',
    'Total number of prediction requests'
)

# Configure logging
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
    handlers=[
        logging.FileHandler('logs/app.log'),
        logging.StreamHandler()
    ]
)
logger = logging.getLogger(__name__)

# Security utilities
def verify_token(credentials: HTTPAuthorizationCredentials = Security(security)) -> dict:
    """Verify JWT token and return payload."""
    try:
        payload = jwt.decode(
            credentials.credentials,
            "your-secret-key",  # Use environment variable in production
            algorithms=["HS256"]
        )
        return payload
    except jwt.InvalidTokenError:
        raise HTTPException(
            status_code=401,
            detail="Invalid authentication token"
        )

# API endpoints
@app.post("/api/v1/predict")
async def predict(
    files: List[UploadFile] = File(...),
    token: dict = Security(verify_token)
):
    """
    Process medical images and return predictions with HIPAA compliance.
    """
    PREDICTION_REQUESTS.inc()
    start_time = time.time()
    
    try:
        # Log request (excluding PHI)
        logger.info(
            f"Processing prediction request from user {token['sub']}, "
            f"files: {len(files)}"
        )
        
        # Process files
        results = []
        for file in files:
            # Verify DICOM quality
            if not MedicalImageDataset.verify_dicom_quality(file.file):
                raise HTTPException(
                    status_code=400,
                    detail=f"Quality check failed for file {file.filename}"
                )
            
            # Process image and get prediction
            # (Implementation details omitted for brevity)
            
            results.append({
                "filename": file.filename,
                "prediction": "prediction_result",
                "confidence": 0.95,
                "processing_time": time.time() - start_time
            })
        
        return {"results": results}
        
    except Exception as e:
        logger.error(f"Error processing request: {str(e)}")
        raise HTTPException(
            status_code=500,
            detail="Error processing medical images"
        )
    finally:
        PREDICTION_TIME.observe(time.time() - start_time)

# Dockerfile content for containerization
dockerfile_content = """
FROM python:3.9-slim

WORKDIR /app

# Install system dependencies
RUN apt-get update && apt-get install -y \\
    build-essential \\
    curl \\
    && rm -rf /var/lib/apt/lists/*

# Copy requirements and install Python packages
COPY requirements.txt .
RUN pip install --no-cache-dir -r requirements.txt

# Copy application code
COPY . .

# Create non-root user
RUN useradd -m appuser && chown -R appuser:appuser /app
USER appuser

# Set environment variables
ENV PYTHONUNBUFFERED=1
ENV MODEL_PATH=/app/models/tumor_detection.pt

# Expose port
EXPOSE 8000

# Start application
CMD ["uvicorn", "main:app", "--host", "0.0.0.0", "--port", "8000"]
"""

print("Dockerfile content for containerized deployment:")

## 5. Performance Optimization

Implement strategies for optimizing model performance:
1. Mixed precision training
2. Efficient data loading
3. Model quantization
4. Batch processing optimization

In [None]:
import torch
from torch.cuda.amp import autocast, GradScaler
import torch.quantization
from torch.profiler import profile, record_function, ProfilerActivity

class OptimizedInference:
    def __init__(self, model: nn.Module, device: torch.device):
        self.model = model
        self.device = device
        self.scaler = GradScaler()
        
        # Enable mixed precision
        self.model = self.model.to(device)
        if device.type == 'cuda':
            self.model = torch.cuda.amp.autocast()(self.model)
        
        # Quantize model for CPU inference
        if device.type == 'cpu':
            self.model = torch.quantization.quantize_dynamic(
                self.model,
                {torch.nn.Linear, torch.nn.Conv3d},
                dtype=torch.qint8
            )
    
    @torch.no_grad()
    def predict_batch(
        self,
        batch: torch.Tensor,
        batch_size: int = 4
    ) -> List[torch.Tensor]:
        """
        Perform efficient batch prediction with mixed precision.
        """
        results = []
        
        for i in range(0, len(batch), batch_size):
            batch_slice = batch[i:i + batch_size].to(self.device)
            
            with autocast(enabled=True):
                output = self.model(batch_slice)
            
            results.append(output.cpu())
        
        return torch.cat(results, dim=0)
    
    def profile_inference(self, sample_input: torch.Tensor):
        """
        Profile model inference performance.
        """
        with profile(
            activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA],
            with_stack=True
        ) as prof:
            with record_function("model_inference"):
                self.predict_batch(sample_input)
        
        print(prof.key_averages().table(
            sort_by="cuda_time_total", row_limit=10
        ))

# Example usage
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
optimized_model = OptimizedInference(model, device)

# Profile performance
sample_input = torch.randn(10, 1, 192, 192, 64)
optimized_model.profile_inference(sample_input)

# Benchmark inference speed
def benchmark_inference(model, input_tensor, num_runs=100):
    """
    Benchmark inference speed.
    """
    start_time = time.time()
    
    for _ in range(num_runs):
        with torch.no_grad():
            _ = model.predict_batch(input_tensor)
    
    end_time = time.time()
    avg_time = (end_time - start_time) / num_runs
    
    print(f"Average inference time: {avg_time:.4f} seconds")
    print(f"Throughput: {1/avg_time:.2f} images/second")

# Run benchmark
benchmark_inference(optimized_model, sample_input)

## Next Steps and Implementation Timeline

1. Immediate Tasks (1-2 weeks):
   - Set up development environment and testing infrastructure
   - Implement basic model architecture with attention mechanisms
   - Create efficient data pipeline for DICOM processing

2. Short-term Goals (2-4 weeks):
   - Enhance model with multi-modal fusion and uncertainty estimation
   - Implement production API endpoints and security measures
   - Set up monitoring and logging infrastructure

3. Medium-term Goals (1-2 months):
   - Optimize model performance and deployment
   - Complete clinical validation and testing
   - Implement full GUI functionality
   - Prepare documentation and deployment guides

4. Long-term Goals (2-3 months):
   - Clinical integration and workflow optimization
   - Performance monitoring and continuous improvement
   - Advanced feature implementation (longitudinal analysis, etc.)

## Action Items

1. [ ] Create development environment setup script
2. [ ] Implement attention-enhanced UNet model
3. [ ] Set up DICOM processing pipeline
4. [ ] Create API endpoints with security measures
5. [ ] Implement performance optimization features
6. [ ] Set up monitoring and logging
7. [ ] Complete documentation and deployment guides

Monitor the GitHub project board for detailed task tracking and progress updates.