In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import confusion_matrix
import numpy as np
from tqdm.notebook import tqdm
import random
import os
import torch.nn.functional as F
import math
import wandb
from torch.optim import AdamW
from collections import defaultdict
import psutil
import sys
import json
import traceback
from datetime import datetime


In [2]:
torch.set_default_dtype(torch.float32)

In [3]:
from dataset import AdvancedXRayTransforms, XRayDataset, custom_collate, AdvancedBatchSampler
from model import create_model_and_optimizer

# Set random seeds for reproducibility
def seed_everything(seed=11):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True

seed_everything()

# Set device
device = torch.device('mps' if torch.backends.mps.is_available() else 'cpu')
print(f"Using device: {device}")

Using device: mps


In [4]:
def get_cosine_schedule_with_warmup(optimizer, num_warmup_steps, num_training_steps, num_cycles=0.5):
    def lr_lambda(current_step):
        if current_step < num_warmup_steps:
            return float(current_step) / float(max(1, num_warmup_steps))
        progress = float(current_step - num_warmup_steps) / float(max(1, num_training_steps - num_warmup_steps))
        return max(0.0, 0.5 * (1.0 + math.cos(math.pi * float(num_cycles) * progress)))
    
    return torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)

In [5]:
class EnhancedTrainingMonitor:
    def __init__(self, model, class_distribution, log_dir='training_logs'):
        """
        Initialize the enhanced training monitor with comprehensive tracking capabilities.
        
        Args:
            model: The PyTorch model being trained
            class_distribution: Dictionary mapping class indices to their frequencies
            log_dir: Directory for saving training logs
        """
        self.model = model
        self.num_classes = len(class_distribution)
        self.class_distribution = class_distribution
        self.log_dir = log_dir
        
        # Initialize statistics tracking
        self.stats = {
            'gradient_norms': defaultdict(list),
            'layer_metrics': defaultdict(list),
            'class_accuracies': {i: [] for i in range(self.num_classes)},
            'class_predictions': {i: [] for i in range(self.num_classes)},
        }
        
        # Create log directory
        os.makedirs(log_dir, exist_ok=True)
        
        # Setup gradient monitoring
        self.setup_gradient_hooks()
        
        # Compute initial class weights
        self.class_weights = self.compute_class_weights()

    def setup_gradient_hooks(self):
        """Setup gradient monitoring hooks for all parameters."""
        for name, param in self.model.named_parameters():
            if param.requires_grad:
                param.register_hook(lambda grad, name=name: self._gradient_monitor(grad, name))

    def _gradient_monitor(self, grad, param_name):
        """Monitor and stabilize gradients."""
        if grad is not None:
            grad_norm = grad.norm().item()
            self.stats['gradient_norms'][param_name].append(grad_norm)
            
            # Handle gradient anomalies
            if grad_norm > 10:  # Unusually high gradients
                grad = grad * (10 / grad_norm)  # Scale down
            
            if torch.isnan(grad).any() or torch.isinf(grad).any():
                grad = torch.where(
                    torch.isnan(grad) | torch.isinf(grad),
                    torch.zeros_like(grad),
                    grad
                )
            
        return grad

    def compute_class_weights(self):
        """Compute balanced class weights based on class distribution."""
        total_samples = sum(self.class_distribution.values())
        weights = []
        for i in range(self.num_classes):
            if i in self.class_distribution and self.class_distribution[i] > 0:
                weight = total_samples / (self.num_classes * self.class_distribution[i])
            else:
                weight = 1.0
            weights.append(weight)
        return torch.tensor(weights)

    def log_batch_metrics(self, batch_idx, loss, accuracy, optimizer, outputs, targets, global_step):
        """Log comprehensive batch-level metrics."""
        # Calculate gradient norms
        total_norm = 0
        for p in self.model.parameters():
            if p.grad is not None:
                param_norm = p.grad.data.norm(2)
                total_norm += param_norm.item() ** 2
        total_norm = total_norm ** 0.5
        
        # Get predictions and update class-wise metrics
        _, preds = torch.max(outputs, 1)
        for class_idx in range(self.num_classes):
            # Update class accuracies
            mask = targets == class_idx
            if mask.sum() > 0:
                class_acc = (preds[mask] == targets[mask]).float().mean().item()
                self.stats['class_accuracies'][class_idx].append(class_acc)
            
            # Track prediction distribution
            class_preds = (preds == class_idx).sum().item()
            self.stats['class_predictions'][class_idx].append(class_preds)
        
        # Get memory stats
        memory_stats = psutil.Process().memory_info()
        memory_used_gb = memory_stats.rss / (1024 ** 3)
        
        # Prepare metrics dictionary
        metrics = {
            'batch/gradient_norm': total_norm,
            'batch/learning_rate': optimizer.param_groups[0]['lr'],
            'batch/loss': loss,
            'batch/accuracy': accuracy,
            'memory/used_gb': memory_used_gb
        }
        
        # Add class-specific metrics
        for class_idx in range(self.num_classes):
            if self.stats['class_accuracies'][class_idx]:
                metrics[f'class_{class_idx}/accuracy'] = np.mean(
                    self.stats['class_accuracies'][class_idx][-100:]
                )
        
        # Log to wandb
        wandb.log(metrics, step=global_step)
        
        # Check for class imbalance every 50 batches
        if batch_idx % 50 == 0:
            self._check_class_imbalance()
            self._print_batch_summary(batch_idx, loss, accuracy, total_norm, memory_used_gb)

    def _check_class_imbalance(self):
        """Check and warn about class imbalance in predictions."""
        pred_counts = {i: np.mean(preds[-100:]) for i, preds in self.stats['class_predictions'].items()}
        total_preds = sum(pred_counts.values())
        
        if total_preds > 0:
            pred_distribution = {k: v/total_preds for k, v in pred_counts.items()}
            
            for class_idx, freq in pred_distribution.items():
                if freq < 0.1:
                    print(f"\nWARNING: Class {class_idx} is severely underrepresented "
                          f"({freq*100:.1f}% of predictions)")
                elif freq > 0.4:
                    print(f"\nWARNING: Class {class_idx} is severely overrepresented "
                          f"({freq*100:.1f}% of predictions)")

    def _print_batch_summary(self, batch_idx, loss, accuracy, grad_norm, memory_used_gb):
        """Print detailed batch summary."""
        print("\n" + "="*50)
        print(f"Batch {batch_idx} Detailed Metrics:")
        print("-"*20)
        print(f"Loss: {loss:.4f}")
        print(f"Accuracy: {accuracy:.2%}")
        print(f"Gradient Norm: {grad_norm:.4f}")
        print(f"Memory Used: {memory_used_gb:.2f}GB")
        print("\nPer-class Accuracies:")
        for class_idx, accs in self.stats['class_accuracies'].items():
            if accs:
                recent_acc = np.mean(accs[-50:])
                print(f"Class {class_idx}: {recent_acc:.2%}")
        print("="*50)

    def log_validation_metrics(self, val_metrics, epoch, global_step):
        """Log comprehensive validation metrics."""
        print("\n" + "="*50)
        print(f"Validation Metrics - Epoch {epoch}")
        print("-"*20)
        print(f"Loss: {val_metrics['val_loss']:.4f}")
        print(f"Accuracy: {val_metrics['accuracy']:.2%}")
        print(f"Specificity: {val_metrics['specificity']:.2%}")
        print(f"Sensitivity: {val_metrics['sensitivity']:.2%}")
        
        # Print confusion matrix
        cm = val_metrics['confusion_matrix']
        print("\nConfusion Matrix:")
        print(cm)
        
        # Calculate and print per-class metrics
        print("\nPer-class Metrics:")
        for i in range(self.num_classes):
            tp = cm[i, i]
            fp = cm[:, i].sum() - tp
            fn = cm[i, :].sum() - tp
            tn = cm.sum() - (tp + fp + fn)
            
            precision = tp / (tp + fp) if (tp + fp) > 0 else 0
            recall = tp / (tp + fn) if (tp + fn) > 0 else 0
            f1 = 2 * (precision * recall) / (precision + recall) if (precision + recall) > 0 else 0
            
            print(f"\nClass {i}:")
            print(f"Precision: {precision:.2%}")
            print(f"Recall: {recall:.2%}")
            print(f"F1-Score: {f1:.2%}")
        
        print("="*50)
        
        # Log to wandb
        wandb.log({
            'val/loss': val_metrics['val_loss'],
            'val/accuracy': val_metrics['accuracy'],
            'val/specificity': val_metrics['specificity'],
            'val/sensitivity': val_metrics['sensitivity']
        }, step=global_step)

