/
losses.py
36 lines (29 loc) · 1.19 KB
/
losses.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
import torch
# output and target must be of size m x n (grid points by samples)
def relative_error(output, target, quadrature=None, agg=None, dim=None):
loss = torch.sqrt(relative_squared_error(output, target, quadrature, agg, dim))
return loss
def relative_squared_error(output, target, quadrature=None, agg=None, dim=None):
if dim is None:
dim = tuple(range(1, output.ndim))
if quadrature is None:
loss = torch.nansum((output - target)**2, dim=dim) / torch.nansum(target**2, dim=dim)
else:
loss = torch.nansum((output - target)**2 * quadrature, dim=dim) / torch.nansum(target**2 * quadrature, dim=dim)
if agg == "mean":
loss = torch.mean(loss)
elif agg == "sum":
loss = torch.sum(loss)
return loss
def squared_error(output, target, quadrature=None, agg=None, dim=None):
if dim is None:
dim = tuple(range(1, output.ndim))
if quadrature is None:
loss = torch.nansum((output - target)**2, dim=dim)
else:
loss = torch.nansum((output - target)**2 * quadrature, dim=dim)
if agg == "mean":
loss = torch.mean(loss)
elif agg == "sum":
loss = torch.sum(loss)
return loss