-
Notifications
You must be signed in to change notification settings - Fork 21
/
model.py
105 lines (82 loc) · 3.83 KB
/
model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
import math
from log_uniform import LogUniformSampler
import util
class SampledSoftmax(nn.Module):
def __init__(self, ntokens, nsampled, nhid, tied_weight):
super(SampledSoftmax, self).__init__()
# Parameters
self.ntokens = ntokens
self.nsampled = nsampled
self.sampler = LogUniformSampler(self.ntokens)
self.params = nn.Linear(nhid, ntokens)
if tied_weight is not None:
self.params.weight = tied_weight
else:
util.initialize(self.params.weight)
def forward(self, inputs, labels):
if self.training:
# sample ids according to word distribution - Unique
sample_values = self.sampler.sample(self.nsampled, labels.data.cpu().numpy())
return self.sampled(inputs, labels, sample_values, remove_accidental_match=True)
else:
return self.full(inputs, labels)
def sampled(self, inputs, labels, sample_values, remove_accidental_match=False):
assert(inputs.data.get_device() == labels.data.get_device())
device_id = labels.data.get_device()
batch_size, d = inputs.size()
sample_ids, true_freq, sample_freq = sample_values
sample_ids = Variable(torch.LongTensor(sample_ids)).cuda(device_id)
true_freq = Variable(torch.FloatTensor(true_freq)).cuda(device_id)
sample_freq = Variable(torch.FloatTensor(sample_freq)).cuda(device_id)
# gather true labels - weights and frequencies
true_weights = torch.index_select(self.params.weight, 0, labels)
true_bias = torch.index_select(self.params.bias, 0, labels)
# gather sample ids - weights and frequencies
sample_weights = torch.index_select(self.params.weight, 0, sample_ids)
sample_bias = torch.index_select(self.params.bias, 0, sample_ids)
# calculate logits
true_logits = torch.sum(torch.mul(inputs, true_weights), dim=1) + true_bias
sample_logits = torch.matmul(inputs, torch.t(sample_weights)) + sample_bias
# remove true labels from sample set
if remove_accidental_match:
acc_hits = self.sampler.accidental_match(labels.data.cpu().numpy(), sample_ids.data.cpu().numpy())
acc_hits = list(zip(*acc_hits))
sample_logits[acc_hits] = -1e37
# perform correction
true_logits = true_logits.sub(torch.log(true_freq))
sample_logits = sample_logits.sub(torch.log(sample_freq))
# return logits and new_labels
logits = torch.cat((torch.unsqueeze(true_logits, dim=1), sample_logits), dim=1)
new_targets = Variable(torch.zeros(batch_size).long()).cuda(device_id)
return logits, new_targets
def full(self, inputs, labels):
return self.params(inputs), labels
class RNNModel(nn.Module):
"""A recurrent module"""
def __init__(self, ntokens, ninp, nhid, nout, nlayers, proj, dropout):
super(RNNModel, self).__init__()
# Parameters
self.nhid = nhid
self.nlayers = nlayers
# Create Layers
self.drop = nn.Dropout(dropout)
self.rnn = nn.LSTM(ninp, nhid, nlayers, dropout=dropout)
if proj:
self.proj = nn.Linear(nhid, nout)
util.initialize(self.proj.weight)
else:
self.proj = None
def forward(self, inputs, hidden):
inputs = self.drop(inputs)
output, hidden = self.rnn(inputs, hidden)
if self.proj is not None:
output = self.proj(output)
output = self.drop(output)
return output.view(output.size(0)*output.size(1), output.size(2)), hidden
def init_hidden(self, bsz):
return (Variable(torch.zeros(self.nlayers, bsz, self.nhid)),
Variable(torch.zeros(self.nlayers, bsz, self.nhid)))