In [6]:
def verify_data_loading(train_loader):
    """Verify data loading and print initial statistics"""
    print("\nVerifying data loading...")
    try:
        batch = next(iter(train_loader))
        inputs, targets_a, targets_b, lams = batch
        
        print("\nInitial batch statistics:")
        print(f"Input shape: {inputs.shape}")
        print(f"Memory usage per batch: {inputs.element_size() * inputs.nelement() / (1024**2):.2f}MB")
        print(f"Target A distribution: {torch.bincount(targets_a)}")
        print(f"Target B distribution: {torch.bincount(targets_b)}")
        print(f"Lambda range: {lams.min().item():.3f} - {lams.max().item():.3f}")
        print("\nData loading verification completed successfully!")
        
    except Exception as e:
        print(f"Error during data loading verification: {str(e)}")
        raise e

In [7]:
class DebugMonitor:
    def __init__(self):
        self.grad_norms = defaultdict(list)
        self.activation_stats = defaultdict(list)
        self.weight_stats = defaultdict(list)
        
    def update_grad_stats(self, model):
        """Track gradient statistics for each parameter"""
        for name, param in model.named_parameters():
            if param.grad is not None:
                self.grad_norms[name].append(param.grad.norm().item())
                
    def update_activation_stats(self, name, tensor):
        """Track activation statistics"""
        stats = {
            'mean': tensor.mean().item(),
            'std': tensor.std().item(),
            'max': tensor.max().item(),
            'min': tensor.min().item()
        }
        self.activation_stats[name].append(stats)
        
    def update_weight_stats(self, model):
        """Track weight statistics"""
        for name, param in model.named_parameters():
            stats = {
                'mean': param.data.mean().item(),
                'std': param.data.std().item(),
                'max': param.data.max().item(),
                'min': param.data.min().item()
            }
            self.weight_stats[name].append(stats)
            
    def log_to_wandb(self, step):
        """Log statistics to W&B"""
        # Log gradient norms
        for name, norms in self.grad_norms.items():
            if norms:  # Only log if we have data
                wandb.log({f'grad_norm/{name}': norms[-1]}, step=step)
        
        # Log activation statistics
        for name, stats_list in self.activation_stats.items():
            if stats_list:
                latest_stats = stats_list[-1]
                for stat_name, value in latest_stats.items():
                    wandb.log({f'activation/{name}/{stat_name}': value}, step=step)
        
        # Log weight statistics
        for name, stats_list in self.weight_stats.items():
            if stats_list:
                latest_stats = stats_list[-1]
                for stat_name, value in latest_stats.items():
                    wandb.log({f'weight/{name}/{stat_name}': value}, step=step)

