Skip to content

Commit

Permalink
PRPN code
Browse files Browse the repository at this point in the history
  • Loading branch information
yikangshen committed Feb 28, 2018
1 parent f674432 commit 4622cb0
Show file tree
Hide file tree
Showing 11 changed files with 852 additions and 0 deletions.
1 change: 1 addition & 0 deletions .gitignore
Expand Up @@ -99,3 +99,4 @@ ENV/

# mypy
.mypy_cache/
*.txt
47 changes: 47 additions & 0 deletions LSTMCell.py
@@ -0,0 +1,47 @@
import torch.nn.functional as F
import torch.nn as nn
from torch.nn.modules.rnn import *


class LayerNorm(nn.Module):

def __init__(self, features, eps=1e-6):
super(LayerNorm, self).__init__()
self.gamma = nn.Parameter(torch.ones(features))
self.beta = nn.Parameter(torch.zeros(features))
self.eps = eps

def forward(self, x):
mean = x.mean(-1, keepdim=True)
std = x.std(-1, keepdim=True)
return self.gamma * (x - mean) / (std + self.eps) + self.beta


class LSTMCell(RNNCellBase):

def __init__(self, input_size, hidden_size, bias=True, dropout=0):
super(LSTMCell, self).__init__()
self.input_size = input_size
self.hidden_size = hidden_size
self.bias = bias
self.ih = nn.Sequential(nn.Linear(input_size, 4 * hidden_size, bias), LayerNorm(4 * hidden_size))
self.hh = nn.Sequential(nn.Linear(hidden_size, 4 * hidden_size, bias), LayerNorm(4 * hidden_size))
self.c_norm = LayerNorm(hidden_size)
self.drop = nn.Dropout(dropout)

def forward(self, input, hidden):

hx, cx = hidden
gates = self.ih(input) + self.hh(hx)

ingate, forgetgate, cellgate, outgate = gates.chunk(4, 1)

ingate = F.sigmoid(ingate)
forgetgate = F.sigmoid(forgetgate)
cellgate = F.tanh(cellgate)
outgate = F.sigmoid(outgate)

cy = forgetgate * cx + ingate * cellgate
hy = outgate * F.tanh(self.c_norm(cy))

return hy, cy
66 changes: 66 additions & 0 deletions ParsingNetwork.py
@@ -0,0 +1,66 @@
import numpy
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable


class ParsingNetwork(nn.Module):
def __init__(self, ninp, nhid, nslots=5, nlookback=1, resolution=0.1, dropout=0.4, hard=False):
super(ParsingNetwork, self).__init__()

self.nhid = nhid
self.ninp = ninp
self.nslots = nslots
self.nlookback = nlookback
self.resolution = resolution
self.hard = hard

self.drop = nn.Dropout(dropout)

# Attention layers
self.gate = nn.Sequential(nn.Dropout(dropout),
nn.Conv1d(ninp, nhid, (nlookback + 1)),
nn.BatchNorm1d(nhid),
nn.ReLU(),
nn.Dropout(dropout),
nn.Conv1d(nhid, 2, 1, groups=2),
nn.ReLU())

def forward(self, emb, parser_state):
emb_last, cum_gate = parser_state
ntimestep = emb.size(0)

emb_last = torch.cat([emb_last, emb], dim=0)
emb = emb_last.transpose(0, 1).transpose(1, 2) # bsz, ninp, ntimestep + nlookback

gates = self.gate(emb) # bsz, 2, ntimestep
gate = gates[:, 0, :]
gate_next = gates[:, 1, :]
cum_gate = torch.cat([cum_gate, gate], dim=1)
gate_hat = torch.stack([cum_gate[:, i:i + ntimestep] for i in range(self.nslots, 0, -1)],
dim=2) # bsz, ntimestep, nslots

