In [1]:
import numpy as np

import torch
from torch import nn
import torch.nn.functional as F

import matplotlib.pyplot as plt

from lovasz_losses import lovasz_softmax
from lovasz_losses_fast import LovaszSoftmaxFast

In [2]:
Niter = 50
B, N, H, W = 16, 20, 512, 512

class Model(nn.Module):
    def __init__(self, B, N, H, W, *args, **kwargs):
        super().__init__(*args, **kwargs)

        self.x = torch.nn.Parameter(torch.rand((B, N, H, W), requires_grad=True))

    def forward(self, *args, **kwargs):
        return self.x

X = Model(B, N, H, W).cuda()
y = (N * torch.rand(B, H, W, device='cuda')).long()

In [3]:
lovasz_softmax_fast = LovaszSoftmaxFast(N).cuda()

b_fast = lovasz_softmax_fast(F.softmax(X(), dim=1), y)
b_fast.backward()
grad_fast = torch.clone(X.x.grad)

b_original = lovasz_softmax(F.softmax(X(), dim=1), y)
b_original.backward()
grad_original = torch.clone(X.x.grad)

with torch.no_grad():
    print(f"Loss\n\tOriginal:\t{b_original}\n\tFast:   \t{b_fast}\n\tDifference:  \t{(b_original-b_fast).abs()}")

    delta_grad = (grad_fast - grad_original).abs()
    print(f"x Gradients")
    print(f"\tMin:\t\t{delta_grad.min()}")
    print(f"\tMedian:\t\t{delta_grad.median()}")
    print(f"\tMax:\t\t{delta_grad.max()}")

del b_fast, grad_fast, b_original, grad_original, lovasz_softmax_fast
torch.cuda.empty_cache()

Loss
	Original:	0.9499844908714294
	Fast:   	0.9499836564064026
	Difference:  	8.344650268554688e-07
x Gradients
	Min:		1.3363890549733526e-10
	Median:		5.667801783459936e-10
	Max:		2.1741865197100196e-08


In [4]:
def time(func, X, y):
    times = []
    for iter in range(Niter):
        torch.cuda.synchronize()

        if iter >= 10:    
            start = torch.cuda.Event(enable_timing=True)
            end = torch.cuda.Event(enable_timing=True)     
            start.record()

        b = func(F.softmax(X(), dim=-1), y)
        b.backward()

        if iter >= 10:
            end.record()
            torch.cuda.synchronize()
            times.append(start.elapsed_time(end))

        del b, X.x.grad
        torch.cuda.empty_cache()

    print(f"{func} Timings:")
    print(f"\tMin:\t\t{np.min(times):.2f} ms")
    print(f"\tMedian:\t\t{np.median(times):.2f} ms")
    print(f"\tMax:\t\t{np.max(times):.2f} ms")


time(LovaszSoftmaxFast(N).cuda(), X, y)
time(lovasz_softmax, X, y)

LovaszSoftmaxFast() Timings:
	Min:		129.00 ms
	Median:		129.17 ms
	Max:		131.17 ms
<function lovasz_softmax at 0x7fe4e18885e0> Timings:
	Min:		268.23 ms
	Median:		268.84 ms
	Max:		272.30 ms