In [8]:
def clear_memory():
    """Clear unused memory caches."""
    if torch.cuda.is_available():
        torch.cuda.empty_cache()
    elif torch.backends.mps.is_available():
        # MPS (Apple Silicon) doesn't need explicit cleanup
        pass
    
    import gc
    gc.collect()

In [9]:
def create_data_loaders(config):
    """Create optimized data loaders with balanced sampling and verification"""
    
    # Data augmentation and preprocessing
    train_transform = AdvancedXRayTransforms.get_train_transform()
    val_transform = AdvancedXRayTransforms.get_val_transform()
    
    # Create datasets
    train_dataset = XRayDataset(
        root_dir=config['train_dir'],
        transform=train_transform,
        phase='train',
        mixup_alpha=0.2
    )
    
    val_dataset = XRayDataset(
        root_dir=config['val_dir'],
        transform=val_transform,
        phase='val'
    )
    
    train_sampler = AdvancedBatchSampler(
        dataset=train_dataset,
        batch_size=32,
        balance_strategy='oversample',  # or 'weights' or 'stratified'
        oversample_multiplier=1.2,
        dynamic_balance=True,
        min_class_samples=2
    )

    
    train_loader = DataLoader(
        train_dataset,
        batch_sampler=train_sampler,
        num_workers=config['num_workers'],
        pin_memory=True,
        collate_fn=custom_collate
    )
    
    val_loader = DataLoader(
        val_dataset,
        batch_size=config['batch_size'],
        shuffle=False,
        num_workers=config['num_workers'],
        pin_memory=True,
        collate_fn=custom_collate
    )

    # Verify batch distribution
    print("\nVerifying training batch distribution...")
    verify_batch_distribution(train_loader, num_batches=5)
    
    return train_loader, val_loader

def verify_batch_distribution(loader, num_batches=5):
    """Verify the class distribution in the first few batches"""
    class_counts = {i: 0 for i in range(5)}
    for i, (_, targets_a, _, _) in enumerate(loader):
        if i >= num_batches:
            break
        for target in targets_a:
            class_counts[target.item()] += 1
    
    total = sum(class_counts.values())
    print("\nBatch distribution check:")
    for cls, count in class_counts.items():
        print(f"Class {cls}: {count} samples ({count/total*100:.1f}%)")

