## Семинар 3: Бинарная сегментация для Луны

В этом семинаре мы будем решать задачу **бинарной сегментации** - определять где на лунной поверхности находятся камни.


In [None]:
import os
import random
from PIL import Image
import numpy as np
import pandas as pd
from collections import Counter, defaultdict

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms

from sklearn.model_selection import train_test_split

import matplotlib.pyplot as plt
import seaborn as sns

import cv2
import albumentations as A
from albumentations.pytorch import ToTensorV2

from tqdm import tqdm

torch.manual_seed(42)
np.random.seed(42)
random.seed(42)


## Создание кастомного Dataset для Moon Segmentation

### Структура данных:
- **render/** - исходные изображения лунной поверхности (1000 изображений)
- **ground/** - бинарные маски сегментации (0 = фон, 255 = камни)



In [None]:
class MoonSegmentationDataset(Dataset):

    def __init__(self, root_dir, image_folder='render', mask_folder='ground', 
                 image_ids=None, augmentation=None, preprocessing=None):
        self.root_dir = root_dir
        self.image_folder = image_folder
        self.mask_folder = mask_folder
        self.augmentation = augmentation
        self.preprocessing = preprocessing
        
        images_dir = os.path.join(root_dir, 'images', image_folder)
        masks_dir = os.path.join(root_dir, 'images', mask_folder)
        
        if image_ids is None:
            all_images = os.listdir(images_dir)
            self.image_ids = [img.replace('.png', '') for img in all_images if img.endswith('.png')]
        else:
            self.image_ids = image_ids
        
    
    def __len__(self):
        return len(self.image_ids)
    
    def __getitem__(self, idx):
        image_id = self.image_ids[idx]
        
        image_path = os.path.join(self.root_dir, 'images', self.image_folder, f"{image_id}.png")
        
        # Для масок убираем префикс "render" если он есть
        # Например: render0001 - 0001
        mask_id = image_id.replace('render', '') if 'render' in image_id else image_id
        mask_path = os.path.join(self.root_dir, 'images', self.mask_folder, f"ground{mask_id}.png")
        
        image = cv2.imread(image_path)
        mask = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE)
        
        # Так как используем opencv, то не забываем преводить из BGR в RGB
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        
        # Нормализуем бинарную маску к [0, 1]
        # 0 = фон, 1 = камни
        mask = (mask > 0).astype(np.float32)
        
        # Применяем аугментации
        if self.augmentation:
            sample = self.augmentation(image=image, mask=mask)
            image, mask = sample['image'], sample['mask']
        
        # Применяем предобработку
        if self.preprocessing:
            sample = self.preprocessing(image=image, mask=mask)
            image, mask = sample['image'], sample['mask']
        
        return image, mask


### Создание аугментаций с Albumentations

Для задачи сегментации важно применять одинаковые трансформации к изображению и маске!


In [None]:
# Аугментации для обучения
train_augmentation = A.Compose([
    A.Resize(256, 256),
    A.HorizontalFlip(p=0.5),
    A.VerticalFlip(p=0.5),
    A.RandomRotate90(p=0.5),
    A.ShiftScaleRotate(shift_limit=0.1, scale_limit=0.1, rotate_limit=45, p=0.5),
    A.OneOf([
        A.GaussNoise(var_limit=(10.0, 50.0)),
        A.GaussianBlur(blur_limit=(3, 7)),
        A.MedianBlur(blur_limit=5),
    ], p=0.3),
    A.RandomBrightnessContrast(brightness_limit=0.2, contrast_limit=0.2, p=0.5),
])

preprocessing = A.Compose([
    A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
    ToTensorV2(),
])

val_augmentation = A.Compose([
    A.Resize(256, 256),
])



### Разделение данных на train/val


In [None]:
# Путь к данным
DATA_ROOT = "data/MOON_SEGMENTATION_BINARY/"

images_dir = os.path.join(DATA_ROOT, 'images', 'render')
all_images = [img.replace('.png', '') for img in os.listdir(images_dir) if img.endswith('.png')][:100]

print(f"Всего изображений: {len(all_images)}")

train_ids, val_ids = train_test_split(all_images, test_size=0.2, random_state=42)

print(f"Train: {len(train_ids)} изображений")
print(f"Val: {len(val_ids)} изображений")

train_dataset = MoonSegmentationDataset(
    root_dir=DATA_ROOT,
    image_folder='render',
    mask_folder='ground',
    image_ids=train_ids,
    augmentation=train_augmentation,
    preprocessing=preprocessing
)

val_dataset = MoonSegmentationDataset(
    root_dir=DATA_ROOT,
    image_folder='render',
    mask_folder='ground',
    image_ids=val_ids,
    augmentation=val_augmentation,
    preprocessing=preprocessing
)

BATCH_SIZE = 4
NUM_WORKERS = 0

train_loader = DataLoader(
    train_dataset, 
    batch_size=BATCH_SIZE, 
    shuffle=True, 
    num_workers=NUM_WORKERS,
    pin_memory=True
)

val_loader = DataLoader(
    val_dataset, 
    batch_size=BATCH_SIZE, 
    shuffle=False, 
    num_workers=NUM_WORKERS,
    pin_memory=True
)

print(f"Батчей в train: {len(train_loader)}")
print(f"Батчей в val: {len(val_loader)}")


## Визуализация данных

Посмотрим на примеры изображений и их маски сегментации


In [None]:
def denormalize(img_tensor):
    """Денормализация изображения для визуализации"""
    mean = torch.tensor([0.485, 0.456, 0.406]).view(3, 1, 1)
    std = torch.tensor([0.229, 0.224, 0.225]).view(3, 1, 1)
    img = img_tensor * std + mean
    return img.clamp(0, 1)

images, masks = next(iter(train_loader))

print(f"Размер батча изображений: {images.shape}")
print(f"Размер батча масок: {masks.shape}")

fig, axes = plt.subplots(2, 4, figsize=(16, 8))

for i in range(min(4, len(images))):
    img = denormalize(images[i]).permute(1, 2, 0).numpy()
    mask = masks[i].numpy()
    
    axes[0, i].imshow(img)
    axes[0, i].set_title(f"Изображение {i+1}")
    axes[0, i].axis('off')
    
    axes[1, i].imshow(mask, cmap='gray')
    axes[1, i].set_title(f"Маска {i+1} (белый=камни)")
    axes[1, i].axis('off')

plt.suptitle("Примеры данных: Бинарная сегментация камней на Луне", fontsize=16)
plt.tight_layout()
plt.show()


---

## Кастомная U-Net архитектура

### Что такое U-Net?
**U-Net** - это архитектура для семантической сегментации, предложенная в 2015 году.

### Основные компоненты:
1. **Encoder (Downsampling)** - сжимает изображение, извлекая признаки
2. **Bottleneck** - самый глубокий слой с максимальным количеством каналов
3. **Decoder (Upsampling)** - восстанавливает разрешение
4. **Skip Connections** - соединяют encoder и decoder для сохранения деталей



In [None]:
class DoubleConv(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(DoubleConv, self).__init__()
        self.double_conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            
            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )
    
    def forward(self, x):
        return self.double_conv(x)



### Полная кастомная U-Net архитектура


In [None]:
class UNet(nn.Module):

    def __init__(self, in_channels=3, out_channels=1, features=[64, 128, 256, 512]):
        super(UNet, self).__init__()
        
        self.encoder_blocks = nn.ModuleList()
        self.decoder_blocks = nn.ModuleList()
        
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2)

        prev_channels = in_channels
        for feature in features:
            self.encoder_blocks.append(DoubleConv(prev_channels, feature))
            prev_channels = feature

        self.bottleneck = DoubleConv(features[-1], features[-1] * 2)

        for feature in reversed(features):
            self.decoder_blocks.append(
                nn.ConvTranspose2d(feature * 2, feature, kernel_size=2, stride=2)
            )
            self.decoder_blocks.append(
                DoubleConv(feature * 2, feature)
            )

        self.final_conv = nn.Conv2d(features[0], out_channels, kernel_size=1)
        
        total_params = sum(p.numel() for p in self.parameters())
        print(f"Параметров: {total_params:,}")
        print(f"Уровней encoder: {len(features)}")
        print(f"Конфигурация каналов: {features}")
    
    def forward(self, x):

        skip_connections = []

        for encoder_block in self.encoder_blocks:
            x = encoder_block(x)
            skip_connections.append(x)
            x = self.pool(x)
        
        x = self.bottleneck(x)
        
        skip_connections = skip_connections[::-1]

        for idx in range(0, len(self.decoder_blocks), 2):
            x = self.decoder_blocks[idx](x)
            
            skip_connection = skip_connections[idx // 2]
            
            if x.shape != skip_connection.shape:
                x = F.interpolate(x, size=skip_connection.shape[2:], 
                                mode='bilinear', align_corners=True)
            
            x = torch.cat([skip_connection, x], dim=1)
            
            x = self.decoder_blocks[idx + 1](x)

        output = self.final_conv(x)
        
        return output


model = UNet(in_channels=3, out_channels=1, features=[64, 128, 256, 512])

test_input = torch.randn(1, 3, 256, 256)
test_output = model(test_input)

print(f"   Вход:  {test_input.shape}")
print(f"   Выход: {test_output.shape}")


---

## Функции потерь и метрики для сегментации

### Метрики сегментации:
- **Dice Loss** - популярная функция потерь для сегментации
- **IoU (Intersection over Union)** - метрика качества сегментации
- **Pixel Accuracy** - точность предсказания пикселей


In [None]:
class DiceLoss(nn.Module):
    """
    Dice Loss для бинарной сегментации
    
    Dice = 2 * |X ∩ Y| / (|X| + |Y|)
    """
    def __init__(self, smooth=1e-6):
        super(DiceLoss, self).__init__()
        self.smooth = smooth
    
    def forward(self, predictions, targets):
        predictions = torch.sigmoid(predictions)
        
        # Flatten
        predictions = predictions.view(-1)
        targets = targets.view(-1)
        
        # Dice coefficient
        intersection = (predictions * targets).sum()
        dice = (2. * intersection + self.smooth) / (predictions.sum() + targets.sum() + self.smooth)
        
        # Dice loss
        return 1 - dice


class CombinedLoss(nn.Module):

    def __init__(self, bce_weight=0.5, dice_weight=0.5):
        super(CombinedLoss, self).__init__()
        self.bce_weight = bce_weight
        self.dice_weight = dice_weight
        self.bce = nn.BCEWithLogitsLoss()
        self.dice = DiceLoss()
    
    def forward(self, predictions, targets):
        bce_loss = self.bce(predictions, targets)
        dice_loss = self.dice(predictions, targets)
        
        return self.bce_weight * bce_loss + self.dice_weight * dice_loss


def dice_coefficient(predictions, targets, threshold=0.5, smooth=1e-6):

    predictions = torch.sigmoid(predictions)
    predictions = (predictions > threshold).float()
    
    predictions = predictions.view(-1)
    targets = targets.view(-1)
    
    intersection = (predictions * targets).sum()
    dice = (2. * intersection + smooth) / (predictions.sum() + targets.sum() + smooth)
    
    return dice.item()


def iou_score(predictions, targets, threshold=0.5, smooth=1e-6):

    predictions = torch.sigmoid(predictions)
    predictions = (predictions > threshold).float()
    
    predictions = predictions.view(-1)
    targets = targets.view(-1)
    
    intersection = (predictions * targets).sum()
    union = predictions.sum() + targets.sum() - intersection
    
    iou = (intersection + smooth) / (union + smooth)
    
    return iou.item()


def pixel_accuracy(predictions, targets, threshold=0.5):

    predictions = torch.sigmoid(predictions)
    predictions = (predictions > threshold).float()
    correct = (predictions == targets).float().sum()
    total = targets.numel()
    
    return (correct / total).item()



In [None]:
def train_epoch(model, train_loader, criterion, optimizer, device):

    model.train()
    
    running_loss = 0.0
    running_dice = 0.0
    running_iou = 0.0
    
    pbar = tqdm(train_loader, desc="Training")
    
    for images, masks in pbar:
        images = images.to(device)
        masks = masks.unsqueeze(1).to(device)
        
        optimizer.zero_grad()
        outputs = model(images)
        
        loss = criterion(outputs, masks)
        
        loss.backward()
        optimizer.step()
        
        with torch.no_grad():
            dice = dice_coefficient(outputs, masks)
            iou = iou_score(outputs, masks)
        
        running_loss += loss.item()
        running_dice += dice
        running_iou += iou
        
        pbar.set_postfix({
            'loss': f'{loss.item():.4f}',
            'dice': f'{dice:.4f}',
            'iou': f'{iou:.4f}'
        })
    
    avg_loss = running_loss / len(train_loader)
    avg_dice = running_dice / len(train_loader)
    avg_iou = running_iou / len(train_loader)
    
    return avg_loss, avg_dice, avg_iou


def validate_epoch(model, val_loader, criterion, device):

    model.eval()
    
    running_loss = 0.0
    running_dice = 0.0
    running_iou = 0.0
    running_acc = 0.0
    
    pbar = tqdm(val_loader, desc="Validation")
    
    with torch.no_grad():
        for images, masks in pbar:
            images = images.to(device)
            masks = masks.unsqueeze(1).to(device)
            
            outputs = model(images)
            
            loss = criterion(outputs, masks)
            
            dice = dice_coefficient(outputs, masks)
            iou = iou_score(outputs, masks)
            acc = pixel_accuracy(outputs, masks)
            
            running_loss += loss.item()
            running_dice += dice
            running_iou += iou
            running_acc += acc
            
            pbar.set_postfix({
                'loss': f'{loss.item():.4f}',
                'dice': f'{dice:.4f}',
                'iou': f'{iou:.4f}',
                'acc': f'{acc:.4f}'
            })
    
    avg_loss = running_loss / len(val_loader)
    avg_dice = running_dice / len(val_loader)
    avg_iou = running_iou / len(val_loader)
    avg_acc = running_acc / len(val_loader)
    
    return avg_loss, avg_dice, avg_iou, avg_acc



### Основной цикл обучения


In [None]:
def train_model(model, train_loader, val_loader, criterion, optimizer, scheduler, device, num_epochs=20, save_path='best_unet.pth'):

    model = model.to(device)
    
    history = {
        'train_loss': [],
        'train_dice': [],
        'train_iou': [],
        'val_loss': [],
        'val_dice': [],
        'val_iou': [],
        'val_acc': [],
        'lr': []
    }
    
    best_val_dice = 0.0
    
    for epoch in range(num_epochs):
        print(f"Эпоха {epoch+1}/{num_epochs}")
        
        # Здесь происходит основное обучение
        train_loss, train_dice, train_iou = train_epoch(
            model, train_loader, criterion, optimizer, device
        )
        
        # Валидация 
        val_loss, val_dice, val_iou, val_acc = validate_epoch(
            model, val_loader, criterion, device
        )
        
        # Обновляем learning rate (проверяем на )
        if scheduler is not None:
            scheduler.step(val_dice)
            current_lr = optimizer.param_groups[0]['lr']
        else:
            current_lr = optimizer.param_groups[0]['lr']
        
        # Сохраняем всю нашу историю
        history['train_loss'].append(train_loss)
        history['train_dice'].append(train_dice)
        history['train_iou'].append(train_iou)
        history['val_loss'].append(val_loss)
        history['val_dice'].append(val_dice)
        history['val_iou'].append(val_iou)
        history['val_acc'].append(val_acc)
        history['lr'].append(current_lr)
        
        # Выводим результаты
        print(f"\n Результаты эпохи {epoch+1}:")
        print(f"  Train - Loss: {train_loss:.4f}, Dice: {train_dice:.4f}, IoU: {train_iou:.4f}")
        print(f"  Val   - Loss: {val_loss:.4f}, Dice: {val_dice:.4f}, IoU: {val_iou:.4f}, Acc: {val_acc:.4f}")
        print(f"  LR: {current_lr:.6f}")
        
        # Сохраняем лучшую модел, можно сохранять по метрике или loss
        if val_dice > best_val_dice:
            best_val_dice = val_dice
            torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'val_dice': val_dice,
                'val_iou': val_iou,
            }, save_path)
            print(f"  Сохранена лучшая модель! Dice: {val_dice:.4f}")
    
    print(f"Лучший Val Dice: {best_val_dice:.4f}") 
    
    return history



### Запуск обучения


In [None]:
# Настройка обучения
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"  Используем устройство: {device}")

model = UNet(in_channels=3, out_channels=1, features=[32, 64, 128, 256])

# Функция потерь (комбинированная BCE + Dice)
criterion = CombinedLoss(bce_weight=0.5, dice_weight=0.5)

# Оптимизатор
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4, weight_decay=1e-5)

# Планировщик learning rate (уменьшаем lr при плато)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
    optimizer, 
    mode='max',
    factor=0.5, 
    patience=3, 
)

NUM_EPOCHS = 5

history = train_model(
    model=model,
    train_loader=train_loader,
    val_loader=val_loader,
    criterion=criterion,
    optimizer=optimizer,
    scheduler=scheduler,
    device=device,
    num_epochs=NUM_EPOCHS,
    save_path='best_moon_unet.pth'
)


---

## Визуализация результатов обучения


In [None]:
def plot_training_history(history):

    fig, axes = plt.subplots(2, 2, figsize=(16, 12))
    
    # Loss
    axes[0, 0].plot(history['train_loss'], label='Train Loss', marker='o')
    axes[0, 0].plot(history['val_loss'], label='Val Loss', marker='o')
    axes[0, 0].set_title('Loss', fontsize=14, fontweight='bold')
    axes[0, 0].set_xlabel('Epoch')
    axes[0, 0].set_ylabel('Loss')
    axes[0, 0].legend()
    axes[0, 0].grid(True, alpha=0.3)
    
    # Dice Coefficient
    axes[0, 1].plot(history['train_dice'], label='Train Dice', marker='o')
    axes[0, 1].plot(history['val_dice'], label='Val Dice', marker='o')
    axes[0, 1].set_title('Dice Coefficient', fontsize=14, fontweight='bold')
    axes[0, 1].set_xlabel('Epoch')
    axes[0, 1].set_ylabel('Dice')
    axes[0, 1].legend()
    axes[0, 1].grid(True, alpha=0.3)
    
    # IoU Score
    axes[1, 0].plot(history['train_iou'], label='Train IoU', marker='o')
    axes[1, 0].plot(history['val_iou'], label='Val IoU', marker='o')
    axes[1, 0].set_title('IoU Score', fontsize=14, fontweight='bold')
    axes[1, 0].set_xlabel('Epoch')
    axes[1, 0].set_ylabel('IoU')
    axes[1, 0].legend()
    axes[1, 0].grid(True, alpha=0.3)
    
    # Learning Rate
    axes[1, 1].plot(history['lr'], label='Learning Rate', marker='o', color='green')
    axes[1, 1].set_title('Learning Rate', fontsize=14, fontweight='bold')
    axes[1, 1].set_xlabel('Epoch')
    axes[1, 1].set_ylabel('LR')
    axes[1, 1].set_yscale('log')
    axes[1, 1].legend()
    axes[1, 1].grid(True, alpha=0.3)
    
    plt.suptitle('История обучения U-Net', fontsize=16, fontweight='bold')
    plt.tight_layout()
    plt.show()
    

    print(f"Train Loss: {history['train_loss'][-1]:.4f}")
    print(f"Val Loss: {history['val_loss'][-1]:.4f}")
    print(f"Train Dice: {history['train_dice'][-1]:.4f}")
    print(f"Val Dice: {history['val_dice'][-1]:.4f}")
    print(f"Train IoU: {history['train_iou'][-1]:.4f}")
    print(f"Val IoU: {history['val_iou'][-1]:.4f}")
    print(f"Val Accuracy: {history['val_acc'][-1]:.4f}")

plot_training_history(history)


## Визуализация предсказаний модели


In [None]:
def visualize_predictions(model, dataloader, device, num_samples=4):

    model.eval()
    
    images, masks = next(iter(dataloader))
    images = images.to(device)
    masks = masks.to(device)
    
    with torch.no_grad():
        outputs = model(images)
        predictions = torch.sigmoid(outputs)
        predictions = (predictions > 0.5).float()
    
    images = images.cpu()
    masks = masks.cpu()
    predictions = predictions.cpu()
    
    fig, axes = plt.subplots(num_samples, 4, figsize=(16, 4*num_samples))
    
    for i in range(min(num_samples, len(images))):
        img = denormalize(images[i]).permute(1, 2, 0).numpy()
        mask_true = masks[i, :,:].numpy()
        mask_pred = predictions[i, 0].numpy()
        
        axes[i, 0].imshow(img)
        axes[i, 0].set_title('Исходное изображение', fontsize=12)
        axes[i, 0].axis('off')
        
        axes[i, 1].imshow(mask_true, cmap='gray')
        axes[i, 1].set_title('Истинная маска', fontsize=12)
        axes[i, 1].axis('off')
        
        axes[i, 2].imshow(mask_pred, cmap='gray')
        axes[i, 2].set_title('Предсказание', fontsize=12)
        axes[i, 2].axis('off')
        
        overlay = img.copy()
        overlay[mask_pred > 0.5] = [0, 1, 0]
        
        axes[i, 3].imshow(overlay)
        axes[i, 3].set_title('Наложение', fontsize=12)
        axes[i, 3].axis('off')
        
        dice = dice_coefficient(outputs[i:i+1], masks[i:i+1])
        iou = iou_score(outputs[i:i+1], masks[i:i+1])
        
        fig.text(0.5, 1 - (i + 0.5) / num_samples, 
                f'Dice: {dice:.4f} | IoU: {iou:.4f}',
                ha='center', fontsize=10, fontweight='bold')
    
    plt.suptitle('Предсказания модели U-Net', fontsize=16, fontweight='bold', y=0.995)
    plt.tight_layout()
    plt.show()

visualize_predictions(model, val_loader, device, num_samples=4)


## Загрузка и тестирование лучшей модели


In [None]:
best_model = UNet(in_channels=3, out_channels=1, features=[32, 64, 128, 256])

checkpoint = torch.load('best_moon_unet.pth', map_location=device)
best_model.load_state_dict(checkpoint['model_state_dict'])
best_model = best_model.to(device)

print(f"Эпоха: {checkpoint['epoch'] + 1}")
print(f"Val Dice: {checkpoint['val_dice']:.4f}")
print(f"Val IoU: {checkpoint['val_iou']:.4f}")

visualize_predictions(best_model, val_loader, device, num_samples=6)


In [None]:
def predict_image(model, image_path, device, threshold=0.5):

    model.eval()
    
    image = cv2.imread(image_path)
    image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
    original_image = image.copy()
    
    augmented = val_augmentation(image=image)
    image = augmented['image']
    preprocessed = preprocessing(image=image)
    image_tensor = preprocessed['image'].unsqueeze(0).to(device)
    
    with torch.no_grad():
        output = model(image_tensor)
        prediction = torch.sigmoid(output)
        prediction = (prediction > threshold).float()
    
    prediction = prediction.cpu().squeeze().numpy()
    
    return prediction, original_image


def visualize_single_prediction(image_path, prediction, original_image):

    overlay = cv2.resize(original_image, (prediction.shape[1], prediction.shape[0]))
    overlay = overlay.astype(float) / 255.0
    overlay[:, :, 1] = np.where(prediction > 0.5, 1, overlay[:, :, 1])
    
    fig, axes = plt.subplots(1, 3, figsize=(15, 5))
    
    axes[0].imshow(original_image)
    axes[0].set_title('Исходное изображение', fontsize=12, fontweight='bold')
    axes[0].axis('off')
    
    axes[1].imshow(prediction, cmap='gray')
    axes[1].set_title('Предсказание модели', fontsize=12, fontweight='bold')
    axes[1].axis('off')
    
    axes[2].imshow(overlay)
    axes[2].set_title('Наложение', fontsize=12, fontweight='bold')
    axes[2].axis('off')
    
    plt.suptitle(f'Инференс: {os.path.basename(image_path)}', fontsize=14, fontweight='bold')
    plt.tight_layout()
    plt.show()


random_image_id = np.random.choice(val_ids)
image_path = os.path.join(DATA_ROOT, 'images', 'render', f'{random_image_id}.png')

prediction, original_image = predict_image(best_model, image_path, device)
visualize_single_prediction(image_path, prediction, original_image)


---

## Часть 2: U-Net с ResNet18 Backbone

Теперь создадим более мощную версию U-Net, используя предобученную ResNet18 как энкодер (backbone).



In [None]:
import torchvision.models as models

class ResNet18UNet(nn.Module):

    def __init__(self, out_channels=1, pretrained=True):
        super(ResNet18UNet, self).__init__()

        # Загружаем предобученную ResNet18
        resnet = models.resnet18(pretrained=pretrained)
        
        # Извлекаем слои энкодера
        self.encoder1 = nn.Sequential(
            resnet.conv1,      # 64 каналов, stride=2
            resnet.bn1,
            resnet.relu
        )
        self.pool1 = resnet.maxpool  # stride=2
        
        self.encoder2 = resnet.layer1  # 64 каналов
        self.encoder3 = resnet.layer2  # 128 каналов, stride=2
        self.encoder4 = resnet.layer3  # 256 каналов, stride=2
        self.encoder5 = resnet.layer4  # 512 каналов, stride=2
        
        # Decoder блок 1 (512 -> 256)
        self.upconv4 = nn.ConvTranspose2d(512, 256, kernel_size=2, stride=2)
        self.decoder4 = DoubleConv(256 + 256, 256)  # Concat с encoder4
        
        # Decoder блок 2 (256 -> 128)
        self.upconv3 = nn.ConvTranspose2d(256, 128, kernel_size=2, stride=2)
        self.decoder3 = DoubleConv(128 + 128, 128)  # Concat с encoder3
        
        # Decoder блок 3 (128 -> 64)
        self.upconv2 = nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2)
        self.decoder2 = DoubleConv(64 + 64, 64)  # Concat с encoder2
        
        # Decoder блок 4 (64 -> 64)
        self.upconv1 = nn.ConvTranspose2d(64, 64, kernel_size=2, stride=2)
        self.decoder1 = DoubleConv(64 + 64, 64)  # Concat с encoder1
        
        # Финальный upsampling до исходного размера
        self.final_upconv = nn.ConvTranspose2d(64, 32, kernel_size=2, stride=2)
        
        # Выходной слой
        self.out_conv = nn.Conv2d(32, out_channels, kernel_size=1)
    
    def forward(self, x):

        # Исходный размер: [B, 3, H, W]
        
        enc1 = self.encoder1(x)      # [B, 64, H/2, W/2]
        enc1_pooled = self.pool1(enc1)  # [B, 64, H/4, W/4]
        
        enc2 = self.encoder2(enc1_pooled)  # [B, 64, H/4, W/4]
        enc3 = self.encoder3(enc2)         # [B, 128, H/8, W/8]
        enc4 = self.encoder4(enc3)         # [B, 256, H/16, W/16]
        enc5 = self.encoder5(enc4)         # [B, 512, H/32, W/32]
        
        # Decoder блок 1
        dec4 = self.upconv4(enc5)           # [B, 256, H/16, W/16]
        dec4 = torch.cat([dec4, enc4], dim=1)  # [B, 512, H/16, W/16]
        dec4 = self.decoder4(dec4)          # [B, 256, H/16, W/16]
        
        # Decoder блок 2
        dec3 = self.upconv3(dec4)           # [B, 128, H/8, W/8]
        dec3 = torch.cat([dec3, enc3], dim=1)  # [B, 256, H/8, W/8]
        dec3 = self.decoder3(dec3)          # [B, 128, H/8, W/8]
        
        # Decoder блок 3
        dec2 = self.upconv2(dec3)           # [B, 64, H/4, W/4]
        dec2 = torch.cat([dec2, enc2], dim=1)  # [B, 128, H/4, W/4]
        dec2 = self.decoder2(dec2)          # [B, 64, H/4, W/4]
        
        # Decoder блок 4
        dec1 = self.upconv1(dec2)           # [B, 64, H/2, W/2]
        dec1 = torch.cat([dec1, enc1], dim=1)  # [B, 128, H/2, W/2]
        dec1 = self.decoder1(dec1)          # [B, 64, H/2, W/2]
        
        # Финальный upsampling
        final = self.final_upconv(dec1)     # [B, 32, H, W]
        
        # Выходной слой
        out = self.out_conv(final)          # [B, 1, H, W]
        
        return out


resnet_unet = ResNet18UNet(out_channels=1, pretrained=True)

test_input = torch.randn(1, 3, 256, 256)
test_output = resnet_unet(test_input)

print(f"   Вход:  {test_input.shape}")
print(f"   Выход: {test_output.shape}")

total_params = sum(p.numel() for p in resnet_unet.parameters())
trainable_params = sum(p.numel() for p in resnet_unet.parameters() if p.requires_grad)

print(f"   Всего параметров: {total_params:,}")
print(f"   Обучаемых параметров: {trainable_params:,}")


### Обучение ResNet18-UNet

Теперь обучим новую модель на том же датасете и сравним результаты с простой U-Net.


In [None]:
resnet_model = ResNet18UNet(out_channels=1, pretrained=True)

resnet_criterion = CombinedLoss(bce_weight=0.5, dice_weight=0.5)

# Разделяем параметры на encoder (ResNet) и decoder
encoder_params = []
decoder_params = []

for name, param in resnet_model.named_parameters():
    if 'encoder' in name:
        encoder_params.append(param)
    else:
        decoder_params.append(param)

# Encoder обучаем с меньшим LR (fine-tuning), decoder - с обычным
resnet_optimizer = torch.optim.Adam([
    {'params': encoder_params, 'lr': 1e-4},  # Меньший LR для предобученных слоев
    {'params': decoder_params, 'lr': 1e-3}   # Обычный LR для новых слоев
])

resnet_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
    resnet_optimizer, 
    mode='min', 
    factor=0.5, 
    patience=5, 
)

NUM_EPOCHS_RESNET = 10
SAVE_PATH_RESNET = 'best_resnet18_unet.pth'

print(f"   Эпох: {NUM_EPOCHS_RESNET}")
print(f"   Encoder LR: 1e-4 (fine-tuning)")
print(f"   Decoder LR: 1e-3")
print(f"   Модель будет сохранена в: {SAVE_PATH_RESNET}")


In [None]:
resnet_history = train_model(
    model=resnet_model,
    train_loader=train_loader,
    val_loader=val_loader,
    criterion=resnet_criterion,
    optimizer=resnet_optimizer,
    scheduler=resnet_scheduler,
    num_epochs=NUM_EPOCHS_RESNET,
    device=device,
    save_path=SAVE_PATH_RESNET
)


### Визуализация результатов ResNet18-UNet


In [None]:
plot_training_history(resnet_history)


In [None]:
best_resnet_model = ResNet18UNet(out_channels=1, pretrained=False)

checkpoint = torch.load(SAVE_PATH_RESNET, map_location=device)
best_resnet_model.load_state_dict(checkpoint['model_state_dict'])
best_resnet_model = best_resnet_model.to(device)

print(f"   Best epoch: {checkpoint['epoch']}")
print(f"   Best val Dice: {checkpoint['val_dice']:.4f}")


In [None]:

visualize_predictions(best_resnet_model, val_loader, device, num_samples=4)


### Сравнение моделей: U-Net vs ResNet18-UNet

Давайте сравним результаты обеих моделей на одних и тех же изображениях.


In [None]:
def compare_models(simple_model, resnet_model, dataloader, device, num_samples=3):

    simple_model.eval()
    resnet_model.eval()
    
    images, masks = next(iter(dataloader))
    images = images.to(device)
    masks = masks.to(device)
    
    with torch.no_grad():
        simple_outputs = simple_model(images)
        simple_preds = torch.sigmoid(simple_outputs)
        simple_preds = (simple_preds > 0.5).float()
        
        resnet_outputs = resnet_model(images)
        resnet_preds = torch.sigmoid(resnet_outputs)
        resnet_preds = (resnet_preds > 0.5).float()
    
    images = images.cpu()
    masks = masks.cpu()
    simple_preds = simple_preds.cpu()
    resnet_preds = resnet_preds.cpu()
    
    fig, axes = plt.subplots(num_samples, 5, figsize=(20, 4*num_samples))
    
    for i in range(min(num_samples, len(images))):
        # Денормализуем изображение
        img = denormalize(images[i]).permute(1, 2, 0).numpy()
        mask_true = masks[i, 0].numpy()
        simple_pred = simple_preds[i, 0].numpy()
        resnet_pred = resnet_preds[i, 0].numpy()
        
        axes[i, 0].imshow(img)
        axes[i, 0].set_title('Исходное', fontsize=12)
        axes[i, 0].axis('off')
        
        axes[i, 1].imshow(mask_true, cmap='gray')
        axes[i, 1].set_title('Ground Truth', fontsize=12)
        axes[i, 1].axis('off')
        
        axes[i, 2].imshow(simple_pred, cmap='gray')
        simple_dice = dice_coefficient(simple_outputs[i:i+1], masks[i:i+1])
        simple_iou = iou_score(simple_outputs[i:i+1], masks[i:i+1])
        axes[i, 2].set_title(f'Simple U-Net\nDice: {simple_dice:.3f} | IoU: {simple_iou:.3f}', fontsize=10)
        axes[i, 2].axis('off')
        
        axes[i, 3].imshow(resnet_pred, cmap='gray')
        resnet_dice = dice_coefficient(resnet_outputs[i:i+1], masks[i:i+1])
        resnet_iou = iou_score(resnet_outputs[i:i+1], masks[i:i+1])
        axes[i, 3].set_title(f'ResNet18-UNet\nDice: {resnet_dice:.3f} | IoU: {resnet_iou:.3f}', fontsize=10)
        axes[i, 3].axis('off')
        
        diff = np.abs(simple_pred - resnet_pred)
        axes[i, 4].imshow(diff, cmap='hot')
        axes[i, 4].set_title('Difference\n(red = disagree)', fontsize=10)
        axes[i, 4].axis('off')
    
    plt.suptitle('Сравнение моделей: U-Net vs ResNet18-UNet', fontsize=16, fontweight='bold')
    plt.tight_layout()
    plt.show()

compare_models(best_model, best_resnet_model, val_loader, device, num_samples=3)
