-
Notifications
You must be signed in to change notification settings - Fork 11
/
mnist_loss.py
153 lines (109 loc) · 5.76 KB
/
mnist_loss.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
"""The loss function for the MNIST variant of the multitask experiment."""
from abc import abstractmethod, ABC
import torch
import torch.nn.functional as F
from torch import Tensor, nn
def _labels_to_1(labels, mnist_type: str):
"""Returns labels for task 1. We classify three way: class a, class b and other. a and b depend on the task."""
# 2 is class other.
converted = torch.full_like(labels, 2)
if mnist_type == 'numbers':
# 3 and 7 are notoriously hard to distinguish between.
converted[labels == 3] = 0
converted[labels == 7] = 1
elif mnist_type == 'fashion_pullover_coat':
# Try and distinguish between pullovers (2) and coats (4), which are very similar.
converted[labels == 2] = 0
converted[labels == 4] = 1
elif mnist_type == 'fashion_tshirt_shirt':
# Try and distinguish between tshirts (0) and shirts (6).
converted[labels == 0] = 0
converted[labels == 6] = 1
else:
raise ValueError(f'Unknown mnist_type: {mnist_type}')
return converted
def _compute_num_correct(preds: Tensor, labels: Tensor) -> int:
assert preds.shape == labels.shape
return (labels == preds).sum().item()
def compute_num_correct_task1(preds: Tensor, labels: Tensor, mnist_type: str) -> int:
"""Computes the number of correct predictions on a single batch for task 1"""
return _compute_num_correct(preds, _labels_to_1(labels, mnist_type))
def compute_num_correct_task2(preds: Tensor, labels: Tensor) -> int:
"""Computes the number of correct predictions on a single batch for task 2"""
return _compute_num_correct(preds, labels)
class MnistLossFunc(ABC):
@abstractmethod
def get_raw_loss(self, output: Tensor, labels: Tensor, original: Tensor) -> Tensor:
"""Return the unweighted loss, where the implementing class defines the exact loss function.
:param output of the model
:param labels for classification losses
:param original The original image, for reconstruction losses
"""
pass
@abstractmethod
def weight_loss(self, loss: Tensor) -> Tensor:
"""Weights the given loss appropriately. e.g. by a fixed weight or a learned uncertainty weight"""
pass
class FixedCELoss(MnistLossFunc):
def __init__(self, weight: float, class_map_func):
assert isinstance(weight, float)
self._weight = weight
self._class_map_func = class_map_func
def get_raw_loss(self, output: Tensor, labels: Tensor, _) -> Tensor:
return F.cross_entropy(output, self._class_map_func(labels))
def weight_loss(self, loss: Tensor) -> Tensor:
return self._weight * loss
class LearnedCELoss(MnistLossFunc):
def __init__(self, s: nn.Parameter, class_map_func):
assert isinstance(s, nn.Parameter)
self._s = s
self._class_map_func = class_map_func
def get_raw_loss(self, output: Tensor, labels: Tensor, _) -> Tensor:
return F.cross_entropy(output, self._class_map_func(labels))
def weight_loss(self, loss: Tensor) -> Tensor:
return torch.exp(-self._s) * loss + 0.5 * self._s
class FixedL1Loss(MnistLossFunc):
def __init__(self, weight: float):
assert isinstance(weight, float)
self._weight = weight
def get_raw_loss(self, output: Tensor, _, original: Tensor) -> Tensor:
return F.l1_loss(output, original)
def weight_loss(self, loss: Tensor) -> Tensor:
return self._weight * loss
class LearnedL1Loss(MnistLossFunc):
def __init__(self, s: nn.Parameter):
assert isinstance(s, nn.Parameter)
self._s = s
def get_raw_loss(self, output: Tensor, _, original: Tensor) -> Tensor:
return F.l1_loss(output, original)
def weight_loss(self, loss: Tensor) -> Tensor:
return 0.5 * torch.exp(-self._s) * loss + 0.5 * self._s
class MultitaskMnistLoss(ABC):
def __init__(self, enabled_tasks: [bool], loss_funcs: [MnistLossFunc]):
super().__init__()
assert len(enabled_tasks) == len(loss_funcs), f'enabled_tasks={enabled_tasks}, loss_funcs={loss_funcs}'
self._enabled_tasks = enabled_tasks
self._loss_funcs = loss_funcs
def __call__(self, outputs: [Tensor], labels: Tensor, original: Tensor):
"""Returns (overall loss, [task losses])"""
assert len(outputs) == len(self._enabled_tasks) == len(self._loss_funcs)
raw_losses = [
loss_func.get_raw_loss(output, labels, original) if enabled else torch.tensor([0.0], device=output.device)
for enabled, loss_func, output in zip(self._enabled_tasks, self._loss_funcs, outputs)]
weighted_losses = [loss_func.weight_loss(raw_loss) for loss_func, raw_loss in zip(self._loss_funcs, raw_losses)]
total_loss = weighted_losses[0] + weighted_losses[1] + weighted_losses[2]
return total_loss, (raw_losses[0], raw_losses[1], raw_losses[2])
def get_fixed_loss(enabled_tasks: [bool], weights: [float], mnist_type: str):
"""Returns the fixed weight loss function."""
task_1_loss_func = FixedCELoss(weights[0], lambda labels: _labels_to_1(labels, mnist_type))
task_2_loss_func = FixedCELoss(weights[1], lambda x: x)
task_3_loss_func = FixedL1Loss(weights[2])
return MultitaskMnistLoss(enabled_tasks, [task_1_loss_func, task_2_loss_func, task_3_loss_func])
def get_learned_loss(enabled_tasks: [bool], ses: [nn.Parameter], mnist_type: str):
"""Returns the learned uncertainties loss function.
:param ses s=log(sigma^2) for each task, as in the paper
"""
task_1_loss_func = LearnedCELoss(ses[0], lambda labels: _labels_to_1(labels, mnist_type))
task_2_loss_func = LearnedCELoss(ses[1], lambda x: x)
task_3_loss_func = LearnedL1Loss(ses[2])
return MultitaskMnistLoss(enabled_tasks, [task_1_loss_func, task_2_loss_func, task_3_loss_func])