In [10]:
class IntegratedTrainer:
    def __init__(self, model, train_loader, val_loader, criterion, optimizer, scheduler, device, config):
        self.model = model
        self.train_loader = train_loader
        self.val_loader = val_loader
        self.criterion = criterion
        self.optimizer = optimizer
        self.scheduler = scheduler
        self.device = device
        self.config = config
        self.global_step = 0
        self.best_val_acc = 0
        
        # Initialize monitoring system
        self.training_monitor = EnhancedTrainingMonitor(
            model=model,
            class_distribution=config['class_distribution'],
            log_dir='training_logs'
        )
        
        # Enable gradient monitoring in wandb
        wandb.watch(model, log="gradients", log_freq=200)
        
    def train_one_epoch(self, epoch):
        self.model.train()
        running_loss = 0.0
        running_acc = 0.0

        if epoch == 0:
            print(f"\nVerifying batch distribution for epoch {epoch}...")
            verify_batch_distribution(self.train_loader, num_batches=5)
        
        # Calculate mixup alpha based on epoch
        current_mixup_alpha = self.config['mixup_alpha'] * min(1.0, epoch / self.config['warmup_mixup_epochs'])
        
        pbar = tqdm(enumerate(self.train_loader), total=len(self.train_loader))
        
        for batch_idx, (inputs, targets_a, targets_b, lams) in pbar:
            try:
                # Move to device
                inputs = inputs.to(self.device, non_blocking=True)
                targets_a = targets_a.to(self.device, non_blocking=True)
                targets_b = targets_b.to(self.device, non_blocking=True)
                lams = lams.to(self.device, non_blocking=True)
                
                # Gradient accumulation steps
                is_accumulation_step = (batch_idx + 1) % self.config['accumulation_steps'] != 0
                
                # Forward pass
                outputs = self.model(inputs)
                if epoch >= self.config['warmup_epochs']:
                    loss_a = self.criterion(outputs, targets_a)
                    loss_b = self.criterion(outputs, targets_b)
                    loss = (lams * loss_a + (1 - lams) * loss_b).mean()
                else:
                    loss = self.criterion(outputs, targets_a)
                
                # Scale loss for gradient accumulation
                loss = loss / self.config['accumulation_steps']
                
                # Backward pass
                loss.backward()
                
                if not is_accumulation_step:
                    # Gradient clipping
                    torch.nn.utils.clip_grad_norm_(
                        self.model.parameters(),
                        self.config['gradient_clip_val']
                    )
                    
                    self.optimizer.step()
                    self.optimizer.zero_grad(set_to_none=True)
                    
                    if self.scheduler is not None:
                        self.scheduler.step()
                
                # Calculate accuracy
                with torch.no_grad():
                    _, preds = torch.max(outputs.data, 1)
                    targets = targets_a if epoch < self.config['warmup_epochs'] else torch.where(lams.view(-1) > 0.5, targets_a, targets_b)
                    accuracy = (preds == targets).float().mean()
                
                # Update metrics
                running_loss += loss.item() * self.config['accumulation_steps']
                running_acc += accuracy.item()

                # Log batch metrics
                if batch_idx % 50 == 0:
                    metrics = {
                        'batch/loss': loss.item() * self.config['accumulation_steps'],
                        'batch/accuracy': accuracy.item(),
                        'batch/learning_rate': self.optimizer.param_groups[0]['lr'],
                        'batch/memory_used': psutil.Process().memory_info().rss / 1024 / 1024 / 1024  # GB
                    }
                    
                    wandb.log(metrics, step=self.global_step)
                    
                    self.training_monitor.log_batch_metrics(
                        batch_idx=batch_idx,
                        loss=loss.item() * self.config['accumulation_steps'],
                        accuracy=accuracy.item(),
                        optimizer=self.optimizer,
                        outputs=outputs,
                        targets=targets,
                        global_step=self.global_step
                    )
                    
                    self.global_step += 1
                
                # Clear memory periodically
                if batch_idx % self.config['clear_cache_freq'] == 0:
                    clear_memory()
                
                # Update progress bar
                pbar.set_description(
                    f'Epoch {epoch} | Loss: {running_loss/(batch_idx+1):.4f} | '
                    f'Acc: {running_acc/(batch_idx+1):.4f}'
                )
                
            except Exception as e:
                print(f"Error in batch {batch_idx}: {str(e)}")
                clear_memory()
                continue
        
        return running_loss / len(self.train_loader), running_acc / len(self.train_loader)
    
    def validate(self):
        self.model.eval()
        val_loss = 0.0
        predictions = []
        targets = []
        
        with torch.no_grad():
            for inputs, targets_a, _, _ in tqdm(self.val_loader, desc='Validating'):
                inputs = inputs.to(self.device, non_blocking=True)
                targets_a = targets_a.to(self.device, non_blocking=True)
                
                outputs = self.model(inputs)
                loss = self.criterion(outputs, targets_a)
                
                val_loss += loss.item()
                _, preds = torch.max(outputs, 1)
                
                predictions.extend(preds.cpu().numpy())
                targets.extend(targets_a.cpu().numpy())
        
        predictions = np.array(predictions)
        targets = np.array(targets)
        
        metrics = self._compute_validation_metrics(val_loss, predictions, targets)
        
        # Log validation metrics with current global step
        self.training_monitor.log_validation_metrics(metrics, self.current_epoch, self.global_step)
        
        return metrics

    def _compute_validation_metrics(self, val_loss, predictions, targets):
        """Compute comprehensive validation metrics"""
        accuracy = np.mean(predictions == targets)
        specificity = self._calculate_specificity(targets, predictions)
        sensitivity = self._calculate_sensitivity(targets, predictions)
        cm = confusion_matrix(targets, predictions)

        metrics = {
            'val_loss': val_loss / len(self.val_loader),
            'accuracy': accuracy,
            'specificity': specificity,
            'sensitivity': sensitivity,
            'confusion_matrix': cm
        }
        
        return metrics
        
    def _plot_confusion_matrix(self, cm, epoch):
        plt.figure(figsize=(10, 8))
        sns.heatmap(cm, annot=True, fmt='d', cmap='Blues')
        plt.title(f'Confusion Matrix - Epoch {epoch}')
        plt.ylabel('True Label')
        plt.xlabel('Predicted Label')
        wandb.log({"confusion_matrix": wandb.Image(plt)}, step=self.global_step)
        plt.close()
        
    def _calculate_specificity(self, y_true, y_pred):
        cm = confusion_matrix(y_true, y_pred)
        fp = cm.sum(axis=0) - np.diag(cm)
        tn = cm.sum() - (fp + cm.sum(axis=1) - np.diag(cm) + np.diag(cm))
        specificity = tn / (tn + fp)
        return np.mean(specificity)
    
    def _calculate_sensitivity(self, y_true, y_pred):
        cm = confusion_matrix(y_true, y_pred)
        tp = np.diag(cm)
        fn = cm.sum(axis=1) - tp
        sensitivity = tp / (tp + fn)
        return np.mean(sensitivity)
    
    def train(self):
        # Add verification before training starts
        verify_data_loading(self.train_loader)
        
        patience_counter = 0
        for epoch in range(self.config['num_epochs']):
            self.current_epoch = epoch
            print(f'\nEpoch {epoch+1}/{self.config["num_epochs"]}')
            
            # Train
            train_loss, train_acc = self.train_one_epoch(epoch)
            
            # Validate
            val_metrics = self.validate()
            
            # Log epoch metrics
            wandb.log({
                'epoch': epoch,
                'train/loss': train_loss,
                'train/accuracy': train_acc,
                'val/loss': val_metrics['val_loss'],
                'val/accuracy': val_metrics['accuracy'],
                'val/specificity': val_metrics['specificity'],
                'val/sensitivity': val_metrics['sensitivity'],
                'learning_rate': self.optimizer.param_groups[0]['lr']
            }, step=self.global_step)
            
            # Plot confusion matrix
            self._plot_confusion_matrix(val_metrics['confusion_matrix'], epoch)
            
            # Update learning rate scheduler
            if isinstance(self.scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau):
                self.scheduler.step(val_metrics['accuracy'])
            elif self.scheduler is not None:
                self.scheduler.step()
            
            # Early stopping check
            if val_metrics['accuracy'] > self.best_val_acc:
                self.best_val_acc = val_metrics['accuracy']
                patience_counter = 0
                self._save_checkpoint(epoch, val_metrics)
                print(f'\nBest val accuracy updated to: {val_metrics["accuracy"]}')
            else:
                patience_counter += 1
            
            if patience_counter >= self.config['early_stopping_patience']:
                print(f'\nEarly stopping triggered after {epoch + 1} epochs')
                break
            
            self.global_step += 1
        
        return self.best_val_acc
    
    def _save_checkpoint(self, epoch, val_metrics):
        """Save model checkpoint with monitoring stats"""
        checkpoint = {
            'epoch': epoch,
            'model_state_dict': self.model.state_dict(),
            'optimizer_state_dict': self.optimizer.state_dict(),
            'scheduler_state_dict': self.scheduler.state_dict() if self.scheduler else None,
            'best_val_acc': val_metrics['accuracy'],
            'config': self.config,
            'monitoring_stats': self.training_monitor.stats,
            'global_step': self.global_step
        }
        torch.save(
            checkpoint,
            f'{self.config["output_dir"]}/best_model.pth'
        )
        wandb.log({"best_val_accuracy": val_metrics['accuracy']}, step=self.global_step)