if self.hard:
memory_gate = (F.hardtanh((gate[:, :, None] - gate_hat) / self.resolution * 2 + 1) + 1) / 2
else:
memory_gate = F.sigmoid(
(gate[:, :, None] - gate_hat) / self.resolution * 10 + 5) # bsz, ntimestep, nslots
memory_gate = torch.cumprod(memory_gate, dim=2) # bsz, ntimestep, nlookback+1
memory_gate = torch.unbind(memory_gate, dim=1)

if self.hard:
memory_gate_next = (F.hardtanh((gate_next[:, :, None] - gate_hat) / self.resolution * 2 + 1) + 1) / 2
else:
memory_gate_next = F.sigmoid(
(gate_next[:, :, None] - gate_hat) / self.resolution * 10 + 5) # bsz, ntimestep, nslots
memory_gate_next = torch.cumprod(memory_gate_next, dim=2) # bsz, ntimestep, nlookback+1
memory_gate_next = torch.unbind(memory_gate_next, dim=1)

return (memory_gate, memory_gate_next), gate, (emb_last[-self.nlookback:], cum_gate[:, -self.nslots:])

def init_hidden(self, bsz):
weight = next(self.parameters()).data
self.ones = Variable(weight.new(bsz, 1).zero_() + 1)
return Variable(weight.new(self.nlookback, bsz, self.ninp).zero_()), \
Variable(weight.new(bsz, self.nslots).zero_() + numpy.inf)
56 changes: 56 additions & 0 deletions PredictNetwork.py
@@ -0,0 +1,56 @@
import math

import torch
import torch.nn as nn
from torch.autograd import Variable

from blocks import softmax, ResBlock


class PredictNetwork(nn.Module):
def __init__(self, ninp, nout, nslots, dropout, nlayers=1):
super(PredictNetwork, self).__init__()

self.ninp = ninp
self.nout = nout
self.nslots = nslots
self.nlayers = nlayers

self.drop = nn.Dropout(dropout)

self.projector_pred = nn.Sequential(nn.Dropout(dropout),
nn.Linear(ninp, ninp),
nn.Dropout(dropout))

if nlayers > 0:
self.res = ResBlock(ninp*2, nout, dropout, nlayers)
else:
self.res = None

self.ffd = nn.Sequential(nn.Dropout(dropout),
nn.Linear(ninp * 2, nout),
nn.BatchNorm1d(nout),
nn.Tanh()
)

def forward(self, input, input_memory):
input = torch.cat([input, input_memory], dim=1)
if self.nlayers > 0:
input = self.res(input)
output = self.ffd(input)
return output

def attention(self, input, memory, gate_time):
key = self.projector_pred(input)
# select memory to use
logits = torch.bmm(memory, key[:, :, None]).squeeze(2)
logits = logits / math.sqrt(self.ninp)
attention = softmax(logits, gate_time)
selected_memory_h = (memory * attention[:, :, None]).sum(dim=1)
memory = torch.cat([input[:, None, :], memory[:, :-1, :]], dim=1)
return selected_memory_h, memory, attention

def init_hidden(self, bsz):
weight = next(self.parameters()).data
self.ones = Variable(weight.new(bsz, 1).zero_() + 1.)
return Variable(weight.new(bsz, self.nslots, self.ninp).zero_())
54 changes: 54 additions & 0 deletions ReadingNetwork.py
@@ -0,0 +1,54 @@
import math

import torch
import torch.nn as nn
from torch.autograd import Variable

from LSTMCell import LSTMCell
from blocks import softmax


class ReadingNetwork(nn.Module):
def __init__(self, ninp, nout, nslots, dropout, idropout):
super(ReadingNetwork, self).__init__()

self.ninp = ninp
self.nout = nout
self.nslots = nslots
self.drop = nn.Dropout(dropout)
self.memory_rnn = LSTMCell(ninp, nout)
self.projector_summ = nn.Sequential(nn.Dropout(idropout),
nn.Linear(ninp + nout, nout),
nn.Dropout(idropout))

def forward(self, input, memory, gate_time, rmask):
memory_h, memory_c = memory

# attention
selected_memory_h, selected_memory_c, attention0 = self.attention(input, memory_h, memory_c,
gate=gate_time)

