# Segmentação Semântica - Oxford-IIIT Pet


In [None]:
!pip install torchmetrics opencv-python matplotlib tqdm

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Subset
from torchvision import datasets, transforms
import torchmetrics
import matplotlib.pyplot as plt
import numpy as np

# Configurações
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'PyTorch: {torch.__version__}')
print(f'Device: {device}')

# Seeds
torch.manual_seed(42)
np.random.seed(42)


In [None]:
# Transformações para imagem
image_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

# Transformações para máscara
mask_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor()
])

trainval_dataset = datasets.OxfordIIITPet(
    root='./data',
    split='trainval',
    target_types='segmentation',
    download=True,
    transform=image_transform,
    target_transform=mask_transform
)

test_dataset = datasets.OxfordIIITPet(
    root='./data',
    split='test',
    target_types='segmentation',
    download=True,
    transform=image_transform,
    target_transform=mask_transform
)

# Dividir trainval em train e val
trainval_size = len(trainval_dataset)
train_size = int(0.8 * trainval_size)  # 80% treino
val_size = trainval_size - train_size  # 20% validação

# Criar subsets
train_subset = Subset(trainval_dataset, range(train_size))
val_subset = Subset(trainval_dataset, range(train_size, trainval_size))
test_subset = Subset(test_dataset, range(min(200, len(test_dataset))))

print(f"✅ Dataset carregado!")
print(f"Train: {len(train_subset)} amostras")
print(f"Val: {len(val_subset)} amostras")
print(f"Test: {len(test_subset)} amostras")
print(f"Classes: Background, Animal, Border")


In [None]:
# # U-Net simples
# class UNet(nn.Module):
#     def __init__(self):
#         super(UNet, self).__init__()

#         # Encoder
#         self.inc = nn.Sequential(
#             nn.Conv2d(3, 64, 3, padding=1), nn.ReLU(),
#             nn.Conv2d(64, 64, 3, padding=1), nn.ReLU()
#         )
#         self.down1 = nn.Sequential(
#             nn.MaxPool2d(2),
#             nn.Conv2d(64, 128, 3, padding=1), nn.ReLU(),
#             nn.Conv2d(128, 128, 3, padding=1), nn.ReLU()
#         )
#         self.down2 = nn.Sequential(
#             nn.MaxPool2d(2),
#             nn.Conv2d(128, 256, 3, padding=1), nn.ReLU(),
#             nn.Conv2d(256, 256, 3, padding=1), nn.ReLU()
#         )

#         # Decoder
#         self.up1 = nn.Sequential(
#             nn.ConvTranspose2d(256, 128, 2, stride=2),
#             nn.Conv2d(256, 128, 3, padding=1), nn.ReLU(),
#             nn.Conv2d(128, 128, 3, padding=1), nn.ReLU()
#         )
#         self.up2 = nn.Sequential(
#             nn.ConvTranspose2d(128, 64, 2, stride=2),
#             nn.Conv2d(128, 64, 3, padding=1), nn.ReLU(),
#             nn.Conv2d(64, 64, 3, padding=1), nn.ReLU()
#         )
#         self.outc = nn.Conv2d(64, 3, 1)

#     def forward(self, x):
#         x1 = self.inc(x)
#         x2 = self.down1(x1)
#         x3 = self.down2(x2)

#         x = self.up1(x3)
#         x = torch.cat([x, x2], dim=1)
#         x = self.up2(x)
#         x = torch.cat([x, x1], dim=1)

#         return self.outc(x)

# model = UNet().to(device)
# print(f"Modelo: {sum(p.numel() for p in model.parameters()):,} parâmetros")