In [11]:
def train_xray_model(config):
    """
    Enhanced training function with improved initialization, monitoring, and error handling
    """
    # Set device with better handling
    device = _setup_device()
    
    # Set deterministic behavior for reproducibility
    _set_deterministic(seed=config.get('seed', 11))
    
    # Validate config
    _validate_config(config)
    
    # Create output directory with timestamp
    timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
    output_dir = os.path.join(config['output_dir'], timestamp)
    os.makedirs(output_dir, exist_ok=True)
    config['output_dir'] = output_dir  # Update config with new path
    
    # Save config for reproducibility
    with open(os.path.join(output_dir, 'config.json'), 'w') as f:
        json.dump(config, f, indent=4)
    
    try:
        # Initialize data loaders with error handling
        train_loader, val_loader = create_data_loaders(config)
        
        # Create model, optimizer, and criterion
        model, optimizer, criterion = create_model_and_optimizer(config)
        model = model.to(device)
        criterion = criterion.to(device)
        
        # Enhanced scheduler initialization
        scheduler = _create_scheduler(optimizer, train_loader, config)
        
        # Initialize trainer with memory profiling
        trainer = IntegratedTrainer(
            model=model,
            train_loader=train_loader,
            val_loader=val_loader,
            criterion=criterion,
            optimizer=optimizer,
            scheduler=scheduler,
            device=device,
            config=config
        )
        
        # Optional data verification
        if config.get('verify_data', True):
            print("\nVerifying data loading...")
            verify_data_loading(train_loader)
            
        # Log hardware info and memory usage
        _log_system_info()
        
        # Train model with error handling
        best_acc = trainer.train()
        
        # Save training summary
        _save_training_summary(config, best_acc, output_dir)
        
        print(f"\nTraining completed successfully.")
        print(f"Best validation accuracy: {best_acc:.4f}")
        print(f"Model and logs saved to: {output_dir}")
        
        return best_acc, output_dir
        
    except Exception as e:
        print(f"\nError during training: {str(e)}")
        traceback.print_exc()
        raise

def _setup_device():
    """Set up and return the appropriate device with better error handling"""
    if torch.backends.mps.is_available():
        device = torch.device("mps")
        print("Using MPS device")
        # Verify MPS is working properly
        try:
            torch.zeros(1).to(device)
        except Exception as e:
            print(f"MPS initialization failed: {e}")
            print("Falling back to CPU")
            device = torch.device("cpu")
    else:
        device = torch.device("cpu")
        print("Using CPU device")
    
    return device

