# üõ©Ô∏è F-18 Chart Curve Extraction - Training Notebook

Bu notebook sentetik veri √ºretip U-Net modeli eƒüitir.

**√áalƒ±≈ütƒ±rma:**
1. Runtime > Change runtime type > GPU (T4 veya daha iyi)
2. T√ºm h√ºcreleri sƒ±rayla √ßalƒ±≈ütƒ±r

In [2]:
# GPU kontrol√º
!nvidia-smi

/bin/bash: line 1: nvidia-smi: command not found


In [None]:
# Gerekli k√ºt√ºphaneler (Colab'da √ßoƒüu zaten y√ºkl√º)
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import numpy as np
import cv2
import matplotlib.pyplot as plt
from PIL import Image, ImageEnhance
from dataclasses import dataclass, field
from typing import List, Tuple, Optional
import random
import math
import io
from tqdm.auto import tqdm

print(f"PyTorch: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")

## 1. Sentetik Veri √úretici

In [None]:
@dataclass
class ChartConfig:
    """Configuration for synthetic chart generation."""
    x_min: float = 0.30
    x_max: float = 1.00
    y_min: float = 0.04
    y_max: float = 0.15
    n_curves: int = 8
    curve_type: str = 'peaked'
    curve_lw: float = 0.6
    add_grid: bool = True
    add_arrows: bool = True
    add_envelope_optimum: bool = True
    add_envelope_endurance: bool = False
    add_vmax_line: bool = False
    add_text_boxes: bool = True
    add_fuel_labels: bool = True
    add_drag_labels: bool = True


def generate_curve_shape(x, curve_type, curve_index, total_curves):
    """Generate different curve shapes."""
    alt = curve_index / max(total_curves - 1, 1)
    x_norm = (x - x.min()) / (x.max() - x.min() + 1e-8)
    
    if curve_type == 'peaked':
        peak_pos = 0.30 + random.uniform(-0.05, 0.05)
        start_y = 0.12 + random.uniform(-0.02, 0.02)
        peak_y = 0.45 + alt * 0.40 + random.uniform(-0.03, 0.03)
        end_y = 0.20 + alt * 0.25 + random.uniform(-0.02, 0.02)
        
        y = np.zeros_like(x_norm)
        for i, t in enumerate(x_norm):
            if t <= peak_pos:
                progress = t / peak_pos
                y[i] = start_y + (peak_y - start_y) * (1 - (1 - progress) ** 2)
            else:
                progress = (t - peak_pos) / (1 - peak_pos)
                y[i] = peak_y - (peak_y - end_y) * (progress ** 0.7)
                
    elif curve_type == 'rising':
        start_y = 0.08 + alt * 0.05 + random.uniform(-0.02, 0.02)
        end_y = 0.55 + alt * 0.30 + random.uniform(-0.03, 0.03)
        curvature = random.uniform(0.7, 1.3)
        y = start_y + (end_y - start_y) * (x_norm ** curvature)
        
    elif curve_type == 'falling':
        start_y = 0.65 + alt * 0.25 + random.uniform(-0.03, 0.03)
        end_y = 0.12 + alt * 0.10 + random.uniform(-0.02, 0.02)
        curvature = random.uniform(0.5, 1.0)
        y = start_y - (start_y - end_y) * (x_norm ** curvature)
        
    else:  # mixed
        return generate_curve_shape(x, random.choice(['peaked', 'rising', 'falling']),
                                   curve_index, total_curves)
    return y


def fig_to_array(fig, dpi=150, tight=True):
    """Convert matplotlib figure to numpy array."""
    buf = io.BytesIO()
    if tight:
        fig.savefig(buf, format='png', dpi=dpi, bbox_inches='tight', pad_inches=0.02,
                    facecolor='white', edgecolor='none')
    else:
        fig.savefig(buf, format='png', dpi=dpi, 
                    facecolor=fig.get_facecolor(), edgecolor='none')
    plt.close(fig)
    buf.seek(0)
    img = Image.open(buf).convert('RGB')
    return np.array(img)

