# Multiclass segmentation: PyTorch version
We want to optimize the mean Jaccard index of the non-void classes.

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
from __future__ import division, print_function

In [None]:
import torch
from torch.autograd import Variable
import torch.nn.functional as F
from torch.optim import Adam
from torch import nn

In [None]:
import lovasz_losses as L
import lovasz_losses_fast as FastL

In [None]:
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline

In [None]:
from tqdm.auto import tqdm

In [None]:
# add parent path to pythonpath to import demo_utils
import sys
sys.path.append('../demo_helpers')
from demo_utils import pil, pil_grid, dummy_triangles

In [None]:
from collections import OrderedDict as OD

In [None]:
CUDA = True

In [None]:
def cuda(x):
    if CUDA:
        return x.cuda()
    else:
        return x

## Multiclass case: batch of 5 random images with classes [0, 1, 2] and void (255)
We assume that we want to optimize the Jaccard index of all non-void classes 0, 1 and 2.

### Labels

In [None]:
np.random.seed(18)
C = 3
B = 5
H = 200
labels_ = [dummy_triangles(H, [0, 255, 1, 2]) for b in range(B)]
labels = torch.stack([torch.from_numpy(a) for a in map(np.array, labels_)]).long()
pil_grid(labels_, 5, margin=1)

### Features

In [None]:
np.random.seed(57)
B, H, W = labels.size()
labels_ = labels.clone()
labels_[labels_ == 255] = labels_[labels_ == 255].random_(C) # random feats for void
labels_1hot = torch.zeros(B, C, H, W)
labels_1hot.scatter_(1, labels_.unsqueeze(1), 1);

In [None]:
feats = labels_1hot.clone().float()
feats += feats.new(feats.size()).normal_(0, 2) # additive gaussian noise

In [None]:
labels = cuda(labels)
feats = Variable(cuda(feats))

## Model

### definition

In [None]:
class Model(nn.Module):
    def __init__(self):
        super(Model, self).__init__()
        self.conv = nn.Conv2d(C, C, 3, padding=1)
        
    def forward(self, x):
        return x + self.conv(x)

In [None]:
m = Model()
init_state_dict = OD([(n, p.clone()) for (n, p) in m.state_dict().items()])

In [None]:
m = cuda(m)

### Initial prediction

In [None]:
_, preds = m(feats).data.max(1)
print("Initial batch-IoUs:", L.iou(preds, labels, C, ignore=255, per_image=False))
print("Initial mIoU:", np.mean(L.iou(preds, labels, C, ignore=255, per_image=False)))

In [None]:
pil_grid([pil(f.byte().cpu().numpy()) for f in preds], 5, margin=1)

## Lovász-Softmax training

In [None]:
m.load_state_dict(init_state_dict)

optimizer = Adam(m.parameters(), lr=0.005)
lovasz_softmax_fast = FastL.LovaszSoftmaxFast(C, ignore_index=255).cuda()

loss_plot = []
iou_plot = []

for iter in tqdm(range(1000)):
    optimizer.zero_grad()
    out = m(feats)
    out = F.softmax(out, dim=1)
    loss = lovasz_softmax_fast(out, labels)
    loss.backward()
    optimizer.step()
    
    _, preds = out.data.max(1)
    loss_plot.append(loss.item())
    iou_plot.append(L.iou_binary(preds, labels, ignore=255, per_image=False))

In [None]:
plt.figure(figsize=(10, 5))
plt.suptitle(u'Lovász-Softmax training')
plt.subplot(1, 2, 1)
plt.plot(loss_plot)
plt.ylabel('loss')
plt.xlabel('iteration')

plt.subplot(1, 2, 2)
plt.plot(iou_plot)
plt.ylabel('Image-IoU (%)')
plt.xlabel('iteration')

### Final prediction

In [None]:
_, preds = m(feats).data.max(1)
print("Final batch-IoUs:", L.iou(preds, labels, C, ignore=255, per_image=False))
print("Final mIoU:", np.mean(L.iou(preds, labels, C, ignore=255, per_image=False)))

In [None]:
pil_grid([pil(f.byte().cpu().numpy()) for f in preds], 5, margin=1)

## Cross-entropy training

In [None]:
m.load_state_dict(init_state_dict)

optimizer = Adam(m.parameters(), lr=0.005)

loss_plot_x = []
iou_plot_x = []

for iter in tqdm(range(1000)):
    optimizer.zero_grad()
    out = m(feats)
    loss = L.xloss(out, labels, ignore=255)
    loss.backward()
    optimizer.step()
    
    _, preds = out.data.max(1)
    loss_plot_x.append(loss.item())
    iou_plot_x.append(L.iou_binary(preds, labels, ignore=255, per_image=False))

In [None]:
plt.figure(figsize=(10, 5))
plt.suptitle(u'Binary cross-entropy training')
plt.subplot(1, 2, 1)
plt.plot(loss_plot_x)
plt.ylabel('loss')
plt.xlabel('iteration')

plt.subplot(1, 2, 2)
plt.plot(iou_plot_x)
plt.ylabel('Image-IoU (%)')
plt.xlabel('iteration')

### Final prediction

In [None]:
_, preds = m(feats).data.max(1)
print("Final batch-IoUs:", L.iou(preds, labels, C, ignore=255, per_image=False))
print("Final mIoU:", np.mean(L.iou(preds, labels, C, ignore=255, per_image=False)))

In [None]:
pil_grid([pil(f.byte().cpu().numpy()) for f in preds], 5, margin=1)