def _set_deterministic(seed):
    """Set seeds and deterministic behavior"""
    torch.manual_seed(seed)
    np.random.seed(seed)
    random.seed(seed)
    if torch.backends.mps.is_available():
        torch.mps.manual_seed(seed)
    
    # Enable deterministic behavior
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

def _validate_config(config):
    """Validate essential config parameters"""
    required_keys = [
        'num_epochs', 'output_dir', 'class_distribution',
        'warmup_epochs', 'scheduler_cycles'
    ]
    
    for key in required_keys:
        if key not in config:
            raise ValueError(f"Missing required config parameter: {key}")
            
    if config['warmup_epochs'] >= config['num_epochs']:
        raise ValueError("warmup_epochs should be less than num_epochs")

def _create_scheduler(optimizer, train_loader, config):
    """Create learning rate scheduler with proper initialization"""
    num_training_steps = len(train_loader) * config['num_epochs']
    num_warmup_steps = len(train_loader) * config['warmup_epochs']
    
    return get_cosine_schedule_with_warmup(
        optimizer,
        num_warmup_steps=num_warmup_steps,
        num_training_steps=num_training_steps,
        num_cycles=config['scheduler_cycles']
    )

def _log_system_info():
    """Log system information and memory usage"""
    print("\nSystem Information:")
    print(f"Python version: {sys.version}")
    print(f"PyTorch version: {torch.__version__}")
    print(f"Number of CPUs: {os.cpu_count()}")
    
    # Memory information
    memory = psutil.virtual_memory()
    print(f"Total RAM: {memory.total / (1024**3):.1f} GB")
    print(f"Available RAM: {memory.available / (1024**3):.1f} GB")
    
    # GPU information if available
    if torch.backends.mps.is_available():
        print("MPS device available")

def _save_training_summary(config, best_acc, output_dir):
    """Save training summary to file"""
    summary = {
        'best_accuracy': float(best_acc),
        'training_completed': datetime.now().strftime('%Y-%m-%d %H:%M:%S'),
        'config': config
    }
    
    summary_path = os.path.join(output_dir, 'training_summary.json')
    with open(summary_path, 'w') as f:
        json.dump(summary, f, indent=4)

In [12]:
config = {
    # Model parameters
    'num_classes': 5,
    'batch_size': 32,  # Reduced for M1 memory constraints
    'num_epochs': 100,  # Reduced but should be sufficient with better optimization
    
    # Optimization parameters
    'base_learning_rate': 1e-4,  # Reduced for more stable training
    'weight_decay': 2e-5,  # Reduced to prevent over-regularization
    'gradient_clip_val': 0.5,  # More aggressive clipping for stability
    'accumulation_steps': 4,  # Added to effectively increase batch size
    
    # Loss function parameters
    'focal_loss_params': {
        'gamma': 0.5,  # Reduced from 2.0 for more stable training
        'smoothing': 0.01  # Reduced smoothing for medical images
    },

    # Memory optimization for M1
    'memory_management': {
        'gradient_checkpointing': True,  # Enable to save memory
        'empty_cache_freq': 3
    },
    
    # Learning rate schedule
    'scheduler_cycles': 1,  # Full cosine cycle
    'warmup_epochs': 3,  # Increased warmup period
    'warmup_mixup_epochs': 5,
    
    # Mixup parameters
    'mixup_alpha': 0.1,  # Reduced for medical domain
    
    # Early stopping
    'early_stopping_patience': 15,  # Increased to allow convergence
    
    # Data parameters - keeping your existing paths
    'train_dir': '/Users/Viku/Datasets/Medical/Knee/train',
    'val_dir': '/Users/Viku/Datasets/Medical/Knee/val',
    'output_dir': 'training_output_v2',
    'class_distribution': {0: 2286, 1: 1046, 2: 1516, 3: 757, 4: 173},
    
    # System parameters for M1
    'num_workers': 2,  # Reduced for M1
    'pin_memory': False,
    'prefetch_factor': 2,
    'clear_cache_freq': 2,
    
    # Monitoring
    'verify_data': True,
    'monitor_gradients': True,
    'save_checkpoint_freq': 3,
    'gradient_logging_freq': 200,
    
    # M1 specific
    'mps_memory_limit': 12 * 1024 * 1024 * 1024,  # 12GB limit
    'compile_model': False  # Disabled for better stability
}


if __name__ == "__main__":
    # Initialize wandb
    wandb.init(
        project="xray-classification-v2",
        config=config,
        name="improved-training-run-v5"
    )
    
    try:
        train_xray_model(config)
    except KeyboardInterrupt:
        print("\nTraining interrupted by user")
    except Exception as e:
        print(f"\nError during training: {str(e)}")
        raise e
    finally:
        wandb.finish()

