From 4622cb0b5bdc939fe6b26c9906b307779c8203d0 Mon Sep 17 00:00:00 2001 From: Yikang Shen Date: Wed, 28 Feb 2018 14:42:46 -0500 Subject: [PATCH] PRPN code --- .gitignore | 1 + LSTMCell.py | 47 +++++++++ ParsingNetwork.py | 66 +++++++++++++ PredictNetwork.py | 56 +++++++++++ ReadingNetwork.py | 54 +++++++++++ blocks.py | 48 ++++++++++ data.py | 54 +++++++++++ demo.py | 133 ++++++++++++++++++++++++++ hinton.py | 43 +++++++++ main.py | 238 ++++++++++++++++++++++++++++++++++++++++++++++ model_PRPN.py | 112 ++++++++++++++++++++++ 11 files changed, 852 insertions(+) create mode 100644 LSTMCell.py create mode 100644 ParsingNetwork.py create mode 100644 PredictNetwork.py create mode 100644 ReadingNetwork.py create mode 100644 blocks.py create mode 100644 data.py create mode 100644 demo.py create mode 100644 hinton.py create mode 100644 main.py create mode 100644 model_PRPN.py diff --git a/.gitignore b/.gitignore index 7bbc71c..98c26b5 100644 --- a/.gitignore +++ b/.gitignore @@ -99,3 +99,4 @@ ENV/ # mypy .mypy_cache/ +*.txt diff --git a/LSTMCell.py b/LSTMCell.py new file mode 100644 index 0000000..682de3f --- /dev/null +++ b/LSTMCell.py @@ -0,0 +1,47 @@ +import torch.nn.functional as F +import torch.nn as nn +from torch.nn.modules.rnn import * + + +class LayerNorm(nn.Module): + + def __init__(self, features, eps=1e-6): + super(LayerNorm, self).__init__() + self.gamma = nn.Parameter(torch.ones(features)) + self.beta = nn.Parameter(torch.zeros(features)) + self.eps = eps + + def forward(self, x): + mean = x.mean(-1, keepdim=True) + std = x.std(-1, keepdim=True) + return self.gamma * (x - mean) / (std + self.eps) + self.beta + + +class LSTMCell(RNNCellBase): + + def __init__(self, input_size, hidden_size, bias=True, dropout=0): + super(LSTMCell, self).__init__() + self.input_size = input_size + self.hidden_size = hidden_size + self.bias = bias + self.ih = nn.Sequential(nn.Linear(input_size, 4 * hidden_size, bias), LayerNorm(4 * hidden_size)) + self.hh = nn.Sequential(nn.Linear(hidden_size, 4 * hidden_size, bias), LayerNorm(4 * hidden_size)) + self.c_norm = LayerNorm(hidden_size) + self.drop = nn.Dropout(dropout) + + def forward(self, input, hidden): + + hx, cx = hidden + gates = self.ih(input) + self.hh(hx) + + ingate, forgetgate, cellgate, outgate = gates.chunk(4, 1) + + ingate = F.sigmoid(ingate) + forgetgate = F.sigmoid(forgetgate) + cellgate = F.tanh(cellgate) + outgate = F.sigmoid(outgate) + + cy = forgetgate * cx + ingate * cellgate + hy = outgate * F.tanh(self.c_norm(cy)) + + return hy, cy \ No newline at end of file diff --git a/ParsingNetwork.py b/ParsingNetwork.py new file mode 100644 index 0000000..d8072cd --- /dev/null +++ b/ParsingNetwork.py @@ -0,0 +1,66 @@ +import numpy +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.autograd import Variable + + +class ParsingNetwork(nn.Module): + def __init__(self, ninp, nhid, nslots=5, nlookback=1, resolution=0.1, dropout=0.4, hard=False): + super(ParsingNetwork, self).__init__() + + self.nhid = nhid + self.ninp = ninp + self.nslots = nslots + self.nlookback = nlookback + self.resolution = resolution + self.hard = hard + + self.drop = nn.Dropout(dropout) + + # Attention layers + self.gate = nn.Sequential(nn.Dropout(dropout), + nn.Conv1d(ninp, nhid, (nlookback + 1)), + nn.BatchNorm1d(nhid), + nn.ReLU(), + nn.Dropout(dropout), + nn.Conv1d(nhid, 2, 1, groups=2), + nn.ReLU()) + + def forward(self, emb, parser_state): + emb_last, cum_gate = parser_state + ntimestep = emb.size(0) + + emb_last = torch.cat([emb_last, emb], dim=0) + emb = emb_last.transpose(0, 1).transpose(1, 2) # bsz, ninp, ntimestep + nlookback + + gates = self.gate(emb) # bsz, 2, ntimestep + gate = gates[:, 0, :] + gate_next = gates[:, 1, :] + cum_gate = torch.cat([cum_gate, gate], dim=1) + gate_hat = torch.stack([cum_gate[:, i:i + ntimestep] for i in range(self.nslots, 0, -1)], + dim=2) # bsz, ntimestep, nslots + + if self.hard: + memory_gate = (F.hardtanh((gate[:, :, None] - gate_hat) / self.resolution * 2 + 1) + 1) / 2 + else: + memory_gate = F.sigmoid( + (gate[:, :, None] - gate_hat) / self.resolution * 10 + 5) # bsz, ntimestep, nslots + memory_gate = torch.cumprod(memory_gate, dim=2) # bsz, ntimestep, nlookback+1 + memory_gate = torch.unbind(memory_gate, dim=1) + + if self.hard: + memory_gate_next = (F.hardtanh((gate_next[:, :, None] - gate_hat) / self.resolution * 2 + 1) + 1) / 2 + else: + memory_gate_next = F.sigmoid( + (gate_next[:, :, None] - gate_hat) / self.resolution * 10 + 5) # bsz, ntimestep, nslots + memory_gate_next = torch.cumprod(memory_gate_next, dim=2) # bsz, ntimestep, nlookback+1 + memory_gate_next = torch.unbind(memory_gate_next, dim=1) + + return (memory_gate, memory_gate_next), gate, (emb_last[-self.nlookback:], cum_gate[:, -self.nslots:]) + + def init_hidden(self, bsz): + weight = next(self.parameters()).data + self.ones = Variable(weight.new(bsz, 1).zero_() + 1) + return Variable(weight.new(self.nlookback, bsz, self.ninp).zero_()), \ + Variable(weight.new(bsz, self.nslots).zero_() + numpy.inf) \ No newline at end of file diff --git a/PredictNetwork.py b/PredictNetwork.py new file mode 100644 index 0000000..d4746b0 --- /dev/null +++ b/PredictNetwork.py @@ -0,0 +1,56 @@ +import math + +import torch +import torch.nn as nn +from torch.autograd import Variable + +from blocks import softmax, ResBlock + + +class PredictNetwork(nn.Module): + def __init__(self, ninp, nout, nslots, dropout, nlayers=1): + super(PredictNetwork, self).__init__() + + self.ninp = ninp + self.nout = nout + self.nslots = nslots + self.nlayers = nlayers + + self.drop = nn.Dropout(dropout) + + self.projector_pred = nn.Sequential(nn.Dropout(dropout), + nn.Linear(ninp, ninp), + nn.Dropout(dropout)) + + if nlayers > 0: + self.res = ResBlock(ninp*2, nout, dropout, nlayers) + else: + self.res = None + + self.ffd = nn.Sequential(nn.Dropout(dropout), + nn.Linear(ninp * 2, nout), + nn.BatchNorm1d(nout), + nn.Tanh() + ) + + def forward(self, input, input_memory): + input = torch.cat([input, input_memory], dim=1) + if self.nlayers > 0: + input = self.res(input) + output = self.ffd(input) + return output + + def attention(self, input, memory, gate_time): + key = self.projector_pred(input) + # select memory to use + logits = torch.bmm(memory, key[:, :, None]).squeeze(2) + logits = logits / math.sqrt(self.ninp) + attention = softmax(logits, gate_time) + selected_memory_h = (memory * attention[:, :, None]).sum(dim=1) + memory = torch.cat([input[:, None, :], memory[:, :-1, :]], dim=1) + return selected_memory_h, memory, attention + + def init_hidden(self, bsz): + weight = next(self.parameters()).data + self.ones = Variable(weight.new(bsz, 1).zero_() + 1.) + return Variable(weight.new(bsz, self.nslots, self.ninp).zero_()) diff --git a/ReadingNetwork.py b/ReadingNetwork.py new file mode 100644 index 0000000..1d517b8 --- /dev/null +++ b/ReadingNetwork.py @@ -0,0 +1,54 @@ +import math + +import torch +import torch.nn as nn +from torch.autograd import Variable + +from LSTMCell import LSTMCell +from blocks import softmax + + +class ReadingNetwork(nn.Module): + def __init__(self, ninp, nout, nslots, dropout, idropout): + super(ReadingNetwork, self).__init__() + + self.ninp = ninp + self.nout = nout + self.nslots = nslots + self.drop = nn.Dropout(dropout) + self.memory_rnn = LSTMCell(ninp, nout) + self.projector_summ = nn.Sequential(nn.Dropout(idropout), + nn.Linear(ninp + nout, nout), + nn.Dropout(idropout)) + + def forward(self, input, memory, gate_time, rmask): + memory_h, memory_c = memory + + # attention + selected_memory_h, selected_memory_c, attention0 = self.attention(input, memory_h, memory_c, + gate=gate_time) + + # recurrent + input = self.drop(input) + h_i, c_i = self.memory_rnn(input, (selected_memory_h * rmask, selected_memory_c)) + + # updata memory + memory_h = torch.cat([h_i[:, None, :], memory_h[:, :-1, :]], dim=1) + memory_c = torch.cat([c_i[:, None, :], memory_c[:, :-1, :]], dim=1) + + return h_i, (memory_h, memory_c), attention0 + + def attention(self, input, memory_h, memory_c, gate=None): + # select memory to use + key = self.projector_summ(torch.cat([input, memory_h[:, 0, :]], dim=1)) + logits = torch.bmm(memory_h, key[:, :, None]).squeeze(2) + logits = logits / math.sqrt(self.nout) + attention = softmax(logits, gate) + selected_memory_h = (memory_h * attention[:, :, None]).sum(dim=1) + selected_memory_c = (memory_c * attention[:, :, None]).sum(dim=1) + return selected_memory_h, selected_memory_c, attention + + def init_hidden(self, bsz): + weight = next(self.parameters()).data + return Variable(weight.new(bsz, self.nslots, self.nout).zero_()), \ + Variable(weight.new(bsz, self.nslots, self.nout).zero_()) diff --git a/blocks.py b/blocks.py new file mode 100644 index 0000000..92327cd --- /dev/null +++ b/blocks.py @@ -0,0 +1,48 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + + +def stick_breaking(logits): + e = F.sigmoid(logits) + z = (1 - e).cumprod(dim=1) + p = torch.cat([e.narrow(1, 0, 1), e[:, 1:] * z[:, :-1]], dim=1) + + return p + + +def softmax(x, mask=None): + max_x, _ = x.max(dim=-1, keepdim=True) + e_x = torch.exp(x - max_x) + if not (mask is None): + e_x = e_x * mask + out = e_x / (e_x.sum(dim=-1, keepdim=True) + 1e-8) + + return out + + +class ResBlock(nn.Module): + def __init__(self, ninp, nout, dropout, nlayers=1): + super(ResBlock, self).__init__() + + self.nlayers = nlayers + + self.drop = nn.Dropout(dropout) + + self.res = nn.ModuleList( + [nn.Sequential( + nn.Linear(ninp, ninp), + nn.BatchNorm1d(ninp), + nn.ReLU(), + nn.Dropout(dropout), + nn.Linear(ninp, ninp), + nn.BatchNorm1d(ninp), + ) + for _ in range(nlayers)] + ) + + def forward(self, input): + # input = self.drop(input) + for i in range(self.nlayers): + input = F.relu(self.res[i](input) + input) + return input diff --git a/data.py b/data.py new file mode 100644 index 0000000..cd88108 --- /dev/null +++ b/data.py @@ -0,0 +1,54 @@ +import os +import torch + +class Dictionary(object): + def __init__(self): + self.word2idx = {} + self.idx2word = [] + + def add_word(self, word): + if word not in self.word2idx: + self.idx2word.append(word) + self.word2idx[word] = len(self.idx2word) - 1 + return self.word2idx[word] + + def __len__(self): + return len(self.idx2word) + + def __getitem__(self, key): + if self.word2idx.has_key(key): + return self.word2idx[key] + else: + return self.word2idx[''] + + +class Corpus(object): + def __init__(self, path): + self.dictionary = Dictionary() + self.train = self.tokenize(os.path.join(path, 'train.txt')) + self.valid = self.tokenize(os.path.join(path, 'valid.txt')) + self.test = self.tokenize(os.path.join(path, 'test.txt')) + + def tokenize(self, path): + """Tokenizes a text file.""" + assert os.path.exists(path) + # Add words to the dictionary + with open(path, 'r') as f: + tokens = 0 + for line in f: + words = line.strip().split() + [''] + tokens += len(words) + for word in words: + self.dictionary.add_word(word) + + # Tokenize file content + with open(path, 'r') as f: + ids = torch.LongTensor(tokens) + token = 0 + for line in f: + words = line.strip().split() + [''] + for word in words: + ids[token] = self.dictionary.word2idx[word] + token += 1 + + return ids diff --git a/demo.py b/demo.py new file mode 100644 index 0000000..bcf54d8 --- /dev/null +++ b/demo.py @@ -0,0 +1,133 @@ +import argparse +import copy +import numpy +import torch +from torch.autograd import Variable +from hinton import plot + +import matplotlib.pyplot as plt + +import data + +def softmax(x): + """Compute softmax values for each sets of scores in x.""" + e_x = numpy.exp(x - numpy.max(x, axis=1, keepdims=True)) + return e_x / e_x.sum(axis=1, keepdims=True) + +numpy.set_printoptions(precision=2, suppress=True, linewidth=5000) + +parser = argparse.ArgumentParser(description='PyTorch PTB Language Model') + +# Model parameters. +parser.add_argument('--data', type=str, default='./data/penn', + help='location of the data corpus') +parser.add_argument('--checkpoint', type=str, default='./model/model.pt', + help='model checkpoint to use') +parser.add_argument('--seed', type=int, default=1111, + help='random seed') +args = parser.parse_args() + + +# def build_tree_depth(depth, sen): +# depth = depth.tolist() +# sorted_idx = numpy.argsort(depth) +# parse_tree = copy.copy(sen) +# i2i = numpy.arange(len(parse_tree)) +# for idx in sorted_idx: +# idx_mapped = i2i[idx] +# new_node = parse_tree[idx_mapped] +# d = depth[idx_mapped] +# if idx < len(sen) - 1 and depth[idx_mapped + 1] <= d: +# new_node = [new_node, parse_tree.pop(idx_mapped + 1)] +# depth.pop(idx_mapped + 1) +# i2i[idx + 1:] -= 1 +# if idx > 0 and depth[idx_mapped - 1] < d: +# idx_mapped -= 1 +# new_node = [parse_tree.pop(idx_mapped), new_node] +# depth.pop(idx_mapped) +# i2i[idx:] -= 1 +# parse_tree[idx_mapped] = new_node +# return parse_tree + +def build_tree(depth, sen): + assert len(depth) == len(sen) + + if len(depth) == 1: + parse_tree = sen[0] + else: + idx_max = numpy.argmax(depth) + parse_tree = [] + if len(sen[:idx_max]) > 0: + tree0 = build_tree(depth[:idx_max], sen[:idx_max]) + parse_tree.append(tree0) + tree1 = sen[idx_max] + if len(sen[idx_max+1:]) > 0: + tree2 = build_tree(depth[idx_max+1:], sen[idx_max+1:]) + tree1 = [tree1, tree2] + if parse_tree == []: + parse_tree = tree1 + else: + parse_tree.append(tree1) + return parse_tree + +def MRG(tr): + if isinstance(tr, str): + return '(' + tr + ')' + # return tr + ' ' + else: + s = '(' + for subtr in tr: + s += MRG(subtr) + s += ')' + return s + +# Set the random seed manually for reproducibility. +torch.manual_seed(args.seed) + +with open(args.checkpoint, 'rb') as f: + model = torch.load(f) +model.eval() +print model + +model.cpu() + +corpus = data.Corpus(args.data) +ntokens = len(corpus.dictionary) +hidden = model.init_hidden(1) +input = Variable(torch.rand(1, 1).mul(ntokens).long(), volatile=True) + +while True: + sens = raw_input('Input a sentences:') + hidden = model.init_hidden(1) + for s in sens.split('\t'): + words = s.strip().split() + x = numpy.array([corpus.dictionary[w] for w in words]) + input = Variable(torch.LongTensor(x[:, None])) + + # hidden = model.init_hidden(1) + output, hidden = model(input, hidden) + output = output.squeeze().data.numpy()[:-1] + output = numpy.log(softmax(output)) + output = numpy.pad(output, ((1, 0), (0, 0)), 'constant', constant_values=0) + output = numpy.exp(-output[range(len(words)), x]) + + attentions = model.attentions.squeeze().data.numpy() + gates = model.gates.squeeze().data.numpy() + phrase = [] + sentence = [] + for i in range(len(words)): + print '%15s\t%7.1f\t%.3f\t%s' % (words[i], output[i], gates[i], plot(attentions[i], 1).replace('\n', '\t')) + midx = numpy.argmax(gates[i]) + if midx > 0: + if phrase != []: + sentence.append(phrase) + phrase = [] + phrase.append(words[i]) + sentence.append(phrase) + + print output[1:].mean() + + parse_tree = build_tree(gates, words) + print MRG(parse_tree) + # parse_tree = build_tree_depth(gates, words) + # print MRG(parse_tree) diff --git a/hinton.py b/hinton.py new file mode 100644 index 0000000..b1d239e --- /dev/null +++ b/hinton.py @@ -0,0 +1,43 @@ +# coding=utf-8 +from __future__ import print_function +import numpy as np +chars = [" ", "▁", "▂", "▃", "▄", "▅", "▆", "▇", "█"] + + +class BarHack(str): + + def __str__(self): + return self.internal + + def __len__(self): + return 1 + + +def plot(arr, max_val=None): + if max_val is None: + max_arr = arr + max_val = max(abs(np.max(max_arr)), abs(np.min(max_arr))) + + opts = np.get_printoptions() + np.set_printoptions(edgeitems=500) + s = str(np.array2string(arr, + formatter={ + 'float_kind': lambda x: visual(x, max_val), + 'int_kind': lambda x: visual(x, max_val)}, + max_line_width=5000 + )) + np.set_printoptions(**opts) + + return s + + +def visual(val, max_val): + if abs(val) == max_val: + step = len(chars) - 1 + else: + step = int(abs(float(val) / max_val) * len(chars)) + colourstart = "" + colourend = "" + if val < 0: + colourstart, colourend = '\033[90m', '\033[0m' + return colourstart + chars[step] + colourend \ No newline at end of file diff --git a/main.py b/main.py new file mode 100644 index 0000000..f4153de --- /dev/null +++ b/main.py @@ -0,0 +1,238 @@ +import argparse +import math +import random +import time + +import torch +import torch.nn as nn +import torch.optim as optim +import torch.optim.lr_scheduler as lr_scheduler +from torch.autograd import Variable + +import data +from model_PRPN import PRPN + +parser = argparse.ArgumentParser(description='PyTorch PennTreeBank RNN/LSTM Language Model') +parser.add_argument('--data', type=str, default='./data/penn', + help='location of the data corpus') +parser.add_argument('--emsize', type=int, default=800, + help='size of word embeddings') +parser.add_argument('--nhid', type=int, default=1200, + help='number of hidden units per layer') +parser.add_argument('--nlayers', type=int, default=2, + help='number of layers') +parser.add_argument('--lr', type=float, default=0.003, + help='initial learning rate') +parser.add_argument('--weight_decay', type=float, default=1e-6, + help='weight decay') +parser.add_argument('--clip', type=float, default=1., + help='gradient clipping') +parser.add_argument('--epochs', type=int, default=100, + help='upper epoch limit') +parser.add_argument('--batch_size', type=int, default=64, metavar='N', + help='batch size') +parser.add_argument('--bptt', type=int, default=35, + help='sequence length') +parser.add_argument('--dropout', type=float, default=0.7, + help='dropout applied to output layers (0 = no dropout)') +parser.add_argument('--idropout', type=float, default=0.4, + help='dropout applied to layers (0 = no dropout)') +parser.add_argument('--rdropout', type=float, default=0.4, + help='dropout applied to recurrent states (0 = no dropout)') +parser.add_argument('--tied', action='store_true', + help='tie the word embedding and softmax weights') +parser.add_argument('--hard', action='store_true', + help='use hard sigmoid') +parser.add_argument('--res', type=int, default=0, + help='number of resnet block in predict network') +parser.add_argument('--seed', type=int, default=1111, + help='random seed') +parser.add_argument('--cuda', action='store_true', + help='use CUDA') +parser.add_argument('--log-interval', type=int, default=100, metavar='N', + help='report interval') +parser.add_argument('--save', type=str, default=None, + help='path to save the final model') +parser.add_argument('--load', type=str, default=None, + help='path to save the final model') +parser.add_argument('--nslots', type=int, default=15, + help='number of memory slots') +parser.add_argument('--nlookback', type=int, default=5, + help='number of look back steps when predict gate') +parser.add_argument('--resolution', type=float, default=0.1, + help='syntactic distance resolution') +parser.add_argument('--device', type=int, default=0, + help='select GPU') +args = parser.parse_args() + +torch.cuda.set_device(args.device) + +if args.save is None: + args.save = 'model.pt' +args.save = './model/' + args.save + +# Set the random seed manually for reproducibility. +torch.manual_seed(args.seed) +if torch.cuda.is_available(): + if not args.cuda: + print("WARNING: You have a CUDA device, so you should probably run with --cuda") + else: + torch.cuda.manual_seed(args.seed) + +############################################################################### +# Load data +############################################################################### + +corpus = data.Corpus(args.data) + + +def batchify(data, bsz, random_start_idx=False): + # Work out how cleanly we can divide the dataset into bsz parts. + nbatch = data.size(0) // bsz + # Trim off any extra elements that wouldn't cleanly fit (remainders). + if random_start_idx: + start_idx = random.randint(0, data.size(0) % bsz - 1) + else: + start_idx = 0 + data = data.narrow(0, start_idx, nbatch * bsz) + # Evenly divide the data across the bsz batches. + data = data.view(bsz, -1).t().contiguous() + if args.cuda: + data = data.cuda() + return data + + +eval_batch_size = 10 +val_data = batchify(corpus.valid, eval_batch_size) +test_data = batchify(corpus.test, eval_batch_size) + +############################################################################### +# Build the model +############################################################################### + +ntokens = len(corpus.dictionary) +model = PRPN(ntokens, args.emsize, args.nhid, args.nlayers, + args.nslots, args.nlookback, args.resolution, + args.dropout, args.idropout, args.rdropout, + args.tied, args.hard, args.res) + +if not (args.load is None): + with open(args.load, 'rb') as f: + model = torch.load(f) + +if args.cuda: + model.cuda() + +criterion = nn.CrossEntropyLoss() + + +############################################################################### +# Training code +############################################################################### + +def repackage_hidden(h): + """Wraps hidden states in new Variables, to detach them from their history.""" + if type(h) == Variable: + return Variable(h.data) + else: + if isinstance(h, list): + return [repackage_hidden(v) for v in h] + else: + return tuple(repackage_hidden(v) for v in h) + + +def get_batch(source, i, evaluation=False): + seq_len = min(args.bptt, len(source) - 1 - i) + data = Variable(source[i:i + seq_len], volatile=evaluation) + target = Variable(source[i + 1:i + 1 + seq_len].view(-1)) + return data, target + + +def evaluate(data_source): + # Turn on evaluation mode which disables dropout. + model.eval() + total_loss = 0 + ntokens = len(corpus.dictionary) + hidden = model.init_hidden(eval_batch_size) + for i in range(0, data_source.size(0) - 1, args.bptt): + data, targets = get_batch(data_source, i, evaluation=True) + output, hidden = model(data, hidden) + output_flat = output.view(-1, ntokens) + total_loss += len(data) * criterion(output_flat, targets).data + hidden = repackage_hidden(hidden) + return total_loss[0] / len(data_source) + + +def train(): + # Turn on training mode which enables dropout. + model.train() + total_loss = 0 + start_time = time.time() + ntokens = len(corpus.dictionary) + hidden = model.init_hidden(args.batch_size) + train_data = batchify(corpus.train, args.batch_size, random_start_idx=True) + for batch, i in enumerate(range(0, train_data.size(0) - 1, args.bptt)): + data, targets = get_batch(train_data, i) + # Starting each batch, we detach the hidden state from how it was previously produced. + # If we didn't, the model would try backpropagating all the way to start of the dataset. + hidden = repackage_hidden(hidden) + optimizer.zero_grad() + output, hidden = model(data, hidden) + loss = criterion(output.view(-1, ntokens), targets) + loss.backward() + + # `clip_grad_norm` helps prevent the exploding gradient problem in RNNs / LSTMs. + torch.nn.utils.clip_grad_norm(model.parameters(), args.clip) + optimizer.step() + + total_loss += loss.data + + if batch % args.log_interval == 0 and batch > 0: + cur_loss = total_loss[0] / args.log_interval + elapsed = time.time() - start_time + print('| epoch {:3d} | {:5d}/{:5d} batches | lr {:02.2f} | ms/batch {:5.2f} | ' + 'loss {:5.2f} | ppl {:8.2f}'.format( + epoch, batch, len(train_data) // args.bptt, lr, + elapsed * 1000 / args.log_interval, cur_loss, math.exp(cur_loss))) + total_loss = 0 + start_time = time.time() + + +# Loop over epochs. +lr = args.lr +best_val_loss = None +optimizer = optim.Adam(model.parameters(), lr=0.003, betas=(0, 0.999), eps=1e-9, weight_decay=args.weight_decay) +scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, 'min', 0.1, patience=1, threshold=0) + +# At any point you can hit Ctrl + C to break out of training early. +try: + for epoch in range(1, args.epochs + 1): + epoch_start_time = time.time() + train() + val_loss = evaluate(val_data) + print('-' * 89) + print('| end of epoch {:3d} | time: {:5.2f}s | valid loss {:5.2f} | ' + 'valid ppl {:8.2f}'.format(epoch, (time.time() - epoch_start_time), + val_loss, math.exp(val_loss))) + print('-' * 89) + # Save the model if the validation loss is the best we've seen so far. + if not best_val_loss or val_loss < best_val_loss: + with open(args.save, 'wb') as f: + torch.save(model, f) + best_val_loss = val_loss + scheduler.step(val_loss) + +except KeyboardInterrupt: + print('-' * 89) + print('Exiting from training early') + +# Load the best saved model. +with open(args.save, 'rb') as f: + model = torch.load(f) + +# Run on test data. +test_loss = evaluate(test_data) +print('=' * 89) +print('| End of training | test loss {:5.2f} | test ppl {:8.2f}'.format( + test_loss, math.exp(test_loss))) +print('=' * 89) diff --git a/model_PRPN.py b/model_PRPN.py new file mode 100644 index 0000000..4fb6676 --- /dev/null +++ b/model_PRPN.py @@ -0,0 +1,112 @@ +import torch +import torch.nn as nn + +from ParsingNetwork import ParsingNetwork +from PredictNetwork import PredictNetwork +from ReadingNetwork import ReadingNetwork + + +class PRPN(nn.Module): + """Container module with an encoder, a recurrent module, and a decoder.""" + + def __init__(self, ntoken, ninp, nhid, nlayers, + nslots=5, nlookback=1, resolution=0.1, + dropout=0.4, idropout=0.4, rdropout=0.1, + tie_weights=False, hard=False, res=1): + super(PRPN, self).__init__() + + self.nhid = nhid + self.ninp = ninp + self.nlayers = nlayers + self.nslots = nslots + self.nlookback = nlookback + + self.drop = nn.Dropout(dropout) + self.idrop = nn.Dropout(idropout) + self.rdrop = nn.Dropout(rdropout) + + # Feedforward layers + self.encoder = nn.Embedding(ntoken, ninp) + self.parser = ParsingNetwork(ninp, nhid, nslots, nlookback, resolution, idropout, hard) + self.reader = nn.ModuleList([ReadingNetwork(ninp, nhid, nslots, dropout=dropout, idropout=idropout), ] + + [ReadingNetwork(nhid, nhid, nslots, dropout=idropout, idropout=idropout) + for i in range(nlayers - 1)]) + self.predictor = PredictNetwork(nhid, ninp, nslots, idropout, res) + self.decoder = nn.Linear(ninp, ntoken) + + if tie_weights: + self.decoder.weight = self.encoder.weight + + self.attentions = None + self.gates = None + + self.init_weights() + + def init_weights(self): + initrange = 0.01 + self.encoder.weight.data.uniform_(-initrange, initrange) + self.decoder.bias.data.fill_(0) + self.decoder.weight.data.uniform_(-initrange, initrange) + + def clip_grad_norm(self, clip): + for model in self.reader: + torch.nn.utils.clip_grad_norm(model.memory_rnn.parameters(), clip) + + def forward(self, input, hidden_states): + ntimestep = input.size(0) + bsz = input.size(1) + emb = self.encoder(input) # timesteps, bsz, ninp + output_h = [] + output_memory = [] + attentions = [] + + reader_state, parser_state, predictor_state = hidden_states # memory_h: bsz, nslots, nhid + + (memory_gate, memory_gate_next), gate, parser_state = self.parser(emb, parser_state) + + rmask = torch.autograd.Variable(torch.ones(self.nlayers, self.nhid)) + if input.is_cuda: rmask = rmask.cuda() + rmask = self.rdrop(rmask) + + for i in range(input.size(0)): + emb_i = emb[i] # emb_i: bsz, nhid + attention = [] + attention.append(memory_gate[i]) + + # summarize layer + h_i = emb_i + for j in range(self.nlayers): + hidden = reader_state[j] + + h_i, new_memory, attention0 = self.reader[j](h_i, hidden, memory_gate[i], rmask[j]) + + # updata states + attention.append(attention0) + reader_state[j] = new_memory + + # predict layer + selected_memory_h, predictor_state, attention1 = self.predictor.attention(h_i, predictor_state, + gate_time=memory_gate_next[i]) + output_h.append(h_i) + output_memory.append(selected_memory_h) + + attention.append(memory_gate_next[i]) + attention.append(attention1) + attentions.append(torch.stack(attention, dim=1)) + + self.attentions = torch.stack(attentions, dim=0) + self.gates = gate + + output_h = torch.stack(output_h, dim=0) + output_memory = torch.stack(output_memory, dim=0) + output = self.predictor(output_h.view(-1, self.nhid), output_memory.view(-1, self.nhid)) + + output = self.drop(output) + decoded = self.decoder(output) + return decoded.view(ntimestep, bsz, -1), (reader_state, parser_state, predictor_state) + + def init_hidden(self, bsz): + return [self.reader[i].init_hidden(bsz) + for i in range(self.nlayers)], \ + self.parser.init_hidden(bsz), \ + self.predictor.init_hidden(bsz)