# Модель сегментации контуров людей

[датасет](https://www.kaggle.com/datasets/tapakah68/supervisely-filtered-segmentation-person-dataset)


In [13]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
import numpy as np
from PIL import Image
import os
import matplotlib.pyplot as plt
from tqdm import tqdm
from model import HumanSegmentationModel

In [14]:
class HumanSegmentationDataset(Dataset):
    def __init__(self, data_dir, img_size=(256, 256)):
        self.data_dir = data_dir
        self.img_size = img_size
        self.images_dir = os.path.join(data_dir, 'images')
        self.masks_dir = os.path.join(data_dir, 'masks')
        
        self.image_files = [
            f for f in os.listdir(self.images_dir) 
            if f.lower().endswith(('.png', '.jpg', '.jpeg', '.bmp'))
        ]
        
    def __len__(self):
        return len(self.image_files)
    
    def __getitem__(self, idx):
        img_name = self.image_files[idx]
        img_path = os.path.join(self.images_dir, img_name)
        
        mask_name = os.path.splitext(img_name)[0] + '.png'
        mask_path = os.path.join(self.masks_dir, mask_name)
        
        if not os.path.exists(mask_path):
            possible_masks = [
                f for f in os.listdir(self.masks_dir) 
                if os.path.splitext(f)[0] == os.path.splitext(img_name)[0]
            ]
            if possible_masks:
                mask_path = os.path.join(self.masks_dir, possible_masks[0])
            else:
                mask_name = os.path.splitext(img_name)[0] + '.png'
                mask_path = os.path.join(self.masks_dir, mask_name)
        
        image = Image.open(img_path).convert('RGB')
        mask = Image.open(mask_path).convert('L')
        
        image = image.resize(self.img_size, Image.BILINEAR)
        mask = mask.resize(self.img_size, Image.NEAREST)
        
        transform = transforms.ToTensor()
        normalize = transforms.Normalize(
            mean=[0.485, 0.456, 0.406], 
            std=[0.229, 0.224, 0.225]
        )
        
        image = transform(image)
        image = normalize(image)
        
        mask = np.array(mask)
        mask = (mask > 128).astype(np.float32)
        mask = torch.from_numpy(mask).unsqueeze(0).float()
        
        return image, mask

In [15]:
def dice_coefficient(pred, target, smooth=1e-6):
    pred_flat = pred.contiguous().view(-1)
    target_flat = target.contiguous().view(-1)
    intersection = (pred_flat * target_flat).sum()
    union = pred_flat.sum() + target_flat.sum()
    return (2. * intersection + smooth) / (union + smooth)

def iou_coefficient(pred, target, smooth=1e-6):
    pred_flat = pred.contiguous().view(-1)
    target_flat = target.contiguous().view(-1)
    intersection = (pred_flat * target_flat).sum()
    total = pred_flat.sum() + target_flat.sum()
    union = total - intersection
    return (intersection + smooth) / (union + smooth)

def save_visualization(images, masks, predictions, epoch, save_dir='visualizations'):
    os.makedirs(save_dir, exist_ok=True)
    
    fig, axes = plt.subplots(min(3, len(images)), 3, figsize=(12, 4*min(3, len(images))))
    if min(3, len(images)) == 1:
        axes = axes.reshape(1, -1)
    
    for i in range(min(3, len(images))):
        img = images[i].cpu().numpy()
        img = img * np.array([0.229, 0.224, 0.225]).reshape(3, 1, 1) + np.array([0.485, 0.456, 0.406]).reshape(3, 1, 1)
        img = np.clip(img, 0, 1)
        img = img.transpose(1, 2, 0)
        
        mask = masks[i][0].cpu().numpy()
        pred = predictions[i][0].cpu().detach().numpy()
        pred_binary = (pred > 0.5).astype(np.float32)
        
        axes[i, 0].imshow(img)
        axes[i, 0].set_title("Image")
        axes[i, 0].axis('off')
        
        axes[i, 1].imshow(mask, cmap='gray')
        axes[i, 1].set_title("GT Mask")
        axes[i, 1].axis('off')
        
        axes[i, 2].imshow(pred_binary, cmap='gray')
        axes[i, 2].set_title(f"Prediction")
        axes[i, 2].axis('off')
    
    plt.tight_layout()
    plt.savefig(f'{save_dir}/epoch_{epoch}.png', dpi=100)
    plt.close()

def train():
    config = {
        'batch_size': 4,
        'num_epochs': 50,
        'learning_rate': 0.001,
        'img_size': (256, 256),
        'train_dir': 'dataset/train',
        'val_dir': 'dataset/val',
        'checkpoint_dir': 'checkpoints',
        'device': 'cuda' if torch.cuda.is_available() else 'cpu'
    }
    
    os.makedirs(config['checkpoint_dir'], exist_ok=True)
    
    train_dataset = HumanSegmentationDataset(config['train_dir'], config['img_size'])
    val_dataset = HumanSegmentationDataset(config['val_dir'], config['img_size'])
    
    train_loader = DataLoader(
        train_dataset, 
        batch_size=config['batch_size'], 
        shuffle=True, 
        num_workers=0
    )
    val_loader = DataLoader(
        val_dataset, 
        batch_size=config['batch_size'], 
        shuffle=False, 
        num_workers=0
    )
    
    model = HumanSegmentationModel().to(config['device'])
    criterion = nn.BCELoss()
    optimizer = optim.Adam(model.parameters(), lr=config['learning_rate'])
    
    best_dice = 0
    
    for epoch in range(config['num_epochs']):
        model.train()
        train_loss = 0
        
        for images, masks in tqdm(train_loader, desc=f'Epoch {epoch+1}'):
            images, masks = images.to(config['device']), masks.to(config['device'])
            
            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, masks)
            loss.backward()
            optimizer.step()
            
            train_loss += loss.item()
        
        avg_train_loss = train_loss / len(train_loader)
        
        model.eval()
        val_dice, val_iou = 0, 0
        val_samples = []
        
        with torch.no_grad():
            for images, masks in val_loader:
                images, masks = images.to(config['device']), masks.to(config['device'])
                outputs = model(images)
                
                val_dice += dice_coefficient(outputs, masks).item()
                val_iou += iou_coefficient(outputs, masks).item()
                
                if len(val_samples) < 3:
                    val_samples.append((images[:1], masks[:1], outputs[:1]))
        
        avg_val_dice = val_dice / len(val_loader)
        avg_val_iou = val_iou / len(val_loader)
        
        print(f'Epoch {epoch+1}: Train Loss: {avg_train_loss:.4f}, '
              f'Val Dice: {avg_val_dice:.4f}, Val IoU: {avg_val_iou:.4f}')
        
        if val_samples:
            save_visualization(
                [s[0][0] for s in val_samples],
                [s[1][0] for s in val_samples],
                [s[2][0] for s in val_samples],
                epoch + 1
            )
        
        if avg_val_dice > best_dice:
            best_dice = avg_val_dice
            torch.save({
                'epoch': epoch,
                'model_state': model.state_dict(),
                'optimizer_state': optimizer.state_dict(),
                'val_dice': avg_val_dice,
                'val_iou': avg_val_iou,
            }, f"{config['checkpoint_dir']}/best_model.pt")
            
        torch.save({
            'epoch': epoch,
            'model_state': model.state_dict(),
            'optimizer_state': optimizer.state_dict(),
            'val_dice': avg_val_dice,
            'val_iou': avg_val_iou,
        }, f"{config['checkpoint_dir']}/last_model.pth")
    
    print(f'Best Dice: {best_dice:.4f}')