[34m[1mwandb[0m: Currently logged in as: [33mvignesh-rox03[0m ([33mvignesh-rox03-vellore-institute-of-technology[0m). Use [1m`wandb login --relogin`[0m to force relogin
[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.


Using MPS device

Error during training: 1 validation error for InitSchema
size
  Input should be a valid tuple [type=tuple_type, input_value=224, input_type=int]
    For further information visit https://errors.pydantic.dev/2.10/v/tuple_type

Error during training: 1 validation error for InitSchema
size
  Input should be a valid tuple [type=tuple_type, input_value=224, input_type=int]
    For further information visit https://errors.pydantic.dev/2.10/v/tuple_type


Traceback (most recent call last):
  File "/var/folders/_2/xx5z8xdj6j98wh4vt2b59jz80000gp/T/ipykernel_59648/2248032029.py", line 26, in train_xray_model
    train_loader, val_loader = create_data_loaders(config)
  File "/var/folders/_2/xx5z8xdj6j98wh4vt2b59jz80000gp/T/ipykernel_59648/4202239987.py", line 5, in create_data_loaders
    train_transform = AdvancedXRayTransforms.get_train_transform()
  File "/Users/Viku/GitHub/Cloned/knee_osteoarthritis_detection/Code/Vignesh/Classification_V2/dataset.py", line 22, in get_train_transform
    A.RandomResizedCrop(
  File "/Users/Viku/GitHub/Cloned/knee_osteoarthritis_detection/.venv/lib/python3.10/site-packages/albumentations/core/validation.py", line 35, in custom_init
    config = dct["InitSchema"](**full_kwargs)
  File "/Users/Viku/GitHub/Cloned/knee_osteoarthritis_detection/.venv/lib/python3.10/site-packages/pydantic/main.py", line 214, in __init__
    validated_self = self.__pydantic_validator__.validate_python(data, self_instance=self)


ValidationError: 1 validation error for InitSchema
size
  Input should be a valid tuple [type=tuple_type, input_value=224, input_type=int]
    For further information visit https://errors.pydantic.dev/2.10/v/tuple_type

In [14]:
# import shutil
# import os
# from datetime import datetime

# # Create backup with timestamp
# timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
# current_model_path = "training_output/best_model.pth"
# backup_path = f"training_output/best_model_acc60_{timestamp}.pth"

# if os.path.exists(current_model_path):
#     shutil.copy(current_model_path, backup_path)
#     print(f"Backed up current model to: {backup_path}")

In [15]:
# model, optimizer, criterion = create_model_and_optimizer(config)
# model = model.to(device)

# pytorch_total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)

# print( pytorch_total_params )

In [16]:
# def train_with_kfold(config, device, num_workers=2):
#     print("Initializing k-fold training...")
#     output_dir = 'output'
#     os.makedirs(output_dir, exist_ok=True)
    
#     n_splits = 5
#     fold_results = []
#     fold_histories = []
    
#     for fold in range(n_splits):
#         print(f"\nStarting Fold {fold + 1}/{n_splits}")
#         fold_dir = f'{output_dir}/fold_{fold+1}'
#         os.makedirs(fold_dir, exist_ok=True)

#         # Calculate class weights
#         class_counts = torch.tensor([config['class_distribution'][i] for i in range(config['num_classes'])])
#         class_weights = 1.0 / class_counts
#         class_weights = class_weights / class_weights.sum() * config['num_classes']
#         class_weights = class_weights.to(device)
        
#         # Create datasets with full transforms
#         train_transform, val_transform = get_improved_transforms()
#         train_dataset = XRayDataset(config['train_dir'], transform=train_transform, mixup_alpha=0.2)
#         val_dataset = XRayDataset(config['val_dir'], transform=val_transform)
        
#         # Use optimized sampler for training
#         train_sampler = OptimizedBatchSampler(train_dataset, config['batch_size'])
        
#         # Create data loaders
#         train_loader = DataLoader(
#             train_dataset,
#             batch_sampler=train_sampler,
#             num_workers=num_workers,
#             pin_memory=True
#         )
        
#         val_loader = DataLoader(
#             val_dataset,
#             batch_size=config['batch_size'],
#             shuffle=False,
#             num_workers=num_workers,
#             pin_memory=True
#         )
        
#         # Initialize model
#         model = IntegratedXRayClassifier(
#             num_classes=config['num_classes'],
#             hidden_size=config['hidden_size']
#         ).to(device)
        
#         # Initialize the balanced optimizer and loss
#         optimizer = BalancedOptimizer(
#             model, 
#             config['base_learning_rate'],
#             config['weight_decay'],
#             config['class_distribution']
#         )

#         criterion = ImprovedCombinedLoss(
#             class_weights=optimizer.class_weights.to(device)
# )
        
#         # Scheduler setup
#         num_warmup_steps = 100
#         num_training_steps = len(train_loader) * config['num_epochs']
#         scheduler = get_linear_schedule_with_warmup(
#             optimizer,
#             num_warmup_steps=num_warmup_steps,
#             num_training_steps=num_training_steps
#         )
        
#         # Training loop variables
#         best_val_acc = 0
#         patience_counter = 0
#         best_val_preds = None
#         best_val_targets = None
#         best_val_probs = None
        
#         fold_history = {
#             'train_loss': [], 'train_acc': [],
#             'val_loss': [], 'val_acc': []
#         }
        
#         # Training loop
#         for epoch in range(config['num_epochs']):
#             print(f'\nEpoch {epoch+1}/{config["num_epochs"]}')
            
#             train_loss, train_acc = train_one_epoch(
#                 model, train_loader, criterion, optimizer, scheduler, device,
#                 config['gradient_accumulation_steps']
#             )
            
#             val_loss, val_acc, val_preds, val_targets, val_probs = validate(
#                 model, val_loader, criterion, device
#             )
            
#             # Update history
#             fold_history['train_loss'].append(train_loss)
#             fold_history['train_acc'].append(train_acc)
#             fold_history['val_loss'].append(val_loss)
#             fold_history['val_acc'].append(val_acc)
            
#             print(f'Train Loss: {train_loss:.4f} | Train Acc: {train_acc:.4f}')
#             print(f'Val Loss: {val_loss:.4f} | Val Acc: {val_acc:.4f}')
            
#             # Save plots
#             save_plots(epoch, fold, 
#                       {'loss': fold_history['train_loss'], 'acc': fold_history['train_acc']},
#                       {'loss': fold_history['val_loss'], 'acc': fold_history['val_acc']},
#                       fold_dir)
            
#             # Check for improvement
#             if val_acc > best_val_acc:
#                 best_val_acc = val_acc
#                 best_val_preds = val_preds
#                 best_val_targets = val_targets
#                 best_val_probs = val_probs
                
#                 # Save best model
#                 torch.save({
#                     'fold': fold,
#                     'epoch': epoch,
#                     'model_state_dict': model.state_dict(),
#                     'optimizer_state_dict': optimizer.state_dict(),
#                     'scheduler_state_dict': scheduler.state_dict(),
#                     'best_val_acc': best_val_acc,
#                     'config': config
#                 }, f'{fold_dir}/best_model.pth')
                
#                 patience_counter = 0
#             else:
#                 patience_counter += 1
            
#             if patience_counter >= config['early_stopping_patience']:
#                 print('Early stopping triggered')
#                 break
        
#         fold_histories.append(fold_history)
#         fold_results.append(best_val_acc)
        
#         # Save final metrics
#         with open(f'{fold_dir}/classification_report.txt', 'w') as f:
#             f.write(classification_report(best_val_targets, best_val_preds))
    
#     # Print final results
#     print("\nCross-validation results:")
#     for fold, acc in enumerate(fold_results):
#         print(f"Fold {fold + 1}: {acc:.4f}")
#     print(f"Mean accuracy: {np.mean(fold_results):.4f} ± {np.std(fold_results):.4f}")

# # Start k-fold training
# print("\nStarting k-fold cross validation training...")
# train_with_kfold(
#     config=config,
#     device=device
# )

In [17]:
# def evaluate_model(model, test_loader, device, num_classes):
#     model.eval()
#     all_preds = []
#     all_probs = []
#     all_targets = []
#     class_correct = [0] * num_classes
#     class_total = [0] * num_classes
    
#     with torch.no_grad():
#         for inputs, targets in tqdm(test_loader, desc='Evaluating'):
#             inputs = inputs.to(device)
#             outputs = model(inputs)
#             probs = torch.softmax(outputs, dim=1)
#             _, preds = torch.max(outputs, 1)
            
#             # Collect predictions and probabilities
#             all_preds.extend(preds.cpu().numpy())
#             all_probs.extend(probs.cpu().numpy())
#             all_targets.extend(targets.numpy())
            
#             # Calculate class-wise accuracy
#             correct = (preds.cpu() == targets)
#             for i in range(len(targets)):
#                 label = targets[i]
#                 class_correct[label] += correct[i].item()
#                 class_total[label] += 1
    
#     all_probs = np.array(all_probs)
#     all_targets = np.array(all_targets)
    
#     # Print class-wise accuracy
#     print("\nClass-wise Accuracy:")
#     for i in range(num_classes):
#         acc = class_correct[i] / class_total[i] if class_total[i] > 0 else 0
#         print(f'Class {i}: {acc:.4f} ({class_correct[i]}/{class_total[i]})')
    
#     # Plot confusion matrix
#     cm = confusion_matrix(all_targets, all_preds)
#     plt.figure(figsize=(10, 8))
#     sns.heatmap(cm, annot=True, fmt='d', cmap='Blues')
#     plt.title('Confusion Matrix')
#     plt.xlabel('Predicted')
#     plt.ylabel('True')
#     plt.show()
    
#     # Plot ROC curves
#     plot_roc_curves(all_targets, all_probs, num_classes)
    
#     # Print classification report
#     print("\nClassification Report:")
#     print(classification_report(all_targets, all_preds))
    
#     return np.mean([class_correct[i]/class_total[i] for i in range(num_classes) if class_total[i] > 0])

# # Load best model
# checkpoint = torch.load('best_model.pth')
# model.load_state_dict(checkpoint['model_state_dict'])
# evaluate_model(model, val_loader, device, config['num_classes'])