Permalink
Find file
Fetching contributors…
Cannot retrieve contributors at this time
106 lines (85 sloc) 3.18 KB
"""
Code for sequence generation
"""
import numpy
import copy
def gen_sample(tparams, f_init, f_next, ctx, options, trng=None, k=1, maxlen=30,
stochastic=True, argmax=False, use_unk=False):
"""
Generate a sample, using either beam search or stochastic sampling
"""
if k > 1:
assert not stochastic, 'Beam search does not support stochastic sampling'
sample = []
sample_score = []
if stochastic:
sample_score = 0
live_k = 1
dead_k = 0
hyp_samples = [[]] * live_k
hyp_scores = numpy.zeros(live_k).astype('float32')
hyp_states = []
next_state = f_init(ctx)
next_w = -1 * numpy.ones((1,)).astype('int64')
for ii in xrange(maxlen):
inps = [next_w, next_state]
ret = f_next(*inps)
next_p, next_w, next_state = ret[0], ret[1], ret[2]
if stochastic:
if argmax:
nw = next_p[0].argmax()
else:
nw = next_w[0]
sample.append(nw)
sample_score += next_p[0,nw]
if nw == 0:
break
else:
cand_scores = hyp_scores[:,None] - numpy.log(next_p)
cand_flat = cand_scores.flatten()
if not use_unk:
voc_size = next_p.shape[1]
for xx in range(len(cand_flat) / voc_size):
cand_flat[voc_size * xx + 1] = 1e20
ranks_flat = cand_flat.argsort()[:(k-dead_k)]
voc_size = next_p.shape[1]
trans_indices = ranks_flat / voc_size
word_indices = ranks_flat % voc_size
costs = cand_flat[ranks_flat]
new_hyp_samples = []
new_hyp_scores = numpy.zeros(k-dead_k).astype('float32')
new_hyp_states = []
for idx, [ti, wi] in enumerate(zip(trans_indices, word_indices)):
new_hyp_samples.append(hyp_samples[ti]+[wi])
new_hyp_scores[idx] = copy.copy(costs[idx])
new_hyp_states.append(copy.copy(next_state[ti]))
# check the finished samples
new_live_k = 0
hyp_samples = []
hyp_scores = []
hyp_states = []
for idx in xrange(len(new_hyp_samples)):
if new_hyp_samples[idx][-1] == 0:
sample.append(new_hyp_samples[idx])
sample_score.append(new_hyp_scores[idx])
dead_k += 1
else:
new_live_k += 1
hyp_samples.append(new_hyp_samples[idx])
hyp_scores.append(new_hyp_scores[idx])
hyp_states.append(new_hyp_states[idx])
hyp_scores = numpy.array(hyp_scores)
live_k = new_live_k
if new_live_k < 1:
break
if dead_k >= k:
break
next_w = numpy.array([w[-1] for w in hyp_samples])
next_state = numpy.array(hyp_states)
if not stochastic:
# dump every remaining one
if live_k > 0:
for idx in xrange(live_k):
sample.append(hyp_samples[idx])
sample_score.append(hyp_scores[idx])
return sample, sample_score