In [16]:
train()

Epoch 1: 100%|██████████| 600/600 [03:28<00:00,  2.88it/s]


Epoch 1: Train Loss: 0.5639, Val Dice: 0.3585, Val IoU: 0.2207


Epoch 2: 100%|██████████| 600/600 [03:01<00:00,  3.30it/s]


Epoch 2: Train Loss: 0.5311, Val Dice: 0.4369, Val IoU: 0.2817


Epoch 3: 100%|██████████| 600/600 [03:10<00:00,  3.15it/s]


Epoch 3: Train Loss: 0.4682, Val Dice: 0.5564, Val IoU: 0.3899


Epoch 4: 100%|██████████| 600/600 [03:08<00:00,  3.17it/s]


Epoch 4: Train Loss: 0.4548, Val Dice: 0.5430, Val IoU: 0.3770


Epoch 5: 100%|██████████| 600/600 [03:07<00:00,  3.20it/s]


Epoch 5: Train Loss: 0.4310, Val Dice: 0.5289, Val IoU: 0.3632


Epoch 6: 100%|██████████| 600/600 [03:06<00:00,  3.21it/s]


Epoch 6: Train Loss: 0.4121, Val Dice: 0.5814, Val IoU: 0.4156


Epoch 7: 100%|██████████| 600/600 [03:09<00:00,  3.17it/s]