In [None]:
def draw_chart_matplotlib(config: ChartConfig, W: int = 512, H: int = 512):
    """
    Draw chart using matplotlib.
    Returns: full_img (RGB), mask (grayscale), curves_data
    """
    fig_w, fig_h = W / 100, H / 100
    
    # Generate curves data
    x = np.linspace(config.x_min + 0.02, config.x_max - 0.02, 400)
    curves_data = []
    for i in range(config.n_curves):
        y_norm = generate_curve_shape(x, config.curve_type, i, config.n_curves)
        y = config.y_min + y_norm * (config.y_max - config.y_min)
        y = np.clip(y, config.y_min + 0.001, config.y_max - 0.001)
        curves_data.append((x.copy(), y))
    
    # ========== FULL IMAGE ==========
    fig1, ax1 = plt.subplots(figsize=(fig_w, fig_h))
    ax1.set_xlim(config.x_min, config.x_max)
    ax1.set_ylim(config.y_min, config.y_max)
    
    # Grid
    if config.add_grid:
        x_range = config.x_max - config.x_min
        y_range = config.y_max - config.y_min
        x_major = 0.1 if x_range > 0.5 else 0.05
        y_major = 0.01 if y_range < 0.08 else 0.02
        ax1.set_xticks(np.arange(config.x_min, config.x_max + 0.001, x_major))
        ax1.set_xticks(np.arange(config.x_min, config.x_max + 0.001, x_major/2), minor=True)
        ax1.set_yticks(np.arange(config.y_min, config.y_max + 0.001, y_major))
        ax1.set_yticks(np.arange(config.y_min, config.y_max + 0.001, y_major/2), minor=True)
        ax1.grid(True, which='major', linewidth=0.8, alpha=0.5, color='black')
        ax1.grid(True, which='minor', linewidth=0.4, alpha=0.3, color='black')
    
    # Axis lines
    ax1.axhline(y=config.y_min, color='black', linewidth=2.0, zorder=10)
    ax1.axvline(x=config.x_min, color='black', linewidth=2.0, zorder=10)
    ax1.tick_params(axis='both', which='major', length=6, width=1.5, direction='in')
    ax1.tick_params(axis='both', which='minor', length=3, width=1.0, direction='in')
    for spine in ax1.spines.values():
        spine.set_linewidth(1.5)
    
    ax1.set_xlabel('MACH NUMBER', fontsize=10, fontweight='bold')
    ax1.set_ylabel('SPECIFIC RANGE ‚Äî NAUTICAL MILES PER POUND OF FUEL', fontsize=8)
    
    # Draw curves
    for cx, cy in curves_data:
        ax1.plot(cx, cy, 'k-', linewidth=config.curve_lw)
    
    # OPTIMUM CRUISE envelope
    if config.add_envelope_optimum:
        if config.curve_type == 'peaked':
            envelope_pts = [(cx[np.argmax(cy)], cy.max()) for cx, cy in curves_data]
        else:
            envelope_pts = [(cx[int(len(cx)*0.5)], cy[int(len(cy)*0.5)]) for cx, cy in curves_data]
        
        envelope_pts.sort(key=lambda p: p[1])
        ex, ey = zip(*envelope_pts)
        ax1.plot(ex, ey, 'k-', linewidth=1.2)
        
        ax1.text(ex[0] - 0.03, ey[-1] + (config.y_max - config.y_min) * 0.02,
                'OPTIMUM\nCRUISE', fontsize=8, ha='right', va='bottom')
    
    # MAXIMUM ENDURANCE envelope
    if config.add_envelope_endurance:
        envelope_pts = [(cx[int(len(cx)*0.2)], cy[int(len(cy)*0.2)]) for cx, cy in curves_data]
        envelope_pts.sort(key=lambda p: p[1])
        ex, ey = zip(*envelope_pts)
        ax1.plot(ex, ey, 'k-', linewidth=1.2)
        
        ax1.text(ex[-1] - 0.02, ey[0] - (config.y_max - config.y_min) * 0.02,
                'MAXIMUM\nENDURANCE', fontsize=8, ha='right', va='top')
    
    # Arrows
    if config.add_arrows:
        fuel_flows = ['3000', '3500', '4000', '4500', '5000', '5500', 
                     '6000', '6500', '7000', '7500', '8000', '8500']
        
        for idx, (cx, cy) in enumerate(reversed(curves_data)):
            if idx >= len(fuel_flows):
                break
            
            # Randomly choose: endpoint OR middle of curve
            if random.random() < 0.5:
                arrow_idx = -1
                x_head = cx[arrow_idx]
                y_head = cy[arrow_idx]
                dx = random.uniform(0.04, 0.08)
                dy = random.uniform(-0.005, 0.005)
                x_tail = x_head + dx
                y_tail = y_head + dy
            else:
                mid_start = len(cx) // 4
                mid_end = 3 * len(cx) // 4
                arrow_idx = random.randint(mid_start, mid_end)
                x_head = cx[arrow_idx]
                y_head = cy[arrow_idx]
                angle = random.uniform(20, 70)
                dist = random.uniform(0.05, 0.10)
                if random.random() < 0.5:
                    dx = dist * math.cos(math.radians(angle))
                    dy = dist * math.sin(math.radians(angle))
                else:
                    dx = dist * math.cos(math.radians(-angle))
                    dy = dist * math.sin(math.radians(-angle))
                x_tail = x_head + dx
                y_tail = y_head + dy
            
            # Leader line
            ax1.plot([x_tail, x_head], [y_tail, y_head], color="black", linewidth=0.6)
            
            # Arrow
            if random.random() < 0.4:
                arrow_style = random.choice(["-|>", "->"])
                fill_style = "none"
            else:
                arrow_style = random.choice(["-|>", "-|>", "->"])
                fill_style = "black"
            
            ax1.annotate(
                "",
                xy=(x_head, y_head),
                xytext=(x_tail, y_tail),
                arrowprops=dict(
                    arrowstyle=arrow_style,
                    lw=random.uniform(0.7, 1.1),
                    color="black",
                    fc=fill_style,
                    shrinkA=0, shrinkB=0,
                    mutation_scale=random.uniform(12, 18),
                ),
            )
            
            # Dashed leader
            if random.random() < 0.85:
                dash_len = random.uniform(0.06, 0.14)
                dash_angle = math.radians(random.choice([35, 45, 55, 65, 75]))
                dash_dx = dash_len * math.cos(dash_angle)
                dash_dy = dash_len * math.sin(dash_angle)
                base_x = config.x_min + (config.x_max - config.x_min) * random.uniform(0.25, 0.75)
                base_y = config.y_min + (config.y_max - config.y_min) * random.uniform(0.25, 0.75)
                ax1.plot([base_x, base_x + dash_dx], [base_y, base_y + dash_dy],
                        color="black", linewidth=0.6, linestyle=(0, (6, 4)))
            
            # Label
            label_x = x_tail + random.uniform(0.06, 0.10)
            ax1.text(label_x, y_tail + random.uniform(-0.002, 0.002),
                    fuel_flows[idx], fontsize=8, va='center', ha='left')
    
    # Additional standalone dashed lines
    if random.random() < 0.7:
        n_extra_dashes = random.randint(2, 6)
        for _ in range(n_extra_dashes):
            dcx = config.x_min + (config.x_max - config.x_min) * random.uniform(0.2, 0.8)
            dcy = config.y_min + (config.y_max - config.y_min) * random.uniform(0.2, 0.8)
            dash_len = random.uniform(0.04, 0.10)
            dash_angle = math.radians(random.choice([35, 45, 55, 65, 75]))
            dash_dx = dash_len * math.cos(dash_angle)
            dash_dy = dash_len * math.sin(dash_angle)
            ax1.plot([dcx, dcx + dash_dx], [dcy, dcy + dash_dy],
                    color="black", linewidth=0.5, linestyle=(0, (5, 3)))
    
    # Text boxes
    if config.add_text_boxes:
        # TOTAL FUEL FLOW box
        ax1.text(
            config.x_max - 0.05, config.y_max - 0.005,
            'TOTAL FUEL FLOW‚Äî\nPOUNDS PER HOUR',
            fontsize=8, ha='right', va='top',
            bbox=dict(boxstyle='square,pad=0.3', facecolor='white', edgecolor='black')
        )
        
        # Legend box
        legend_x = config.x_min + (config.x_max - config.x_min) * 0.15
        legend_y = config.y_max - (config.y_max - config.y_min) * 0.1
        ax1.text(
            legend_x, legend_y,
            '‚óÑ‚îÄ CRUISE    DASH ‚îÄ‚ñ∫\n      AOA          AOA\n(USED FOR INTERFERENCE\n DRAG DETERMINATION)',
            fontsize=7, ha='left', va='top',
            bbox=dict(boxstyle='square,pad=0.3', facecolor='white', edgecolor='black')
        )
    
    # Drag index labels
    if config.add_drag_labels:
        labels = ['0.00', '25.00', '50.00', '75.00', '100.00', '125.00', '150.00']
        base_x = config.x_min + (config.x_max - config.x_min) * 0.65
        base_y = config.y_min + (config.y_max - config.y_min) * 0.15
        
        for i, lbl in enumerate(labels[:random.randint(4, 7)]):
            ax1.text(base_x + random.uniform(-0.02, 0.02),
                    base_y + i * (config.y_max - config.y_min) * 0.05,
                    lbl, fontsize=7, alpha=0.9)
    
    # Vmax line
    if config.add_vmax_line:
        vmax_pts = [(cx[int(len(cx)*0.85)], cy[int(len(cy)*0.85)]) for cx, cy in curves_data]
        vmax_pts.sort(key=lambda p: p[1])
        vx, vy = zip(*vmax_pts)
        ax1.plot(vx, vy, 'k--', linewidth=0.8)
        ax1.text(vx[-1], vy[-1] + 0.003, r'$V_{max}$(MIL)', fontsize=7)
    
    full_img = fig_to_array(fig1, dpi=150, tight=True)
    full_img = cv2.resize(full_img, (W, H))
    
    # ========== MASK (curves only) ==========
    fig2, ax2 = plt.subplots(figsize=(fig_w, fig_h))
    ax2.set_xlim(config.x_min, config.x_max)
    ax2.set_ylim(config.y_min, config.y_max)
    ax2.set_position([0, 0, 1, 1])
    ax2.axis('off')
    fig2.patch.set_facecolor('black')
    ax2.set_facecolor('black')
    
    for cx, cy in curves_data:
        ax2.plot(cx, cy, 'w-', linewidth=config.curve_lw + 0.6)
    
    mask = fig_to_array(fig2, dpi=150, tight=False)
    mask = cv2.resize(mask, (W, H))
    mask = cv2.cvtColor(mask, cv2.COLOR_RGB2GRAY)
    _, mask = cv2.threshold(mask, 20, 255, cv2.THRESH_BINARY)
    
    return full_img, mask, curves_data


