# Dental Caries Detection - Training Notebook

This notebook implements training for dental caries detection using Google Colab GPU.

## Setup Steps:
1. Verify GPU availability
2. Install dependencies
3. Clone repository and set up environment
4. Prepare dataset
5. Start training

In [None]:
# First, verify GPU is enabled
!nvidia-smi

In [None]:
# Install PyTorch with CUDA support
!pip install torch==2.0.1+cu118 torchvision==0.15.2+cu118 --extra-index-url https://download.pytorch.org/whl/cu118

# Install other dependencies
!pip install albumentations==1.3.1 opencv-python==4.8.0.74 numpy==1.24.3 tqdm==4.65.0

# Import basic libraries
import os
import sys
import torch
import gc

# Set CUDA environment variables
os.environ['CUDA_LAUNCH_BLOCKING'] = '1'
os.environ['TORCH_USE_CUDA_DSA'] = '1'
os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'max_split_size_mb:128'

# Clear any existing memory
gc.collect()
if torch.cuda.is_available():
    torch.cuda.empty_cache()
    torch.cuda.reset_peak_memory_stats()

In [None]:
# Clone repository
!git clone https://github.com/projectprasanth42/dental-caries-detection.git
%cd dental-caries-detection

# Add project to path
project_path = os.path.abspath('.')
if project_path not in sys.path:
    sys.path.append(project_path)

# Test CUDA setup
def test_cuda():
    try:
        print("\nGPU Information:")
        print(f"PyTorch Version: {torch.__version__}")
        print(f"CUDA Available: {torch.cuda.is_available()}")
        if torch.cuda.is_available():
            print(f"GPU Device: {torch.cuda.get_device_name(0)}")
            print(f"CUDA Version: {torch.version.cuda}")
            
            # Test small tensor operations
            x = torch.ones(2, 2, device='cuda')
            y = x + x
            print("\nCUDA Test Successful!")
            print(f"Test tensor device: {y.device}")
            print(f"Current memory allocated: {torch.cuda.memory_allocated() / 1024**2:.2f} MB")
            
            del x, y
            torch.cuda.empty_cache()
            return True
    except Exception as e:
        print(f"\nError testing CUDA: {str(e)}")
        return False

cuda_ok = test_cuda()
if not cuda_ok:
    raise RuntimeError("CUDA setup failed. Please ensure GPU is enabled in Colab.")

In [None]:
from google.colab import drive

# Mount Google Drive
drive.mount('/content/drive')

# Set data paths
DRIVE_PATH = '/content/drive/MyDrive/dental_caries_dataset'

# Create config
from src.configs.model_config import ModelConfig

config = ModelConfig()

# Update paths
config.train_data_path = os.path.join(DRIVE_PATH, 'X_train.npy')
config.train_labels_path = os.path.join(DRIVE_PATH, 'y_train.npy')
config.val_data_path = os.path.join(DRIVE_PATH, 'X_val.npy')
config.val_labels_path = os.path.join(DRIVE_PATH, 'y_val.npy')

# Verify dataset
import numpy as np

def verify_dataset():
    print("\nChecking dataset:")
    for name, path in [
        ('Training Data', config.train_data_path),
        ('Training Labels', config.train_labels_path),
        ('Validation Data', config.val_data_path),
        ('Validation Labels', config.val_labels_path)
    ]:
        if os.path.exists(path):
            data = np.load(path)
            print(f"{name}: ✓ Found - Shape: {data.shape}")
            del data
        else:
            print(f"{name}: ✗ Not found at {path}")
            raise FileNotFoundError(f"Dataset file not found: {path}")

verify_dataset()

In [None]:
from src.training.memory_efficient_train import memory_efficient_training
import logging

# Configure logging
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(levelname)s - %(message)s'
)

# Additional memory cleanup before training
gc.collect()
if torch.cuda.is_available():
    torch.cuda.empty_cache()
    torch.cuda.reset_peak_memory_stats()

# Start training with error handling
try:
    memory_efficient_training(config)
except Exception as e:
    logging.error(f"Training failed: {str(e)}")
    logging.info("Cleaning up GPU memory...")
    gc.collect()
    if torch.cuda.is_available():
        torch.cuda.empty_cache()
        torch.cuda.reset_peak_memory_stats()
    raise