In [None]:
# segmentation
# https://grok.com/chat/fb747915-a406-41b4-9b29-c5e6cc86eb49

import os
import json
import base64
import torch
import torch.nn as nn
import torchvision
from torchvision.models.detection import maskrcnn_resnet50_fpn
from torchvision.transforms import ToTensor
from torch.utils.data import Dataset, DataLoader
from PIL import Image
import numpy as np
from io import BytesIO
from tqdm import tqdm


# Кастомный датасет
class SegmentationDataset(Dataset):
    def __init__(self, img_dir, annotation_file, transform=None):
        self.img_dir = img_dir
        self.transform = transform
        # Загружаем аннотации
        with open(annotation_file, "r") as f:
            self.annotations = json.load(f)
        self.imgs = list(self.annotations.keys())

    def __len__(self):
        return len(self.imgs)

    def __getitem__(self, idx):
        img_name = self.imgs[idx]
        img_path = os.path.join(self.img_dir, img_name)
        img = Image.open(img_path).convert("RGB")

        # Получаем аннотации
        annot = self.annotations[img_name]
        boxes = []
        labels = []
        masks = []

        for obj in annot["objects"]:
            # Bounding box: [x_min, y_min, x_max, y_max]
            bbox = obj["bbox"]
            boxes.append([bbox["x_min"], bbox["y_min"], bbox["x_max"], bbox["y_max"]])

            # Метка класса
            labels.append(obj["label"])

            # Декодируем Base64-маску
            mask_data = base64.b64decode(obj["base64_mask"])
            mask_img = Image.open(BytesIO(mask_data)).convert("L")  # L - grayscale
            mask = np.array(mask_img, dtype=np.uint8)
            # Преобразуем в бинарную маску (0 или 1)
            mask = (mask > 0).astype(np.uint8)
            masks.append(mask)

        boxes = torch.as_tensor(boxes, dtype=torch.float32)
        labels = torch.as_tensor(labels, dtype=torch.int64)
        masks = torch.as_tensor(masks, dtype=torch.uint8)

        # Формируем target для Mask R-CNN
        target = {
            "boxes": boxes,
            "labels": labels,
            "masks": masks,
            "image_id": torch.tensor([idx]),
        }

        if self.transform:
            img = self.transform(img)

        return img, target


# Функция для создания модели
def get_model(num_classes):
    # Загружаем предобученную Mask R-CNN
    model = maskrcnn_resnet50_fpn(pretrained=True)

    # Заменяем голову классификатора
    in_features = model.roi_heads.box_predictor.cls_score.in_features
    model.roi_heads.box_predictor = (
        torchvision.models.detection.faster_rcnn.FastRCNNPredictor(
            in_features, num_classes
        )
    )

    # Заменяем голову для масок
    in_features_mask = model.roi_heads.mask_predictor.conv5_mask.in_channels
    model.roi_heads.mask_predictor = (
        torchvision.models.detection.mask_rcnn.MaskRCNNPredictor(
            in_features_mask, 256, num_classes
        )
    )

    return model


# Функция для тренировки
def train_model(model, data_loader, optimizer, device, num_epochs=10):
    model.train()
    for epoch in range(num_epochs):
        print(f"Epoch {epoch+1}/{num_epochs}")
        running_loss = 0.0
        for images, targets in tqdm(data_loader):
            images = list(image.to(device) for image in images)
            targets = [{k: v.to(device) for k, v in t.items()} for t in targets]

            # Вычисляем потери
            loss_dict = model(images, targets)
            losses = sum(loss for loss in loss_dict.values())

            # Обновляем параметры
            optimizer.zero_grad()
            losses.backward()
            optimizer.step()

            running_loss += losses.item()

        print(f"Epoch {epoch+1}, Loss: {running_loss/len(data_loader)}")


# Функция для валидации
def evaluate_model(model, data_loader, device):
    model.eval()
    with torch.no_grad():
        for images, targets in tqdm(data_loader):
            images = list(image.to(device) for image in images)
            outputs = model(images)
            # Пример вывода предсказаний
            print("Sample predictions:", outputs[0])


def main():
    # Параметры
    img_dir = "path/to/images"  # Путь к изображениям
    annotation_file = "path/to/annotations.json"  # Путь к аннотациям
    num_classes = 2  # Количество классов + 1 (фон)
    batch_size = 2  # Уменьшен, т.к. маски требуют больше памяти
    num_epochs = 10
    learning_rate = 0.005

    # Устройство
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")

    # Датасет и загрузчики
    dataset = SegmentationDataset(img_dir, annotation_file, transform=ToTensor())
    data_loader = DataLoader(
        dataset,
        batch_size=batch_size,
        shuffle=True,
        collate_fn=lambda x: tuple(zip(*x)),
    )

    # Валидационный датасет (для примера тот же, в реальной задаче разделите)
    dataset_val = SegmentationDataset(img_dir, annotation_file, transform=ToTensor())
    data_loader_val = DataLoader(
        dataset_val,
        batch_size=batch_size,
        shuffle=False,
        collate_fn=lambda x: tuple(zip(*x)),
    )

    # Модель
    model = get_model(num_classes)
    model.to(device)

    # Оптимизатор
    params = [p for p in model.parameters() if p.requires_grad]
    optimizer = torch.optim.SGD(
        params, lr=learning_rate, momentum=0.9, weight_decay=0.0005
    )

    # Тренировка
    train_model(model, data_loader, optimizer, device, num_epochs)

    # Валидация
    evaluate_model(model, data_loader_val, device)

    # Сохранение модели
    torch.save(model.state_dict(), "mask_rcnn_model.pth")
    print("Model saved as mask_rcnn_model.pth")


if __name__ == "__main__":
    main()