def random_config() -> ChartConfig:
    """Generate random chart configuration."""
    x_ranges = [(0.30, 0.95), (0.30, 1.00), (0.40, 1.10), (0.50, 1.20),
               (0.50, 1.30), (0.50, 1.40), (0.60, 1.40)]
    y_ranges = [(0.04, 0.15), (0.05, 0.15), (0.06, 0.17), (0.07, 0.18),
               (0.08, 0.19), (0.08, 0.20), (0.05, 0.14)]
    
    x_min, x_max = random.choice(x_ranges)
    y_min, y_max = random.choice(y_ranges)
    
    return ChartConfig(
        x_min=x_min, x_max=x_max,
        y_min=y_min, y_max=y_max,
        n_curves=random.randint(4, 12),
        curve_type=random.choice(['peaked', 'peaked', 'rising', 'falling', 'mixed']),
        curve_lw=random.uniform(0.3, 0.6),
        add_grid=random.random() < 0.95,
        add_arrows=random.random() < 0.85,
        add_envelope_optimum=random.random() < 0.70,
        add_envelope_endurance=random.random() < 0.35,
        add_vmax_line=random.random() < 0.25,
        add_text_boxes=random.random() < 0.75,
        add_fuel_labels=random.random() < 0.80,
        add_drag_labels=random.random() < 0.55,
    )


