# 1.손실 함수 및 메트릭 정의

In [None]:
# 📈 개선된 손실 함수 및 메트릭 시스템
class AdaptiveBraTSLoss(nn.Module):
    """적응형 BraTS 손실 함수 (클래스 불균형 고려)"""
    
    def __init__(self, ce_weight=1.0, dice_weight=1.0, focal_weight=0.5):
        super().__init__()
        self.ce_weight = ce_weight
        self.dice_weight = dice_weight
        self.focal_weight = focal_weight
        
        # 클래스 가중치
        self.class_weights = torch.tensor([0.1, 2.0, 1.5, 2.5]).to(device)
        self.ce_loss = nn.CrossEntropyLoss(weight=self.class_weights)
        print("📊 적응형 손실 함수 초기화 (CE + Dice + Focal)")
    
    def forward(self, predictions, targets):
        # Cross-Entropy Loss
        ce_loss = self.ce_loss(predictions, targets)
        
        # Multi-Class Dice Loss
        dice_loss = self._multiclass_dice_loss(predictions, targets)
        
        # Focal Loss
        focal_loss = self._focal_loss(predictions, targets)
        
        # 총 손실
        total_loss = (self.ce_weight * ce_loss +
                     self.dice_weight * dice_loss +
                     self.focal_weight * focal_loss)
        
        return total_loss
    
    def _multiclass_dice_loss(self, predictions, targets):
        """다중 클래스 Dice 손실"""
        smooth = 1e-6
        probs = F.softmax(predictions, dim=1)
        
        # 원-핫 인코딩
        targets_one_hot = F.one_hot(targets.long(), num_classes=predictions.shape[1])
        targets_one_hot = targets_one_hot.permute(0, 4, 1, 2, 3).float()
        
        dice_scores = []
        for i in range(predictions.shape[1]):
            pred_i = probs[:, i]
            target_i = targets_one_hot[:, i]
            
            intersection = (pred_i * target_i).sum()
            union = pred_i.sum() + target_i.sum()
            dice = (2.0 * intersection + smooth) / (union + smooth)
            
            class_weight = self.class_weights[i] if i < len(self.class_weights) else 1.0
            dice_scores.append(dice * class_weight)
        
        return 1.0 - torch.stack(dice_scores).mean()
    
    def _focal_loss(self, predictions, targets, alpha=0.25, gamma=2.0):
        """Focal Loss"""
        ce_loss = F.cross_entropy(predictions, targets, reduction='none')
        pt = torch.exp(-ce_loss)
        focal_loss = alpha * (1 - pt) ** gamma * ce_loss
        return focal_loss.mean()

# 손실 함수 초기화
criterion = AdaptiveBraTSLoss(ce_weight=1.0, dice_weight=2.0, focal_weight=0.5)


# 2. 성능 메트릭 클래스

In [None]:
class ComprehensiveMetrics:
    """포괄적 성능 메트릭"""
    
    def __init__(self, num_classes=4):
        self.num_classes = num_classes
        self.reset()
        self.class_names = ['Background', 'NCR/NET', 'ED', 'ET']
        print("📏 포괄적 메트릭 시스템 초기화")
    
    def reset(self):
        """메트릭 초기화"""
        self.all_predictions = []
        self.all_targets = []
        self.dice_scores = []
        self.iou_scores = []
        self.processing_times = []
    
    def update(self, predictions, targets, processing_time=None):
        """메트릭 업데이트"""
        with torch.no_grad():
            if predictions.dim() == 5:  # (B, C, H, W, D)
                probs = F.softmax(predictions, dim=1)
                preds = torch.argmax(probs, dim=1)
            else:
                preds = predictions
            
            # CPU로 이동하여 저장
            preds_np = preds.cpu().numpy().flatten()
            targets_np = targets.cpu().numpy().flatten()
            
            self.all_predictions.extend(preds_np)
            self.all_targets.extend(targets_np)
            
            # 배치별 Dice 및 IoU 계산
            batch_dice = self._calculate_batch_dice(preds, targets)
            batch_iou = self._calculate_batch_iou(preds, targets)
            
            self.dice_scores.extend(batch_dice)
            self.iou_scores.extend(batch_iou)
            
            if processing_time:
                self.processing_times.append(processing_time)
    
    def _calculate_batch_dice(self, preds, targets):
        """배치별 Dice 점수 계산"""
        batch_dice = []
        batch_size = preds.shape[0]
        
        for b in range(batch_size):
            pred_b = preds[b]
            target_b = targets[b]
            dice_per_class = []
            
            for class_id in range(1, self.num_classes):  # 배경 제외
                pred_mask = (pred_b == class_id).float()
                target_mask = (target_b == class_id).float()
                
                intersection = (pred_mask * target_mask).sum()
                union = pred_mask.sum() + target_mask.sum()
                
                if union > 0:
                    dice = (2.0 * intersection / union).item()
                else:
                    dice = 1.0 if intersection == 0 else 0.0
                
                dice_per_class.append(dice)
            
            batch_dice.append(np.mean(dice_per_class))
        
        return batch_dice
    
    def get_comprehensive_results(self):
        """포괄적 결과 반환"""
        if not self.all_predictions or not self.all_targets:
            return {'error': '계산할 데이터가 없습니다.'}
        
        # 기본 분류 메트릭
        accuracy = accuracy_score(self.all_targets, self.all_predictions)
        precision = precision_score(self.all_targets, self.all_predictions,
                                   average='weighted', zero_division=0)
        recall = recall_score(self.all_targets, self.all_predictions,
                             average='weighted', zero_division=0)
        f1 = f1_score(self.all_targets, self.all_predictions,
                      average='weighted', zero_division=0)
        
        # 분할 메트릭
        mean_dice = np.mean(self.dice_scores) if self.dice_scores else 0.0
        mean_iou = np.mean(self.iou_scores) if self.iou_scores else 0.0
        
        # 처리 시간
        avg_processing_time = np.mean(self.processing_times) if self.processing_times else 0.0
        
        return {
            'accuracy': accuracy,
            'precision': precision,
            'recall': recall,
            'f1_score': f1,
            'dice_score': mean_dice,
            'iou_score': mean_iou,
            'processing_time': avg_processing_time,
            'total_samples': len(self.all_predictions)
        }

