# U-Net ve Hibrit Model Mimarileri

Bu notebook, standart U-Net ve topolojik özelliklerle desteklenmiş hibrit segmentasyon modellerini içerir.

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from pathlib import Path
import cv2
from tqdm import tqdm
import time
import warnings
warnings.filterwarnings('ignore')

# Önceki notebook'lardan import (gerçek uygulamada ayrı modül olacak)
import sys
sys.path.append('.')

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Kullanılan device: {device}")

plt.style.use('default')
sns.set_palette('husl')

## Standart U-Net Mimarisi

In [None]:
class DoubleConv(nn.Module):
    """(convolution => [BN] => ReLU) * 2"""

    def __init__(self, in_channels, out_channels, mid_channels=None):
        super().__init__()
        if not mid_channels:
            mid_channels = out_channels
        self.double_conv = nn.Sequential(
            nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(mid_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(mid_channels, out_channels, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        return self.double_conv(x)


class Down(nn.Module):
    """Downscaling with maxpool then double conv"""

    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.maxpool_conv = nn.Sequential(
            nn.MaxPool2d(2),
            DoubleConv(in_channels, out_channels)
        )

    def forward(self, x):
        return self.maxpool_conv(x)


class Up(nn.Module):
    """Upscaling then double conv"""

    def __init__(self, in_channels, out_channels, bilinear=True):
        super().__init__()

        # if bilinear, use the normal convolutions to reduce the number of channels
        if bilinear:
            self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
            self.conv = DoubleConv(in_channels, out_channels, in_channels // 2)
        else:
            self.up = nn.ConvTranspose2d(in_channels, in_channels // 2, kernel_size=2, stride=2)
            self.conv = DoubleConv(in_channels, out_channels)

    def forward(self, x1, x2):
        x1 = self.up(x1)
        # input is CHW
        diffY = x2.size()[2] - x1.size()[2]
        diffX = x2.size()[3] - x1.size()[3]

        x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2,
                        diffY // 2, diffY - diffY // 2])
        
        x = torch.cat([x2, x1], dim=1)
        return self.conv(x)


class OutConv(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(OutConv, self).__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1)

    def forward(self, x):
        return self.conv(x)

In [None]:
class UNet(nn.Module):
    def __init__(self, n_channels=1, n_classes=1, bilinear=False):
        super(UNet, self).__init__()
        self.n_channels = n_channels
        self.n_classes = n_classes
        self.bilinear = bilinear

        self.inc = DoubleConv(n_channels, 64)
        self.down1 = Down(64, 128)
        self.down2 = Down(128, 256)
        self.down3 = Down(256, 512)
        factor = 2 if bilinear else 1
        self.down4 = Down(512, 1024 // factor)
        self.up1 = Up(1024, 512 // factor, bilinear)
        self.up2 = Up(512, 256 // factor, bilinear)
        self.up3 = Up(256, 128 // factor, bilinear)
        self.up4 = Up(128, 64, bilinear)
        self.outc = OutConv(64, n_classes)

    def forward(self, x):
        x1 = self.inc(x)
        x2 = self.down1(x1)
        x3 = self.down2(x2)
        x4 = self.down3(x3)
        x5 = self.down4(x4)
        x = self.up1(x5, x4)
        x = self.up2(x, x3)
        x = self.up3(x, x2)
        x = self.up4(x, x1)
        logits = self.outc(x)
        return logits

# Standart U-Net modelini test et
model = UNet(n_channels=1, n_classes=1)
print(f"U-Net parametrelerinin sayısı: {sum(p.numel() for p in model.parameters()):,}")

# Test girişi
test_input = torch.randn(1, 1, 512, 512)
with torch.no_grad():
    output = model(test_input)
    print(f"Giriş boyutu: {test_input.shape}")
    print(f"Çıkış boyutu: {output.shape}")

## Topolojik Destekli Hibrit U-Net

In [None]:
class TopologyAwareUNet(nn.Module):
    """
    Topolojik özelliklerle desteklenmiş U-Net mimarisi
    """
    
    def __init__(self, n_channels=1, n_classes=1, bilinear=False, topo_feature_dim=32):
        super(TopologyAwareUNet, self).__init__()
        self.n_channels = n_channels
        self.n_classes = n_classes
        self.bilinear = bilinear
        self.topo_feature_dim = topo_feature_dim
        
        # Standart U-Net encoder
        self.inc = DoubleConv(n_channels, 64)
        self.down1 = Down(64, 128)
        self.down2 = Down(128, 256)
        self.down3 = Down(256, 512)
        factor = 2 if bilinear else 1
        self.down4 = Down(512, 1024 // factor)
        
        # Topolojik özellik işleme (basitleştirilmiş versiyon)
        self.topo_processor = nn.Sequential(
            nn.Linear(19, 64),  # 19: topolojik özellik boyutu
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(64, topo_feature_dim),
            nn.ReLU()
        )
        
        # Topolojik özellikler için spatial projection
        self.topo_spatial = nn.Sequential(
            nn.Linear(topo_feature_dim, 256),
            nn.ReLU(),
            nn.Linear(256, 16 * 16),  # 16x16 spatial map
            nn.Sigmoid()
        )
        
        # Özellik birleştirme için bottleneck
        bottleneck_channels = 1024 // factor + 1  # +1 for topological map
        self.feature_fusion = DoubleConv(bottleneck_channels, 1024 // factor)
        
        # U-Net decoder (modifiye edilmiş)
        self.up1 = Up(1024, 512 // factor, bilinear)
        self.up2 = Up(512, 256 // factor, bilinear)
        self.up3 = Up(256, 128 // factor, bilinear)
        self.up4 = Up(128, 64, bilinear)
        self.outc = OutConv(64, n_classes)
        
        # Attention mechanism için
        self.attention = nn.Sequential(
            nn.Conv2d(64, 32, 1),
            nn.ReLU(),
            nn.Conv2d(32, 1, 1),
            nn.Sigmoid()
        )
    
    def extract_simple_topo_features(self, x):
        """
        Basitleştirilmiş topolojik özellik çıkarımı
        (Gerçek uygulamada TopologicalLayer kullanılacak)
        """
        batch_size = x.shape[0]
        features = []
        
        for i in range(batch_size):
            img = x[i, 0].cpu().numpy()
            
            # Basit topolojik özellikler
            # Gerçek uygulamada persistent homology kullanılacak
            
            # Bağlantısız bileşen sayısı (threshold'lu)
            binary = (img > 0.5).astype(np.uint8)
            num_labels, labels = cv2.connectedComponents(binary)
            
            # Basit istatistikler
            mean_val = img.mean()
            std_val = img.std()
            min_val = img.min()
            max_val = img.max()
            
            # Gradient based features
            grad_x = cv2.Sobel(img, cv2.CV_64F, 1, 0, ksize=3)
            grad_y = cv2.Sobel(img, cv2.CV_64F, 0, 1, ksize=3)
            grad_mag = np.sqrt(grad_x**2 + grad_y**2)
            
            feature_vec = [
                num_labels,  # Connected components
                mean_val, std_val, min_val, max_val,  # Basic stats
                grad_mag.mean(), grad_mag.std(),  # Gradient stats
                *np.histogram(img.flatten(), bins=10)[0] / img.size,  # Histogram (10 bins)
                (img > img.mean()).sum() / img.size,  # Above mean ratio
                np.median(img)  # Median
            ]
            
            # Pad to 19 features
            while len(feature_vec) < 19:
                feature_vec.append(0.0)
            
            features.append(feature_vec[:19])
        
        return torch.FloatTensor(features).to(x.device)
    
    def forward(self, x):
        # Standart U-Net encoder
        x1 = self.inc(x)
        x2 = self.down1(x1)
        x3 = self.down2(x2)
        x4 = self.down3(x3)
        x5 = self.down4(x4)  # Bottleneck
        
        # Topolojik özellikler çıkar
        topo_features = self.extract_simple_topo_features(x)
        
        # Topolojik özellikleri işle
        topo_processed = self.topo_processor(topo_features)
        
        # Spatial map'e çevir
        topo_spatial = self.topo_spatial(topo_processed)
        topo_map = topo_spatial.view(-1, 1, 16, 16)
        
        # Bottleneck boyutuna upsample
        topo_map_upsampled = F.interpolate(
            topo_map, 
            size=(x5.shape[2], x5.shape[3]), 
            mode='bilinear', 
            align_corners=True
        )
        
        # Özellikleri birleştir
        fused_features = torch.cat([x5, topo_map_upsampled], dim=1)
        fused_features = self.feature_fusion(fused_features)
        
        # U-Net decoder
        x = self.up1(fused_features, x4)
        x = self.up2(x, x3)
        x = self.up3(x, x2)
        x = self.up4(x, x1)
        
        # Attention mechanism
        attention_weights = self.attention(x)
        x = x * attention_weights
        
        logits = self.outc(x)
        return logits

# Hibrit modeli test et
hybrid_model = TopologyAwareUNet(n_channels=1, n_classes=1, topo_feature_dim=32)
print(f"Hibrit U-Net parametrelerinin sayısı: {sum(p.numel() for p in hybrid_model.parameters()):,}")

# Test girişi
with torch.no_grad():
    hybrid_output = hybrid_model(test_input)
    print(f"Hibrit model çıkış boyutu: {hybrid_output.shape}")

## Loss Fonksiyonları

In [None]:
class DiceLoss(nn.Module):
    """
    Dice Loss - segmentasyon için yaygın kullanılan loss
    """
    
    def __init__(self, smooth=1e-6):
        super(DiceLoss, self).__init__()
        self.smooth = smooth
    
    def forward(self, pred, target):
        # Sigmoid uygula
        pred = torch.sigmoid(pred)
        
        # Flatten
        pred = pred.view(-1)
        target = target.view(-1)
        
        # Dice coefficient
        intersection = (pred * target).sum()
        dice = (2. * intersection + self.smooth) / (pred.sum() + target.sum() + self.smooth)
        
        return 1 - dice


class CombinedLoss(nn.Module):
    """
    BCE + Dice Loss kombinasyonu
    """
    
    def __init__(self, dice_weight=0.5, bce_weight=0.5):
        super(CombinedLoss, self).__init__()
        self.dice_weight = dice_weight
        self.bce_weight = bce_weight
        self.dice_loss = DiceLoss()
        self.bce_loss = nn.BCEWithLogitsLoss()
    
    def forward(self, pred, target):
        dice = self.dice_loss(pred, target)
        bce = self.bce_loss(pred, target)
        
        return self.dice_weight * dice + self.bce_weight * bce


class TopologicalLoss(nn.Module):
    """
    Topolojik tutarlılık için ek loss
    Betti sayıları arasındaki farkı minimize eder
    """
    
    def __init__(self, weight=0.1):
        super(TopologicalLoss, self).__init__()
        self.weight = weight
    
    def compute_betti_0(self, mask):
        """
        0. Betti sayısını (bağlantısız bileşen sayısını) hesapla
        """
        batch_size = mask.shape[0]
        betti_0_counts = []
        
        for i in range(batch_size):
            # Binary mask
            binary_mask = (mask[i, 0] > 0.5).cpu().numpy().astype(np.uint8)
            
            # Connected components
            num_labels, _ = cv2.connectedComponents(binary_mask)
            betti_0_counts.append(num_labels - 1)  # -1 for background
        
        return torch.FloatTensor(betti_0_counts).to(mask.device)
    
    def forward(self, pred, target):
        # Sigmoid uygula
        pred_prob = torch.sigmoid(pred)
        
        # Betti sayılarını hesapla
        pred_betti = self.compute_betti_0(pred_prob)
        target_betti = self.compute_betti_0(target)
        
        # L1 loss
        topo_loss = torch.mean(torch.abs(pred_betti - target_betti))
        
        return self.weight * topo_loss


class HybridLoss(nn.Module):
    """
    Hibrit model için kombine loss: Segmentasyon + Topolojik
    """
    
    def __init__(self, dice_weight=0.4, bce_weight=0.4, topo_weight=0.2):
        super(HybridLoss, self).__init__()
        self.segmentation_loss = CombinedLoss(dice_weight/(dice_weight+bce_weight), 
                                            bce_weight/(dice_weight+bce_weight))
        self.topological_loss = TopologicalLoss()
        self.seg_weight = dice_weight + bce_weight
        self.topo_weight = topo_weight
    
    def forward(self, pred, target):
        seg_loss = self.segmentation_loss(pred, target)
        topo_loss = self.topological_loss(pred, target)
        
        total_loss = self.seg_weight * seg_loss + self.topo_weight * topo_loss
        
        return total_loss, seg_loss, topo_loss

# Test loss fonksiyonları
print("Loss fonksiyonları test ediliyor...")

# Dummy data
pred = torch.randn(2, 1, 64, 64)
target = torch.randint(0, 2, (2, 1, 64, 64)).float()

# Test losses
dice_loss = DiceLoss()
combined_loss = CombinedLoss()
hybrid_loss = HybridLoss()

with torch.no_grad():
    dice_val = dice_loss(pred, target)
    combined_val = combined_loss(pred, target)
    hybrid_val, seg_val, topo_val = hybrid_loss(pred, target)
    
    print(f"Dice Loss: {dice_val.item():.4f}")
    print(f"Combined Loss: {combined_val.item():.4f}")
    print(f"Hybrid Loss: {hybrid_val.item():.4f} (Seg: {seg_val.item():.4f}, Topo: {topo_val.item():.4f})")

## Değerlendirme Metrikleri

In [None]:
class SegmentationMetrics:
    """
    Segmentasyon performans metrikleri
    """
    
    def __init__(self, threshold=0.5):
        self.threshold = threshold
    
    def dice_score(self, pred, target, smooth=1e-6):
        """
        Dice skorunu hesapla
        """
        pred = (torch.sigmoid(pred) > self.threshold).float()
        
        pred = pred.view(-1)
        target = target.view(-1)
        
        intersection = (pred * target).sum()
        dice = (2. * intersection + smooth) / (pred.sum() + target.sum() + smooth)
        
        return dice.item()
    
    def iou_score(self, pred, target, smooth=1e-6):
        """
        Intersection over Union skorunu hesapla
        """
        pred = (torch.sigmoid(pred) > self.threshold).float()
        
        pred = pred.view(-1)
        target = target.view(-1)
        
        intersection = (pred * target).sum()
        union = pred.sum() + target.sum() - intersection
        
        iou = (intersection + smooth) / (union + smooth)
        
        return iou.item()
    
    def precision_recall(self, pred, target):
        """
        Precision ve Recall hesapla
        """
        pred = (torch.sigmoid(pred) > self.threshold).float()
        
        pred = pred.view(-1)
        target = target.view(-1)
        
        tp = (pred * target).sum()
        fp = (pred * (1 - target)).sum()
        fn = ((1 - pred) * target).sum()
        
        precision = tp / (tp + fp + 1e-6)
        recall = tp / (tp + fn + 1e-6)
        
        return precision.item(), recall.item()
    
    def hausdorff_distance(self, pred, target):
        """
        Basitleştirilmiş Hausdorff mesafesi
        """
        pred = (torch.sigmoid(pred) > self.threshold).float()
        
        batch_size = pred.shape[0]
        distances = []
        
        for i in range(batch_size):
            pred_np = pred[i, 0].cpu().numpy().astype(np.uint8)
            target_np = target[i, 0].cpu().numpy().astype(np.uint8)
            
            # Contour'ları bul
            pred_contours, _ = cv2.findContours(pred_np, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
            target_contours, _ = cv2.findContours(target_np, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
            
            if len(pred_contours) == 0 or len(target_contours) == 0:
                distances.append(float('inf'))
                continue
            
            # Basit mesafe hesabı (gerçek Hausdorff'un basitleştirilmiş versiyonu)
            pred_points = np.vstack(pred_contours[0])
            target_points = np.vstack(target_contours[0])
            
            # Ortalama mesafe
            from scipy.spatial.distance import cdist
            dist_matrix = cdist(pred_points[:, 0], target_points[:, 0])
            distances.append(np.mean(dist_matrix.min(axis=1)))
        
        return np.mean(distances)
    
    def compute_all_metrics(self, pred, target):
        """
        Tüm metrikleri hesapla
        """
        dice = self.dice_score(pred, target)
        iou = self.iou_score(pred, target)
        precision, recall = self.precision_recall(pred, target)
        f1 = 2 * (precision * recall) / (precision + recall + 1e-6)
        
        try:
            hausdorff = self.hausdorff_distance(pred, target)
        except:
            hausdorff = float('inf')
        
        return {
            'dice': dice,
            'iou': iou,
            'precision': precision,
            'recall': recall,
            'f1': f1,
            'hausdorff': hausdorff
        }

# Test metrikleri
print("Metrikler test ediliyor...")
metrics = SegmentationMetrics()

# Dummy data ile test
pred_test = torch.randn(1, 1, 64, 64)
target_test = torch.randint(0, 2, (1, 1, 64, 64)).float()

with torch.no_grad():
    all_metrics = metrics.compute_all_metrics(pred_test, target_test)
    print("Test metrikleri:")
    for metric_name, value in all_metrics.items():
        print(f"  {metric_name}: {value:.4f}")

## Model Karşılaştırma Fonksiyonu

In [None]:
def compare_models(unet_model, hybrid_model, test_loader, device, num_samples=5):
    """
    İki modeli karşılaştır ve sonuçları görselleştir
    """
    unet_model.eval()
    hybrid_model.eval()
    metrics = SegmentationMetrics()
    
    unet_metrics = []
    hybrid_metrics = []
    
    with torch.no_grad():
        for i, (images, masks) in enumerate(test_loader):
            if i >= num_samples:
                break
            
            images = images.to(device)
            masks = masks.to(device)
            
            # Model tahminleri
            unet_pred = unet_model(images)
            hybrid_pred = hybrid_model(images)
            
            # Metrikleri hesapla
            unet_metric = metrics.compute_all_metrics(unet_pred, masks)
            hybrid_metric = metrics.compute_all_metrics(hybrid_pred, masks)
            
            unet_metrics.append(unet_metric)
            hybrid_metrics.append(hybrid_metric)
            
            # İlk örneği görselleştir
            if i == 0:
                # Görselleştirme
                fig, axes = plt.subplots(2, 4, figsize=(16, 8))
                
                # İlk satır: Orijinal, Ground Truth, U-Net, Hibrit
                axes[0, 0].imshow(images[0, 0].cpu().numpy(), cmap='gray')
                axes[0, 0].set_title('Orijinal Görüntü')
                axes[0, 0].axis('off')
                
                axes[0, 1].imshow(masks[0, 0].cpu().numpy(), cmap='hot')
                axes[0, 1].set_title('Ground Truth')
                axes[0, 1].axis('off')
                
                unet_prob = torch.sigmoid(unet_pred[0, 0]).cpu().numpy()
                axes[0, 2].imshow(unet_prob, cmap='hot')
                axes[0, 2].set_title(f'U-Net (Dice: {unet_metric["dice"]:.3f})')
                axes[0, 2].axis('off')
                
                hybrid_prob = torch.sigmoid(hybrid_pred[0, 0]).cpu().numpy()
                axes[0, 3].imshow(hybrid_prob, cmap='hot')
                axes[0, 3].set_title(f'Hibrit (Dice: {hybrid_metric["dice"]:.3f})')
                axes[0, 3].axis('off')
                
                # İkinci satır: Binary tahminler
                axes[1, 0].imshow(images[0, 0].cpu().numpy(), cmap='gray')
                axes[1, 0].set_title('Orijinal')
                axes[1, 0].axis('off')
                
                axes[1, 1].imshow(masks[0, 0].cpu().numpy(), cmap='gray')
                axes[1, 1].set_title('Ground Truth')
                axes[1, 1].axis('off')
                
                unet_binary = (unet_prob > 0.5).astype(np.float32)
                axes[1, 2].imshow(unet_binary, cmap='gray')
                axes[1, 2].set_title(f'U-Net Binary (IoU: {unet_metric["iou"]:.3f})')
                axes[1, 2].axis('off')
                
                hybrid_binary = (hybrid_prob > 0.5).astype(np.float32)
                axes[1, 3].imshow(hybrid_binary, cmap='gray')
                axes[1, 3].set_title(f'Hibrit Binary (IoU: {hybrid_metric["iou"]:.3f})')
                axes[1, 3].axis('off')
                
                plt.tight_layout()
                plt.show()
    
    # Ortalama metrikleri hesapla
    avg_unet_metrics = {}
    avg_hybrid_metrics = {}
    
    for key in unet_metrics[0].keys():
        avg_unet_metrics[key] = np.mean([m[key] for m in unet_metrics])
        avg_hybrid_metrics[key] = np.mean([m[key] for m in hybrid_metrics])
    
    # Metrik karşılaştırması
    fig, ax = plt.subplots(figsize=(12, 6))
    
    metrics_names = list(avg_unet_metrics.keys())
    metrics_names = [m for m in metrics_names if m != 'hausdorff']  # Hausdorff görselleştirmeden çıkar
    
    x = np.arange(len(metrics_names))
    width = 0.35
    
    unet_values = [avg_unet_metrics[m] for m in metrics_names]
    hybrid_values = [avg_hybrid_metrics[m] for m in metrics_names]
    
    ax.bar(x - width/2, unet_values, width, label='Standart U-Net', alpha=0.7)
    ax.bar(x + width/2, hybrid_values, width, label='Hibrit U-Net', alpha=0.7)
    
    ax.set_xlabel('Metrikler')
    ax.set_ylabel('Skor')
    ax.set_title('Model Performans Karşılaştırması')
    ax.set_xticks(x)
    ax.set_xticklabels(metrics_names)
    ax.legend()
    ax.grid(True, alpha=0.3)
    
    plt.tight_layout()
    plt.show()
    
    print("\nOrtalama Performans Metrikleri:")
    print("-" * 50)
    print(f"{'Metrik':<12} {'U-Net':<10} {'Hibrit':<10} {'İyileştirme':<12}")
    print("-" * 50)
    
    for key in metrics_names:
        unet_val = avg_unet_metrics[key]
        hybrid_val = avg_hybrid_metrics[key]
        improvement = ((hybrid_val - unet_val) / unet_val) * 100
        
        print(f"{key:<12} {unet_val:<10.4f} {hybrid_val:<10.4f} {improvement:<12.2f}%")
    
    return avg_unet_metrics, avg_hybrid_metrics

print("Model karşılaştırma fonksiyonu hazır.")

## Sonraki Adımlar

Bu notebook'ta geliştirdiğimiz bileşenler:

1. **Standart U-Net mimarisi**
2. **Topolojik destekli hibrit U-Net**
3. **Özelleşmiş loss fonksiyonları** (Dice, BCE, Topological)
4. **Kapsamlı değerlendirme metrikleri**
5. **Model karşılaştırma araçları**

Sonraki notebook'ta bu modelleri eğiteceğiz ve detaylı performans analizini yapacağız.