def add_scan_artifacts(img, strength=1.0):
    """Add scan/photocopy artifacts."""
    pil_img = Image.fromarray(img)
    
    # Rotation
    angle = random.uniform(-1.2, 1.2) * strength
    pil_img = pil_img.rotate(angle, fillcolor=(255, 255, 255), resample=Image.BICUBIC)
    
    # Brightness/contrast
    pil_img = ImageEnhance.Brightness(pil_img).enhance(random.uniform(0.90, 1.10))
    pil_img = ImageEnhance.Contrast(pil_img).enhance(random.uniform(0.88, 1.12))
    
    # Noise
    arr = np.array(pil_img).astype(np.float32) / 255.0
    noise = np.random.normal(0, 0.012 * strength, arr.shape)
    arr = np.clip(arr + noise, 0, 1)
    
    # JPEG artifacts (always apply like v5)
    buf = io.BytesIO()
    Image.fromarray((arr * 255).astype(np.uint8)).save(buf, format='JPEG', quality=random.randint(50, 80))
    buf.seek(0)
    return np.array(Image.open(buf).convert('RGB'))

In [None]:
# Test sentetik veri √ºretimi
config = random_config()
full_img, mask, curves = draw_chart_matplotlib(config, W=512, H=512)
full_img = add_scan_artifacts(full_img)