Epoch 7: Train Loss: 0.3859, Val Dice: 0.5931, Val IoU: 0.4262


Epoch 8: 100%|██████████| 600/600 [03:10<00:00,  3.15it/s]


Epoch 8: Train Loss: 0.3760, Val Dice: 0.6341, Val IoU: 0.4695


Epoch 9: 100%|██████████| 600/600 [03:11<00:00,  3.13it/s]


Epoch 9: Train Loss: 0.3573, Val Dice: 0.6292, Val IoU: 0.4649


Epoch 10: 100%|██████████| 600/600 [03:09<00:00,  3.16it/s]


Epoch 10: Train Loss: 0.3467, Val Dice: 0.6451, Val IoU: 0.4823


Epoch 11: 100%|██████████| 600/600 [03:27<00:00,  2.90it/s]


Epoch 11: Train Loss: 0.3384, Val Dice: 0.6454, Val IoU: 0.4842


Epoch 12: 100%|██████████| 600/600 [06:21<00:00,  1.57it/s]


Epoch 12: Train Loss: 0.3216, Val Dice: 0.6270, Val IoU: 0.4613


Epoch 13: 100%|██████████| 600/600 [05:10<00:00,  1.93it/s]


Epoch 13: Train Loss: 0.3150, Val Dice: 0.6930, Val IoU: 0.5367


Epoch 14: 100%|██████████| 600/600 [03:07<00:00,  3.21it/s]


Epoch 14: Train Loss: 0.3061, Val Dice: 0.6829, Val IoU: 0.5246


Epoch 15: 100%|██████████| 600/600 [03:06<00:00,  3.21it/s]


Epoch 15: Train Loss: 0.2974, Val Dice: 0.6834, Val IoU: 0.5258


Epoch 16: 100%|██████████| 600/600 [03:14<00:00,  3.09it/s]


Epoch 16: Train Loss: 0.2893, Val Dice: 0.6776, Val IoU: 0.5181


Epoch 17: 100%|██████████| 600/600 [07:25<00:00,  1.35it/s]


Epoch 17: Train Loss: 0.2902, Val Dice: 0.6800, Val IoU: 0.5215


Epoch 18: 100%|██████████| 600/600 [07:11<00:00,  1.39it/s]


Epoch 18: Train Loss: 0.2763, Val Dice: 0.7139, Val IoU: 0.5614


Epoch 19: 100%|██████████| 600/600 [07:06<00:00,  1.41it/s]


Epoch 19: Train Loss: 0.2703, Val Dice: 0.7062, Val IoU: 0.5524


Epoch 20: 100%|██████████| 600/600 [07:11<00:00,  1.39it/s]


Epoch 20: Train Loss: 0.2626, Val Dice: 0.6902, Val IoU: 0.5322


Epoch 21: 100%|██████████| 600/600 [07:05<00:00,  1.41it/s]


Epoch 21: Train Loss: 0.2600, Val Dice: 0.7004, Val IoU: 0.5448


Epoch 22: 100%|██████████| 600/600 [07:00<00:00,  1.43it/s]


Epoch 22: Train Loss: 0.2503, Val Dice: 0.7293, Val IoU: 0.5831


Epoch 23: 100%|██████████| 600/600 [06:49<00:00,  1.47it/s]


Epoch 23: Train Loss: 0.2371, Val Dice: 0.6853, Val IoU: 0.5280


Epoch 24: 100%|██████████| 600/600 [06:58<00:00,  1.43it/s]


Epoch 24: Train Loss: 0.2362, Val Dice: 0.7216, Val IoU: 0.5715


Epoch 25: 100%|██████████| 600/600 [07:12<00:00,  1.39it/s]


Epoch 25: Train Loss: 0.2264, Val Dice: 0.7172, Val IoU: 0.5669


Epoch 26: 100%|██████████| 600/600 [07:12<00:00,  1.39it/s]


Epoch 26: Train Loss: 0.2154, Val Dice: 0.7367, Val IoU: 0.5903


Epoch 27: 100%|██████████| 600/600 [07:06<00:00,  1.41it/s]


Epoch 27: Train Loss: 0.2061, Val Dice: 0.7454, Val IoU: 0.6009


Epoch 28: 100%|██████████| 600/600 [07:04<00:00,  1.41it/s]


Epoch 28: Train Loss: 0.2090, Val Dice: 0.7297, Val IoU: 0.5809


