This repository has been archived by the owner on Nov 1, 2021. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 114
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* Added masked batch normalization, layer normalization, and soft attention utility functions. * Moved new util functions to new module directory. * Added build support for modules and fixed filename. * Added unit tests and various fixes for new modules. * Added torch-dokx style documentation to modules.
- Loading branch information
1 parent
dae7f53
commit 9c20f7b
Showing
7 changed files
with
351 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,41 @@ | ||
local util = require 'autograd.util' | ||
return function(opt, params) | ||
local opt = opt or {} | ||
local params = params or {} | ||
|
||
local nOutputs = opt.nOutputs or 10 | ||
p = {gain = torch.ones(1, nOutputs), | ||
bias = torch.zeros(1, nOutputs)} | ||
table.insert(params, p) | ||
|
||
local function layer_norm(params, x, eps) | ||
--[[ Layer Normalization of Ba, Kiros, and Hinton (https://arxiv.org/abs/1607.06450) | ||
Normalizes activations x at a layer by their mean and std. | ||
Parameters: | ||
* `params` - Gain and bias parameters to adjust normalized output. | ||
* `x` - ([batch, nOutputs]) tensor to be normalized. | ||
* `eps` - Small constant to avoid divide by zero for small std. | ||
Returns: | ||
* `x_corrected` - ([batch,] nOutputs]) layer normalized tensor. | ||
--]] | ||
local p = params[1] or params | ||
local eps = eps or 1e-5 | ||
local x_in = x | ||
if torch.nDimension(x) == 1 then | ||
x_in = torch.view(x, 1, torch.size(x, 1)) | ||
end | ||
local n = torch.size(x_in,2) | ||
local mean = torch.expand(torch.mean(x_in, 2), torch.size(x_in)) | ||
local x_centered = x_in - mean | ||
local std = torch.expand(torch.sqrt(torch.sum(torch.cmul(x_centered, x_centered) / n, 2)) + eps, torch.size(x_in)) | ||
local x_normed = torch.cdiv(x_centered, std) | ||
local gain = torch.expand(p.gain, torch.size(x_in)) | ||
local bias = torch.expand(p.bias, torch.size(x_in)) | ||
local x_corrected = torch.view(torch.cmul(x_normed, gain) + bias, torch.size(x)) | ||
return x_corrected | ||
end | ||
return layer_norm, params | ||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,63 @@ | ||
local util = require 'autograd.util' | ||
return function(opt, params) | ||
local opt = opt or {} | ||
local params = params or {} | ||
|
||
local nOutputs = opt.nOutputs or 10 | ||
local momentum = opt.momentum or 0.1 | ||
|
||
batchNormState = {momentum = momentum, train = 1, | ||
running_mean = torch.zeros(1, nOutputs), | ||
running_std = torch.ones(1, nOutputs)} | ||
|
||
-- initializing gain to < 1 is recommended for LSTM batch norm. | ||
p = {gain = torch.zeros(1, nOutputs):fill(0.1), | ||
bias = torch.zeros(1, nOutputs)} | ||
table.insert(params, p) | ||
|
||
local function masked_batch_norm(params, x, mask, state, eps) | ||
--[[ Masked batch normalization for minibatches with variable length sequences. | ||
Based on sequence batch norm from Batch Normalized Recurrent Neural Networks by Laurent et al. | ||
(http://arxiv.org/abs/1510.01378) | ||
Parameters: | ||
* `params` - Gain and bias parameters to adjust normalized output. | ||
* `x` - ([batch, [time,], nOutputs]) tensor to be normalized. | ||
* `mask` - Tensor with the same size as x that is 1 where x is valid and 0 otherwise. | ||
* `state` - Running mean and std estimates, momentum for estimates, and train flag. | ||
* `eps` - Small constant to avoid divide by zero for small std. | ||
Returns: | ||
* `x_corrected` - ([batch, [time,], nOutputs]) batch normalized tensor. | ||
--]] | ||
local p = params[1] or params | ||
local eps = eps or 1e-5 | ||
local train = state.train or 1 | ||
local momentum = (state.momentum or 0.1) * train -- kill state updates during evaluation | ||
local x_in = x | ||
local mask_in = mask | ||
if torch.nDimension(x) == 3 then -- collapse batch and time dimensions | ||
x_in = torch.view(x, -1, torch.size(x, 3)) | ||
mask_in = torch.view(mask, -1, torch.size(mask, 3)) | ||
elseif torch.nDimension(x) == 1 then -- expand batch dimension | ||
x_in = torch.view(x, 1, torch.size(x, 1)) | ||
mask_in = torch.view(mask, 1, torch.size(mask, 1)) | ||
end | ||
local n = torch.sum(mask) | ||
mask_in = torch.expand(mask_in, torch.size(x_in)) | ||
local x_masked = torch.cmul(x_in, mask_in) | ||
local mean = torch.sum(x_masked / n, 1) | ||
state.running_mean = momentum * mean + (1 - momentum) * state.running_mean | ||
local x_centered = torch.cmul(x_masked - torch.expand(state.running_std, torch.size(x_in)), mask_in) | ||
local var = torch.sum(torch.cmul(x_centered, x_centered) / n, 1) + eps | ||
local std = torch.sqrt(var) | ||
state.running_std = momentum * std + (1 - momentum) * state.running_std | ||
local x_normed = torch.cdiv(x_centered, torch.expand(state.running_std, torch.size(x_in))) | ||
local gain = torch.expand(p.gain, torch.size(x_in)) | ||
local bias = torch.expand(p.bias, torch.size(x_in)) | ||
local x_corrected = torch.view(torch.cmul(x_normed, gain) + bias, torch.size(x)) | ||
return x_corrected | ||
end | ||
return masked_batch_norm, params, batchNormState | ||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,76 @@ | ||
local functionalize = require('autograd.nnwrapper').functionalize | ||
local nn = functionalize('nn') | ||
local LayerNorm = require 'autograd.module.LayerNormalization' | ||
|
||
local softMax = nn.SoftMax() | ||
|
||
return function(opt, params) | ||
local opt = opt or {} | ||
local params = params or {} | ||
|
||
local layerNormalization = opt.layerNormalization or false | ||
local hiddenFeatures = opt.hiddenFeatures or 10 | ||
local subjectFeatures = opt.subjectFeatures or 15 | ||
local subjectChoices = opt.subjectChoices or 20 | ||
|
||
p = {W_att_subject = torch.zeros(1, 1, subjectFeatures), | ||
W_att_h = torch.zeros(hiddenFeatures, subjectChoices), | ||
b_att = torch.zeros(1, subjectChoices)} | ||
|
||
if layerNormalization then | ||
local focus_ln_params = LayerNorm({nOutputs = subjectChoices}) | ||
This comment has been minimized.
Sorry, something went wrong. |
||
p.focus_ln_gain = focus_ln_params.gain | ||
p.focus_ln_bias = focus_ln_params.bias | ||
p.b_att = nil | ||
end | ||
table.insert(params, p) | ||
|
||
local soft_attention = function(params, subject, h) | ||
--[[ Soft attention over subject given hidden state. | ||
Deterministic soft attention of Show, Attend, and Tell by Xu et al. (http://arxiv.org/abs/1502.03044) | ||
Parameters: | ||
* `params` - Weights to combine subject and hidden features to score choices. | ||
* `subject` - ([batch,] subjectFeatures, subjectChoices) tensor. | ||
* `h` - ([batch,] hiddenFeatures) tensor. | ||
Returns: | ||
* `attention` - ([batch,], subjectFeatures) tensor that is the expectation of the attended subject vector. | ||
* `focus` - ([batch,], subjectChoices) tensor that is the probability of selecting any given subject choice. | ||
--]] | ||
local p = params[1] or params | ||
local subject_in = subject | ||
local h_in = h | ||
if torch.nDimension(subject) == 2 then | ||
subject_in = torch.view(subject, 1, torch.size(subject, 1), torch.size(subject, 2)) | ||
end | ||
if torch.nDimension(h) == 1 then | ||
h_in = torch.view(h, 1, torch.size(h, 1)) | ||
end | ||
local batchSize = torch.size(subject_in, 1) | ||
local subjectFeatures = torch.size(subject_in, 2) | ||
local subjectChoices = torch.size(subject_in, 3) | ||
-- Activations for each subject choice and hidden state. | ||
local W_subject = torch.expand(p.W_att_subject, batchSize, 1, subjectFeatures) | ||
local subject_logit = torch.squeeze(torch.bmm(W_subject, subject_in), 2) | ||
local hidden_logit = h_in * p.W_att_h | ||
-- Focus distribution over subject choices. | ||
local focus_logit = subject_logit + hidden_logit | ||
if layerNormalization then | ||
focus_logit = layer_norm({gain = p.focus_ln_gain, bias = p.focus_ln_bias}, focus_logit) | ||
This comment has been minimized.
Sorry, something went wrong.
Atcold
Contributor
|
||
else | ||
focus_logit = focus_logit + torch.expand(p.b_att, batchSize, subjectChoices) | ||
end | ||
local focus = softMax(focus_logit) | ||
-- Attend to choice in expectation. | ||
local expanded_focus = torch.expand(torch.view(focus, batchSize, 1, subjectChoices), torch.size(subject_in)) | ||
local attention = torch.squeeze(torch.sum(torch.cmul(subject_in, expanded_focus), 3), 3) | ||
if torch.nDimension(subject) == 2 then | ||
attention = torch.squeeze(attention, 1) | ||
focus = torch.squeeze(focus, 1) | ||
end | ||
return attention, focus | ||
end | ||
return soft_attention, params | ||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,8 @@ | ||
-- autograd native modules | ||
local module = { | ||
LayerNormalization = require 'autograd.module.LayerNormalization', | ||
MaskedBatchNormalization = require 'autograd.module.MaskedBatchNormalization', | ||
SoftAttention = require 'autograd.module.SoftAttention' | ||
} | ||
|
||
return module |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
LayerNorm
returnslayer_norm, params
. So, in this case,focus_ln_params
looks like it's gettinglayer_norm
and notparams
.