fig, axes = plt.subplots(1, 2, figsize=(12, 5))
axes[0].imshow(full_img)
axes[0].set_title(f'Input ({config.n_curves} curves, lw={config.curve_lw:.2f})')
axes[0].axis('off')
axes[1].imshow(mask, cmap='gray')
axes[1].set_title('Target Mask')
axes[1].axis('off')
plt.tight_layout()
plt.show()

## 2. Dataset ve DataLoader

In [None]:
class SyntheticChartDataset(Dataset):
    """On-the-fly synthetic chart dataset."""
    
    def __init__(self, size=1000, img_size=512, add_artifacts=True):
        self.size = size
        self.img_size = img_size
        self.add_artifacts = add_artifacts
    
    def __len__(self):
        return self.size
    
    def __getitem__(self, idx):
        # Her √ßaƒürƒ±da yeni veri √ºret
        config = random_config()
        full_img, mask, _ = draw_chart_matplotlib(config, W=self.img_size, H=self.img_size)
        
        if self.add_artifacts:
            full_img = add_scan_artifacts(full_img)
        
        # Normalize
        img_tensor = torch.from_numpy(full_img).permute(2, 0, 1).float() / 255.0
        mask_tensor = torch.from_numpy(mask).unsqueeze(0).float() / 255.0
        
        return img_tensor, mask_tensor


# Test
dataset = SyntheticChartDataset(size=100)
img, mask = dataset[0]
print(f"Image shape: {img.shape}, Mask shape: {mask.shape}")
print(f"Image range: [{img.min():.3f}, {img.max():.3f}]")
print(f"Mask range: [{mask.min():.3f}, {mask.max():.3f}]")

## 3. U-Net Model

In [None]:
class DoubleConv(nn.Module):
    """Double convolution block."""
    def __init__(self, in_ch, out_ch):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_ch, out_ch, 3, padding=1),
            nn.BatchNorm2d(out_ch),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_ch, out_ch, 3, padding=1),
            nn.BatchNorm2d(out_ch),
            nn.ReLU(inplace=True)
        )
    
    def forward(self, x):
        return self.conv(x)