Epoch 29: 100%|██████████| 600/600 [07:08<00:00,  1.40it/s]


Epoch 29: Train Loss: 0.1898, Val Dice: 0.7459, Val IoU: 0.6016


Epoch 30: 100%|██████████| 600/600 [07:14<00:00,  1.38it/s]


Epoch 30: Train Loss: 0.1770, Val Dice: 0.7258, Val IoU: 0.5754


Epoch 31: 100%|██████████| 600/600 [07:09<00:00,  1.40it/s]


Epoch 31: Train Loss: 0.1641, Val Dice: 0.7436, Val IoU: 0.5995


Epoch 32: 100%|██████████| 600/600 [07:07<00:00,  1.40it/s]


Epoch 32: Train Loss: 0.1616, Val Dice: 0.7343, Val IoU: 0.5855


Epoch 33: 100%|██████████| 600/600 [07:14<00:00,  1.38it/s]


Epoch 33: Train Loss: 0.1533, Val Dice: 0.7475, Val IoU: 0.6058


Epoch 34: 100%|██████████| 600/600 [07:11<00:00,  1.39it/s]


Epoch 34: Train Loss: 0.1483, Val Dice: 0.7526, Val IoU: 0.6109


Epoch 35: 100%|██████████| 600/600 [07:07<00:00,  1.40it/s]


Epoch 35: Train Loss: 0.1390, Val Dice: 0.7417, Val IoU: 0.5961


Epoch 36: 100%|██████████| 600/600 [07:09<00:00,  1.40it/s]


Epoch 36: Train Loss: 0.1291, Val Dice: 0.7435, Val IoU: 0.5983


Epoch 37: 100%|██████████| 600/600 [07:12<00:00,  1.39it/s]


Epoch 37: Train Loss: 0.1232, Val Dice: 0.7580, Val IoU: 0.6187


Epoch 38: 100%|██████████| 600/600 [07:07<00:00,  1.40it/s]


Epoch 38: Train Loss: 0.1147, Val Dice: 0.7493, Val IoU: 0.6061


Epoch 39: 100%|██████████| 600/600 [07:07<00:00,  1.40it/s]


Epoch 39: Train Loss: 0.1147, Val Dice: 0.7490, Val IoU: 0.6069


Epoch 40: 100%|██████████| 600/600 [07:13<00:00,  1.38it/s]


Epoch 40: Train Loss: 0.1069, Val Dice: 0.7555, Val IoU: 0.6155


Epoch 41: 100%|██████████| 600/600 [07:08<00:00,  1.40it/s]


Epoch 41: Train Loss: 0.1017, Val Dice: 0.7543, Val IoU: 0.6133


Epoch 42: 100%|██████████| 600/600 [07:08<00:00,  1.40it/s]


Epoch 42: Train Loss: 0.0951, Val Dice: 0.7555, Val IoU: 0.6138


Epoch 43: 100%|██████████| 600/600 [07:09<00:00,  1.40it/s]


Epoch 43: Train Loss: 0.0901, Val Dice: 0.7523, Val IoU: 0.6112


Epoch 44: 100%|██████████| 600/600 [07:10<00:00,  1.39it/s]


Epoch 44: Train Loss: 0.0918, Val Dice: 0.7565, Val IoU: 0.6154


Epoch 45: 100%|██████████| 600/600 [07:06<00:00,  1.41it/s]


Epoch 45: Train Loss: 0.0845, Val Dice: 0.7525, Val IoU: 0.6106


Epoch 46: 100%|██████████| 600/600 [07:04<00:00,  1.41it/s]


Epoch 46: Train Loss: 0.0810, Val Dice: 0.7624, Val IoU: 0.6236


Epoch 47: 100%|██████████| 600/600 [07:09<00:00,  1.40it/s]


Epoch 47: Train Loss: 0.0803, Val Dice: 0.7541, Val IoU: 0.6122


Epoch 48: 100%|██████████| 600/600 [07:06<00:00,  1.41it/s]


Epoch 48: Train Loss: 0.0742, Val Dice: 0.7488, Val IoU: 0.6048


Epoch 49: 100%|██████████| 600/600 [07:06<00:00,  1.41it/s]


Epoch 49: Train Loss: 0.0727, Val Dice: 0.7479, Val IoU: 0.6046


Epoch 50: 100%|██████████| 600/600 [07:12<00:00,  1.39it/s]


Epoch 50: Train Loss: 0.0739, Val Dice: 0.7587, Val IoU: 0.6186
Best Dice: 0.7624