# 메트릭 초기화
metrics = ComprehensiveMetrics(num_classes=config.num_classes)


# 3. 훈련 시스템

In [None]:
# 🚀 훈련 시스템
class Trainer:
    """지능형 훈련 시스템"""
    
    def __init__(self, model, train_loader, val_loader, config):
        self.model = model
        self.train_loader = train_loader
        self.val_loader = val_loader
        self.config = config
        
        # 옵티마이저
        self.optimizer = torch.optim.AdamW(
            model.parameters(),
            lr=config.learning_rate,
            weight_decay=config.weight_decay,
            betas=(0.9, 0.999),
            eps=1e-8
        )
        
        # 스케줄러
        self.scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(
            self.optimizer,
            T_0=5,
            T_mult=2,
            eta_min=1e-6
        )
        
        # 손실 함수
        self.criterion = AdaptiveBraTSLoss()
        
        # 성능 추적
        self.history = {
            'train_loss': [], 'val_loss': [],
            'train_dice': [], 'val_dice': [],
            'learning_rate': [], 'epoch_time': []
        }
        
        self.best_dice = 0.0
        self.best_loss = float('inf')
        self.patience_counter = 0
        self.early_stop_patience = 8
        
        print("🚀 스마트 훈련 시스템 초기화")
        print(f" └── 옵티마이저: AdamW (lr={config.learning_rate})")
        print(f" └── 스케줄러: CosineAnnealingWarmRestarts")
        print(f" └── 조기 종료: patience={self.early_stop_patience}")
    
    def train_epoch(self, epoch):
        """훈련 에포크"""
        self.model.train()
        epoch_start_time = time.time()
        running_loss = 0.0
        train_metrics = ComprehensiveMetrics(self.config.num_classes)
        processed_batches = 0
        
        pbar = tqdm(self.train_loader, desc=f"훈련 Epoch {epoch}")
        
        for batch_idx, (volumes, targets) in enumerate(pbar):
            try:
                volumes = volumes.to(self.config.device, non_blocking=True)
                targets = targets.to(self.config.device, non_blocking=True)
                
                self.optimizer.zero_grad()
                
                batch_start_time = time.time()
                outputs = self.model(volumes)
                processing_time = time.time() - batch_start_time
                
                loss = self.criterion(outputs, targets)
                
                if torch.isnan(loss) or torch.isinf(loss):
                    print(f"⚠️ 배치 {batch_idx}: NaN/Inf 손실 감지, 스킵")
                    continue
                
                loss.backward()
                torch.nn.utils.clip_grad_norm_(
                    self.model.parameters(),
                    max_norm=self.config.gradient_clip_norm
                )
                
                self.optimizer.step()
                
                running_loss += loss.item()
                train_metrics.update(outputs, targets, processing_time)
                processed_batches += 1
                
                current_lr = self.optimizer.param_groups[0]['lr']
                pbar.set_postfix({
                    'Loss': f'{loss.item():.4f}',
                    'LR': f'{current_lr:.2e}'
                })
                
                if batch_idx % 3 == 0:
                    self._cleanup_memory()
                    
            except Exception as e:
                print(f"⚠️ 배치 {batch_idx} 훈련 실패: {e}")
                continue
        
        avg_loss = running_loss / max(processed_batches, 1)
        train_results = train_metrics.get_comprehensive_results()
        epoch_time = time.time() - epoch_start_time
        
        return avg_loss, train_results['dice_score'], epoch_time
    
    def validate_epoch(self, epoch):
        """검증 에포크"""
        self.model.eval()
        running_loss = 0.0
        val_metrics = ComprehensiveMetrics(self.config.num_classes)
        processed_batches = 0
        
        pbar = tqdm(self.val_loader, desc=f"검증 Epoch {epoch}")
        
        with torch.no_grad():
            for batch_idx, (volumes, targets) in enumerate(pbar):
                try:
                    volumes = volumes.to(self.config.device, non_blocking=True)
                    targets = targets.to(self.config.device, non_blocking=True)
                    
                    batch_start_time = time.time()
                    outputs = self.model(volumes)
                    processing_time = time.time() - batch_start_time
                    
                    loss = self.criterion(outputs, targets)
                    
                    if not (torch.isnan(loss) or torch.isinf(loss)):
                        running_loss += loss.item()
                        val_metrics.update(outputs, targets, processing_time)
                        processed_batches += 1
                        
                    pbar.set_postfix({'Loss': f'{loss.item():.4f}'})
                    
                except Exception as e:
                    continue
        
        avg_loss = running_loss / max(processed_batches, 1)
        val_results = val_metrics.get_comprehensive_results()
        
        return avg_loss, val_results['dice_score'], val_results
    
    def train(self, num_epochs):
        """전체 훈련 프로세스"""
        print(f"🎯 훈련 시작: {num_epochs} 에포크")
        print("=" * 60)
        
        for epoch in range(1, num_epochs + 1):
            print(f"\n📅 Epoch {epoch}/{num_epochs}")
            
            # 훈련
            train_loss, train_dice, epoch_time = self.train_epoch(epoch)
            
            # 검증
            val_loss, val_dice, val_results = self.validate_epoch(epoch)
            
            # 스케줄러 업데이트
            self.scheduler.step()
            current_lr = self.optimizer.param_groups[0]['lr']
            
            # 히스토리 업데이트
            self.history['train_loss'].append(train_loss)
            self.history['val_loss'].append(val_loss)
            self.history['train_dice'].append(train_dice)
            self.history['val_dice'].append(val_dice)
            self.history['learning_rate'].append(current_lr)
            self.history['epoch_time'].append(epoch_time)
            
            # 결과 출력
            print(f"훈련 - Loss: {train_loss:.4f}, Dice: {train_dice:.4f}")
            print(f"검증 - Loss: {val_loss:.4f}, Dice: {val_dice:.4f}")
            print(f"학습률: {current_lr:.2e}, 시간: {epoch_time:.1f}초")
            
            # 베스트 모델 체크
            is_best = val_dice > self.best_dice
            if is_best:
                self.best_dice = val_dice
                self.patience_counter = 0
                self.save_checkpoint(epoch, is_best=True)
                print(f"🏆 새로운 최고 성능: Dice {val_dice:.4f}")
            else:
                self.patience_counter += 1
            
            # 조기 종료 체크
            if self.patience_counter >= self.early_stop_patience:
                print(f"⏹️ 조기 종료: {self.early_stop_patience} 에포크 개선 없음")
                break
            
            # 정기 메모리 정리
            if epoch % self.config.memory_cleanup_interval == 0:
                self._cleanup_memory(full_cleanup=True)
                print("🧹 전체 메모리 정리 수행")
        
        print(f"\n🎉 훈련 완료! 최고 Dice: {self.best_dice:.4f}")
        return self.history
    
    def save_checkpoint(self, epoch, is_best=False):
        """체크포인트 저장"""
        checkpoint = {
            'epoch': epoch,
            'model_state_dict': self.model.state_dict(),
            'optimizer_state_dict': self.optimizer.state_dict(),
            'scheduler_state_dict': self.scheduler.state_dict(),
            'best_dice': self.best_dice,
            'history': self.history,
            'config_dict': {
                'in_channels': self.config.in_channels,
                'num_classes': self.config.num_classes,
                'base_features': self.config.base_features,
            }
        }
        
        filename = "best_model.pth" if is_best else f"checkpoint_epoch_{epoch}.pth"
        filepath = os.path.join(self.config.model_save_path, filename)
        torch.save(checkpoint, filepath)
        
        if is_best:
            print(f"💾 최고 성능 모델 저장: {filepath}")
    
    def _cleanup_memory(self, full_cleanup=False):
        """메모리 정리"""
        if self.config.device.type == 'mps':
            torch.mps.empty_cache()
        elif torch.cuda.is_available():
            torch.cuda.empty_cache()
        
        if full_cleanup:
            gc.collect()

