Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
f674432
commit 4622cb0
Showing
11 changed files
with
852 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -99,3 +99,4 @@ ENV/ | |
|
||
# mypy | ||
.mypy_cache/ | ||
*.txt |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,47 @@ | ||
import torch.nn.functional as F | ||
import torch.nn as nn | ||
from torch.nn.modules.rnn import * | ||
|
||
|
||
class LayerNorm(nn.Module): | ||
|
||
def __init__(self, features, eps=1e-6): | ||
super(LayerNorm, self).__init__() | ||
self.gamma = nn.Parameter(torch.ones(features)) | ||
self.beta = nn.Parameter(torch.zeros(features)) | ||
self.eps = eps | ||
|
||
def forward(self, x): | ||
mean = x.mean(-1, keepdim=True) | ||
std = x.std(-1, keepdim=True) | ||
return self.gamma * (x - mean) / (std + self.eps) + self.beta | ||
|
||
|
||
class LSTMCell(RNNCellBase): | ||
|
||
def __init__(self, input_size, hidden_size, bias=True, dropout=0): | ||
super(LSTMCell, self).__init__() | ||
self.input_size = input_size | ||
self.hidden_size = hidden_size | ||
self.bias = bias | ||
self.ih = nn.Sequential(nn.Linear(input_size, 4 * hidden_size, bias), LayerNorm(4 * hidden_size)) | ||
self.hh = nn.Sequential(nn.Linear(hidden_size, 4 * hidden_size, bias), LayerNorm(4 * hidden_size)) | ||
self.c_norm = LayerNorm(hidden_size) | ||
self.drop = nn.Dropout(dropout) | ||
|
||
def forward(self, input, hidden): | ||
|
||
hx, cx = hidden | ||
gates = self.ih(input) + self.hh(hx) | ||
|
||
ingate, forgetgate, cellgate, outgate = gates.chunk(4, 1) | ||
|
||
ingate = F.sigmoid(ingate) | ||
forgetgate = F.sigmoid(forgetgate) | ||
cellgate = F.tanh(cellgate) | ||
outgate = F.sigmoid(outgate) | ||
|
||
cy = forgetgate * cx + ingate * cellgate | ||
hy = outgate * F.tanh(self.c_norm(cy)) | ||
|
||
return hy, cy |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,66 @@ | ||
import numpy | ||
import torch | ||
import torch.nn as nn | ||
import torch.nn.functional as F | ||
from torch.autograd import Variable | ||
|
||
|
||
class ParsingNetwork(nn.Module): | ||
def __init__(self, ninp, nhid, nslots=5, nlookback=1, resolution=0.1, dropout=0.4, hard=False): | ||
super(ParsingNetwork, self).__init__() | ||
|
||
self.nhid = nhid | ||
self.ninp = ninp | ||
self.nslots = nslots | ||
self.nlookback = nlookback | ||
self.resolution = resolution | ||
self.hard = hard | ||
|
||
self.drop = nn.Dropout(dropout) | ||
|
||
# Attention layers | ||
self.gate = nn.Sequential(nn.Dropout(dropout), | ||
nn.Conv1d(ninp, nhid, (nlookback + 1)), | ||
nn.BatchNorm1d(nhid), | ||
nn.ReLU(), | ||
nn.Dropout(dropout), | ||
nn.Conv1d(nhid, 2, 1, groups=2), | ||
nn.ReLU()) | ||
|
||
def forward(self, emb, parser_state): | ||
emb_last, cum_gate = parser_state | ||
ntimestep = emb.size(0) | ||
|
||
emb_last = torch.cat([emb_last, emb], dim=0) | ||
emb = emb_last.transpose(0, 1).transpose(1, 2) # bsz, ninp, ntimestep + nlookback | ||
|
||
gates = self.gate(emb) # bsz, 2, ntimestep | ||
gate = gates[:, 0, :] | ||
gate_next = gates[:, 1, :] | ||
cum_gate = torch.cat([cum_gate, gate], dim=1) | ||
gate_hat = torch.stack([cum_gate[:, i:i + ntimestep] for i in range(self.nslots, 0, -1)], | ||
dim=2) # bsz, ntimestep, nslots | ||
|
||
if self.hard: | ||
memory_gate = (F.hardtanh((gate[:, :, None] - gate_hat) / self.resolution * 2 + 1) + 1) / 2 | ||
else: | ||
memory_gate = F.sigmoid( | ||
(gate[:, :, None] - gate_hat) / self.resolution * 10 + 5) # bsz, ntimestep, nslots | ||
memory_gate = torch.cumprod(memory_gate, dim=2) # bsz, ntimestep, nlookback+1 | ||
memory_gate = torch.unbind(memory_gate, dim=1) | ||
|
||
if self.hard: | ||
memory_gate_next = (F.hardtanh((gate_next[:, :, None] - gate_hat) / self.resolution * 2 + 1) + 1) / 2 | ||
else: | ||
memory_gate_next = F.sigmoid( | ||
(gate_next[:, :, None] - gate_hat) / self.resolution * 10 + 5) # bsz, ntimestep, nslots | ||
memory_gate_next = torch.cumprod(memory_gate_next, dim=2) # bsz, ntimestep, nlookback+1 | ||
memory_gate_next = torch.unbind(memory_gate_next, dim=1) | ||
|
||
return (memory_gate, memory_gate_next), gate, (emb_last[-self.nlookback:], cum_gate[:, -self.nslots:]) | ||
|
||
def init_hidden(self, bsz): | ||
weight = next(self.parameters()).data | ||
self.ones = Variable(weight.new(bsz, 1).zero_() + 1) | ||
return Variable(weight.new(self.nlookback, bsz, self.ninp).zero_()), \ | ||
Variable(weight.new(bsz, self.nslots).zero_() + numpy.inf) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,56 @@ | ||
import math | ||
|
||
import torch | ||
import torch.nn as nn | ||
from torch.autograd import Variable | ||
|
||
from blocks import softmax, ResBlock | ||
|
||
|
||
class PredictNetwork(nn.Module): | ||
def __init__(self, ninp, nout, nslots, dropout, nlayers=1): | ||
super(PredictNetwork, self).__init__() | ||
|
||
self.ninp = ninp | ||
self.nout = nout | ||
self.nslots = nslots | ||
self.nlayers = nlayers | ||
|
||
self.drop = nn.Dropout(dropout) | ||
|
||
self.projector_pred = nn.Sequential(nn.Dropout(dropout), | ||
nn.Linear(ninp, ninp), | ||
nn.Dropout(dropout)) | ||
|
||
if nlayers > 0: | ||
self.res = ResBlock(ninp*2, nout, dropout, nlayers) | ||
else: | ||
self.res = None | ||
|
||
self.ffd = nn.Sequential(nn.Dropout(dropout), | ||
nn.Linear(ninp * 2, nout), | ||
nn.BatchNorm1d(nout), | ||
nn.Tanh() | ||
) | ||
|
||
def forward(self, input, input_memory): | ||
input = torch.cat([input, input_memory], dim=1) | ||
if self.nlayers > 0: | ||
input = self.res(input) | ||
output = self.ffd(input) | ||
return output | ||
|
||
def attention(self, input, memory, gate_time): | ||
key = self.projector_pred(input) | ||
# select memory to use | ||
logits = torch.bmm(memory, key[:, :, None]).squeeze(2) | ||
logits = logits / math.sqrt(self.ninp) | ||
attention = softmax(logits, gate_time) | ||
selected_memory_h = (memory * attention[:, :, None]).sum(dim=1) | ||
memory = torch.cat([input[:, None, :], memory[:, :-1, :]], dim=1) | ||
return selected_memory_h, memory, attention | ||
|
||
def init_hidden(self, bsz): | ||
weight = next(self.parameters()).data | ||
self.ones = Variable(weight.new(bsz, 1).zero_() + 1.) | ||
return Variable(weight.new(bsz, self.nslots, self.ninp).zero_()) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,54 @@ | ||
import math | ||
|
||
import torch | ||
import torch.nn as nn | ||
from torch.autograd import Variable | ||
|
||
from LSTMCell import LSTMCell | ||
from blocks import softmax | ||
|
||
|
||
class ReadingNetwork(nn.Module): | ||
def __init__(self, ninp, nout, nslots, dropout, idropout): | ||
super(ReadingNetwork, self).__init__() | ||
|
||
self.ninp = ninp | ||
self.nout = nout | ||
self.nslots = nslots | ||
self.drop = nn.Dropout(dropout) | ||
self.memory_rnn = LSTMCell(ninp, nout) | ||
self.projector_summ = nn.Sequential(nn.Dropout(idropout), | ||
nn.Linear(ninp + nout, nout), | ||
nn.Dropout(idropout)) | ||
|
||
def forward(self, input, memory, gate_time, rmask): | ||
memory_h, memory_c = memory | ||
|
||
# attention | ||
selected_memory_h, selected_memory_c, attention0 = self.attention(input, memory_h, memory_c, | ||
gate=gate_time) | ||
|
||
# recurrent | ||
input = self.drop(input) | ||
h_i, c_i = self.memory_rnn(input, (selected_memory_h * rmask, selected_memory_c)) | ||
|
||
# updata memory | ||
memory_h = torch.cat([h_i[:, None, :], memory_h[:, :-1, :]], dim=1) | ||
memory_c = torch.cat([c_i[:, None, :], memory_c[:, :-1, :]], dim=1) | ||
|
||
return h_i, (memory_h, memory_c), attention0 | ||
|
||
def attention(self, input, memory_h, memory_c, gate=None): | ||
# select memory to use | ||
key = self.projector_summ(torch.cat([input, memory_h[:, 0, :]], dim=1)) | ||
logits = torch.bmm(memory_h, key[:, :, None]).squeeze(2) | ||
logits = logits / math.sqrt(self.nout) | ||
attention = softmax(logits, gate) | ||
selected_memory_h = (memory_h * attention[:, :, None]).sum(dim=1) | ||
selected_memory_c = (memory_c * attention[:, :, None]).sum(dim=1) | ||
return selected_memory_h, selected_memory_c, attention | ||
|
||
def init_hidden(self, bsz): | ||
weight = next(self.parameters()).data | ||
return Variable(weight.new(bsz, self.nslots, self.nout).zero_()), \ | ||
Variable(weight.new(bsz, self.nslots, self.nout).zero_()) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,48 @@ | ||
import torch | ||
import torch.nn as nn | ||
import torch.nn.functional as F | ||
|
||
|
||
def stick_breaking(logits): | ||
e = F.sigmoid(logits) | ||
z = (1 - e).cumprod(dim=1) | ||
p = torch.cat([e.narrow(1, 0, 1), e[:, 1:] * z[:, :-1]], dim=1) | ||
|
||
return p | ||
|
||
|
||
def softmax(x, mask=None): | ||
max_x, _ = x.max(dim=-1, keepdim=True) | ||
e_x = torch.exp(x - max_x) | ||
if not (mask is None): | ||
e_x = e_x * mask | ||
out = e_x / (e_x.sum(dim=-1, keepdim=True) + 1e-8) | ||
|
||
return out | ||
|
||
|
||
class ResBlock(nn.Module): | ||
def __init__(self, ninp, nout, dropout, nlayers=1): | ||
super(ResBlock, self).__init__() | ||
|
||
self.nlayers = nlayers | ||
|
||
self.drop = nn.Dropout(dropout) | ||
|
||
self.res = nn.ModuleList( | ||
[nn.Sequential( | ||
nn.Linear(ninp, ninp), | ||
nn.BatchNorm1d(ninp), | ||
nn.ReLU(), | ||
nn.Dropout(dropout), | ||
nn.Linear(ninp, ninp), | ||
nn.BatchNorm1d(ninp), | ||
) | ||
for _ in range(nlayers)] | ||
) | ||
|
||
def forward(self, input): | ||
# input = self.drop(input) | ||
for i in range(self.nlayers): | ||
input = F.relu(self.res[i](input) + input) | ||
return input |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,54 @@ | ||
import os | ||
import torch | ||
|
||
class Dictionary(object): | ||
def __init__(self): | ||
self.word2idx = {} | ||
self.idx2word = [] | ||
|
||
def add_word(self, word): | ||
if word not in self.word2idx: | ||
self.idx2word.append(word) | ||
self.word2idx[word] = len(self.idx2word) - 1 | ||
return self.word2idx[word] | ||
|
||
def __len__(self): | ||
return len(self.idx2word) | ||
|
||
def __getitem__(self, key): | ||
if self.word2idx.has_key(key): | ||
return self.word2idx[key] | ||
else: | ||
return self.word2idx['<unk>'] | ||
|
||
|
||
class Corpus(object): | ||
def __init__(self, path): | ||
self.dictionary = Dictionary() | ||
self.train = self.tokenize(os.path.join(path, 'train.txt')) | ||
self.valid = self.tokenize(os.path.join(path, 'valid.txt')) | ||
self.test = self.tokenize(os.path.join(path, 'test.txt')) | ||
|
||
def tokenize(self, path): | ||
"""Tokenizes a text file.""" | ||
assert os.path.exists(path) | ||
# Add words to the dictionary | ||
with open(path, 'r') as f: | ||
tokens = 0 | ||
for line in f: | ||
words = line.strip().split() + ['</s>'] | ||
tokens += len(words) | ||
for word in words: | ||
self.dictionary.add_word(word) | ||
|
||
# Tokenize file content | ||
with open(path, 'r') as f: | ||
ids = torch.LongTensor(tokens) | ||
token = 0 | ||
for line in f: | ||
words = line.strip().split() + ['</s>'] | ||
for word in words: | ||
ids[token] = self.dictionary.word2idx[word] | ||
token += 1 | ||
|
||
return ids |
Oops, something went wrong.