-
Notifications
You must be signed in to change notification settings - Fork 66
/
hm.py
74 lines (56 loc) · 2.5 KB
/
hm.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
import numpy as np
import torch
import torch.nn.functional as F
from torch.nn import init
from torch import nn, autograd
class HM(autograd.Function):
@staticmethod
def forward(ctx, inputs, indexes, features, momentum):
ctx.features = features
ctx.momentum = momentum
ctx.save_for_backward(inputs, indexes)
outputs = inputs.mm(ctx.features.t())
return outputs
@staticmethod
def backward(ctx, grad_outputs):
inputs, indexes = ctx.saved_tensors
grad_inputs = None
if ctx.needs_input_grad[0]:
grad_inputs = grad_outputs.mm(ctx.features)
# momentum update
for x, y in zip(inputs, indexes):
ctx.features[y] = ctx.momentum * ctx.features[y] + (1. - ctx.momentum) * x
ctx.features[y] /= ctx.features[y].norm()
return grad_inputs, None, None, None
def hm(inputs, indexes, features, momentum=0.5):
return HM.apply(inputs, indexes, features, torch.Tensor([momentum]).to(inputs.device))
class HybridMemory(nn.Module):
def __init__(self, num_features, num_samples, temp=0.05, momentum=0.2):
super(HybridMemory, self).__init__()
self.num_features = num_features
self.num_samples = num_samples
self.momentum = momentum
self.temp = temp
self.register_buffer('features', torch.zeros(num_samples, num_features))
self.register_buffer('labels', torch.zeros(num_samples).long())
def forward(self, inputs, indexes):
# inputs: B*2048, features: L*2048
inputs = hm(inputs, indexes, self.features, self.momentum)
inputs /= self.temp
B = inputs.size(0)
def masked_softmax(vec, mask, dim=1, epsilon=1e-6):
exps = torch.exp(vec)
masked_exps = exps * mask.float().clone()
masked_sums = masked_exps.sum(dim, keepdim=True) + epsilon
return (masked_exps/masked_sums)
targets = self.labels[indexes].clone()
labels = self.labels.clone()
sim = torch.zeros(labels.max()+1, B).float().cuda()
sim.index_add_(0, labels, inputs.t().contiguous())
nums = torch.zeros(labels.max()+1, 1).float().cuda()
nums.index_add_(0, labels, torch.ones(self.num_samples,1).float().cuda())
mask = (nums>0).float()
sim /= (mask*nums+(1-mask)).clone().expand_as(sim)
mask = mask.expand_as(sim)
masked_sim = masked_softmax(sim.t().contiguous(), mask.t().contiguous())
return F.nll_loss(torch.log(masked_sim+1e-6), targets)