Skip to content

Commit

Permalink
Merge branch 'master' of https://github.com/dnouri/inferno
Browse files Browse the repository at this point in the history
  • Loading branch information
ottonemo committed Nov 17, 2017
2 parents b6ae6fd + 7f6da55 commit 25264a5
Show file tree
Hide file tree
Showing 7 changed files with 139 additions and 164 deletions.
22 changes: 12 additions & 10 deletions examples/word_language_model/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

import data
import model
import learner
import net

parser = argparse.ArgumentParser(description='PyTorch PennTreeBank RNN/LSTM Language Model')
parser.add_argument('--data', type=str, default='./data/penn',
Expand All @@ -30,12 +30,12 @@

args = parser.parse_args()

# TODO: set seed
torch.manual_seed(args.seed)

corpus = data.Corpus(args.data)
ntokens = len(corpus.dictionary)

learner = learner.Learner(
net = net.Net(
module=model.RNNModel,
batch_size=1,
use_cuda=args.cuda,
Expand All @@ -44,20 +44,22 @@
module__ninp=200,
module__nhid=200,
module__nlayers=2)
learner.initialize()
learner.load_params(args.checkpoint)
net.initialize()
net.load_params(args.checkpoint)

hidden = None
input = skorch.utils.to_var(torch.rand(1, 1).mul(ntokens).long(),
use_cuda=args.cuda)

with open(args.outf, 'w') as outf:
for i in range(args.words):
word_idx, hidden = learner.sample(input=input,
temperature=args.temperature,
hidden=hidden)
input = skorch.utils.to_var(torch.LongTensor([[word_idx]]),
use_cuda=args.cuda)
word_idx, hidden = net.sample(
input=input,
temperature=args.temperature,
hidden=hidden)
input = skorch.utils.to_var(
torch.LongTensor([[word_idx]]),
use_cuda=args.cuda)

word = corpus.dictionary.idx2word[word_idx]
outf.write(word + ('\n' if i % 20 == 19 else ' '))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,17 +5,20 @@
from sklearn.metrics import f1_score


class Learner(skorch.NeuralNet):

def __init__(self,
criterion=torch.nn.CrossEntropyLoss,
clip=0.25,
lr=20,
ntokens=10000,
*args, **kwargs):
class Net(skorch.NeuralNet):

def __init__(
self,
criterion=torch.nn.CrossEntropyLoss,
clip=0.25,
lr=20,
ntokens=10000,
*args,
**kwargs
):
self.clip = clip
self.ntokens = ntokens
super(Learner, self).__init__(criterion=criterion, lr=lr, *args, **kwargs)
super(Net, self).__init__(criterion=criterion, lr=lr, *args, **kwargs)

def repackage_hidden(self, h):
"""Wraps hidden states in new Variables, to detach them from their history."""
Expand All @@ -27,28 +30,17 @@ def repackage_hidden(self, h):

def on_epoch_begin(self, *args, **kwargs):
super().on_epoch_begin(*args, **kwargs)
self.hidden = self.module_.init_hidden(self.batch_size)

def sample(self, input, temperature=1., hidden=None):
hidden = self.module_.init_hidden(1) if hidden is None else hidden
output, hidden = self.module_(input, hidden)
probas = output.squeeze().data.div(temperature).exp()
sample = torch.multinomial(probas, 1)[-1]
if probas.dim() > 1:
sample = sample[0]
return sample, self.repackage_hidden(hidden)

def sample_n(self, num_words, input, temperature=1., hidden=None):
preds = [None] * num_words
for i in range(num_words):
preds[i], hidden = self.sample(input, hidden=hidden)
input = skorch.utils.to_var(torch.LongTensor([[preds[i]]]),
use_cuda=self.use_cuda)
return preds, hidden
# As an optimization to save tensor allocation for each
# batch we initialize the hidden state only once per epoch.
# This optimization was taken from the original example.
self.hidden = self.module_.init_hidden(self.batch_size)

def train_step(self, X, y, _):
def train_step(self, X, y):
self.module_.train()

# Repackage shared hidden state so that the previous batch
# does not influence the current one.
self.hidden = self.repackage_hidden(self.hidden)
self.module_.zero_grad()

Expand All @@ -66,7 +58,8 @@ def train_step(self, X, y, _):
def validation_step(self, X, y):
self.module_.eval()

output, self.hidden = self.module_(X, self.hidden)
hidden = self.module_.init_hidden(self.batch_size)
output, _ = self.module_(X, hidden)
output_flat = output.view(-1, self.ntokens)

return self.get_loss(output_flat, y)
Expand All @@ -75,36 +68,33 @@ def evaluation_step(self, X, **kwargs):
self.module_.eval()

X = skorch.utils.to_var(X, use_cuda=self.use_cuda)
hidden = self.module_.init_hidden(self.batch_size)
output, _ = self.module_(X, hidden)

# TODO: resetting the hidden layer here prevents the user from
# manually resetting the hidden layer from outside (when generating
# text for example).
self.hidden = self.module_.init_hidden(X.size(1))

# TODO: decide if predict should be stateful or not.
# I have no good answer for this. Needs discussion.
output, self.hidden = self.module_(X, self.hidden)
return output.view(-1, self.ntokens)

# avoid huge computational graph with unnecessary backprop
# information when predicting from the network.
self.hidden = self.repackage_hidden(self.hidden)
def sample(self, input, temperature=1., hidden=None):
hidden = self.module_.init_hidden(1) if hidden is None else hidden
output, hidden = self.module_(input, hidden)
probas = output.squeeze().data.div(temperature).exp()
sample = torch.multinomial(probas, 1)[-1]
if probas.dim() > 1:
sample = sample[0]
return sample, self.repackage_hidden(hidden)

return output.view(-1, self.ntokens)
def sample_n(self, num_words, input, temperature=1., hidden=None):
preds = [None] * num_words
for i in range(num_words):
preds[i], hidden = self.sample(input, hidden=hidden)
input = skorch.utils.to_var(torch.LongTensor([[preds[i]]]),
use_cuda=self.use_cuda)
return preds, hidden

def score(self, X, y=None):
y_probas = []
y_target = []

# We collect the predictions batch-wise and store them on the host
# side as this data can be quite big and the GPU might run into
# memory issues. We do not calculate F1 on the batches as this
# would introduce an error to the score.
for X, y in self.get_iterator(X, y, train=False):
prediction = skorch.utils.to_numpy(self.evaluation_step(X)).argmax(1)
y_probas.append(prediction)
y_target.append(skorch.utils.to_numpy(y))

y_probas = np.concatenate(y_probas)
y_target = np.concatenate(y_target)

return f1_score(y_probas, y_target, average='micro')
ds = self.get_dataset(X)
target_iterator = self.get_iterator(ds, train=False)

y_true = np.concatenate([skorch.utils.to_numpy(y) for _, y in target_iterator])
y_pred = self.predict(X)

return f1_score(y_true, y_pred, average='micro')
68 changes: 0 additions & 68 deletions examples/word_language_model/predict.py

This file was deleted.

64 changes: 40 additions & 24 deletions examples/word_language_model/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

import data
from model import RNNModel
from learner import Learner
from net import Net

parser = argparse.ArgumentParser(description='PyTorch PennTreeBank RNN/LSTM Language Model')
parser.add_argument('--data', type=str, default='./data/penn',
Expand All @@ -17,6 +17,8 @@
help='batch size')
parser.add_argument('--epochs', type=int, default=10, metavar='N',
help='number of epochs')
parser.add_argument('--data-limit', type=int, default=-1,
help='Limit the input data to length N.')
parser.add_argument('--seed', type=int, default=1111,
help='random seed')
parser.add_argument('--no-cuda', dest='cuda', action='store_false',
Expand All @@ -25,7 +27,7 @@
help='path to save the final model')
args = parser.parse_args()

# TODO: set seed
torch.manual_seed(args.seed)

corpus = data.Corpus(args.data)
ntokens = len(corpus.dictionary)
Expand All @@ -35,11 +37,6 @@ def on_epoch_end(self, net, **kwargs):
if not net.history[-1]['valid_loss_best']:
net.lr /= 4.0

class Checkpointing(skorch.callbacks.Callback):
def on_epoch_end(self, net, **kwargs):
if net.history[-1]['valid_loss_best']:
net.save_params(args.save)

class ExamplePrinter(skorch.callbacks.Callback):
def on_epoch_end(self, net, **kwargs):
seed_sentence = "the meaning of"
Expand All @@ -49,48 +46,67 @@ def on_epoch_end(self, net, **kwargs):
print(seed_sentence,
" ".join([corpus.dictionary.idx2word[n] for n in sentence]))

def train_split(X, y):
return X, corpus.valid, None, None

learner = Learner(
def my_train_split(X, y):
# Return (corpus.train, corpus.valid) in case the network
# is fitted using net.fit(corpus.train).
#
# TODO: remove dummy y values once #112 is fixed.
#
import numpy as np
return X, corpus.valid, np.zeros(len(X)), np.zeros(len(corpus.valid))

net = Net(
module=RNNModel,
max_epochs=args.epochs,
batch_size=args.batch_size,
use_cuda=args.cuda,
callbacks=[LRAnnealing(), Checkpointing(), ExamplePrinter()],
callbacks=[
skorch.callbacks.Checkpoint(),
skorch.callbacks.ProgressBar(),
LRAnnealing(),
ExamplePrinter()
],
module__rnn_type='LSTM',
module__ntoken=ntokens,
module__ninp=200,
module__nhid=200,
module__nlayers=2,
train_split=train_split,

# Use (corpus.train, corpus.valid) as validation split.
# Even though we are doing a grid search, we use an internal
# validation set to determine when to save (Checkpoint callback)
# and when to decrease the learning rate (LRAnnealing callback).
train_split=my_train_split,

# To demonstrate that skorch is able to use already available
# data loaders as well, we use the data loader from the word
# language model.
iterator_train=data.Loader,
iterator_train__use_cuda=args.cuda,
iterator_train__bptt=args.bptt,
iterator_valid=data.Loader,
iterator_valid__use_cuda=args.cuda,
iterator_valid__bptt=args.bptt)

# NOFIXME: iterator_valid does not use corpus.valid as dataset
# REASON: we use GridSearchCV to generate validation splits
# FIXME: but we need validation data during training (LR annealing)

# FIXME: currently we have iterators for training and validation. Both of those
# supply (X,y) pairs. We do, however, also use the validation generator in
# predict (thus in scoring as well). Therefore we always generate `y` values
# even though we don't need to.

# TODO: easy way to write own score() that accesses the validation data only.
# Demonstrate the use of grid search by testing different learning
# rates while saving the best model at the end.

params = [
{
'lr': [10,20,30],
},
]

pl = GridSearchCV(learner, params)
pl.fit(corpus.train)
pl = GridSearchCV(net, params)

pl.fit(corpus.train[:args.data_limit].numpy())

print("Results of grid search:")
print("Best parameter configuration:", pl.best_params_)
print("Achieved score:", pl.best_score_)
print("Achieved F1 score:", pl.best_score_)

print("Saving best model to '{}'.".format(args.save))
pl.best_estimator_.save_params(args.save)

0 comments on commit 25264a5

Please sign in to comment.