class UNet(nn.Module):
    """U-Net for curve segmentation."""
    
    def __init__(self, in_channels=3, out_channels=1, features=[64, 128, 256, 512]):
        super().__init__()
        self.downs = nn.ModuleList()
        self.ups = nn.ModuleList()
        self.pool = nn.MaxPool2d(2, 2)
        
        # Encoder
        for f in features:
            self.downs.append(DoubleConv(in_channels, f))
            in_channels = f
        
        # Bottleneck
        self.bottleneck = DoubleConv(features[-1], features[-1] * 2)
        
        # Decoder
        for f in reversed(features):
            self.ups.append(nn.ConvTranspose2d(f * 2, f, 2, 2))
            self.ups.append(DoubleConv(f * 2, f))
        
        self.final = nn.Conv2d(features[0], out_channels, 1)
    
    def forward(self, x):
        skip_connections = []
        
        # Encoder
        for down in self.downs:
            x = down(x)
            skip_connections.append(x)
            x = self.pool(x)
        
        x = self.bottleneck(x)
        skip_connections = skip_connections[::-1]
        
        # Decoder
        for i in range(0, len(self.ups), 2):
            x = self.ups[i](x)
            skip = skip_connections[i // 2]
            if x.shape != skip.shape:
                x = F.interpolate(x, size=skip.shape[2:], mode='bilinear', align_corners=True)
            x = torch.cat([skip, x], dim=1)
            x = self.ups[i + 1](x)
        
        return torch.sigmoid(self.final(x))


# Test model
model = UNet()
x = torch.randn(1, 3, 512, 512)
y = model(x)
print(f"Input: {x.shape} -> Output: {y.shape}")
print(f"Model parameters: {sum(p.numel() for p in model.parameters()):,}")

## 4. Loss Function

In [None]:
class DiceBCELoss(nn.Module):
    """Combined Dice + BCE loss for better segmentation."""
    
    def __init__(self, dice_weight=0.5):
        super().__init__()
        self.dice_weight = dice_weight
        self.bce = nn.BCELoss()
    
    def forward(self, pred, target):
        # BCE
        bce_loss = self.bce(pred, target)
        
        # Dice
        smooth = 1e-5
        pred_flat = pred.view(-1)
        target_flat = target.view(-1)
        intersection = (pred_flat * target_flat).sum()
        dice = (2. * intersection + smooth) / (pred_flat.sum() + target_flat.sum() + smooth)
        dice_loss = 1 - dice
        
        return self.dice_weight * dice_loss + (1 - self.dice_weight) * bce_loss

## 5. Eƒüitim

In [None]:
# Hyperparameters
BATCH_SIZE = 4
NUM_EPOCHS = 30
LEARNING_RATE = 1e-4
DATASET_SIZE = 2000  # Her epoch'ta bu kadar √∂rnek
IMG_SIZE = 512
NUM_WORKERS = 0  # Colab worker crash fix

# Device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

# Model, optimizer, loss
model = UNet().to(device)
optimizer = torch.optim.AdamW(model.parameters(), lr=LEARNING_RATE)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=NUM_EPOCHS)
criterion = DiceBCELoss(dice_weight=0.5)

# Dataset
train_dataset = SyntheticChartDataset(size=DATASET_SIZE, img_size=IMG_SIZE, add_artifacts=True)
train_loader = DataLoader(
    train_dataset,
    batch_size=BATCH_SIZE,
    shuffle=True,
    num_workers=NUM_WORKERS,
    pin_memory=True,
)

In [None]:
def train_epoch(model, loader, criterion, optimizer, device):
    model.train()
    total_loss = 0
    
    pbar = tqdm(loader, desc='Training')
    for imgs, masks in pbar:
        imgs = imgs.to(device)
        masks = masks.to(device)
        
        optimizer.zero_grad()
        preds = model(imgs)
        loss = criterion(preds, masks)
        loss.backward()
        optimizer.step()
        
        total_loss += loss.item()
        pbar.set_postfix({'loss': f'{loss.item():.4f}'})
    
    return total_loss / len(loader)