# recurrent
input = self.drop(input)
h_i, c_i = self.memory_rnn(input, (selected_memory_h * rmask, selected_memory_c))

# updata memory
memory_h = torch.cat([h_i[:, None, :], memory_h[:, :-1, :]], dim=1)
memory_c = torch.cat([c_i[:, None, :], memory_c[:, :-1, :]], dim=1)

return h_i, (memory_h, memory_c), attention0

def attention(self, input, memory_h, memory_c, gate=None):
# select memory to use
key = self.projector_summ(torch.cat([input, memory_h[:, 0, :]], dim=1))
logits = torch.bmm(memory_h, key[:, :, None]).squeeze(2)
logits = logits / math.sqrt(self.nout)
attention = softmax(logits, gate)
selected_memory_h = (memory_h * attention[:, :, None]).sum(dim=1)
selected_memory_c = (memory_c * attention[:, :, None]).sum(dim=1)
return selected_memory_h, selected_memory_c, attention

def init_hidden(self, bsz):
weight = next(self.parameters()).data
return Variable(weight.new(bsz, self.nslots, self.nout).zero_()), \
Variable(weight.new(bsz, self.nslots, self.nout).zero_())
48 changes: 48 additions & 0 deletions blocks.py
@@ -0,0 +1,48 @@
import torch
import torch.nn as nn
import torch.nn.functional as F


def stick_breaking(logits):
e = F.sigmoid(logits)
z = (1 - e).cumprod(dim=1)
p = torch.cat([e.narrow(1, 0, 1), e[:, 1:] * z[:, :-1]], dim=1)

return p


def softmax(x, mask=None):
max_x, _ = x.max(dim=-1, keepdim=True)
e_x = torch.exp(x - max_x)
if not (mask is None):
e_x = e_x * mask
out = e_x / (e_x.sum(dim=-1, keepdim=True) + 1e-8)

return out


class ResBlock(nn.Module):
def __init__(self, ninp, nout, dropout, nlayers=1):
super(ResBlock, self).__init__()

self.nlayers = nlayers

self.drop = nn.Dropout(dropout)

self.res = nn.ModuleList(
[nn.Sequential(
nn.Linear(ninp, ninp),
nn.BatchNorm1d(ninp),
nn.ReLU(),
nn.Dropout(dropout),
nn.Linear(ninp, ninp),
nn.BatchNorm1d(ninp),
)
for _ in range(nlayers)]
)

def forward(self, input):
# input = self.drop(input)
for i in range(self.nlayers):
input = F.relu(self.res[i](input) + input)
return input
54 changes: 54 additions & 0 deletions data.py
@@ -0,0 +1,54 @@
import os
import torch

class Dictionary(object):
def __init__(self):
self.word2idx = {}
self.idx2word = []

def add_word(self, word):
if word not in self.word2idx:
self.idx2word.append(word)
self.word2idx[word] = len(self.idx2word) - 1
return self.word2idx[word]

def __len__(self):
return len(self.idx2word)

def __getitem__(self, key):
if self.word2idx.has_key(key):
return self.word2idx[key]
else:
return self.word2idx['<unk>']


class Corpus(object):
def __init__(self, path):
self.dictionary = Dictionary()
self.train = self.tokenize(os.path.join(path, 'train.txt'))
self.valid = self.tokenize(os.path.join(path, 'valid.txt'))
self.test = self.tokenize(os.path.join(path, 'test.txt'))

def tokenize(self, path):
"""Tokenizes a text file."""
assert os.path.exists(path)
# Add words to the dictionary
with open(path, 'r') as f:
tokens = 0
for line in f:
words = line.strip().split() + ['</s>']
tokens += len(words)
for word in words:
self.dictionary.add_word(word)

# Tokenize file content
with open(path, 'r') as f:
ids = torch.LongTensor(tokens)
token = 0
for line in f:
words = line.strip().split() + ['</s>']
for word in words:
ids[token] = self.dictionary.word2idx[word]
token += 1

return ids

0 comments on commit 4622cb0

Please sign in to comment.