In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class PartialCrossEntropyLoss(nn.Module):
    def __init__(self):
        super(PartialCrossEntropyLoss, self).__init__()

    def forward(self, predictions, targets, mask):
        # Compute focal loss
        ce_loss = F.cross_entropy(predictions, targets, reduction='none')
        pt = torch.exp(-ce_loss)
        focal_loss = (1 - pt) ** 2 * ce_loss

        # Apply mask
        masked_loss = focal_loss * mask

        # Normalize
        loss = masked_loss.sum() / mask.sum()
        return loss

In [5]:
import torch
import torchvision
from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision.datasets import VOCSegmentation

# Load dataset
transform = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.ToTensor(),  # Convert PIL image to tensor
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

dataset = VOCSegmentation(root='./data', year='2012', image_set='train', download=True, transform=transform)
dataloader = DataLoader(dataset, batch_size=4, shuffle=True)

# Define a simple segmentation network
class SimpleSegmentationNet(nn.Module):
    def __init__(self):
        super(SimpleSegmentationNet, self).__init__()
        self.backbone = torchvision.models.resnet18(pretrained=True)
        self.conv = nn.Conv2d(512, 21, kernel_size=1)

    def forward(self, x):
        x = self.backbone(x)
        x = self.conv(x)
        return x

# Training loop
model = SimpleSegmentationNet()
criterion = PartialCrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

for epoch in range(10):
    for images, targets in dataloader:
        # Simulate point labels
        mask = (targets != 255).float()

        # Forward pass
        outputs = model(images)
        loss = criterion(outputs, targets, mask)

        # Backward pass and optimization
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    print(f'Epoch [{epoch+1}/10], Loss: {loss.item():.4f}')

Using downloaded and verified file: ./data/VOCtrainval_11-May-2012.tar
Extracting ./data/VOCtrainval_11-May-2012.tar to ./data


TypeError: default_collate: batch must contain tensors, numpy arrays, numbers, dicts or lists; found <class 'PIL.PngImagePlugin.PngImageFile'>