def visualize_predictions(model, dataset, device, n=3):
    model.eval()
    fig, axes = plt.subplots(n, 3, figsize=(12, 4*n))
    
    with torch.no_grad():
        for i in range(n):
            img, mask = dataset[i]
            pred = model(img.unsqueeze(0).to(device)).cpu().squeeze()
            
            axes[i, 0].imshow(img.permute(1, 2, 0))
            axes[i, 0].set_title('Input')
            axes[i, 0].axis('off')
            
            axes[i, 1].imshow(mask.squeeze(), cmap='gray')
            axes[i, 1].set_title('Ground Truth')
            axes[i, 1].axis('off')
            
            axes[i, 2].imshow(pred, cmap='gray')
            axes[i, 2].set_title('Prediction')
            axes[i, 2].axis('off')
    
    plt.tight_layout()
    plt.show()

In [None]:
# Training loop
history = []

for epoch in range(NUM_EPOCHS):
    # Her epoch'ta yeni veri √ºret (on-the-fly)
    train_dataset = SyntheticChartDataset(size=DATASET_SIZE, img_size=IMG_SIZE, add_artifacts=True)
    train_loader = DataLoader(
        train_dataset,
        batch_size=BATCH_SIZE,
        shuffle=True,
        num_workers=NUM_WORKERS,
        pin_memory=True,
    )
    
    loss = train_epoch(model, train_loader, criterion, optimizer, device)
    scheduler.step()
    
    history.append(loss)
    print(f"Epoch {epoch+1}/{NUM_EPOCHS} - Loss: {loss:.4f} - LR: {scheduler.get_last_lr()[0]:.6f}")
    
    # Her 5 epoch'ta g√∂rselle≈ütir
    if (epoch + 1) % 5 == 0:
        visualize_predictions(model, train_dataset, device, n=2)

In [None]:
# Loss grafiƒüi
plt.figure(figsize=(10, 4))
plt.plot(history)
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Training Loss')
plt.grid(True)
plt.show()

## 6. Model Kaydetme

In [None]:
# Modeli kaydet
torch.save(model.state_dict(), 'curve_unet_colab.pt')
print("Model saved: curve_unet_colab.pt")

# Drive'a kaydet (opsiyonel)
# from google.colab import drive
# drive.mount('/content/drive')
# torch.save(model.state_dict(), '/content/drive/MyDrive/curve_unet_colab.pt')

## 7. Test - Ger√ßek G√∂r√ºnt√º √úzerinde

In [None]:
def predict_and_show(model, image_path_or_array, device):
    """Predict curves on an image."""
    model.eval()
    
    if isinstance(image_path_or_array, str):
        img = cv2.imread(image_path_or_array)
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
    else:
        img = image_path_or_array
    
    # Resize and normalize
    orig_h, orig_w = img.shape[:2]
    img_resized = cv2.resize(img, (512, 512))
    img_tensor = torch.from_numpy(img_resized).permute(2, 0, 1).float() / 255.0
    
    with torch.no_grad():
        pred = model(img_tensor.unsqueeze(0).to(device)).cpu().squeeze().numpy()
    
    # Threshold
    pred_binary = (pred > 0.5).astype(np.uint8) * 255
    
    # Resize back
    pred_full = cv2.resize(pred_binary, (orig_w, orig_h))
    
    fig, axes = plt.subplots(1, 3, figsize=(15, 5))
    axes[0].imshow(img)
    axes[0].set_title('Input')
    axes[0].axis('off')
    
    axes[1].imshow(pred, cmap='gray')
    axes[1].set_title('Raw Prediction')
    axes[1].axis('off')
    
    axes[2].imshow(pred_full, cmap='gray')
    axes[2].set_title('Binary Mask')
    axes[2].axis('off')
    
    plt.tight_layout()
    plt.show()
    
    return pred_full


# Test on synthetic
test_config = random_config()
test_img, _, _ = draw_chart_matplotlib(test_config, W=800, H=600)
test_img = add_scan_artifacts(test_img)
pred_mask = predict_and_show(model, test_img, device)

In [None]:
# Ger√ßek g√∂r√ºnt√º y√ºkle (opsiyonel)
# from google.colab import files
# uploaded = files.upload()
# for fn in uploaded.keys():
#     pred_mask = predict_and_show(model, fn, device)