In [None]:
from PIL import Image, ImageDraw, ImageFont
from glob import glob
import os, sys, random
import pandas as pd
import numpy as np
from shutil import copyfile
import matplotlib.pyplot as plt

import torch
import torchvision
from torch.utils.data import Dataset, DataLoader
import xml.etree.ElementTree as ET
import torchvision.transforms as T
from torchvision.models.detection import FasterRCNN
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
from torchvision.models.detection.transform import GeneralizedRCNNTransform

import utils
from engine import train_one_epoch, evaluate

# Путь к данным
images_path = './train/images'
annotations_path = './train/annotations'

# Определяем количество классов: 1 (фон) + 3 (Text, Math, Image) = 4
num_classes = 4

# Словарь сопоставления имён классов индексам
label_map = {
    "Text": 1,
    "Math": 2,
    "Image": 3
}

class VOCDataset(Dataset):
    def __init__(self, images_path, annotations_path, transforms=None):
        self.images_path = images_path
        self.annotations_path = annotations_path
        self.transforms = transforms

        annotations = glob(os.path.join(self.annotations_path, '*.xml'))
        self.data = self._load_annotations(annotations)

    def _load_annotations(self, annotations):
        data = []
        for file in annotations:
            filename = os.path.basename(file).replace('.xml', '.jpg')
            parsedXML = ET.parse(file)
            for node in parsedXML.getroot().iter('object'):
                obj_class = node.find('name').text
                xmin = int(node.find('bndbox/xmin').text)
                xmax = int(node.find('bndbox/xmax').text)
                ymin = int(node.find('bndbox/ymin').text)
                ymax = int(node.find('bndbox/ymax').text)
                data.append([filename, obj_class, xmin, xmax, ymin, ymax])
        return pd.DataFrame(data, columns=['filename', 'fields', 'xmin', 'xmax', 'ymin', 'ymax'])

    def __getitem__(self, idx):
        record = self.data[self.data['filename'] == self.data.iloc[idx]['filename']]
        img_path = os.path.join(self.images_path, record['filename'].iloc[0])
    
        if not os.path.exists(img_path):
            print(f"File not found: {img_path}")
            return self.__getitem__((idx + 1) % len(self.data))
    
        img = Image.open(img_path).convert("L")
    
        boxes = []
        labels = []
    
        for _, row in record.iterrows():
            boxes.append([row['xmin'], row['ymin'], row['xmax'], row['ymax']])
            class_name = row['fields']
            labels.append(label_map.get(class_name, 0))
    
        boxes = torch.as_tensor(boxes, dtype=torch.float32)
        labels = torch.as_tensor(labels, dtype=torch.int64)
    
        target = {}
        target["boxes"] = boxes
        target["labels"] = labels
        target["image_id"] = torch.tensor([idx])
    
        if self.transforms is not None:
            img, target = self.transforms(img, target)
    
        return img, target

    def __len__(self):
        return len(self.data)


class Compose:
    def __init__(self, transforms):
        self.transforms = transforms

    def __call__(self, image, target):
        for t in self.transforms:
            image, target = t(image, target)
        return image, target

class ToTensor:
    def __call__(self, image, target):
        image = T.ToTensor()(image)
        return image, target

class RandomVerticalFlip:
    def __init__(self, p=0.2):
        self.p = p

    def __call__(self, image, target):
        if torch.rand(1) < self.p:
            image = T.functional.vflip(image)
            if "boxes" in target:
                bbox = target["boxes"]
                # Обновляем координаты ограничивающих прямоугольников
                # Высота изображения:
                height = image.shape[1]
                bbox[:, [1, 3]] = height - bbox[:, [3, 1]]
                target["boxes"] = bbox
        return image, target

def get_transform(train):
    transforms = [ToTensor()]
    if train:
        transforms.append(RandomVerticalFlip(0.2))
    return Compose(transforms)


# Загружаем предобученную модель
model = torchvision.models.detection.fasterrcnn_resnet50_fpn(pretrained=True)
in_features = model.roi_heads.box_predictor.cls_score.in_features
model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes)


# Создаём датасет и делим его на train и test
full_dataset = VOCDataset(images_path, annotations_path, get_transform(train=True))

torch.manual_seed(1)
train_size = int(0.8 * len(full_dataset))
test_size = len(full_dataset) - train_size
train_dataset, test_dataset = torch.utils.data.random_split(full_dataset, [train_size, test_size])

train_loader = DataLoader(train_dataset, batch_size=2, shuffle=True, num_workers=2, collate_fn=utils.collate_fn)
test_loader = DataLoader(test_dataset, batch_size=2, shuffle=False, num_workers=2, collate_fn=utils.collate_fn)

device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
model.to(device)

params = [p for p in model.parameters() if p.requires_grad]
optimizer = torch.optim.SGD(params, lr=0.005, momentum=0.9, weight_decay=0.0005)
lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=3, gamma=0.1)

def save_checkpoint(epoch, model, optimizer, scheduler, path):
    os.makedirs(os.path.dirname(path), exist_ok=True)
    checkpoint = {
        'epoch': epoch,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'scheduler_state_dict': scheduler.state_dict()
    }
    torch.save(checkpoint, path)

def load_checkpoint(path, model, optimizer, scheduler):
    if os.path.exists(path):
        checkpoint = torch.load(path)
        model.load_state_dict(checkpoint['model_state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
        return checkpoint['epoch']
    else:
        print(f"Контрольная точка по пути {path} не найдена!")
        return 0

def plot_metrics(metrics, metric_name):
    plt.figure(figsize=(10, 5))
    plt.plot(metrics, label=metric_name)
    plt.xlabel("Epoch")
    plt.ylabel(metric_name)
    plt.title(f"{metric_name} Curve")
    plt.legend()
    plt.grid(True)
    plt.show()

losses = []
classifier_losses = []
box_reg_losses = []
objectness_losses = []
rpn_box_reg_losses = []

num_epochs = 10
save_path = './model_checkpoint.pth'
start_epoch = load_checkpoint(save_path, model, optimizer, lr_scheduler)

for epoch in range(start_epoch, num_epochs):
    metric_logger = train_one_epoch(model, optimizer, train_loader, device, epoch, print_freq=100)

    epoch_losses = metric_logger.meters["loss"].global_avg
    epoch_loss_classifier = metric_logger.meters["loss_classifier"].global_avg
    epoch_loss_box_reg = metric_logger.meters["loss_box_reg"].global_avg
    epoch_loss_objectness = metric_logger.meters["loss_objectness"].global_avg
    epoch_loss_rpn_box_reg = metric_logger.meters["loss_rpn_box_reg"].global_avg

    losses.append(epoch_losses)
    classifier_losses.append(epoch_loss_classifier)
    box_reg_losses.append(epoch_loss_box_reg)
    objectness_losses.append(epoch_loss_objectness)
    rpn_box_reg_losses.append(epoch_loss_rpn_box_reg)

    lr_scheduler.step()
    save_checkpoint(epoch, model, optimizer, lr_scheduler, save_path)

def save_trained_model(model, path):
    torch.save(model.state_dict(), path)

trained_model_path = './model_checkpoint.pth'
save_trained_model(model, trained_model_path)