# 훈련자 생성
trainer = SmartTrainer(model, train_loader, val_loader, config)


# 4. 훈련 실행 및 결과 시각화

In [None]:
# 🎯 모델 훈련 실행
def monitor_memory():
    """메모리 사용량 모니터링"""
    if torch.backends.mps.is_available():
        try:
            allocated = torch.mps.current_allocated_memory() / 1024**2
            reserved = torch.mps.driver_allocated_memory() / 1024**2
            print(f"💾 MPS 메모리 - 할당: {allocated:.1f}MB, 예약: {reserved:.1f}MB")
        except:
            print("💾 MPS 메모리 정보 조회 실패")
    else:
        print("💾 MPS 메모리 모니터링 불가")

def visualize_training_results(history):
    """훈련 결과 시각화"""
    if not history['train_loss']:
        print("⚠️ 훈련 히스토리가 비어있습니다.")
        return
    
    fig, axes = plt.subplots(2, 2, figsize=(15, 10))
    fig.suptitle('📊 훈련 결과 분석', fontsize=16, fontweight='bold')
    
    epochs = range(1, len(history['train_loss']) + 1)
    
    # 1. 손실 함수 변화
    ax1 = axes[0, 0]
    ax1.plot(epochs, history['train_loss'], 'b-', label='훈련 Loss', linewidth=2)
    ax1.plot(epochs, history['val_loss'], 'r-', label='검증 Loss', linewidth=2)
    ax1.set_title('손실 함수 변화')
    ax1.set_xlabel('에포크')
    ax1.set_ylabel('Loss')
    ax1.legend()
    ax1.grid(True, alpha=0.3)
    
    # 2. Dice 점수 변화
    ax2 = axes[0, 1]
    ax2.plot(epochs, history['train_dice'], 'g-', label='훈련 Dice', linewidth=2)
    ax2.plot(epochs, history['val_dice'], 'orange', label='검증 Dice', linewidth=2)
    ax2.set_title('Dice 점수 변화')
    ax2.set_xlabel('에포크')
    ax2.set_ylabel('Dice Score')
    ax2.legend()
    ax2.grid(True, alpha=0.3)
    ax2.set_ylim(0, 1)
    
    # 3. 학습률 변화
    ax3 = axes[1, 0]
    ax3.semilogy(epochs, history['learning_rate'], 'purple', linewidth=2)
    ax3.set_title('학습률 변화')
    ax3.set_xlabel('에포크')
    ax3.set_ylabel('Learning Rate (log scale)')
    ax3.grid(True, alpha=0.3)
    
    # 4. 에포크별 시간
    ax4 = axes[1, 1]
    ax4.bar(epochs, history['epoch_time'], color='skyblue', alpha=0.7)
    ax4.set_title('에포크별 훈련 시간')
    ax4.set_xlabel('에포크')
    ax4.set_ylabel('시간 (초)')
    ax4.grid(True, alpha=0.3)
    
    plt.tight_layout()
    plt.show()