In [None]:
# U-Net corrigida
class UNet(nn.Module):
    def __init__(self):
        super(UNet, self).__init__()

        # Encoder
        self.inc = nn.Sequential(
            nn.Conv2d(3, 64, 3, padding=1), nn.ReLU(),
            nn.Conv2d(64, 64, 3, padding=1), nn.ReLU()
        )
        self.down1 = nn.Sequential(
            nn.MaxPool2d(2),
            nn.Conv2d(64, 128, 3, padding=1), nn.ReLU(),
            nn.Conv2d(128, 128, 3, padding=1), nn.ReLU()
        )
        self.down2 = nn.Sequential(
            nn.MaxPool2d(2),
            nn.Conv2d(128, 256, 3, padding=1), nn.ReLU(),
            nn.Conv2d(256, 256, 3, padding=1), nn.ReLU()
        )

        # Decoder
        # Camadas de upsampling separadas
        self.up_trans_1 = nn.ConvTranspose2d(256, 128, 2, stride=2)
        # Bloco de convolução que processa o tensor concatenado
        # A entrada terá 128 (de up_trans_1) + 128 (de x2) = 256 canais
        self.up_conv_1 = nn.Sequential(
            nn.Conv2d(256, 128, 3, padding=1), nn.ReLU(),
            nn.Conv2d(128, 128, 3, padding=1), nn.ReLU()
        )

        # Segunda camada de upsampling
        self.up_trans_2 = nn.ConvTranspose2d(128, 64, 2, stride=2)
        # Segundo bloco de convolução
        # A entrada terá 64 (de up_trans_2) + 64 (de x1) = 128 canais
        self.up_conv_2 = nn.Sequential(
            nn.Conv2d(128, 64, 3, padding=1), nn.ReLU(),
            nn.Conv2d(64, 64, 3, padding=1), nn.ReLU()
        )

        self.outc = nn.Conv2d(64, 3, 1) # 3 classes: Background, Animal, Border

    def forward(self, x):
        # Encoder
        x1 = self.inc(x)
        x2 = self.down1(x1)
        x3 = self.down2(x2)

        # Decoder
        x = self.up_trans_1(x3)
        x = torch.cat([x, x2], dim=1)
        x = self.up_conv_1(x)

        # Repete o processo para a próxima camada
        x = self.up_trans_2(x)
        x = torch.cat([x, x1], dim=1)
        x = self.up_conv_2(x)

        return self.outc(x)

model = UNet().to(device)
print(f"Modelo: {sum(p.numel() for p in model.parameters()):,} parâmetros")

In [None]:
# Treinamento
train_loader = DataLoader(train_subset, batch_size=16, shuffle=True)
val_loader = DataLoader(val_subset, batch_size=16, shuffle=False)
test_loader = DataLoader(test_subset, batch_size=16, shuffle=False)

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=1e-4)

# Métricas com torchmetrics
jaccard = torchmetrics.JaccardIndex(task='multiclass', num_classes=3).to(device)
accuracy = torchmetrics.Accuracy(task='multiclass', num_classes=3).to(device)

EPOCHS = 5
for epoch in range(EPOCHS):
    model.train()
    train_loss = 0

    for images, masks in train_loader:
        images, masks = images.to(device), masks.to(device)

        # Converter máscara para formato correto (B, H, W)
        if masks.dim() == 4:  # Se for (B, C, H, W)
            masks = masks.squeeze(1)  # Remove dimensão C se for 1

        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, masks.long())
        loss.backward()
        optimizer.step()

        train_loss += loss.item()

    # Validar
    model.eval()
    val_loss = 0
    val_iou = 0
    val_acc = 0

    with torch.no_grad():
        for images, masks in val_loader:
            images, masks = images.to(device), masks.to(device)

            # Converter máscara para formato correto
            if masks.dim() == 4:
                masks = masks.squeeze(1)

            outputs = model(images)
            loss = criterion(outputs, masks.long())

            preds = torch.argmax(outputs, dim=1)
            val_loss += loss.item()
            val_iou += jaccard(preds, masks.long()).item()
            val_acc += accuracy(preds, masks.long()).item()

    print(f"Epoch {epoch+1}: Loss={train_loss/len(train_loader):.4f}, "
          f"Val IoU={val_iou/len(val_loader):.4f}, Val Acc={val_acc/len(val_loader):.4f}")


In [None]:
# Visualização e Salvar Modelo
def show_results(dataset, num_samples=3):
    model.eval()
    with torch.no_grad():
        for i in range(num_samples):
            image, true_mask = dataset[i]
            image_batch = image.unsqueeze(0).to(device)

            pred = model(image_batch)
            pred_mask = torch.argmax(pred, dim=1).squeeze().cpu()

            # Desnormalizar imagem
            mean = torch.tensor([0.485, 0.456, 0.406]).view(3, 1, 1)
            std = torch.tensor([0.229, 0.224, 0.225]).view(3, 1, 1)
            image_denorm = image * std + mean
            image_denorm = torch.clamp(image_denorm, 0, 1)

            # Converter máscara para numpy se necessário
            if true_mask.dim() == 3:  # Se for (C, H, W)
                true_mask = true_mask.squeeze(0)  # Remove dimensão C
            true_mask_np = true_mask.numpy()

            fig, axes = plt.subplots(1, 3, figsize=(15, 5))
            axes[0].imshow(image_denorm.permute(1, 2, 0))
            axes[0].set_title('Original')
            axes[0].axis('off')

            axes[1].imshow(true_mask_np, cmap='tab10')
            axes[1].set_title('Ground Truth')
            axes[1].axis('off')

            axes[2].imshow(pred_mask, cmap='tab10')
            axes[2].set_title('Predição')
            axes[2].axis('off')

            plt.tight_layout()
            plt.show()

show_results(test_subset, 3)
torch.save(model.state_dict(), 'unet_pet_segmentation.pth')
