In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F



In [2]:

def one_hot(label, n_classes, requires_grad=True):
    print(f'{label.shape=}')
    print(f'{label=}')
    print(f'{n_classes=}')
    
    one_hot_label = torch.eye(
        n_classes, device=device, requires_grad=requires_grad)[label]
    one_hot_label = one_hot_label.transpose(1, 3).transpose(2, 3)

    return one_hot_label


In [3]:
# a = torch.randint(0, 2, (8, 16, 16))

# print(a)

# b = torch.eye(1)

# print(b[a])

In [4]:


class BoundaryLoss(nn.Module):
    """Boundary Loss proposed in:
    Alexey Bokhovkin et al., Boundary Loss for Remote Sensing Imagery Semantic Segmentation
    https://arxiv.org/abs/1905.07852
    """

    def __init__(self, theta0=3, theta=5):
        super().__init__()

        self.theta0 = theta0
        self.theta = theta

    def forward(self, pred, gt):
        """
        Input:
            - pred: the output from model (before softmax)
                    shape (N, C, H, W)
            - gt: ground truth map
                    shape (N, H, w)
        Return:
            - boundary loss, averaged over mini-bathc
        """

        n, c, _, _ = pred.shape

        # softmax so that predicted map can be distributed in [0, 1]
        pred = torch.softmax(pred, dim=1)

        # one-hot vector of ground truth
        one_hot_gt = one_hot(gt, c)

        # boundary map
        gt_b = F.max_pool2d(
            1 - one_hot_gt, kernel_size=self.theta0, stride=1, padding=(self.theta0 - 1) // 2)
        gt_b -= 1 - one_hot_gt

        pred_b = F.max_pool2d(
            1 - pred, kernel_size=self.theta0, stride=1, padding=(self.theta0 - 1) // 2)
        pred_b -= 1 - pred

        # extended boundary map
        gt_b_ext = F.max_pool2d(
            gt_b, kernel_size=self.theta, stride=1, padding=(self.theta - 1) // 2)

        pred_b_ext = F.max_pool2d(
            pred_b, kernel_size=self.theta, stride=1, padding=(self.theta - 1) // 2)

        # reshape
        gt_b = gt_b.view(n, c, -1)
        pred_b = pred_b.view(n, c, -1)
        gt_b_ext = gt_b_ext.view(n, c, -1)
        pred_b_ext = pred_b_ext.view(n, c, -1)

        # Precision, Recall
        P = torch.sum(pred_b * gt_b_ext, dim=2) / (torch.sum(pred_b, dim=2) + 1e-7)
        R = torch.sum(pred_b_ext * gt_b, dim=2) / (torch.sum(gt_b, dim=2) + 1e-7)

        # Boundary F1 Score
        BF1 = 2 * P * R / (P + R + 1e-7)

        # summing BF1 Score for each class and average over mini-batch
        loss = torch.mean(1 - BF1)

        return loss



In [6]:

# for debug
import torch.optim as optim
from torchvision.models import segmentation

device = 'cuda' if torch.cuda.is_available() else 'cpu'

img = torch.randn(8, 3, 224, 224).to(device)
gt = torch.randint(0, 2, (8, 224, 224)).to(device)

model = segmentation.fcn_resnet50(num_classes=2).to(device)

optimizer = optim.Adam(model.parameters())
criterion = BoundaryLoss()

y = model(img)

loss = criterion(y['out'], gt)

optimizer.zero_grad()

loss.backward()
optimizer.step()

print(loss)

label.shape=torch.Size([8, 224, 224])
label=tensor([[[1, 1, 1,  ..., 1, 1, 0],
         [0, 1, 1,  ..., 0, 0, 0],
         [0, 1, 0,  ..., 0, 0, 0],
         ...,
         [1, 1, 0,  ..., 0, 0, 0],
         [0, 1, 0,  ..., 1, 0, 0],
         [0, 1, 0,  ..., 1, 1, 0]],

        [[1, 1, 1,  ..., 0, 1, 0],
         [1, 0, 1,  ..., 1, 1, 1],
         [0, 0, 1,  ..., 0, 1, 1],
         ...,
         [1, 1, 1,  ..., 0, 0, 1],
         [1, 0, 1,  ..., 1, 1, 0],
         [0, 0, 0,  ..., 1, 1, 1]],

        [[0, 0, 1,  ..., 0, 0, 1],
         [0, 1, 1,  ..., 1, 1, 0],
         [1, 0, 0,  ..., 1, 1, 0],
         ...,
         [1, 1, 0,  ..., 1, 0, 0],
         [1, 0, 1,  ..., 1, 1, 1],
         [0, 1, 1,  ..., 1, 0, 1]],

        ...,

        [[0, 1, 1,  ..., 0, 1, 0],
         [1, 0, 1,  ..., 1, 0, 1],
         [0, 1, 1,  ..., 1, 1, 0],
         ...,
         [1, 0, 0,  ..., 1, 1, 1],
         [1, 1, 0,  ..., 0, 0, 1],
         [1, 1, 0,  ..., 1, 1, 0]],

        [[1, 0, 1,  ..., 1, 1, 1],
   

In [20]:
a = torch.arange(30).reshape(5, 2, 3)

In [21]:
a = a.max(2)[0]

In [22]:
a

tensor([[ 2,  5],
        [ 8, 11],
        [14, 17],
        [20, 23],
        [26, 29]])