-
Notifications
You must be signed in to change notification settings - Fork 1
/
loss.py
137 lines (115 loc) · 5 KB
/
loss.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
# list all the additional loss functions
import torch
import torch.nn as nn
import torch.nn.functional as F
################## entropy loss (continuous target) #####################
def cross_entropy_soft(pred):
softmax = nn.Softmax(dim=1)
logsoftmax = nn.LogSoftmax(dim=1)
loss = torch.mean(torch.sum(-softmax(pred) * logsoftmax(pred), 1))
return loss
################## attentive entropy loss (source + target) #####################
def attentive_entropy(pred, pred_domain):
softmax = nn.Softmax(dim=1)
logsoftmax = nn.LogSoftmax(dim=1)
# attention weight
entropy = torch.sum(-softmax(pred_domain) * logsoftmax(pred_domain), 1)
weights = 1 + entropy
# attentive entropy
loss = torch.mean(weights * torch.sum(-softmax(pred) * logsoftmax(pred), 1))
return loss
################## ensemble-based loss #####################
# discrepancy loss used in MCD (CVPR 18)
def dis_MCD(out1, out2):
return torch.mean(torch.abs(F.softmax(out1,dim=1) - F.softmax(out2, dim=1)))
################## MMD-based loss #####################
def mmd_linear(f_of_X, f_of_Y):
# Consider linear time MMD with a linear kernel:
# K(f(x), f(y)) = f(x)^Tf(y)
# h(z_i, z_j) = k(x_i, x_j) + k(y_i, y_j) - k(x_i, y_j) - k(x_j, y_i)
# = [f(x_i) - f(y_i)]^T[f(x_j) - f(y_j)]
#
# f_of_X: batch_size * k
# f_of_Y: batch_size * k
delta = f_of_X - f_of_Y
loss = torch.mean(torch.mm(delta, torch.transpose(delta, 0, 1)))
return loss
def guassian_kernel(source, target, kernel_mul=2.0, kernel_num=5, fix_sigma=None):
n_samples = int(source.size()[0])+int(target.size()[0])
total = torch.cat([source, target], dim=0)
total0 = total.unsqueeze(0).expand(int(total.size(0)), int(total.size(0)), int(total.size(1)))
total1 = total.unsqueeze(1).expand(int(total.size(0)), int(total.size(0)), int(total.size(1)))
L2_distance = ((total0-total1)**2).sum(2)
if fix_sigma:
bandwidth = fix_sigma
else:
bandwidth = torch.sum(L2_distance.data) / (n_samples**2-n_samples)
bandwidth /= kernel_mul ** (kernel_num // 2)
bandwidth_list = [bandwidth * (kernel_mul**i) for i in range(kernel_num)]
kernel_val = [torch.exp(-L2_distance / bandwidth_temp) for bandwidth_temp in bandwidth_list]
return sum(kernel_val)
def mmd_rbf(source, target, kernel_mul=2.0, kernel_num=5, fix_sigma=None, ver=2):
batch_size = int(source.size()[0])
kernels = guassian_kernel(source, target, kernel_mul=kernel_mul, kernel_num=kernel_num, fix_sigma=fix_sigma)
loss = 0
if ver==1:
for i in range(batch_size):
s1, s2 = i, (i + 1) % batch_size
t1, t2 = s1 + batch_size, s2 + batch_size
loss += kernels[s1, s2] + kernels[t1, t2]
loss -= kernels[s1, t2] + kernels[s2, t1]
loss = loss.abs_() / float(batch_size)
elif ver==2:
XX = kernels[:batch_size, :batch_size]
YY = kernels[batch_size:, batch_size:]
XY = kernels[:batch_size, batch_size:]
YX = kernels[batch_size:, :batch_size]
loss = torch.mean(XX + YY - XY - YX)
else:
raise ValueError('ver == 1 or 2')
return loss
def JAN(source_list, target_list, kernel_muls=[2.0, 2.0], kernel_nums=[2, 5], fix_sigma_list=[None, None], ver=2):
batch_size = int(source_list[0].size()[0])
layer_num = len(source_list)
joint_kernels = None
for i in range(layer_num):
source = source_list[i]
target = target_list[i]
kernel_mul = kernel_muls[i]
kernel_num = kernel_nums[i]
fix_sigma = fix_sigma_list[i]
kernels = guassian_kernel(source, target,
kernel_mul=kernel_mul, kernel_num=kernel_num, fix_sigma=fix_sigma)
if joint_kernels is not None:
joint_kernels = joint_kernels * kernels
else:
joint_kernels = kernels
loss = 0
if ver==1:
for i in range(batch_size):
s1, s2 = i, (i + 1) % batch_size
t1, t2 = s1 + batch_size, s2 + batch_size
loss += joint_kernels[s1, s2] + joint_kernels[t1, t2]
loss -= joint_kernels[s1, t2] + joint_kernels[s2, t1]
loss = loss.abs_() / float(batch_size)
elif ver==2:
XX = joint_kernels[:batch_size, :batch_size]
YY = joint_kernels[batch_size:, batch_size:]
XY = joint_kernels[:batch_size, batch_size:]
YX = joint_kernels[batch_size:, :batch_size]
loss = torch.mean(XX + YY - XY - YX)
else:
raise ValueError('ver == 1 or 2')
return loss
def ivan_CEL(my_outputs, my_labels, weights):
#specifying the batch size
my_batch_size = my_outputs.size()[0]
#calculating the log of softmax values
my_outputs = F.log_softmax(my_outputs, dim=1)
#selecting the values that correspond to labels
my_outputs = my_outputs[range(my_batch_size), my_labels] * weights.float()
#returning the results
return -torch.sum(my_outputs)/my_batch_size
def H_loss(x):
b = F.softmax(x, dim=1) * F.log_softmax(x, dim=1)
return b.sum(dim=1).mean()