# Dental Caries Detection - Training Notebook

This notebook implements training for dental caries detection using Google Colab GPU.

## Setup Steps:
1. Verify GPU availability
2. Install dependencies
3. Clone repository and set up environment
4. Prepare dataset
5. Start training

In [None]:
# First, verify GPU is enabled
!nvidia-smi

In [None]:
# Install PyTorch with CUDA support
!pip install torch==2.0.1+cu118 torchvision==0.15.2+cu118 --extra-index-url https://download.pytorch.org/whl/cu118

# Install other dependencies
!pip install albumentations==1.3.1 opencv-python==4.8.0.74 numpy==1.24.3 tqdm==4.65.0

# Import basic libraries
import os
import sys
import torch
import gc

# Set CUDA environment variables
os.environ['CUDA_LAUNCH_BLOCKING'] = '1'
os.environ['TORCH_USE_CUDA_DSA'] = '1'
os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'max_split_size_mb:128'

# Clear any existing memory
gc.collect()
if torch.cuda.is_available():
    torch.cuda.empty_cache()
    torch.cuda.reset_peak_memory_stats()

In [None]:
# Clone repository
!git clone https://github.com/projectprasanth42/dental-caries-detection.git
%cd dental-caries-detection

# Add project to path
project_path = os.path.abspath('.')
if project_path not in sys.path:
    sys.path.append(project_path)

# Test CUDA setup
def test_cuda():
    try:
        print("\nGPU Information:")
        print(f"PyTorch Version: {torch.__version__}")
        print(f"CUDA Available: {torch.cuda.is_available()}")
        if torch.cuda.is_available():
            print(f"GPU Device: {torch.cuda.get_device_name(0)}")
            print(f"CUDA Version: {torch.version.cuda}")
            
            # Test small tensor operations
            x = torch.ones(2, 2, device='cuda')
            y = x + x
            print("\nCUDA Test Successful!")
            print(f"Test tensor device: {y.device}")
            print(f"Current memory allocated: {torch.cuda.memory_allocated() / 1024**2:.2f} MB")
            
            del x, y
            torch.cuda.empty_cache()
            return True
    except Exception as e:
        print(f"\nError testing CUDA: {str(e)}")
        return False

cuda_ok = test_cuda()
if not cuda_ok:
    raise RuntimeError("CUDA setup failed. Please ensure GPU is enabled in Colab.")

In [None]:
from google.colab import drive

# Mount Google Drive
drive.mount('/content/drive')

# Set data paths
DRIVE_PATH = '/content/drive/MyDrive/dental_caries_dataset'

# Create config
from src.configs.model_config import ModelConfig

config = ModelConfig()

# Update paths
config.train_data_path = os.path.join(DRIVE_PATH, 'X_train.npy')
config.train_labels_path = os.path.join(DRIVE_PATH, 'y_train.npy')
config.val_data_path = os.path.join(DRIVE_PATH, 'X_val.npy')
config.val_labels_path = os.path.join(DRIVE_PATH, 'y_val.npy')

# Verify dataset
import numpy as np

def verify_dataset():
    print("\nChecking dataset:")
    for name, path in [
        ('Training Data', config.train_data_path),
        ('Training Labels', config.train_labels_path),
        ('Validation Data', config.val_data_path),
        ('Validation Labels', config.val_labels_path)
    ]:
        if os.path.exists(path):
            data = np.load(path)
            print(f"{name}: ✓ Found - Shape: {data.shape}")
            del data
        else:
            print(f"{name}: ✗ Not found at {path}")
            raise FileNotFoundError(f"Dataset file not found: {path}")

verify_dataset()

In [None]:
from src.training.memory_efficient_train import memory_efficient_training
import logging

# Configure logging
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(levelname)s - %(message)s'
)

# Additional memory cleanup before training
gc.collect()
if torch.cuda.is_available():
    torch.cuda.empty_cache()
    torch.cuda.reset_peak_memory_stats()

# Start training with error handling
try:
    memory_efficient_training(config)
except Exception as e:
    logging.error(f"Training failed: {str(e)}")
    logging.info("Cleaning up GPU memory...")
    gc.collect()
    if torch.cuda.is_available():
        torch.cuda.empty_cache()
        torch.cuda.reset_peak_memory_stats()
    raise

## Model Evaluation and Visualization

After training, we'll evaluate the model's performance and visualize some predictions.

In [None]:
import torch
from src.models.mask_rcnn import DentalCariesMaskRCNN
import matplotlib.pyplot as plt
import cv2

def load_best_model(config):
    model = DentalCariesMaskRCNN(
        num_classes=config.num_classes,
        hidden_dim=config.hidden_dim
    ).to('cuda')
    
    # Load the best checkpoint
    checkpoints = [f for f in os.listdir('.') if f.endswith('.pth')]
    if not checkpoints:
        raise FileNotFoundError("No checkpoint files found!")
    
    # Find the best checkpoint based on loss
    best_loss = float('inf')
    best_checkpoint = None
    
    for checkpoint in checkpoints:
        state = torch.load(checkpoint)
        if state['loss'] < best_loss:
            best_loss = state['loss']
            best_checkpoint = checkpoint
    
    print(f"Loading best model from {best_checkpoint} with loss {best_loss:.4f}")
    state_dict = torch.load(best_checkpoint)
    model.load_state_dict(state_dict['model_state_dict'])
    return model

