In [None]:
import os
import torch
from PIL import Image
from torch.utils.data import Dataset, DataLoader
from torchvision.transforms import ToTensor
from torchvision.models.detection import fasterrcnn_resnet50_fpn
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor

# Configurar o dispositivo (GPU ou CPU)
device = torch.device("cuda:" if torch.cuda.is_available() else "cpu")

# Dataset Personalizado
class CustomVOCDataset(Dataset):
    def __init__(self, images_folder, annotations_folder, transform=None):
        self.images_folder = images_folder
        self.annotations_folder = annotations_folder
        self.transform = transform
        self.image_files = [
            f for f in os.listdir(images_folder) if f.endswith(('.png', '.jpg', '.jpeg'))
        ]

    def __len__(self):
        return len(self.image_files)

    def __getitem__(self, idx):
        image_file = self.image_files[idx]
        image_path = os.path.join(self.images_folder, image_file)
        annotation_path = os.path.join(self.annotations_folder, os.path.splitext(image_file)[0] + ".xml")

        # Carregar imagem
        image = Image.open(image_path).convert("RGB")

        # Carregar anotações (parsing XML)
        boxes, labels = self.parse_annotation(annotation_path)

        if self.transform:
            image = self.transform(image)

        target = {
            "boxes": torch.tensor(boxes, dtype=torch.float32),
            "labels": torch.tensor(labels, dtype=torch.int64)
        }

        return image, target

    def parse_annotation(self, annotation_path):
        import xml.etree.ElementTree as ET
        tree = ET.parse(annotation_path)
        root = tree.getroot()
        boxes = []
        labels = []
        for obj in root.findall("object"):
            name = obj.find("name").text
            bndbox = obj.find("bndbox")
            xmin = int(bndbox.find("xmin").text)
            ymin = int(bndbox.find("ymin").text)
            xmax = int(bndbox.find("xmax").text)
            ymax = int(bndbox.find("ymax").text)
            boxes.append([xmin, ymin, xmax, ymax])
            labels.append(1)  # Usando 1 para "bar_chart"
        return boxes, labels

# Função collate personalizada
def collate_fn(batch):
    return tuple(zip(*batch))

# Caminhos para treino
train_images_folder = "..."
train_annotations_folder =  "..."

# Criar Dataset e DataLoader
train_dataset = CustomVOCDataset(
    images_folder=train_images_folder,
    annotations_folder=train_annotations_folder,
    transform=ToTensor()
)

train_loader = DataLoader(
    train_dataset,
    batch_size=16,
    shuffle=True,
    collate_fn=collate_fn  # Usar a função collate personalizada
)

# Modelo Faster R-CNN
# Carregar modelo pré-treinado
model = fasterrcnn_resnet50_fpn(pretrained=True)

# Configurar o número de classes (1 classe + fundo)
num_classes = 2  # Classe "bar_chart" + "background"

# Substituir a camada preditora do modelo
in_features = model.roi_heads.box_predictor.cls_score.in_features
model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes)

# Enviar modelo para o dispositivo
model.to(device)

# Configuração do Otimizador
optimizer = torch.optim.SGD(model.parameters(), lr=0.001, momentum=0.9, weight_decay=0.0005)

# Loop de Treinamento
num_epochs = 20
for epoch in range(num_epochs):
    model.train()
    epoch_loss = 0

    for images, targets in train_loader:
        # Enviar dados para o dispositivo
        images = [img.to(device) for img in images]
        targets = [{k: v.to(device) for k, v in t.items()} for t in targets]

        # Calcular perda
        loss_dict = model(images, targets)
        losses = sum(loss for loss in loss_dict.values())
        epoch_loss += losses.item()

        # Backpropagation
        optimizer.zero_grad()
        losses.backward()
        optimizer.step()

    print(f"Época {epoch + 1}, Perda: {epoch_loss:.4f}")

# Salvar o modelo treinado
torch.save(model.state_dict(), "faster_rcnn_bar_chart.pth")
print("Modelo salvo com sucesso!")

# Validação
print("Treinamento concluído. Teste o modelo em novas imagens.")