print("🚀 BraTS 2021 개선 버전 훈련 시작!")
print("=" * 70)

# 초기 메모리 상태
monitor_memory()

try:
    # 모델 테스트
    print("🧪 모델 구조 테스트...")
    test_volume = torch.randn(1, 4, 64, 64, 64).to(config.device)
    
    with torch.no_grad():
        test_output = model(test_volume)
    
    print(f" ✅ 테스트 성공: {test_volume.shape} → {test_output.shape}")
    del test_volume, test_output
    
    if config.device.type == 'mps':
        torch.mps.empty_cache()
    
    # 실제 훈련 실행
    print(f"\n🎯 훈련 시작 ({config.epochs} 에포크)")
    history = trainer.train(num_epochs=config.epochs)
    
    # 결과 시각화
    print(f"\n📊 훈련 결과 시각화")
    visualize_training_results(history)
    
    # 최종 결과 요약
    print(f"\n🎉 훈련 완료!")
    print(f" 🏆 최고 Dice 점수: {trainer.best_dice:.4f}")
    
    if history['train_loss']:
        print(f" 📉 최종 훈련 Loss: {history['train_loss'][-1]:.4f}")
        print(f" 📉 최종 검증 Loss: {history['val_loss'][-1]:.4f}")
        print(f" 📈 최종 검증 Dice: {history['val_dice'][-1]:.4f}")
        print(f" ⏱️ 총 훈련 시간: {sum(history['epoch_time']):.1f}초")
        print(f" ⚡ 평균 에포크 시간: {np.mean(history['epoch_time']):.1f}초")

except Exception as e:
    print(f"❌ 훈련 중 오류 발생: {e}")
    import traceback
    traceback.print_exc()

finally:
    # 메모리 정리
    trainer._cleanup_memory(full_cleanup=True)
    monitor_memory()
    print("🧹 메모리 정리 완료")

print(f"\n🏁 모든 작업 완료: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