# Load the best model
model = load_best_model(config)
model.eval();

In [None]:
from src.data.dataset import DentalCariesDataset
from torch.utils.data import DataLoader
import numpy as np

def evaluate_model(model, config):
    val_dataset = DentalCariesDataset(
        config.val_data_path,
        config.val_labels_path,
        is_training=False
    )
    
    val_loader = DataLoader(
        val_dataset,
        batch_size=1,
        shuffle=False,
        num_workers=0,
        collate_fn=lambda x: tuple(zip(*x))
    )
    
    model.eval()
    total_loss = 0
    metrics = {
        'detection_loss': 0,
        'classification_loss': 0,
        'segmentation_loss': 0
    }
    
    print("\nEvaluating model on validation set...")
    with torch.no_grad():
        for images, targets in val_loader:
            images = [img.to('cuda') for img in images]
            targets = [{k: v.to('cuda') for k, v in t.items()} for t in targets]
            
            loss_dict = model(images, targets)
            total_loss += sum(loss for loss in loss_dict.values())
            
            for k, v in loss_dict.items():
                if k in metrics:
                    metrics[k] += v.item()
    
    # Calculate average metrics
    num_batches = len(val_loader)
    avg_loss = total_loss / num_batches
    metrics = {k: v/num_batches for k, v in metrics.items()}
    
    print(f"\nValidation Results:")
    print(f"Average Loss: {avg_loss:.4f}")
    for k, v in metrics.items():
        print(f"{k}: {v:.4f}")
    
    return avg_loss, metrics

# Evaluate the model
val_loss, metrics = evaluate_model(model, config)

In [None]:
def visualize_predictions(model, dataset, num_samples=5):
    model.eval()
    plt.figure(figsize=(20, 4*num_samples))
    
    for i in range(num_samples):
        # Get a random sample
        idx = np.random.randint(len(dataset))
        image, target = dataset[idx]
        
        # Get prediction
        with torch.no_grad():
            prediction = model([image.to('cuda')])[0]
        
        # Move tensors to CPU for visualization
        image = image.cpu().numpy().transpose(1, 2, 0)
        masks = prediction['masks'].cpu().numpy()
        scores = prediction['scores'].cpu().numpy()
        labels = prediction['labels'].cpu().numpy()
        
        # Plot original image
        plt.subplot(num_samples, 3, i*3 + 1)
        plt.imshow(image)
        plt.title('Original Image')
        plt.axis('off')
        
        # Plot ground truth
        plt.subplot(num_samples, 3, i*3 + 2)
        plt.imshow(image)
        for mask in target['masks'].cpu().numpy():
            plt.imshow(mask[0], alpha=0.5, cmap='jet')
        plt.title('Ground Truth')
        plt.axis('off')
        
        # Plot prediction
        plt.subplot(num_samples, 3, i*3 + 3)
        plt.imshow(image)
        for mask, score, label in zip(masks, scores, labels):
            if score > 0.5:  # Confidence threshold
                plt.imshow(mask[0], alpha=0.5, cmap='jet')
                plt.text(10, 10, f'Class {label}: {score:.2f}', 
                        color='white', bbox=dict(facecolor='red', alpha=0.5))
        plt.title('Prediction')
        plt.axis('off')
    
    plt.tight_layout()
    plt.show()

# Create validation dataset
val_dataset = DentalCariesDataset(
    config.val_data_path,
    config.val_labels_path,
    is_training=False
)

# Visualize some predictions
visualize_predictions(model, val_dataset, num_samples=3)

## Model Inference

Use the trained model to make predictions on new images.

In [None]:
def predict_image(model, image_path, confidence_threshold=0.5):
    # Load and preprocess image
    image = cv2.imread(image_path)
    image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
    
    # Convert to tensor
    image_tensor = torch.from_numpy(image).permute(2, 0, 1).float() / 255.0
    
    # Make prediction
    model.eval()
    with torch.no_grad():
        prediction = model([image_tensor.to('cuda')])[0]
    
    # Visualize results
    plt.figure(figsize=(10, 5))
    
    # Original image
    plt.subplot(1, 2, 1)
    plt.imshow(image)
    plt.title('Original Image')
    plt.axis('off')
    
    # Prediction
    plt.subplot(1, 2, 2)
    plt.imshow(image)
    
    masks = prediction['masks'].cpu().numpy()
    scores = prediction['scores'].cpu().numpy()
    labels = prediction['labels'].cpu().numpy()
    
    for mask, score, label in zip(masks, scores, labels):
        if score > confidence_threshold:
            plt.imshow(mask[0], alpha=0.5, cmap='jet')
            plt.text(10, 10 + label*20, f'Class {label}: {score:.2f}', 
                    color='white', bbox=dict(facecolor='red', alpha=0.5))
    
    plt.title('Prediction')
    plt.axis('off')
    plt.show()
    
    return prediction

# Example usage:
# prediction = predict_image(model, 'path_to_new_image.jpg')

## Save Model for Deployment

Save the trained model for future use or deployment.

In [None]:
# Save model to Drive
SAVE_PATH = os.path.join(DRIVE_PATH, 'trained_model.pth')

torch.save({
    'model_state_dict': model.state_dict(),
    'config': config.__dict__,
    'metrics': metrics
}, SAVE_PATH)

print(f"Model saved to {SAVE_PATH}")